diff --git a/server_test.go b/server_test.go index a828d77b..8c6407c3 100644 --- a/server_test.go +++ b/server_test.go @@ -252,8 +252,18 @@ func TestServerConnClose(t *testing.T) { nconn, err := net.Dial("tcp", "localhost:8554") require.NoError(t, err) defer nconn.Close() + conn := conn.NewConn(nconn) <-nconnClosed + + _, err = writeReqReadRes(conn, base.Request{ + Method: base.Options, + URL: mustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + }, + }) + require.Error(t, err) } func TestServerCSeq(t *testing.T) { @@ -795,21 +805,24 @@ func TestServerErrorInvalidSession(t *testing.T) { } func TestServerSessionClose(t *testing.T) { - sessionClosed := make(chan struct{}) + stream := NewServerStream(Tracks{&TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + }}) + defer stream.Close() + + var session *ServerSession s := &Server{ Handler: &testServerHandler{ onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { - ctx.Session.Close() - ctx.Session.Close() - }, - onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { - close(sessionClosed) + session = ctx.Session }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { return &base.Response{ StatusCode: base.StatusOK, - }, nil, nil + }, stream, nil }, }, RTSPAddress: "localhost:8554", @@ -824,7 +837,7 @@ func TestServerSessionClose(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - err = conn.WriteRequest(&base.Request{ + res, err := writeReqReadRes(conn, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ @@ -844,8 +857,19 @@ func TestServerSessionClose(t *testing.T) { }, }) require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) - <-sessionClosed + session.Close() + session.Close() + + _, err = writeReqReadRes(conn, base.Request{ + Method: base.Options, + URL: mustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + }, + }) + require.Error(t, err) } func TestServerSessionAutoClose(t *testing.T) { @@ -855,13 +879,11 @@ func TestServerSessionAutoClose(t *testing.T) { t.Run(ca, func(t *testing.T) { sessionClosed := make(chan struct{}) - track := &TrackH264{ + stream := NewServerStream(Tracks{&TrackH264{ PayloadType: 96, SPS: []byte{0x01, 0x02, 0x03, 0x04}, PPS: []byte{0x01, 0x02, 0x03, 0x04}, - } - - stream := NewServerStream(Tracks{track}) + }}) defer stream.Close() s := &Server{ @@ -920,6 +942,82 @@ func TestServerSessionAutoClose(t *testing.T) { } } +func TestServerSessionTeardown(t *testing.T) { + stream := NewServerStream(Tracks{&TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + }}) + defer stream.Close() + + s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Marshal(), + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + var sx headers.Session + err = sx.Unmarshal(res.Header["Session"]) + require.NoError(t, err) + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Teardown, + URL: mustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": base.HeaderValue{sx.Session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Options, + URL: mustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) +} + func TestServerErrorInvalidPath(t *testing.T) { for _, ca := range []string{"inside session", "outside session"} { t.Run(ca, func(t *testing.T) { diff --git a/serversession.go b/serversession.go index b2b2b475..efb27299 100644 --- a/serversession.go +++ b/serversession.go @@ -310,13 +310,14 @@ func (ss *ServerSession) run() { <-ss.writerDone } + // close all associated connections, both UDP and TCP + // except for the ones that called TEARDOWN + // (that are detached from the session just after the request) for sc := range ss.conns { - if sc == ss.tcpConn { - sc.Close() + sc.Close() - // make sure that OnFrame() is never called after OnSessionClose() - <-sc.done - } + // make sure that OnFrame() is never called after OnSessionClose() + <-sc.done select { case sc.sessionRemove <- ss: @@ -379,6 +380,7 @@ func (ss *ServerSession) runInner() error { // after a TEARDOWN, session must be unpaired with the connection if req.req.Method == base.Teardown { + delete(ss.conns, req.sc) returnedSession = nil } }