rewrite transport header

This commit is contained in:
aler9
2020-09-05 22:19:06 +02:00
parent 37e3a1f29f
commit 8525e1e0ff
6 changed files with 266 additions and 83 deletions

View File

@@ -357,12 +357,12 @@ func (c *ConnClient) urlForTrack(baseUrl *url.URL, track *Track) *url.URL {
return u return u
} }
func (c *ConnClient) setup(u *url.URL, track *Track, transport []string) (*Response, error) { func (c *ConnClient) setup(u *url.URL, track *Track, ht *HeaderTransport) (*Response, error) {
res, err := c.Do(&Request{ res, err := c.Do(&Request{
Method: SETUP, Method: SETUP,
Url: c.urlForTrack(u, track), Url: c.urlForTrack(u, track),
Header: Header{ Header: Header{
"Transport": HeaderValue{strings.Join(transport, ";")}, "Transport": ht.Write(),
}, },
}) })
if err != nil { if err != nil {
@@ -406,10 +406,13 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
return nil, nil, nil, err return nil, nil, nil, err
} }
res, err := c.setup(u, track, []string{ res, err := c.setup(u, track, &HeaderTransport{
"RTP/AVP/UDP", Protocol: StreamProtocolUDP,
"unicast", Cast: func() *StreamCast {
fmt.Sprintf("client_port=%d-%d", rtpPort, rtcpPort), ret := StreamUnicast
return &ret
}(),
ClientPorts: &[2]int{rtpPort, rtcpPort},
}) })
if err != nil { if err != nil {
rtpListener.close() rtpListener.close()
@@ -424,8 +427,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err) return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err)
} }
rtpServerPort, rtcpServerPort := th.Ports("server_port") if th.ServerPorts == nil {
if rtpServerPort == 0 {
rtpListener.close() rtpListener.close()
rtcpListener.close() rtcpListener.close()
return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided") return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided")
@@ -437,11 +439,11 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int,
c.rtcpReceivers[track.Id] = NewRtcpReceiver() c.rtcpReceivers[track.Id] = NewRtcpReceiver()
rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
rtpListener.publisherPort = rtpServerPort rtpListener.publisherPort = (*th.ServerPorts)[0]
c.rtpListeners[track.Id] = rtpListener c.rtpListeners[track.Id] = rtpListener
rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP
rtcpListener.publisherPort = rtcpServerPort rtcpListener.publisherPort = (*th.ServerPorts)[1]
c.rtcpListeners[track.Id] = rtcpListener c.rtcpListeners[track.Id] = rtcpListener
return rtpListener.Read, rtcpListener.Read, res, nil return rtpListener.Read, rtcpListener.Read, res, nil
@@ -462,11 +464,14 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) {
return nil, fmt.Errorf("track has already been setup") return nil, fmt.Errorf("track has already been setup")
} }
interleaved := fmt.Sprintf("interleaved=%d-%d", (track.Id * 2), (track.Id*2)+1) interleavedIds := &[2]int{(track.Id * 2), (track.Id * 2) + 1}
res, err := c.setup(u, track, []string{ res, err := c.setup(u, track, &HeaderTransport{
"RTP/AVP/TCP", Protocol: StreamProtocolTCP,
"unicast", Cast: func() *StreamCast {
interleaved, ret := StreamUnicast
return &ret
}(),
InterleavedIds: interleavedIds,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -477,10 +482,10 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) {
return nil, fmt.Errorf("SETUP: transport header: %s", err) return nil, fmt.Errorf("SETUP: transport header: %s", err)
} }
_, ok := th[interleaved] if th.InterleavedIds == nil || (*th.InterleavedIds)[0] != (*interleavedIds)[0] ||
if !ok { (*th.InterleavedIds)[1] != (*interleavedIds)[1] {
return nil, fmt.Errorf("SETUP: transport header does not contain '%s' (%s)", return nil, fmt.Errorf("SETUP: transport header does not have interleaved ids %v (%s)",
interleaved, res.Header["Transport"]) *interleavedIds, res.Header["Transport"])
} }
c.streamUrl = u c.streamUrl = u

View File

@@ -8,9 +8,9 @@ import (
var casesHeaderAuth = []struct { var casesHeaderAuth = []struct {
name string name string
dec HeaderValue vin HeaderValue
enc HeaderValue vout HeaderValue
ha *HeaderAuth h *HeaderAuth
}{ }{
{ {
"basic", "basic",
@@ -112,9 +112,9 @@ var casesHeaderAuth = []struct {
func TestHeaderAuthRead(t *testing.T) { func TestHeaderAuthRead(t *testing.T) {
for _, c := range casesHeaderAuth { for _, c := range casesHeaderAuth {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadHeaderAuth(c.dec) req, err := ReadHeaderAuth(c.vin)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.ha, req) require.Equal(t, c.h, req)
}) })
} }
} }
@@ -122,8 +122,8 @@ func TestHeaderAuthRead(t *testing.T) {
func TestHeaderAuthWrite(t *testing.T) { func TestHeaderAuthWrite(t *testing.T) {
for _, c := range casesHeaderAuth { for _, c := range casesHeaderAuth {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req := c.ha.Write() req := c.h.Write()
require.Equal(t, c.enc, req) require.Equal(t, c.vout, req)
}) })
} }
} }

View File

@@ -7,9 +7,9 @@ import (
) )
var casesHeaderSession = []struct { var casesHeaderSession = []struct {
name string name string
value HeaderValue v HeaderValue
hs *HeaderSession h *HeaderSession
}{ }{
{ {
"base", "base",
@@ -45,9 +45,9 @@ var casesHeaderSession = []struct {
func TestHeaderSession(t *testing.T) { func TestHeaderSession(t *testing.T) {
for _, c := range casesHeaderSession { for _, c := range casesHeaderSession {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
req, err := ReadHeaderSession(c.value) req, err := ReadHeaderSession(c.v)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c.hs, req) require.Equal(t, c.h, req)
}) })
} }
} }

View File

@@ -7,10 +7,47 @@ import (
) )
// HeaderTransport is a Transport header. // HeaderTransport is a Transport header.
type HeaderTransport map[string]struct{} type HeaderTransport struct {
// protocol of the stream
Protocol StreamProtocol
// cast of the stream
Cast *StreamCast
// client ports
ClientPorts *[2]int
// server ports
ServerPorts *[2]int
// interleaved frame ids
InterleavedIds *[2]int
// mode
Mode *string
}
func parsePorts(val string) (*[2]int, error) {
ports := strings.Split(val, "-")
if len(ports) != 2 {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
}
port1, err := strconv.ParseInt(ports[0], 10, 64)
if err != nil {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
}
port2, err := strconv.ParseInt(ports[1], 10, 64)
if err != nil {
return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val)
}
return &[2]int{int(port1), int(port2)}, nil
}
// ReadHeaderTransport parses a Transport header. // ReadHeaderTransport parses a Transport header.
func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) { func ReadHeaderTransport(v HeaderValue) (*HeaderTransport, error) {
if len(v) == 0 { if len(v) == 0 {
return nil, fmt.Errorf("value not provided") return nil, fmt.Errorf("value not provided")
} }
@@ -19,63 +56,99 @@ func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) {
return nil, fmt.Errorf("value provided multiple times (%v)", v) return nil, fmt.Errorf("value provided multiple times (%v)", v)
} }
ht := make(map[string]struct{}) ht := &HeaderTransport{}
protoSet := false
for _, t := range strings.Split(v[0], ";") { for _, t := range strings.Split(v[0], ";") {
ht[t] = struct{}{} if t == "RTP/AVP" || t == "RTP/AVP/UDP" {
ht.Protocol = StreamProtocolUDP
protoSet = true
} else if t == "RTP/AVP/TCP" {
ht.Protocol = StreamProtocolTCP
protoSet = true
} else if t == "unicast" {
ret := StreamUnicast
ht.Cast = &ret
} else if t == "multicast" {
ret := StreamMulticast
ht.Cast = &ret
} else if strings.HasPrefix(t, "client_port=") {
ports, err := parsePorts(t[len("client_port="):])
if err != nil {
return nil, err
}
ht.ClientPorts = ports
} else if strings.HasPrefix(t, "server_port=") {
ports, err := parsePorts(t[len("server_port="):])
if err != nil {
return nil, err
}
ht.ServerPorts = ports
} else if strings.HasPrefix(t, "interleaved=") {
ports, err := parsePorts(t[len("interleaved="):])
if err != nil {
return nil, err
}
ht.InterleavedIds = ports
} else if strings.HasPrefix(t, "mode=") {
ret := strings.ToLower(t[len("mode="):])
ret = strings.TrimPrefix(ret, "\"")
ret = strings.TrimSuffix(ret, "\"")
ht.Mode = &ret
}
}
// protocol is the only mandatory field
if !protoSet {
return nil, fmt.Errorf("protocol not set (%v)", v)
} }
return ht, nil return ht, nil
} }
// IsUDP check whether the header contains the UDP protocol. // Write encodes a Transport header
func (ht HeaderTransport) IsUDP() bool { func (ht *HeaderTransport) Write() HeaderValue {
if _, ok := ht["RTP/AVP"]; ok { var vals []string
return true
}
if _, ok := ht["RTP/AVP/UDP"]; ok {
return true
}
return false
}
// IsTCP check whether the header contains the TCP protocol. if ht.Protocol == StreamProtocolUDP {
func (ht HeaderTransport) IsTCP() bool { vals = append(vals, "RTP/AVP")
_, ok := ht["RTP/AVP/TCP"] } else {
return ok vals = append(vals, "RTP/AVP/TCP")
} }
// Value gets a value from the header. if ht.Cast != nil {
func (ht HeaderTransport) Value(key string) string { if *ht.Cast == StreamUnicast {
prefix := key + "=" vals = append(vals, "unicast")
for t := range ht { } else {
if strings.HasPrefix(t, prefix) { vals = append(vals, "multicast")
return t[len(prefix):]
} }
} }
return ""
} if ht.ClientPorts != nil {
ports := *ht.ClientPorts
// Ports gets a value from the header and parses its ports. vals = append(vals, "client_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
func (ht HeaderTransport) Ports(key string) (int, int) { }
val := ht.Value(key)
if val == "" { if ht.ServerPorts != nil {
return 0, 0 ports := *ht.ServerPorts
} vals = append(vals, "server_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
}
ports := strings.Split(val, "-")
if len(ports) != 2 { if ht.InterleavedIds != nil {
return 0, 0 ports := *ht.InterleavedIds
} vals = append(vals, "interleaved="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10))
}
port1, err := strconv.ParseInt(ports[0], 10, 64)
if err != nil { if ht.Mode != nil {
return 0, 0 vals = append(vals, "mode="+*ht.Mode)
} }
port2, err := strconv.ParseInt(ports[1], 10, 64) return HeaderValue{strings.Join(vals, ";")}
if err != nil {
return 0, 0
}
return int(port1), int(port2)
} }

