diff --git a/client.go b/client.go index 2b24c17b..3a32f961 100644 --- a/client.go +++ b/client.go @@ -1237,7 +1237,25 @@ func (c *Client) doDescribe(u *base.URL) (Tracks, *base.URL, *base.Response, err return nil, nil, nil, liberrors.ErrClientContentTypeUnsupported{CT: ct} } + tracks, sd, err := ReadTracks(res.Body, true) + if err != nil { + return nil, nil, nil, err + } + baseURL, err := func() (*base.URL, error) { + // use global control attribute + if control, ok := sd.Attribute("control"); ok && control != "*" { + ret, err := base.ParseURL(control) + if err != nil { + return nil, fmt.Errorf("invalid control attribute: '%v'", control) + } + + // add credentials + ret.User = u.User + + return ret, nil + } + // use Content-Base if cb, ok := res.Header["Content-Base"]; ok { if len(cb) != 1 { @@ -1249,24 +1267,19 @@ func (c *Client) doDescribe(u *base.URL) (Tracks, *base.URL, *base.Response, err return nil, fmt.Errorf("invalid Content-Base: '%v'", cb) } - // add credentials from URL of request + // add credentials ret.User = u.User return ret, nil } - // if not provided, use URL of request + // use URL of request return u, nil }() if err != nil { return nil, nil, nil, err } - tracks, err := ReadTracks(res.Body, true) - if err != nil { - return nil, nil, nil, err - } - c.lastDescribeURL = u return tracks, baseURL, res, nil diff --git a/client_read_test.go b/client_read_test.go index 82231216..408b1ed4 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -768,115 +768,140 @@ func TestClientReadPartial(t *testing.T) { <-packetRecv } -func TestClientReadNoContentBase(t *testing.T) { - l, err := net.Listen("tcp", "localhost:8554") - require.NoError(t, err) - defer l.Close() +func TestClientReadContentBase(t *testing.T) { + for _, ca := range []string{ + "absent", + "inside control attribute", + } { + t.Run(ca, func(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) - conn, err := l.Accept() - require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + br := bufio.NewReader(conn) - req, err := readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Options, req.Method) + req, err := readRequest(br) + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) - byts, _ := base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Public": base.HeaderValue{strings.Join([]string{ - string(base.Describe), - string(base.Setup), - string(base.Play), - }, ", ")}, - }, - }.Write() - _, err = conn.Write(byts) - require.NoError(t, err) + byts, _ := base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }.Write() + _, err = conn.Write(byts) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Describe, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) + req, err = readRequest(br) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) - require.NoError(t, err) + track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) + require.NoError(t, err) - tracks := Tracks{track} - tracks.setControls() + tracks := Tracks{track} + tracks.setControls() - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Content-Type": base.HeaderValue{"application/sdp"}, - }, - Body: tracks.Write(false), - }.Write() - _, err = conn.Write(byts) - require.NoError(t, err) + switch ca { + case "absent": + byts, _ = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: tracks.Write(false), + }.Write() + _, err = conn.Write(byts) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) + case "inside control attribute": + body := string(tracks.Write(false)) + body = strings.Replace(body, "t=0 0", "t=0 0\r\na=control:rtsp://localhost:8554/teststream", 1) - var inTH headers.Transport - err = inTH.Read(req.Header["Transport"]) - require.NoError(t, err) + byts, _ = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream2/"}, + }, + Body: []byte(body), + }.Write() + _, err = conn.Write(byts) + require.NoError(t, err) + } - th := headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Protocol: headers.TransportProtocolUDP, - ClientPorts: inTH.ClientPorts, - ServerPorts: &[2]int{34556, 34557}, - } + req, err = readRequest(br) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Write(), - }, - }.Write() - _, err = conn.Write(byts) - require.NoError(t, err) + var inTH headers.Transport + err = inTH.Read(req.Header["Transport"]) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Play, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) + th := headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Protocol: headers.TransportProtocolUDP, + ClientPorts: inTH.ClientPorts, + ServerPorts: &[2]int{34556, 34557}, + } - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Write() - _, err = conn.Write(byts) - require.NoError(t, err) + byts, _ = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Write(), + }, + }.Write() + _, err = conn.Write(byts) + require.NoError(t, err) - req, err = readRequest(br) - require.NoError(t, err) - require.Equal(t, base.Teardown, req.Method) - require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) + req, err = readRequest(br) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - byts, _ = base.Response{ - StatusCode: base.StatusOK, - }.Write() - _, err = conn.Write(byts) - require.NoError(t, err) - }() + byts, _ = base.Response{ + StatusCode: base.StatusOK, + }.Write() + _, err = conn.Write(byts) + require.NoError(t, err) - c := Client{} + req, err = readRequest(br) + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream"), req.URL) - err = c.StartReading("rtsp://localhost:8554/teststream") - require.NoError(t, err) - c.Close() + byts, _ = base.Response{ + StatusCode: base.StatusOK, + }.Write() + _, err = conn.Write(byts) + require.NoError(t, err) + }() + + c := Client{} + + err = c.StartReading("rtsp://localhost:8554/teststream") + require.NoError(t, err) + c.Close() + }) + } } func TestClientReadAnyPort(t *testing.T) { diff --git a/pkg/sdp/sdp.go b/pkg/sdp/sdp.go index df69e144..7b2afad9 100644 --- a/pkg/sdp/sdp.go +++ b/pkg/sdp/sdp.go @@ -14,6 +14,11 @@ import ( // SessionDescription is a SDP session description. type SessionDescription psdp.SessionDescription +// Attribute returns the value of an attribute and if it exists +func (s *SessionDescription) Attribute(key string) (string, bool) { + return (*psdp.SessionDescription)(s).Attribute(key) +} + // Marshal encodes a SessionDescription. func (s *SessionDescription) Marshal() ([]byte, error) { return (*psdp.SessionDescription)(s).Marshal() diff --git a/serversession.go b/serversession.go index 94820173..51ba9a76 100644 --- a/serversession.go +++ b/serversession.go @@ -503,7 +503,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, liberrors.ErrServerContentTypeUnsupported{CT: ct} } - tracks, err := ReadTracks(req.Body, false) + tracks, _, err := ReadTracks(req.Body, false) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, diff --git a/track_test.go b/track_test.go index b691eb83..4dd04b3b 100644 --- a/track_test.go +++ b/track_test.go @@ -760,7 +760,7 @@ func TestTrackURL(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - tracks, err := ReadTracks(ca.sdp, false) + tracks, _, err := ReadTracks(ca.sdp, false) require.NoError(t, err) ur, err := tracks[0].url(ca.baseURL) require.NoError(t, err) diff --git a/tracks.go b/tracks.go index 4e933531..53cf87e9 100644 --- a/tracks.go +++ b/tracks.go @@ -14,11 +14,12 @@ import ( type Tracks []Track // ReadTracks decodes tracks from the SDP format. -func ReadTracks(byts []byte, skipGenericTracksWithoutClockRate bool) (Tracks, error) { +// It returns the tracks and the decoded SDP. +func ReadTracks(byts []byte, skipGenericTracksWithoutClockRate bool) (Tracks, *sdp.SessionDescription, error) { var sd sdp.SessionDescription err := sd.Unmarshal(byts) if err != nil { - return nil, err + return nil, nil, err } var tracks Tracks //nolint:prealloc @@ -30,17 +31,17 @@ func ReadTracks(byts []byte, skipGenericTracksWithoutClockRate bool) (Tracks, er strings.HasPrefix(err.Error(), "unable to get clock rate") { continue } - return nil, fmt.Errorf("unable to parse track %d: %s", i+1, err) + return nil, nil, fmt.Errorf("unable to parse track %d: %s", i+1, err) } tracks = append(tracks, t) } if len(tracks) == 0 { - return nil, fmt.Errorf("no valid tracks found") + return nil, nil, fmt.Errorf("no valid tracks found") } - return tracks, nil + return tracks, &sd, nil } func (ts Tracks) clone() Tracks { diff --git a/tracks_test.go b/tracks_test.go index 8c2db03a..02251de8 100644 --- a/tracks_test.go +++ b/tracks_test.go @@ -35,7 +35,7 @@ func TestTracksReadErrors(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - _, err := ReadTracks(ca.sdp, false) + _, _, err := ReadTracks(ca.sdp, false) require.EqualError(t, err, ca.err) }) } @@ -67,7 +67,7 @@ func TestTracksReadSkipGenericTracksWithoutClockRate(t *testing.T) { "m=application 42508 RTP/AVP 107\r\n" + "b=AS:8\r\n") - tracks, err := ReadTracks(sdp, true) + tracks, _, err := ReadTracks(sdp, true) require.NoError(t, err) require.Equal(t, Tracks{ &TrackH264{