add OnDecodeError callback to both client and server

This allows to detect decode errors of RTP and RTCP packets
This commit is contained in:
aler9
2022-10-31 12:36:30 +01:00
parent 9b5071f505
commit 30e029011b
7 changed files with 372 additions and 30 deletions

View File

@@ -240,6 +240,8 @@ type Client struct {
OnPacketRTP func(*ClientOnPacketRTPCtx) OnPacketRTP func(*ClientOnPacketRTPCtx)
// called when receiving a RTCP packet. // called when receiving a RTCP packet.
OnPacketRTCP func(*ClientOnPacketRTCPCtx) OnPacketRTCP func(*ClientOnPacketRTCPCtx)
// called when there's a non-fatal decoding error of RTP or RTCP packets.
OnDecodeError func(error)
// //
// private // private
@@ -335,12 +337,24 @@ func (c *Client) Start(scheme string, host string) error {
} }
// callbacks // callbacks
if c.OnRequest == nil {
c.OnRequest = func(*base.Request) {
}
}
if c.OnResponse == nil {
c.OnResponse = func(*base.Response) {
}
}
if c.OnPacketRTP == nil { if c.OnPacketRTP == nil {
c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) { c.OnPacketRTP = func(*ClientOnPacketRTPCtx) {
} }
} }
if c.OnPacketRTCP == nil { if c.OnPacketRTCP == nil {
c.OnPacketRTCP = func(ctx *ClientOnPacketRTCPCtx) { c.OnPacketRTCP = func(*ClientOnPacketRTCPCtx) {
}
}
if c.OnDecodeError == nil {
c.OnDecodeError = func(error) {
} }
} }
@@ -814,6 +828,7 @@ func (c *Client) runReader() {
if err != nil { if err != nil {
// some cameras send invalid RTCP packets. // some cameras send invalid RTCP packets.
// skip them. // skip them.
c.OnDecodeError(err)
return nil return nil
} }
@@ -1038,9 +1053,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
c.sender.AddAuthorization(req) c.sender.AddAuthorization(req)
} }
if c.OnRequest != nil { c.OnRequest(req)
c.OnRequest(req)
}
c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout))
err := c.conn.WriteRequest(req) err := c.conn.WriteRequest(req)
@@ -1067,9 +1080,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba
return nil, err return nil, err
} }
if c.OnResponse != nil { c.OnResponse(res)
c.OnResponse(res)
}
// get session from response // get session from response
if v, ok := res.Header["Session"]; ok { if v, ok := res.Header["Session"]; ok {

View File

@@ -2669,7 +2669,6 @@ func TestClientReadDifferentSource(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Play, req.Method) require.Equal(t, base.Play, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL)
require.Equal(t, base.HeaderValue{"npt=0-"}, req.Header["Range"])
err = conn.WriteResponse(&base.Response{ err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -2695,9 +2694,6 @@ func TestClientReadDifferentSource(t *testing.T) {
}() }()
c := Client{ c := Client{
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
Transport: func() *Transport { Transport: func() *Transport {
v := TransportUDP v := TransportUDP
return &v return &v
@@ -2716,3 +2712,159 @@ func TestClientReadDifferentSource(t *testing.T) {
<-packetRecv <-packetRecv
} }
func TestClientReadDecodeErrors(t *testing.T) {
for _, ca := range []string{
"invalid rtp",
"invalid rtcp",
} {
t.Run(ca, func(t *testing.T) {
errorRecv := make(chan struct{})
l, err := net.Listen("tcp", "localhost:8554")
require.NoError(t, err)
defer l.Close()
serverDone := make(chan struct{})
defer func() { <-serverDone }()
go func() {
defer close(serverDone)
nconn, err := l.Accept()
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
req, err := conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Options, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL)
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Public": base.HeaderValue{strings.Join([]string{
string(base.Describe),
string(base.Setup),
string(base.Play),
}, ", ")},
},
})
require.NoError(t, err)
req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL)
tracks := Tracks{&TrackH264{
PayloadType: 96,
SPS: []byte{0x01, 0x02, 0x03, 0x04},
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}}
tracks.setControls()
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Content-Type": base.HeaderValue{"application/sdp"},
"Content-Base": base.HeaderValue{"rtsp://localhost:8554/stream/"},
},
Body: tracks.Marshal(false),
})
require.NoError(t, err)
req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/trackID=0"), req.URL)
var inTH headers.Transport
err = inTH.Unmarshal(req.Header["Transport"])
require.NoError(t, err)
th := headers.Transport{
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Protocol: headers.TransportProtocolUDP,
ClientPorts: inTH.ClientPorts,
ServerPorts: &[2]int{34556, 34557},
}
l1, err := net.ListenPacket("udp", "127.0.0.1:34556")
require.NoError(t, err)
defer l1.Close()
l2, err := net.ListenPacket("udp", "127.0.0.1:34557")
require.NoError(t, err)
defer l2.Close()
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
Header: base.Header{
"Transport": th.Marshal(),
},
})
require.NoError(t, err)
req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Play, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/"), req.URL)
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err)
switch ca {
case "invalid rtp":
l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
case "invalid rtcp":
l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
}
req, err = conn.ReadRequest()
require.NoError(t, err)
require.Equal(t, base.Teardown, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/"), req.URL)
err = conn.WriteResponse(&base.Response{
StatusCode: base.StatusOK,
})
require.NoError(t, err)
}()
c := Client{
Transport: func() *Transport {
v := TransportUDP
return &v
}(),
OnDecodeError: func(err error) {
switch ca {
case "invalid rtp":
require.EqualError(t, err, "RTP header size insufficient: 2 < 4")
case "invalid rtcp":
require.EqualError(t, err, "rtcp: packet too short")
}
close(errorRecv)
},
}
err = startReading(&c, "rtsp://localhost:8554/stream")
require.NoError(t, err)
defer c.Close()
<-errorRecv
})
}
}

