automatically remux oversized RTP/H264 packets; drop parameter ReadBufferSize

This commit is contained in:
aler9
2022-04-09 12:11:38 +02:00
committed by Alessandro Ros
parent b1a4b52090
commit bfe4e8cdaa
21 changed files with 390 additions and 376 deletions

190
client.go
View File

@@ -28,18 +28,12 @@ import (
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/multibuffer"
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver"
"github.com/aler9/gortsplib/pkg/rtcpsender"
"github.com/aler9/gortsplib/pkg/rtph264"
)
const (
clientReadBufferSize = 4096
clientUDPKernelReadBufferSize = 0x80000 // same size as gstreamer's rtspsrc
)
func isAnyPort(p int) bool {
return p == 0 || p == 1
}
@@ -62,6 +56,7 @@ type clientTrack struct {
rtcpReceiver *rtcpreceiver.RTCPReceiver
rtcpSender *rtcpsender.RTCPSender
h264Decoder *rtph264.Decoder
h264Encoder *rtph264.Encoder
}
func (s clientState) String() string {
@@ -187,10 +182,6 @@ type Client struct {
// that is reading frames.
// It defaults to 256.
ReadBufferCount int
// read buffer size.
// This must be touched only when the server reports errors about buffer sizes.
// It defaults to 2048.
ReadBufferSize int
// write buffer count.
// It allows to queue packets before sending them.
// It defaults to 8.
@@ -291,9 +282,6 @@ func (c *Client) Start(scheme string, host string) error {
if c.ReadBufferCount == 0 {
c.ReadBufferCount = 256
}
if c.ReadBufferSize == 0 {
c.ReadBufferSize = 2048
}
if c.WriteBufferCount == 0 {
c.WriteBufferCount = 256
}
@@ -760,14 +748,12 @@ func (c *Client) runReader() {
}
}
} else {
var tcpReadBuffer *multibuffer.MultiBuffer
var processFunc func(int, bool, []byte)
var processFunc func(int, bool, []byte) error
if c.state == clientStatePlay {
tcpReadBuffer = multibuffer.New(uint64(c.ReadBufferCount), uint64(c.ReadBufferSize))
tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(c.ReadBufferCount))
processFunc = func(trackID int, isRTP bool, payload []byte) {
processFunc = func(trackID int, isRTP bool, payload []byte) error {
now := time.Now()
atomic.StoreInt64(c.tcpLastFrameTime, now.Unix())
@@ -775,38 +761,105 @@ func (c *Client) runReader() {
pkt := tcpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload)
if err != nil {
return
return err
}
c.onPacketRTP(trackID, pkt)
ctx := ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
}
c.processPacketRTP(&ctx)
ct := c.tracks[trackID]
if ct.h264Decoder != nil {
if ct.h264Encoder == nil && len(payload) > udpReadBufferSize {
v1 := pkt.SSRC
v2 := pkt.SequenceNumber
v3 := pkt.Timestamp
ct.h264Encoder = &rtph264.Encoder{
PayloadType: pkt.PayloadType,
SSRC: &v1,
InitialSequenceNumber: &v2,
InitialTimestamp: &v3,
}
ct.h264Encoder.Init()
}
if ct.h264Encoder != nil {
if ctx.H264NALUs != nil {
packets, err := ct.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS)
if err != nil {
return err
}
for i, pkt := range packets {
if i != len(packets)-1 {
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
} else {
ctx.Packet = pkt
c.OnPacketRTP(&ctx)
}
}
}
} else {
c.OnPacketRTP(&ctx)
}
} else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
c.OnPacketRTP(&ctx)
}
} else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
return err
}
for _, pkt := range packets {
c.onPacketRTCP(trackID, pkt)
c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: trackID,
Packet: pkt,
})
}
}
return nil
}
} else {
// when recording, tcpReadBuffer is only used to receive RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
tcpReadBuffer = multibuffer.New(8, uint64(c.ReadBufferSize))
processFunc = func(trackID int, isRTP bool, payload []byte) {
processFunc = func(trackID int, isRTP bool, payload []byte) error {
if !isRTP {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
return err
}
for _, pkt := range packets {
c.onPacketRTCP(trackID, pkt)
c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: trackID,
Packet: pkt,
})
}
}
return nil
}
}
@@ -814,8 +867,7 @@ func (c *Client) runReader() {
var res base.Response
for {
frame.Payload = tcpReadBuffer.Next()
what, err := base.ReadInterleavedFrameOrResponse(&frame, &res, c.br)
what, err := base.ReadInterleavedFrameOrResponse(&frame, tcpMaxFramePayloadSize, &res, c.br)
if err != nil {
return err
}
@@ -833,7 +885,10 @@ func (c *Client) runReader() {
continue
}
processFunc(trackID, isRTP, frame.Payload)
err := processFunc(trackID, isRTP, frame.Payload)
if err != nil {
return err
}
}
}
}
@@ -874,6 +929,7 @@ func (c *Client) playRecordStop(isClosing bool) {
for _, ct := range c.tracks {
ct.h264Decoder = nil
ct.h264Encoder = nil
}
// stop timers
@@ -929,7 +985,7 @@ func (c *Client) connOpen() error {
return nconn
}()
c.br = bufio.NewReaderSize(c.conn, clientReadBufferSize)
c.br = bufio.NewReaderSize(c.conn, tcpReadBufferSize)
c.connCloserStart()
return nil
}
@@ -1008,11 +1064,10 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
if allowFrames {
// read the response and ignore interleaved frames in between;
// interleaved frames are sent in two scenarios:
// interleaved frames are sent in two cases:
// * when the server is v4lrtspserver, before the PLAY response
// * when the stream is already playing
buf := make([]byte, c.ReadBufferSize)
err = res.ReadIgnoreFrames(c.br, buf)
err = res.ReadIgnoreFrames(tcpMaxFramePayloadSize, c.br)
if err != nil {
return nil, err
}
@@ -1230,7 +1285,7 @@ func (c *Client) doAnnounce(u *base.URL, tracks Tracks) (*base.Response, error)
}
// in case of ANNOUNCE, the base URL doesn't have a trailing slash.
// (tested with ffmpeg and gstreamer)
// (tested with ffmpeg and GStreamer)
baseURL := u.Clone()
tracks.setControls()
@@ -1847,62 +1902,25 @@ func (c *Client) runWriter() {
}
}
func (c *Client) onPacketRTP(trackID int, pkt *rtp.Packet) {
func (c *Client) processPacketRTP(ctx *ClientOnPacketRTPCtx) {
// remove padding
pkt.Header.Padding = false
pkt.PaddingSize = 0
ct := c.tracks[trackID]
ctx.Packet.Header.Padding = false
ctx.Packet.PaddingSize = 0
// decode
ct := c.tracks[ctx.TrackID]
if ct.h264Decoder != nil {
nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(pkt)
nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(ctx.Packet)
if err == nil {
ptsEqualsDTS := h264.IDRPresent(nalus)
rr := ct.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(time.Now(), pkt, ptsEqualsDTS)
}
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: ptsEqualsDTS,
H264NALUs: append([][]byte(nil), nalus...),
H264PTS: pts,
})
ctx.PTSEqualsDTS = h264.IDRPresent(nalus)
ctx.H264NALUs = append([][]byte(nil), nalus...)
ctx.H264PTS = pts
} else {
rr := ct.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(time.Now(), pkt, false)
ctx.PTSEqualsDTS = false
}
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
} else {
ctx.PTSEqualsDTS = true
}
return
}
rr := ct.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(time.Now(), pkt, true)
}
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: true,
})
}
func (c *Client) onPacketRTCP(trackID int, pkt rtcp.Packet) {
c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: trackID,
Packet: pkt,
})
}
// WritePacketRTP writes a RTP packet.

View File

@@ -182,8 +182,7 @@ func TestClientPublishSerial(t *testing.T) {
require.Equal(t, testRTPPacket, pkt)
} else {
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err = f.Read(br)
err = f.Read(1024, br)
require.NoError(t, err)
require.Equal(t, 0, f.Channel)
var pkt rtp.Packet
@@ -823,8 +822,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.NoError(t, err)
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err = f.Read(br)
err = f.Read(2048, br)
require.NoError(t, err)
require.Equal(t, 0, f.Channel)
var pkt rtp.Packet

View File

@@ -21,6 +21,22 @@ import (
"github.com/aler9/gortsplib/pkg/headers"
)
func mergeBytes(vals ...[]byte) []byte {
size := 0
for _, v := range vals {
size += len(v)
}
res := make([]byte, size)
pos := 0
for _, v := range vals {
n := copy(res[pos:], v)
pos += n
}
return res
}
func TestClientReadTracks(t *testing.T) {
track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err)
@@ -359,8 +375,7 @@ func TestClientRead(t *testing.T) {
case "tcp", "tls":
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err := f.Read(br)
err := f.Read(2048, br)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
packets, err := rtcp.Unmarshal(f.Payload)
@@ -429,16 +444,60 @@ func TestClientRead(t *testing.T) {
}
}
func TestClientReadNonStandardFrameSize(t *testing.T) {
refRTPPacket := rtp.Packet{
var oversizedPacketRTPIn = rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
CSRC: []uint32{},
Marker: true,
SequenceNumber: 34572,
},
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5),
}
var oversizedPacketsRTPOut = []rtp.Packet{
{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
Marker: false,
SequenceNumber: 34572,
},
Payload: mergeBytes(
[]byte{0x1c, 0x81, 0x02, 0x03, 0x04, 0x05},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 290),
[]byte{0x01, 0x02, 0x03, 0x04},
),
},
{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
Marker: false,
SequenceNumber: 34573,
},
Payload: mergeBytes(
[]byte{0x1c, 0x01, 0x05},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 291),
[]byte{0x01, 0x02},
),
},
{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
Marker: true,
SequenceNumber: 34574,
},
Payload: mergeBytes(
[]byte{0x1c, 0x41, 0x03, 0x04, 0x05},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 235),
),
},
}
func TestClientReadOversizedPacket(t *testing.T) {
oversizedPacketsRTPOut := append([]rtp.Packet(nil), oversizedPacketsRTPOut...)
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
@@ -529,7 +588,7 @@ func TestClientReadNonStandardFrameSize(t *testing.T) {
_, err = conn.Write(bb.Bytes())
require.NoError(t, err)
byts, _ := refRTPPacket.Marshal()
byts, _ := oversizedPacketRTPIn.Marshal()
base.InterleavedFrame{
Channel: 0,
Payload: byts,
@@ -541,15 +600,18 @@ func TestClientReadNonStandardFrameSize(t *testing.T) {
packetRecv := make(chan struct{})
c := &Client{
ReadBufferSize: 4500 + 4,
Transport: func() *Transport {
v := TransportTCP
return &v
}(),
OnPacketRTP: func(ctx *ClientOnPacketRTPCtx) {
require.Equal(t, 0, ctx.TrackID)
require.Equal(t, &refRTPPacket, ctx.Packet)
cmp := oversizedPacketsRTPOut[0]
oversizedPacketsRTPOut = oversizedPacketsRTPOut[1:]
require.Equal(t, &cmp, ctx.Packet)
if len(oversizedPacketsRTPOut) == 0 {
close(packetRecv)
}
},
}

View File

@@ -29,9 +29,8 @@ func readRequest(br *bufio.Reader) (*base.Request, error) {
}
func readRequestIgnoreFrames(br *bufio.Reader) (*base.Request, error) {
buf := make([]byte, 2048)
var req base.Request
err := req.ReadIgnoreFrames(br, buf)
err := req.ReadIgnoreFrames(2048, br)
return &req, err
}

View File

@@ -76,7 +76,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP
p := ipv4.NewPacketConn(tmp)
err = p.SetMulticastTTL(16)
err = p.SetMulticastTTL(multicastTTL)
if err != nil {
return nil, err
}
@@ -102,7 +102,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP
pc = tmp.(*net.UDPConn)
}
err := pc.SetReadBuffer(clientUDPKernelReadBufferSize)
err := pc.SetReadBuffer(udpKernelReadBufferSize)
if err != nil {
return nil, err
}
@@ -110,7 +110,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP
return &clientUDPListener{
c: c,
pc: pc,
readBuffer: multibuffer.New(uint64(c.ReadBufferCount), uint64(c.ReadBufferSize)),
readBuffer: multibuffer.New(uint64(c.ReadBufferCount), uint64(udpReadBufferSize)),
rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)),
lastPacketTime: func() *int64 {
v := int64(0)
@@ -182,7 +182,13 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
return
}
u.c.onPacketRTP(u.trackID, pkt)
ctx := ClientOnPacketRTPCtx{
TrackID: u.trackID,
Packet: pkt,
}
u.c.processPacketRTP(&ctx)
u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, ctx.PTSEqualsDTS)
u.c.OnPacketRTP(&ctx)
}
func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
@@ -193,7 +199,10 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
for _, pkt := range packets {
u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt)
u.c.onPacketRTCP(u.trackID, pkt)
u.c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: u.trackID,
Packet: pkt,
})
}
}
@@ -204,7 +213,10 @@ func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) {
}
for _, pkt := range packets {
u.c.onPacketRTCP(u.trackID, pkt)
u.c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: u.trackID,
Packet: pkt,
})
}
}

