From ffe8c87c38a5834d30efab24853b1fdd0410d738 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 19 Dec 2022 13:46:43 +0100 Subject: [PATCH] fix overriding of previously-received RTP packets that leaded to crashes RTP packets were previously take from a buffer pool. This was messing up the Client, since that buffer pool was used by multiple routines at once, and was probably messing up the Server too, since packets can be pushed to different queues and there's no guarantee that these queues have an overall size less than ReadBufferCount. This buffer pool is removed; this decreases performance but avoids bugs. --- client.go | 2 - client_play_test.go | 272 ++++++++++++++++++++++------------------ clientmedia.go | 5 +- rtppacketmultibuffer.go | 24 ---- server_record_test.go | 240 +++++++++++++++++++---------------- serversession.go | 3 - serversessionmedia.go | 4 +- serverstream.go | 19 +-- serverstreammedia.go | 21 ++++ 9 files changed, 305 insertions(+), 285 deletions(-) delete mode 100644 rtppacketmultibuffer.go diff --git a/client.go b/client.go index 399215f7..67f9ad8a 100644 --- a/client.go +++ b/client.go @@ -236,7 +236,6 @@ type Client struct { medias map[*media.Media]*clientMedia tcpMediasByChannel map[int]*clientMedia lastRange *headers.Range - rtpPacketBuffer *rtpPacketMultiBuffer // play checkStreamTimer *time.Timer checkStreamInitial bool tcpLastFrameTime *int64 @@ -630,7 +629,6 @@ func (c *Client) playRecordStart() { if c.state == clientStatePlay { c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) - c.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)) switch *c.effectiveTransport { case TransportUDP: diff --git a/client_play_test.go b/client_play_test.go index 3b4b328e..8fc0f528 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "net" + "strconv" "strings" "sync/atomic" "testing" @@ -274,10 +275,16 @@ func TestClientPlay(t *testing.T) { err = forma.Init() require.NoError(t, err) - medias := media.Medias{&media.Media{ - Type: "application", - Formats: []format.Format{forma}, - }} + medias := media.Medias{ + &media.Media{ + Type: "application", + Formats: []format.Format{forma}, + }, + &media.Media{ + Type: "application", + Formats: []format.Format{forma}, + }, + } medias.SetControls() err = conn.WriteResponse(&base.Response{ @@ -290,87 +297,92 @@ func TestClientPlay(t *testing.T) { }) require.NoError(t, err) - req, err = conn.ReadRequest() - require.NoError(t, err) - require.Equal(t, base.Setup, req.Method) - require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID=0"), req.URL) + var l1s [2]net.PacketConn + var l2s [2]net.PacketConn + var clientPorts [2]*[2]int - var inTH headers.Transport - err = inTH.Unmarshal(req.Header["Transport"]) - require.NoError(t, err) - - th := headers.Transport{} - - var l1 net.PacketConn - var l2 net.PacketConn - - switch transport { - case "udp": - v := headers.TransportDeliveryUnicast - th.Delivery = &v - th.Protocol = headers.TransportProtocolUDP - th.ClientPorts = inTH.ClientPorts - th.ServerPorts = &[2]int{34556, 34557} - - l1, err = net.ListenPacket("udp", listenIP+":34556") + for i := 0; i < 2; i++ { + req, err = conn.ReadRequest() require.NoError(t, err) - defer l1.Close() + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL( + scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID="+strconv.FormatInt(int64(i), 10)), req.URL) - l2, err = net.ListenPacket("udp", listenIP+":34557") - require.NoError(t, err) - defer l2.Close() - - case "multicast": - v := headers.TransportDeliveryMulticast - th.Delivery = &v - th.Protocol = headers.TransportProtocolUDP - v2 := net.ParseIP("224.1.0.1") - th.Destination = &v2 - th.Ports = &[2]int{25000, 25001} - - l1, err = net.ListenPacket("udp", "224.0.0.0:25000") - require.NoError(t, err) - defer l1.Close() - - p := ipv4.NewPacketConn(l1) - - intfs, err := net.Interfaces() + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) require.NoError(t, err) - for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) + var th headers.Transport + + switch transport { + case "udp": + v := headers.TransportDeliveryUnicast + th.Delivery = &v + th.Protocol = headers.TransportProtocolUDP + th.ClientPorts = inTH.ClientPorts + clientPorts[i] = inTH.ClientPorts + th.ServerPorts = &[2]int{34556 + i*2, 34557 + i*2} + + l1s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[0]), 10)) require.NoError(t, err) + defer l1s[i].Close() + + l2s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[1]), 10)) + require.NoError(t, err) + defer l2s[i].Close() + + case "multicast": + v := headers.TransportDeliveryMulticast + th.Delivery = &v + th.Protocol = headers.TransportProtocolUDP + v2 := net.ParseIP("224.1.0.1") + th.Destination = &v2 + th.Ports = &[2]int{25000 + i*2, 25001 + i*2} + + l1s[i], err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10)) + require.NoError(t, err) + defer l1s[i].Close() + + p := ipv4.NewPacketConn(l1s[i]) + + intfs, err := net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) + require.NoError(t, err) + } + + l2s[i], err = net.ListenPacket("udp", "224.0.0.0:25001") + require.NoError(t, err) + defer l2s[i].Close() + + p = ipv4.NewPacketConn(l2s[i]) + + intfs, err = net.Interfaces() + require.NoError(t, err) + + for _, intf := range intfs { + err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) + require.NoError(t, err) + } + + case "tcp", "tls": + v := headers.TransportDeliveryUnicast + th.Delivery = &v + th.Protocol = headers.TransportProtocolTCP + th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2} } - l2, err = net.ListenPacket("udp", "224.0.0.0:25001") + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) require.NoError(t, err) - defer l2.Close() - - p = ipv4.NewPacketConn(l2) - - intfs, err = net.Interfaces() - require.NoError(t, err) - - for _, intf := range intfs { - err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")}) - require.NoError(t, err) - } - - case "tcp", "tls": - v := headers.TransportDeliveryUnicast - th.Delivery = &v - th.Protocol = headers.TransportProtocolTCP - th.InterleavedIDs = &[2]int{0, 1} } - err = conn.WriteResponse(&base.Response{ - StatusCode: base.StatusOK, - Header: base.Header{ - "Transport": th.Marshal(), - }, - }) - require.NoError(t, err) - req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Play, req.Method) @@ -382,56 +394,58 @@ func TestClientPlay(t *testing.T) { }) require.NoError(t, err) - // server -> client (RTP) - switch transport { - case "udp": - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) + for i := 0; i < 2; i++ { + // server -> client (RTP) + switch transport { + case "udp": + l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: clientPorts[i][0], + }) - case "multicast": - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("224.1.0.1"), - Port: 25000, - }) + case "multicast": + l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("224.1.0.1"), + Port: 25000, + }) - case "tcp", "tls": - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 0, - Payload: testRTPPacketMarshaled, - }, make([]byte, 1024)) - require.NoError(t, err) - } - - // client -> server (RTCP) - switch transport { - case "udp", "multicast": - // skip firewall opening - if transport == "udp" { - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) + case "tcp", "tls": + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0 + i*2, + Payload: testRTPPacketMarshaled, + }, make([]byte, 1024)) require.NoError(t, err) } - buf := make([]byte, 2048) - n, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - packets, err := rtcp.Unmarshal(buf[:n]) - require.NoError(t, err) - require.Equal(t, &testRTCPPacket, packets[0]) - close(packetRecv) + // client -> server (RTCP) + switch transport { + case "udp", "multicast": + // skip firewall opening + if transport == "udp" { + buf := make([]byte, 2048) + _, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + } - case "tcp", "tls": - f, err := conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 1, f.Channel) - packets, err := rtcp.Unmarshal(f.Payload) - require.NoError(t, err) - require.Equal(t, &testRTCPPacket, packets[0]) - close(packetRecv) + buf := make([]byte, 2048) + n, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + packets, err := rtcp.Unmarshal(buf[:n]) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) + + case "tcp", "tls": + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 1+i*2, f.Channel) + packets, err := rtcp.Unmarshal(f.Payload) + require.NoError(t, err) + require.Equal(t, &testRTCPPacket, packets[0]) + } } + close(packetRecv) + req, err = conn.ReadRequest() require.NoError(t, err) require.Equal(t, base.Teardown, req.Method) @@ -464,16 +478,28 @@ func TestClientPlay(t *testing.T) { }(), } - err = readAll(&c, - scheme+"://"+listenIP+":8554/test/stream?param=value", - func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { - require.Equal(t, &testRTPPacket, pkt) - err := c.WritePacketRTCP(medi, &testRTCPPacket) - require.NoError(t, err) - }) + u, err := url.Parse(scheme + "://" + listenIP + ":8554/test/stream?param=value") + require.NoError(t, err) + + err = c.Start(u.Scheme, u.Host) require.NoError(t, err) defer c.Close() + medias, baseURL, _, err := c.Describe(u) + require.NoError(t, err) + + err = c.SetupAll(medias, baseURL) + require.NoError(t, err) + + c.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { + require.Equal(t, &testRTPPacket, pkt) + err := c.WritePacketRTCP(medi, &testRTCPPacket) + require.NoError(t, err) + }) + + _, err = c.Play(nil) + require.NoError(t, err) + <-packetRecv }) } diff --git a/clientmedia.go b/clientmedia.go index d9e62d1a..5ee754eb 100644 --- a/clientmedia.go +++ b/clientmedia.go @@ -6,6 +6,7 @@ import ( "time" "github.com/pion/rtcp" + "github.com/pion/rtp" "github.com/aler9/gortsplib/v2/pkg/base" "github.com/aler9/gortsplib/v2/pkg/media" @@ -187,7 +188,7 @@ func (cm *clientMedia) readRTPTCPPlay(payload []byte) error { now := time.Now() atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix()) - pkt := cm.c.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { return err @@ -259,7 +260,7 @@ func (cm *clientMedia) readRTPUDPPlay(payload []byte) error { return nil } - pkt := cm.c.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { cm.c.OnDecodeError(err) diff --git a/rtppacketmultibuffer.go b/rtppacketmultibuffer.go deleted file mode 100644 index a36f2233..00000000 --- a/rtppacketmultibuffer.go +++ /dev/null @@ -1,24 +0,0 @@ -package gortsplib - -import ( - "github.com/pion/rtp" -) - -type rtpPacketMultiBuffer struct { - count uint64 - buffers []rtp.Packet - cur uint64 -} - -func newRTPPacketMultiBuffer(count uint64) *rtpPacketMultiBuffer { - return &rtpPacketMultiBuffer{ - count: count, - buffers: make([]rtp.Packet, count), - } -} - -func (mb *rtpPacketMultiBuffer) next() *rtp.Packet { - ret := &mb.buffers[mb.cur%mb.count] - mb.cur++ - return ret -} diff --git a/server_record_test.go b/server_record_test.go index d9022594..9c5b7347 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "net" + "strconv" "testing" "time" @@ -499,18 +500,24 @@ func TestServerRecord(t *testing.T) { // send RTCP packets directly to the session. // these are sent after the response, only if onRecord returns StatusOK. ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) + ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[1], &testRTCPPacket) - ctx.Session.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { - require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) - require.Equal(t, ctx.Session.AnnouncedMedias()[0].Formats[0], forma) - require.Equal(t, &testRTPPacket, pkt) - }) + for i := 0; i < 2; i++ { + ctx.Session.OnPacketRTP( + ctx.Session.AnnouncedMedias()[i], + ctx.Session.AnnouncedMedias()[i].Formats[0], + func(pkt *rtp.Packet) { + require.Equal(t, &testRTPPacket, pkt) + }) - ctx.Session.OnPacketRTCPAny(func(medi *media.Media, pkt rtcp.Packet) { - require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) - require.Equal(t, &testRTCPPacket, pkt) - ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) - }) + ci := i + ctx.Session.OnPacketRTCP( + ctx.Session.AnnouncedMedias()[i], + func(pkt rtcp.Packet) { + require.Equal(t, &testRTCPPacket, pkt) + ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[ci], &testRTCPPacket) + }) + } return &base.Response{ StatusCode: base.StatusOK, @@ -549,7 +556,7 @@ func TestServerRecord(t *testing.T) { <-nconnOpened - medias := media.Medias{testH264Media.Clone()} + medias := media.Medias{testH264Media.Clone(), testH264Media.Clone()} medias.SetControls() res, err := writeReqReadRes(conn, base.Request{ @@ -566,54 +573,61 @@ func TestServerRecord(t *testing.T) { <-sessionOpened - inTH := &headers.Transport{ - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Mode: func() *headers.TransportMode { - v := headers.TransportModeRecord - return &v - }(), - } - - var l1 net.PacketConn - var l2 net.PacketConn - - if transport == "udp" { - inTH.Protocol = headers.TransportProtocolUDP - inTH.ClientPorts = &[2]int{35466, 35467} - - l1, err = net.ListenPacket("udp", "localhost:35466") - require.NoError(t, err) - defer l1.Close() - - l2, err = net.ListenPacket("udp", "localhost:35467") - require.NoError(t, err) - defer l2.Close() - } else { - inTH.Protocol = headers.TransportProtocolTCP - inTH.InterleavedIDs = &[2]int{0, 1} - } - - res, err = writeReqReadRes(conn, base.Request{ - Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=0"), - Header: base.Header{ - "CSeq": base.HeaderValue{"2"}, - "Transport": inTH.Marshal(), - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - + var l1s [2]net.PacketConn + var l2s [2]net.PacketConn var sx headers.Session - err = sx.Unmarshal(res.Header["Session"]) - require.NoError(t, err) + var serverPorts [2]*[2]int - var th headers.Transport - err = th.Unmarshal(res.Header["Transport"]) - require.NoError(t, err) + for i := 0; i < 2; i++ { + inTH := &headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModeRecord + return &v + }(), + } + + if transport == "udp" { + inTH.Protocol = headers.TransportProtocolUDP + inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2} + + l1s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[0]), 10)) + require.NoError(t, err) + defer l1s[i].Close() + + l2s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[1]), 10)) + require.NoError(t, err) + defer l2s[i].Close() + } else { + inTH.Protocol = headers.TransportProtocolTCP + inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2} + } + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=" + strconv.FormatInt(int64(i), 10)), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + err = sx.Unmarshal(res.Header["Session"]) + require.NoError(t, err) + + var th headers.Transport + err = th.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) + + if transport == "udp" { + serverPorts[i] = th.ServerPorts + } + } res, err = writeReqReadRes(conn, base.Request{ Method: base.Record, @@ -626,62 +640,66 @@ func TestServerRecord(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - // server -> client (direct) - if transport == "udp" { - buf := make([]byte, 2048) - n, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) - } else { - f, err := conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 1, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + for i := 0; i < 2; i++ { + // server -> client (direct) + if transport == "udp" { + buf := make([]byte, 2048) + n, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + } else { + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 3+i*2, f.Channel) + require.Equal(t, testRTCPPacketMarshaled, f.Payload) + } + + // skip firewall opening + if transport == "udp" { + buf := make([]byte, 2048) + _, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + } + + // client -> server + if transport == "udp" { + l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: serverPorts[i][0], + }) + + l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: serverPorts[i][1], + }) + } else { + err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 2 + i*2, + Payload: testRTPPacketMarshaled, + }, make([]byte, 1024)) + require.NoError(t, err) + + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 3 + i*2, + Payload: testRTCPPacketMarshaled, + }, make([]byte, 1024)) + require.NoError(t, err) + } } - // skip firewall opening - if transport == "udp" { - buf := make([]byte, 2048) - _, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - } - - // client -> server - if transport == "udp" { - l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ServerPorts[0], - }) - - l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ServerPorts[1], - }) - } else { - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 0, - Payload: testRTPPacketMarshaled, - }, make([]byte, 1024)) - require.NoError(t, err) - - err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 1, - Payload: testRTCPPacketMarshaled, - }, make([]byte, 1024)) - require.NoError(t, err) - } - - // server -> client (RTCP) - if transport == "udp" { - buf := make([]byte, 2048) - n, _, err := l2.ReadFrom(buf) - require.NoError(t, err) - require.Equal(t, testRTCPPacketMarshaled, buf[:n]) - } else { - f, err := conn.ReadInterleavedFrame() - require.NoError(t, err) - require.Equal(t, 1, f.Channel) - require.Equal(t, testRTCPPacketMarshaled, f.Payload) + for i := 0; i < 2; i++ { + // server -> client (RTCP) + if transport == "udp" { + buf := make([]byte, 2048) + n, _, err := l2s[i].ReadFrom(buf) + require.NoError(t, err) + require.Equal(t, testRTCPPacketMarshaled, buf[:n]) + } else { + f, err := conn.ReadInterleavedFrame() + require.NoError(t, err) + require.Equal(t, 3+i*2, f.Channel) + require.Equal(t, testRTCPPacketMarshaled, f.Payload) + } } res, err = writeReqReadRes(conn, base.Request{ diff --git a/serversession.go b/serversession.go index d6fc31d8..e844bbce 100644 --- a/serversession.go +++ b/serversession.go @@ -173,7 +173,6 @@ type ServerSession struct { udpLastPacketTime *int64 // publish udpCheckStreamTimer *time.Timer writer writer - rtpPacketBuffer *rtpPacketMultiBuffer // in request chan sessionRequestReq @@ -948,8 +947,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStateRecord - ss.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(ss.s.ReadBufferCount)) - for _, sm := range ss.setuppedMedias { sm.start() } diff --git a/serversessionmedia.go b/serversessionmedia.go index ed7916bc..e44ce85b 100644 --- a/serversessionmedia.go +++ b/serversessionmedia.go @@ -185,7 +185,7 @@ func (sm *serverSessionMedia) readRTPUDPRecord(payload []byte) error { return nil } - pkt := sm.ss.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { onDecodeError(sm.ss, err) @@ -265,7 +265,7 @@ func (sm *serverSessionMedia) readRTCPTCPPlay(payload []byte) error { } func (sm *serverSessionMedia) readRTPTCPRecord(payload []byte) error { - pkt := sm.ss.rtpPacketBuffer.next() + pkt := &rtp.Packet{} err := pkt.Unmarshal(payload) if err != nil { return err diff --git a/serverstream.go b/serverstream.go index bc351d56..30ab9938 100644 --- a/serverstream.go +++ b/serverstream.go @@ -292,11 +292,6 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet // WritePacketRTCP writes a RTCP packet to all the readers of the stream. func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { - byts, err := pkt.Marshal() - if err != nil { - return - } - st.mutex.RLock() defer st.mutex.RUnlock() @@ -305,17 +300,5 @@ func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { } sm := st.streamMedias[medi] - - // send unicast - for r := range st.activeUnicastReaders { - sm, ok := r.setuppedMedias[medi] - if ok { - sm.writePacketRTCP(byts) - } - } - - // send multicast - if sm.multicastHandler != nil { - sm.multicastHandler.writePacketRTCP(byts) - } + sm.writePacketRTCP(st, pkt) } diff --git a/serverstreammedia.go b/serverstreammedia.go index 8832b5ea..04330210 100644 --- a/serverstreammedia.go +++ b/serverstreammedia.go @@ -3,6 +3,7 @@ package gortsplib import ( "time" + "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/aler9/gortsplib/v2/pkg/media" @@ -69,3 +70,23 @@ func (sm *serverStreamMedia) WritePacketRTPWithNTP(ss *ServerStream, pkt *rtp.Pa sm.multicastHandler.writePacketRTP(byts) } } + +func (sm *serverStreamMedia) writePacketRTCP(ss *ServerStream, pkt rtcp.Packet) { + byts, err := pkt.Marshal() + if err != nil { + return + } + + // send unicast + for r := range ss.activeUnicastReaders { + sm, ok := r.setuppedMedias[sm.media] + if ok { + sm.writePacketRTCP(byts) + } + } + + // send multicast + if sm.multicastHandler != nil { + sm.multicastHandler.writePacketRTCP(byts) + } +}