diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 7c4a8421..a973a99e 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -48,16 +48,6 @@ func (e ErrServerCSeqMissing) Error() string { return "CSeq is missing" } -// ErrServerUnhandledRequest is an error that can be returned by a server. -type ErrServerUnhandledRequest struct { - Request *base.Request -} - -// Error implements the error interface. -func (e ErrServerUnhandledRequest) Error() string { - return fmt.Sprintf("unhandled request: %v %v", e.Request.Method, e.Request.URL) -} - // ErrServerInvalidState is an error that can be returned by a server. type ErrServerInvalidState struct { AllowedList []fmt.Stringer diff --git a/server_test.go b/server_test.go index 7c6e3bb1..263ab54b 100644 --- a/server_test.go +++ b/server_test.go @@ -316,38 +316,103 @@ func TestServerErrorCSeqMissing(t *testing.T) { <-connClosed } -func TestServerErrorInvalidMethod(t *testing.T) { - connClosed := make(chan struct{}) +type testServerErrMethodNotImplemented struct { + stream *ServerStream +} - s := &Server{ - Handler: &testServerHandler{ - onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { - require.EqualError(t, ctx.Error, "unhandled request: INVALID rtsp://localhost:8554/") - close(connClosed) - }, - }, - RTSPAddress: "localhost:8554", +func (s *testServerErrMethodNotImplemented) OnSetup( + ctx *ServerHandlerOnSetupCtx, +) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, s.stream, nil +} + +func TestServerErrorMethodNotImplemented(t *testing.T) { + for _, ca := range []string{"outside session", "inside session"} { + t.Run(ca, func(t *testing.T) { + track := &TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + } + + stream := NewServerStream(Tracks{track}) + defer stream.Close() + + s := &Server{ + Handler: &testServerErrMethodNotImplemented{stream}, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + br := bufio.NewReader(conn) + + var sx headers.Session + + if ca == "inside session" { + res, err := writeReqReadRes(conn, br, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Marshal(), + }, + }) + require.NoError(t, err) + + err = sx.Unmarshal(res.Header["Session"]) + require.NoError(t, err) + } + + headers := base.Header{ + "CSeq": base.HeaderValue{"2"}, + } + if ca == "inside session" { + headers["Session"] = base.HeaderValue{sx.Session} + } + + res, err := writeReqReadRes(conn, br, base.Request{ + Method: base.SetParameter, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: headers, + }) + require.NoError(t, err) + require.Equal(t, base.StatusNotImplemented, res.StatusCode) + + headers = base.Header{ + "CSeq": base.HeaderValue{"3"}, + } + if ca == "inside session" { + headers["Session"] = base.HeaderValue{sx.Session} + } + + res, err = writeReqReadRes(conn, br, base.Request{ + Method: base.Options, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: headers, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + }) } - err := s.Start() - require.NoError(t, err) - defer s.Close() - - conn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer conn.Close() - br := bufio.NewReader(conn) - - res, err := writeReqReadRes(conn, br, base.Request{ - Method: "INVALID", - URL: mustParseURL("rtsp://localhost:8554/"), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusBadRequest, res.StatusCode) - - <-connClosed } func TestServerErrorTCPTwoConnOneSession(t *testing.T) { diff --git a/serverconn.go b/serverconn.go index bfc665b1..9e45089f 100644 --- a/serverconn.go +++ b/serverconn.go @@ -505,8 +505,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerUnhandledRequest{Request: req} + StatusCode: base.StatusNotImplemented, + }, nil } func (sc *ServerConn) handleRequestOuter(req *base.Request) error { diff --git a/serversession.go b/serversession.go index e44ea8fc..77396fc9 100644 --- a/serversession.go +++ b/serversession.go @@ -1133,8 +1133,8 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerUnhandledRequest{Request: req} + StatusCode: base.StatusNotImplemented, + }, nil } func (ss *ServerSession) runWriter() {