diff --git a/conn-client.go b/conn-client.go index b25a764b..201c60b2 100644 --- a/conn-client.go +++ b/conn-client.go @@ -357,12 +357,12 @@ func (c *ConnClient) urlForTrack(baseUrl *url.URL, track *Track) *url.URL { return u } -func (c *ConnClient) setup(u *url.URL, track *Track, transport []string) (*Response, error) { +func (c *ConnClient) setup(u *url.URL, track *Track, ht *HeaderTransport) (*Response, error) { res, err := c.Do(&Request{ Method: SETUP, Url: c.urlForTrack(u, track), Header: Header{ - "Transport": HeaderValue{strings.Join(transport, ";")}, + "Transport": ht.Write(), }, }) if err != nil { @@ -406,10 +406,13 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, return nil, nil, nil, err } - res, err := c.setup(u, track, []string{ - "RTP/AVP/UDP", - "unicast", - fmt.Sprintf("client_port=%d-%d", rtpPort, rtcpPort), + res, err := c.setup(u, track, &HeaderTransport{ + Protocol: StreamProtocolUDP, + Cast: func() *StreamCast { + ret := StreamUnicast + return &ret + }(), + ClientPorts: &[2]int{rtpPort, rtcpPort}, }) if err != nil { rtpListener.close() @@ -424,8 +427,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, return nil, nil, nil, fmt.Errorf("SETUP: transport header: %s", err) } - rtpServerPort, rtcpServerPort := th.Ports("server_port") - if rtpServerPort == 0 { + if th.ServerPorts == nil { rtpListener.close() rtcpListener.close() return nil, nil, nil, fmt.Errorf("SETUP: server ports not provided") @@ -437,11 +439,11 @@ func (c *ConnClient) SetupUDP(u *url.URL, track *Track, rtpPort int, c.rtcpReceivers[track.Id] = NewRtcpReceiver() rtpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP - rtpListener.publisherPort = rtpServerPort + rtpListener.publisherPort = (*th.ServerPorts)[0] c.rtpListeners[track.Id] = rtpListener rtcpListener.publisherIp = c.nconn.RemoteAddr().(*net.TCPAddr).IP - rtcpListener.publisherPort = rtcpServerPort + rtcpListener.publisherPort = (*th.ServerPorts)[1] c.rtcpListeners[track.Id] = rtcpListener return rtpListener.Read, rtcpListener.Read, res, nil @@ -462,11 +464,14 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) { return nil, fmt.Errorf("track has already been setup") } - interleaved := fmt.Sprintf("interleaved=%d-%d", (track.Id * 2), (track.Id*2)+1) - res, err := c.setup(u, track, []string{ - "RTP/AVP/TCP", - "unicast", - interleaved, + interleavedIds := &[2]int{(track.Id * 2), (track.Id * 2) + 1} + res, err := c.setup(u, track, &HeaderTransport{ + Protocol: StreamProtocolTCP, + Cast: func() *StreamCast { + ret := StreamUnicast + return &ret + }(), + InterleavedIds: interleavedIds, }) if err != nil { return nil, err @@ -477,10 +482,10 @@ func (c *ConnClient) SetupTCP(u *url.URL, track *Track) (*Response, error) { return nil, fmt.Errorf("SETUP: transport header: %s", err) } - _, ok := th[interleaved] - if !ok { - return nil, fmt.Errorf("SETUP: transport header does not contain '%s' (%s)", - interleaved, res.Header["Transport"]) + if th.InterleavedIds == nil || (*th.InterleavedIds)[0] != (*interleavedIds)[0] || + (*th.InterleavedIds)[1] != (*interleavedIds)[1] { + return nil, fmt.Errorf("SETUP: transport header does not have interleaved ids %v (%s)", + *interleavedIds, res.Header["Transport"]) } c.streamUrl = u diff --git a/header-auth_test.go b/header-auth_test.go index a1c700ab..bf91d969 100644 --- a/header-auth_test.go +++ b/header-auth_test.go @@ -8,9 +8,9 @@ import ( var casesHeaderAuth = []struct { name string - dec HeaderValue - enc HeaderValue - ha *HeaderAuth + vin HeaderValue + vout HeaderValue + h *HeaderAuth }{ { "basic", @@ -112,9 +112,9 @@ var casesHeaderAuth = []struct { func TestHeaderAuthRead(t *testing.T) { for _, c := range casesHeaderAuth { t.Run(c.name, func(t *testing.T) { - req, err := ReadHeaderAuth(c.dec) + req, err := ReadHeaderAuth(c.vin) require.NoError(t, err) - require.Equal(t, c.ha, req) + require.Equal(t, c.h, req) }) } } @@ -122,8 +122,8 @@ func TestHeaderAuthRead(t *testing.T) { func TestHeaderAuthWrite(t *testing.T) { for _, c := range casesHeaderAuth { t.Run(c.name, func(t *testing.T) { - req := c.ha.Write() - require.Equal(t, c.enc, req) + req := c.h.Write() + require.Equal(t, c.vout, req) }) } } diff --git a/header-session_test.go b/header-session_test.go index b8cfd3dd..7b27108d 100644 --- a/header-session_test.go +++ b/header-session_test.go @@ -7,9 +7,9 @@ import ( ) var casesHeaderSession = []struct { - name string - value HeaderValue - hs *HeaderSession + name string + v HeaderValue + h *HeaderSession }{ { "base", @@ -45,9 +45,9 @@ var casesHeaderSession = []struct { func TestHeaderSession(t *testing.T) { for _, c := range casesHeaderSession { t.Run(c.name, func(t *testing.T) { - req, err := ReadHeaderSession(c.value) + req, err := ReadHeaderSession(c.v) require.NoError(t, err) - require.Equal(t, c.hs, req) + require.Equal(t, c.h, req) }) } } diff --git a/header-transport.go b/header-transport.go index 0e1daef0..939b10fe 100644 --- a/header-transport.go +++ b/header-transport.go @@ -7,10 +7,47 @@ import ( ) // HeaderTransport is a Transport header. -type HeaderTransport map[string]struct{} +type HeaderTransport struct { + // protocol of the stream + Protocol StreamProtocol + + // cast of the stream + Cast *StreamCast + + // client ports + ClientPorts *[2]int + + // server ports + ServerPorts *[2]int + + // interleaved frame ids + InterleavedIds *[2]int + + // mode + Mode *string +} + +func parsePorts(val string) (*[2]int, error) { + ports := strings.Split(val, "-") + if len(ports) != 2 { + return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val) + } + + port1, err := strconv.ParseInt(ports[0], 10, 64) + if err != nil { + return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val) + } + + port2, err := strconv.ParseInt(ports[1], 10, 64) + if err != nil { + return &[2]int{0, 0}, fmt.Errorf("invalid ports (%v)", val) + } + + return &[2]int{int(port1), int(port2)}, nil +} // ReadHeaderTransport parses a Transport header. -func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) { +func ReadHeaderTransport(v HeaderValue) (*HeaderTransport, error) { if len(v) == 0 { return nil, fmt.Errorf("value not provided") } @@ -19,63 +56,99 @@ func ReadHeaderTransport(v HeaderValue) (HeaderTransport, error) { return nil, fmt.Errorf("value provided multiple times (%v)", v) } - ht := make(map[string]struct{}) + ht := &HeaderTransport{} + protoSet := false + for _, t := range strings.Split(v[0], ";") { - ht[t] = struct{}{} + if t == "RTP/AVP" || t == "RTP/AVP/UDP" { + ht.Protocol = StreamProtocolUDP + protoSet = true + + } else if t == "RTP/AVP/TCP" { + ht.Protocol = StreamProtocolTCP + protoSet = true + + } else if t == "unicast" { + ret := StreamUnicast + ht.Cast = &ret + + } else if t == "multicast" { + ret := StreamMulticast + ht.Cast = &ret + + } else if strings.HasPrefix(t, "client_port=") { + ports, err := parsePorts(t[len("client_port="):]) + if err != nil { + return nil, err + } + ht.ClientPorts = ports + + } else if strings.HasPrefix(t, "server_port=") { + ports, err := parsePorts(t[len("server_port="):]) + if err != nil { + return nil, err + } + ht.ServerPorts = ports + + } else if strings.HasPrefix(t, "interleaved=") { + ports, err := parsePorts(t[len("interleaved="):]) + if err != nil { + return nil, err + } + ht.InterleavedIds = ports + + } else if strings.HasPrefix(t, "mode=") { + ret := strings.ToLower(t[len("mode="):]) + ret = strings.TrimPrefix(ret, "\"") + ret = strings.TrimSuffix(ret, "\"") + ht.Mode = &ret + } + } + + // protocol is the only mandatory field + if !protoSet { + return nil, fmt.Errorf("protocol not set (%v)", v) } return ht, nil } -// IsUDP check whether the header contains the UDP protocol. -func (ht HeaderTransport) IsUDP() bool { - if _, ok := ht["RTP/AVP"]; ok { - return true - } - if _, ok := ht["RTP/AVP/UDP"]; ok { - return true - } - return false -} +// Write encodes a Transport header +func (ht *HeaderTransport) Write() HeaderValue { + var vals []string -// IsTCP check whether the header contains the TCP protocol. -func (ht HeaderTransport) IsTCP() bool { - _, ok := ht["RTP/AVP/TCP"] - return ok -} + if ht.Protocol == StreamProtocolUDP { + vals = append(vals, "RTP/AVP") + } else { + vals = append(vals, "RTP/AVP/TCP") + } -// Value gets a value from the header. -func (ht HeaderTransport) Value(key string) string { - prefix := key + "=" - for t := range ht { - if strings.HasPrefix(t, prefix) { - return t[len(prefix):] + if ht.Cast != nil { + if *ht.Cast == StreamUnicast { + vals = append(vals, "unicast") + } else { + vals = append(vals, "multicast") } } - return "" -} - -// Ports gets a value from the header and parses its ports. -func (ht HeaderTransport) Ports(key string) (int, int) { - val := ht.Value(key) - if val == "" { - return 0, 0 - } - - ports := strings.Split(val, "-") - if len(ports) != 2 { - return 0, 0 - } - - port1, err := strconv.ParseInt(ports[0], 10, 64) - if err != nil { - return 0, 0 - } - - port2, err := strconv.ParseInt(ports[1], 10, 64) - if err != nil { - return 0, 0 - } - - return int(port1), int(port2) + + if ht.ClientPorts != nil { + ports := *ht.ClientPorts + vals = append(vals, "client_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10)) + } + + if ht.ServerPorts != nil { + ports := *ht.ServerPorts + vals = append(vals, "server_port="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10)) + } + + if ht.InterleavedIds != nil { + ports := *ht.InterleavedIds + vals = append(vals, "interleaved="+strconv.FormatInt(int64(ports[0]), 10)+"-"+strconv.FormatInt(int64(ports[1]), 10)) + } + + if ht.Mode != nil { + vals = append(vals, "mode="+*ht.Mode) + } + + return HeaderValue{strings.Join(vals, ";")} } diff --git a/header-transport_test.go b/header-transport_test.go new file mode 100644 index 00000000..42407234 --- /dev/null +++ b/header-transport_test.go @@ -0,0 +1,86 @@ +package gortsplib + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +var casesHeaderTransport = []struct { + name string + vin HeaderValue + vout HeaderValue + h *HeaderTransport +}{ + { + "udp unicast play request", + HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode="PLAY"`}, + HeaderValue{`RTP/AVP;unicast;client_port=3456-3457;mode=play`}, + &HeaderTransport{ + Protocol: StreamProtocolUDP, + Cast: func() *StreamCast { + ret := StreamUnicast + return &ret + }(), + ClientPorts: &[2]int{3456, 3457}, + Mode: func() *string { + ret := "play" + return &ret + }(), + }, + }, + { + "udp unicast play response", + HeaderValue{`RTP/AVP/UDP;unicast;client_port=3056-3057;server_port=5000-5001`}, + HeaderValue{`RTP/AVP;unicast;client_port=3056-3057;server_port=5000-5001`}, + &HeaderTransport{ + Protocol: StreamProtocolUDP, + Cast: func() *StreamCast { + ret := StreamUnicast + return &ret + }(), + ClientPorts: &[2]int{3056, 3057}, + ServerPorts: &[2]int{5000, 5001}, + }, + }, + { + "udp multicast play request / response", + HeaderValue{`RTP/AVP;multicast;destination=225.219.201.15;port=7000-7001;ttl=127`}, + HeaderValue{`RTP/AVP;multicast`}, + &HeaderTransport{ + Protocol: StreamProtocolUDP, + Cast: func() *StreamCast { + ret := StreamMulticast + return &ret + }(), + }, + }, + { + "tcp play request / response", + HeaderValue{`RTP/AVP/TCP;interleaved=0-1`}, + HeaderValue{`RTP/AVP/TCP;interleaved=0-1`}, + &HeaderTransport{ + Protocol: StreamProtocolTCP, + InterleavedIds: &[2]int{0, 1}, + }, + }, +} + +func TestHeaderTransportRead(t *testing.T) { + for _, c := range casesHeaderTransport { + t.Run(c.name, func(t *testing.T) { + req, err := ReadHeaderTransport(c.vin) + require.NoError(t, err) + require.Equal(t, c.h, req) + }) + } +} + +func TestHeaderTransportWrite(t *testing.T) { + for _, c := range casesHeaderTransport { + t.Run(c.name, func(t *testing.T) { + req := c.h.Write() + require.Equal(t, c.vout, req) + }) + } +} diff --git a/utils.go b/utils.go index 4ad52e61..58a2d5b0 100644 --- a/utils.go +++ b/utils.go @@ -11,7 +11,7 @@ const ( rtspMaxContentLength = 4096 ) -// StreamProtocol is the protocol of a stream +// StreamProtocol is the protocol of a stream. type StreamProtocol int const ( @@ -30,6 +30,25 @@ func (sp StreamProtocol) String() string { return "tcp" } +// StreamCast is the cast of a stream. +type StreamCast int + +const ( + // Unicast means that the stream will be unicasted + StreamUnicast StreamCast = iota + + // Multicast means that the stream will be multicasted + StreamMulticast +) + +// String implements fmt.Stringer +func (sc StreamCast) String() string { + if sc == StreamUnicast { + return "unicast" + } + return "multicast" +} + // StreamType is the type of a stream. type StreamType int