return error in OnDecodeError when there are oversized UDP packets

This commit is contained in:
aler9
2022-10-31 15:38:23 +01:00
parent 5a5f454814
commit b1f72f9392
8 changed files with 93 additions and 20 deletions

View File

@@ -820,7 +820,7 @@ func (c *Client) runReader() {
} }
} else { } else {
if len(payload) > maxPacketSize { if len(payload) > maxPacketSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)",
len(payload), maxPacketSize) len(payload), maxPacketSize)
} }
@@ -846,7 +846,7 @@ func (c *Client) runReader() {
processFunc = func(track *clientTrack, isRTP bool, payload []byte) error { processFunc = func(track *clientTrack, isRTP bool, payload []byte) error {
if !isRTP { if !isRTP {
if len(payload) > maxPacketSize { if len(payload) > maxPacketSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)",
len(payload), maxPacketSize) len(payload), maxPacketSize)
} }

View File

@@ -1,6 +1,7 @@
package gortsplib package gortsplib
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@@ -2715,9 +2716,11 @@ func TestClientReadDifferentSource(t *testing.T) {
func TestClientReadDecodeErrors(t *testing.T) { func TestClientReadDecodeErrors(t *testing.T) {
for _, ca := range []string{ for _, ca := range []string{
"invalid rtp", "rtp invalid",
"invalid rtcp", "rtcp invalid",
"packets lost", "packets lost",
"rtp too big",
"rtcp too big",
} { } {
t.Run(ca, func(t *testing.T) { t.Run(ca, func(t *testing.T) {
errorRecv := make(chan struct{}) errorRecv := make(chan struct{})
@@ -2821,13 +2824,13 @@ func TestClientReadDecodeErrors(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
switch ca { //nolint:dupl switch ca { //nolint:dupl
case "invalid rtp": case "rtp invalid":
l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0], Port: th.ClientPorts[0],
}) })
case "invalid rtcp": case "rtcp invalid":
l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1], Port: th.ClientPorts[1],
@@ -2853,6 +2856,18 @@ func TestClientReadDecodeErrors(t *testing.T) {
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0], Port: th.ClientPorts[0],
}) })
case "rtp too big":
l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
case "rtcp too big":
l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
} }
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
@@ -2873,12 +2888,16 @@ func TestClientReadDecodeErrors(t *testing.T) {
}(), }(),
OnDecodeError: func(err error) { OnDecodeError: func(err error) {
switch ca { switch ca {
case "invalid rtp": case "rtp invalid":
require.EqualError(t, err, "RTP header size insufficient: 2 < 4") require.EqualError(t, err, "RTP header size insufficient: 2 < 4")
case "invalid rtcp": case "rtcp invalid":
require.EqualError(t, err, "rtcp: packet too short") require.EqualError(t, err, "rtcp: packet too short")
case "packets lost": case "packets lost":
require.EqualError(t, err, "69 RTP packet(s) lost") require.EqualError(t, err, "69 RTP packet(s) lost")
case "rtp too big":
require.EqualError(t, err, "RTP packet is too big to be read with UDP")
case "rtcp too big":
require.EqualError(t, err, "RTCP packet is too big to be read with UDP")
} }
close(errorRecv) close(errorRecv)
}, },

View File

@@ -171,7 +171,7 @@ func (u *clientUDPListener) runReader(forPlay bool) {
} }
for { for {
buf := make([]byte, maxPacketSize) buf := make([]byte, maxPacketSize+1)
n, addr, err := u.pc.ReadFrom(buf) n, addr, err := u.pc.ReadFrom(buf)
if err != nil { if err != nil {
return return
@@ -191,6 +191,11 @@ func (u *clientUDPListener) runReader(forPlay bool) {
} }
func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
if len(payload) == (maxPacketSize + 1) {
u.c.OnDecodeError(fmt.Errorf("RTP packet is too big to be read with UDP"))
return
}
pkt := u.ct.udpRTPPacketBuffer.next() pkt := u.ct.udpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
@@ -226,6 +231,11 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
} }
func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
if len(payload) == (maxPacketSize + 1) {
u.c.OnDecodeError(fmt.Errorf("RTCP packet is too big to be read with UDP"))
return
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
u.c.OnDecodeError(err) u.c.OnDecodeError(err)
@@ -242,6 +252,11 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
} }
func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) {
if len(payload) == (maxPacketSize + 1) {
u.c.OnDecodeError(fmt.Errorf("RTCP packet is too big to be read with UDP"))
return
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
u.c.OnDecodeError(err) u.c.OnDecodeError(err)

View File

@@ -134,7 +134,7 @@ func (p *Cleaner) Process(pkt *rtp.Packet) ([]*Output, error) {
} }
if p.isTCP && pkt.MarshalSize() > maxPacketSize { if p.isTCP && pkt.MarshalSize() > maxPacketSize {
return nil, fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", return nil, fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)",
pkt.MarshalSize(), maxPacketSize) pkt.MarshalSize(), maxPacketSize)
} }

