diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index e9c365f2..c36ceb93 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -77,12 +77,12 @@ func (e ErrServerWrongState) Error() string { e.AllowedList, e.State) } -// ErrServerNoPath is an error that can be returned by a server. -type ErrServerNoPath struct{} +// ErrServerInvalidPath is an error that can be returned by a server. +type ErrServerInvalidPath struct{} // Error implements the error interface. -func (e ErrServerNoPath) Error() string { - return "RTSP path can't be retrieved" +func (e ErrServerInvalidPath) Error() string { + return "invalid path" } // ErrServerContentTypeMissing is an error that can be returned by a server. diff --git a/server_test.go b/server_test.go index c3272a4e..702d5a51 100644 --- a/server_test.go +++ b/server_test.go @@ -5,10 +5,12 @@ import ( "crypto/tls" "fmt" "net" + "strconv" "sync" "testing" "time" + psdp "github.com/pion/sdp/v3" "github.com/stretchr/testify/require" "github.com/aler9/gortsplib/pkg/base" @@ -927,6 +929,10 @@ func TestServerSessionClose(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) + res, err := readResponse(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + <-sessionClosed } @@ -975,7 +981,151 @@ func TestServerSessionAutoClose(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) + res, err := readResponse(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + conn.Close() <-sessionClosed } + +func TestServerErrorInvalidPath(t *testing.T) { + for _, method := range []base.Method{ + base.Describe, + base.Announce, + base.Play, + base.Record, + base.Pause, + //base.GetParameter, + //base.SetParameter, + } { + t.Run(string(method), func(t *testing.T) { + s := &Server{ + Handler: &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + require.Equal(t, "invalid path", err.Error()) + }, + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + sxID := "" + + if method == base.Record { + track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + + tracks := Tracks{track} + for i, t := range tracks { + t.Media.Attributes = append(t.Media.Attributes, psdp.Attribute{ + Key: "control", + Value: "trackID=" + strconv.FormatInt(int64(i), 10), + }) + } + + err = base.Request{ + Method: base.Announce, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: tracks.Write(), + }.Write(bconn.Writer) + require.NoError(t, err) + + res, err := readResponse(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + sxID = res.Header["Session"][0] + } + + if method == base.Play || method == base.Record || method == base.Pause { + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": base.HeaderValue{sxID}, + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + if method == base.Play || method == base.Pause { + v := headers.TransportModePlay + return &v + } + v := headers.TransportModeRecord + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + res, err := readResponse(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + sxID = res.Header["Session"][0] + } + + if method == base.Pause { + err = base.Request{ + Method: base.Play, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": base.HeaderValue{sxID}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + res, err := readResponse(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + } + + err = base.Request{ + Method: method, + URL: base.MustParseURL("rtsp://localhost:8554"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + "Session": base.HeaderValue{sxID}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + res, err := readResponse(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) + }) + } +} diff --git a/serverconn.go b/serverconn.go index 402c0153..fa8e35f6 100644 --- a/serverconn.go +++ b/serverconn.go @@ -245,14 +245,14 @@ func (sc *ServerConn) run() { } }() - sc.nconn.Close() - <-readDone - if sc.tcpFrameEnabled { sc.tcpFrameWriteBuffer.Close() <-sc.tcpFrameBackgroundWriteDone } + sc.nconn.Close() + <-readDone + for _, ss := range sc.sessions { ss.connRemove <- sc } @@ -336,7 +336,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } path, query := base.PathSplitQuery(pathAndQuery) @@ -430,7 +430,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } path, query := base.PathSplitQuery(pathAndQuery) @@ -449,7 +449,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } path, query := base.PathSplitQuery(pathAndQuery) diff --git a/serversession.go b/serversession.go index 896e35c0..83cb6f62 100644 --- a/serversession.go +++ b/serversession.go @@ -26,7 +26,7 @@ func setupGetTrackIDPathQuery(url *base.URL, pathAndQuery, ok := url.RTSPPathAndQuery() if !ok { - return 0, "", "", liberrors.ErrServerNoPath{} + return 0, "", "", liberrors.ErrServerInvalidPath{} } if thMode == nil || *thMode == headers.TransportModePlay { @@ -417,6 +417,15 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, err } + pathAndQuery, ok := req.URL.RTSPPath() + if !ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerInvalidPath{} + } + + path, query := base.PathSplitQuery(pathAndQuery) + ct, ok := req.Header["Content-Type"] if !ok || len(ct) != 1 { return &base.Response{ @@ -443,15 +452,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, liberrors.ErrServerSDPNoTracksDefined{} } - pathAndQuery, ok := req.URL.RTSPPath() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} - } - - path, query := base.PathSplitQuery(pathAndQuery) - for _, track := range tracks { trackURL, err := track.URL() if err != nil { @@ -681,7 +681,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } // path can end with a slash due to Content-Base, remove it @@ -750,7 +750,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } // path can end with a slash due to Content-Base, remove it @@ -818,7 +818,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } // path can end with a slash due to Content-Base, remove it @@ -875,7 +875,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if !ok { return &base.Response{ StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerNoPath{} + }, liberrors.ErrServerInvalidPath{} } path, query := base.PathSplitQuery(pathAndQuery)