diff --git a/examples/proxy/server.go b/examples/proxy/server.go index 32334fb6..d7ee5235 100644 --- a/examples/proxy/server.go +++ b/examples/proxy/server.go @@ -102,7 +102,10 @@ func (s *server) setStreamReady(desc *description.Session) *gortsplib.ServerStre Server: s.s, Desc: desc, } - s.stream.Initialize() + err := s.stream.Initialize() + if err != nil { + panic(err) + } return s.stream } diff --git a/examples/server-auth/main.go b/examples/server-auth/main.go index fb1b4b33..b15901c9 100644 --- a/examples/server-auth/main.go +++ b/examples/server-auth/main.go @@ -121,7 +121,10 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) ( Server: sh.s, Desc: ctx.Description, } - sh.stream.Initialize() + err := sh.stream.Initialize() + if err != nil { + panic(err) + } sh.publisher = ctx.Session return &base.Response{ diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index cb62d501..751f1d83 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -93,7 +93,10 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) ( Server: sh.s, Desc: ctx.Description, } - sh.stream.Initialize() + err := sh.stream.Initialize() + if err != nil { + panic(err) + } sh.publisher = ctx.Session return &base.Response{ diff --git a/examples/server/main.go b/examples/server/main.go index 270edfdc..2ecca102 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -92,7 +92,10 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) ( Server: sh.s, Desc: ctx.Description, } - sh.stream.Initialize() + err := sh.stream.Initialize() + if err != nil { + panic(err) + } sh.publisher = ctx.Session return &base.Response{ diff --git a/internal/highleveltests/server_test.go b/internal/highleveltests/server_test.go index 9507e29a..0980230f 100644 --- a/internal/highleveltests/server_test.go +++ b/internal/highleveltests/server_test.go @@ -332,7 +332,8 @@ func TestServerRecordRead(t *testing.T) { Server: s, Desc: ctx.Description, } - stream.Initialize() + err := stream.Initialize() + require.NoError(t, err) publisher = ctx.Session return &base.Response{ diff --git a/server_play_test.go b/server_play_test.go index 108844c8..e87ba4d0 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -302,7 +302,8 @@ func TestServerPlayPath(t *testing.T) { }, }, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -388,7 +389,8 @@ func TestServerPlaySetupErrors(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) if ca == "closed stream" { stream.Close() @@ -560,7 +562,8 @@ func TestServerPlaySetupErrorSameUDPPortsAndIP(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() for i := 0; i < 2; i++ { @@ -740,7 +743,8 @@ func TestServerPlay(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", listenIP+":8554") @@ -1038,7 +1042,8 @@ func TestServerPlaySocketError(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) func() { nconn, err := net.Dial("tcp", listenIP+":8554") @@ -1208,7 +1213,8 @@ func TestServerPlayDecodeErrors(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1330,7 +1336,8 @@ func TestServerPlayRTCPReport(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1453,7 +1460,8 @@ func TestServerPlayVLCMulticast(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", listenIP+":8554") @@ -1538,7 +1546,8 @@ func TestServerPlayTCPResponseBeforeFrames(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1629,7 +1638,8 @@ func TestServerPlayPause(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1726,7 +1736,8 @@ func TestServerPlayPlayPausePausePlay(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1813,7 +1824,8 @@ func TestServerPlayTimeout(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1903,7 +1915,8 @@ func TestServerPlayWithoutTeardown(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -1979,7 +1992,8 @@ func TestServerPlayUDPChangeConn(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() sxID := "" @@ -2067,7 +2081,8 @@ func TestServerPlayPartialMedias(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media, testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -2188,7 +2203,8 @@ func TestServerPlayAdditionalInfos(t *testing.T) { }, }, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() err = stream.WritePacketRTP(stream.Description().Medias[0], &rtp.Packet{ @@ -2318,7 +2334,8 @@ func TestServerPlayNoInterleavedIDs(t *testing.T) { }, }, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -2392,7 +2409,8 @@ func TestServerPlayStreamStats(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() for _, transport := range []string{"tcp", "multicast"} { diff --git a/server_record_test.go b/server_record_test.go index 800ffb68..e1dc813d 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -297,7 +297,8 @@ func TestServerRecordPath(t *testing.T) { Server: s, Desc: ctx.Description, } - stream.Initialize() + err := stream.Initialize() + require.NoError(t, err) defer stream.Close() return &base.Response{ diff --git a/server_stream.go b/server_stream.go index b10c43f0..574d112d 100644 --- a/server_stream.go +++ b/server_stream.go @@ -1,6 +1,7 @@ package gortsplib import ( + "fmt" "sync" "sync/atomic" "time" @@ -32,7 +33,10 @@ func NewServerStream(s *Server, desc *description.Session) *ServerStream { Server: s, Desc: desc, } - st.Initialize() + err := st.Initialize() + if err != nil { + panic(err) + } return st } @@ -54,7 +58,11 @@ type ServerStream struct { } // Initialize initializes a ServerStream. -func (st *ServerStream) Initialize() { +func (st *ServerStream) Initialize() error { + if st.Server == nil || st.Server.sessions == nil { + return fmt.Errorf("server not present or not initialized") + } + st.readers = make(map[*ServerSession]struct{}) st.activeUnicastReaders = make(map[*ServerSession]struct{}) @@ -68,6 +76,8 @@ func (st *ServerStream) Initialize() { sm.initialize() st.medias[medi] = sm } + + return nil } // Close closes a ServerStream. diff --git a/server_test.go b/server_test.go index f3045194..4e2a8916 100644 --- a/server_test.go +++ b/server_test.go @@ -392,7 +392,8 @@ func TestServerErrorMethodNotImplemented(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() h.stream = stream @@ -489,7 +490,8 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn1, err := net.Dial("tcp", "localhost:8554") @@ -574,7 +576,8 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -641,7 +644,8 @@ func TestServerSetupMultipleTransports(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -741,7 +745,8 @@ func TestServerGetSetParameter(t *testing.T) { Server: s, Desc: &description.Session{Medias: []*description.Media{testH264Media}}, } - stream.Initialize() + err = stream.Initialize() + require.NoError(t, err) defer stream.Close() nconn, err := net.Dial("tcp", "localhost:8554") @@ -854,220 +859,6 @@ func TestServerErrorInvalidSession(t *testing.T) { } } -func TestServerSessionClose(t *testing.T) { - var stream *ServerStream - var session *ServerSession - connClosed := make(chan struct{}) - - s := &Server{ - Handler: &testServerHandler{ - onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { - session = ctx.Session - }, - onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { - close(connClosed) - }, - onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - }, - RTSPAddress: "localhost:8554", - } - - err := s.Start() - require.NoError(t, err) - defer s.Close() - - stream = &ServerStream{ - Server: s, - Desc: &description.Session{Medias: []*description.Media{testH264Media}}, - } - stream.Initialize() - defer stream.Close() - - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer nconn.Close() - conn := conn.NewConn(nconn) - - desc := doDescribe(t, conn) - - inTH := &headers.Transport{ - Protocol: headers.TransportProtocolTCP, - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - Mode: transportModePtr(headers.TransportModePlay), - InterleavedIDs: &[2]int{0, 1}, - } - - doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") - - session.Close() - session.Close() - - select { - case <-connClosed: - case <-time.After(2 * time.Second): - t.Errorf("should not happen") - } - - _, err = writeReqReadRes(conn, base.Request{ - Method: base.Options, - URL: mustParseURL("rtsp://localhost:8554/"), - Header: base.Header{ - "CSeq": base.HeaderValue{"2"}, - }, - }) - require.Error(t, err) -} - -func TestServerSessionAutoClose(t *testing.T) { - for _, ca := range []string{ - "200", "400", - } { - t.Run(ca, func(t *testing.T) { - var stream *ServerStream - sessionClosed := make(chan struct{}) - - s := &Server{ - Handler: &testServerHandler{ - onSessionClose: func(_ *ServerHandlerOnSessionCloseCtx) { - close(sessionClosed) - }, - onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - if ca == "200" { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - } - - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, nil, fmt.Errorf("error") - }, - }, - RTSPAddress: "localhost:8554", - } - - err := s.Start() - require.NoError(t, err) - defer s.Close() - - stream = &ServerStream{ - Server: s, - Desc: &description.Session{Medias: []*description.Media{testH264Media}}, - } - stream.Initialize() - defer stream.Close() - - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - conn := conn.NewConn(nconn) - - desc := doDescribe(t, conn) - - inTH := &headers.Transport{ - Protocol: headers.TransportProtocolTCP, - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - Mode: transportModePtr(headers.TransportModePlay), - InterleavedIDs: &[2]int{0, 1}, - } - - res, err := writeReqReadRes(conn, base.Request{ - Method: base.Setup, - URL: mediaURL(t, desc.BaseURL, desc.Medias[0]), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Transport": inTH.Marshal(), - }, - }) - require.NoError(t, err) - - if ca == "200" { - require.Equal(t, base.StatusOK, res.StatusCode) - } else { - require.Equal(t, base.StatusBadRequest, res.StatusCode) - } - - nconn.Close() - - <-sessionClosed - }) - } -} - -func TestServerSessionTeardown(t *testing.T) { - var stream *ServerStream - - s := &Server{ - Handler: &testServerHandler{ - onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - }, - RTSPAddress: "localhost:8554", - } - - err := s.Start() - require.NoError(t, err) - defer s.Close() - - stream = &ServerStream{ - Server: s, - Desc: &description.Session{Medias: []*description.Media{testH264Media}}, - } - stream.Initialize() - defer stream.Close() - - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer nconn.Close() - conn := conn.NewConn(nconn) - - desc := doDescribe(t, conn) - - inTH := &headers.Transport{ - Protocol: headers.TransportProtocolTCP, - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - Mode: transportModePtr(headers.TransportModePlay), - InterleavedIDs: &[2]int{0, 1}, - } - - res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") - - session := readSession(t, res) - - doTeardown(t, conn, "rtsp://localhost:8554/", session) - - res, err = writeReqReadRes(conn, base.Request{ - Method: base.Options, - URL: mustParseURL("rtsp://localhost:8554/"), - Header: base.Header{ - "CSeq": base.HeaderValue{"3"}, - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) -} - func TestServerAuth(t *testing.T) { for _, method := range []string{"all", "basic", "digest_md5", "digest_sha256"} { t.Run(method, func(t *testing.T) { @@ -1207,3 +998,228 @@ func TestServerAuthFail(t *testing.T) { _, err = writeReqReadRes(conn, req) require.Error(t, err) } + +func TestServerSessionClose(t *testing.T) { + var stream *ServerStream + var session *ServerSession + connClosed := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { + session = ctx.Session + }, + onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { + close(connClosed) + }, + onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + stream = &ServerStream{ + Server: s, + Desc: &description.Session{Medias: []*description.Media{testH264Media}}, + } + err = stream.Initialize() + require.NoError(t, err) + defer stream.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + desc := doDescribe(t, conn) + + inTH := &headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Mode: transportModePtr(headers.TransportModePlay), + InterleavedIDs: &[2]int{0, 1}, + } + + doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") + + session.Close() + session.Close() + + select { + case <-connClosed: + case <-time.After(2 * time.Second): + t.Errorf("should not happen") + } + + _, err = writeReqReadRes(conn, base.Request{ + Method: base.Options, + URL: mustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + }, + }) + require.Error(t, err) +} + +func TestServerSessionAutoClose(t *testing.T) { + for _, ca := range []string{ + "200", "400", + } { + t.Run(ca, func(t *testing.T) { + var stream *ServerStream + sessionClosed := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onSessionClose: func(_ *ServerHandlerOnSessionCloseCtx) { + close(sessionClosed) + }, + onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + if ca == "200" { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + } + + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, nil, fmt.Errorf("error") + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + stream = &ServerStream{ + Server: s, + Desc: &description.Session{Medias: []*description.Media{testH264Media}}, + } + err = stream.Initialize() + require.NoError(t, err) + defer stream.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + conn := conn.NewConn(nconn) + + desc := doDescribe(t, conn) + + inTH := &headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Mode: transportModePtr(headers.TransportModePlay), + InterleavedIDs: &[2]int{0, 1}, + } + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mediaURL(t, desc.BaseURL, desc.Medias[0]), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + + if ca == "200" { + require.Equal(t, base.StatusOK, res.StatusCode) + } else { + require.Equal(t, base.StatusBadRequest, res.StatusCode) + } + + nconn.Close() + + <-sessionClosed + }) + } +} + +func TestServerSessionTeardown(t *testing.T) { + var stream *ServerStream + + s := &Server{ + Handler: &testServerHandler{ + onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + stream = &ServerStream{ + Server: s, + Desc: &description.Session{Medias: []*description.Media{testH264Media}}, + } + err = stream.Initialize() + require.NoError(t, err) + defer stream.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + desc := doDescribe(t, conn) + + inTH := &headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Mode: transportModePtr(headers.TransportModePlay), + InterleavedIDs: &[2]int{0, 1}, + } + + res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "") + + session := readSession(t, res) + + doTeardown(t, conn, "rtsp://localhost:8554/", session) + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Options, + URL: mustParseURL("rtsp://localhost:8554/"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) +} + +func TestServerStreamErrorNoServer(t *testing.T) { + s := &Server{} + + stream := &ServerStream{Server: s} + err := stream.Initialize() + require.Error(t, err) +}