diff --git a/serverconn.go b/serverconn.go index fba627e4..c2474ad0 100644 --- a/serverconn.go +++ b/serverconn.go @@ -435,20 +435,20 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid track URL") + }, fmt.Errorf("unable to generate track URL") } trackPath, ok := trackURL.RTSPPath() if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid track URL") + }, fmt.Errorf("invalid track URL (%v)", trackURL) } if !strings.HasPrefix(trackPath, reqPath) { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid track URL: must begin with '%s', but is '%s'", + }, fmt.Errorf("invalid track path: must begin with '%s', but is '%s'", reqPath, trackPath) } } @@ -528,7 +528,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if th.Mode == nil || *th.Mode == headers.TransportModePlay { trackID, _, ok := base.PathSplitControlAttribute(pathAndQuery) if !ok { - return 0, fmt.Errorf("invalid track (%s)", pathAndQuery) + return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery) } return trackID, nil @@ -541,7 +541,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } } - return 0, fmt.Errorf("invalid track (%s)", pathAndQuery) + return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery) }() if err != nil { return &base.Response{ diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index 389e94ea..15ff63ea 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -16,6 +16,153 @@ import ( "github.com/aler9/gortsplib/pkg/headers" ) +func TestServerConnPublishSetupPath(t *testing.T) { + for _, ca := range []struct { + name string + control string + url string + trackID int + }{ + { + "normal", + "trackID=0", + "rtsp://localhost:8554/teststream/trackID=0", + 0, + }, + { + "unordered id", + "trackID=2", + "rtsp://localhost:8554/teststream/trackID=2", + 0, + }, + { + "custom param name", + "testing=0", + "rtsp://localhost:8554/teststream/testing=0", + 0, + }, + { + "query", + "?testing=0", + "rtsp://localhost:8554/teststream?testing=0", + 0, + }, + } { + t.Run(ca.name, func(t *testing.T) { + setupDone := make(chan int) + + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + onAnnounce := func(req *base.Request, tracks Tracks) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + setupDone <- trackID + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + err = <-conn.Read(ServerConnReadHandlers{ + OnAnnounce: onAnnounce, + OnSetup: onSetup, + }) + require.Equal(t, io.EOF, err) + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + track.Media.Attributes = append(track.Media.Attributes, psdp.Attribute{ + Key: "control", + Value: ca.control, + }) + + sout := &psdp.SessionDescription{ + SessionName: psdp.SessionName("Stream"), + Origin: psdp.Origin{ + Username: "-", + NetworkType: "IN", + AddressType: "IP4", + UnicastAddress: "127.0.0.1", + }, + TimeDescriptions: []psdp.TimeDescription{ + {Timing: psdp.Timing{0, 0}}, //nolint:govet + }, + MediaDescriptions: []*psdp.MediaDescription{ + track.Media, + }, + } + + byts, _ := sout.Marshal() + + err = base.Request{ + Method: base.Announce, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: byts, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + th := &headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModeRecord + return &v + }(), + InterleavedIds: &[2]int{0, 1}, + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL(ca.url), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + trackID := <-setupDone + require.Equal(t, ca.trackID, trackID) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + }) + } +} + func TestServerConnPublishReceivePackets(t *testing.T) { for _, proto := range []string{ "udp", @@ -24,9 +171,11 @@ func TestServerConnPublishReceivePackets(t *testing.T) { t.Run(proto, func(t *testing.T) { packetsReceived := make(chan struct{}) - conf := ServerConf{ - UDPRTPAddress: "127.0.0.1:8000", - UDPRTCPAddress: "127.0.0.1:8001", + conf := ServerConf{} + + if proto == "udp" { + conf.UDPRTPAddress = "127.0.0.1:8000" + conf.UDPRTCPAddress = "127.0.0.1:8001" } s, err := conf.Serve("127.0.0.1:8554") diff --git a/serverconnread_test.go b/serverconnread_test.go index 98dd94d4..ed51698d 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -13,6 +13,91 @@ import ( "github.com/aler9/gortsplib/pkg/headers" ) +func TestServerConnReadSetupPath(t *testing.T) { + for _, ca := range []struct { + name string + url string + trackID int + }{ + { + "normal", + "rtsp://localhost:8554/teststream/trackID=0", + 0, + }, + { + "unordered id", + "rtsp://localhost:8554/teststream/trackID=2", + 2, + }, + } { + t.Run(ca.name, func(t *testing.T) { + setupDone := make(chan int) + + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + setupDone <- trackID + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + err = <-conn.Read(ServerConnReadHandlers{ + OnSetup: onSetup, + }) + require.Equal(t, io.EOF, err) + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + th := &headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIds: &[2]int{ca.trackID * 2, (ca.trackID * 2) + 1}, + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL(ca.url), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + trackID := <-setupDone + require.Equal(t, ca.trackID, trackID) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + }) + } +} + func TestServerConnReadReceivePackets(t *testing.T) { for _, proto := range []string{ "udp",