diff --git a/server_read_test.go b/server_read_test.go index 48fd7ef3..221ea336 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -280,6 +280,75 @@ func TestServerReadSetupErrors(t *testing.T) { } } +func TestServerReadSetupErrorSameUDPPorts(t *testing.T) { + track := &TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + } + + stream := NewServerStream(Tracks{track}) + defer stream.Close() + + s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + UDPRTPAddress: "127.0.0.1:8000", + UDPRTCPAddress: "127.0.0.1:8001", + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + for i := 0; i < 2; i++ { + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + inTH := &headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + Protocol: headers.TransportProtocolUDP, + ClientPorts: &[2]int{35466, 35467}, + } + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + + if i == 0 { + require.Equal(t, base.StatusOK, res.StatusCode) + } else { + require.Equal(t, base.StatusBadRequest, res.StatusCode) + } + } +} + func TestServerRead(t *testing.T) { for _, transport := range []string{ "udp", @@ -1871,113 +1940,3 @@ func TestServerReadAdditionalInfos(t *testing.T) { }(), }, ssrcs) } - -func TestServerReadErrorUDPSamePorts(t *testing.T) { - track := &TrackH264{ - PayloadType: 96, - SPS: []byte{0x01, 0x02, 0x03, 0x04}, - PPS: []byte{0x01, 0x02, 0x03, 0x04}, - } - - stream := NewServerStream(Tracks{track}) - defer stream.Close() - - s := &Server{ - Handler: &testServerHandler{ - onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - }, - }, - UDPRTPAddress: "127.0.0.1:8000", - UDPRTCPAddress: "127.0.0.1:8001", - RTSPAddress: "localhost:8554", - } - - err := s.Start() - require.NoError(t, err) - defer s.Close() - - func() { - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer nconn.Close() - conn := conn.NewConn(nconn) - - inTH := &headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Mode: func() *headers.TransportMode { - v := headers.TransportModePlay - return &v - }(), - Protocol: headers.TransportProtocolUDP, - ClientPorts: &[2]int{35466, 35467}, - } - - res, err := writeReqReadRes(conn, base.Request{ - Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Transport": inTH.Marshal(), - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - - var sx headers.Session - err = sx.Unmarshal(res.Header["Session"]) - require.NoError(t, err) - - res, err = writeReqReadRes(conn, base.Request{ - Method: base.Play, - URL: mustParseURL("rtsp://localhost:8554/teststream"), - Header: base.Header{ - "CSeq": base.HeaderValue{"2"}, - "Session": base.HeaderValue{sx.Session}, - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - }() - - func() { - nconn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer nconn.Close() - conn := conn.NewConn(nconn) - - inTH := &headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Mode: func() *headers.TransportMode { - v := headers.TransportModePlay - return &v - }(), - Protocol: headers.TransportProtocolUDP, - ClientPorts: &[2]int{35466, 35467}, - } - - res, err := writeReqReadRes(conn, base.Request{ - Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Transport": inTH.Marshal(), - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusBadRequest, res.StatusCode) - }() -} diff --git a/serverstream.go b/serverstream.go index 119cb766..50745ae7 100644 --- a/serverstream.go +++ b/serverstream.go @@ -21,7 +21,7 @@ type serverStreamTrack struct { udpRTCPSender *rtcpsender.RTCPSender } -// ServerStream represents a single stream. +// ServerStream represents a data stream. // This is in charge of // - distributing the stream to each reader // - allocating multicast listeners @@ -31,10 +31,10 @@ type ServerStream struct { mutex sync.RWMutex s *Server - readersUnicast map[*ServerSession]struct{} + activeUnicastReaders map[*ServerSession]struct{} readers map[*ServerSession]struct{} serverMulticastHandlers []*serverMulticastHandler - stTracks []*serverStreamTrack + ssTracks []*serverStreamTrack } // NewServerStream allocates a ServerStream. @@ -43,14 +43,14 @@ func NewServerStream(tracks Tracks) *ServerStream { tracks.setControls() st := &ServerStream{ - tracks: tracks, - readersUnicast: make(map[*ServerSession]struct{}), - readers: make(map[*ServerSession]struct{}), + tracks: tracks, + activeUnicastReaders: make(map[*ServerSession]struct{}), + readers: make(map[*ServerSession]struct{}), } - st.stTracks = make([]*serverStreamTrack, len(tracks)) - for i := range st.stTracks { - st.stTracks[i] = &serverStreamTrack{} + st.ssTracks = make([]*serverStreamTrack, len(tracks)) + for i := range st.ssTracks { + st.ssTracks[i] = &serverStreamTrack{} } return st @@ -73,7 +73,7 @@ func (st *ServerStream) Close() error { } st.readers = nil - st.readersUnicast = nil + st.activeUnicastReaders = nil return nil } @@ -86,14 +86,14 @@ func (st *ServerStream) Tracks() Tracks { func (st *ServerStream) ssrc(trackID int) uint32 { st.mutex.Lock() defer st.mutex.Unlock() - return st.stTracks[trackID].lastSSRC + return st.ssTracks[trackID].lastSSRC } func (st *ServerStream) rtpInfo(trackID int, now time.Time) (uint16, uint32, bool) { st.mutex.Lock() defer st.mutex.Unlock() - track := st.stTracks[trackID] + track := st.ssTracks[trackID] if !track.firstPacketSent { return 0, 0, false @@ -125,31 +125,10 @@ func (st *ServerStream) readerAdd( return fmt.Errorf("stream is closed") } - if st.s == nil { - st.s = ss.s - - for trackID, track := range st.stTracks { - cTrackID := trackID - - // always generate RTCP sender reports. - // they're mandatory needed when transport protocol is UDP or UDP-multicast. - // they're also needed when transport protocol is TCP and client is Nvidia Deepstream - // since they're used to compute NTP timestamp of frames: - // https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html - track.udpRTCPSender = rtcpsender.New( - st.s.udpSenderReportPeriod, - st.tracks[trackID].ClockRate(), - func(pkt rtcp.Packet) { - st.WritePacketRTCP(cTrackID, pkt) - }, - ) - } - } - switch transport { case TransportUDP: - // check if client ports are already in use by another reader. - for r := range st.readersUnicast { + // check if client ports are already in use by another reader + for r := range st.readers { if *r.setuppedTransport == TransportUDP && r.author.ip().Equal(ss.author.ip()) && r.author.zone() == ss.author.zone() { @@ -183,6 +162,27 @@ func (st *ServerStream) readerAdd( } } + if st.s == nil { + st.s = ss.s + + for trackID, track := range st.ssTracks { + cTrackID := trackID + + // always generate RTCP sender reports. + // they're mandatory when transport protocol is UDP or UDP-multicast. + // they're also needed when transport protocol is TCP and client is Nvidia Deepstream + // since they're used to compute NTP timestamp of frames: + // https://docs.nvidia.com/metropolis/deepstream/dev-guide/text/DS_NTP_Timestamp.html + track.udpRTCPSender = rtcpsender.New( + st.s.udpSenderReportPeriod, + st.tracks[trackID].ClockRate(), + func(pkt rtcp.Packet) { + st.WritePacketRTCP(cTrackID, pkt) + }, + ) + } + } + st.readers[ss] = struct{}{} return nil @@ -209,7 +209,7 @@ func (st *ServerStream) readerSetActive(ss *ServerSession) { switch *ss.setuppedTransport { case TransportUDP, TransportTCP: - st.readersUnicast[ss] = struct{}{} + st.activeUnicastReaders[ss] = struct{}{} default: // UDPMulticast for trackID, track := range ss.setuppedTracks { @@ -225,7 +225,7 @@ func (st *ServerStream) readerSetInactive(ss *ServerSession) { switch *ss.setuppedTransport { case TransportUDP, TransportTCP: - delete(st.readersUnicast, ss) + delete(st.activeUnicastReaders, ss) default: // UDPMulticast if st.serverMulticastHandlers != nil { @@ -248,7 +248,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT st.mutex.RLock() defer st.mutex.RUnlock() - track := st.stTracks[trackID] + track := st.ssTracks[trackID] now := time.Now() if !track.firstPacketSent || ptsEqualsDTS { @@ -265,7 +265,7 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet, ptsEqualsDT } // send unicast - for r := range st.readersUnicast { + for r := range st.activeUnicastReaders { r.writePacketRTP(trackID, byts) } @@ -286,7 +286,7 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) { defer st.mutex.RUnlock() // send unicast - for r := range st.readersUnicast { + for r := range st.activeUnicastReaders { r.writePacketRTCP(trackID, byts) }