diff --git a/client.go b/client.go index 04fa20e4..0fcd0660 100644 --- a/client.go +++ b/client.go @@ -28,18 +28,12 @@ import ( "github.com/aler9/gortsplib/pkg/h264" "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/liberrors" - "github.com/aler9/gortsplib/pkg/multibuffer" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" "github.com/aler9/gortsplib/pkg/rtcpsender" "github.com/aler9/gortsplib/pkg/rtph264" ) -const ( - clientReadBufferSize = 4096 - clientUDPKernelReadBufferSize = 0x80000 // same size as gstreamer's rtspsrc -) - func isAnyPort(p int) bool { return p == 0 || p == 1 } @@ -62,6 +56,7 @@ type clientTrack struct { rtcpReceiver *rtcpreceiver.RTCPReceiver rtcpSender *rtcpsender.RTCPSender h264Decoder *rtph264.Decoder + h264Encoder *rtph264.Encoder } func (s clientState) String() string { @@ -187,10 +182,6 @@ type Client struct { // that is reading frames. // It defaults to 256. ReadBufferCount int - // read buffer size. - // This must be touched only when the server reports errors about buffer sizes. - // It defaults to 2048. - ReadBufferSize int // write buffer count. // It allows to queue packets before sending them. // It defaults to 8. @@ -291,9 +282,6 @@ func (c *Client) Start(scheme string, host string) error { if c.ReadBufferCount == 0 { c.ReadBufferCount = 256 } - if c.ReadBufferSize == 0 { - c.ReadBufferSize = 2048 - } if c.WriteBufferCount == 0 { c.WriteBufferCount = 256 } @@ -760,14 +748,12 @@ func (c *Client) runReader() { } } } else { - var tcpReadBuffer *multibuffer.MultiBuffer - var processFunc func(int, bool, []byte) + var processFunc func(int, bool, []byte) error if c.state == clientStatePlay { - tcpReadBuffer = multibuffer.New(uint64(c.ReadBufferCount), uint64(c.ReadBufferSize)) tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)) - processFunc = func(trackID int, isRTP bool, payload []byte) { + processFunc = func(trackID int, isRTP bool, payload []byte) error { now := time.Now() atomic.StoreInt64(c.tcpLastFrameTime, now.Unix()) @@ -775,38 +761,105 @@ func (c *Client) runReader() { pkt := tcpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) if err != nil { - return + return err } - c.onPacketRTP(trackID, pkt) + ctx := ClientOnPacketRTPCtx{ + TrackID: trackID, + Packet: pkt, + } + c.processPacketRTP(&ctx) + + ct := c.tracks[trackID] + + if ct.h264Decoder != nil { + if ct.h264Encoder == nil && len(payload) > udpReadBufferSize { + v1 := pkt.SSRC + v2 := pkt.SequenceNumber + v3 := pkt.Timestamp + ct.h264Encoder = &rtph264.Encoder{ + PayloadType: pkt.PayloadType, + SSRC: &v1, + InitialSequenceNumber: &v2, + InitialTimestamp: &v3, + } + ct.h264Encoder.Init() + } + + if ct.h264Encoder != nil { + if ctx.H264NALUs != nil { + packets, err := ct.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS) + if err != nil { + return err + } + + for i, pkt := range packets { + if i != len(packets)-1 { + c.OnPacketRTP(&ClientOnPacketRTPCtx{ + TrackID: trackID, + Packet: pkt, + PTSEqualsDTS: false, + }) + } else { + ctx.Packet = pkt + c.OnPacketRTP(&ctx) + } + } + } + } else { + c.OnPacketRTP(&ctx) + } + } else { + if len(payload) > udpReadBufferSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + len(payload), udpReadBufferSize) + } + + c.OnPacketRTP(&ctx) + } } else { + if len(payload) > udpReadBufferSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + len(payload), udpReadBufferSize) + } + packets, err := rtcp.Unmarshal(payload) if err != nil { - return + return err } for _, pkt := range packets { - c.onPacketRTCP(trackID, pkt) + c.OnPacketRTCP(&ClientOnPacketRTCPCtx{ + TrackID: trackID, + Packet: pkt, + }) } } + + return nil } } else { - // when recording, tcpReadBuffer is only used to receive RTCP receiver reports, - // that are much smaller than RTP packets and are sent at a fixed interval. - // decrease RAM consumption by allocating less buffers. - tcpReadBuffer = multibuffer.New(8, uint64(c.ReadBufferSize)) - - processFunc = func(trackID int, isRTP bool, payload []byte) { + processFunc = func(trackID int, isRTP bool, payload []byte) error { if !isRTP { + if len(payload) > udpReadBufferSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + len(payload), udpReadBufferSize) + } + packets, err := rtcp.Unmarshal(payload) if err != nil { - return + return err } for _, pkt := range packets { - c.onPacketRTCP(trackID, pkt) + c.OnPacketRTCP(&ClientOnPacketRTCPCtx{ + TrackID: trackID, + Packet: pkt, + }) } } + + return nil } } @@ -814,8 +867,7 @@ func (c *Client) runReader() { var res base.Response for { - frame.Payload = tcpReadBuffer.Next() - what, err := base.ReadInterleavedFrameOrResponse(&frame, &res, c.br) + what, err := base.ReadInterleavedFrameOrResponse(&frame, tcpMaxFramePayloadSize, &res, c.br) if err != nil { return err } @@ -833,7 +885,10 @@ func (c *Client) runReader() { continue } - processFunc(trackID, isRTP, frame.Payload) + err := processFunc(trackID, isRTP, frame.Payload) + if err != nil { + return err + } } } } @@ -874,6 +929,7 @@ func (c *Client) playRecordStop(isClosing bool) { for _, ct := range c.tracks { ct.h264Decoder = nil + ct.h264Encoder = nil } // stop timers @@ -929,7 +985,7 @@ func (c *Client) connOpen() error { return nconn }() - c.br = bufio.NewReaderSize(c.conn, clientReadBufferSize) + c.br = bufio.NewReaderSize(c.conn, tcpReadBufferSize) c.connCloserStart() return nil } @@ -1008,11 +1064,10 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba if allowFrames { // read the response and ignore interleaved frames in between; - // interleaved frames are sent in two scenarios: + // interleaved frames are sent in two cases: // * when the server is v4lrtspserver, before the PLAY response // * when the stream is already playing - buf := make([]byte, c.ReadBufferSize) - err = res.ReadIgnoreFrames(c.br, buf) + err = res.ReadIgnoreFrames(tcpMaxFramePayloadSize, c.br) if err != nil { return nil, err } @@ -1230,7 +1285,7 @@ func (c *Client) doAnnounce(u *base.URL, tracks Tracks) (*base.Response, error) } // in case of ANNOUNCE, the base URL doesn't have a trailing slash. - // (tested with ffmpeg and gstreamer) + // (tested with ffmpeg and GStreamer) baseURL := u.Clone() tracks.setControls() @@ -1847,62 +1902,25 @@ func (c *Client) runWriter() { } } -func (c *Client) onPacketRTP(trackID int, pkt *rtp.Packet) { +func (c *Client) processPacketRTP(ctx *ClientOnPacketRTPCtx) { // remove padding - pkt.Header.Padding = false - pkt.PaddingSize = 0 - - ct := c.tracks[trackID] + ctx.Packet.Header.Padding = false + ctx.Packet.PaddingSize = 0 + // decode + ct := c.tracks[ctx.TrackID] if ct.h264Decoder != nil { - nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(pkt) + nalus, pts, err := ct.h264Decoder.DecodeUntilMarker(ctx.Packet) if err == nil { - ptsEqualsDTS := h264.IDRPresent(nalus) - - rr := ct.rtcpReceiver - if rr != nil { - rr.ProcessPacketRTP(time.Now(), pkt, ptsEqualsDTS) - } - - c.OnPacketRTP(&ClientOnPacketRTPCtx{ - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: ptsEqualsDTS, - H264NALUs: append([][]byte(nil), nalus...), - H264PTS: pts, - }) + ctx.PTSEqualsDTS = h264.IDRPresent(nalus) + ctx.H264NALUs = append([][]byte(nil), nalus...) + ctx.H264PTS = pts } else { - rr := ct.rtcpReceiver - if rr != nil { - rr.ProcessPacketRTP(time.Now(), pkt, false) - } - - c.OnPacketRTP(&ClientOnPacketRTPCtx{ - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: false, - }) + ctx.PTSEqualsDTS = false } - return + } else { + ctx.PTSEqualsDTS = true } - - rr := ct.rtcpReceiver - if rr != nil { - rr.ProcessPacketRTP(time.Now(), pkt, true) - } - - c.OnPacketRTP(&ClientOnPacketRTPCtx{ - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: true, - }) -} - -func (c *Client) onPacketRTCP(trackID int, pkt rtcp.Packet) { - c.OnPacketRTCP(&ClientOnPacketRTCPCtx{ - TrackID: trackID, - Packet: pkt, - }) } // WritePacketRTP writes a RTP packet. diff --git a/client_publish_test.go b/client_publish_test.go index 65e9d436..388da87d 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -182,8 +182,7 @@ func TestClientPublishSerial(t *testing.T) { require.Equal(t, testRTPPacket, pkt) } else { var f base.InterleavedFrame - f.Payload = make([]byte, 2048) - err = f.Read(br) + err = f.Read(1024, br) require.NoError(t, err) require.Equal(t, 0, f.Channel) var pkt rtp.Packet @@ -823,8 +822,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { require.NoError(t, err) var f base.InterleavedFrame - f.Payload = make([]byte, 2048) - err = f.Read(br) + err = f.Read(2048, br) require.NoError(t, err) require.Equal(t, 0, f.Channel) var pkt rtp.Packet diff --git a/client_read_test.go b/client_read_test.go index 581c65fe..edf2ee82 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -21,6 +21,22 @@ import ( "github.com/aler9/gortsplib/pkg/headers" ) +func mergeBytes(vals ...[]byte) []byte { + size := 0 + for _, v := range vals { + size += len(v) + } + res := make([]byte, size) + + pos := 0 + for _, v := range vals { + n := copy(res[pos:], v) + pos += n + } + + return res +} + func TestClientReadTracks(t *testing.T) { track1, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) require.NoError(t, err) @@ -359,8 +375,7 @@ func TestClientRead(t *testing.T) { case "tcp", "tls": var f base.InterleavedFrame - f.Payload = make([]byte, 2048) - err := f.Read(br) + err := f.Read(2048, br) require.NoError(t, err) require.Equal(t, 1, f.Channel) packets, err := rtcp.Unmarshal(f.Payload) @@ -429,15 +444,59 @@ func TestClientRead(t *testing.T) { } } -func TestClientReadNonStandardFrameSize(t *testing.T) { - refRTPPacket := rtp.Packet{ +var oversizedPacketRTPIn = rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + Marker: true, + SequenceNumber: 34572, + }, + Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), +} + +var oversizedPacketsRTPOut = []rtp.Packet{ + { Header: rtp.Header{ - Version: 2, - PayloadType: 96, - CSRC: []uint32{}, + Version: 2, + PayloadType: 96, + Marker: false, + SequenceNumber: 34572, }, - Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), - } + Payload: mergeBytes( + []byte{0x1c, 0x81, 0x02, 0x03, 0x04, 0x05}, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 290), + []byte{0x01, 0x02, 0x03, 0x04}, + ), + }, + { + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + Marker: false, + SequenceNumber: 34573, + }, + Payload: mergeBytes( + []byte{0x1c, 0x01, 0x05}, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 291), + []byte{0x01, 0x02}, + ), + }, + { + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + Marker: true, + SequenceNumber: 34574, + }, + Payload: mergeBytes( + []byte{0x1c, 0x41, 0x03, 0x04, 0x05}, + bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 235), + ), + }, +} + +func TestClientReadOversizedPacket(t *testing.T) { + oversizedPacketsRTPOut := append([]rtp.Packet(nil), oversizedPacketsRTPOut...) l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) @@ -529,7 +588,7 @@ func TestClientReadNonStandardFrameSize(t *testing.T) { _, err = conn.Write(bb.Bytes()) require.NoError(t, err) - byts, _ := refRTPPacket.Marshal() + byts, _ := oversizedPacketRTPIn.Marshal() base.InterleavedFrame{ Channel: 0, Payload: byts, @@ -541,15 +600,18 @@ func TestClientReadNonStandardFrameSize(t *testing.T) { packetRecv := make(chan struct{}) c := &Client{ - ReadBufferSize: 4500 + 4, Transport: func() *Transport { v := TransportTCP return &v }(), OnPacketRTP: func(ctx *ClientOnPacketRTPCtx) { require.Equal(t, 0, ctx.TrackID) - require.Equal(t, &refRTPPacket, ctx.Packet) - close(packetRecv) + cmp := oversizedPacketsRTPOut[0] + oversizedPacketsRTPOut = oversizedPacketsRTPOut[1:] + require.Equal(t, &cmp, ctx.Packet) + if len(oversizedPacketsRTPOut) == 0 { + close(packetRecv) + } }, } diff --git a/client_test.go b/client_test.go index 263f15be..9c4403bf 100644 --- a/client_test.go +++ b/client_test.go @@ -29,9 +29,8 @@ func readRequest(br *bufio.Reader) (*base.Request, error) { } func readRequestIgnoreFrames(br *bufio.Reader) (*base.Request, error) { - buf := make([]byte, 2048) var req base.Request - err := req.ReadIgnoreFrames(br, buf) + err := req.ReadIgnoreFrames(2048, br) return &req, err } diff --git a/clientudpl.go b/clientudpl.go index 64d26a73..a87939c9 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -76,7 +76,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP p := ipv4.NewPacketConn(tmp) - err = p.SetMulticastTTL(16) + err = p.SetMulticastTTL(multicastTTL) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP pc = tmp.(*net.UDPConn) } - err := pc.SetReadBuffer(clientUDPKernelReadBufferSize) + err := pc.SetReadBuffer(udpKernelReadBufferSize) if err != nil { return nil, err } @@ -110,7 +110,7 @@ func newClientUDPListener(c *Client, multicast bool, address string) (*clientUDP return &clientUDPListener{ c: c, pc: pc, - readBuffer: multibuffer.New(uint64(c.ReadBufferCount), uint64(c.ReadBufferSize)), + readBuffer: multibuffer.New(uint64(c.ReadBufferCount), uint64(udpReadBufferSize)), rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(c.ReadBufferCount)), lastPacketTime: func() *int64 { v := int64(0) @@ -182,7 +182,13 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { return } - u.c.onPacketRTP(u.trackID, pkt) + ctx := ClientOnPacketRTPCtx{ + TrackID: u.trackID, + Packet: pkt, + } + u.c.processPacketRTP(&ctx) + u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTP(time.Now(), pkt, ctx.PTSEqualsDTS) + u.c.OnPacketRTP(&ctx) } func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { @@ -193,7 +199,10 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { for _, pkt := range packets { u.c.tracks[u.trackID].rtcpReceiver.ProcessPacketRTCP(now, pkt) - u.c.onPacketRTCP(u.trackID, pkt) + u.c.OnPacketRTCP(&ClientOnPacketRTCPCtx{ + TrackID: u.trackID, + Packet: pkt, + }) } } @@ -204,7 +213,10 @@ func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) { } for _, pkt := range packets { - u.c.onPacketRTCP(u.trackID, pkt) + u.c.OnPacketRTCP(&ClientOnPacketRTCPCtx{ + TrackID: u.trackID, + Packet: pkt, + }) } } diff --git a/constants.go b/constants.go new file mode 100644 index 00000000..7e33c896 --- /dev/null +++ b/constants.go @@ -0,0 +1,9 @@ +package gortsplib + +const ( + tcpReadBufferSize = 4096 + tcpMaxFramePayloadSize = 60 * 1024 * 1024 + udpKernelReadBufferSize = 0x80000 // same size as GStreamer's rtspsrc + udpReadBufferSize = 1472 // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header) + multicastTTL = 16 +) diff --git a/examples/client-publish-options/main.go b/examples/client-publish-options/main.go index 0fa025b9..ba5712b3 100644 --- a/examples/client-publish-options/main.go +++ b/examples/client-publish-options/main.go @@ -23,7 +23,7 @@ func main() { } defer pc.Close() - log.Println("Waiting for a rtp/h264 stream on port 9000 - you can send one with gstreamer:\n" + + log.Println("Waiting for a RTP/H264 stream on port 9000 - you can send one with GStreamer:\n" + "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") diff --git a/examples/client-publish-pause/main.go b/examples/client-publish-pause/main.go index e9d49807..8051523b 100644 --- a/examples/client-publish-pause/main.go +++ b/examples/client-publish-pause/main.go @@ -24,7 +24,7 @@ func main() { } defer pc.Close() - log.Println("Waiting for a rtp/h264 stream on port 9000 - you can send one with gstreamer:\n" + + log.Println("Waiting for a RTP/H264 stream on port 9000 - you can send one with GStreamer:\n" + "gst-launch-1.0 filesrc location=video.mp4 ! qtdemux ! video/x-h264" + " ! h264parse config-interval=1 ! rtph264pay ! udpsink host=127.0.0.1 port=9000") diff --git a/pkg/base/interleavedframe.go b/pkg/base/interleavedframe.go index 55bc175a..4cecb08a 100644 --- a/pkg/base/interleavedframe.go +++ b/pkg/base/interleavedframe.go @@ -13,7 +13,12 @@ const ( ) // ReadInterleavedFrameOrRequest reads an InterleavedFrame or a Response. -func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bufio.Reader) (interface{}, error) { +func ReadInterleavedFrameOrRequest( + frame *InterleavedFrame, + maxPayloadSize int, + req *Request, + br *bufio.Reader, +) (interface{}, error) { b, err := br.ReadByte() if err != nil { return nil, err @@ -21,7 +26,7 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bu br.UnreadByte() if b == interleavedFrameMagicByte { - err := frame.Read(br) + err := frame.Read(maxPayloadSize, br) if err != nil { return nil, err } @@ -36,7 +41,12 @@ func ReadInterleavedFrameOrRequest(frame *InterleavedFrame, req *Request, br *bu } // ReadInterleavedFrameOrResponse reads an InterleavedFrame or a Response. -func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br *bufio.Reader) (interface{}, error) { +func ReadInterleavedFrameOrResponse( + frame *InterleavedFrame, + maxPayloadSize int, + res *Response, + br *bufio.Reader, +) (interface{}, error) { b, err := br.ReadByte() if err != nil { return nil, err @@ -44,7 +54,7 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br * br.UnreadByte() if b == interleavedFrameMagicByte { - err := frame.Read(br) + err := frame.Read(maxPayloadSize, br) if err != nil { return nil, err } @@ -61,15 +71,15 @@ func ReadInterleavedFrameOrResponse(frame *InterleavedFrame, res *Response, br * // InterleavedFrame is an interleaved frame, and allows to transfer binary data // within RTSP/TCP connections. It is used to send and receive RTP and RTCP packets with TCP. type InterleavedFrame struct { - // channel id + // channel ID Channel int - // frame payload + // payload Payload []byte } // Read reads an interleaved frame. -func (f *InterleavedFrame) Read(br *bufio.Reader) error { +func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error { var header [4]byte _, err := io.ReadFull(br, header[:]) if err != nil { @@ -80,14 +90,14 @@ func (f *InterleavedFrame) Read(br *bufio.Reader) error { return fmt.Errorf("invalid magic byte (0x%.2x)", header[0]) } - framelen := int(binary.BigEndian.Uint16(header[2:])) - if framelen > len(f.Payload) { - return fmt.Errorf("payload size greater than maximum allowed (%d vs %d)", - framelen, len(f.Payload)) + payloadLen := int(binary.BigEndian.Uint16(header[2:])) + if payloadLen > maxPayloadSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + payloadLen, maxPayloadSize) } f.Channel = int(header[1]) - f.Payload = f.Payload[:framelen] + f.Payload = make([]byte, payloadLen) _, err = io.ReadFull(br, f.Payload) if err != nil { diff --git a/pkg/base/interleavedframe_test.go b/pkg/base/interleavedframe_test.go index ed887eec..047aa572 100644 --- a/pkg/base/interleavedframe_test.go +++ b/pkg/base/interleavedframe_test.go @@ -37,8 +37,7 @@ func TestInterleavedFrameRead(t *testing.T) { for _, ca := range casesInterleavedFrame { t.Run(ca.name, func(t *testing.T) { - f.Payload = make([]byte, 1024) - err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.enc))) + err := f.Read(1024, bufio.NewReader(bytes.NewBuffer(ca.enc))) require.NoError(t, err) require.Equal(t, ca.dec, f) }) @@ -64,7 +63,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) { { "payload size too big", []byte{0x24, 0x00, 0x00, 0x08}, - "payload size greater than maximum allowed (8 vs 5)", + "payload size (8) greater than maximum allowed (5)", }, { "payload invalid", @@ -74,8 +73,7 @@ func TestInterleavedFrameReadErrors(t *testing.T) { } { t.Run(ca.name, func(t *testing.T) { var f InterleavedFrame - f.Payload = make([]byte, 5) - err := f.Read(bufio.NewReader(bytes.NewBuffer(ca.byts))) + err := f.Read(5, bufio.NewReader(bytes.NewBuffer(ca.byts))) require.EqualError(t, err, ca.err) }) } @@ -99,15 +97,14 @@ func TestReadInterleavedFrameOrRequest(t *testing.T) { byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) var f InterleavedFrame - f.Payload = make([]byte, 10) var req Request br := bufio.NewReader(bytes.NewBuffer(byts)) - out, err := ReadInterleavedFrameOrRequest(&f, &req, br) + out, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br) require.NoError(t, err) require.Equal(t, &req, out) - out, err = ReadInterleavedFrameOrRequest(&f, &req, br) + out, err = ReadInterleavedFrameOrRequest(&f, 10, &req, br) require.NoError(t, err) require.Equal(t, &f, out) } @@ -136,11 +133,9 @@ func TestReadInterleavedFrameOrRequestErrors(t *testing.T) { } { t.Run(ca.name, func(t *testing.T) { var f InterleavedFrame - f.Payload = make([]byte, 10) var req Request br := bufio.NewReader(bytes.NewBuffer(ca.byts)) - - _, err := ReadInterleavedFrameOrRequest(&f, &req, br) + _, err := ReadInterleavedFrameOrRequest(&f, 10, &req, br) require.EqualError(t, err, ca.err) }) } @@ -154,15 +149,13 @@ func TestReadInterleavedFrameOrResponse(t *testing.T) { byts = append(byts, []byte{0x24, 0x6, 0x0, 0x4, 0x1, 0x2, 0x3, 0x4}...) var f InterleavedFrame - f.Payload = make([]byte, 10) var res Response br := bufio.NewReader(bytes.NewBuffer(byts)) - - out, err := ReadInterleavedFrameOrResponse(&f, &res, br) + out, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br) require.NoError(t, err) require.Equal(t, &res, out) - out, err = ReadInterleavedFrameOrResponse(&f, &res, br) + out, err = ReadInterleavedFrameOrResponse(&f, 10, &res, br) require.NoError(t, err) require.Equal(t, &f, out) } @@ -191,11 +184,9 @@ func TestReadInterleavedFrameOrResponseErrors(t *testing.T) { } { t.Run(ca.name, func(t *testing.T) { var f InterleavedFrame - f.Payload = make([]byte, 10) var res Response br := bufio.NewReader(bytes.NewBuffer(ca.byts)) - - _, err := ReadInterleavedFrameOrResponse(&f, &res, br) + _, err := ReadInterleavedFrameOrResponse(&f, 10, &res, br) require.EqualError(t, err, ca.err) }) } diff --git a/pkg/base/request.go b/pkg/base/request.go index a261eee0..926c1721 100644 --- a/pkg/base/request.go +++ b/pkg/base/request.go @@ -101,15 +101,11 @@ func (req *Request) Read(rb *bufio.Reader) error { // ReadIgnoreFrames reads a request and ignores any interleaved frame sent // before the request. -func (req *Request) ReadIgnoreFrames(rb *bufio.Reader, buf []byte) error { - buflen := len(buf) - f := InterleavedFrame{ - Payload: buf, - } +func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error { + var f InterleavedFrame for { - f.Payload = f.Payload[:buflen] - recv, err := ReadInterleavedFrameOrRequest(&f, req, rb) + recv, err := ReadInterleavedFrameOrRequest(&f, maxPayloadSize, req, rb) if err != nil { return err } diff --git a/pkg/base/request_test.go b/pkg/base/request_test.go index 137bb483..673a9777 100644 --- a/pkg/base/request_test.go +++ b/pkg/base/request_test.go @@ -237,9 +237,8 @@ func TestRequestReadIgnoreFrames(t *testing.T) { "\r\n")...) rb := bufio.NewReader(bytes.NewBuffer(byts)) - buf := make([]byte, 10) var req Request - err := req.ReadIgnoreFrames(rb, buf) + err := req.ReadIgnoreFrames(10, rb) require.NoError(t, err) } @@ -247,9 +246,8 @@ func TestRequestReadIgnoreFramesErrors(t *testing.T) { byts := []byte{0x25} rb := bufio.NewReader(bytes.NewBuffer(byts)) - buf := make([]byte, 10) var req Request - err := req.ReadIgnoreFrames(rb, buf) + err := req.ReadIgnoreFrames(10, rb) require.EqualError(t, err, "EOF") } diff --git a/pkg/base/response.go b/pkg/base/response.go index 1bc19c1a..9884705c 100644 --- a/pkg/base/response.go +++ b/pkg/base/response.go @@ -187,15 +187,11 @@ func (res *Response) Read(rb *bufio.Reader) error { // ReadIgnoreFrames reads a response and ignores any interleaved frame sent // before the response. -func (res *Response) ReadIgnoreFrames(rb *bufio.Reader, buf []byte) error { - buflen := len(buf) - f := InterleavedFrame{ - Payload: buf, - } +func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error { + var f InterleavedFrame for { - f.Payload = f.Payload[:buflen] - recv, err := ReadInterleavedFrameOrResponse(&f, res, rb) + recv, err := ReadInterleavedFrameOrResponse(&f, maxPayloadSize, res, rb) if err != nil { return err } diff --git a/pkg/base/response_test.go b/pkg/base/response_test.go index 0aeb69f4..5926b2de 100644 --- a/pkg/base/response_test.go +++ b/pkg/base/response_test.go @@ -220,9 +220,8 @@ func TestResponseReadIgnoreFrames(t *testing.T) { "\r\n")...) rb := bufio.NewReader(bytes.NewBuffer(byts)) - buf := make([]byte, 10) var res Response - err := res.ReadIgnoreFrames(rb, buf) + err := res.ReadIgnoreFrames(10, rb) require.NoError(t, err) } @@ -230,9 +229,8 @@ func TestResponseReadIgnoreFramesErrors(t *testing.T) { byts := []byte{0x25} rb := bufio.NewReader(bytes.NewBuffer(byts)) - buf := make([]byte, 10) var res Response - err := res.ReadIgnoreFrames(rb, buf) + err := res.ReadIgnoreFrames(10, rb) require.EqualError(t, err, "EOF") } diff --git a/server.go b/server.go index f06cf750..7a7316e0 100644 --- a/server.go +++ b/server.go @@ -15,11 +15,6 @@ import ( "github.com/aler9/gortsplib/pkg/liberrors" ) -const ( - serverReadBufferSize = 4096 - serverUDPKernelReadBufferSize = 0x80000 // same as gstreamer's rtspsrc -) - func extractPort(address string) (int, error) { _, tmp, err := net.SplitHostPort(address) if err != nil { @@ -115,10 +110,6 @@ type Server struct { // that are particularly relevant when using UDP. // It defaults to 256. ReadBufferCount int - // read buffer size. - // This must be touched only when the server reports errors about buffer sizes. - // It defaults to 2048. - ReadBufferSize int // write buffer count. // It allows to queue packets before sending them. // It defaults to 256. @@ -174,9 +165,6 @@ func (s *Server) Start() error { if s.ReadBufferCount == 0 { s.ReadBufferCount = 256 } - if s.ReadBufferSize == 0 { - s.ReadBufferSize = 2048 - } if s.WriteBufferCount == 0 { s.WriteBufferCount = 256 } diff --git a/server_publish_test.go b/server_publish_test.go index d388c220..d7957b80 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -736,8 +736,7 @@ func TestServerPublish(t *testing.T) { require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { var f base.InterleavedFrame - f.Payload = make([]byte, 2048) - err := f.Read(br) + err := f.Read(2048, br) require.NoError(t, err) require.Equal(t, 1, f.Channel) require.Equal(t, testRTCPPacketMarshaled, f.Payload) @@ -789,8 +788,7 @@ func TestServerPublish(t *testing.T) { require.Equal(t, testRTCPPacketMarshaled, buf[:n]) } else { var f base.InterleavedFrame - f.Payload = make([]byte, 2048) - err := f.Read(br) + err := f.Read(2048, br) require.NoError(t, err) require.Equal(t, 1, f.Channel) require.Equal(t, testRTCPPacketMarshaled, f.Payload) @@ -815,21 +813,12 @@ func TestServerPublish(t *testing.T) { } } -func TestServerPublishNonStandardFrameSize(t *testing.T) { - packet := rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 97, - CSRC: []uint32{}, - }, - Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), - } - packetMarshaled, _ := packet.Marshal() - frameReceived := make(chan struct{}) +func TestServerPublishOversizedPacket(t *testing.T) { + oversizedPacketsRTPOut := append([]rtp.Packet(nil), oversizedPacketsRTPOut...) + packetRecv := make(chan struct{}) s := &Server{ - RTSPAddress: "localhost:8554", - ReadBufferSize: 4500, + RTSPAddress: "localhost:8554", Handler: &testServerHandler{ onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ @@ -848,8 +837,12 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) { }, onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { require.Equal(t, 0, ctx.TrackID) - require.Equal(t, &packet, ctx.Packet) - close(frameReceived) + cmp := oversizedPacketsRTPOut[0] + oversizedPacketsRTPOut = oversizedPacketsRTPOut[1:] + require.Equal(t, &cmp, ctx.Packet) + if len(oversizedPacketsRTPOut) == 0 { + close(packetRecv) + } }, }, } @@ -921,14 +914,15 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + byts, _ := oversizedPacketRTPIn.Marshal() base.InterleavedFrame{ Channel: 0, - Payload: packetMarshaled, + Payload: byts, }.Write(&bb) _, err = conn.Write(bb.Bytes()) require.NoError(t, err) - <-frameReceived + <-packetRecv } func TestServerPublishErrorInvalidProtocol(t *testing.T) { diff --git a/server_read_test.go b/server_read_test.go index 2904580f..18bcb93b 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -484,9 +484,7 @@ func TestServerRead(t *testing.T) { case "tcp", "tls": var f base.InterleavedFrame - - f.Payload = make([]byte, 2048) - err := f.Read(br) + err := f.Read(2048, br) require.NoError(t, err) switch f.Channel { @@ -516,8 +514,7 @@ func TestServerRead(t *testing.T) { var f base.InterleavedFrame for i := 0; i < 2; i++ { - f.Payload = make([]byte, 2048) - err := f.Read(br) + err := f.Read(2048, br) require.NoError(t, err) switch f.Channel { @@ -763,99 +760,6 @@ func TestServerReadVLCMulticast(t *testing.T) { require.Equal(t, "224.1.0.0", desc.ConnectionInformation.Address.Address) } -func TestServerReadNonStandardFrameSize(t *testing.T) { - packet := rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 97, - CSRC: []uint32{}, - }, - Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4096/5), - } - packetMarshaled, _ := packet.Marshal() - - track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) - require.NoError(t, err) - - stream := NewServerStream(Tracks{track}) - defer stream.Close() - - s := &Server{ - Handler: &testServerHandler{ - onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, stream, nil - }, - onPlay: func(ctx *ServerHandlerOnPlayCtx) (*base.Response, error) { - go func() { - time.Sleep(1 * time.Second) - stream.WritePacketRTP(0, &packet, true) - }() - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - }, - }, - RTSPAddress: "localhost:8554", - } - - err = s.Start() - require.NoError(t, err) - defer s.Close() - - conn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - br := bufio.NewReader(conn) - - inTH := &headers.Transport{ - Mode: func() *headers.TransportMode { - v := headers.TransportModePlay - return &v - }(), - Delivery: func() *headers.TransportDelivery { - v := headers.TransportDeliveryUnicast - return &v - }(), - Protocol: headers.TransportProtocolTCP, - InterleavedIDs: &[2]int{0, 1}, - } - - res, err := writeReqReadRes(conn, br, base.Request{ - Method: base.Setup, - URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Transport": inTH.Write(), - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - - res, err = writeReqReadRes(conn, br, base.Request{ - Method: base.Play, - URL: mustParseURL("rtsp://localhost:8554/teststream"), - Header: base.Header{ - "CSeq": base.HeaderValue{"2"}, - "Session": base.HeaderValue{sx.Session}, - }, - }) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - - var f base.InterleavedFrame - f.Payload = make([]byte, 4500) - err = f.Read(br) - require.NoError(t, err) - require.Equal(t, 0, f.Channel) - require.Equal(t, packetMarshaled, f.Payload) -} - func TestServerReadTCPResponseBeforeFrames(t *testing.T) { writerDone := make(chan struct{}) writerTerminate := make(chan struct{}) @@ -953,8 +857,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) var fr base.InterleavedFrame - fr.Payload = make([]byte, 2048) - err = fr.Read(br) + err = fr.Read(2048, br) require.NoError(t, err) } @@ -1701,8 +1604,7 @@ func TestServerReadPartialTracks(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) var f base.InterleavedFrame - f.Payload = make([]byte, 2048) - err = f.Read(br) + err = f.Read(2048, br) require.NoError(t, err) require.Equal(t, 4, f.Channel) require.Equal(t, testRTPPacketMarshaled, f.Payload) diff --git a/server_test.go b/server_test.go index 868a3c02..5ab9a157 100644 --- a/server_test.go +++ b/server_test.go @@ -37,9 +37,8 @@ func writeReqReadRes(conn net.Conn, } func readResIgnoreFrames(br *bufio.Reader) (*base.Response, error) { - buf := make([]byte, 2048) var res base.Response - err := res.ReadIgnoreFrames(br, buf) + err := res.ReadIgnoreFrames(2048, br) return &res, err } diff --git a/serverconn.go b/serverconn.go index d1bded4b..7b503ab3 100644 --- a/serverconn.go +++ b/serverconn.go @@ -6,6 +6,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "net" "net/url" "strings" @@ -15,7 +16,7 @@ import ( "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/liberrors" - "github.com/aler9/gortsplib/pkg/multibuffer" + "github.com/aler9/gortsplib/pkg/rtph264" ) func getSessionID(header base.Header) string { @@ -109,7 +110,7 @@ func (sc *ServerConn) run() { }) } - sc.br = bufio.NewReaderSize(sc.conn, serverReadBufferSize) + sc.br = bufio.NewReaderSize(sc.conn, tcpReadBufferSize) readRequest := make(chan readReq) readErr := make(chan error) @@ -218,20 +219,19 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { case <-sc.session.ctx.Done(): } - var tcpReadBuffer *multibuffer.MultiBuffer - var processFunc func(int, bool, []byte) + var processFunc func(int, bool, []byte) error if sc.session.state == ServerSessionStatePlay { - // when playing, tcpReadBuffer is only used to receive RTCP receiver reports, - // that are much smaller than RTP packets and are sent at a fixed interval. - // decrease RAM consumption by allocating less buffers. - tcpReadBuffer = multibuffer.New(8, uint64(sc.s.ReadBufferSize)) - - processFunc = func(trackID int, isRTP bool, payload []byte) { + processFunc = func(trackID int, isRTP bool, payload []byte) error { if !isRTP { + if len(payload) > udpReadBufferSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + len(payload), udpReadBufferSize) + } + packets, err := rtcp.Unmarshal(payload) if err != nil { - return + return err } if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { @@ -244,30 +244,100 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { } } } + + return nil } } else { - tcpReadBuffer = multibuffer.New(uint64(sc.s.ReadBufferCount), uint64(sc.s.ReadBufferSize)) tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount)) - processFunc = func(trackID int, isRTP bool, payload []byte) { + processFunc = func(trackID int, isRTP bool, payload []byte) error { if isRTP { pkt := tcpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) if err != nil { - return + return err } - sc.session.onPacketRTP(time.Time{}, trackID, pkt) + ctx := ServerHandlerOnPacketRTPCtx{ + Session: sc.session, + TrackID: trackID, + Packet: pkt, + } + sc.session.processPacketRTP(&ctx) + + at := sc.session.announcedTracks[trackID] + + if at.h264Decoder != nil { + if at.h264Encoder == nil && len(payload) > udpReadBufferSize { + v1 := pkt.SSRC + v2 := pkt.SequenceNumber + v3 := pkt.Timestamp + at.h264Encoder = &rtph264.Encoder{ + PayloadType: pkt.PayloadType, + SSRC: &v1, + InitialSequenceNumber: &v2, + InitialTimestamp: &v3, + } + at.h264Encoder.Init() + } + + if at.h264Encoder != nil { + if ctx.H264NALUs != nil { + packets, err := at.h264Encoder.Encode(ctx.H264NALUs, ctx.H264PTS) + if err != nil { + return err + } + + for i, pkt := range packets { + if i != len(packets)-1 { + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: sc.session, + TrackID: trackID, + Packet: pkt, + PTSEqualsDTS: false, + }) + } + } else { + ctx.Packet = pkt + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ctx) + } + } + } + } + } else { + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ctx) + } + } + } else { + if len(payload) > udpReadBufferSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + len(payload), udpReadBufferSize) + } + + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ctx) + } + } } else { + if len(payload) > udpReadBufferSize { + return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + len(payload), udpReadBufferSize) + } + packets, err := rtcp.Unmarshal(payload) if err != nil { - return + return err } for _, pkt := range packets { sc.session.onPacketRTCP(trackID, pkt) } } + + return nil } } @@ -279,8 +349,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout)) } - frame.Payload = tcpReadBuffer.Next() - what, err := base.ReadInterleavedFrameOrRequest(&frame, &req, sc.br) + what, err := base.ReadInterleavedFrameOrRequest(&frame, tcpMaxFramePayloadSize, &req, sc.br) if err != nil { return err } @@ -296,7 +365,10 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { // forward frame only if it has been set up if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok { - processFunc(trackID, isRTP, frame.Payload) + err := processFunc(trackID, isRTP, frame.Payload) + if err != nil { + return err + } } case *base.Request: diff --git a/serversession.go b/serversession.go index 544ffb17..4bce0b6f 100644 --- a/serversession.go +++ b/serversession.go @@ -155,6 +155,7 @@ type ServerSessionAnnouncedTrack struct { track Track rtcpReceiver *rtcpreceiver.RTCPReceiver h264Decoder *rtph264.Decoder + h264Encoder *rtph264.Encoder } // ServerSession is a server-side RTSP session. @@ -1103,6 +1104,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base for _, at := range ss.announcedTracks { at.h264Decoder = nil + at.h264Encoder = nil } ss.state = ServerSessionStatePreRecord @@ -1212,63 +1214,24 @@ func (ss *ServerSession) runWriter() { } } -func (ss *ServerSession) onPacketRTP(now time.Time, trackID int, pkt *rtp.Packet) { +func (ss *ServerSession) processPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) { // remove padding - pkt.Header.Padding = false - pkt.PaddingSize = 0 - - at := ss.announcedTracks[trackID] + ctx.Packet.Header.Padding = false + ctx.Packet.PaddingSize = 0 + // decode + at := ss.announcedTracks[ctx.TrackID] if at.h264Decoder != nil { - nalus, pts, err := at.h264Decoder.DecodeUntilMarker(pkt) + nalus, pts, err := at.h264Decoder.DecodeUntilMarker(ctx.Packet) if err == nil { - ptsEqualsDTS := h264.IDRPresent(nalus) - - rr := at.rtcpReceiver - if rr != nil { - rr.ProcessPacketRTP(now, pkt, ptsEqualsDTS) - } - - if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: ss, - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: ptsEqualsDTS, - H264NALUs: append([][]byte(nil), nalus...), - H264PTS: pts, - }) - } + ctx.PTSEqualsDTS = h264.IDRPresent(nalus) + ctx.H264NALUs = append([][]byte(nil), nalus...) + ctx.H264PTS = pts } else { - rr := at.rtcpReceiver - if rr != nil { - rr.ProcessPacketRTP(now, pkt, false) - } - - if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: ss, - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: false, - }) - } + ctx.PTSEqualsDTS = false } - return - } - - rr := at.rtcpReceiver - if rr != nil { - rr.ProcessPacketRTP(now, pkt, true) - } - - if h, ok := ss.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: ss, - TrackID: trackID, - Packet: pkt, - PTSEqualsDTS: true, - }) + } else { + ctx.PTSEqualsDTS = false } } diff --git a/serverudpl.go b/serverudpl.go index 7727526e..e120e2c8 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -98,7 +98,7 @@ func newServerUDPListener( p := ipv4.NewPacketConn(tmp) - err = p.SetMulticastTTL(16) + err = p.SetMulticastTTL(multicastTTL) if err != nil { return nil, err } @@ -130,7 +130,7 @@ func newServerUDPListener( listenIP = tmp.LocalAddr().(*net.UDPAddr).IP } - err := pc.SetReadBuffer(serverUDPKernelReadBufferSize) + err := pc.SetReadBuffer(udpKernelReadBufferSize) if err != nil { return nil, err } @@ -142,7 +142,7 @@ func newServerUDPListener( clients: make(map[clientAddr]*clientData), isRTP: isRTP, writeTimeout: s.WriteTimeout, - readBuffer: multibuffer.New(uint64(s.ReadBufferCount), uint64(s.ReadBufferSize)), + readBuffer: multibuffer.New(uint64(s.ReadBufferCount), uint64(udpReadBufferSize)), rtpPacketBuffer: newRTPPacketMultiBuffer(uint64(s.ReadBufferCount)), readerDone: make(chan struct{}), } @@ -207,7 +207,16 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { now := time.Now() atomic.StoreInt64(clientData.ss.udpLastFrameTime, now.Unix()) - clientData.ss.onPacketRTP(now, clientData.trackID, pkt) + ctx := ServerHandlerOnPacketRTPCtx{ + Session: clientData.ss, + TrackID: clientData.trackID, + Packet: pkt, + } + clientData.ss.processPacketRTP(&ctx) + clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessPacketRTP(now, ctx.Packet, ctx.PTSEqualsDTS) + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ctx) + } } func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {