rewrite transport header

This commit is contained in:
aler9
2020-09-05 22:19:06 +02:00
parent 37e3a1f29f
commit 8525e1e0ff
6 changed files with 266 additions and 83 deletions

View File

@@ -357,12 +357,12 @@ func (c *ConnClient) urlForTrack(baseUrl *url.URL, track *Track) *url.URL {
return u
}
func (c *ConnClient) setup(u *url.URL, track *Track, transport []string) (*Response, error) {
func (c *ConnClient) setup(u *url.URL, track *Track, ht *HeaderTransport) (*Response, error) {
res, err := c.Do(&Request{
Method: SETUP,
Url: c.urlForTrack(u, track),
Header: Header{
"Transport": HeaderValue{strings.Join(transport, ";")},
"Transport": ht.Write(),
},
})
if err != nil {
@@ -406,10 +406,13 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
return nil, nil, nil, err
}
res, err := c.setup(u, track, []string{
"RTP/AVP/UDP",
"unicast",
fmt.Sprintf("client_port=%d-%d", rtpPort, rtcpPort),
res, err := c.setup(u, track, &HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamUnicast
return &ret
}(),
ClientPorts: &[2]int{rtpPort, rtcpPort},
})
if err != nil {
rtpListener.close()
@@ -424,8 +427,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err)
}
rtpServerPort, rtcpServerPort := th.Ports("server_port")
if rtpServerPort == 0 {
if th.ServerPorts == nil {
rtpListener.close()
rtcpListener.close()
return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided")
@@ -437,11 +439,11 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
c.rtcpReceivers[track.Id] = NewRtcpReceiver()
rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
rtpListener.publisherPort = rtpServerPort
rtpListener.publisherPort = (*th.ServerPorts)[0]
c.rtpListeners[track.Id] = rtpListener
rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
rtcpListener.publisherPort = rtcpServerPort
rtcpListener.publisherPort = (*th.ServerPorts)[1]
c.rtcpListeners[track.Id] = rtcpListener
return rtpListener.Read, rtcpListener.Read, res, nil
@@ -462,11 +464,14 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) {
return nil, fmt.Errorf("track has already been setup")
}
interleaved := fmt.Sprintf("interleaved=%d-%d", (track.Id * 2), (track.Id*2)+1)
res, err := c.setup(u, track, []string{
"RTP/AVP/TCP",
"unicast",
interleaved,
interleavedIds := &[2]int{(track.Id * 2), (track.Id * 2) + 1}
res, err := c.setup(u, track, &HeaderTransport{
Protocol: StreamProtocolTCP,
Cast: func() *StreamCast {
ret := StreamUnicast
return &ret
}(),
InterleavedIds: interleavedIds,
})
if err != nil {
return nil, err
@@ -477,10 +482,10 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) {
return nil, fmt.Errorf("SETUP: transport header: %s", err)
}
_, ok := th[interleaved]
if !ok {
return nil, fmt.Errorf("SETUP: transport header does not contain '%s' (%s)",
interleaved, res.Header["Transport"])
if th.InterleavedIds == nil || (*th.InterleavedIds)[0] != (*interleavedIds)[0] ||
(*th.InterleavedIds)[1] != (*interleavedIds)[1] {
return nil, fmt.Errorf("SETUP: transport header does not have interleaved ids %v (%s)",
*interleavedIds, res.Header["Transport"])
}
c.streamUrl = u

View File

@@ -8,9 +8,9 @@ import (
var casesHeaderAuth = []struct {
name string
dec HeaderValue
enc HeaderValue
ha *HeaderAuth
vin HeaderValue
vout HeaderValue
h *HeaderAuth
}{
{
"basic",
@@ -112,9 +112,9 @@ var casesHeaderAuth = []struct {
func TestHeaderAuthRead(t *testing.T) {
for _, c := range casesHeaderAuth {
t.Run(c.name, func(t *testing.T) {
req, err := ReadHeaderAuth(c.dec)
req, err := ReadHeaderAuth(c.vin)
require.NoError(t, err)
require.Equal(t, c.ha, req)
require.Equal(t, c.h, req)
})
}
}
@@ -122,8 +122,8 @@ func TestHeaderAuthRead(t *testing.T) {
func TestHeaderAuthWrite(t *testing.T) {
for _, c := range casesHeaderAuth {
t.Run(c.name, func(t *testing.T) {
req := c.ha.Write()
require.Equal(t, c.enc, req)
req := c.h.Write()
require.Equal(t, c.vout, req)
})
}
}

View File

@@ -7,9 +7,9 @@ import (
)
var casesHeaderSession = []struct {
name string
value HeaderValue
hs *HeaderSession
name string
v HeaderValue
h *HeaderSession
}{
{
"base",
@@ -45,9 +45,9 @@ var casesHeaderSession = []struct {
func TestHeaderSession(t *testing.T) {
for _, c := range casesHeaderSession {
t.Run(c.name, func(t *testing.T) {
req, err := ReadHeaderSession(c.value)
req, err := ReadHeaderSession(c.v)
require.NoError(t, err)
require.Equal(t, c.hs, req)
require.Equal(t, c.h, req)
})
}
}

View File

@@ -7,10 +7,47 @@ import (
)
// HeaderTransport is a Transport header.
type HeaderTransport map[string]struct{}
type HeaderTransport struct {
// protocol of the stream
Protocol StreamProtocol
// cast of the stream
Cast *StreamCast
// client ports
ClientPorts *[2]int
// server ports
ServerPorts *[2]int
// interleaved frame ids
InterleavedIds *[2]int
// mode
Mode *string
}
func parsePorts(val string) (*[2]int, error) {
ports := strings.Split(val, "-")
if len(ports) != 2 {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
}
port1, err := strconv.ParseInt(ports[0], 10, 64)
if err != nil {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
}
port2, err := strconv.ParseInt(ports[1], 10, 64)
if err != nil {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
}
return &[2]int{int(port1), int(port2)}, nil
}
// ReadHeaderTransport parses a Transport header.
func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) {
func ReadHeaderTransport(v HeaderValue) (*HeaderTransport, error) {
if len(v) == 0 {
return nil, fmt.Errorf("value not provided")
}
@@ -19,63 +56,99 @@ func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) {
return nil, fmt.Errorf("value provided multiple times (%v)", v)
}
ht := make(map[string]struct{})
ht := &HeaderTransport{}
protoSet := false
for _, t := range strings.Split(v[0], ";") {
ht[t] = struct{}{}
if t == "RTP/AVP" || t == "RTP/AVP/UDP" {
ht.Protocol = StreamProtocolUDP
protoSet = true
} else if t == "RTP/AVP/TCP" {
ht.Protocol = StreamProtocolTCP
protoSet = true
} else if t == "unicast" {
ret := StreamUnicast
ht.Cast = &ret
} else if t == "multicast" {
ret := StreamMulticast
ht.Cast = &ret
} else if strings.HasPrefix(t, "client_port=") {
ports, err := parsePorts(t[len("client_port="):])
if err != nil {
return nil, err
}
ht.ClientPorts = ports
} else if strings.HasPrefix(t, "server_port=") {
ports, err := parsePorts(t[len("server_port="):])
if err != nil {
return nil, err
}
ht.ServerPorts = ports
} else if strings.HasPrefix(t, "interleaved=") {
ports, err := parsePorts(t[len("interleaved="):])
if err != nil {
return nil, err
}
ht.InterleavedIds = ports
} else if strings.HasPrefix(t, "mode=") {
ret := strings.ToLower(t[len("mode="):])
ret = strings.TrimPrefix(ret, "\"")
ret = strings.TrimSuffix(ret, "\"")
ht.Mode = &ret
}
}
// protocol is the only mandatory field
if !protoSet {
return nil, fmt.Errorf("protocol not set (%v)", v)
}
return ht, nil
}
// IsUDP check whether the header contains the UDP protocol.
func (ht HeaderTransport) IsUDP() bool {
if _, ok := ht["RTP/AVP"]; ok {
return true
}
if _, ok := ht["RTP/AVP/UDP"]; ok {
return true
}
return false
}
// Write encodes a Transport header
func (ht *HeaderTransport) Write() HeaderValue {
var vals []string
// IsTCP check whether the header contains the TCP protocol.
func (ht HeaderTransport) IsTCP() bool {
_, ok := ht["RTP/AVP/TCP"]
return ok
}
if ht.Protocol == StreamProtocolUDP {
vals = append(vals, "RTP/AVP")
} else {
vals = append(vals, "RTP/AVP/TCP")
}
// Value gets a value from the header.
func (ht HeaderTransport) Value(key string) string {
prefix := key + "="
for t := range ht {
if strings.HasPrefix(t, prefix) {
return t[len(prefix):]
if ht.Cast != nil {
if *ht.Cast == StreamUnicast {
vals = append(vals, "unicast")
} else {
vals = append(vals, "multicast")
}
}
return ""
}
// Ports gets a value from the header and parses its ports.
func (ht HeaderTransport) Ports(key string) (int, int) {
val := ht.Value(key)
if val == "" {
return 0, 0
}
ports := strings.Split(val, "-")
if len(ports) != 2 {
return 0, 0
}
port1, err := strconv.ParseInt(ports[0], 10, 64)
if err != nil {
return 0, 0
}
port2, err := strconv.ParseInt(ports[1], 10, 64)
if err != nil {
return 0, 0
}
return int(port1), int(port2)
if ht.ClientPorts != nil {
ports := *ht.ClientPorts
vals = append(vals, "client_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
}
if ht.ServerPorts != nil {
ports := *ht.ServerPorts
vals = append(vals, "server_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
}
if ht.InterleavedIds != nil {
ports := *ht.InterleavedIds
vals = append(vals, "interleaved="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
}
if ht.Mode != nil {
vals = append(vals, "mode="+*ht.Mode)
}
return HeaderValue{strings.Join(vals, ";")}
}

86
header-transport_test.go Normal file
View File

@@ -0,0 +1,86 @@
package gortsplib
import (
"testing"
"github.com/stretchr/testify/require"
)
var casesHeaderTransport = []struct {
name string
vin HeaderValue
vout HeaderValue
h *HeaderTransport
}{
{
"udp unicast play request",
HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`},
HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`},
&HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamUnicast
return &ret
}(),
ClientPorts: &[2]int{3456, 3457},
Mode: func() *string {
ret := "play"
return &ret
}(),
},
},
{
"udp unicast play response",
HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`},
HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`},
&HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamUnicast
return &ret
}(),
ClientPorts: &[2]int{3056, 3057},
ServerPorts: &[2]int{5000, 5001},
},
},
{
"udp multicast play request / response",
HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`},
HeaderValue{`RTP/AVP;multicast`},
&HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamMulticast
return &ret
}(),
},
},
{
"tcp play request / response",
HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
&HeaderTransport{
Protocol: StreamProtocolTCP,
InterleavedIds: &[2]int{0, 1},
},
},
}
func TestHeaderTransportRead(t *testing.T) {
for _, c := range casesHeaderTransport {
t.Run(c.name, func(t *testing.T) {
req, err := ReadHeaderTransport(c.vin)
require.NoError(t, err)
require.Equal(t, c.h, req)
})
}
}
func TestHeaderTransportWrite(t *testing.T) {
for _, c := range casesHeaderTransport {
t.Run(c.name, func(t *testing.T) {
req := c.h.Write()
require.Equal(t, c.vout, req)
})
}
}

View File

@@ -11,7 +11,7 @@ const (
rtspMaxContentLength = 4096
)
// StreamProtocol is the protocol of a stream
// StreamProtocol is the protocol of a stream.
type StreamProtocol int
const (
@@ -30,6 +30,25 @@ func (sp StreamProtocol) String() string {
return "tcp"
}
// StreamCast is the cast of a stream.
type StreamCast int
const (
// Unicast means that the stream will be unicasted
StreamUnicast StreamCast = iota
// Multicast means that the stream will be multicasted
StreamMulticast
)
// String implements fmt.Stringer
func (sc StreamCast) String() string {
if sc == StreamUnicast {
return "unicast"
}
return "multicast"
}
// StreamType is the type of a stream.
type StreamType int