automatically remux oversized RTP/H264 packets; drop parameter ReadBufferSize

This commit is contained in:
aler9
2022-04-09 12:11:38 +02:00
committed by Alessandro Ros
parent b1a4b52090
commit bfe4e8cdaa
21 changed files with 390 additions and 376 deletions

190
client.go
View File

@@ -28,18 +28,12 @@ import (
"github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/multibuffer"
"github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtcpreceiver" "github.com/aler9/gortsplib/pkg/rtcpreceiver"
"github.com/aler9/gortsplib/pkg/rtcpsender" "github.com/aler9/gortsplib/pkg/rtcpsender"
"github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/rtph264"
) )
const (
clientReadBufferSize = 4096
clientUDPKernelReadBufferSize = 0x80000 // same size as gstreamer's rtspsrc
)
func isAnyPort(p int) bool { func isAnyPort(p int) bool {
return p == 0 || p == 1 return p == 0 || p == 1
} }
@@ -62,6 +56,7 @@ type clientTrack struct {
rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpReceiver *rtcpreceiver.RTCPReceiver
rtcpSender *rtcpsender.RTCPSender rtcpSender *rtcpsender.RTCPSender
h264Decoder *rtph264.Decoder h264Decoder *rtph264.Decoder
h264Encoder *rtph264.Encoder
} }
func (s clientState) String() string { func (s clientState) String() string {
@@ -187,10 +182,6 @@ type Client struct {
// that is reading frames. // that is reading frames.
// It defaults to 256. // It defaults to 256.
ReadBufferCount int ReadBufferCount int
// read buffer size.
// This must be touched only when the server reports errors about buffer sizes.
// It defaults to 2048.
ReadBufferSize int
// write buffer count. // write buffer count.
// It allows to queue packets before sending them. // It allows to queue packets before sending them.
// It defaults to 8. // It defaults to 8.
@@ -291,9 +282,6 @@ func (c *Client) Start(scheme string, host string) error {
if c.ReadBufferCount == 0 { if c.ReadBufferCount == 0 {
c.ReadBufferCount = 256 c.ReadBufferCount = 256
} }
if c.ReadBufferSize == 0 {
c.ReadBufferSize = 2048
}
if c.WriteBufferCount == 0 { if c.WriteBufferCount == 0 {
c.WriteBufferCount = 256 c.WriteBufferCount = 256
} }
@@ -760,14 +748,12 @@ func (c *Client) runReader() {
} }
} }
} else { } else {
var tcpReadBuffer *multibuffer.MultiBuffer var processFunc func(int, bool, []byte) error
var processFunc func(int, bool, []byte)
if c.state == clientStatePlay { if c.state == clientStatePlay {
tcpReadBuffer = multibuffer.New(uint64(c.ReadBufferCount), uint64(c.ReadBufferSize))
tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)) tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(c.ReadBufferCount))
processFunc = func(trackID int, isRTP bool, payload []byte) { processFunc = func(trackID int, isRTP bool, payload []byte) error {
now := time.Now() now := time.Now()
atomic.StoreInt64(c.tcpLastFrameTime, now.Unix()) atomic.StoreInt64(c.tcpLastFrameTime, now.Unix())
@@ -775,38 +761,105 @@ func (c *Client) runReader() {
pkt := tcpRTPPacketBuffer.next() pkt := tcpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
return return err
} }
c.onPacketRTP(trackID, pkt) ctx := ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
}
c.processPacketRTP(&ctx)
ct := c.tracks[trackID]
if ct.h264Decoder != nil {
if ct.h264Encoder == nil && len(payload) > udpReadBufferSize {
v1 := pkt.SSRC
v2 := pkt.SequenceNumber
v3 := pkt.Timestamp
ct.h264Encoder = &rtph264.Encoder{
PayloadType: pkt.PayloadType,
SSRC: &v1,
InitialSequenceNumber: &v2,
InitialTimestamp: &v3,
}
ct.h264Encoder.Init()
}
if ct.h264Encoder != nil {
if ctx.H264NALUs != nil {
packets, err := ct.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS)
if err != nil {
return err
}
for i, pkt := range packets {
if i != len(packets)-1 {
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
} else {
ctx.Packet = pkt
c.OnPacketRTP(&ctx)
}
}
}
} else {
c.OnPacketRTP(&ctx)
}
} else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
c.OnPacketRTP(&ctx)
}
} else { } else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
return return err
} }
for _, pkt := range packets { for _, pkt := range packets {
c.onPacketRTCP(trackID, pkt) c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: trackID,
Packet: pkt,
})
} }
} }
return nil
} }
} else { } else {
// when recording, tcpReadBuffer is only used to receive RTCP receiver reports, processFunc = func(trackID int, isRTP bool, payload []byte) error {
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
tcpReadBuffer = multibuffer.New(8, uint64(c.ReadBufferSize))
processFunc = func(trackID int, isRTP bool, payload []byte) {
if !isRTP { if !isRTP {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
return return err
} }
for _, pkt := range packets { for _, pkt := range packets {
c.onPacketRTCP(trackID, pkt) c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: trackID,
Packet: pkt,
})
} }
} }
return nil
} }
} }
@@ -814,8 +867,7 @@ func (c *Client) runReader() {
var res base.Response var res base.Response
for { for {
frame.Payload = tcpReadBuffer.Next() what, err := base.ReadInterleavedFrameOrResponse(&frame, tcpMaxFramePayloadSize, &res, c.br)
what, err := base.ReadInterleavedFrameOrResponse(&frame, &res, c.br)
if err != nil { if err != nil {
return err return err
} }
@@ -833,7 +885,10 @@ func (c *Client) runReader() {
continue continue
} }
processFunc(trackID, isRTP, frame.Payload) err := processFunc(trackID, isRTP, frame.Payload)
if err != nil {
return err
}
} }
} }
} }
@@ -874,6 +929,7 @@ func (c *Client) playRecordStop(isClosing bool) {
for _, ct := range c.tracks { for _, ct := range c.tracks {
ct.h264Decoder = nil ct.h264Decoder = nil
ct.h264Encoder = nil
} }
// stop timers // stop timers
@@ -929,7 +985,7 @@ func (c *Client) connOpen() error {
return nconn return nconn
}() }()
c.br = bufio.NewReaderSize(c.conn, clientReadBufferSize) c.br = bufio.NewReaderSize(c.conn, tcpReadBufferSize)
c.connCloserStart() c.connCloserStart()
return nil return nil
} }
@@ -1008,11 +1064,10 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
if allowFrames { if allowFrames {
// read the response and ignore interleaved frames in between; // read the response and ignore interleaved frames in between;
// interleaved frames are sent in two scenarios: // interleaved frames are sent in two cases:
// * when the server is v4lrtspserver, before the PLAY response // * when the server is v4lrtspserver, before the PLAY response
// * when the stream is already playing // * when the stream is already playing
buf := make([]byte, c.ReadBufferSize) err = res.ReadIgnoreFrames(tcpMaxFramePayloadSize, c.br)
err = res.ReadIgnoreFrames(c.br, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1230,7 +1285,7 @@ func (c *Client) doAnnounce(u *base.URL, tracks Tracks) (*base.Response, error)
} }
// in case of ANNOUNCE, the base URL doesn't have a trailing slash. // in case of ANNOUNCE, the base URL doesn't have a trailing slash.
// (tested with ffmpeg and gstreamer) // (tested with ffmpeg and GStreamer)
baseURL := u.Clone() baseURL := u.Clone()
tracks.setControls() tracks.setControls()
@@ -1847,62 +1902,25 @@ func (c *Client) runWriter() {
} }
} }
func (c *Client) onPacketRTP(trackID int, pkt *rtp.Packet) { func (c *Client) processPacketRTP(ctx *ClientOnPacketRTPCtx) {
// remove padding // remove padding
pkt.Header.Padding = false ctx.Packet.Header.Padding = false
pkt.PaddingSize = 0 ctx.Packet.PaddingSize = 0
ct := c.tracks[trackID]
// decode
ct := c.tracks[ctx.TrackID]
if ct.h264Decoder != nil { if ct.h264Decoder != nil {
nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(pkt) nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(ctx.Packet)
if err == nil { if err == nil {
ptsEqualsDTS := h264.IDRPresent(nalus) ctx.PTSEqualsDTS = h264.IDRPresent(nalus)
ctx.H264NALUs = append([][]byte(nil), nalus...)
rr := ct.rtcpReceiver ctx.H264PTS = pts
if rr != nil {
rr.ProcessPacketRTP(time.Now(), pkt, ptsEqualsDTS)
}
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: ptsEqualsDTS,
H264NALUs: append([][]byte(nil), nalus...),
H264PTS: pts,
})
} else { } else {
rr := ct.rtcpReceiver ctx.PTSEqualsDTS = false
if rr != nil {
rr.ProcessPacketRTP(time.Now(), pkt, false)
}
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
} }
return } else {
ctx.PTSEqualsDTS = true
} }
rr := ct.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(time.Now(), pkt, true)
}
c.OnPacketRTP(&ClientOnPacketRTPCtx{
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: true,
})
}
func (c *Client) onPacketRTCP(trackID int, pkt rtcp.Packet) {
c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: trackID,
Packet: pkt,
})
} }
// WritePacketRTP writes a RTP packet. // WritePacketRTP writes a RTP packet.