9
constants.go Normal file
View File

@@ -0,0 +1,9 @@
package gortsplib
const (
tcpReadBufferSize = 4096
tcpMaxFramePayloadSize = 60 * 1024 * 1024
udpKernelReadBufferSize = 0x80000 // same size as GStreamer's rtspsrc
udpReadBufferSize = 1472 // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header)
multicastTTL = 16
)

View File

@@ -23,7 +23,7 @@ func main() {
}
defer pc.Close()
log.Println("Waiting for a rtp/h264 stream on port 9000 - you can send one with gstreamer:\n" +
log.Println("Waiting for a RTP/H264 stream on port 9000 - you can send one with GStreamer:\n" +
"gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" +
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")

View File

@@ -24,7 +24,7 @@ func main() {
}
defer pc.Close()
log.Println("Waiting for a rtp/h264 stream on port 9000 - you can send one with gstreamer:\n" +
log.Println("Waiting for a RTP/H264 stream on port 9000 - you can send one with GStreamer:\n" +
"gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" +
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")

View File

@@ -13,7 +13,12 @@ const (
)
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bufio.Reader) (interface{}, error) {
func ReadInterleavedFrameOrRequest(
frame *InterleavedFrame,
maxPayloadSize int,
req *Request,
br *bufio.Reader,
) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
@@ -21,7 +26,7 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bu
br.UnreadByte()
if b == interleavedFrameMagicByte {
err := frame.Read(br)
err := frame.Read(maxPayloadSize, br)
if err != nil {
return nil, err
}
@@ -36,7 +41,12 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bu
}
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *bufio.Reader) (interface{}, error) {
func ReadInterleavedFrameOrResponse(
frame *InterleavedFrame,
maxPayloadSize int,
res *Response,
br *bufio.Reader,
) (interface{}, error) {
b, err := br.ReadByte()
if err != nil {
return nil, err
@@ -44,7 +54,7 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *
br.UnreadByte()
if b == interleavedFrameMagicByte {
err := frame.Read(br)
err := frame.Read(maxPayloadSize, br)
if err != nil {
return nil, err
}
@@ -61,15 +71,15 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *
// InterleavedFrame is an interleaved frame, and allows to transfer binary data
// within RTSP/TCP connections. It is used to send and receive RTP and RTCP packets with TCP.
type InterleavedFrame struct {
// channel id
// channel ID
Channel int
// frame payload
// payload
Payload []byte
}
// Read reads an interleaved frame.
func (f *InterleavedFrame) Read(br *bufio.Reader) error {
func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error {
var header [4]byte
_, err := io.ReadFull(br, header[:])
if err != nil {
@@ -80,14 +90,14 @@ func (f *InterleavedFrame) Read(br *bufio.Reader) error {
return fmt.Errorf("invalid magic byte (0x%.2x)", header[0])
}
framelen := int(binary.BigEndian.Uint16(header[2:]))
if framelen > len(f.Payload) {
return fmt.Errorf("payload size greater than maximum allowed (%d vs %d)",
framelen, len(f.Payload))
payloadLen := int(binary.BigEndian.Uint16(header[2:]))
if payloadLen > maxPayloadSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
payloadLen, maxPayloadSize)
}
f.Channel = int(header[1])
f.Payload = f.Payload[:framelen]
f.Payload = make([]byte, payloadLen)
_, err = io.ReadFull(br, f.Payload)
if err != nil {

View File

@@ -37,8 +37,7 @@ func TestInterleavedFrameRead(t *testing.T) {
for _, ca := range casesInterleavedFrame {
t.Run(ca.name, func(t *testing.T) {
f.Payload = make([]byte, 1024)
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc)))
err := f.Read(1024, bufio.NewReader(bytes.NewBuffer(ca.enc)))
require.NoError(t, err)
require.Equal(t, ca.dec, f)
})
@@ -64,7 +63,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
{
"payload size too big",
[]byte{0x24, 0x00, 0x00, 0x08},
"payload size greater than maximum allowed (8 vs 5)",
"payload size (8) greater than maximum allowed (5)",
},
{
"payload invalid",
@@ -74,8 +73,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
f.Payload = make([]byte, 5)
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
err := f.Read(5, bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.EqualError(t, err, ca.err)
})
}
@@ -99,15 +97,14 @@ func TestReadInterleavedFrameOrRequest(t *testing.T) {
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f InterleavedFrame
f.Payload = make([]byte, 10)
var req Request
br := bufio.NewReader(bytes.NewBuffer(byts))
out, err := ReadInterleavedFrameOrRequest(&f, &req, br)
out, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.NoError(t, err)
require.Equal(t, &req, out)
out, err = ReadInterleavedFrameOrRequest(&f, &req, br)
out, err = ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.NoError(t, err)
require.Equal(t, &f, out)
}
@@ -136,11 +133,9 @@ func TestReadInterleavedFrameOrRequestErrors(t *testing.T) {
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
f.Payload = make([]byte, 10)
var req Request
br := bufio.NewReader(bytes.NewBuffer(ca.byts))
_, err := ReadInterleavedFrameOrRequest(&f, &req, br)
_, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.EqualError(t, err, ca.err)
})
}
@@ -154,15 +149,13 @@ func TestReadInterleavedFrameOrResponse(t *testing.T) {
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f InterleavedFrame
f.Payload = make([]byte, 10)
var res Response
br := bufio.NewReader(bytes.NewBuffer(byts))
out, err := ReadInterleavedFrameOrResponse(&f, &res, br)
out, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.NoError(t, err)
require.Equal(t, &res, out)
out, err = ReadInterleavedFrameOrResponse(&f, &res, br)
out, err = ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.NoError(t, err)
require.Equal(t, &f, out)
}
@@ -191,11 +184,9 @@ func TestReadInterleavedFrameOrResponseErrors(t *testing.T) {
} {
t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame
f.Payload = make([]byte, 10)
var res Response
br := bufio.NewReader(bytes.NewBuffer(ca.byts))
_, err := ReadInterleavedFrameOrResponse(&f, &res, br)
_, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.EqualError(t, err, ca.err)
})
}

View File

@@ -101,15 +101,11 @@ func (req *Request) Read(rb *bufio.Reader) error {
// ReadIgnoreFrames reads a request and ignores any interleaved frame sent
// before the request.
func (req *Request) ReadIgnoreFrames(rb *bufio.Reader, buf []byte) error {
buflen := len(buf)
f := InterleavedFrame{
Payload: buf,
}
func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error {
var f InterleavedFrame
for {
f.Payload = f.Payload[:buflen]
recv, err := ReadInterleavedFrameOrRequest(&f, req, rb)
recv, err := ReadInterleavedFrameOrRequest(&f, maxPayloadSize, req, rb)
if err != nil {
return err
}

View File

@@ -237,9 +237,8 @@ func TestRequestReadIgnoreFrames(t *testing.T) {
"\r\n")...)
rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var req Request
err := req.ReadIgnoreFrames(rb, buf)
err := req.ReadIgnoreFrames(10, rb)
require.NoError(t, err)
}
@@ -247,9 +246,8 @@ func TestRequestReadIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var req Request
err := req.ReadIgnoreFrames(rb, buf)
err := req.ReadIgnoreFrames(10, rb)
require.EqualError(t, err, "EOF")
}

View File

@@ -187,15 +187,11 @@ func (res *Response) Read(rb *bufio.Reader) error {
// ReadIgnoreFrames reads a response and ignores any interleaved frame sent
// before the response.
func (res *Response) ReadIgnoreFrames(rb *bufio.Reader, buf []byte) error {
buflen := len(buf)
f := InterleavedFrame{
Payload: buf,
}
func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error {
var f InterleavedFrame
for {
f.Payload = f.Payload[:buflen]
recv, err := ReadInterleavedFrameOrResponse(&f, res, rb)
recv, err := ReadInterleavedFrameOrResponse(&f, maxPayloadSize, res, rb)
if err != nil {
return err
}

View File

@@ -220,9 +220,8 @@ func TestResponseReadIgnoreFrames(t *testing.T) {
"\r\n")...)
rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var res Response
err := res.ReadIgnoreFrames(rb, buf)
err := res.ReadIgnoreFrames(10, rb)
require.NoError(t, err)
}
@@ -230,9 +229,8 @@ func TestResponseReadIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25}
rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var res Response
err := res.ReadIgnoreFrames(rb, buf)
err := res.ReadIgnoreFrames(10, rb)
require.EqualError(t, err, "EOF")
}

View File

@@ -15,11 +15,6 @@ import (
"github.com/aler9/gortsplib/pkg/liberrors"
)
const (
serverReadBufferSize = 4096
serverUDPKernelReadBufferSize = 0x80000 // same as gstreamer's rtspsrc
)
func extractPort(address string) (int, error) {
_, tmp, err := net.SplitHostPort(address)
if err != nil {
@@ -115,10 +110,6 @@ type Server struct {
// that are particularly relevant when using UDP.
// It defaults to 256.
ReadBufferCount int
// read buffer size.
// This must be touched only when the server reports errors about buffer sizes.
// It defaults to 2048.
ReadBufferSize int
// write buffer count.
// It allows to queue packets before sending them.
// It defaults to 256.
@@ -174,9 +165,6 @@ func (s *Server) Start() error {
if s.ReadBufferCount == 0 {
s.ReadBufferCount = 256
}
if s.ReadBufferSize == 0 {
s.ReadBufferSize = 2048
}
if s.WriteBufferCount == 0 {
s.WriteBufferCount = 256
}

View File

@@ -736,8 +736,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err := f.Read(br)
err := f.Read(2048, br)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
@@ -789,8 +788,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else {
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err := f.Read(br)
err := f.Read(2048, br)
require.NoError(t, err)
require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload)
@@ -815,21 +813,12 @@ func TestServerPublish(t *testing.T) {
}
}
func TestServerPublishNonStandardFrameSize(t *testing.T) {
packet := rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 97,
CSRC: []uint32{},
},
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5),
}
packetMarshaled, _ := packet.Marshal()
frameReceived := make(chan struct{})
func TestServerPublishOversizedPacket(t *testing.T) {
oversizedPacketsRTPOut := append([]rtp.Packet(nil), oversizedPacketsRTPOut...)
packetRecv := make(chan struct{})
s := &Server{
RTSPAddress: "localhost:8554",
ReadBufferSize: 4500,
Handler: &testServerHandler{
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
@@ -848,8 +837,12 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) {
},
onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
require.Equal(t, 0, ctx.TrackID)
require.Equal(t, &packet, ctx.Packet)
close(frameReceived)
cmp := oversizedPacketsRTPOut[0]
oversizedPacketsRTPOut = oversizedPacketsRTPOut[1:]
require.Equal(t, &cmp, ctx.Packet)
if len(oversizedPacketsRTPOut) == 0 {
close(packetRecv)
}
},
},
}
@@ -921,14 +914,15 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) {
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ := oversizedPacketRTPIn.Marshal()
base.InterleavedFrame{
Channel: 0,
Payload: packetMarshaled,
Payload: byts,
}.Write(&bb)
_, err = conn.Write(bb.Bytes())
require.NoError(t, err)
<-frameReceived
<-packetRecv
}
func TestServerPublishErrorInvalidProtocol(t *testing.T) {

View File

@@ -484,9 +484,7 @@ func TestServerRead(t *testing.T) {
case "tcp", "tls":
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err := f.Read(br)
err := f.Read(2048, br)
require.NoError(t, err)
switch f.Channel {
@@ -516,8 +514,7 @@ func TestServerRead(t *testing.T) {
var f base.InterleavedFrame
for i := 0; i < 2; i++ {
f.Payload = make([]byte, 2048)
err := f.Read(br)
err := f.Read(2048, br)
require.NoError(t, err)
switch f.Channel {
@@ -763,99 +760,6 @@ func TestServerReadVLCMulticast(t *testing.T) {
require.Equal(t, "224.1.0.0", desc.ConnectionInformation.Address.Address)
}
func TestServerReadNonStandardFrameSize(t *testing.T) {
packet := rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 97,
CSRC: []uint32{},
},
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5),
}
packetMarshaled, _ := packet.Marshal()
track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err)
stream := NewServerStream(Tracks{track})
defer stream.Close()
s := &Server{
Handler: &testServerHandler{
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
go func() {
time.Sleep(1 * time.Second)
stream.WritePacketRTP(0, &packet, true)
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: "localhost:8554",
}
err = s.Start()
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
br := bufio.NewReader(conn)
inTH := &headers.Transport{
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Protocol: headers.TransportProtocolTCP,
InterleavedIDs: &[2]int{0, 1},
}
res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Write(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var sx headers.Session
err = sx.Read(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var f base.InterleavedFrame
f.Payload = make([]byte, 4500)
err = f.Read(br)
require.NoError(t, err)
require.Equal(t, 0, f.Channel)
require.Equal(t, packetMarshaled, f.Payload)
}
func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
writerDone := make(chan struct{})
writerTerminate := make(chan struct{})
@@ -953,8 +857,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame
fr.Payload = make([]byte, 2048)
err = fr.Read(br)
err = fr.Read(2048, br)
require.NoError(t, err)
}
@@ -1701,8 +1604,7 @@ func TestServerReadPartialTracks(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode)
var f base.InterleavedFrame
f.Payload = make([]byte, 2048)
err = f.Read(br)
err = f.Read(2048, br)
require.NoError(t, err)
require.Equal(t, 4, f.Channel)
require.Equal(t, testRTPPacketMarshaled, f.Payload)

View File

@@ -37,9 +37,8 @@ func writeReqReadRes(conn net.Conn,
}
func readResIgnoreFrames(br *bufio.Reader) (*base.Response, error) {
buf := make([]byte, 2048)
var res base.Response
err := res.ReadIgnoreFrames(br, buf)
err := res.ReadIgnoreFrames(2048, br)
return &res, err
}

View File

@@ -6,6 +6,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strings"
@@ -15,7 +16,7 @@ import (
"github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/multibuffer"
"github.com/aler9/gortsplib/pkg/rtph264"
)
func getSessionID(header base.Header) string {
@@ -109,7 +110,7 @@ func (sc *ServerConn) run() {
})
}
sc.br = bufio.NewReaderSize(sc.conn, serverReadBufferSize)
sc.br = bufio.NewReaderSize(sc.conn, tcpReadBufferSize)
readRequest := make(chan readReq)
readErr := make(chan error)
@@ -218,20 +219,19 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
case <-sc.session.ctx.Done():
}
var tcpReadBuffer *multibuffer.MultiBuffer
var processFunc func(int, bool, []byte)
var processFunc func(int, bool, []byte) error
if sc.session.state == ServerSessionStatePlay {
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize))
processFunc = func(trackID int, isRTP bool, payload []byte) {
processFunc = func(trackID int, isRTP bool, payload []byte) error {
if !isRTP {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
return err
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
@@ -244,30 +244,100 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
}
}
}
return nil
}
} else {
tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize))
tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount))
processFunc = func(trackID int, isRTP bool, payload []byte) {
processFunc = func(trackID int, isRTP bool, payload []byte) error {
if isRTP {
pkt := tcpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload)
if err != nil {
return
return err
}
sc.session.onPacketRTP(time.Time{}, trackID, pkt)
ctx := ServerHandlerOnPacketRTPCtx{
Session: sc.session,
TrackID: trackID,
Packet: pkt,
}
sc.session.processPacketRTP(&ctx)
at := sc.session.announcedTracks[trackID]
if at.h264Decoder != nil {
if at.h264Encoder == nil && len(payload) > udpReadBufferSize {
v1 := pkt.SSRC
v2 := pkt.SequenceNumber
v3 := pkt.Timestamp
at.h264Encoder = &rtph264.Encoder{
PayloadType: pkt.PayloadType,
SSRC: &v1,
InitialSequenceNumber: &v2,
InitialTimestamp: &v3,
}
at.h264Encoder.Init()
}
if at.h264Encoder != nil {
if ctx.H264NALUs != nil {
packets, err := at.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS)
if err != nil {
return err
}
for i, pkt := range packets {
if i != len(packets)-1 {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.session,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
}
} else {
ctx.Packet = pkt
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
}
}
} else {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
} else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
} else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload)
if err != nil {
return
return err
}
for _, pkt := range packets {
sc.session.onPacketRTCP(trackID, pkt)
}
}
return nil
}
}
@@ -279,8 +349,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
}
frame.Payload = tcpReadBuffer.Next()
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
what, err := base.ReadInterleavedFrameOrRequest(&frame, tcpMaxFramePayloadSize, &req, sc.br)
if err != nil {
return err
}
@@ -296,7 +365,10 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
// forward frame only if it has been set up
if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok {
processFunc(trackID, isRTP, frame.Payload)
err := processFunc(trackID, isRTP, frame.Payload)
if err != nil {
return err
}
}
case *base.Request:

View File

@@ -155,6 +155,7 @@ type ServerSessionAnnouncedTrack struct {
track Track
rtcpReceiver *rtcpreceiver.RTCPReceiver
h264Decoder *rtph264.Decoder
h264Encoder *rtph264.Encoder
}
// ServerSession is a server-side RTSP session.
@@ -1103,6 +1104,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
for _, at := range ss.announcedTracks {
at.h264Decoder = nil
at.h264Encoder = nil
}
ss.state = ServerSessionStatePreRecord
@@ -1212,63 +1214,24 @@ func (ss *ServerSession) runWriter() {
}
}
func (ss *ServerSession) onPacketRTP(now time.Time, trackID int, pkt *rtp.Packet) {
func (ss *ServerSession) processPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) {
// remove padding
pkt.Header.Padding = false
pkt.PaddingSize = 0
at := ss.announcedTracks[trackID]
ctx.Packet.Header.Padding = false
ctx.Packet.PaddingSize = 0
// decode
at := ss.announcedTracks[ctx.TrackID]
if at.h264Decoder != nil {
nalus, pts, err := at.h264Decoder.DecodeUntilMarker(pkt)
nalus, pts, err := at.h264Decoder.DecodeUntilMarker(ctx.Packet)
if err == nil {
ptsEqualsDTS := h264.IDRPresent(nalus)
rr := at.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(now, pkt, ptsEqualsDTS)
}
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: ss,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: ptsEqualsDTS,
H264NALUs: append([][]byte(nil), nalus...),
H264PTS: pts,
})
ctx.PTSEqualsDTS = h264.IDRPresent(nalus)
ctx.H264NALUs = append([][]byte(nil), nalus...)
ctx.H264PTS = pts
} else {
ctx.PTSEqualsDTS = false
}
} else {
rr := at.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(now, pkt, false)
}
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: ss,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
}
}
return
}
rr := at.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(now, pkt, true)
}
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: ss,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: true,
})
ctx.PTSEqualsDTS = false
}
}

View File

@@ -98,7 +98,7 @@ func newServerUDPListener(
p := ipv4.NewPacketConn(tmp)
err = p.SetMulticastTTL(16)
err = p.SetMulticastTTL(multicastTTL)
if err != nil {
return nil, err
}
@@ -130,7 +130,7 @@ func newServerUDPListener(
listenIP = tmp.LocalAddr().(*net.UDPAddr).IP
}
err := pc.SetReadBuffer(serverUDPKernelReadBufferSize)
err := pc.SetReadBuffer(udpKernelReadBufferSize)
if err != nil {
return nil, err
}
@@ -142,7 +142,7 @@ func newServerUDPListener(
clients: make(map[clientAddr]*clientData),
isRTP: isRTP,
writeTimeout: s.WriteTimeout,
readBuffer: multibuffer.New(uint64(s.ReadBufferCount), uint64(s.ReadBufferSize)),
readBuffer: multibuffer.New(uint64(s.ReadBufferCount), uint64(udpReadBufferSize)),
rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(s.ReadBufferCount)),
readerDone: make(chan struct{}),
}
@@ -207,7 +207,16 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
now := time.Now()
atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix())
clientData.ss.onPacketRTP(now, clientData.trackID, pkt)
ctx := ServerHandlerOnPacketRTPCtx{
Session: clientData.ss,
TrackID: clientData.trackID,
Packet: pkt,
}
clientData.ss.processPacketRTP(&ctx)
clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTP(now, ctx.Packet, ctx.PTSEqualsDTS)
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {