diff --git a/client.go b/client.go index 0c7ebbda..0a1254b3 100644 --- a/client.go +++ b/client.go @@ -240,6 +240,8 @@ type Client struct { OnPacketRTP func(*ClientOnPacketRTPCtx) // called when receiving a RTCP packet. OnPacketRTCP func(*ClientOnPacketRTCPCtx) + // called when there's a non-fatal decoding error of RTP or RTCP packets. + OnDecodeError func(error) // // private @@ -335,12 +337,24 @@ func (c *Client) Start(scheme string, host string) error { } // callbacks + if c.OnRequest == nil { + c.OnRequest = func(*base.Request) { + } + } + if c.OnResponse == nil { + c.OnResponse = func(*base.Response) { + } + } if c.OnPacketRTP == nil { - c.OnPacketRTP = func(ctx *ClientOnPacketRTPCtx) { + c.OnPacketRTP = func(*ClientOnPacketRTPCtx) { } } if c.OnPacketRTCP == nil { - c.OnPacketRTCP = func(ctx *ClientOnPacketRTCPCtx) { + c.OnPacketRTCP = func(*ClientOnPacketRTCPCtx) { + } + } + if c.OnDecodeError == nil { + c.OnDecodeError = func(error) { } } @@ -814,6 +828,7 @@ func (c *Client) runReader() { if err != nil { // some cameras send invalid RTCP packets. // skip them. + c.OnDecodeError(err) return nil } @@ -1038,9 +1053,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba c.sender.AddAuthorization(req) } - if c.OnRequest != nil { - c.OnRequest(req) - } + c.OnRequest(req) c.nconn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)) err := c.conn.WriteRequest(req) @@ -1067,9 +1080,7 @@ func (c *Client) do(req *base.Request, skipResponse bool, allowFrames bool) (*ba return nil, err } - if c.OnResponse != nil { - c.OnResponse(res) - } + c.OnResponse(res) // get session from response if v, ok := res.Header["Session"]; ok { diff --git a/client_read_test.go b/client_read_test.go index 65148102..46d5a4b0 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -2669,7 +2669,6 @@ func TestClientReadDifferentSource(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Play, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/test/stream?param=value/"), req.URL) - require.Equal(t, base.HeaderValue{"npt=0-"}, req.Header["Range"]) err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -2695,9 +2694,6 @@ func TestClientReadDifferentSource(t *testing.T) { }() c := Client{ - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, Transport: func() *Transport { v := TransportUDP return &v @@ -2716,3 +2712,159 @@ func TestClientReadDifferentSource(t *testing.T) { <-packetRecv } + +func TestClientReadDecodeErrors(t *testing.T) { + for _, ca := range []string{ + "invalid rtp", + "invalid rtcp", + } { + t.Run(ca, func(t *testing.T) { + errorRecv := make(chan struct{}) + + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + req, err := conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL) + + tracks := Tracks{&TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + }} + tracks.setControls() + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/stream/"}, + }, + Body: tracks.Marshal(false), + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/trackID=0"), req.URL) + + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) + + th := headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Protocol: headers.TransportProtocolUDP, + ClientPorts: inTH.ClientPorts, + ServerPorts: &[2]int{34556, 34557}, + } + + l1, err := net.ListenPacket("udp", "127.0.0.1:34556") + require.NoError(t, err) + defer l1.Close() + + l2, err := net.ListenPacket("udp", "127.0.0.1:34557") + require.NoError(t, err) + defer l2.Close() + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Marshal(), + }, + }) + require.NoError(t, err) + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/"), req.URL) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + switch ca { + case "invalid rtp": + l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[0], + }) + + case "invalid rtcp": + l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[1], + }) + } + + req, err = conn.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) + require.Equal(t, mustParseURL("rtsp://localhost:8554/stream/"), req.URL) + + err = conn.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + }() + + c := Client{ + Transport: func() *Transport { + v := TransportUDP + return &v + }(), + OnDecodeError: func(err error) { + switch ca { + case "invalid rtp": + require.EqualError(t, err, "RTP header size insufficient: 2 < 4") + case "invalid rtcp": + require.EqualError(t, err, "rtcp: packet too short") + } + close(errorRecv) + }, + } + + err = startReading(&c, "rtsp://localhost:8554/stream") + require.NoError(t, err) + defer c.Close() + + <-errorRecv + }) + } +} diff --git a/clientudpl.go b/clientudpl.go index 30670a26..182b3663 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -193,6 +193,7 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { pkt := u.ct.udpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) if err != nil { + u.c.OnDecodeError(err) return } @@ -201,8 +202,10 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { for _, pkt := range packets { out, err := u.ct.cleaner.Process(pkt) if err != nil { - return + u.c.OnDecodeError(err) + continue } + out0 := out[0] u.ct.udpRTCPReceiver.ProcessPacketRTP(time.Now(), pkt, out0.PTSEqualsDTS) @@ -220,6 +223,7 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { packets, err := rtcp.Unmarshal(payload) if err != nil { + u.c.OnDecodeError(err) return } @@ -235,6 +239,7 @@ func (u *clientUDPListener) processPlayRTCP(now time.Time, payload []byte) { func (u *clientUDPListener) processRecordRTCP(now time.Time, payload []byte) { packets, err := rtcp.Unmarshal(payload) if err != nil { + u.c.OnDecodeError(err) return } diff --git a/server_publish_test.go b/server_publish_test.go index 1fe600d9..9dbf7ac4 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -1375,8 +1375,6 @@ func TestServerPublishUDPChangeConn(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { - }, }, UDPRTPAddress: "127.0.0.1:8000", UDPRTCPAddress: "127.0.0.1:8001", @@ -1476,3 +1474,141 @@ func TestServerPublishUDPChangeConn(t *testing.T) { require.Equal(t, base.StatusOK, res.StatusCode) }() } + +func TestServerPublishDecodeErrors(t *testing.T) { + for _, ca := range []string{ + "invalid rtp", + "invalid rtcp", + } { + t.Run(ca, func(t *testing.T) { + errorRecv := make(chan struct{}) + + s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, *ServerStream, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil, nil + }, + onRecord: func(ctx *ServerHandlerOnRecordCtx) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { + switch ca { + case "invalid rtp": + require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") + case "invalid rtcp": + require.EqualError(t, ctx.Error, "rtcp: packet too short") + } + close(errorRecv) + }, + }, + UDPRTPAddress: "127.0.0.1:8000", + UDPRTCPAddress: "127.0.0.1:8001", + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + nconn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer nconn.Close() + conn := conn.NewConn(nconn) + + tracks := Tracks{&TrackH264{ + PayloadType: 96, + SPS: []byte{0x01, 0x02, 0x03, 0x04}, + PPS: []byte{0x01, 0x02, 0x03, 0x04}, + }} + tracks.setControls() + + res, err := writeReqReadRes(conn, base.Request{ + Method: base.Announce, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: tracks.Marshal(false), + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + inTH := &headers.Transport{ + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModeRecord + return &v + }(), + Protocol: headers.TransportProtocolUDP, + ClientPorts: &[2]int{35466, 35467}, + } + + l1, err := net.ListenPacket("udp", "127.0.0.1:35466") + require.NoError(t, err) + defer l1.Close() + + l2, err := net.ListenPacket("udp", "127.0.0.1:35467") + require.NoError(t, err) + defer l2.Close() + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": inTH.Marshal(), + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + var sx headers.Session + err = sx.Unmarshal(res.Header["Session"]) + require.NoError(t, err) + + var resTH headers.Transport + err = resTH.Unmarshal(res.Header["Transport"]) + require.NoError(t, err) + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Record, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"3"}, + "Session": base.HeaderValue{sx.Session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + switch ca { + case "invalid rtp": + l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[0], + }) + + case "invalid rtcp": + l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[1], + }) + } + + <-errorRecv + }) + } +} diff --git a/server_test.go b/server_test.go index 05481111..e251ac13 100644 --- a/server_test.go +++ b/server_test.go @@ -90,10 +90,11 @@ type testServerHandler struct { onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) - onPacketRTP func(*ServerHandlerOnPacketRTPCtx) - onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx) onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error) onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error) + onPacketRTP func(*ServerHandlerOnPacketRTPCtx) + onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx) + onDecodeError func(*ServerHandlerOnDecodeErrorCtx) } func (sh *testServerHandler) OnConnOpen(ctx *ServerHandlerOnConnOpenCtx) { @@ -162,18 +163,6 @@ func (sh *testServerHandler) OnPause(ctx *ServerHandlerOnPauseCtx) (*base.Respon return nil, fmt.Errorf("unimplemented") } -func (sh *testServerHandler) OnPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) { - if sh.onPacketRTP != nil { - sh.onPacketRTP(ctx) - } -} - -func (sh *testServerHandler) OnPacketRTCP(ctx *ServerHandlerOnPacketRTCPCtx) { - if sh.onPacketRTCP != nil { - sh.onPacketRTCP(ctx) - } -} - func (sh *testServerHandler) OnSetParameter(ctx *ServerHandlerOnSetParameterCtx) (*base.Response, error) { if sh.onSetParameter != nil { return sh.onSetParameter(ctx) @@ -188,6 +177,24 @@ func (sh *testServerHandler) OnGetParameter(ctx *ServerHandlerOnGetParameterCtx) return nil, fmt.Errorf("unimplemented") } +func (sh *testServerHandler) OnPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) { + if sh.onPacketRTP != nil { + sh.onPacketRTP(ctx) + } +} + +func (sh *testServerHandler) OnPacketRTCP(ctx *ServerHandlerOnPacketRTCPCtx) { + if sh.onPacketRTCP != nil { + sh.onPacketRTCP(ctx) + } +} + +func (sh *testServerHandler) OnDecodeError(ctx *ServerHandlerOnDecodeErrorCtx) { + if sh.onDecodeError != nil { + sh.onDecodeError(ctx) + } +} + func TestServerClose(t *testing.T) { s := &Server{ Handler: &testServerHandler{}, diff --git a/serverhandler.go b/serverhandler.go index 6a896648..16dbe5a9 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -227,3 +227,15 @@ type ServerHandlerOnPacketRTCP interface { // called when receiving a RTCP packet. OnPacketRTCP(*ServerHandlerOnPacketRTCPCtx) } + +// ServerHandlerOnDecodeErrorCtx is the context of OnDecodeError. +type ServerHandlerOnDecodeErrorCtx struct { + Session *ServerSession + Error error +} + +// ServerHandlerOnDecodeError can be implemented by a ServerHandler. +type ServerHandlerOnDecodeError interface { + // called when there's a non-fatal decoding error of RTP or RTCP packets. + OnDecodeError(*ServerHandlerOnDecodeErrorCtx) +} diff --git a/serverudpl.go b/serverudpl.go index 70516aea..6c8e48b7 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -195,6 +195,12 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { pkt := u.s.udpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) if err != nil { + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok { + h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{ + Session: clientData.ss, + Error: err, + }) + } return } @@ -206,8 +212,15 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { out, err := clientData.track.cleaner.Process(pkt) if err != nil { - return + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok { + h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{ + Session: clientData.ss, + Error: err, + }) + } + continue } + out0 := out[0] clientData.track.udpRTCPReceiver.ProcessPacketRTP(now, pkt, out0.PTSEqualsDTS) @@ -228,6 +241,12 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { packets, err := rtcp.Unmarshal(payload) if err != nil { + if h, ok := clientData.ss.s.Handler.(ServerHandlerOnDecodeError); ok { + h.OnDecodeError(&ServerHandlerOnDecodeErrorCtx{ + Session: clientData.ss, + Error: err, + }) + } return }