View File

@@ -182,8 +182,7 @@ func TestClientPublishSerial(t *testing.T) {
require.Equal(t, testRTPPacket, pkt) require.Equal(t, testRTPPacket, pkt)
} else { } else {
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) err = f.Read(1024, br)
err = f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, f.Channel) require.Equal(t, 0, f.Channel)
var pkt rtp.Packet var pkt rtp.Packet
@@ -823,8 +822,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) err = f.Read(2048, br)
err = f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, f.Channel) require.Equal(t, 0, f.Channel)
var pkt rtp.Packet var pkt rtp.Packet

View File

@@ -21,6 +21,22 @@ import (
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
) )
func mergeBytes(vals ...[]byte) []byte {
size := 0
for _, v := range vals {
size += len(v)
}
res := make([]byte, size)
pos := 0
for _, v := range vals {
n := copy(res[pos:], v)
pos += n
}
return res
}
func TestClientReadTracks(t *testing.T) { func TestClientReadTracks(t *testing.T) {
track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err) require.NoError(t, err)
@@ -359,8 +375,7 @@ func TestClientRead(t *testing.T) {
case "tcp", "tls": case "tcp", "tls":
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) err := f.Read(2048, br)
err := f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) require.Equal(t, 1, f.Channel)
packets, err := rtcp.Unmarshal(f.Payload) packets, err := rtcp.Unmarshal(f.Payload)
@@ -429,15 +444,59 @@ func TestClientRead(t *testing.T) {
} }
} }
func TestClientReadNonStandardFrameSize(t *testing.T) { var oversizedPacketRTPIn = rtp.Packet{
refRTPPacket := rtp.Packet{ Header: rtp.Header{
Version: 2,
PayloadType: 96,
Marker: true,
SequenceNumber: 34572,
},
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5),
}
var oversizedPacketsRTPOut = []rtp.Packet{
{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
PayloadType: 96, PayloadType: 96,
CSRC: []uint32{}, Marker: false,
SequenceNumber: 34572,
}, },
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), Payload: mergeBytes(
} []byte{0x1c, 0x81, 0x02, 0x03, 0x04, 0x05},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 290),
[]byte{0x01, 0x02, 0x03, 0x04},
),
},
{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
Marker: false,
SequenceNumber: 34573,
},
Payload: mergeBytes(
[]byte{0x1c, 0x01, 0x05},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 291),
[]byte{0x01, 0x02},
),
},
{
Header: rtp.Header{
Version: 2,
PayloadType: 96,
Marker: true,
SequenceNumber: 34574,
},
Payload: mergeBytes(
[]byte{0x1c, 0x41, 0x03, 0x04, 0x05},
bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 235),
),
},
}
func TestClientReadOversizedPacket(t *testing.T) {
oversizedPacketsRTPOut := append([]rtp.Packet(nil), oversizedPacketsRTPOut...)
l, err := net.Listen("tcp", "localhost:8554") l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err) require.NoError(t, err)
@@ -529,7 +588,7 @@ func TestClientReadNonStandardFrameSize(t *testing.T) {
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(bb.Bytes())
require.NoError(t, err) require.NoError(t, err)
byts, _ := refRTPPacket.Marshal() byts, _ := oversizedPacketRTPIn.Marshal()
base.InterleavedFrame{ base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: byts, Payload: byts,
@@ -541,15 +600,18 @@ func TestClientReadNonStandardFrameSize(t *testing.T) {
packetRecv := make(chan struct{}) packetRecv := make(chan struct{})
c := &Client{ c := &Client{
ReadBufferSize: 4500 + 4,
Transport: func() *Transport { Transport: func() *Transport {
v := TransportTCP v := TransportTCP
return &v return &v
}(), }(),
OnPacketRTP: func(ctx *ClientOnPacketRTPCtx) { OnPacketRTP: func(ctx *ClientOnPacketRTPCtx) {
require.Equal(t, 0, ctx.TrackID) require.Equal(t, 0, ctx.TrackID)
require.Equal(t, &refRTPPacket, ctx.Packet) cmp := oversizedPacketsRTPOut[0]
close(packetRecv) oversizedPacketsRTPOut = oversizedPacketsRTPOut[1:]
require.Equal(t, &cmp, ctx.Packet)
if len(oversizedPacketsRTPOut) == 0 {
close(packetRecv)
}
}, },
} }

View File

@@ -29,9 +29,8 @@ func readRequest(br *bufio.Reader) (*base.Request, error) {
} }
func readRequestIgnoreFrames(br *bufio.Reader) (*base.Request, error) { func readRequestIgnoreFrames(br *bufio.Reader) (*base.Request, error) {
buf := make([]byte, 2048)
var req base.Request var req base.Request
err := req.ReadIgnoreFrames(br, buf) err := req.ReadIgnoreFrames(2048, br)
return &req, err return &req, err
} }

View File

@@ -76,7 +76,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP
p := ipv4.NewPacketConn(tmp) p := ipv4.NewPacketConn(tmp)
err = p.SetMulticastTTL(16) err = p.SetMulticastTTL(multicastTTL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -102,7 +102,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP
pc = tmp.(*net.UDPConn) pc = tmp.(*net.UDPConn)
} }
err := pc.SetReadBuffer(clientUDPKernelReadBufferSize) err := pc.SetReadBuffer(udpKernelReadBufferSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -110,7 +110,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP
return &clientUDPListener{ return &clientUDPListener{
c: c, c: c,
pc: pc, pc: pc,
readBuffer: multibuffer.New(uint64(c.ReadBufferCount), uint64(c.ReadBufferSize)), readBuffer: multibuffer.New(uint64(c.ReadBufferCount), uint64(udpReadBufferSize)),
rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)), rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)),
lastPacketTime: func() *int64 { lastPacketTime: func() *int64 {
v := int64(0) v := int64(0)
@@ -182,7 +182,13 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
return return
} }
u.c.onPacketRTP(u.trackID, pkt) ctx := ClientOnPacketRTPCtx{
TrackID: u.trackID,
Packet: pkt,
}
u.c.processPacketRTP(&ctx)
u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, ctx.PTSEqualsDTS)
u.c.OnPacketRTP(&ctx)
} }
func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
@@ -193,7 +199,10 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
for _, pkt := range packets { for _, pkt := range packets {
u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt) u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt)
u.c.onPacketRTCP(u.trackID, pkt) u.c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: u.trackID,
Packet: pkt,
})
} }
} }
@@ -204,7 +213,10 @@ func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) {
} }
for _, pkt := range packets { for _, pkt := range packets {
u.c.onPacketRTCP(u.trackID, pkt) u.c.OnPacketRTCP(&ClientOnPacketRTCPCtx{
TrackID: u.trackID,
Packet: pkt,
})
} }
} }

