diff --git a/serverstream.go b/serverstream.go index fdc68603..95bf363b 100644 --- a/serverstream.go +++ b/serverstream.go @@ -35,6 +35,7 @@ type ServerStream struct { activeUnicastReaders map[*ServerSession]struct{} readers map[*ServerSession]struct{} streamTracks []*serverStreamTrack + closed bool } // NewServerStream allocates a ServerStream. @@ -59,7 +60,8 @@ func NewServerStream(tracks Tracks) *ServerStream { // Close closes a ServerStream. func (st *ServerStream) Close() error { st.mutex.Lock() - defer st.mutex.Unlock() + st.closed = true + st.mutex.Unlock() for ss := range st.readers { ss.Close() @@ -71,8 +73,6 @@ func (st *ServerStream) Close() error { } } - st.readers = nil - return nil } @@ -119,12 +119,33 @@ func (st *ServerStream) readerAdd( st.mutex.Lock() defer st.mutex.Unlock() - if st.readers == nil { + if st.closed { return fmt.Errorf("stream is closed") } - // check whether UDP ports are already assigned to another reader - if transport == TransportUDP { + if st.s == nil { + st.s = ss.s + + for trackID, track := range st.streamTracks { + // always generate RTCP sender reports. + // they're mandatory when transport protocol is UDP or UDP-multicast. + // they're also needed when transport protocol is TCP and client is Nvidia Deepstream + // since they're used to compute NTP timestamp of frames: + // https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html + cTrackID := trackID + track.rtcpSender = rtcpsender.New( + st.s.udpSenderReportPeriod, + st.tracks[trackID].ClockRate(), + func(pkt rtcp.Packet) { + st.WritePacketRTCP(cTrackID, pkt) + }, + ) + } + } + + switch transport { + case TransportUDP: + // check whether UDP ports are already assigned to another reader for r := range st.readers { if *r.setuppedTransport == TransportUDP && r.author.ip().Equal(ss.author.ip()) && @@ -136,31 +157,9 @@ func (st *ServerStream) readerAdd( } } } - } - if st.s == nil { - st.s = ss.s - - for trackID, track := range st.streamTracks { - cTrackID := trackID - - // always generate RTCP sender reports. - // they're mandatory when transport protocol is UDP or UDP-multicast. - // they're also needed when transport protocol is TCP and client is Nvidia Deepstream - // since they're used to compute NTP timestamp of frames: - // https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html - track.rtcpSender = rtcpsender.New( - st.s.udpSenderReportPeriod, - st.tracks[trackID].ClockRate(), - func(pkt rtcp.Packet) { - st.WritePacketRTCP(cTrackID, pkt) - }, - ) - } - } - - // allocate multicast listeners - if transport == TransportUDPMulticast { + case TransportUDPMulticast: + // allocate multicast listeners for _, track := range st.streamTracks { if track.multicastHandler == nil { mh, err := newServerMulticastHandler(st.s) @@ -181,6 +180,10 @@ func (st *ServerStream) readerRemove(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() + if st.closed { + return + } + delete(st.readers, ss) if len(st.readers) == 0 { @@ -197,6 +200,10 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() + if st.closed { + return + } + if *ss.setuppedTransport == TransportUDPMulticast { for trackID, track := range ss.setuppedTracks { st.streamTracks[trackID].multicastHandler.rtcpl.addClient( @@ -211,6 +218,10 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() + if st.closed { + return + } + if *ss.setuppedTransport == TransportUDPMulticast { for trackID := range ss.setuppedTracks { st.streamTracks[trackID].multicastHandler.rtcpl.removeClient(ss) @@ -232,6 +243,10 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT st.mutex.RLock() defer st.mutex.RUnlock() + if st.closed { + return + } + track := st.streamTracks[trackID] now := time.Now() @@ -269,6 +284,10 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) { st.mutex.RLock() defer st.mutex.RUnlock() + if st.closed { + return + } + // send unicast for r := range st.activeUnicastReaders { r.writePacketRTCP(trackID, byts)