From b4dba4bfddebb48ce4245d01a56591eb8f8730fe Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Tue, 4 May 2021 12:11:34 +0200 Subject: [PATCH] server: test invalid methods and double sessions --- pkg/liberrors/server.go | 8 +++ server_test.go | 130 +++++++++++++++++++++++++++++++++++++++- serverconn.go | 23 +++---- 3 files changed, 146 insertions(+), 15 deletions(-) diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index 4358b995..fcf778ca 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -31,6 +31,14 @@ func (e ErrServerCSeqMissing) Error() string { return "CSeq is missing" } +// ErrServerInvalidMethod is an error that can be returned by a server. +type ErrServerInvalidMethod struct{} + +// Error implements the error interface. +func (e ErrServerInvalidMethod) Error() string { + return "invalid method" +} + // ErrServerWrongState is an error that can be returned by a server. type ErrServerWrongState struct { AllowedList []fmt.Stringer diff --git a/server_test.go b/server_test.go index f386e379..e58e1525 100644 --- a/server_test.go +++ b/server_test.go @@ -485,7 +485,39 @@ func TestServerErrorCSeqMissing(t *testing.T) { require.Equal(t, base.StatusBadRequest, res.StatusCode) } -func TestServerErrorTCPSameSession(t *testing.T) { +func TestServerErrorInvalidMethod(t *testing.T) { + h := &testServerHandler{ + onConnClose: func(sc *ServerConn, err error) { + require.Equal(t, "invalid method", err.Error()) + }, + } + + s := &Server{Handler: h} + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + err = base.Request{ + Method: "INVALID", + URL: base.MustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) +} + +func TestServerErrorTCPTwoConnOneSession(t *testing.T) { s := &Server{ Handler: &testServerHandler{ onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { @@ -587,6 +619,102 @@ func TestServerErrorTCPSameSession(t *testing.T) { require.Equal(t, base.StatusBadRequest, res.StatusCode) } +func TestServerErrorTCPOneConnTwoSessions(t *testing.T) { + s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + } + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Play, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": res.Header["Session"], + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + "Transport": headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) +} + func TestServerGetSetParameter(t *testing.T) { var params []byte diff --git a/serverconn.go b/serverconn.go index a4cc0f0d..29310a39 100644 --- a/serverconn.go +++ b/serverconn.go @@ -156,9 +156,6 @@ func (sc *ServerConn) run() { } else { err := req.Read(sc.br) if err != nil { - /*if atomic.LoadInt32(&sc.udpTimeout) == 1 { - return liberrors.ErrServerNoUDPPacketsRecently{} - }*/ return err } @@ -174,26 +171,24 @@ func (sc *ServerConn) run() { var err error select { case err = <-readDone: + if sc.tcpFrameEnabled { + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpFrameBackgroundWriteDone + } sc.nconn.Close() sc.s.connClose <- sc <-sc.terminate case <-sc.terminate: + if sc.tcpFrameEnabled { + sc.tcpFrameWriteBuffer.Close() + <-sc.tcpFrameBackgroundWriteDone + } sc.nconn.Close() err = <-readDone } if sc.tcpFrameEnabled { - if sc.tcpFrameIsRecording { - sc.tcpFrameTimeout = false - sc.nconn.SetReadDeadline(time.Time{}) - } - - sc.tcpFrameEnabled = false - sc.tcpFrameWriteBuffer.Close() - <-sc.tcpFrameBackgroundWriteDone - sc.tcpFrameWriteBuffer.Reset() - sc.s.sessionClose <- sc.tcpFrameLinkedSession } @@ -467,7 +462,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusBadRequest, - }, fmt.Errorf("unhandled method: %v", req.Method) + }, liberrors.ErrServerInvalidMethod{} } func (sc *ServerConn) handleRequestOuter(req *base.Request) error {