diff --git a/client_publish_test.go b/client_publish_test.go index cca6281e..507b8e0f 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -22,6 +22,7 @@ var testRTPPacket = rtp.Packet{ Version: 2, PayloadType: 97, CSRC: []uint32{}, + SSRC: 0x38F27A2F, }, Payload: []byte{0x01, 0x02, 0x03, 0x04}, } diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index c28ad1fc..488815ec 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -135,17 +135,6 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) } } -// called after receiving a RTCP packet. -func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTCPCtx) { - sh.mutex.Lock() - defer sh.mutex.Unlock() - - // if we are the publisher, route the RTCP packet to readers - if ctx.Session == sh.publisher { - sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Packet) - } -} - func main() { // setup certificates - they can be generated with // openssl genrsa -out server.key 2048 diff --git a/examples/server/main.go b/examples/server/main.go index 8caaca54..952b0e51 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -134,17 +134,6 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) } } -// called after receiving a RTCP packet. -func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTCPCtx) { - sh.mutex.Lock() - defer sh.mutex.Unlock() - - // if we are the publisher, route the RTCP packet to readers - if ctx.Session == sh.publisher { - sh.stream.WritePacketRTCP(ctx.TrackID, ctx.Packet) - } -} - func main() { // configure server s := &gortsplib.Server{ diff --git a/server.go b/server.go index 32bdf71e..f06cf750 100644 --- a/server.go +++ b/server.go @@ -139,6 +139,7 @@ type Server struct { // udpReceiverReportPeriod time.Duration + udpSenderReportPeriod time.Duration sessionTimeout time.Duration checkStreamPeriod time.Duration @@ -192,6 +193,9 @@ func (s *Server) Start() error { if s.udpReceiverReportPeriod == 0 { s.udpReceiverReportPeriod = 10 * time.Second } + if s.udpSenderReportPeriod == 0 { + s.udpSenderReportPeriod = 10 * time.Second + } if s.sessionTimeout == 0 { s.sessionTimeout = 1 * 60 * time.Second } diff --git a/server_read_test.go b/server_read_test.go index 01ad1069..f80d872a 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/pion/rtcp" "github.com/pion/rtp/v2" psdp "github.com/pion/sdp/v3" "github.com/stretchr/testify/require" @@ -500,13 +501,6 @@ func TestServerRead(t *testing.T) { } } - // skip firewall opening - if transport == "udp" { - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - } - // server -> client (through stream) if transport == "udp" || transport == "multicast" { buf := make([]byte, 2048) @@ -610,6 +604,116 @@ func TestServerRead(t *testing.T) { } } +func TestServerReadRTCPReport(t *testing.T) { + track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) + require.NoError(t, err) + + stream := NewServerStream(Tracks{track}) + defer stream.Close() + + s := &Server{ + Handler: &testServerHandler{ + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, stream, nil + }, + onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + udpSenderReportPeriod: 1 * time.Second, + RTSPAddress: "localhost:8554", + UDPRTPAddress: "127.0.0.1:8000", + UDPRTCPAddress: "127.0.0.1:8001", + } + + err = s.Start() + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + br := bufio.NewReader(conn) + + inTH := &headers.Transport{ + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Protocol: headers.TransportProtocolUDP, + ClientPorts: &[2]int{35466, 35467}, + } + + res, err := writeReqReadRes(conn, br, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": inTH.Write(), + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + l1, err := net.ListenPacket("udp", "localhost:35466") + require.NoError(t, err) + defer l1.Close() + + l2, err := net.ListenPacket("udp", "localhost:35467") + require.NoError(t, err) + defer l2.Close() + + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + + res, err = writeReqReadRes(conn, br, base.Request{ + Method: base.Play, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Session": base.HeaderValue{sx.Session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + stream.WritePacketRTP(0, &testRTPPacket) + stream.WritePacketRTP(0, &testRTPPacket) + + buf := make([]byte, 2048) + n, _, err := l2.ReadFrom(buf) + require.NoError(t, err) + packets, err := rtcp.Unmarshal(buf[:n]) + require.NoError(t, err) + require.Equal(t, &rtcp.SenderReport{ + SSRC: 0x38F27A2F, + NTPTime: packets[0].(*rtcp.SenderReport).NTPTime, + RTPTime: packets[0].(*rtcp.SenderReport).RTPTime, + PacketCount: 2, + OctetCount: 8, + }, packets[0]) + + res, err = writeReqReadRes(conn, br, base.Request{ + Method: base.Teardown, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + "Session": base.HeaderValue{sx.Session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) +} + func TestServerReadVLCMulticast(t *testing.T) { track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) require.NoError(t, err) diff --git a/server_test.go b/server_test.go index 9e83f1ba..9b0e9dfa 100644 --- a/server_test.go +++ b/server_test.go @@ -414,14 +414,6 @@ func TestServerHighLevelPublishRead(t *testing.T) { stream.WritePacketRTP(ctx.TrackID, ctx.Packet) } }, - onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { - mutex.Lock() - defer mutex.Unlock() - - if ctx.Session == publisher { - stream.WritePacketRTCP(ctx.TrackID, ctx.Packet) - } - }, }, RTSPAddress: "localhost:8554", } diff --git a/serversession.go b/serversession.go index 92436ae1..12298454 100644 --- a/serversession.go +++ b/serversession.go @@ -863,8 +863,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base // readers can send RTCP packets only sc.s.udpRTCPListener.addClient(ss.author.ip(), track.udpRTCPPort, ss, trackID, false) - // open the firewall by sending packets to the counterpart - ss.WritePacketRTCP(trackID, &rtcp.ReceiverReport{}) + // firewall opening is performed by RTCP sender reports generated by ServerStream } case TransportUDPMulticast: diff --git a/serverstream.go b/serverstream.go index 99bc5a73..12067f1c 100644 --- a/serverstream.go +++ b/serverstream.go @@ -9,13 +9,15 @@ import ( "github.com/pion/rtp/v2" "github.com/aler9/gortsplib/pkg/liberrors" + "github.com/aler9/gortsplib/pkg/rtcpsender" ) -type trackInfo struct { +type serverStreamTrack struct { lastSequenceNumber uint32 lastTimeRTP uint32 lastTimeNTP int64 lastSSRC uint32 + rtcpSender *rtcpsender.RTCPSender } // ServerStream represents a single stream. @@ -24,14 +26,14 @@ type trackInfo struct { // - allocating multicast listeners // - gathering infos about the stream to generate SSRC and RTP-Info type ServerStream struct { - s *Server tracks Tracks mutex sync.RWMutex + s *Server readersUnicast map[*ServerSession]struct{} readers map[*ServerSession]struct{} serverMulticastHandlers []*serverMulticastHandler - trackInfos []*trackInfo + stTracks []*serverStreamTrack } // NewServerStream allocates a ServerStream. @@ -45,9 +47,9 @@ func NewServerStream(tracks Tracks) *ServerStream { readers: make(map[*ServerSession]struct{}), } - st.trackInfos = make([]*trackInfo, len(tracks)) - for i := range st.trackInfos { - st.trackInfos[i] = &trackInfo{} + st.stTracks = make([]*serverStreamTrack, len(tracks)) + for i := range st.stTracks { + st.stTracks[i] = &serverStreamTrack{} } return st @@ -81,12 +83,12 @@ func (st *ServerStream) Tracks() Tracks { } func (st *ServerStream) ssrc(trackID int) uint32 { - return atomic.LoadUint32(&st.trackInfos[trackID].lastSSRC) + return atomic.LoadUint32(&st.stTracks[trackID].lastSSRC) } func (st *ServerStream) timestamp(trackID int) uint32 { - lastTimeRTP := atomic.LoadUint32(&st.trackInfos[trackID].lastTimeRTP) - lastTimeNTP := atomic.LoadInt64(&st.trackInfos[trackID].lastTimeNTP) + lastTimeRTP := atomic.LoadUint32(&st.stTracks[trackID].lastTimeRTP) + lastTimeNTP := atomic.LoadInt64(&st.stTracks[trackID].lastTimeNTP) if lastTimeRTP == 0 || lastTimeNTP == 0 { return 0 @@ -97,7 +99,7 @@ func (st *ServerStream) timestamp(trackID int) uint32 { } func (st *ServerStream) lastSequenceNumber(trackID int) uint16 { - return uint16(atomic.LoadUint32(&st.trackInfos[trackID].lastSequenceNumber)) + return uint16(atomic.LoadUint32(&st.stTracks[trackID].lastSequenceNumber)) } func (st *ServerStream) readerAdd( @@ -110,6 +112,17 @@ func (st *ServerStream) readerAdd( if st.s == nil { st.s = ss.s + + for trackID, track := range st.stTracks { + cTrackID := trackID + track.rtcpSender = rtcpsender.New( + st.s.udpSenderReportPeriod, + st.tracks[trackID].ClockRate(), + func(pkt rtcp.Packet) { + st.writePacketRTCPSenderReport(cTrackID, pkt) + }, + ) + } } switch transport { @@ -209,17 +222,22 @@ func (st *ServerStream) WritePacketRTP(trackID int, pkt *rtp.Packet) { return } - track := st.trackInfos[trackID] + track := st.stTracks[trackID] + now := time.Now() atomic.StoreUint32(&track.lastSequenceNumber, uint32(pkt.Header.SequenceNumber)) atomic.StoreUint32(&track.lastTimeRTP, pkt.Header.Timestamp) - atomic.StoreInt64(&track.lastTimeNTP, time.Now().Unix()) + atomic.StoreInt64(&track.lastTimeNTP, now.Unix()) atomic.StoreUint32(&track.lastSSRC, pkt.Header.SSRC) st.mutex.RLock() defer st.mutex.RUnlock() + if track.rtcpSender != nil { + track.rtcpSender.ProcessPacketRTP(now, pkt) + } + // send unicast for r := range st.readersUnicast { r.writePacketRTP(trackID, byts) @@ -251,3 +269,25 @@ func (st *ServerStream) WritePacketRTCP(trackID int, pkt rtcp.Packet) { st.serverMulticastHandlers[trackID].writePacketRTCP(byts) } } + +func (st *ServerStream) writePacketRTCPSenderReport(trackID int, pkt rtcp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return + } + + st.mutex.RLock() + defer st.mutex.RUnlock() + + // send unicast (UDP only) + for r := range st.readersUnicast { + if *r.setuppedTransport == TransportUDP { + r.writePacketRTCP(trackID, byts) + } + } + + // send multicast + if st.serverMulticastHandlers != nil { + st.serverMulticastHandlers[trackID].writePacketRTCP(byts) + } +}