server: make ServerStream.Close() thread-safe

This commit is contained in:
aler9
2022-10-28 16:22:31 +02:00
parent 630b1ebce1
commit cc9dcae08a

View File

@@ -35,6 +35,7 @@ type ServerStream struct {
activeUnicastReaders map[*ServerSession]struct{} activeUnicastReaders map[*ServerSession]struct{}
readers map[*ServerSession]struct{} readers map[*ServerSession]struct{}
streamTracks []*serverStreamTrack streamTracks []*serverStreamTrack
closed bool
} }
// NewServerStream allocates a ServerStream. // NewServerStream allocates a ServerStream.
@@ -59,7 +60,8 @@ func NewServerStream(tracks Tracks) *ServerStream {
// Close closes a ServerStream. // Close closes a ServerStream.
func (st *ServerStream) Close() error { func (st *ServerStream) Close() error {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() st.closed = true
st.mutex.Unlock()
for ss := range st.readers { for ss := range st.readers {
ss.Close() ss.Close()
@@ -71,8 +73,6 @@ func (st *ServerStream) Close() error {
} }
} }
st.readers = nil
return nil return nil
} }
@@ -119,12 +119,33 @@ func (st *ServerStream) readerAdd(
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
if st.readers == nil { if st.closed {
return fmt.Errorf("stream is closed") return fmt.Errorf("stream is closed")
} }
// check whether UDP ports are already assigned to another reader if st.s == nil {
if transport == TransportUDP { st.s = ss.s
for trackID, track := range st.streamTracks {
// always generate RTCP sender reports.
// they're mandatory when transport protocol is UDP or UDP-multicast.
// they're also needed when transport protocol is TCP and client is Nvidia Deepstream
// since they're used to compute NTP timestamp of frames:
// https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html
cTrackID := trackID
track.rtcpSender = rtcpsender.New(
st.s.udpSenderReportPeriod,
st.tracks[trackID].ClockRate(),
func(pkt rtcp.Packet) {
st.WritePacketRTCP(cTrackID, pkt)
},
)
}
}
switch transport {
case TransportUDP:
// check whether UDP ports are already assigned to another reader
for r := range st.readers { for r := range st.readers {
if *r.setuppedTransport == TransportUDP && if *r.setuppedTransport == TransportUDP &&
r.author.ip().Equal(ss.author.ip()) && r.author.ip().Equal(ss.author.ip()) &&
@@ -136,31 +157,9 @@ func (st *ServerStream) readerAdd(
} }
} }
} }
}
if st.s == nil { case TransportUDPMulticast:
st.s = ss.s // allocate multicast listeners
for trackID, track := range st.streamTracks {
cTrackID := trackID
// always generate RTCP sender reports.
// they're mandatory when transport protocol is UDP or UDP-multicast.
// they're also needed when transport protocol is TCP and client is Nvidia Deepstream
// since they're used to compute NTP timestamp of frames:
// https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html
track.rtcpSender = rtcpsender.New(
st.s.udpSenderReportPeriod,
st.tracks[trackID].ClockRate(),
func(pkt rtcp.Packet) {
st.WritePacketRTCP(cTrackID, pkt)
},
)
}
}
// allocate multicast listeners
if transport == TransportUDPMulticast {
for _, track := range st.streamTracks { for _, track := range st.streamTracks {
if track.multicastHandler == nil { if track.multicastHandler == nil {
mh, err := newServerMulticastHandler(st.s) mh, err := newServerMulticastHandler(st.s)
@@ -181,6 +180,10 @@ func (st *ServerStream) readerRemove(ss *ServerSession) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
if st.closed {
return
}
delete(st.readers, ss) delete(st.readers, ss)
if len(st.readers) == 0 { if len(st.readers) == 0 {
@@ -197,6 +200,10 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
if st.closed {
return
}
if *ss.setuppedTransport == TransportUDPMulticast { if *ss.setuppedTransport == TransportUDPMulticast {
for trackID, track := range ss.setuppedTracks { for trackID, track := range ss.setuppedTracks {
st.streamTracks[trackID].multicastHandler.rtcpl.addClient( st.streamTracks[trackID].multicastHandler.rtcpl.addClient(
@@ -211,6 +218,10 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
if st.closed {
return
}
if *ss.setuppedTransport == TransportUDPMulticast { if *ss.setuppedTransport == TransportUDPMulticast {
for trackID := range ss.setuppedTracks { for trackID := range ss.setuppedTracks {
st.streamTracks[trackID].multicastHandler.rtcpl.removeClient(ss) st.streamTracks[trackID].multicastHandler.rtcpl.removeClient(ss)
@@ -232,6 +243,10 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
if st.closed {
return
}
track := st.streamTracks[trackID] track := st.streamTracks[trackID]
now := time.Now() now := time.Now()
@@ -269,6 +284,10 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) {
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
if st.closed {
return
}
// send unicast // send unicast
for r := range st.activeUnicastReaders { for r := range st.activeUnicastReaders {
r.writePacketRTCP(trackID, byts) r.writePacketRTCP(trackID, byts)