86
header-transport_test.go Normal file
View File

@@ -0,0 +1,86 @@
package gortsplib
import (
"testing"
"github.com/stretchr/testify/require"
)
var casesHeaderTransport = []struct {
name string
vin HeaderValue
vout HeaderValue
h *HeaderTransport
}{
{
"udp unicast play request",
HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`},
HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`},
&HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamUnicast
return &ret
}(),
ClientPorts: &[2]int{3456, 3457},
Mode: func() *string {
ret := "play"
return &ret
}(),
},
},
{
"udp unicast play response",
HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`},
HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`},
&HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamUnicast
return &ret
}(),
ClientPorts: &[2]int{3056, 3057},
ServerPorts: &[2]int{5000, 5001},
},
},
{
"udp multicast play request / response",
HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`},
HeaderValue{`RTP/AVP;multicast`},
&HeaderTransport{
Protocol: StreamProtocolUDP,
Cast: func() *StreamCast {
ret := StreamMulticast
return &ret
}(),
},
},
{
"tcp play request / response",
HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
HeaderValue{`RTP/AVP/TCP;interleaved=0-1`},
&HeaderTransport{
Protocol: StreamProtocolTCP,
InterleavedIds: &[2]int{0, 1},
},
},
}
func TestHeaderTransportRead(t *testing.T) {
for _, c := range casesHeaderTransport {
t.Run(c.name, func(t *testing.T) {
req, err := ReadHeaderTransport(c.vin)
require.NoError(t, err)
require.Equal(t, c.h, req)
})
}
}
func TestHeaderTransportWrite(t *testing.T) {
for _, c := range casesHeaderTransport {
t.Run(c.name, func(t *testing.T) {
req := c.h.Write()
require.Equal(t, c.vout, req)
})
}
}

View File

@@ -11,7 +11,7 @@ const (
rtspMaxContentLength = 4096 rtspMaxContentLength = 4096
) )
// StreamProtocol is the protocol of a stream // StreamProtocol is the protocol of a stream.
type StreamProtocol int type StreamProtocol int
const ( const (
@@ -30,6 +30,25 @@ func (sp StreamProtocol) String() string {
return "tcp" return "tcp"
} }
// StreamCast is the cast of a stream.
type StreamCast int
const (
// Unicast means that the stream will be unicasted
StreamUnicast StreamCast = iota
// Multicast means that the stream will be multicasted
StreamMulticast
)
// String implements fmt.Stringer
func (sc StreamCast) String() string {
if sc == StreamUnicast {
return "unicast"
}
return "multicast"
}
// StreamType is the type of a stream. // StreamType is the type of a stream.
type StreamType int type StreamType int