From 3c2625c7cfd7b099e912c8b8275bdcc294a75f61 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Sat, 6 Sep 2025 15:42:07 +0200 Subject: [PATCH] make most methods thread safe (#882) Client: Stats ServerConn: Session, Stats ServerSession: State, Stats, Medias, Path, Query, Stream, SetuppedSecure, SetuppedTransport, AnnouncedDescription --- client.go | 9 ++++- client_media.go | 24 +++++------ client_play_test.go | 5 +++ client_udp_listener.go | 8 ++-- server_conn.go | 9 +++++ server_play_test.go | 24 ++++++++++- server_record_test.go | 17 +++++++- server_session.go | 88 +++++++++++++++++++++++++++++++---------- server_session_media.go | 37 +++++++++-------- server_stream.go | 5 ++- 10 files changed, 167 insertions(+), 59 deletions(-) diff --git a/client.go b/client.go index bc9f43f2..c6020762 100644 --- a/client.go +++ b/client.go @@ -543,6 +543,7 @@ type Client struct { ctx context.Context ctxCancel func() + propsMutex sync.RWMutex state clientState nconn net.Conn conn *conn.Conn @@ -2023,16 +2024,19 @@ func (c *Client) doSetup( udpRTPListener = nil udpRTCPListener = nil + c.propsMutex.Lock() + if c.setuppedMedias == nil { c.setuppedMedias = make(map[*description.Media]*clientMedia) } - c.setuppedMedias[medi] = cm c.baseURL = baseURL c.setuppedTransport = &transport c.setuppedProfile = th.Profile + c.propsMutex.Unlock() + if medi.IsBackChannel { c.backChannelSetupped = true } else { @@ -2426,6 +2430,9 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time, // Stats returns client statistics. func (c *Client) Stats() *ClientStats { + c.propsMutex.RLock() + defer c.propsMutex.RUnlock() + mediaStats := func() map[*description.Media]StatsSessionMedia { //nolint:dupl ret := make(map[*description.Media]StatsSessionMedia, len(c.setuppedMedias)) diff --git a/client_media.go b/client_media.go index c39a162d..0a376a72 100644 --- a/client_media.go +++ b/client_media.go @@ -125,21 +125,11 @@ func (cm *clientMedia) initialize() { f.initialize() cm.formats[forma.PayloadType()] = f } -} -func (cm *clientMedia) close() { - cm.stop() - - for _, ct := range cm.formats { - ct.close() - } -} - -func (cm *clientMedia) start() { if cm.udpRTPListener != nil { cm.writePacketRTCPInQueue = cm.writePacketRTCPInQueueUDP - if cm.c.state == clientStateRecord || cm.media.IsBackChannel { + if cm.c.state == clientStatePreRecord || cm.media.IsBackChannel { cm.udpRTPListener.readFunc = cm.readPacketRTPUDPRecord cm.udpRTCPListener.readFunc = cm.readPacketRTCPUDPRecord } else { @@ -153,7 +143,7 @@ func (cm *clientMedia) start() { cm.c.tcpCallbackByChannel = make(map[int]readFunc) } - if cm.c.state == clientStateRecord || cm.media.IsBackChannel { + if cm.c.state == clientStatePreRecord || cm.media.IsBackChannel { cm.c.tcpCallbackByChannel[cm.tcpChannel] = cm.readPacketRTPTCPRecord cm.c.tcpCallbackByChannel[cm.tcpChannel+1] = cm.readPacketRTCPTCPRecord } else { @@ -161,7 +151,17 @@ func (cm *clientMedia) start() { cm.c.tcpCallbackByChannel[cm.tcpChannel+1] = cm.readPacketRTCPTCPPlay } } +} +func (cm *clientMedia) close() { + cm.stop() + + for _, ct := range cm.formats { + ct.close() + } +} + +func (cm *clientMedia) start() { if cm.udpRTPListener != nil { cm.udpRTPListener.start() cm.udpRTCPListener.start() diff --git a/client_play_test.go b/client_play_test.go index fb1c43f4..77465d0c 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -640,6 +640,11 @@ func TestClientPlay(t *testing.T) { sd, _, err := c.Describe(u) require.NoError(t, err) + // test that properties can be accessed in parallel + go func() { + c.Stats() + }() + err = c.SetupAll(sd.BaseURL, sd.Medias) require.NoError(t, err) diff --git a/client_udp_listener.go b/client_udp_listener.go index 0233a891..853e3e93 100644 --- a/client_udp_listener.go +++ b/client_udp_listener.go @@ -108,9 +108,11 @@ func (u *clientUDPListener) start() { } func (u *clientUDPListener) stop() { - u.pc.SetReadDeadline(time.Now()) - <-u.done - u.running = false + if u.running { + u.pc.SetReadDeadline(time.Now()) + <-u.done + u.running = false + } } func (u *clientUDPListener) run() { diff --git a/server_conn.go b/server_conn.go index cab14c3a..5be5087f 100644 --- a/server_conn.go +++ b/server_conn.go @@ -9,6 +9,7 @@ import ( gourl "net/url" "strconv" "strings" + "sync" "time" "github.com/bluenviron/gortsplib/v4/pkg/auth" @@ -199,6 +200,7 @@ type ServerConn struct { ctx context.Context ctxCancel func() + propsMutex sync.RWMutex userData interface{} remoteAddr *net.TCPAddr bc *bytecounter.ByteCounter @@ -268,6 +270,9 @@ func (sc *ServerConn) UserData() interface{} { // Session returns the associated session. func (sc *ServerConn) Session() *ServerSession { + sc.propsMutex.RLock() + defer sc.propsMutex.RUnlock() + return sc.session } @@ -658,7 +663,11 @@ func (sc *ServerConn) handleRequestInSession( } res, session, err := sc.s.handleRequest(sreq) + + sc.propsMutex.Lock() sc.session = session + sc.propsMutex.Unlock() + return res, err } diff --git a/server_play_test.go b/server_play_test.go index 0cb572ca..c9c5e74c 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -665,7 +665,21 @@ func TestServerPlay(t *testing.T) { StatusCode: base.StatusOK, }, stream, nil }, - onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + // test that properties can be accessed in parallel + go func() { + ctx.Conn.Session() + ctx.Conn.Stats() + ctx.Session.State() + ctx.Session.Stats() + ctx.Session.Medias() + ctx.Session.Path() + ctx.Session.Query() + ctx.Session.Stream() + ctx.Session.SetuppedTransport() + ctx.Session.SetuppedSecure() + }() + return &base.Response{ StatusCode: base.StatusOK, }, stream, nil @@ -1726,7 +1740,13 @@ func TestServerPlayPause(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) { + onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) { + // test that properties can be accessed in parallel + go func() { + ctx.Session.State() + ctx.Session.Stats() + }() + return &base.Response{ StatusCode: base.StatusOK, }, nil diff --git a/server_record_test.go b/server_record_test.go index 193eb560..7607d36b 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -602,7 +602,14 @@ func TestServerRecord(t *testing.T) { close(sessionClosed) }, - onAnnounce: func(_ *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + // test that properties can be accessed in parallel + go func() { + ctx.Session.State() + ctx.Session.Stats() + ctx.Session.AnnouncedDescription() + }() + return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -1732,7 +1739,13 @@ func TestServerRecordPausePause(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) { + onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) { + // test that properties can be accessed in parallel + go func() { + ctx.Session.State() + ctx.Session.Stats() + }() + return &base.Response{ StatusCode: base.StatusOK, }, nil diff --git a/server_session.go b/server_session.go index 48daf9c7..833bc701 100644 --- a/server_session.go +++ b/server_session.go @@ -428,8 +428,9 @@ type ServerSession struct { secretID string // must not be shared, allows to take ownership of the session ctx context.Context ctxCancel func() - userData interface{} + propsMutex sync.RWMutex conns map[*ServerConn]struct{} + userData interface{} state ServerSessionState setuppedMedias map[*description.Media]*serverSessionMedia setuppedMediasOrdered []*serverSessionMedia @@ -506,11 +507,17 @@ func (ss *ServerSession) BytesSent() uint64 { // State returns the state of the session. func (ss *ServerSession) State() ServerSessionState { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return ss.state } // SetuppedTransport returns the transport negotiated during SETUP. func (ss *ServerSession) SetuppedTransport() *Transport { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return ss.setuppedTransport } @@ -518,47 +525,62 @@ func (ss *ServerSession) SetuppedTransport() *Transport { // If this is false, it does not mean that the stream is not secure, since // there are some combinations that are secure nonetheless, like RTSPS+TCP+unsecure. func (ss *ServerSession) SetuppedSecure() bool { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return isSecure(ss.setuppedProfile) } // SetuppedStream returns the stream associated with the session. // -// Deprecated: replaced by Stream +// Deprecated: replaced by Stream. func (ss *ServerSession) SetuppedStream() *ServerStream { return ss.Stream() } // Stream returns the stream associated with the session. func (ss *ServerSession) Stream() *ServerStream { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return ss.setuppedStream } // SetuppedPath returns the path sent during SETUP or ANNOUNCE. // -// Deprecated: replaced by Path +// Deprecated: replaced by Path. func (ss *ServerSession) SetuppedPath() string { return ss.Path() } // Path returns the path sent during SETUP or ANNOUNCE. func (ss *ServerSession) Path() string { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return ss.setuppedPath } // SetuppedQuery returns the query sent during SETUP or ANNOUNCE. // -// Deprecated: replaced by Medias. +// Deprecated: replaced by Query. func (ss *ServerSession) SetuppedQuery() string { return ss.Query() } // Query returns the query sent during SETUP or ANNOUNCE. func (ss *ServerSession) Query() string { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return ss.setuppedQuery } // AnnouncedDescription returns the announced stream description. func (ss *ServerSession) AnnouncedDescription() *description.Session { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + return ss.announcedDesc } @@ -571,6 +593,9 @@ func (ss *ServerSession) SetuppedMedias() []*description.Media { // Medias returns setupped medias. func (ss *ServerSession) Medias() []*description.Media { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + ret := make([]*description.Media, len(ss.setuppedMedias)) for i, sm := range ss.setuppedMediasOrdered { ret[i] = sm.media @@ -590,6 +615,9 @@ func (ss *ServerSession) UserData() interface{} { // Stats returns server session statistics. func (ss *ServerSession) Stats() *StatsSession { + ss.propsMutex.RLock() + defer ss.propsMutex.RUnlock() + mediaStats := func() map[*description.Media]StatsSessionMedia { //nolint:dupl ret := make(map[*description.Media]StatsSessionMedia, len(ss.setuppedMedias)) @@ -856,10 +884,14 @@ func (ss *ServerSession) run() { ss.setuppedStream.readerRemove(ss) } + ss.propsMutex.Lock() + for _, sm := range ss.setuppedMedias { sm.close() } + ss.propsMutex.Unlock() + if ss.writer != nil { ss.destroyWriter() } @@ -1094,10 +1126,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }) if res.StatusCode == base.StatusOK { + ss.propsMutex.Lock() ss.state = ServerSessionStatePreRecord ss.setuppedPath = path ss.setuppedQuery = query ss.announcedDesc = &desc + ss.propsMutex.Unlock() } return res, err @@ -1272,6 +1306,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( panic("stream cannot be nil when StatusCode is StatusOK") } + if ss.state == ServerSessionStatePrePlay { + if stream != ss.setuppedStream { + panic("stream cannot be different than the one returned in previous OnSetup call") + } + } + medi = findMediaByTrackID(stream.Desc.Medias, trackID) default: // record medi = findMediaByURL(ss.announcedDesc.Medias, path, query, req.URL) @@ -1289,34 +1329,23 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }, liberrors.ErrServerMediaAlreadySetup{} } - ss.setuppedTransport = &transport - ss.setuppedProfile = inTH.Profile - if ss.state == ServerSessionStateInitial { err = stream.readerAdd(ss, inTH.ClientPorts, + transport, ) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, }, err } - - ss.state = ServerSessionStatePrePlay - ss.setuppedPath = path - ss.setuppedQuery = query - ss.setuppedStream = stream } th := headers.Transport{ Profile: inTH.Profile, } - if ss.state == ServerSessionStatePrePlay { - if stream != ss.setuppedStream { - panic("stream cannot be different than the one returned in previous OnSetup call") - } - + if ss.state == ServerSessionStateInitial || ss.state == ServerSessionStatePrePlay { // Fill SSRC if there is a single SSRC only // since the Transport header does not support multiple SSRCs. if len(stream.medias[medi].formats) == 1 { @@ -1343,7 +1372,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( } } else { localSSRCs = make(map[uint8]uint32) - for forma, data := range ss.setuppedStream.medias[medi].formats { + for forma, data := range stream.medias[medi].formats { localSSRCs[forma] = data.localSSRC } } @@ -1371,7 +1400,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }, err } } else { - srtpOutCtx = ss.setuppedStream.medias[medi].srtpOutCtx + srtpOutCtx = stream.medias[medi].srtpOutCtx } } @@ -1429,6 +1458,11 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( th.InterleavedIDs = &[2]int{tcpChannel, tcpChannel + 1} } + ss.propsMutex.Lock() + + ss.setuppedTransport = &transport + ss.setuppedProfile = inTH.Profile + sm := &serverSessionMedia{ ss: ss, media: medi, @@ -1447,10 +1481,18 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( if ss.setuppedMedias == nil { ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia) } - ss.setuppedMedias[medi] = sm ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm) + if ss.state == ServerSessionStateInitial { + ss.state = ServerSessionStatePrePlay + ss.setuppedPath = path + ss.setuppedQuery = query + ss.setuppedStream = stream + } + + ss.propsMutex.Unlock() + res.Header["Transport"] = th.Marshal() if isSecure(inTH.Profile) { @@ -1514,7 +1556,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( if res.StatusCode == base.StatusOK { if ss.state != ServerSessionStatePlay { + ss.propsMutex.Lock() ss.state = ServerSessionStatePlay + ss.propsMutex.Unlock() v := ss.s.timeNow().Unix() ss.udpLastPacketTime = &v @@ -1685,7 +1729,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( switch ss.state { case ServerSessionStatePlay: + ss.propsMutex.Lock() ss.state = ServerSessionStatePrePlay + ss.propsMutex.Unlock() switch *ss.setuppedTransport { case TransportUDP: @@ -1709,7 +1755,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( ss.tcpConn = nil } + ss.propsMutex.Lock() ss.state = ServerSessionStatePreRecord + ss.propsMutex.Unlock() } } } diff --git a/server_session_media.go b/server_session_media.go index fe0c2355..07077f35 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -57,6 +57,26 @@ func (sm *serverSessionMedia) initialize() { f.initialize() sm.formats[forma.PayloadType()] = f } + + switch *sm.ss.setuppedTransport { + case TransportUDP, TransportUDPMulticast: + sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP + + case TransportTCP: + sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueTCP + + if sm.ss.tcpCallbackByChannel == nil { + sm.ss.tcpCallbackByChannel = make(map[int]readFunc) + } + + if sm.ss.state == ServerSessionStateInitial || sm.ss.state == ServerSessionStatePrePlay { + sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPPlay + sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPPlay + } else { + sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPRecord + sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPRecord + } + } } func (sm *serverSessionMedia) close() { @@ -70,8 +90,6 @@ func (sm *serverSessionMedia) close() { func (sm *serverSessionMedia) start() error { switch *sm.ss.setuppedTransport { case TransportUDP, TransportUDPMulticast: - sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP - if *sm.ss.setuppedTransport == TransportUDP { if sm.ss.state == ServerSessionStatePlay { if sm.media.IsBackChannel { @@ -112,21 +130,6 @@ func (sm *serverSessionMedia) start() error { sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readPacketRTCPUDPRecord) } } - - case TransportTCP: - sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueTCP - - if sm.ss.tcpCallbackByChannel == nil { - sm.ss.tcpCallbackByChannel = make(map[int]readFunc) - } - - if sm.ss.state == ServerSessionStatePlay { - sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPPlay - sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPPlay - } else { - sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPRecord - sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPRecord - } } return nil diff --git a/server_stream.go b/server_stream.go index f1549044..6283181a 100644 --- a/server_stream.go +++ b/server_stream.go @@ -211,6 +211,7 @@ func (st *ServerStream) Stats() *ServerStreamStats { func (st *ServerStream) readerAdd( ss *ServerSession, clientPorts *[2]int, + protocol Transport, ) error { st.mutex.Lock() defer st.mutex.Unlock() @@ -219,11 +220,11 @@ func (st *ServerStream) readerAdd( return liberrors.ErrServerStreamClosed{} } - switch *ss.setuppedTransport { + switch protocol { case TransportUDP: // check whether UDP ports and IP are already assigned to another reader for r := range st.readers { - if *r.setuppedTransport == TransportUDP && + if protocol == TransportUDP && r.author.ip().Equal(ss.author.ip()) && r.author.zone() == ss.author.zone() { for _, rt := range r.setuppedMedias {