diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index 34b2a260..eb344e00 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -283,6 +283,118 @@ func TestServerConnPublishSetupDifferentPaths(t *testing.T) { require.Equal(t, "invalid track path (test2stream/trackID=0)", err.Error()) } +func TestServerConnPublishSetupDouble(t *testing.T) { + serverErr := make(chan error) + + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + onAnnounce := func(req *base.Request, tracks Tracks) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + err = <-conn.Read(ServerConnReadHandlers{ + OnAnnounce: onAnnounce, + OnSetup: onSetup, + }) + serverErr <- err + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + + tracks := Tracks{track} + for i, t := range tracks { + t.Media.Attributes = append(t.Media.Attributes, psdp.Attribute{ + Key: "control", + Value: "trackID=" + strconv.FormatInt(int64(i), 10), + }) + } + + err = base.Request{ + Method: base.Announce, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: tracks.Write(), + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + th := &headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModeRecord + return &v + }(), + InterleavedIds: &[2]int{0, 1}, + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) + + err = <-serverErr + require.Equal(t, "track 0 has already been setup", err.Error()) +} + func TestServerConnPublishRecordPartialTracks(t *testing.T) { serverErr := make(chan error) diff --git a/serverconnread_test.go b/serverconnread_test.go index f0c7d099..1b24e6c1 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -130,6 +130,8 @@ func TestServerConnReadSetupPath(t *testing.T) { } func TestServerConnReadSetupDifferentPaths(t *testing.T) { + serverErr := make(chan error) + s, err := Serve("127.0.0.1:8554") require.NoError(t, err) defer s.Close() @@ -149,9 +151,10 @@ func TestServerConnReadSetupDifferentPaths(t *testing.T) { }, nil } - <-conn.Read(ServerConnReadHandlers{ + err = <-conn.Read(ServerConnReadHandlers{ OnSetup: onSetup, }) + serverErr <- err }() conn, err := net.Dial("tcp", "localhost:8554") @@ -202,6 +205,90 @@ func TestServerConnReadSetupDifferentPaths(t *testing.T) { err = res.Read(bconn.Reader) require.NoError(t, err) require.Equal(t, base.StatusBadRequest, res.StatusCode) + + err = <-serverErr + require.Equal(t, "can't setup tracks with different paths", err.Error()) +} + +func TestServerConnReadSetupDouble(t *testing.T) { + serverErr := make(chan error) + + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + err = <-conn.Read(ServerConnReadHandlers{ + OnSetup: onSetup, + }) + serverErr <- err + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + th := &headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIds: &[2]int{0, 1}, + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + th.InterleavedIds = &[2]int{2, 3} + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) + + err = <-serverErr + require.Equal(t, "track 0 has already been setup", err.Error()) } func TestServerConnReadReceivePackets(t *testing.T) {