fix overriding of previously-received RTP packets that leaded to crashes

RTP packets were previously take from a buffer pool. This was messing
up the Client, since that buffer pool was used by multiple routines at
once, and was probably messing up the Server too, since packets can be
pushed to different queues and there's no guarantee that these queues
have an overall size less than ReadBufferCount.

This buffer pool is removed; this decreases performance but avoids bugs.
This commit is contained in:
aler9
2022-12-19 13:46:43 +01:00
parent b3de3cf80e
commit ffe8c87c38
9 changed files with 305 additions and 285 deletions

View File

@@ -236,7 +236,6 @@ type Client struct {
medias map[*media.Media]*clientMedia medias map[*media.Media]*clientMedia
tcpMediasByChannel map[int]*clientMedia tcpMediasByChannel map[int]*clientMedia
lastRange *headers.Range lastRange *headers.Range
rtpPacketBuffer *rtpPacketMultiBuffer // play
checkStreamTimer *time.Timer checkStreamTimer *time.Timer
checkStreamInitial bool checkStreamInitial bool
tcpLastFrameTime *int64 tcpLastFrameTime *int64
@@ -630,7 +629,6 @@ func (c *Client) playRecordStart() {
if c.state == clientStatePlay { if c.state == clientStatePlay {
c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) c.keepaliveTimer = time.NewTimer(c.keepalivePeriod)
c.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(c.ReadBufferCount))
switch *c.effectiveTransport { switch *c.effectiveTransport {
case TransportUDP: case TransportUDP:

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
"strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -274,10 +275,16 @@ func TestClientPlay(t *testing.T) {
err = forma.Init() err = forma.Init()
require.NoError(t, err) require.NoError(t, err)
medias := media.Medias{&media.Media{ medias := media.Medias{
Type: "application", &media.Media{
Formats: []format.Format{forma}, Type: "application",
}} Formats: []format.Format{forma},
},
&media.Media{
Type: "application",
Formats: []format.Format{forma},
},
}
medias.SetControls() medias.SetControls()
err = conn.WriteResponse(&base.Response{ err = conn.WriteResponse(&base.Response{
@@ -290,87 +297,92 @@ func TestClientPlay(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
req, err = conn.ReadRequest() var l1s [2]net.PacketConn
require.NoError(t, err) var l2s [2]net.PacketConn
require.Equal(t, base.Setup, req.Method) var clientPorts [2]*[2]int
require.Equal(t, mustParseURL(scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID=0"), req.URL)
var inTH headers.Transport for i := 0; i < 2; i++ {
err = inTH.Unmarshal(req.Header["Transport"]) req, err = conn.ReadRequest()
require.NoError(t, err)
th := headers.Transport{}
var l1 net.PacketConn
var l2 net.PacketConn
switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{34556, 34557}
l1, err = net.ListenPacket("udp", listenIP+":34556")
require.NoError(t, err) require.NoError(t, err)
defer l1.Close() require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL(
scheme+"://"+listenIP+":8554/test/stream?param=value/mediaID="+strconv.FormatInt(int64(i), 10)), req.URL)
l2, err = net.ListenPacket("udp", listenIP+":34557") var inTH headers.Transport
require.NoError(t, err) err = inTH.Unmarshal(req.Header["Transport"])
defer l2.Close()
case "multicast":
v := headers.TransportDeliveryMulticast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
v2 := net.ParseIP("224.1.0.1")
th.Destination = &v2
th.Ports = &[2]int{25000, 25001}
l1, err = net.ListenPacket("udp", "224.0.0.0:25000")
require.NoError(t, err)
defer l1.Close()
p := ipv4.NewPacketConn(l1)
intfs, err := net.Interfaces()
require.NoError(t, err) require.NoError(t, err)
for _, intf := range intfs { var th headers.Transport
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
switch transport {
case "udp":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
th.ClientPorts = inTH.ClientPorts
clientPorts[i] = inTH.ClientPorts
th.ServerPorts = &[2]int{34556 + i*2, 34557 + i*2}
l1s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[0]), 10))
require.NoError(t, err) require.NoError(t, err)
defer l1s[i].Close()
l2s[i], err = net.ListenPacket("udp", listenIP+":"+strconv.FormatInt(int64(th.ServerPorts[1]), 10))
require.NoError(t, err)
defer l2s[i].Close()
case "multicast":
v := headers.TransportDeliveryMulticast
th.Delivery = &v
th.Protocol = headers.TransportProtocolUDP
v2 := net.ParseIP("224.1.0.1")
th.Destination = &v2
th.Ports = &[2]int{25000 + i*2, 25001 + i*2}
l1s[i], err = net.ListenPacket("udp", "224.0.0.0:"+strconv.FormatInt(int64(th.Ports[0]), 10))
require.NoError(t, err)
defer l1s[i].Close()
p := ipv4.NewPacketConn(l1s[i])
intfs, err := net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
require.NoError(t, err)
}
l2s[i], err = net.ListenPacket("udp", "224.0.0.0:25001")
require.NoError(t, err)
defer l2s[i].Close()
p = ipv4.NewPacketConn(l2s[i])
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
require.NoError(t, err)
}
case "tcp", "tls":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = &[2]int{0 + i*2, 1 + i*2}
} }
l2, err = net.ListenPacket("udp", "224.0.0.0:25001") err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err) require.NoError(t, err)
defer l2.Close()
p = ipv4.NewPacketConn(l2)
intfs, err = net.Interfaces()
require.NoError(t, err)
for _, intf := range intfs {
err := p.JoinGroup(&intf, &net.UDPAddr{IP: net.ParseIP("224.1.0.1")})
require.NoError(t, err)
}
case "tcp", "tls":
v := headers.TransportDeliveryUnicast
th.Delivery = &v
th.Protocol = headers.TransportProtocolTCP
th.InterleavedIDs = &[2]int{0, 1}
} }
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err)
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Play, req.Method) require.Equal(t, base.Play, req.Method)
@@ -382,56 +394,58 @@ func TestClientPlay(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
// server -> client (RTP) for i := 0; i < 2; i++ {
switch transport { // server -> client (RTP)
case "udp": switch transport {
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ case "udp":
IP: net.ParseIP("127.0.0.1"), l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
Port: th.ClientPorts[0], IP: net.ParseIP("127.0.0.1"),
}) Port: clientPorts[i][0],
})
case "multicast": case "multicast":
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("224.1.0.1"), IP: net.ParseIP("224.1.0.1"),
Port: 25000, Port: 25000,
}) })
case "tcp", "tls": case "tcp", "tls":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0, Channel: 0 + i*2,
Payload: testRTPPacketMarshaled, Payload: testRTPPacketMarshaled,
}, make([]byte, 1024)) }, make([]byte, 1024))
require.NoError(t, err)
}
// client -> server (RTCP)
switch transport {
case "udp", "multicast":
// skip firewall opening
if transport == "udp" {
buf := make([]byte, 2048)
_, _, err := l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
} }
buf := make([]byte, 2048) // client -> server (RTCP)
n, _, err := l2.ReadFrom(buf) switch transport {
require.NoError(t, err) case "udp", "multicast":
packets, err := rtcp.Unmarshal(buf[:n]) // skip firewall opening
require.NoError(t, err) if transport == "udp" {
require.Equal(t, &testRTCPPacket, packets[0]) buf := make([]byte, 2048)
close(packetRecv) _, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err)
}
case "tcp", "tls": buf := make([]byte, 2048)
f, err := conn.ReadInterleavedFrame() n, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) packets, err := rtcp.Unmarshal(buf[:n])
packets, err := rtcp.Unmarshal(f.Payload) require.NoError(t, err)
require.NoError(t, err) require.Equal(t, &testRTCPPacket, packets[0])
require.Equal(t, &testRTCPPacket, packets[0])
close(packetRecv) case "tcp", "tls":
f, err := conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 1+i*2, f.Channel)
packets, err := rtcp.Unmarshal(f.Payload)
require.NoError(t, err)
require.Equal(t, &testRTCPPacket, packets[0])
}
} }
close(packetRecv)
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method) require.Equal(t, base.Teardown, req.Method)
@@ -464,16 +478,28 @@ func TestClientPlay(t *testing.T) {
}(), }(),
} }
err = readAll(&c, u, err := url.Parse(scheme + "://" + listenIP + ":8554/test/stream?param=value")
scheme+"://"+listenIP+":8554/test/stream?param=value", require.NoError(t, err)
func(medi *media.Media, forma format.Format, pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt) err = c.Start(u.Scheme, u.Host)
err := c.WritePacketRTCP(medi, &testRTCPPacket)
require.NoError(t, err)
})
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
medias, baseURL, _, err := c.Describe(u)
require.NoError(t, err)
err = c.SetupAll(medias, baseURL)
require.NoError(t, err)
c.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt)
err := c.WritePacketRTCP(medi, &testRTCPPacket)
require.NoError(t, err)
})
_, err = c.Play(nil)
require.NoError(t, err)
<-packetRecv <-packetRecv
}) })
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/aler9/gortsplib/v2/pkg/base" "github.com/aler9/gortsplib/v2/pkg/base"
"github.com/aler9/gortsplib/v2/pkg/media" "github.com/aler9/gortsplib/v2/pkg/media"
@@ -187,7 +188,7 @@ func (cm *clientMedia) readRTPTCPPlay(payload []byte) error {
now := time.Now() now := time.Now()
atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix()) atomic.StoreInt64(cm.c.tcpLastFrameTime, now.Unix())
pkt := cm.c.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
return err return err
@@ -259,7 +260,7 @@ func (cm *clientMedia) readRTPUDPPlay(payload []byte) error {
return nil return nil
} }
pkt := cm.c.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
cm.c.OnDecodeError(err) cm.c.OnDecodeError(err)

View File

@@ -1,24 +0,0 @@
package gortsplib
import (
"github.com/pion/rtp"
)
type rtpPacketMultiBuffer struct {
count uint64
buffers []rtp.Packet
cur uint64
}
func newRTPPacketMultiBuffer(count uint64) *rtpPacketMultiBuffer {
return &rtpPacketMultiBuffer{
count: count,
buffers: make([]rtp.Packet, count),
}
}
func (mb *rtpPacketMultiBuffer) next() *rtp.Packet {
ret := &mb.buffers[mb.cur%mb.count]
mb.cur++
return ret
}

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"strconv"
"testing" "testing"
"time" "time"
@@ -499,18 +500,24 @@ func TestServerRecord(t *testing.T) {
// send RTCP packets directly to the session. // send RTCP packets directly to the session.
// these are sent after the response, only if onRecord returns StatusOK. // these are sent after the response, only if onRecord returns StatusOK.
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket)
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[1], &testRTCPPacket)
ctx.Session.OnPacketRTPAny(func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { for i := 0; i < 2; i++ {
require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) ctx.Session.OnPacketRTP(
require.Equal(t, ctx.Session.AnnouncedMedias()[0].Formats[0], forma) ctx.Session.AnnouncedMedias()[i],
require.Equal(t, &testRTPPacket, pkt) ctx.Session.AnnouncedMedias()[i].Formats[0],
}) func(pkt *rtp.Packet) {
require.Equal(t, &testRTPPacket, pkt)
})
ctx.Session.OnPacketRTCPAny(func(medi *media.Media, pkt rtcp.Packet) { ci := i
require.Equal(t, ctx.Session.AnnouncedMedias()[0], medi) ctx.Session.OnPacketRTCP(
require.Equal(t, &testRTCPPacket, pkt) ctx.Session.AnnouncedMedias()[i],
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[0], &testRTCPPacket) func(pkt rtcp.Packet) {
}) require.Equal(t, &testRTCPPacket, pkt)
ctx.Session.WritePacketRTCP(ctx.Session.AnnouncedMedias()[ci], &testRTCPPacket)
})
}
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -549,7 +556,7 @@ func TestServerRecord(t *testing.T) {
<-nconnOpened <-nconnOpened
medias := media.Medias{testH264Media.Clone()} medias := media.Medias{testH264Media.Clone(), testH264Media.Clone()}
medias.SetControls() medias.SetControls()
res, err := writeReqReadRes(conn, base.Request{ res, err := writeReqReadRes(conn, base.Request{
@@ -566,54 +573,61 @@ func TestServerRecord(t *testing.T) {
<-sessionOpened <-sessionOpened
inTH := &headers.Transport{ var l1s [2]net.PacketConn
Delivery: func() *headers.TransportDelivery { var l2s [2]net.PacketConn
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModeRecord
return &v
}(),
}
var l1 net.PacketConn
var l2 net.PacketConn
if transport == "udp" {
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466, 35467}
l1, err = net.ListenPacket("udp", "localhost:35466")
require.NoError(t, err)
defer l1.Close()
l2, err = net.ListenPacket("udp", "localhost:35467")
require.NoError(t, err)
defer l2.Close()
} else {
inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{0, 1}
}
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var sx headers.Session var sx headers.Session
err = sx.Unmarshal(res.Header["Session"]) var serverPorts [2]*[2]int
require.NoError(t, err)
var th headers.Transport for i := 0; i < 2; i++ {
err = th.Unmarshal(res.Header["Transport"]) inTH := &headers.Transport{
require.NoError(t, err) Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModeRecord
return &v
}(),
}
if transport == "udp" {
inTH.Protocol = headers.TransportProtocolUDP
inTH.ClientPorts = &[2]int{35466 + i*2, 35467 + i*2}
l1s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[0]), 10))
require.NoError(t, err)
defer l1s[i].Close()
l2s[i], err = net.ListenPacket("udp", "localhost:"+strconv.FormatInt(int64(inTH.ClientPorts[1]), 10))
require.NoError(t, err)
defer l2s[i].Close()
} else {
inTH.Protocol = headers.TransportProtocolTCP
inTH.InterleavedIDs = &[2]int{2 + i*2, 3 + i*2}
}
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/mediaID=" + strconv.FormatInt(int64(i), 10)),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
var th headers.Transport
err = th.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
if transport == "udp" {
serverPorts[i] = th.ServerPorts
}
}
res, err = writeReqReadRes(conn, base.Request{ res, err = writeReqReadRes(conn, base.Request{
Method: base.Record, Method: base.Record,
@@ -626,62 +640,66 @@ func TestServerRecord(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
// server -> client (direct) for i := 0; i < 2; i++ {
if transport == "udp" { // server -> client (direct)
buf := make([]byte, 2048) if transport == "udp" {
n, _, err := l2.ReadFrom(buf) buf := make([]byte, 2048)
require.NoError(t, err) n, _, err := l2s[i].ReadFrom(buf)
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.NoError(t, err)
} else { require.Equal(t, testRTCPPacketMarshaled, buf[:n])
f, err := conn.ReadInterleavedFrame() } else {
require.NoError(t, err) f, err := conn.ReadInterleavedFrame()
require.Equal(t, 1, f.Channel) require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, f.Payload) require.Equal(t, 3+i*2, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
}
// skip firewall opening
if transport == "udp" {
buf := make([]byte, 2048)
_, _, err := l2s[i].ReadFrom(buf)
require.NoError(t, err)
}
// client -> server
if transport == "udp" {
l1s[i].WriteTo(testRTPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: serverPorts[i][0],
})
l2s[i].WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: serverPorts[i][1],
})
} else {
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 2 + i*2,
Payload: testRTPPacketMarshaled,
}, make([]byte, 1024))
require.NoError(t, err)
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 3 + i*2,
Payload: testRTCPPacketMarshaled,
}, make([]byte, 1024))
require.NoError(t, err)
}
} }
// skip firewall opening for i := 0; i < 2; i++ {
if transport == "udp" { // server -> client (RTCP)
buf := make([]byte, 2048) if transport == "udp" {
_, _, err := l2.ReadFrom(buf) buf := make([]byte, 2048)
require.NoError(t, err) n, _, err := l2s[i].ReadFrom(buf)
} require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
// client -> server } else {
if transport == "udp" { f, err := conn.ReadInterleavedFrame()
l1.WriteTo(testRTPPacketMarshaled, &net.UDPAddr{ require.NoError(t, err)
IP: net.ParseIP("127.0.0.1"), require.Equal(t, 3+i*2, f.Channel)
Port: th.ServerPorts[0], require.Equal(t, testRTCPPacketMarshaled, f.Payload)
}) }
l2.WriteTo(testRTCPPacketMarshaled, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[1],
})
} else {
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: testRTPPacketMarshaled,
}, make([]byte, 1024))
require.NoError(t, err)
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: testRTCPPacketMarshaled,
}, make([]byte, 1024))
require.NoError(t, err)
}
// server -> client (RTCP)
if transport == "udp" {
buf := make([]byte, 2048)
n, _, err := l2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
f, err := conn.ReadInterleavedFrame()
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
} }
res, err = writeReqReadRes(conn, base.Request{ res, err = writeReqReadRes(conn, base.Request{

View File

@@ -173,7 +173,6 @@ type ServerSession struct {
udpLastPacketTime *int64 // publish udpLastPacketTime *int64 // publish
udpCheckStreamTimer *time.Timer udpCheckStreamTimer *time.Timer
writer writer writer writer
rtpPacketBuffer *rtpPacketMultiBuffer
// in // in
request chan sessionRequestReq request chan sessionRequestReq
@@ -948,8 +947,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.state = ServerSessionStateRecord ss.state = ServerSessionStateRecord
ss.rtpPacketBuffer = newRTPPacketMultiBuffer(uint64(ss.s.ReadBufferCount))
for _, sm := range ss.setuppedMedias { for _, sm := range ss.setuppedMedias {
sm.start() sm.start()
} }

View File

@@ -185,7 +185,7 @@ func (sm *serverSessionMedia) readRTPUDPRecord(payload []byte) error {
return nil return nil
} }
pkt := sm.ss.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
onDecodeError(sm.ss, err) onDecodeError(sm.ss, err)
@@ -265,7 +265,7 @@ func (sm *serverSessionMedia) readRTCPTCPPlay(payload []byte) error {
} }
func (sm *serverSessionMedia) readRTPTCPRecord(payload []byte) error { func (sm *serverSessionMedia) readRTPTCPRecord(payload []byte) error {
pkt := sm.ss.rtpPacketBuffer.next() pkt := &rtp.Packet{}
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
return err return err

View File

@@ -292,11 +292,6 @@ func (st *ServerStream) WritePacketRTPWithNTP(medi *media.Media, pkt *rtp.Packet
// WritePacketRTCP writes a RTCP packet to all the readers of the stream. // WritePacketRTCP writes a RTCP packet to all the readers of the stream.
func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) { func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) {
byts, err := pkt.Marshal()
if err != nil {
return
}
st.mutex.RLock() st.mutex.RLock()
defer st.mutex.RUnlock() defer st.mutex.RUnlock()
@@ -305,17 +300,5 @@ func (st *ServerStream) WritePacketRTCP(medi *media.Media, pkt rtcp.Packet) {
} }
sm := st.streamMedias[medi] sm := st.streamMedias[medi]
sm.writePacketRTCP(st, pkt)
// send unicast
for r := range st.activeUnicastReaders {
sm, ok := r.setuppedMedias[medi]
if ok {
sm.writePacketRTCP(byts)
}
}
// send multicast
if sm.multicastHandler != nil {
sm.multicastHandler.writePacketRTCP(byts)
}
} }

View File

@@ -3,6 +3,7 @@ package gortsplib
import ( import (
"time" "time"
"github.com/pion/rtcp"
"github.com/pion/rtp" "github.com/pion/rtp"
"github.com/aler9/gortsplib/v2/pkg/media" "github.com/aler9/gortsplib/v2/pkg/media"
@@ -69,3 +70,23 @@ func (sm *serverStreamMedia) WritePacketRTPWithNTP(ss *ServerStream, pkt *rtp.Pa
sm.multicastHandler.writePacketRTP(byts) sm.multicastHandler.writePacketRTP(byts)
} }
} }
func (sm *serverStreamMedia) writePacketRTCP(ss *ServerStream, pkt rtcp.Packet) {
byts, err := pkt.Marshal()
if err != nil {
return
}
// send unicast
for r := range ss.activeUnicastReaders {
sm, ok := r.setuppedMedias[sm.media]
if ok {
sm.writePacketRTCP(byts)
}
}
// send multicast
if sm.multicastHandler != nil {
sm.multicastHandler.writePacketRTCP(byts)
}
}