View File

@@ -49,7 +49,7 @@ func TestGenericOversized(t *testing.T) {
}, },
Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2050/5), Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2050/5),
}) })
require.EqualError(t, err, "payload size (2062) greater than maximum allowed (1472)") require.EqualError(t, err, "payload size (2062) is greater than maximum allowed (1472)")
} }
func TestH264Oversized(t *testing.T) { func TestH264Oversized(t *testing.T) {

View File

@@ -1,6 +1,7 @@
package gortsplib package gortsplib
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"net" "net"
"testing" "testing"
@@ -1477,9 +1478,11 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
func TestServerPublishDecodeErrors(t *testing.T) { func TestServerPublishDecodeErrors(t *testing.T) {
for _, ca := range []string{ for _, ca := range []string{
"invalid rtp", "rtp invalid",
"invalid rtcp", "rtcp invalid",
"packets lost", "packets lost",
"rtp too big",
"rtcp too big",
} { } {
t.Run(ca, func(t *testing.T) { t.Run(ca, func(t *testing.T) {
errorRecv := make(chan struct{}) errorRecv := make(chan struct{})
@@ -1503,12 +1506,16 @@ func TestServerPublishDecodeErrors(t *testing.T) {
}, },
onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) {
switch ca { switch ca {
case "invalid rtp": case "rtp invalid":
require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4")
case "invalid rtcp": case "rtcp invalid":
require.EqualError(t, ctx.Error, "rtcp: packet too short") require.EqualError(t, ctx.Error, "rtcp: packet too short")
case "packets lost": case "packets lost":
require.EqualError(t, ctx.Error, "69 RTP packet(s) lost") require.EqualError(t, ctx.Error, "69 RTP packet(s) lost")
case "rtp too big":
require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP")
case "rtcp too big":
require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP")
} }
close(errorRecv) close(errorRecv)
}, },
@@ -1598,13 +1605,13 @@ func TestServerPublishDecodeErrors(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
switch ca { //nolint:dupl switch ca { //nolint:dupl
case "invalid rtp": case "rtp invalid":
l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0], Port: resTH.ServerPorts[0],
}) })
case "invalid rtcp": case "rtcp invalid":
l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[1], Port: resTH.ServerPorts[1],
@@ -1630,6 +1637,18 @@ func TestServerPublishDecodeErrors(t *testing.T) {
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0], Port: resTH.ServerPorts[0],
}) })
case "rtp too big":
l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0],
})
case "rtcp too big":
l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[1],
})
} }
<-errorRecv <-errorRecv

View File

@@ -222,7 +222,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
processFunc = func(trackID int, isRTP bool, payload []byte) error { processFunc = func(trackID int, isRTP bool, payload []byte) error {
if !isRTP { if !isRTP {
if len(payload) > maxPacketSize { if len(payload) > maxPacketSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)",
len(payload), maxPacketSize) len(payload), maxPacketSize)
} }
@@ -274,7 +274,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
} }
} else { } else {
if len(payload) > maxPacketSize { if len(payload) > maxPacketSize {
return fmt.Errorf("payload size (%d) greater than maximum allowed (%d)", return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)",
len(payload), maxPacketSize) len(payload), maxPacketSize)
} }

View File

@@ -169,7 +169,7 @@ func (u *serverUDPListener) runReader() {
} }
for { for {
buf := make([]byte, maxPacketSize) buf := make([]byte, maxPacketSize+1)
n, addr, err := u.pc.ReadFromUDP(buf) n, addr, err := u.pc.ReadFromUDP(buf)
if err != nil { if err != nil {
break break
@@ -192,6 +192,16 @@ func (u *serverUDPListener) runReader() {
} }
func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
if len(payload) == (maxPacketSize + 1) {
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {
h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{
Session: clientData.ss,
Error: fmt.Errorf("RTP packet is too big to be read with UDP"),
})
}
return
}
pkt := u.s.udpRTPPacketBuffer.next() pkt := u.s.udpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
@@ -248,6 +258,16 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
} }
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {
if len(payload) == (maxPacketSize + 1) {
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {
h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{
Session: clientData.ss,
Error: fmt.Errorf("RTCP packet is too big to be read with UDP"),
})
}
return
}
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok { if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {