make most methods thread safe (#882)

Client: Stats

ServerConn: Session, Stats

ServerSession: State, Stats, Medias, Path, Query, Stream,
SetuppedSecure, SetuppedTransport, AnnouncedDescription
This commit is contained in:
Alessandro Ros
2025-09-06 15:42:07 +02:00
committed by GitHub
parent 702cd0a70f
commit 3c2625c7cf
10 changed files with 167 additions and 59 deletions

View File

@@ -543,6 +543,7 @@ type Client struct {
ctx context.Context
ctxCancel func()
propsMutex sync.RWMutex
state clientState
nconn net.Conn
conn *conn.Conn
@@ -2023,16 +2024,19 @@ func (c *Client) doSetup(
udpRTPListener = nil
udpRTCPListener = nil
c.propsMutex.Lock()
if c.setuppedMedias == nil {
c.setuppedMedias = make(map[*description.Media]*clientMedia)
}
c.setuppedMedias[medi] = cm
c.baseURL = baseURL
c.setuppedTransport = &transport
c.setuppedProfile = th.Profile
c.propsMutex.Unlock()
if medi.IsBackChannel {
c.backChannelSetupped = true
} else {
@@ -2426,6 +2430,9 @@ func (c *Client) PacketNTP(medi *description.Media, pkt *rtp.Packet) (time.Time,
// Stats returns client statistics.
func (c *Client) Stats() *ClientStats {
c.propsMutex.RLock()
defer c.propsMutex.RUnlock()
mediaStats := func() map[*description.Media]StatsSessionMedia { //nolint:dupl
ret := make(map[*description.Media]StatsSessionMedia, len(c.setuppedMedias))

View File

@@ -125,21 +125,11 @@ func (cm *clientMedia) initialize() {
f.initialize()
cm.formats[forma.PayloadType()] = f
}
}
func (cm *clientMedia) close() {
cm.stop()
for _, ct := range cm.formats {
ct.close()
}
}
func (cm *clientMedia) start() {
if cm.udpRTPListener != nil {
cm.writePacketRTCPInQueue = cm.writePacketRTCPInQueueUDP
if cm.c.state == clientStateRecord || cm.media.IsBackChannel {
if cm.c.state == clientStatePreRecord || cm.media.IsBackChannel {
cm.udpRTPListener.readFunc = cm.readPacketRTPUDPRecord
cm.udpRTCPListener.readFunc = cm.readPacketRTCPUDPRecord
} else {
@@ -153,7 +143,7 @@ func (cm *clientMedia) start() {
cm.c.tcpCallbackByChannel = make(map[int]readFunc)
}
if cm.c.state == clientStateRecord || cm.media.IsBackChannel {
if cm.c.state == clientStatePreRecord || cm.media.IsBackChannel {
cm.c.tcpCallbackByChannel[cm.tcpChannel] = cm.readPacketRTPTCPRecord
cm.c.tcpCallbackByChannel[cm.tcpChannel+1] = cm.readPacketRTCPTCPRecord
} else {
@@ -161,7 +151,17 @@ func (cm *clientMedia) start() {
cm.c.tcpCallbackByChannel[cm.tcpChannel+1] = cm.readPacketRTCPTCPPlay
}
}
}
func (cm *clientMedia) close() {
cm.stop()
for _, ct := range cm.formats {
ct.close()
}
}
func (cm *clientMedia) start() {
if cm.udpRTPListener != nil {
cm.udpRTPListener.start()
cm.udpRTCPListener.start()

View File

@@ -640,6 +640,11 @@ func TestClientPlay(t *testing.T) {
sd, _, err := c.Describe(u)
require.NoError(t, err)
// test that properties can be accessed in parallel
go func() {
c.Stats()
}()
err = c.SetupAll(sd.BaseURL, sd.Medias)
require.NoError(t, err)

View File

@@ -108,9 +108,11 @@ func (u *clientUDPListener) start() {
}
func (u *clientUDPListener) stop() {
u.pc.SetReadDeadline(time.Now())
<-u.done
u.running = false
if u.running {
u.pc.SetReadDeadline(time.Now())
<-u.done
u.running = false
}
}
func (u *clientUDPListener) run() {

View File

@@ -9,6 +9,7 @@ import (
gourl "net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/bluenviron/gortsplib/v4/pkg/auth"
@@ -199,6 +200,7 @@ type ServerConn struct {
ctx context.Context
ctxCancel func()
propsMutex sync.RWMutex
userData interface{}
remoteAddr *net.TCPAddr
bc *bytecounter.ByteCounter
@@ -268,6 +270,9 @@ func (sc *ServerConn) UserData() interface{} {
// Session returns the associated session.
func (sc *ServerConn) Session() *ServerSession {
sc.propsMutex.RLock()
defer sc.propsMutex.RUnlock()
return sc.session
}
@@ -658,7 +663,11 @@ func (sc *ServerConn) handleRequestInSession(
}
res, session, err := sc.s.handleRequest(sreq)
sc.propsMutex.Lock()
sc.session = session
sc.propsMutex.Unlock()
return res, err
}

View File

@@ -665,7 +665,21 @@ func TestServerPlay(t *testing.T) {
StatusCode: base.StatusOK,
}, stream, nil
},
onSetup: func(_ *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
// test that properties can be accessed in parallel
go func() {
ctx.Conn.Session()
ctx.Conn.Stats()
ctx.Session.State()
ctx.Session.Stats()
ctx.Session.Medias()
ctx.Session.Path()
ctx.Session.Query()
ctx.Session.Stream()
ctx.Session.SetuppedTransport()
ctx.Session.SetuppedSecure()
}()
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
@@ -1726,7 +1740,13 @@ func TestServerPlayPause(t *testing.T) {
StatusCode: base.StatusOK,
}, nil
},
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) {
// test that properties can be accessed in parallel
go func() {
ctx.Session.State()
ctx.Session.Stats()
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil

View File

@@ -602,7 +602,14 @@ func TestServerRecord(t *testing.T) {
close(sessionClosed)
},
onAnnounce: func(_ *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
// test that properties can be accessed in parallel
go func() {
ctx.Session.State()
ctx.Session.Stats()
ctx.Session.AnnouncedDescription()
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
@@ -1732,7 +1739,13 @@ func TestServerRecordPausePause(t *testing.T) {
StatusCode: base.StatusOK,
}, nil
},
onPause: func(_ *ServerHandlerOnPauseCtx) (*base.Response, error) {
onPause: func(ctx *ServerHandlerOnPauseCtx) (*base.Response, error) {
// test that properties can be accessed in parallel
go func() {
ctx.Session.State()
ctx.Session.Stats()
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil

View File

@@ -428,8 +428,9 @@ type ServerSession struct {
secretID string // must not be shared, allows to take ownership of the session
ctx context.Context
ctxCancel func()
userData interface{}
propsMutex sync.RWMutex
conns map[*ServerConn]struct{}
userData interface{}
state ServerSessionState
setuppedMedias map[*description.Media]*serverSessionMedia
setuppedMediasOrdered []*serverSessionMedia
@@ -506,11 +507,17 @@ func (ss *ServerSession) BytesSent() uint64 {
// State returns the state of the session.
func (ss *ServerSession) State() ServerSessionState {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return ss.state
}
// SetuppedTransport returns the transport negotiated during SETUP.
func (ss *ServerSession) SetuppedTransport() *Transport {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return ss.setuppedTransport
}
@@ -518,47 +525,62 @@ func (ss *ServerSession) SetuppedTransport() *Transport {
// If this is false, it does not mean that the stream is not secure, since
// there are some combinations that are secure nonetheless, like RTSPS+TCP+unsecure.
func (ss *ServerSession) SetuppedSecure() bool {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return isSecure(ss.setuppedProfile)
}
// SetuppedStream returns the stream associated with the session.
//
// Deprecated: replaced by Stream
// Deprecated: replaced by Stream.
func (ss *ServerSession) SetuppedStream() *ServerStream {
return ss.Stream()
}
// Stream returns the stream associated with the session.
func (ss *ServerSession) Stream() *ServerStream {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return ss.setuppedStream
}
// SetuppedPath returns the path sent during SETUP or ANNOUNCE.
//
// Deprecated: replaced by Path
// Deprecated: replaced by Path.
func (ss *ServerSession) SetuppedPath() string {
return ss.Path()
}
// Path returns the path sent during SETUP or ANNOUNCE.
func (ss *ServerSession) Path() string {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return ss.setuppedPath
}
// SetuppedQuery returns the query sent during SETUP or ANNOUNCE.
//
// Deprecated: replaced by Medias.
// Deprecated: replaced by Query.
func (ss *ServerSession) SetuppedQuery() string {
return ss.Query()
}
// Query returns the query sent during SETUP or ANNOUNCE.
func (ss *ServerSession) Query() string {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return ss.setuppedQuery
}
// AnnouncedDescription returns the announced stream description.
func (ss *ServerSession) AnnouncedDescription() *description.Session {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
return ss.announcedDesc
}
@@ -571,6 +593,9 @@ func (ss *ServerSession) SetuppedMedias() []*description.Media {
// Medias returns setupped medias.
func (ss *ServerSession) Medias() []*description.Media {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
ret := make([]*description.Media, len(ss.setuppedMedias))
for i, sm := range ss.setuppedMediasOrdered {
ret[i] = sm.media
@@ -590,6 +615,9 @@ func (ss *ServerSession) UserData() interface{} {
// Stats returns server session statistics.
func (ss *ServerSession) Stats() *StatsSession {
ss.propsMutex.RLock()
defer ss.propsMutex.RUnlock()
mediaStats := func() map[*description.Media]StatsSessionMedia { //nolint:dupl
ret := make(map[*description.Media]StatsSessionMedia, len(ss.setuppedMedias))
@@ -856,10 +884,14 @@ func (ss *ServerSession) run() {
ss.setuppedStream.readerRemove(ss)
}
ss.propsMutex.Lock()
for _, sm := range ss.setuppedMedias {
sm.close()
}
ss.propsMutex.Unlock()
if ss.writer != nil {
ss.destroyWriter()
}
@@ -1094,10 +1126,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
})
if res.StatusCode == base.StatusOK {
ss.propsMutex.Lock()
ss.state = ServerSessionStatePreRecord
ss.setuppedPath = path
ss.setuppedQuery = query
ss.announcedDesc = &desc
ss.propsMutex.Unlock()
}
return res, err
@@ -1272,6 +1306,12 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
panic("stream cannot be nil when StatusCode is StatusOK")
}
if ss.state == ServerSessionStatePrePlay {
if stream != ss.setuppedStream {
panic("stream cannot be different than the one returned in previous OnSetup call")
}
}
medi = findMediaByTrackID(stream.Desc.Medias, trackID)
default: // record
medi = findMediaByURL(ss.announcedDesc.Medias, path, query, req.URL)
@@ -1289,34 +1329,23 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, liberrors.ErrServerMediaAlreadySetup{}
}
ss.setuppedTransport = &transport
ss.setuppedProfile = inTH.Profile
if ss.state == ServerSessionStateInitial {
err = stream.readerAdd(ss,
inTH.ClientPorts,
transport,
)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
ss.state = ServerSessionStatePrePlay
ss.setuppedPath = path
ss.setuppedQuery = query
ss.setuppedStream = stream
}
th := headers.Transport{
Profile: inTH.Profile,
}
if ss.state == ServerSessionStatePrePlay {
if stream != ss.setuppedStream {
panic("stream cannot be different than the one returned in previous OnSetup call")
}
if ss.state == ServerSessionStateInitial || ss.state == ServerSessionStatePrePlay {
// Fill SSRC if there is a single SSRC only
// since the Transport header does not support multiple SSRCs.
if len(stream.medias[medi].formats) == 1 {
@@ -1343,7 +1372,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
} else {
localSSRCs = make(map[uint8]uint32)
for forma, data := range ss.setuppedStream.medias[medi].formats {
for forma, data := range stream.medias[medi].formats {
localSSRCs[forma] = data.localSSRC
}
}
@@ -1371,7 +1400,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}, err
}
} else {
srtpOutCtx = ss.setuppedStream.medias[medi].srtpOutCtx
srtpOutCtx = stream.medias[medi].srtpOutCtx
}
}
@@ -1429,6 +1458,11 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
th.InterleavedIDs = &[2]int{tcpChannel, tcpChannel + 1}
}
ss.propsMutex.Lock()
ss.setuppedTransport = &transport
ss.setuppedProfile = inTH.Profile
sm := &serverSessionMedia{
ss: ss,
media: medi,
@@ -1447,10 +1481,18 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
if ss.setuppedMedias == nil {
ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia)
}
ss.setuppedMedias[medi] = sm
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
if ss.state == ServerSessionStateInitial {
ss.state = ServerSessionStatePrePlay
ss.setuppedPath = path
ss.setuppedQuery = query
ss.setuppedStream = stream
}
ss.propsMutex.Unlock()
res.Header["Transport"] = th.Marshal()
if isSecure(inTH.Profile) {
@@ -1514,7 +1556,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
if res.StatusCode == base.StatusOK {
if ss.state != ServerSessionStatePlay {
ss.propsMutex.Lock()
ss.state = ServerSessionStatePlay
ss.propsMutex.Unlock()
v := ss.s.timeNow().Unix()
ss.udpLastPacketTime = &v
@@ -1685,7 +1729,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
switch ss.state {
case ServerSessionStatePlay:
ss.propsMutex.Lock()
ss.state = ServerSessionStatePrePlay
ss.propsMutex.Unlock()
switch *ss.setuppedTransport {
case TransportUDP:
@@ -1709,7 +1755,9 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
ss.tcpConn = nil
}
ss.propsMutex.Lock()
ss.state = ServerSessionStatePreRecord
ss.propsMutex.Unlock()
}
}
}

View File

@@ -57,6 +57,26 @@ func (sm *serverSessionMedia) initialize() {
f.initialize()
sm.formats[forma.PayloadType()] = f
}
switch *sm.ss.setuppedTransport {
case TransportUDP, TransportUDPMulticast:
sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP
case TransportTCP:
sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueTCP
if sm.ss.tcpCallbackByChannel == nil {
sm.ss.tcpCallbackByChannel = make(map[int]readFunc)
}
if sm.ss.state == ServerSessionStateInitial || sm.ss.state == ServerSessionStatePrePlay {
sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPPlay
sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPPlay
} else {
sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPRecord
sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPRecord
}
}
}
func (sm *serverSessionMedia) close() {
@@ -70,8 +90,6 @@ func (sm *serverSessionMedia) close() {
func (sm *serverSessionMedia) start() error {
switch *sm.ss.setuppedTransport {
case TransportUDP, TransportUDPMulticast:
sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP
if *sm.ss.setuppedTransport == TransportUDP {
if sm.ss.state == ServerSessionStatePlay {
if sm.media.IsBackChannel {
@@ -112,21 +130,6 @@ func (sm *serverSessionMedia) start() error {
sm.ss.s.udpRTCPListener.addClient(sm.ss.author.ip(), sm.udpRTCPReadPort, sm.readPacketRTCPUDPRecord)
}
}
case TransportTCP:
sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueTCP
if sm.ss.tcpCallbackByChannel == nil {
sm.ss.tcpCallbackByChannel = make(map[int]readFunc)
}
if sm.ss.state == ServerSessionStatePlay {
sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPPlay
sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPPlay
} else {
sm.ss.tcpCallbackByChannel[sm.tcpChannel] = sm.readPacketRTPTCPRecord
sm.ss.tcpCallbackByChannel[sm.tcpChannel+1] = sm.readPacketRTCPTCPRecord
}
}
return nil

View File

@@ -211,6 +211,7 @@ func (st *ServerStream) Stats() *ServerStreamStats {
func (st *ServerStream) readerAdd(
ss *ServerSession,
clientPorts *[2]int,
protocol Transport,
) error {
st.mutex.Lock()
defer st.mutex.Unlock()
@@ -219,11 +220,11 @@ func (st *ServerStream) readerAdd(
return liberrors.ErrServerStreamClosed{}
}
switch *ss.setuppedTransport {
switch protocol {
case TransportUDP:
// check whether UDP ports and IP are already assigned to another reader
for r := range st.readers {
if *r.setuppedTransport == TransportUDP &&
if protocol == TransportUDP &&
r.author.ip().Equal(ss.author.ip()) &&
r.author.zone() == ss.author.zone() {
for _, rt := range r.setuppedMedias {