diff --git a/server_test.go b/server_test.go index 4803dc75..21eb5d57 100644 --- a/server_test.go +++ b/server_test.go @@ -721,61 +721,72 @@ func TestServerSessionClose(t *testing.T) { } func TestServerSessionAutoClose(t *testing.T) { - sessionClosed := make(chan struct{}) + for _, ca := range []string{ + "200", "400", + } { + t.Run(ca, func(t *testing.T) { + sessionClosed := make(chan struct{}) - track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) - require.NoError(t, err) + track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) + require.NoError(t, err) - stream := NewServerStream(Tracks{track}) - defer stream.Close() + stream := NewServerStream(Tracks{track}) + defer stream.Close() - s := &Server{ - Handler: &testServerHandler{ - onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { - close(sessionClosed) - }, - onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - }, - RTSPAddress: "localhost:8554", + s := &Server{ + Handler: &testServerHandler{ + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { + close(sessionClosed) + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + if ca == "200" { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + } + + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, nil, fmt.Errorf("error") + }, + }, + RTSPAddress: "localhost:8554", + } + + err = s.Start() + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + br := bufio.NewReader(conn) + + _, err = writeReqReadRes(conn, br, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Write(), + }, + }) + require.NoError(t, err) + + conn.Close() + + <-sessionClosed + }) } - - err = s.Start() - require.NoError(t, err) - defer s.Close() - - conn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - br := bufio.NewReader(conn) - - res, err := writeReqReadRes(conn, br, base.Request{ - Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Transport": headers.Transport{ - Protocol: headers.TransportProtocolTCP, - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Mode: func() *headers.TransportMode { - v := headers.TransportModePlay - return &v - }(), - InterleavedIDs: &[2]int{0, 1}, - }.Write(), - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - - conn.Close() - - <-sessionClosed } func TestServerErrorInvalidPath(t *testing.T) { diff --git a/serversession.go b/serversession.go index f14780cb..d4d6b93d 100644 --- a/serversession.go +++ b/serversession.go @@ -337,7 +337,8 @@ func (ss *ServerSession) runInner() error { res, err := ss.handleRequest(req.sc, req.req) - var returnedSession *ServerSession + returnedSession := ss + if err == nil || err == errSwitchReadFunc { // ANNOUNCE responses don't contain the session header. if req.req.Method != base.Announce && @@ -364,9 +365,9 @@ func (ss *ServerSession) runInner() error { }.Write() } - // after a TEARDOWN, session must be unpaired with the connection. - if req.req.Method != base.Teardown { - returnedSession = ss + // after a TEARDOWN, session must be unpaired with the connection + if req.req.Method == base.Teardown { + returnedSession = nil } }