fix overriding of previously-received RTP packets that leaded to crashes

RTP packets were previously take from a buffer pool. This was messing
up the Client, since that buffer pool was used by multiple routines at
once, and was probably messing up the Server too, since packets can be
pushed to different queues and there's no guarantee that these queues
have an overall size less than ReadBufferCount.

This buffer pool is removed; this decreases performance but avoids bugs.
This commit is contained in:
aler9
2022-12-19 13:46:43 +01:00
parent b3de3cf80e
commit ffe8c87c38
9 changed files with 305 additions and 285 deletions

View File

@@ -236,7 +236,6 @@ type Client struct {
medias map[*media.Media]*clientMedia medias map[*media.Media]*clientMedia
tcpMediasByChannel map[int]*clientMedia tcpMediasByChannel map[int]*clientMedia
lastRange *headers.Range lastRange *headers.Range
rtpPacketBuffer *rtpPacketMultiBuffer // play
checkStreamTimer *time.Timer checkStreamTimer *time.Timer
checkStreamInitial bool checkStreamInitial bool
tcpLastFrameTime *int64 tcpLastFrameTime *int64
@@ -630,7 +629,6 @@ func (c *Client) playRecordStart() {
if c.state == clientStatePlay { if c.state == clientStatePlay {
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
c.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(c.ReadBufferCount))
switch *c.effectiveTransport { switch *c.effectiveTransport {
case TransportUDP: case TransportUDP:

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -274,10 +275,16 @@ func TestClientPlay(t *testing.T) {
err = forma.Init() err = forma.Init()
require.NoError(t, err) require.NoError(t, err)
medias := media.Medias{&media.Media{ medias := media.Medias{
&media.Media{
Type: "application", Type: "application",
Formats: []format.Format{forma}, Formats: []format.Format{forma},
}} },
&media.Media{
Type: "application",
Formats: []format.Format{forma},
},
}
medias.SetControls() medias.SetControls()
err = conn.WriteResponse(&base.Response{ err = conn.WriteResponse(&base.Response{
@@ -290,19 +297,22 @@ func TestClientPlay(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
var l1s [2]net.PacketConn
var l2s [2]net.PacketConn
var clientPorts [2]*[2]int
for i := 0; i < 2; i++ {
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID=0"), req.URL) require.Equal(t, mustParseURL(
scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID="+strconv.FormatInt(int64(i), 10)), req.URL)
var inTH headers.Transport var inTH headers.Transport
err = inTH.Unmarshal(req.Header["Transport"]) err = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
th := headers.Transport{} var th headers.Transport
var l1 net.PacketConn
var l2 net.PacketConn
switch transport { switch transport {
case "udp": case "udp":
@@ -310,15 +320,16 @@ func TestClientPlay(t *testing.T) {
th.Delivery = &v th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = inTH.ClientPorts th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{34556, 34557} clientPorts[i] = inTH.ClientPorts
th.ServerPorts = &[2]int{34556 + i*2, 34557 + i*2}
l1, err = net.ListenPacket("udp", listenIP+":34556") l1s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[0]), 10))
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1s[i].Close()
l2, err = net.ListenPacket("udp", listenIP+":34557") l2s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[1]), 10))
require.NoError(t, err) require.NoError(t, err)
defer l2.Close() defer l2s[i].Close()
case "multicast": case "multicast":
v := headers.TransportDeliveryMulticast v := headers.TransportDeliveryMulticast
@@ -326,13 +337,13 @@ func TestClientPlay(t *testing.T) {
th.Protocol = headers.TransportProtocolUDP th.Protocol = headers.TransportProtocolUDP
v2 := net.ParseIP("224.1.0.1") v2 := net.ParseIP("224.1.0.1")
th.Destination = &v2 th.Destination = &v2
th.Ports = &[2]int{25000, 25001} th.Ports = &[2]int{25000 + i*2, 25001 + i*2}
l1, err = net.ListenPacket("udp", "224.0.0.0:25000") l1s[i], err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10))
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1s[i].Close()
p := ipv4.NewPacketConn(l1) p := ipv4.NewPacketConn(l1s[i])
intfs, err := net.Interfaces() intfs, err := net.Interfaces()
require.NoError(t, err) require.NoError(t, err)
@@ -342,11 +353,11 @@ func TestClientPlay(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
l2, err = net.ListenPacket("udp", "224.0.0.0:25001") l2s[i], err = net.ListenPacket("udp", "224.0.0.0:25001")
require.NoError(t, err) require.NoError(t, err)
defer l2.Close() defer l2s[i].Close()
p = ipv4.NewPacketConn(l2) p = ipv4.NewPacketConn(l2s[i])
intfs, err = net.Interfaces() intfs, err = net.Interfaces()
require.NoError(t, err) require.NoError(t, err)
@@ -360,7 +371,7 @@ func TestClientPlay(t *testing.T) {
v := headers.TransportDeliveryUnicast v := headers.TransportDeliveryUnicast
th.Delivery = &v th.Delivery = &v
th.Protocol = headers.TransportProtocolTCP th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = &[2]int{0, 1} th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2}
} }
err = conn.WriteResponse(&base.Response{ err = conn.WriteResponse(&base.Response{
@@ -370,6 +381,7 @@ func TestClientPlay(t *testing.T) {
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
}
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
require.NoError(t, err) require.NoError(t, err)
@@ -382,23 +394,24 @@ func TestClientPlay(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
for i := 0; i < 2; i++ {
// server -> client (RTP) // server -> client (RTP)
switch transport { switch transport {
case "udp": case "udp":
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0], Port: clientPorts[i][0],
}) })
case "multicast": case "multicast":
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("224.1.0.1"), IP: net.ParseIP("224.1.0.1"),
Port: 25000, Port: 25000,
}) })
case "tcp", "tls": case "tcp", "tls":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0, Channel: 0 + i*2,
Payload: testRTPPacketMarshaled, Payload: testRTPPacketMarshaled,
}, make([]byte, 1024)) }, make([]byte, 1024))
require.NoError(t, err) require.NoError(t, err)
@@ -410,27 +423,28 @@ func TestClientPlay(t *testing.T) {
// skip firewall opening // skip firewall opening
if transport == "udp" { if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf) _, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
} }
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf) n, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
packets, err := rtcp.Unmarshal(buf[:n]) packets, err := rtcp.Unmarshal(buf[:n])
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0]) require.Equal(t, &testRTCPPacket, packets[0])
close(packetRecv)
case "tcp", "tls": case "tcp", "tls":
f, err := conn.ReadInterleavedFrame() f, err := conn.ReadInterleavedFrame()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) require.Equal(t, 1+i*2, f.Channel)
packets, err := rtcp.Unmarshal(f.Payload) packets, err := rtcp.Unmarshal(f.Payload)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0]) require.Equal(t, &testRTCPPacket, packets[0])
close(packetRecv)
} }
}
close(packetRecv)
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
require.NoError(t, err) require.NoError(t, err)
@@ -464,15 +478,27 @@ func TestClientPlay(t *testing.T) {
}(), }(),
} }
err = readAll(&c, u, err := url.Parse(scheme + "://" + listenIP + ":8554/test/stream?param=value")
scheme+"://"+listenIP+":8554/test/stream?param=value", require.NoError(t, err)
func(medi *media.Media, forma format.Format, pkt *rtp.Packet) {
err = c.Start(u.Scheme, u.Host)
require.NoError(t, err)
defer c.Close()
medias, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAll(medias, baseURL)
require.NoError(t, err)
c.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt) require.Equal(t, &testRTPPacket, pkt)
err := c.WritePacketRTCP(medi, &testRTCPPacket) err := c.WritePacketRTCP(medi, &testRTCPPacket)
require.NoError(t, err) require.NoError(t, err)
}) })
_, err = c.Play(nil)
require.NoError(t, err) require.NoError(t, err)
defer c.Close()
<-packetRecv <-packetRecv
}) })

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/aler9/gortsplib/v2/pkg/base" "github.com/aler9/gortsplib/v2/pkg/base"
"github.com/aler9/gortsplib/v2/pkg/media" "github.com/aler9/gortsplib/v2/pkg/media"
@@ -187,7 +188,7 @@ func (cm *clientMedia) readRTPTCPPlay(payload []byte) error {
now := time.Now() now := time.Now()
atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix()) atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix())
pkt := cm.c.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
return err return err
@@ -259,7 +260,7 @@ func (cm *clientMedia) readRTPUDPPlay(payload []byte) error {
return nil return nil
} }
pkt := cm.c.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
cm.c.OnDecodeError(err) cm.c.OnDecodeError(err)