9
constants.go Normal file
View File

@@ -0,0 +1,9 @@
package gortsplib
const (
tcpReadBufferSize = 4096
tcpMaxFramePayloadSize = 60 * 1024 * 1024
udpKernelReadBufferSize = 0x80000 // same size as GStreamer's rtspsrc
udpReadBufferSize = 1472 // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header)
multicastTTL = 16
)

View File

@@ -23,7 +23,7 @@ func main() {
} }
defer pc.Close() defer pc.Close()
log.Println("Waiting for a rtp/h264 stream on port 9000 - you can send one with gstreamer:\n" + log.Println("Waiting for a RTP/H264 stream on port 9000 - you can send one with GStreamer:\n" +
"gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" +
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")

View File

@@ -24,7 +24,7 @@ func main() {
} }
defer pc.Close() defer pc.Close()
log.Println("Waiting for a rtp/h264 stream on port 9000 - you can send one with gstreamer:\n" + log.Println("Waiting for a RTP/H264 stream on port 9000 - you can send one with GStreamer:\n" +
"gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" +
" ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000")

View File

@@ -13,7 +13,12 @@ const (
) )
// ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response. // ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bufio.Reader) (interface{}, error) { func ReadInterleavedFrameOrRequest(
frame *InterleavedFrame,
maxPayloadSize int,
req *Request,
br *bufio.Reader,
) (interface{}, error) {
b, err := br.ReadByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -21,7 +26,7 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bu
br.UnreadByte() br.UnreadByte()
if b == interleavedFrameMagicByte { if b == interleavedFrameMagicByte {
err := frame.Read(br) err := frame.Read(maxPayloadSize, br)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -36,7 +41,12 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bu
} }
// ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. // ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response.
func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *bufio.Reader) (interface{}, error) { func ReadInterleavedFrameOrResponse(
frame *InterleavedFrame,
maxPayloadSize int,
res *Response,
br *bufio.Reader,
) (interface{}, error) {
b, err := br.ReadByte() b, err := br.ReadByte()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -44,7 +54,7 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *
br.UnreadByte() br.UnreadByte()
if b == interleavedFrameMagicByte { if b == interleavedFrameMagicByte {
err := frame.Read(br) err := frame.Read(maxPayloadSize, br)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -61,15 +71,15 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *
// InterleavedFrame is an interleaved frame, and allows to transfer binary data // InterleavedFrame is an interleaved frame, and allows to transfer binary data
// within RTSP/TCP connections. It is used to send and receive RTP and RTCP packets with TCP. // within RTSP/TCP connections. It is used to send and receive RTP and RTCP packets with TCP.
type InterleavedFrame struct { type InterleavedFrame struct {
// channel id // channel ID
Channel int Channel int
// frame payload // payload
Payload []byte Payload []byte
} }
// Read reads an interleaved frame. // Read reads an interleaved frame.
func (f *InterleavedFrame) Read(br *bufio.Reader) error { func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error {
var header [4]byte var header [4]byte
_, err := io.ReadFull(br, header[:]) _, err := io.ReadFull(br, header[:])
if err != nil { if err != nil {
@@ -80,14 +90,14 @@ func (f *InterleavedFrame) Read(br *bufio.Reader) error {
return fmt.Errorf("invalid magic byte (0x%.2x)", header[0]) return fmt.Errorf("invalid magic byte (0x%.2x)", header[0])
} }
framelen := int(binary.BigEndian.Uint16(header[2:])) payloadLen := int(binary.BigEndian.Uint16(header[2:]))
if framelen > len(f.Payload) { if payloadLen > maxPayloadSize {
return fmt.Errorf("payload size greater than maximum allowed (%d vs %d)", return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
framelen, len(f.Payload)) payloadLen, maxPayloadSize)
} }
f.Channel = int(header[1]) f.Channel = int(header[1])
f.Payload = f.Payload[:framelen] f.Payload = make([]byte, payloadLen)
_, err = io.ReadFull(br, f.Payload) _, err = io.ReadFull(br, f.Payload)
if err != nil { if err != nil {

View File

@@ -37,8 +37,7 @@ func TestInterleavedFrameRead(t *testing.T) {
for _, ca := range casesInterleavedFrame { for _, ca := range casesInterleavedFrame {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
f.Payload = make([]byte, 1024) err := f.Read(1024, bufio.NewReader(bytes.NewBuffer(ca.enc)))
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc)))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ca.dec, f) require.Equal(t, ca.dec, f)
}) })
@@ -64,7 +63,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
{ {
"payload size too big", "payload size too big",
[]byte{0x24, 0x00, 0x00, 0x08}, []byte{0x24, 0x00, 0x00, 0x08},
"payload size greater than maximum allowed (8 vs 5)", "payload size (8) greater than maximum allowed (5)",
}, },
{ {
"payload invalid", "payload invalid",
@@ -74,8 +73,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame var f InterleavedFrame
f.Payload = make([]byte, 5) err := f.Read(5, bufio.NewReader(bytes.NewBuffer(ca.byts)))
err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.byts)))
require.EqualError(t, err, ca.err) require.EqualError(t, err, ca.err)
}) })
} }
@@ -99,15 +97,14 @@ func TestReadInterleavedFrameOrRequest(t *testing.T) {
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f InterleavedFrame var f InterleavedFrame
f.Payload = make([]byte, 10)
var req Request var req Request
br := bufio.NewReader(bytes.NewBuffer(byts)) br := bufio.NewReader(bytes.NewBuffer(byts))
out, err := ReadInterleavedFrameOrRequest(&f, &req, br) out, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &req, out) require.Equal(t, &req, out)
out, err = ReadInterleavedFrameOrRequest(&f, &req, br) out, err = ReadInterleavedFrameOrRequest(&f, 10, &req, br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &f, out) require.Equal(t, &f, out)
} }
@@ -136,11 +133,9 @@ func TestReadInterleavedFrameOrRequestErrors(t *testing.T) {
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame var f InterleavedFrame
f.Payload = make([]byte, 10)
var req Request var req Request
br := bufio.NewReader(bytes.NewBuffer(ca.byts)) br := bufio.NewReader(bytes.NewBuffer(ca.byts))
_, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br)
_, err := ReadInterleavedFrameOrRequest(&f, &req, br)
require.EqualError(t, err, ca.err) require.EqualError(t, err, ca.err)
}) })
} }
@@ -154,15 +149,13 @@ func TestReadInterleavedFrameOrResponse(t *testing.T) {
byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...)
var f InterleavedFrame var f InterleavedFrame
f.Payload = make([]byte, 10)
var res Response var res Response
br := bufio.NewReader(bytes.NewBuffer(byts)) br := bufio.NewReader(bytes.NewBuffer(byts))
out, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br)
out, err := ReadInterleavedFrameOrResponse(&f, &res, br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &res, out) require.Equal(t, &res, out)
out, err = ReadInterleavedFrameOrResponse(&f, &res, br) out, err = ReadInterleavedFrameOrResponse(&f, 10, &res, br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, &f, out) require.Equal(t, &f, out)
} }
@@ -191,11 +184,9 @@ func TestReadInterleavedFrameOrResponseErrors(t *testing.T) {
} { } {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
var f InterleavedFrame var f InterleavedFrame
f.Payload = make([]byte, 10)
var res Response var res Response
br := bufio.NewReader(bytes.NewBuffer(ca.byts)) br := bufio.NewReader(bytes.NewBuffer(ca.byts))
_, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br)
_, err := ReadInterleavedFrameOrResponse(&f, &res, br)
require.EqualError(t, err, ca.err) require.EqualError(t, err, ca.err)
}) })
} }

