From b1f72f9392045b79f7fc044bd1fcd39455d3b6df Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Mon, 31 Oct 2022 15:38:23 +0100 Subject: [PATCH] return error in OnDecodeError when there are oversized UDP packets --- client.go | 4 ++-- client_read_test.go | 31 +++++++++++++++++++++++++------ clientudpl.go | 17 ++++++++++++++++- pkg/rtpcleaner/cleaner.go | 2 +- pkg/rtpcleaner/cleaner_test.go | 2 +- server_publish_test.go | 31 +++++++++++++++++++++++++------ serverconn.go | 4 ++-- serverudpl.go | 22 +++++++++++++++++++++- 8 files changed, 93 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 0a1254b3..730e37d9 100644 --- a/client.go +++ b/client.go @@ -820,7 +820,7 @@ func (c *Client) runReader() { } } else { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", len(payload), maxPacketSize) } @@ -846,7 +846,7 @@ func (c *Client) runReader() { processFunc = func(track *clientTrack, isRTP bool, payload []byte) error { if !isRTP { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", len(payload), maxPacketSize) } diff --git a/client_read_test.go b/client_read_test.go index e6feab0b..b340c1a4 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bytes" "crypto/tls" "fmt" "net" @@ -2715,9 +2716,11 @@ func TestClientReadDifferentSource(t *testing.T) { func TestClientReadDecodeErrors(t *testing.T) { for _, ca := range []string{ - "invalid rtp", - "invalid rtcp", + "rtp invalid", + "rtcp invalid", "packets lost", + "rtp too big", + "rtcp too big", } { t.Run(ca, func(t *testing.T) { errorRecv := make(chan struct{}) @@ -2821,13 +2824,13 @@ func TestClientReadDecodeErrors(t *testing.T) { require.NoError(t, err) switch ca { //nolint:dupl - case "invalid rtp": + case "rtp invalid": l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) - case "invalid rtcp": + case "rtcp invalid": l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[1], @@ -2853,6 +2856,18 @@ func TestClientReadDecodeErrors(t *testing.T) { IP: net.ParseIP("127.0.0.1"), Port: th.ClientPorts[0], }) + + case "rtp too big": + l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[0], + }) + + case "rtcp too big": + l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[1], + }) } req, err = conn.ReadRequest() @@ -2873,12 +2888,16 @@ func TestClientReadDecodeErrors(t *testing.T) { }(), OnDecodeError: func(err error) { switch ca { - case "invalid rtp": + case "rtp invalid": require.EqualError(t, err, "RTP header size insufficient: 2 < 4") - case "invalid rtcp": + case "rtcp invalid": require.EqualError(t, err, "rtcp: packet too short") case "packets lost": require.EqualError(t, err, "69 RTP packet(s) lost") + case "rtp too big": + require.EqualError(t, err, "RTP packet is too big to be read with UDP") + case "rtcp too big": + require.EqualError(t, err, "RTCP packet is too big to be read with UDP") } close(errorRecv) }, diff --git a/clientudpl.go b/clientudpl.go index d32d20ac..b27a5b34 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -171,7 +171,7 @@ func (u *clientUDPListener) runReader(forPlay bool) { } for { - buf := make([]byte, maxPacketSize) + buf := make([]byte, maxPacketSize+1) n, addr, err := u.pc.ReadFrom(buf) if err != nil { return @@ -191,6 +191,11 @@ func (u *clientUDPListener) runReader(forPlay bool) { } func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { + if len(payload) == (maxPacketSize + 1) { + u.c.OnDecodeError(fmt.Errorf("RTP packet is too big to be read with UDP")) + return + } + pkt := u.ct.udpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) if err != nil { @@ -226,6 +231,11 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { } func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { + if len(payload) == (maxPacketSize + 1) { + u.c.OnDecodeError(fmt.Errorf("RTCP packet is too big to be read with UDP")) + return + } + packets, err := rtcp.Unmarshal(payload) if err != nil { u.c.OnDecodeError(err) @@ -242,6 +252,11 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { } func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) { + if len(payload) == (maxPacketSize + 1) { + u.c.OnDecodeError(fmt.Errorf("RTCP packet is too big to be read with UDP")) + return + } + packets, err := rtcp.Unmarshal(payload) if err != nil { u.c.OnDecodeError(err) diff --git a/pkg/rtpcleaner/cleaner.go b/pkg/rtpcleaner/cleaner.go index 9e032c44..4cf9ae73 100644 --- a/pkg/rtpcleaner/cleaner.go +++ b/pkg/rtpcleaner/cleaner.go @@ -134,7 +134,7 @@ func (p *Cleaner) Process(pkt *rtp.Packet) ([]*Output, error) { } if p.isTCP && pkt.MarshalSize() > maxPacketSize { - return nil, fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + return nil, fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", pkt.MarshalSize(), maxPacketSize) } diff --git a/pkg/rtpcleaner/cleaner_test.go b/pkg/rtpcleaner/cleaner_test.go index 249e6903..91eae3cb 100644 --- a/pkg/rtpcleaner/cleaner_test.go +++ b/pkg/rtpcleaner/cleaner_test.go @@ -49,7 +49,7 @@ func TestGenericOversized(t *testing.T) { }, Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2050/5), }) - require.EqualError(t, err, "payload size (2062) greater than maximum allowed (1472)") + require.EqualError(t, err, "payload size (2062) is greater than maximum allowed (1472)") } func TestH264Oversized(t *testing.T) { diff --git a/server_publish_test.go b/server_publish_test.go index cc0069dc..2abc8557 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "bytes" "crypto/tls" "net" "testing" @@ -1477,9 +1478,11 @@ func TestServerPublishUDPChangeConn(t *testing.T) { func TestServerPublishDecodeErrors(t *testing.T) { for _, ca := range []string{ - "invalid rtp", - "invalid rtcp", + "rtp invalid", + "rtcp invalid", "packets lost", + "rtp too big", + "rtcp too big", } { t.Run(ca, func(t *testing.T) { errorRecv := make(chan struct{}) @@ -1503,12 +1506,16 @@ func TestServerPublishDecodeErrors(t *testing.T) { }, onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { switch ca { - case "invalid rtp": + case "rtp invalid": require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") - case "invalid rtcp": + case "rtcp invalid": require.EqualError(t, ctx.Error, "rtcp: packet too short") case "packets lost": require.EqualError(t, ctx.Error, "69 RTP packet(s) lost") + case "rtp too big": + require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP") + case "rtcp too big": + require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP") } close(errorRecv) }, @@ -1598,13 +1605,13 @@ func TestServerPublishDecodeErrors(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) switch ca { //nolint:dupl - case "invalid rtp": + case "rtp invalid": l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[0], }) - case "invalid rtcp": + case "rtcp invalid": l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[1], @@ -1630,6 +1637,18 @@ func TestServerPublishDecodeErrors(t *testing.T) { IP: net.ParseIP("127.0.0.1"), Port: resTH.ServerPorts[0], }) + + case "rtp too big": + l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[0], + }) + + case "rtcp too big": + l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[1], + }) } <-errorRecv diff --git a/serverconn.go b/serverconn.go index cefe55fd..6e052adc 100644 --- a/serverconn.go +++ b/serverconn.go @@ -222,7 +222,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { processFunc = func(trackID int, isRTP bool, payload []byte) error { if !isRTP { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", len(payload), maxPacketSize) } @@ -274,7 +274,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { } } else { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", + return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", len(payload), maxPacketSize) } diff --git a/serverudpl.go b/serverudpl.go index a577ad43..b2cdff94 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -169,7 +169,7 @@ func (u *serverUDPListener) runReader() { } for { - buf := make([]byte, maxPacketSize) + buf := make([]byte, maxPacketSize+1) n, addr, err := u.pc.ReadFromUDP(buf) if err != nil { break @@ -192,6 +192,16 @@ func (u *serverUDPListener) runReader() { } func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { + if len(payload) == (maxPacketSize + 1) { + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok { + h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{ + Session: clientData.ss, + Error: fmt.Errorf("RTP packet is too big to be read with UDP"), + }) + } + return + } + pkt := u.s.udpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) if err != nil { @@ -248,6 +258,16 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { } func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { + if len(payload) == (maxPacketSize + 1) { + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok { + h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{ + Session: clientData.ss, + Error: fmt.Errorf("RTCP packet is too big to be read with UDP"), + }) + } + return + } + packets, err := rtcp.Unmarshal(payload) if err != nil { if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {