diff --git a/client.go b/client.go index 399215f7..67f9ad8a 100644 --- a/client.go +++ b/client.go @@ -236,7 +236,6 @@ type Client struct { medias map[*media.Media]*clientMedia tcpMediasByChannel map[int]*clientMedia lastRange *headers.Range - rtpPacketBuffer *rtpPacketMultiBuffer // play checkStreamTimer *time.Timer checkStreamInitial bool tcpLastFrameTime *int64 @@ -630,7 +629,6 @@ func (c *Client) playRecordStart() { if c.state == clientStatePlay { c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) - c.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)) switch *c.effectiveTransport { case TransportUDP: diff --git a/client_play_test.go b/client_play_test.go index 3b4b328e..8fc0f528 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "net" + "strconv" "strings" "sync/atomic" "testing" @@ -274,10 +275,16 @@ func TestClientPlay(t *testing.T) { err = forma.Init() require.NoError(t, err) - medias := media.Medias{&media.Media{ - Type: "application", - Formats: []format.Format{forma}, - }} + medias := media.Medias{ + &media.Media{ + Type: "application", + Formats: []format.Format{forma}, + }, + &media.Media{ + Type: "application", + Formats: []format.Format{forma}, + }, + } medias.SetControls() err = conn.WriteResponse(&base.Response{ @@ -290,87 +297,92 @@ func TestClientPlay(t *testing.T) { }) require.NoError(t, err) - req, err = conn.ReadRequest() - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) - require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID=0"), req.URL) + var l1s [2]net.PacketConn + var l2s [2]net.PacketConn + var clientPorts [2]*[2]int - var inTH headers.Transport - err = inTH.Unmarshal(req.Header["Transport"]) - require.NoError(t, err) - - th := headers.Transport{} - - var l1 net.PacketConn - var l2 net.PacketConn - - switch transport { - case "udp": - v := headers.TransportDeliveryUnicast - th.Delivery = &v - th.Protocol = headers.TransportProtocolUDP - th.ClientPorts = inTH.ClientPorts - th.ServerPorts = &[2]int{34556, 34557} - - l1, err = net.ListenPacket("udp", listenIP+":34556") + for i := 0; i < 2; i++ { + req, err = conn.ReadRequest() require.NoError(t, err) - defer l1.Close() + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL( + scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID="+strconv.FormatInt(int64(i), 10)), req.URL) - l2, err = net.ListenPacket("udp", listenIP+":34557") - require.NoError(t, err) - defer l2.Close() - - case "multicast": - v := headers.TransportDeliveryMulticast - th.Delivery = &v - th.Protocol = headers.TransportProtocolUDP - v2 := net.ParseIP("224.1.0.1") - th.Destination = &v2 - th.Ports = &[2]int{25000, 25001} - - l1, err = net.ListenPacket("udp", "224.0.0.0:25000") - require.NoError(t, err) - defer l1.Close() - - p := ipv4.NewPacketConn(l1) - - intfs, err := net.Interfaces() + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) require.NoError(t, err) - for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) + var th headers.Transport + + switch transport { + case "udp": + v := headers.TransportDeliveryUnicast + th.Delivery = &v + th.Protocol = headers.TransportProtocolUDP + th.ClientPorts = inTH.ClientPorts + clientPorts[i] = inTH.ClientPorts + th.ServerPorts = &[2]int{34556 + i*2, 34557 + i*2} + + l1s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[0]), 10)) require.NoError(t, err) + defer l1s[i].Close() + + l2s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[1]), 10)) + require.NoError(t, err) + defer l2s[i].Close() + + case "multicast": + v := headers.TransportDeliveryMulticast + th.Delivery = &v + th.Protocol = headers.TransportProtocolUDP + v2 := net.ParseIP("224.1.0.1") + th.Destination = &v2 + th.Ports = &[2]int{25000 + i*2, 25001 + i*2} + + l1s[i], err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10)) + require.NoError(t, err) + defer l1s[i].Close() + + p := ipv4.NewPacketConn(l1s[i]) + + intfs, err := net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) + require.NoError(t, err) + } + + l2s[i], err = net.ListenPacket("udp", "224.0.0.0:25001") + require.NoError(t, err) + defer l2s[i].Close() + + p = ipv4.NewPacketConn(l2s[i]) + + intfs, err = net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) + require.NoError(t, err) + } + + case "tcp", "tls": + v := headers.TransportDeliveryUnicast + th.Delivery = &v + th.Protocol = headers.TransportProtocolTCP + th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2} } - l2, err = net.ListenPacket("udp", "224.0.0.0:25001") + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) require.NoError(t, err) - defer l2.Close() - - p = ipv4.NewPacketConn(l2) - - intfs, err = net.Interfaces() - require.NoError(t, err) - - for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) - require.NoError(t, err) - } - - case "tcp", "tls": - v := headers.TransportDeliveryUnicast - th.Delivery = &v - th.Protocol = headers.TransportProtocolTCP - th.InterleavedIDs = &[2]int{0, 1} } - err = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - }, - }) - require.NoError(t, err) - req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -382,56 +394,58 @@ func TestClientPlay(t *testing.T) { }) require.NoError(t, err) - // server -> client (RTP) - switch transport { - case "udp": - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) + for i := 0; i < 2; i++ { + // server -> client (RTP) + switch transport { + case "udp": + l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: clientPorts[i][0], + }) - case "multicast": - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("224.1.0.1"), - Port: 25000, - }) + case "multicast": + l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("224.1.0.1"), + Port: 25000, + }) - case "tcp", "tls": - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 0, - Payload: testRTPPacketMarshaled, - }, make([]byte, 1024)) - require.NoError(t, err) - } - - // client -> server (RTCP) - switch transport { - case "udp", "multicast": - // skip firewall opening - if transport == "udp" { - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) + case "tcp", "tls": + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0 + i*2, + Payload: testRTPPacketMarshaled, + }, make([]byte, 1024)) require.NoError(t, err) } - buf := make([]byte, 2048) - n, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - packets, err := rtcp.Unmarshal(buf[:n]) - require.NoError(t, err) - require.Equal(t, &testRTCPPacket, packets[0]) - close(packetRecv) + // client -> server (RTCP) + switch transport { + case "udp", "multicast": + // skip firewall opening + if transport == "udp" { + buf := make([]byte, 2048) + _, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + } - case "tcp", "tls": - f, err := conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 1, f.Channel) - packets, err := rtcp.Unmarshal(f.Payload) - require.NoError(t, err) - require.Equal(t, &testRTCPPacket, packets[0]) - close(packetRecv) + buf := make([]byte, 2048) + n, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + packets, err := rtcp.Unmarshal(buf[:n]) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) + + case "tcp", "tls": + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 1+i*2, f.Channel) + packets, err := rtcp.Unmarshal(f.Payload) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) + } } + close(packetRecv) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -464,16 +478,28 @@ func TestClientPlay(t *testing.T) { }(), } - err = readAll(&c, - scheme+"://"+listenIP+":8554/test/stream?param=value", - func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { - require.Equal(t, &testRTPPacket, pkt) - err := c.WritePacketRTCP(medi, &testRTCPPacket) - require.NoError(t, err) - }) + u, err := url.Parse(scheme + "://" + listenIP + ":8554/test/stream?param=value") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) defer c.Close() + medias, baseURL, _, err := c.Describe(u) + require.NoError(t, err) + + err = c.SetupAll(medias, baseURL) + require.NoError(t, err) + + c.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { + require.Equal(t, &testRTPPacket, pkt) + err := c.WritePacketRTCP(medi, &testRTCPPacket) + require.NoError(t, err) + }) + + _, err = c.Play(nil) + require.NoError(t, err) + <-packetRecv }) } diff --git a/clientmedia.go b/clientmedia.go index d9e62d1a..5ee754eb 100644 --- a/clientmedia.go +++ b/clientmedia.go @@ -6,6 +6,7 @@ import ( "time" "github.com/pion/rtcp" + "github.com/pion/rtp" "github.com/aler9/gortsplib/v2/pkg/base" "github.com/aler9/gortsplib/v2/pkg/media" @@ -187,7 +188,7 @@ func (cm *clientMedia) readRTPTCPPlay(payload []byte) error { now := time.Now() atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix()) - pkt := cm.c.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { return err @@ -259,7 +260,7 @@ func (cm *clientMedia) readRTPUDPPlay(payload []byte) error { return nil } - pkt := cm.c.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { cm.c.OnDecodeError(err) diff --git a/rtppacketmultibuffer.go b/rtppacketmultibuffer.go deleted file mode 100644 index a36f2233..00000000 --- a/rtppacketmultibuffer.go +++ /dev/null @@ -1,24 +0,0 @@ -package gortsplib - -import ( - "github.com/pion/rtp" -) - -type rtpPacketMultiBuffer struct { - count uint64 - buffers []rtp.Packet - cur uint64 -} - -func newRTPPacketMultiBuffer(count uint64) *rtpPacketMultiBuffer { - return &rtpPacketMultiBuffer{ - count: count, - buffers: make([]rtp.Packet, count), - } -} - -func (mb *rtpPacketMultiBuffer) next() *rtp.Packet { - ret := &mb.buffers[mb.cur%mb.count] - mb.cur++ - return ret -} diff --git a/server_record_test.go b/server_record_test.go index d9022594..9c5b7347 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "net" + "strconv" "testing" "time" @@ -499,18 +500,24 @@ func TestServerRecord(t *testing.T) { // send RTCP packets directly to the session. // these are sent after the response, only if onRecord returns StatusOK. ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) + ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[1], &testRTCPPacket) - ctx.Session.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { - require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) - require.Equal(t, ctx.Session.AnnouncedMedias()[0].Formats[0], forma) - require.Equal(t, &testRTPPacket, pkt) - }) + for i := 0; i < 2; i++ { + ctx.Session.OnPacketRTP( + ctx.Session.AnnouncedMedias()[i], + ctx.Session.AnnouncedMedias()[i].Formats[0], + func(pkt *rtp.Packet) { + require.Equal(t, &testRTPPacket, pkt) + }) - ctx.Session.OnPacketRTCPAny(func(medi *media.Media, pkt rtcp.Packet) { - require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) - require.Equal(t, &testRTCPPacket, pkt) - ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) - }) + ci := i + ctx.Session.OnPacketRTCP( + ctx.Session.AnnouncedMedias()[i], + func(pkt rtcp.Packet) { + require.Equal(t, &testRTCPPacket, pkt) + ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[ci], &testRTCPPacket) + }) + } return &base.Response{ StatusCode: base.StatusOK, @@ -549,7 +556,7 @@ func TestServerRecord(t *testing.T) { <-nconnOpened - medias := media.Medias{testH264Media.Clone()} + medias := media.Medias{testH264Media.Clone(), testH264Media.Clone()} medias.SetControls() res, err := writeReqReadRes(conn, base.Request{ @@ -566,54 +573,61 @@ func TestServerRecord(t *testing.T) { <-sessionOpened - inTH := &headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Mode: func() *headers.TransportMode { - v := headers.TransportModeRecord - return &v - }(), - } - - var l1 net.PacketConn - var l2 net.PacketConn - - if transport == "udp" { - inTH.Protocol = headers.TransportProtocolUDP - inTH.ClientPorts = &[2]int{35466, 35467} - - l1, err = net.ListenPacket("udp", "localhost:35466") - require.NoError(t, err) - defer l1.Close() - - l2, err = net.ListenPacket("udp", "localhost:35467") - require.NoError(t, err) - defer l2.Close() - } else { - inTH.Protocol = headers.TransportProtocolTCP - inTH.InterleavedIDs = &[2]int{0, 1} - } - - res, err = writeReqReadRes(conn, base.Request{ - Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=0"), - Header: base.Header{ - "CSeq": base.HeaderValue{"2"}, - "Transport": inTH.Marshal(), - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - + var l1s [2]net.PacketConn + var l2s [2]net.PacketConn var sx headers.Session - err = sx.Unmarshal(res.Header["Session"]) - require.NoError(t, err) + var serverPorts [2]*[2]int - var th headers.Transport - err = th.Unmarshal(res.Header["Transport"]) - require.NoError(t, err) + for i := 0; i < 2; i++ { + inTH := &headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModeRecord + return &v + }(), + } + + if transport == "udp" { + inTH.Protocol = headers.TransportProtocolUDP + inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2} + + l1s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[0]), 10)) + require.NoError(t, err) + defer l1s[i].Close() + + l2s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[1]), 10)) + require.NoError(t, err) + defer l2s[i].Close() + } else { + inTH.Protocol = headers.TransportProtocolTCP + inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2} + } + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=" + strconv.FormatInt(int64(i), 10)), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = sx.Unmarshal(res.Header["Session"]) + require.NoError(t, err) + + var th headers.Transport + err = th.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) + + if transport == "udp" { + serverPorts[i] = th.ServerPorts + } + } res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, @@ -626,62 +640,66 @@ func TestServerRecord(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - // server -> client (direct) - if transport == "udp" { - buf := make([]byte, 2048) - n, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) - } else { - f, err := conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 1, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + for i := 0; i < 2; i++ { + // server -> client (direct) + if transport == "udp" { + buf := make([]byte, 2048) + n, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + } else { + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 3+i*2, f.Channel) + require.Equal(t, testRTCPPacketMarshaled, f.Payload) + } + + // skip firewall opening + if transport == "udp" { + buf := make([]byte, 2048) + _, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + } + + // client -> server + if transport == "udp" { + l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: serverPorts[i][0], + }) + + l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: serverPorts[i][1], + }) + } else { + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 2 + i*2, + Payload: testRTPPacketMarshaled, + }, make([]byte, 1024)) + require.NoError(t, err) + + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 3 + i*2, + Payload: testRTCPPacketMarshaled, + }, make([]byte, 1024)) + require.NoError(t, err) + } } - // skip firewall opening - if transport == "udp" { - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - } - - // client -> server - if transport == "udp" { - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ServerPorts[0], - }) - - l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ServerPorts[1], - }) - } else { - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 0, - Payload: testRTPPacketMarshaled, - }, make([]byte, 1024)) - require.NoError(t, err) - - err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 1, - Payload: testRTCPPacketMarshaled, - }, make([]byte, 1024)) - require.NoError(t, err) - } - - // server -> client (RTCP) - if transport == "udp" { - buf := make([]byte, 2048) - n, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) - } else { - f, err := conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 1, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + for i := 0; i < 2; i++ { + // server -> client (RTCP) + if transport == "udp" { + buf := make([]byte, 2048) + n, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + } else { + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 3+i*2, f.Channel) + require.Equal(t, testRTCPPacketMarshaled, f.Payload) + } } res, err = writeReqReadRes(conn, base.Request{ diff --git a/serversession.go b/serversession.go index d6fc31d8..e844bbce 100644 --- a/serversession.go +++ b/serversession.go @@ -173,7 +173,6 @@ type ServerSession struct { udpLastPacketTime *int64 // publish udpCheckStreamTimer *time.Timer writer writer - rtpPacketBuffer *rtpPacketMultiBuffer // in request chan sessionRequestReq @@ -948,8 +947,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStateRecord - ss.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(ss.s.ReadBufferCount)) - for _, sm := range ss.setuppedMedias { sm.start() } diff --git a/serversessionmedia.go b/serversessionmedia.go index ed7916bc..e44ce85b 100644 --- a/serversessionmedia.go +++ b/serversessionmedia.go @@ -185,7 +185,7 @@ func (sm *serverSessionMedia) readRTPUDPRecord(payload []byte) error { return nil } - pkt := sm.ss.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { onDecodeError(sm.ss, err) @@ -265,7 +265,7 @@ func (sm *serverSessionMedia) readRTCPTCPPlay(payload []byte) error { } func (sm *serverSessionMedia) readRTPTCPRecord(payload []byte) error { - pkt := sm.ss.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { return err diff --git a/serverstream.go b/serverstream.go index bc351d56..30ab9938 100644 --- a/serverstream.go +++ b/serverstream.go @@ -292,11 +292,6 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet // WritePacketRTCP writes a RTCP packet to all the readers of the stream. func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { - byts, err := pkt.Marshal() - if err != nil { - return - } - st.mutex.RLock() defer st.mutex.RUnlock() @@ -305,17 +300,5 @@ func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { } sm := st.streamMedias[medi] - - // send unicast - for r := range st.activeUnicastReaders { - sm, ok := r.setuppedMedias[medi] - if ok { - sm.writePacketRTCP(byts) - } - } - - // send multicast - if sm.multicastHandler != nil { - sm.multicastHandler.writePacketRTCP(byts) - } + sm.writePacketRTCP(st, pkt) } diff --git a/serverstreammedia.go b/serverstreammedia.go index 8832b5ea..04330210 100644 --- a/serverstreammedia.go +++ b/serverstreammedia.go @@ -3,6 +3,7 @@ package gortsplib import ( "time" + "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/aler9/gortsplib/v2/pkg/media" @@ -69,3 +70,23 @@ func (sm *serverStreamMedia) WritePacketRTPWithNTP(ss *ServerStream, pkt *rtp.Pa sm.multicastHandler.writePacketRTP(byts) } } + +func (sm *serverStreamMedia) writePacketRTCP(ss *ServerStream, pkt rtcp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return + } + + // send unicast + for r := range ss.activeUnicastReaders { + sm, ok := r.setuppedMedias[sm.media] + if ok { + sm.writePacketRTCP(byts) + } + } + + // send multicast + if sm.multicastHandler != nil { + sm.multicastHandler.writePacketRTCP(byts) + } +}