View File

@@ -193,6 +193,7 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
pkt := u.ct.udpRTPPacketBuffer.next() pkt := u.ct.udpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
u.c.OnDecodeError(err)
return return
} }
@@ -201,8 +202,10 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
for _, pkt := range packets { for _, pkt := range packets {
out, err := u.ct.cleaner.Process(pkt) out, err := u.ct.cleaner.Process(pkt)
if err != nil { if err != nil {
return u.c.OnDecodeError(err)
continue
} }
out0 := out[0] out0 := out[0]
u.ct.udpRTCPReceiver.ProcessPacketRTP(time.Now(), pkt, out0.PTSEqualsDTS) u.ct.udpRTCPReceiver.ProcessPacketRTP(time.Now(), pkt, out0.PTSEqualsDTS)
@@ -220,6 +223,7 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) {
func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
u.c.OnDecodeError(err)
return return
} }
@@ -235,6 +239,7 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) {
func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) {
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
u.c.OnDecodeError(err)
return return
} }

View File

@@ -1375,8 +1375,6 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
},
}, },
UDPRTPAddress: "127.0.0.1:8000", UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001", UDPRTCPAddress: "127.0.0.1:8001",
@@ -1476,3 +1474,141 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
}() }()
} }
func TestServerPublishDecodeErrors(t *testing.T) {
for _, ca := range []string{
"invalid rtp",
"invalid rtcp",
} {
t.Run(ca, func(t *testing.T) {
errorRecv := make(chan struct{})
s := &Server{
Handler: &testServerHandler{
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil, nil
},
onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) {
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) {
switch ca {
case "invalid rtp":
require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4")
case "invalid rtcp":
require.EqualError(t, ctx.Error, "rtcp: packet too short")
}
close(errorRecv)
},
},
UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001",
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
tracks := Tracks{&TrackH264{
PayloadType: 96,
SPS: []byte{0x01, 0x02, 0x03, 0x04},
PPS: []byte{0x01, 0x02, 0x03, 0x04},
}}
tracks.setControls()
res, err := writeReqReadRes(conn, base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: tracks.Marshal(false),
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
inTH := &headers.Transport{
Delivery: func() *headers.TransportDelivery {
v := headers.TransportDeliveryUnicast
return &v
}(),
Mode: func() *headers.TransportMode {
v := headers.TransportModeRecord
return &v
}(),
Protocol: headers.TransportProtocolUDP,
ClientPorts: &[2]int{35466, 35467},
}
l1, err := net.ListenPacket("udp", "127.0.0.1:35466")
require.NoError(t, err)
defer l1.Close()
l2, err := net.ListenPacket("udp", "127.0.0.1:35467")
require.NoError(t, err)
defer l2.Close()
res, err = writeReqReadRes(conn, base.Request{
Method: base.Setup,
URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
Header: base.Header{
"CSeq": base.HeaderValue{"2"},
"Transport": inTH.Marshal(),
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
var sx headers.Session
err = sx.Unmarshal(res.Header["Session"])
require.NoError(t, err)
var resTH headers.Transport
err = resTH.Unmarshal(res.Header["Transport"])
require.NoError(t, err)
res, err = writeReqReadRes(conn, base.Request{
Method: base.Record,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"3"},
"Session": base.HeaderValue{sx.Session},
},
})
require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode)
switch ca {
case "invalid rtp":
l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0],
})
case "invalid rtcp":
l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[1],
})
}
<-errorRecv
})
}
}

