diff --git a/serversession.go b/serversession.go index 6d0b668e..62c32d0c 100644 --- a/serversession.go +++ b/serversession.go @@ -771,7 +771,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base th.Delivery = &de v := uint(127) th.TTL = &v - d := stream.serverMulticastHandlers[trackID].ip() + d := stream.streamTracks[trackID].multicastHandler.ip() th.Destination = &d th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort} diff --git a/serverstream.go b/serverstream.go index 50745ae7..13cd8495 100644 --- a/serverstream.go +++ b/serverstream.go @@ -18,23 +18,23 @@ type serverStreamTrack struct { lastSSRC uint32 lastTimeRTP uint32 lastTimeNTP time.Time - udpRTCPSender *rtcpsender.RTCPSender + rtcpSender *rtcpsender.RTCPSender + multicastHandler *serverMulticastHandler } // ServerStream represents a data stream. // This is in charge of // - distributing the stream to each reader // - allocating multicast listeners -// - gathering infos about the stream to generate SSRC and RTP-Info +// - gathering infos about the stream in order to generate SSRC and RTP-Info type ServerStream struct { tracks Tracks - mutex sync.RWMutex - s *Server - activeUnicastReaders map[*ServerSession]struct{} - readers map[*ServerSession]struct{} - serverMulticastHandlers []*serverMulticastHandler - ssTracks []*serverStreamTrack + mutex sync.RWMutex + s *Server + activeUnicastReaders map[*ServerSession]struct{} + readers map[*ServerSession]struct{} + streamTracks []*serverStreamTrack } // NewServerStream allocates a ServerStream. @@ -48,9 +48,9 @@ func NewServerStream(tracks Tracks) *ServerStream { readers: make(map[*ServerSession]struct{}), } - st.ssTracks = make([]*serverStreamTrack, len(tracks)) - for i := range st.ssTracks { - st.ssTracks[i] = &serverStreamTrack{} + st.streamTracks = make([]*serverStreamTrack, len(tracks)) + for i := range st.streamTracks { + st.streamTracks[i] = &serverStreamTrack{} } return st @@ -65,15 +65,13 @@ func (st *ServerStream) Close() error { ss.Close() } - if st.serverMulticastHandlers != nil { - for _, h := range st.serverMulticastHandlers { - h.close() + for _, track := range st.streamTracks { + if track.multicastHandler != nil { + track.multicastHandler.close() } - st.serverMulticastHandlers = nil } st.readers = nil - st.activeUnicastReaders = nil return nil } @@ -86,14 +84,14 @@ func (st *ServerStream) Tracks() Tracks { func (st *ServerStream) ssrc(trackID int) uint32 { st.mutex.Lock() defer st.mutex.Unlock() - return st.ssTracks[trackID].lastSSRC + return st.streamTracks[trackID].lastSSRC } func (st *ServerStream) rtpInfo(trackID int, now time.Time) (uint16, uint32, bool) { st.mutex.Lock() defer st.mutex.Unlock() - track := st.ssTracks[trackID] + track := st.streamTracks[trackID] if !track.firstPacketSent { return 0, 0, false @@ -125,9 +123,8 @@ func (st *ServerStream) readerAdd( return fmt.Errorf("stream is closed") } - switch transport { - case TransportUDP: - // check if client ports are already in use by another reader + // check whether UDP ports are already in use by another reader + if transport == TransportUDP { for r := range st.readers { if *r.setuppedTransport == TransportUDP && r.author.ip().Equal(ss.author.ip()) && @@ -139,33 +136,12 @@ func (st *ServerStream) readerAdd( } } } - - case TransportUDPMulticast: - // allocate multicast listeners - if st.serverMulticastHandlers == nil { - st.serverMulticastHandlers = make([]*serverMulticastHandler, len(st.tracks)) - - for i := range st.tracks { - h, err := newServerMulticastHandler(st.s) - if err != nil { - for _, h := range st.serverMulticastHandlers { - if h != nil { - h.close() - } - } - st.serverMulticastHandlers = nil - return err - } - - st.serverMulticastHandlers[i] = h - } - } } if st.s == nil { st.s = ss.s - for trackID, track := range st.ssTracks { + for trackID, track := range st.streamTracks { cTrackID := trackID // always generate RTCP sender reports. @@ -173,7 +149,7 @@ func (st *ServerStream) readerAdd( // they're also needed when transport protocol is TCP and client is Nvidia Deepstream // since they're used to compute NTP timestamp of frames: // https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html - track.udpRTCPSender = rtcpsender.New( + track.rtcpSender = rtcpsender.New( st.s.udpSenderReportPeriod, st.tracks[trackID].ClockRate(), func(pkt rtcp.Packet) { @@ -183,6 +159,19 @@ func (st *ServerStream) readerAdd( } } + // allocate multicast listeners + if transport == TransportUDPMulticast { + for _, track := range st.streamTracks { + if track.multicastHandler == nil { + mh, err := newServerMulticastHandler(st.s) + if err != nil { + return err + } + track.multicastHandler = mh + } + } + } + st.readers[ss] = struct{}{} return nil @@ -194,12 +183,13 @@ func (st *ServerStream) readerRemove(ss *ServerSession) { delete(st.readers, ss) - if len(st.readers) == 0 && st.serverMulticastHandlers != nil { - for _, l := range st.serverMulticastHandlers { - l.rtpl.close() - l.rtcpl.close() + if len(st.readers) == 0 { + for _, track := range st.streamTracks { + if track.multicastHandler != nil { + track.multicastHandler.close() + track.multicastHandler = nil + } } - st.serverMulticastHandlers = nil } } @@ -207,15 +197,13 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() - switch *ss.setuppedTransport { - case TransportUDP, TransportTCP: - st.activeUnicastReaders[ss] = struct{}{} - - default: // UDPMulticast + if *ss.setuppedTransport == TransportUDPMulticast { for trackID, track := range ss.setuppedTracks { - st.serverMulticastHandlers[trackID].rtcpl.addClient( - ss.author.ip(), st.serverMulticastHandlers[trackID].rtcpl.port(), ss, track, false) + st.streamTracks[trackID].multicastHandler.rtcpl.addClient( + ss.author.ip(), st.streamTracks[trackID].multicastHandler.rtcpl.port(), ss, track, false) } + } else { + st.activeUnicastReaders[ss] = struct{}{} } } @@ -223,16 +211,12 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { st.mutex.Lock() defer st.mutex.Unlock() - switch *ss.setuppedTransport { - case TransportUDP, TransportTCP: - delete(st.activeUnicastReaders, ss) - - default: // UDPMulticast - if st.serverMulticastHandlers != nil { - for trackID := range ss.setuppedTracks { - st.serverMulticastHandlers[trackID].rtcpl.removeClient(ss) - } + if *ss.setuppedTransport == TransportUDPMulticast { + for trackID := range ss.setuppedTracks { + st.streamTracks[trackID].multicastHandler.rtcpl.removeClient(ss) } + } else { + delete(st.activeUnicastReaders, ss) } } @@ -248,7 +232,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT st.mutex.RLock() defer st.mutex.RUnlock() - track := st.ssTracks[trackID] + track := st.streamTracks[trackID] now := time.Now() if !track.firstPacketSent || ptsEqualsDTS { @@ -260,8 +244,8 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT track.lastSequenceNumber = pkt.Header.SequenceNumber track.lastSSRC = pkt.Header.SSRC - if track.udpRTCPSender != nil { - track.udpRTCPSender.ProcessPacketRTP(now, pkt, ptsEqualsDTS) + if track.rtcpSender != nil { + track.rtcpSender.ProcessPacketRTP(now, pkt, ptsEqualsDTS) } // send unicast @@ -270,8 +254,8 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT } // send multicast - if st.serverMulticastHandlers != nil { - st.serverMulticastHandlers[trackID].writePacketRTP(byts) + if track.multicastHandler != nil { + track.multicastHandler.writePacketRTP(byts) } } @@ -291,7 +275,8 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) { } // send multicast - if st.serverMulticastHandlers != nil { - st.serverMulticastHandlers[trackID].writePacketRTCP(byts) + track := st.streamTracks[trackID] + if track.multicastHandler != nil { + track.multicastHandler.writePacketRTCP(byts) } }