make ServerStream return an error if initialized before Server (#719) (#728)

This commit is contained in:
Alessandro Ros
2025-03-23 16:17:34 +01:00
committed by GitHub
parent fa94080e84
commit fcb018151b
9 changed files with 303 additions and 245 deletions

View File

@@ -102,7 +102,10 @@ func (s *server) setStreamReady(desc *description.Session) *gortsplib.ServerStre
Server: s.s,
Desc: desc,
}
s.stream.Initialize()
err := s.stream.Initialize()
if err != nil {
panic(err)
}
return s.stream
}

View File

@@ -121,7 +121,10 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
Server: sh.s,
Desc: ctx.Description,
}
sh.stream.Initialize()
err := sh.stream.Initialize()
if err != nil {
panic(err)
}
sh.publisher = ctx.Session
return &base.Response{

View File

@@ -93,7 +93,10 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
Server: sh.s,
Desc: ctx.Description,
}
sh.stream.Initialize()
err := sh.stream.Initialize()
if err != nil {
panic(err)
}
sh.publisher = ctx.Session
return &base.Response{

View File

@@ -92,7 +92,10 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (
Server: sh.s,
Desc: ctx.Description,
}
sh.stream.Initialize()
err := sh.stream.Initialize()
if err != nil {
panic(err)
}
sh.publisher = ctx.Session
return &base.Response{

View File

@@ -332,7 +332,8 @@ func TestServerRecordRead(t *testing.T) {
Server: s,
Desc: ctx.Description,
}
stream.Initialize()
err := stream.Initialize()
require.NoError(t, err)
publisher = ctx.Session
return &base.Response{

View File

@@ -302,7 +302,8 @@ func TestServerPlayPath(t *testing.T) {
},
},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -388,7 +389,8 @@ func TestServerPlaySetupErrors(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
if ca == "closed stream" {
stream.Close()
@@ -560,7 +562,8 @@ func TestServerPlaySetupErrorSameUDPPortsAndIP(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
for i := 0; i < 2; i++ {
@@ -740,7 +743,8 @@ func TestServerPlay(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", listenIP+":8554")
@@ -1038,7 +1042,8 @@ func TestServerPlaySocketError(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
func() {
nconn, err := net.Dial("tcp", listenIP+":8554")
@@ -1208,7 +1213,8 @@ func TestServerPlayDecodeErrors(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1330,7 +1336,8 @@ func TestServerPlayRTCPReport(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1453,7 +1460,8 @@ func TestServerPlayVLCMulticast(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", listenIP+":8554")
@@ -1538,7 +1546,8 @@ func TestServerPlayTCPResponseBeforeFrames(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1629,7 +1638,8 @@ func TestServerPlayPause(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1726,7 +1736,8 @@ func TestServerPlayPlayPausePausePlay(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1813,7 +1824,8 @@ func TestServerPlayTimeout(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1903,7 +1915,8 @@ func TestServerPlayWithoutTeardown(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -1979,7 +1992,8 @@ func TestServerPlayUDPChangeConn(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
sxID := ""
@@ -2067,7 +2081,8 @@ func TestServerPlayPartialMedias(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media, testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -2188,7 +2203,8 @@ func TestServerPlayAdditionalInfos(t *testing.T) {
},
},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
err = stream.WritePacketRTP(stream.Description().Medias[0], &rtp.Packet{
@@ -2318,7 +2334,8 @@ func TestServerPlayNoInterleavedIDs(t *testing.T) {
},
},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -2392,7 +2409,8 @@ func TestServerPlayStreamStats(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
for _, transport := range []string{"tcp", "multicast"} {

View File

@@ -297,7 +297,8 @@ func TestServerRecordPath(t *testing.T) {
Server: s,
Desc: ctx.Description,
}
stream.Initialize()
err := stream.Initialize()
require.NoError(t, err)
defer stream.Close()
return &base.Response{

View File

@@ -1,6 +1,7 @@
package gortsplib
import (
"fmt"
"sync"
"sync/atomic"
"time"
@@ -32,7 +33,10 @@ func NewServerStream(s *Server, desc *description.Session) *ServerStream {
Server: s,
Desc: desc,
}
st.Initialize()
err := st.Initialize()
if err != nil {
panic(err)
}
return st
}
@@ -54,7 +58,11 @@ type ServerStream struct {
}
// Initialize initializes a ServerStream.
func (st *ServerStream) Initialize() {
func (st *ServerStream) Initialize() error {
if st.Server == nil || st.Server.sessions == nil {
return fmt.Errorf("server not present or not initialized")
}
st.readers = make(map[*ServerSession]struct{})
st.activeUnicastReaders = make(map[*ServerSession]struct{})
@@ -68,6 +76,8 @@ func (st *ServerStream) Initialize() {
sm.initialize()
st.medias[medi] = sm
}
return nil
}
// Close closes a ServerStream.

View File

@@ -392,7 +392,8 @@ func TestServerErrorMethodNotImplemented(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
h.stream = stream
@@ -489,7 +490,8 @@ func TestServerErrorTCPTwoConnOneSession(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn1, err := net.Dial("tcp", "localhost:8554")
@@ -574,7 +576,8 @@ func TestServerErrorTCPOneConnTwoSessions(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -641,7 +644,8 @@ func TestServerSetupMultipleTransports(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -741,7 +745,8 @@ func TestServerGetSetParameter(t *testing.T) {
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
@@ -854,220 +859,6 @@ func TestServerErrorInvalidSession(t *testing.T) {
}
}
func TestServerSessionClose(t *testing.T) {
var stream *ServerStream
var session *ServerSession
connClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) {
session = ctx.Session
},
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(connClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = &ServerStream{
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session.Close()
session.Close()
select {
case <-connClosed:
case <-time.After(2 * time.Second):
t.Errorf("should not happen")
}
_, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
})
require.Error(t, err)
}
func TestServerSessionAutoClose(t *testing.T) {
for _, ca := range []string{
"200", "400",
} {
t.Run(ca, func(t *testing.T) {
var stream *ServerStream
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onSessionClose: func(_ *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
if ca == "200" {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
}
return &base.Response{
StatusCode: base.StatusBadRequest,
}, nil, fmt.Errorf("error")
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = &ServerStream{
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mediaURL(t, desc.BaseURL, desc.Medias[0]),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
if ca == "200" {
require.Equal(t, base.StatusOK, res.StatusCode)
} else {
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
nconn.Close()
<-sessionClosed
})
}
}
func TestServerSessionTeardown(t *testing.T) {
var stream *ServerStream
s := &Server{
Handler: &testServerHandler{
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = &ServerStream{
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
stream.Initialize()
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session := readSession(t, res)
doTeardown(t, conn, "rtsp://localhost:8554/", session)
res, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
func TestServerAuth(t *testing.T) {
for _, method := range []string{"all", "basic", "digest_md5", "digest_sha256"} {
t.Run(method, func(t *testing.T) {
@@ -1207,3 +998,228 @@ func TestServerAuthFail(t *testing.T) {
_, err = writeReqReadRes(conn, req)
require.Error(t, err)
}
func TestServerSessionClose(t *testing.T) {
var stream *ServerStream
var session *ServerSession
connClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) {
session = ctx.Session
},
onConnClose: func(_ *ServerHandlerOnConnCloseCtx) {
close(connClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = &ServerStream{
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session.Close()
session.Close()
select {
case <-connClosed:
case <-time.After(2 * time.Second):
t.Errorf("should not happen")
}
_, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
},
})
require.Error(t, err)
}
func TestServerSessionAutoClose(t *testing.T) {
for _, ca := range []string{
"200", "400",
} {
t.Run(ca, func(t *testing.T) {
var stream *ServerStream
sessionClosed := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onSessionClose: func(_ *ServerHandlerOnSessionCloseCtx) {
close(sessionClosed)
},
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
if ca == "200" {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
}
return &base.Response{
StatusCode: base.StatusBadRequest,
}, nil, fmt.Errorf("error")
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = &ServerStream{
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
res, err := writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mediaURL(t, desc.BaseURL, desc.Medias[0]),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
if ca == "200" {
require.Equal(t, base.StatusOK, res.StatusCode)
} else {
require.Equal(t, base.StatusBadRequest, res.StatusCode)
}
nconn.Close()
<-sessionClosed
})
}
}
func TestServerSessionTeardown(t *testing.T) {
var stream *ServerStream
s := &Server{
Handler: &testServerHandler{
onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
stream = &ServerStream{
Server: s,
Desc: &description.Session{Medias: []*description.Media{testH264Media}},
}
err = stream.Initialize()
require.NoError(t, err)
defer stream.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
desc := doDescribe(t, conn)
inTH := &headers.Transport{
Protocol: headers.TransportProtocolTCP,
Delivery: deliveryPtr(headers.TransportDeliveryUnicast),
Mode: transportModePtr(headers.TransportModePlay),
InterleavedIDs: &[2]int{0, 1},
}
res, _ := doSetup(t, conn, mediaURL(t, desc.BaseURL, desc.Medias[0]).String(), inTH, "")
session := readSession(t, res)
doTeardown(t, conn, "rtsp://localhost:8554/", session)
res, err = writeReqReadRes(conn, base.Request{
Method: base.Options,
URL: mustParseURL("rtsp://localhost:8554/"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
}
func TestServerStreamErrorNoServer(t *testing.T) {
s := &Server{}
stream := &ServerStream{Server: s}
err := stream.Initialize()
require.Error(t, err)
}