View File

@@ -1,24 +0,0 @@
package gortsplib
import (
"github.com/pion/rtp"
)
type rtpPacketMultiBuffer struct {
count uint64
buffers []rtp.Packet
cur uint64
}
func newRTPPacketMultiBuffer(count uint64) *rtpPacketMultiBuffer {
return &rtpPacketMultiBuffer{
count: count,
buffers: make([]rtp.Packet, count),
}
}
func (mb *rtpPacketMultiBuffer) next() *rtp.Packet {
ret := &mb.buffers[mb.cur%mb.count]
mb.cur++
return ret
}

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"strconv"
"testing" "testing"
"time" "time"
@@ -499,18 +500,24 @@ func TestServerRecord(t *testing.T) {
// send RTCP packets directly to the session. // send RTCP packets directly to the session.
// these are sent after the response, only if onRecord returns StatusOK. // these are sent after the response, only if onRecord returns StatusOK.
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket)
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[1], &testRTCPPacket)
ctx.Session.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { for i := 0; i < 2; i++ {
require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) ctx.Session.OnPacketRTP(
require.Equal(t, ctx.Session.AnnouncedMedias()[0].Formats[0], forma) ctx.Session.AnnouncedMedias()[i],
ctx.Session.AnnouncedMedias()[i].Formats[0],
func(pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt) require.Equal(t, &testRTPPacket, pkt)
}) })
ctx.Session.OnPacketRTCPAny(func(medi *media.Media, pkt rtcp.Packet) { ci := i
require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) ctx.Session.OnPacketRTCP(
ctx.Session.AnnouncedMedias()[i],
func(pkt rtcp.Packet) {
require.Equal(t, &testRTCPPacket, pkt) require.Equal(t, &testRTCPPacket, pkt)
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[ci], &testRTCPPacket)
}) })
}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -549,7 +556,7 @@ func TestServerRecord(t *testing.T) {
<-nconnOpened <-nconnOpened
medias := media.Medias{testH264Media.Clone()} medias := media.Medias{testH264Media.Clone(), testH264Media.Clone()}
medias.SetControls() medias.SetControls()
res, err := writeReqReadRes(conn, base.Request{ res, err := writeReqReadRes(conn, base.Request{
@@ -566,6 +573,12 @@ func TestServerRecord(t *testing.T) {
<-sessionOpened <-sessionOpened
var l1s [2]net.PacketConn
var l2s [2]net.PacketConn
var sx headers.Session
var serverPorts [2]*[2]int
for i := 0; i < 2; i++ {
inTH := &headers.Transport{ inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery { Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast v := headers.TransportDeliveryUnicast
@@ -577,28 +590,25 @@ func TestServerRecord(t *testing.T) {
}(), }(),
} }
var l1 net.PacketConn
var l2 net.PacketConn
if transport == "udp" { if transport == "udp" {
inTH.Protocol = headers.TransportProtocolUDP inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467} inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2}
l1, err = net.ListenPacket("udp", "localhost:35466") l1s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[0]), 10))
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() defer l1s[i].Close()
l2, err = net.ListenPacket("udp", "localhost:35467") l2s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[1]), 10))
require.NoError(t, err) require.NoError(t, err)
defer l2.Close() defer l2s[i].Close()
} else { } else {
inTH.Protocol = headers.TransportProtocolTCP inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{0, 1} inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2}
} }
res, err = writeReqReadRes(conn, base.Request{ res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup, Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=0"), URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=" + strconv.FormatInt(int64(i), 10)),
Header: base.Header{ Header: base.Header{
"CSeq": base.HeaderValue{"2"}, "CSeq": base.HeaderValue{"2"},
"Transport": inTH.Marshal(), "Transport": inTH.Marshal(),
@@ -607,7 +617,6 @@ func TestServerRecord(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
var sx headers.Session
err = sx.Unmarshal(res.Header["Session"]) err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err) require.NoError(t, err)
@@ -615,6 +624,11 @@ func TestServerRecord(t *testing.T) {
err = th.Unmarshal(res.Header["Transport"]) err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
if transport == "udp" {
serverPorts[i] = th.ServerPorts
}
}
res, err = writeReqReadRes(conn, base.Request{ res, err = writeReqReadRes(conn, base.Request{
Method: base.Record, Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"), URL: mustParseURL("rtsp://localhost:8554/teststream"),
@@ -626,63 +640,67 @@ func TestServerRecord(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
for i := 0; i < 2; i++ {
// server -> client (direct) // server -> client (direct)
if transport == "udp" { if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf) n, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else { } else {
f, err := conn.ReadInterleavedFrame() f, err := conn.ReadInterleavedFrame()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload) require.Equal(t, testRTCPPacketMarshaled, f.Payload)
} }
// skip firewall opening // skip firewall opening
if transport == "udp" { if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf) _, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
} }
// client -> server // client -> server
if transport == "udp" { if transport == "udp" {
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[0], Port: serverPorts[i][0],
}) })
l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{ l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[1], Port: serverPorts[i][1],
}) })
} else { } else {
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0, Channel: 2 + i*2,
Payload: testRTPPacketMarshaled, Payload: testRTPPacketMarshaled,
}, make([]byte, 1024)) }, make([]byte, 1024))
require.NoError(t, err) require.NoError(t, err)
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1, Channel: 3 + i*2,
Payload: testRTCPPacketMarshaled, Payload: testRTCPPacketMarshaled,
}, make([]byte, 1024)) }, make([]byte, 1024))
require.NoError(t, err) require.NoError(t, err)
} }
}
for i := 0; i < 2; i++ {
// server -> client (RTCP) // server -> client (RTCP)
if transport == "udp" { if transport == "udp" {
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf) n, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else { } else {
f, err := conn.ReadInterleavedFrame() f, err := conn.ReadInterleavedFrame()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload) require.Equal(t, testRTCPPacketMarshaled, f.Payload)
} }
}
res, err = writeReqReadRes(conn, base.Request{ res, err = writeReqReadRes(conn, base.Request{
Method: base.Teardown, Method: base.Teardown,

View File

@@ -173,7 +173,6 @@ type ServerSession struct {
udpLastPacketTime *int64 // publish udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer udpCheckStreamTimer *time.Timer
writer writer writer writer
rtpPacketBuffer *rtpPacketMultiBuffer
// in // in
request chan sessionRequestReq request chan sessionRequestReq
@@ -948,8 +947,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStateRecord ss.state = ServerSessionStateRecord
ss.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(ss.s.ReadBufferCount))
for _, sm := range ss.setuppedMedias { for _, sm := range ss.setuppedMedias {
sm.start() sm.start()
} }

View File

@@ -185,7 +185,7 @@ func (sm *serverSessionMedia) readRTPUDPRecord(payload []byte) error {
return nil return nil
} }
pkt := sm.ss.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
onDecodeError(sm.ss, err) onDecodeError(sm.ss, err)
@@ -265,7 +265,7 @@ func (sm *serverSessionMedia) readRTCPTCPPlay(payload []byte) error {
} }
func (sm *serverSessionMedia) readRTPTCPRecord(payload []byte) error { func (sm *serverSessionMedia) readRTPTCPRecord(payload []byte) error {
pkt := sm.ss.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
return err return err

View File

@@ -292,11 +292,6 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet
// WritePacketRTCP writes a RTCP packet to all the readers of the stream. // WritePacketRTCP writes a RTCP packet to all the readers of the stream.
func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) {
byts, err := pkt.Marshal()
if err != nil {
return
}
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
@@ -305,17 +300,5 @@ func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) {
} }
sm := st.streamMedias[medi] sm := st.streamMedias[medi]
sm.writePacketRTCP(st, pkt)
// send unicast
for r := range st.activeUnicastReaders {
sm, ok := r.setuppedMedias[medi]
if ok {
sm.writePacketRTCP(byts)
}
}
// send multicast
if sm.multicastHandler != nil {
sm.multicastHandler.writePacketRTCP(byts)
}
} }

View File

@@ -3,6 +3,7 @@ package gortsplib
import ( import (
"time" "time"
"github.com/pion/rtcp"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/aler9/gortsplib/v2/pkg/media" "github.com/aler9/gortsplib/v2/pkg/media"
@@ -69,3 +70,23 @@ func (sm *serverStreamMedia) WritePacketRTPWithNTP(ss *ServerStream, pkt *rtp.Pa
sm.multicastHandler.writePacketRTP(byts) sm.multicastHandler.writePacketRTP(byts)
} }
} }
func (sm *serverStreamMedia) writePacketRTCP(ss *ServerStream, pkt rtcp.Packet) {
byts, err := pkt.Marshal()
if err != nil {
return
}
// send unicast
for r := range ss.activeUnicastReaders {
sm, ok := r.setuppedMedias[sm.media]
if ok {
sm.writePacketRTCP(byts)
}
}
// send multicast
if sm.multicastHandler != nil {
sm.multicastHandler.writePacketRTCP(byts)
}
}