View File

@@ -101,15 +101,11 @@ func (req *Request) Read(rb *bufio.Reader) error {
// ReadIgnoreFrames reads a request and ignores any interleaved frame sent // ReadIgnoreFrames reads a request and ignores any interleaved frame sent
// before the request. // before the request.
func (req *Request) ReadIgnoreFrames(rb *bufio.Reader, buf []byte) error { func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error {
buflen := len(buf) var f InterleavedFrame
f := InterleavedFrame{
Payload: buf,
}
for { for {
f.Payload = f.Payload[:buflen] recv, err := ReadInterleavedFrameOrRequest(&f, maxPayloadSize, req, rb)
recv, err := ReadInterleavedFrameOrRequest(&f, req, rb)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -237,9 +237,8 @@ func TestRequestReadIgnoreFrames(t *testing.T) {
"\r\n")...) "\r\n")...)
rb := bufio.NewReader(bytes.NewBuffer(byts)) rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var req Request var req Request
err := req.ReadIgnoreFrames(rb, buf) err := req.ReadIgnoreFrames(10, rb)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -247,9 +246,8 @@ func TestRequestReadIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25} byts := []byte{0x25}
rb := bufio.NewReader(bytes.NewBuffer(byts)) rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var req Request var req Request
err := req.ReadIgnoreFrames(rb, buf) err := req.ReadIgnoreFrames(10, rb)
require.EqualError(t, err, "EOF") require.EqualError(t, err, "EOF")
} }