View File

@@ -90,10 +90,11 @@ type testServerHandler struct {
onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error)
onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error)
onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error)
onPacketRTP func(*ServerHandlerOnPacketRTPCtx)
onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx)
onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error) onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error)
onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error) onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error)
onPacketRTP func(*ServerHandlerOnPacketRTPCtx)
onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx)
onDecodeError func(*ServerHandlerOnDecodeErrorCtx)
} }
func (sh *testServerHandler) OnConnOpen(ctx *ServerHandlerOnConnOpenCtx) { func (sh *testServerHandler) OnConnOpen(ctx *ServerHandlerOnConnOpenCtx) {
@@ -162,18 +163,6 @@ func (sh *testServerHandler) OnPause(ctx *ServerHandlerOnPauseCtx) (*base.Respon
return nil, fmt.Errorf("unimplemented") return nil, fmt.Errorf("unimplemented")
} }
func (sh *testServerHandler) OnPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) {
if sh.onPacketRTP != nil {
sh.onPacketRTP(ctx)
}
}
func (sh *testServerHandler) OnPacketRTCP(ctx *ServerHandlerOnPacketRTCPCtx) {
if sh.onPacketRTCP != nil {
sh.onPacketRTCP(ctx)
}
}
func (sh *testServerHandler) OnSetParameter(ctx *ServerHandlerOnSetParameterCtx) (*base.Response, error) { func (sh *testServerHandler) OnSetParameter(ctx *ServerHandlerOnSetParameterCtx) (*base.Response, error) {
if sh.onSetParameter != nil { if sh.onSetParameter != nil {
return sh.onSetParameter(ctx) return sh.onSetParameter(ctx)
@@ -188,6 +177,24 @@ func (sh *testServerHandler) OnGetParameter(ctx *ServerHandlerOnGetParameterCtx)
return nil, fmt.Errorf("unimplemented") return nil, fmt.Errorf("unimplemented")
} }
func (sh *testServerHandler) OnPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) {
if sh.onPacketRTP != nil {
sh.onPacketRTP(ctx)
}
}
func (sh *testServerHandler) OnPacketRTCP(ctx *ServerHandlerOnPacketRTCPCtx) {
if sh.onPacketRTCP != nil {
sh.onPacketRTCP(ctx)
}
}
func (sh *testServerHandler) OnDecodeError(ctx *ServerHandlerOnDecodeErrorCtx) {
if sh.onDecodeError != nil {
sh.onDecodeError(ctx)
}
}
func TestServerClose(t *testing.T) { func TestServerClose(t *testing.T) {
s := &Server{ s := &Server{
Handler: &testServerHandler{}, Handler: &testServerHandler{},

View File

@@ -227,3 +227,15 @@ type ServerHandlerOnPacketRTCP interface {
// called when receiving a RTCP packet. // called when receiving a RTCP packet.
OnPacketRTCP(*ServerHandlerOnPacketRTCPCtx) OnPacketRTCP(*ServerHandlerOnPacketRTCPCtx)
} }
// ServerHandlerOnDecodeErrorCtx is the context of OnDecodeError.
type ServerHandlerOnDecodeErrorCtx struct {
Session *ServerSession
Error error
}
// ServerHandlerOnDecodeError can be implemented by a ServerHandler.
type ServerHandlerOnDecodeError interface {
// called when there's a non-fatal decoding error of RTP or RTCP packets.
OnDecodeError(*ServerHandlerOnDecodeErrorCtx)
}

View File

@@ -195,6 +195,12 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
pkt := u.s.udpRTPPacketBuffer.next() pkt := u.s.udpRTPPacketBuffer.next()
err := pkt.Unmarshal(payload) err := pkt.Unmarshal(payload)
if err != nil { if err != nil {
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {
h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{
Session: clientData.ss,
Error: err,
})
}
return return
} }
@@ -206,8 +212,15 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
out, err := clientData.track.cleaner.Process(pkt) out, err := clientData.track.cleaner.Process(pkt)
if err != nil { if err != nil {
return if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {
h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{
Session: clientData.ss,
Error: err,
})
}
continue
} }
out0 := out[0] out0 := out[0]
clientData.track.udpRTCPReceiver.ProcessPacketRTP(now, pkt, out0.PTSEqualsDTS) clientData.track.udpRTCPReceiver.ProcessPacketRTP(now, pkt, out0.PTSEqualsDTS)
@@ -228,6 +241,12 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {
packets, err := rtcp.Unmarshal(payload) packets, err := rtcp.Unmarshal(payload)
if err != nil { if err != nil {
if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok {
h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{
Session: clientData.ss,
Error: err,
})
}
return return
} }