View File

@@ -187,15 +187,11 @@ func (res *Response) Read(rb *bufio.Reader) error {
// ReadIgnoreFrames reads a response and ignores any interleaved frame sent // ReadIgnoreFrames reads a response and ignores any interleaved frame sent
// before the response. // before the response.
func (res *Response) ReadIgnoreFrames(rb *bufio.Reader, buf []byte) error { func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error {
buflen := len(buf) var f InterleavedFrame
f := InterleavedFrame{
Payload: buf,
}
for { for {
f.Payload = f.Payload[:buflen] recv, err := ReadInterleavedFrameOrResponse(&f, maxPayloadSize, res, rb)
recv, err := ReadInterleavedFrameOrResponse(&f, res, rb)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -220,9 +220,8 @@ func TestResponseReadIgnoreFrames(t *testing.T) {
"\r\n")...) "\r\n")...)
rb := bufio.NewReader(bytes.NewBuffer(byts)) rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var res Response var res Response
err := res.ReadIgnoreFrames(rb, buf) err := res.ReadIgnoreFrames(10, rb)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -230,9 +229,8 @@ func TestResponseReadIgnoreFramesErrors(t *testing.T) {
byts := []byte{0x25} byts := []byte{0x25}
rb := bufio.NewReader(bytes.NewBuffer(byts)) rb := bufio.NewReader(bytes.NewBuffer(byts))
buf := make([]byte, 10)
var res Response var res Response
err := res.ReadIgnoreFrames(rb, buf) err := res.ReadIgnoreFrames(10, rb)
require.EqualError(t, err, "EOF") require.EqualError(t, err, "EOF")
} }

View File

@@ -15,11 +15,6 @@ import (
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
) )
const (
serverReadBufferSize = 4096
serverUDPKernelReadBufferSize = 0x80000 // same as gstreamer's rtspsrc
)
func extractPort(address string) (int, error) { func extractPort(address string) (int, error) {
_, tmp, err := net.SplitHostPort(address) _, tmp, err := net.SplitHostPort(address)
if err != nil { if err != nil {
@@ -115,10 +110,6 @@ type Server struct {
// that are particularly relevant when using UDP. // that are particularly relevant when using UDP.
// It defaults to 256. // It defaults to 256.
ReadBufferCount int ReadBufferCount int
// read buffer size.
// This must be touched only when the server reports errors about buffer sizes.
// It defaults to 2048.
ReadBufferSize int
// write buffer count. // write buffer count.
// It allows to queue packets before sending them. // It allows to queue packets before sending them.
// It defaults to 256. // It defaults to 256.
@@ -174,9 +165,6 @@ func (s *Server) Start() error {
if s.ReadBufferCount == 0 { if s.ReadBufferCount == 0 {
s.ReadBufferCount = 256 s.ReadBufferCount = 256
} }
if s.ReadBufferSize == 0 {
s.ReadBufferSize = 2048
}
if s.WriteBufferCount == 0 { if s.WriteBufferCount == 0 {
s.WriteBufferCount = 256 s.WriteBufferCount = 256
} }

View File

@@ -736,8 +736,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else { } else {
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) err := f.Read(2048, br)
err := f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload) require.Equal(t, testRTCPPacketMarshaled, f.Payload)
@@ -789,8 +788,7 @@ func TestServerPublish(t *testing.T) {
require.Equal(t, testRTCPPacketMarshaled, buf[:n]) require.Equal(t, testRTCPPacketMarshaled, buf[:n])
} else { } else {
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) err := f.Read(2048, br)
err := f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, f.Channel) require.Equal(t, 1, f.Channel)
require.Equal(t, testRTCPPacketMarshaled, f.Payload) require.Equal(t, testRTCPPacketMarshaled, f.Payload)
@@ -815,21 +813,12 @@ func TestServerPublish(t *testing.T) {
} }
} }
func TestServerPublishNonStandardFrameSize(t *testing.T) { func TestServerPublishOversizedPacket(t *testing.T) {
packet := rtp.Packet{ oversizedPacketsRTPOut := append([]rtp.Packet(nil), oversizedPacketsRTPOut...)
Header: rtp.Header{ packetRecv := make(chan struct{})
Version: 2,
PayloadType: 97,
CSRC: []uint32{},
},
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5),
}
packetMarshaled, _ := packet.Marshal()
frameReceived := make(chan struct{})
s := &Server{ s := &Server{
RTSPAddress: "localhost:8554", RTSPAddress: "localhost:8554",
ReadBufferSize: 4500,
Handler: &testServerHandler{ Handler: &testServerHandler{
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
@@ -848,8 +837,12 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) {
}, },
onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
require.Equal(t, 0, ctx.TrackID) require.Equal(t, 0, ctx.TrackID)
require.Equal(t, &packet, ctx.Packet) cmp := oversizedPacketsRTPOut[0]
close(frameReceived) oversizedPacketsRTPOut = oversizedPacketsRTPOut[1:]
require.Equal(t, &cmp, ctx.Packet)
if len(oversizedPacketsRTPOut) == 0 {
close(packetRecv)
}
}, },
}, },
} }
@@ -921,14 +914,15 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
byts, _ := oversizedPacketRTPIn.Marshal()
base.InterleavedFrame{ base.InterleavedFrame{
Channel: 0, Channel: 0,
Payload: packetMarshaled, Payload: byts,
}.Write(&bb) }.Write(&bb)
_, err = conn.Write(bb.Bytes()) _, err = conn.Write(bb.Bytes())
require.NoError(t, err) require.NoError(t, err)
<-frameReceived <-packetRecv
} }
func TestServerPublishErrorInvalidProtocol(t *testing.T) { func TestServerPublishErrorInvalidProtocol(t *testing.T) {

View File

@@ -484,9 +484,7 @@ func TestServerRead(t *testing.T) {
case "tcp", "tls": case "tcp", "tls":
var f base.InterleavedFrame var f base.InterleavedFrame
err := f.Read(2048, br)
f.Payload = make([]byte, 2048)
err := f.Read(br)
require.NoError(t, err) require.NoError(t, err)
switch f.Channel { switch f.Channel {
@@ -516,8 +514,7 @@ func TestServerRead(t *testing.T) {
var f base.InterleavedFrame var f base.InterleavedFrame
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
f.Payload = make([]byte, 2048) err := f.Read(2048, br)
err := f.Read(br)
require.NoError(t, err) require.NoError(t, err)
switch f.Channel { switch f.Channel {
@@ -763,99 +760,6 @@ func TestServerReadVLCMulticast(t *testing.T) {
require.Equal(t, "224.1.0.0", desc.ConnectionInformation.Address.Address) require.Equal(t, "224.1.0.0", desc.ConnectionInformation.Address.Address)
} }
func TestServerReadNonStandardFrameSize(t *testing.T) {
packet := rtp.Packet{
Header: rtp.Header{
Version: 2,
PayloadType: 97,
CSRC: []uint32{},
},
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5),
}
packetMarshaled, _ := packet.Marshal()
track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil)
require.NoError(t, err)
stream := NewServerStream(Tracks{track})
defer stream.Close()
s := &Server{
Handler: &testServerHandler{
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, stream, nil
},
onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) {
go func() {
time.Sleep(1 * time.Second)
stream.WritePacketRTP(0, &packet, true)
}()
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: "localhost:8554",
}
err = s.Start()
require.NoError(t, err)
defer s.Close()
conn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
br := bufio.NewReader(conn)
inTH := &headers.Transport{
Mode: func() *headers.TransportMode {
v := headers.TransportModePlay
return &v
}(),
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Protocol: headers.TransportProtocolTCP,
InterleavedIDs: &[2]int{0, 1},
}
res, err := writeReqReadRes(conn, br, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Transport": inTH.Write(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var sx headers.Session
err = sx.Read(res.Header["Session"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, br, base.Request{
Method: base.Play,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Session": base.HeaderValue{sx.Session},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var f base.InterleavedFrame
f.Payload = make([]byte, 4500)
err = f.Read(br)
require.NoError(t, err)
require.Equal(t, 0, f.Channel)
require.Equal(t, packetMarshaled, f.Payload)
}
func TestServerReadTCPResponseBeforeFrames(t *testing.T) { func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
writerDone := make(chan struct{}) writerDone := make(chan struct{})
writerTerminate := make(chan struct{}) writerTerminate := make(chan struct{})
@@ -953,8 +857,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
var fr base.InterleavedFrame var fr base.InterleavedFrame
fr.Payload = make([]byte, 2048) err = fr.Read(2048, br)
err = fr.Read(br)
require.NoError(t, err) require.NoError(t, err)
} }
@@ -1701,8 +1604,7 @@ func TestServerReadPartialTracks(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
var f base.InterleavedFrame var f base.InterleavedFrame
f.Payload = make([]byte, 2048) err = f.Read(2048, br)
err = f.Read(br)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 4, f.Channel) require.Equal(t, 4, f.Channel)
require.Equal(t, testRTPPacketMarshaled, f.Payload) require.Equal(t, testRTPPacketMarshaled, f.Payload)

View File

@@ -37,9 +37,8 @@ func writeReqReadRes(conn net.Conn,
} }
func readResIgnoreFrames(br *bufio.Reader) (*base.Response, error) { func readResIgnoreFrames(br *bufio.Reader) (*base.Response, error) {
buf := make([]byte, 2048)
var res base.Response var res base.Response
err := res.ReadIgnoreFrames(br, buf) err := res.ReadIgnoreFrames(2048, br)
return &res, err return &res, err
} }

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@@ -15,7 +16,7 @@ import (
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/multibuffer" "github.com/aler9/gortsplib/pkg/rtph264"
) )
func getSessionID(header base.Header) string { func getSessionID(header base.Header) string {
@@ -109,7 +110,7 @@ func (sc *ServerConn) run() {
}) })
} }
sc.br = bufio.NewReaderSize(sc.conn, serverReadBufferSize) sc.br = bufio.NewReaderSize(sc.conn, tcpReadBufferSize)
readRequest := make(chan readReq) readRequest := make(chan readReq)
readErr := make(chan error) readErr := make(chan error)
@@ -218,20 +219,19 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
case <-sc.session.ctx.Done(): case <-sc.session.ctx.Done():
} }
var tcpReadBuffer *multibuffer.MultiBuffer var processFunc func(int, bool, []byte) error
var processFunc func(int, bool, []byte)
if sc.session.state == ServerSessionStatePlay { if sc.session.state == ServerSessionStatePlay {
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports, processFunc = func(trackID int, isRTP bool, payload []byte) error {
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize))
processFunc = func(trackID int, isRTP bool, payload []byte) {
if !isRTP { if !isRTP {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
return return err
} }
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
@@ -244,30 +244,100 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
} }
} }
} }
return nil
} }
} else { } else {
tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize))
tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount)) tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount))
processFunc = func(trackID int, isRTP bool, payload []byte) { processFunc = func(trackID int, isRTP bool, payload []byte) error {
if isRTP { if isRTP {
pkt := tcpRTPPacketBuffer.next() pkt := tcpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
return return err
} }
sc.session.onPacketRTP(time.Time{}, trackID, pkt) ctx := ServerHandlerOnPacketRTPCtx{
Session: sc.session,
TrackID: trackID,
Packet: pkt,
}
sc.session.processPacketRTP(&ctx)
at := sc.session.announcedTracks[trackID]
if at.h264Decoder != nil {
if at.h264Encoder == nil && len(payload) > udpReadBufferSize {
v1 := pkt.SSRC
v2 := pkt.SequenceNumber
v3 := pkt.Timestamp
at.h264Encoder = &rtph264.Encoder{
PayloadType: pkt.PayloadType,
SSRC: &v1,
InitialSequenceNumber: &v2,
InitialTimestamp: &v3,
}
at.h264Encoder.Init()
}
if at.h264Encoder != nil {
if ctx.H264NALUs != nil {
packets, err := at.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS)
if err != nil {
return err
}
for i, pkt := range packets {
if i != len(packets)-1 {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.session,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
}
} else {
ctx.Packet = pkt
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
}
}
} else {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
} else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
}
} else { } else {
if len(payload) > udpReadBufferSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)",
len(payload), udpReadBufferSize)
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
return return err
} }
for _, pkt := range packets { for _, pkt := range packets {
sc.session.onPacketRTCP(trackID, pkt) sc.session.onPacketRTCP(trackID, pkt)
} }
} }
return nil
} }
} }
@@ -279,8 +349,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
} }
frame.Payload = tcpReadBuffer.Next() what, err := base.ReadInterleavedFrameOrRequest(&frame, tcpMaxFramePayloadSize, &req, sc.br)
what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br)
if err != nil { if err != nil {
return err return err
} }
@@ -296,7 +365,10 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
// forward frame only if it has been set up // forward frame only if it has been set up
if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok { if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok {
processFunc(trackID, isRTP, frame.Payload) err := processFunc(trackID, isRTP, frame.Payload)
if err != nil {
return err
}
} }
case *base.Request: case *base.Request:

View File

@@ -155,6 +155,7 @@ type ServerSessionAnnouncedTrack struct {
track Track track Track
rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpReceiver *rtcpreceiver.RTCPReceiver
h264Decoder *rtph264.Decoder h264Decoder *rtph264.Decoder
h264Encoder *rtph264.Encoder
} }
// ServerSession is a server-side RTSP session. // ServerSession is a server-side RTSP session.
@@ -1103,6 +1104,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
for _, at := range ss.announcedTracks { for _, at := range ss.announcedTracks {
at.h264Decoder = nil at.h264Decoder = nil
at.h264Encoder = nil
} }
ss.state = ServerSessionStatePreRecord ss.state = ServerSessionStatePreRecord
@@ -1212,63 +1214,24 @@ func (ss *ServerSession) runWriter() {
} }
} }
func (ss *ServerSession) onPacketRTP(now time.Time, trackID int, pkt *rtp.Packet) { func (ss *ServerSession) processPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) {
// remove padding // remove padding
pkt.Header.Padding = false ctx.Packet.Header.Padding = false
pkt.PaddingSize = 0 ctx.Packet.PaddingSize = 0
at := ss.announcedTracks[trackID]
// decode
at := ss.announcedTracks[ctx.TrackID]
if at.h264Decoder != nil { if at.h264Decoder != nil {
nalus, pts, err := at.h264Decoder.DecodeUntilMarker(pkt) nalus, pts, err := at.h264Decoder.DecodeUntilMarker(ctx.Packet)
if err == nil { if err == nil {
ptsEqualsDTS := h264.IDRPresent(nalus) ctx.PTSEqualsDTS = h264.IDRPresent(nalus)
ctx.H264NALUs = append([][]byte(nil), nalus...)
rr := at.rtcpReceiver ctx.H264PTS = pts
if rr != nil {
rr.ProcessPacketRTP(now, pkt, ptsEqualsDTS)
}
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: ss,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: ptsEqualsDTS,
H264NALUs: append([][]byte(nil), nalus...),
H264PTS: pts,
})
}
} else { } else {
rr := at.rtcpReceiver ctx.PTSEqualsDTS = false
if rr != nil {
rr.ProcessPacketRTP(now, pkt, false)
}
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: ss,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: false,
})
}
} }
return } else {
} ctx.PTSEqualsDTS = false
rr := at.rtcpReceiver
if rr != nil {
rr.ProcessPacketRTP(now, pkt, true)
}
if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: ss,
TrackID: trackID,
Packet: pkt,
PTSEqualsDTS: true,
})
} }
} }

View File

@@ -98,7 +98,7 @@ func newServerUDPListener(
p := ipv4.NewPacketConn(tmp) p := ipv4.NewPacketConn(tmp)
err = p.SetMulticastTTL(16) err = p.SetMulticastTTL(multicastTTL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -130,7 +130,7 @@ func newServerUDPListener(
listenIP = tmp.LocalAddr().(*net.UDPAddr).IP listenIP = tmp.LocalAddr().(*net.UDPAddr).IP
} }
err := pc.SetReadBuffer(serverUDPKernelReadBufferSize) err := pc.SetReadBuffer(udpKernelReadBufferSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -142,7 +142,7 @@ func newServerUDPListener(
clients: make(map[clientAddr]*clientData), clients: make(map[clientAddr]*clientData),
isRTP: isRTP, isRTP: isRTP,
writeTimeout: s.WriteTimeout, writeTimeout: s.WriteTimeout,
readBuffer: multibuffer.New(uint64(s.ReadBufferCount), uint64(s.ReadBufferSize)), readBuffer: multibuffer.New(uint64(s.ReadBufferCount), uint64(udpReadBufferSize)),
rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(s.ReadBufferCount)), rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(s.ReadBufferCount)),
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
} }
@@ -207,7 +207,16 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
now := time.Now() now := time.Now()
atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix())
clientData.ss.onPacketRTP(now, clientData.trackID, pkt) ctx := ServerHandlerOnPacketRTPCtx{
Session: clientData.ss,
TrackID: clientData.trackID,
Packet: pkt,
}
clientData.ss.processPacketRTP(&ctx)
clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTP(now, ctx.Packet, ctx.PTSEqualsDTS)
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ctx)
}
} }
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {