diff --git a/README.md b/README.md index fe94cdf9..e34d9bc8 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,6 @@ Features: * Pause or seek without disconnecting from the server * Generate RTCP receiver reports (UDP only) * Reorder incoming RTP packets (UDP only) - * Clean up non-compliant streams (remove padding, re-encode RTP packets if they are too big) * Publish * Publish streams to servers with the UDP or TCP transport protocol * Publish TLS-encrypted streams (TCP only) @@ -38,7 +37,6 @@ Features: * Read TLS-encrypted streams (TCP only) * Generate RTCP receiver reports (UDP only) * Reorder incoming RTP packets (UDP only) - * Clean up non-compliant streams (remove padding, re-encode RTP packets if they are too big) * Read * Write streams to clients with the UDP, UDP-multicast or TCP transport protocol * Write TLS-encrypted streams diff --git a/client.go b/client.go index 3f9443b6..8a90d933 100644 --- a/client.go +++ b/client.go @@ -28,7 +28,6 @@ import ( "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" "github.com/aler9/gortsplib/pkg/rtcpsender" - "github.com/aler9/gortsplib/pkg/rtpcleaner" "github.com/aler9/gortsplib/pkg/rtpreorderer" "github.com/aler9/gortsplib/pkg/sdp" "github.com/aler9/gortsplib/pkg/url" @@ -94,7 +93,6 @@ type clientTrack struct { udpRTPPacketBuffer *rtpPacketMultiBuffer udpRTCPReceiver *rtcpreceiver.RTCPReceiver reorderer *rtpreorderer.Reorderer - cleaner *rtpcleaner.Cleaner // record rtcpSender *rtcpsender.RTCPSender @@ -165,8 +163,6 @@ type ClientOnPacketRTPCtx struct { TrackID int Packet *rtp.Packet PTSEqualsDTS bool - H264NALUs [][]byte - H264PTS time.Duration } // ClientOnPacketRTCPCtx is the context of a RTCP packet. @@ -704,8 +700,6 @@ func (c *Client) playRecordStart() { if *c.effectiveTransport == TransportUDP || *c.effectiveTransport == TransportUDPMulticast { ct.reorderer = rtpreorderer.New() } - _, isH264 := ct.track.(*TrackH264) - ct.cleaner = rtpcleaner.New(isH264, *c.effectiveTransport == TransportTCP) } c.keepaliveTimer = time.NewTimer(c.keepalivePeriod) @@ -804,30 +798,22 @@ func (c *Client) runReader() { return err } - out, err := track.cleaner.Process(pkt) - if err != nil { - return err - } - - for _, entry := range out { - c.OnPacketRTP(&ClientOnPacketRTPCtx{ - TrackID: track.id, - Packet: entry.Packet, - PTSEqualsDTS: entry.PTSEqualsDTS, - H264NALUs: entry.H264NALUs, - H264PTS: entry.H264PTS, - }) - } + c.OnPacketRTP(&ClientOnPacketRTPCtx{ + TrackID: track.id, + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS(track.track, pkt), + }) } else { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", - len(payload), maxPacketSize) + c.OnDecodeError(fmt.Errorf("RTCP packet size (%d) is greater than maximum allowed (%d)", + len(payload), maxPacketSize)) + return nil } packets, err := rtcp.Unmarshal(payload) if err != nil { // some cameras send invalid RTCP packets. - // skip them. + // ignore them. c.OnDecodeError(err) return nil } @@ -846,8 +832,9 @@ func (c *Client) runReader() { processFunc = func(track *clientTrack, isRTP bool, payload []byte) error { if !isRTP { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", - len(payload), maxPacketSize) + c.OnDecodeError(fmt.Errorf("RTCP packet size (%d) is greater than maximum allowed (%d)", + len(payload), maxPacketSize)) + return nil } packets, err := rtcp.Unmarshal(payload) @@ -929,7 +916,6 @@ func (c *Client) playRecordStop(isClosing bool) { ct.rtcpSender.Close() ct.rtcpSender = nil } - ct.cleaner = nil ct.reorderer = nil } diff --git a/client_read_test.go b/client_read_test.go index ad7d4483..139575f3 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -2717,10 +2717,10 @@ func TestClientReadDecodeErrors(t *testing.T) { for _, ca := range []string{ "rtp invalid", "rtcp invalid", - "packets lost", + "rtp packets lost", "rtp too big", "rtcp too big", - "cleaner error", + "rtcp too big tcp", } { t.Run(ca, func(t *testing.T) { errorRecv := make(chan struct{}) @@ -2761,23 +2761,13 @@ func TestClientReadDecodeErrors(t *testing.T) { require.Equal(t, base.Describe, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/stream"), req.URL) - var track Track - if ca != "cleaner error" { - track = &TrackGeneric{ - Media: "application", - Payloads: []TrackGenericPayload{{ - Type: 97, - RTPMap: "private/90000", - }}, - } - } else { - track = &TrackH264{ - PayloadType: 96, - SPS: []byte{0x01, 0x02, 0x03, 0x04}, - PPS: []byte{0x01, 0x02, 0x03, 0x04}, - } - } - tracks := Tracks{track} + tracks := Tracks{&TrackGeneric{ + Media: "application", + Payloads: []TrackGenericPayload{{ + Type: 97, + RTPMap: "private/90000", + }}, + }} tracks.setControls() err = conn.WriteResponse(&base.Response{ @@ -2804,18 +2794,29 @@ func TestClientReadDecodeErrors(t *testing.T) { v := headers.TransportDeliveryUnicast return &v }(), - Protocol: headers.TransportProtocolUDP, - ClientPorts: inTH.ClientPorts, - ServerPorts: &[2]int{34556, 34557}, } - l1, err := net.ListenPacket("udp", "127.0.0.1:34556") - require.NoError(t, err) - defer l1.Close() + if ca != "rtcp too big tcp" { + th.Protocol = headers.TransportProtocolUDP + th.ClientPorts = inTH.ClientPorts + th.ServerPorts = &[2]int{34556, 34557} + } else { + th.Protocol = headers.TransportProtocolTCP + th.InterleavedIDs = inTH.InterleavedIDs + } - l2, err := net.ListenPacket("udp", "127.0.0.1:34557") - require.NoError(t, err) - defer l2.Close() + var l1 net.PacketConn + var l2 net.PacketConn + + if ca != "rtcp too big tcp" { + l1, err = net.ListenPacket("udp", "127.0.0.1:34556") + require.NoError(t, err) + defer l1.Close() + + l2, err = net.ListenPacket("udp", "127.0.0.1:34557") + require.NoError(t, err) + defer l2.Close() + } err = conn.WriteResponse(&base.Response{ StatusCode: base.StatusOK, @@ -2848,7 +2849,7 @@ func TestClientReadDecodeErrors(t *testing.T) { Port: th.ClientPorts[1], }) - case "packets lost": + case "rtp packets lost": byts, _ := rtp.Packet{ Header: rtp.Header{ SequenceNumber: 30, @@ -2881,17 +2882,12 @@ func TestClientReadDecodeErrors(t *testing.T) { Port: th.ClientPorts[1], }) - case "cleaner error": - byts, _ := rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: 100, - }, - Payload: []byte{0x99}, - }.Marshal() - l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) + case "rtcp too big tcp": + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), + }, make([]byte, 2048)) + require.NoError(t, err) } req, err = conn.ReadRequest() @@ -2907,7 +2903,11 @@ func TestClientReadDecodeErrors(t *testing.T) { c := Client{ Transport: func() *Transport { - v := TransportUDP + if ca != "rtcp too big tcp" { + v := TransportUDP + return &v + } + v := TransportTCP return &v }(), OnDecodeError: func(err error) { @@ -2916,14 +2916,14 @@ func TestClientReadDecodeErrors(t *testing.T) { require.EqualError(t, err, "RTP header size insufficient: 2 < 4") case "rtcp invalid": require.EqualError(t, err, "rtcp: packet too short") - case "packets lost": + case "rtp packets lost": require.EqualError(t, err, "69 RTP packet(s) lost") case "rtp too big": require.EqualError(t, err, "RTP packet is too big to be read with UDP") case "rtcp too big": require.EqualError(t, err, "RTCP packet is too big to be read with UDP") - case "cleaner error": - require.EqualError(t, err, "packet type not supported (STAP-B)") + case "rtcp too big tcp": + require.EqualError(t, err, "RTCP packet size (2000) is greater than maximum allowed (1472)") } close(errorRecv) }, diff --git a/clientudpl.go b/clientudpl.go index 38e5a9b6..ae8a1162 100644 --- a/clientudpl.go +++ b/clientudpl.go @@ -210,25 +210,14 @@ func (u *clientUDPListener) processPlayRTP(now time.Time, payload []byte) { } for _, pkt := range packets { - out, err := u.ct.cleaner.Process(pkt) - if err != nil { - u.c.OnDecodeError(err) - // do not return - } + ptsEqualsDTS := ptsEqualsDTS(u.ct.track, pkt) + u.ct.udpRTCPReceiver.ProcessPacketRTP(time.Now(), pkt, ptsEqualsDTS) - if out != nil { - out0 := out[0] - - u.ct.udpRTCPReceiver.ProcessPacketRTP(time.Now(), pkt, out0.PTSEqualsDTS) - - u.c.OnPacketRTP(&ClientOnPacketRTPCtx{ - TrackID: u.ct.id, - Packet: out0.Packet, - PTSEqualsDTS: out0.PTSEqualsDTS, - H264NALUs: out0.H264NALUs, - H264PTS: out0.H264PTS, - }) - } + u.c.OnPacketRTP(&ClientOnPacketRTPCtx{ + TrackID: u.ct.id, + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS, + }) } } diff --git a/examples/client-read-h264-convert-to-jpeg/main.go b/examples/client-read-h264-convert-to-jpeg/main.go index 8baa6baf..de7452b6 100644 --- a/examples/client-read-h264-convert-to-jpeg/main.go +++ b/examples/client-read-h264-convert-to-jpeg/main.go @@ -9,6 +9,7 @@ import ( "time" "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/url" ) @@ -73,21 +74,25 @@ func main() { panic("H264 track not found") } + // setup RTP/H264->H264 decoder + rtpDec := &rtph264.Decoder{} + rtpDec.Init() + // setup H264->raw frames decoder - h264dec, err := newH264Decoder() + h264RawDec, err := newH264Decoder() if err != nil { panic(err) } - defer h264dec.close() + defer h264RawDec.close() - // if present, send SPS and PPS from the SDP to the decoder + // if SPS and PPS are present into the SDP, send them to the decoder sps := h264track.SafeSPS() if sps != nil { - h264dec.decode(sps) + h264RawDec.decode(sps) } pps := h264track.SafePPS() if pps != nil { - h264dec.decode(pps) + h264RawDec.decode(pps) } // called when a RTP packet arrives @@ -97,13 +102,15 @@ func main() { return } - if ctx.H264NALUs == nil { + // convert RTP packets into NALUs + nalus, _, err := rtpDec.Decode(ctx.Packet) + if err != nil { return } - for _, nalu := range ctx.H264NALUs { - // convert H264 NALUs to RGBA frames - img, err := h264dec.decode(nalu) + for _, nalu := range nalus { + // convert NALUs into RGBA frames + img, err := h264RawDec.decode(nalu) if err != nil { panic(err) } diff --git a/examples/client-read-h264-save-to-disk/main.go b/examples/client-read-h264-save-to-disk/main.go index 7aab1b4e..8791b849 100644 --- a/examples/client-read-h264-save-to-disk/main.go +++ b/examples/client-read-h264-save-to-disk/main.go @@ -2,6 +2,7 @@ package main import ( "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/url" ) @@ -45,8 +46,12 @@ func main() { panic("H264 track not found") } - // setup H264->MPEGTS encoder - enc, err := newMPEGTSMuxer(h264track.SafeSPS(), h264track.SafePPS()) + // setup RTP/H264->H264 decoder + rtpDec := &rtph264.Decoder{} + rtpDec.Init() + + // setup H264->MPEGTS muxer + mpegtsMuxer, err := newMPEGTSMuxer(h264track.SafeSPS(), h264track.SafePPS()) if err != nil { panic(err) } @@ -57,15 +62,14 @@ func main() { return } - if ctx.H264NALUs == nil { + // convert RTP packets into NALUs + nalus, pts, err := rtpDec.Decode(ctx.Packet) + if err != nil { return } // encode H264 NALUs into MPEG-TS - err = enc.encode(ctx.H264NALUs, ctx.H264PTS) - if err != nil { - return - } + mpegtsMuxer.encode(nalus, pts) } // setup and read all tracks diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index eb0a3608..34d5dc17 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -4,6 +4,7 @@ import ( "log" "github.com/aler9/gortsplib" + "github.com/aler9/gortsplib/pkg/rtph264" "github.com/aler9/gortsplib/pkg/url" ) @@ -50,21 +51,25 @@ func main() { panic("H264 track not found") } + // setup RTP/H264->H264 decoder + rtpDec := &rtph264.Decoder{} + rtpDec.Init() + // setup H264->raw frames decoder - h264dec, err := newH264Decoder() + h264RawDec, err := newH264Decoder() if err != nil { panic(err) } - defer h264dec.close() + defer h264RawDec.close() - // if present, send SPS and PPS from the SDP to the decoder + // if SPS and PPS are present into the SDP, send them to the decoder sps := h264track.SafeSPS() if sps != nil { - h264dec.decode(sps) + h264RawDec.decode(sps) } pps := h264track.SafePPS() if pps != nil { - h264dec.decode(pps) + h264RawDec.decode(pps) } // called when a RTP packet arrives @@ -73,13 +78,15 @@ func main() { return } - if ctx.H264NALUs == nil { + // convert RTP packets into NALUs + nalus, _, err := rtpDec.Decode(ctx.Packet) + if err != nil { return } - for _, nalu := range ctx.H264NALUs { - // convert H264 NALUs to RGBA frames - img, err := h264dec.decode(nalu) + for _, nalu := range nalus { + // convert NALUs into RGBA frames + img, err := h264RawDec.decode(nalu) if err != nil { panic(err) } diff --git a/examples/server-h264-save-to-disk/main.go b/examples/server-h264-save-to-disk/main.go index 0f4f5bca..fbcaa61c 100644 --- a/examples/server-h264-save-to-disk/main.go +++ b/examples/server-h264-save-to-disk/main.go @@ -7,6 +7,7 @@ import ( "github.com/aler9/gortsplib" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/rtph264" ) // This example shows how to @@ -19,6 +20,7 @@ type serverHandler struct { publisher *gortsplib.ServerSession h264TrackID int h264track *gortsplib.TrackH264 + rtpDec *rtph264.Decoder mpegtsMuxer *mpegtsMuxer } @@ -76,7 +78,11 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) ( }, fmt.Errorf("H264 track not found") } - // setup H264->MPEGTS encoder + // setup RTP/H264->H264 decoder + rtpDec := &rtph264.Decoder{} + rtpDec.Init() + + // setup H264->MPEGTS muxer mpegtsMuxer, err := newMPEGTSMuxer(h264track.SafeSPS(), h264track.SafePPS()) if err != nil { return &base.Response{ @@ -86,6 +92,7 @@ func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) ( sh.publisher = ctx.Session sh.h264TrackID = h264TrackID + sh.rtpDec = rtpDec sh.mpegtsMuxer = mpegtsMuxer return &base.Response{ @@ -120,15 +127,13 @@ func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) return } - if ctx.H264NALUs == nil { + nalus, pts, err := sh.rtpDec.Decode(ctx.Packet) + if err != nil { return } // encode H264 NALUs into MPEG-TS - err := sh.mpegtsMuxer.encode(ctx.H264NALUs, ctx.H264PTS) - if err != nil { - return - } + sh.mpegtsMuxer.encode(nalus, pts) } func main() { diff --git a/pkg/rtpcleaner/cleaner.go b/pkg/rtpcleaner/cleaner.go deleted file mode 100644 index bbae4702..00000000 --- a/pkg/rtpcleaner/cleaner.go +++ /dev/null @@ -1,149 +0,0 @@ -// Package rtpcleaner contains a cleaning utility. -package rtpcleaner - -import ( - "fmt" - "time" - - "github.com/pion/rtp" - - "github.com/aler9/gortsplib/pkg/h264" - "github.com/aler9/gortsplib/pkg/rtph264" -) - -const ( - // 1500 (UDP MTU) - 20 (IP header) - 8 (UDP header) - maxPacketSize = 1472 -) - -// Output is the output of Clear(). -type Output struct { - Packet *rtp.Packet - PTSEqualsDTS bool - H264NALUs [][]byte - H264PTS time.Duration -} - -// Cleaner is used to clean incoming RTP packets, in order to: -// - remove padding -// - re-encode them if they are bigger than maximum allowed -type Cleaner struct { - isH264 bool - isTCP bool - - h264Decoder *rtph264.Decoder - h264Encoder *rtph264.Encoder -} - -// New allocates a Cleaner. -func New(isH264 bool, isTCP bool) *Cleaner { - p := &Cleaner{ - isH264: isH264, - isTCP: isTCP, - } - - if isH264 { - p.h264Decoder = &rtph264.Decoder{} - p.h264Decoder.Init() - } - - return p -} - -func (p *Cleaner) processH264(pkt *rtp.Packet) ([]*Output, error) { - // check if we need to re-encode - if p.isTCP && p.h264Encoder == nil && pkt.MarshalSize() > maxPacketSize { - v1 := pkt.SSRC - v2 := pkt.SequenceNumber - v3 := pkt.Timestamp - p.h264Encoder = &rtph264.Encoder{ - PayloadType: pkt.PayloadType, - SSRC: &v1, - InitialSequenceNumber: &v2, - InitialTimestamp: &v3, - } - p.h264Encoder.Init() - } - - // re-encode - if p.h264Encoder != nil { - // decode - nalus, pts, err := p.h264Decoder.DecodeUntilMarker(pkt) - if err != nil { - if err == rtph264.ErrNonStartingPacketAndNoPrevious || - err == rtph264.ErrMorePacketsNeeded { // hide standard errors - err = nil - } - - return nil, err // original packets are oversized, do not return them - } - - packets, err := p.h264Encoder.Encode(nalus, pts) - if err != nil { - return nil, err // original packets are oversized, do not return them - } - - ptsEqualsDTS := h264.IDRPresent(nalus) - output := make([]*Output, len(packets)) - - for i, pkt := range packets { - if i != len(packets)-1 { - output[i] = &Output{ - Packet: pkt, - PTSEqualsDTS: false, - } - } else { - output[i] = &Output{ - Packet: pkt, - PTSEqualsDTS: ptsEqualsDTS, - H264NALUs: nalus, - H264PTS: pts, - } - } - } - - return output, nil - } - - // decode - nalus, pts, err := p.h264Decoder.DecodeUntilMarker(pkt) - if err != nil { - if err == rtph264.ErrNonStartingPacketAndNoPrevious || - err == rtph264.ErrMorePacketsNeeded { // hide standard errors - err = nil - } - - return []*Output{{ - Packet: pkt, - PTSEqualsDTS: false, - }}, err - } - - return []*Output{{ - Packet: pkt, - PTSEqualsDTS: h264.IDRPresent(nalus), - H264NALUs: nalus, - H264PTS: pts, - }}, nil -} - -// Process processes a RTP packet. -func (p *Cleaner) Process(pkt *rtp.Packet) ([]*Output, error) { - // remove padding - pkt.Header.Padding = false - pkt.PaddingSize = 0 - - if p.h264Decoder != nil { - return p.processH264(pkt) - } - - if p.isTCP && pkt.MarshalSize() > maxPacketSize { - return nil, fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", - pkt.MarshalSize(), maxPacketSize) - } - - return []*Output{{ - Packet: pkt, - PTSEqualsDTS: true, - }}, nil -} diff --git a/pkg/rtpcleaner/cleaner_test.go b/pkg/rtpcleaner/cleaner_test.go deleted file mode 100644 index 5c580aa3..00000000 --- a/pkg/rtpcleaner/cleaner_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package rtpcleaner - -import ( - "bytes" - "testing" - - "github.com/pion/rtp" - "github.com/stretchr/testify/require" -) - -func TestRemovePadding(t *testing.T) { - cleaner := New(false, false) - - out, err := cleaner.Process(&rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: true, - SequenceNumber: 34572, - Padding: true, - }, - Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 64/4), - PaddingSize: 64, - }) - require.NoError(t, err) - require.Equal(t, []*Output{{ - Packet: &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: true, - SequenceNumber: 34572, - }, - Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04}, 64/4), - }, - PTSEqualsDTS: true, - }}, out) -} - -func TestGenericOversized(t *testing.T) { - cleaner := New(false, true) - - _, err := cleaner.Process(&rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: false, - SequenceNumber: 34572, - }, - Payload: bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2050/5), - }) - require.EqualError(t, err, "payload size (2062) is greater than maximum allowed (1472)") -} - -func TestH264Oversized(t *testing.T) { - cleaner := New(true, true) - - out, err := cleaner.Process(&rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: false, - SequenceNumber: 34572, - }, - Payload: append( - []byte{0x1C, 1<<7 | 0x05}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2050/5)..., - ), - }) - require.NoError(t, err) - require.Equal(t, []*Output(nil), out) - - out, err = cleaner.Process(&rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: true, - SequenceNumber: 34573, - }, - Payload: append( - []byte{0x1C, 1 << 6}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2050/5)..., - ), - }) - require.NoError(t, err) - require.Equal(t, []*Output{ - { - Packet: &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: false, - SequenceNumber: 34572, - }, - Payload: append( - append( - []byte{0x1c, 0x85}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 291)..., - ), - []byte{0x01, 0x02, 0x03}..., - ), - }, - }, - { - Packet: &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: false, - SequenceNumber: 34573, - }, - Payload: append( - append( - []byte{0x1c, 0x05, 0x04, 0x05}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 291)..., - ), - []byte{0x01}..., - ), - }, - }, - { - Packet: &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: true, - SequenceNumber: 34574, - }, - Payload: append( - []byte{0x1c, 0x45, 0x02, 0x03, 0x04, 0x05}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 236)..., - ), - }, - H264NALUs: [][]byte{ - append( - []byte{0x05}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 4100/5)..., - ), - }, - PTSEqualsDTS: true, - }, - }, out) -} - -func TestH264ProcessEvenIfInvalid(t *testing.T) { - cleaner := New(true, true) - - out, err := cleaner.Process(&rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: false, - SequenceNumber: 34572, - }, - Payload: []byte{25}, - }) - require.Error(t, err) - require.Equal(t, []*Output{{ - Packet: &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - Marker: false, - SequenceNumber: 34572, - }, - Payload: []byte{25}, - }, - }}, out) -} - -func TestH264RandomAccess(t *testing.T) { - for _, ca := range []string{ - "standard", - "oversized", - } { - t.Run(ca, func(t *testing.T) { - cleaner := New(true, true) - - var payload []byte - if ca == "standard" { - payload = append([]byte{0x1C, 1 << 6}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 10/5)...) - } else { - payload = append([]byte{0x1C, 1 << 6}, - bytes.Repeat([]byte{0x01, 0x02, 0x03, 0x04, 0x05}, 2048/5)...) - } - - out, err := cleaner.Process(&rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - SequenceNumber: 34572, - }, - Payload: payload, - }) - require.NoError(t, err) - - if ca == "standard" { - require.Equal(t, []*Output{{ - Packet: &rtp.Packet{ - Header: rtp.Header{ - Version: 2, - PayloadType: 96, - SequenceNumber: 34572, - }, - Payload: payload, - }, - }}, out) - } else { - require.Equal(t, []*Output(nil), out) - } - }) - } -} diff --git a/ptsequalsdts.go b/ptsequalsdts.go new file mode 100644 index 00000000..f3d71d7a --- /dev/null +++ b/ptsequalsdts.go @@ -0,0 +1,71 @@ +package gortsplib + +import ( + "github.com/pion/rtp" + + "github.com/aler9/gortsplib/pkg/h264" +) + +// find IDR NALUs without decoding RTP +func rtpH264ContainsIDR(pkt *rtp.Packet) bool { + if len(pkt.Payload) == 0 { + return false + } + + typ := h264.NALUType(pkt.Payload[0] & 0x1F) + + switch typ { + case h264.NALUTypeIDR: + return true + + case 24: // STAP-A + payload := pkt.Payload[1:] + + for len(payload) > 0 { + if len(payload) < 2 { + return false + } + + size := uint16(payload[0])<<8 | uint16(payload[1]) + payload = payload[2:] + + if size == 0 || int(size) > len(payload) { + return false + } + + nalu := payload[:size] + payload = payload[size:] + + typ = h264.NALUType(nalu[0] & 0x1F) + if typ == h264.NALUTypeIDR { + return true + } + } + + return false + + case 28: // FU-A + if len(pkt.Payload) < 2 { + return false + } + + start := pkt.Payload[1] >> 7 + if start != 1 { + return false + } + + typ := h264.NALUType(pkt.Payload[1] & 0x1F) + return (typ == h264.NALUTypeIDR) + + default: + return false + } +} + +func ptsEqualsDTS(track Track, pkt *rtp.Packet) bool { + if _, ok := track.(*TrackH264); ok { + return rtpH264ContainsIDR(pkt) + } + + return true +} diff --git a/server_publish_test.go b/server_publish_test.go index 9f83cf79..dd314b4f 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -1476,10 +1476,10 @@ func TestServerPublishDecodeErrors(t *testing.T) { for _, ca := range []string{ "rtp invalid", "rtcp invalid", - "packets lost", + "rtp packets lost", "rtp too big", "rtcp too big", - "cleaner error", + "rtcp too big tcp", } { t.Run(ca, func(t *testing.T) { errorRecv := make(chan struct{}) @@ -1507,14 +1507,14 @@ func TestServerPublishDecodeErrors(t *testing.T) { require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") case "rtcp invalid": require.EqualError(t, ctx.Error, "rtcp: packet too short") - case "packets lost": + case "rtp packets lost": require.EqualError(t, ctx.Error, "69 RTP packet(s) lost") case "rtp too big": require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP") case "rtcp too big": require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP") - case "cleaner error": - require.EqualError(t, ctx.Error, "packet type not supported (STAP-B)") + case "rtcp too big tcp": + require.EqualError(t, ctx.Error, "RTCP packet size (2000) is greater than maximum allowed (1472)") } close(errorRecv) }, @@ -1533,23 +1533,13 @@ func TestServerPublishDecodeErrors(t *testing.T) { defer nconn.Close() conn := conn.NewConn(nconn) - var track Track - if ca != "cleaner error" { - track = &TrackGeneric{ - Media: "application", - Payloads: []TrackGenericPayload{{ - Type: 97, - RTPMap: "private/90000", - }}, - } - } else { - track = &TrackH264{ - PayloadType: 96, - SPS: []byte{0x01, 0x02, 0x03, 0x04}, - PPS: []byte{0x01, 0x02, 0x03, 0x04}, - } - } - tracks := Tracks{track} + tracks := Tracks{&TrackGeneric{ + Media: "application", + Payloads: []TrackGenericPayload{{ + Type: 97, + RTPMap: "private/90000", + }}, + }} tracks.setControls() res, err := writeReqReadRes(conn, base.Request{ @@ -1573,17 +1563,28 @@ func TestServerPublishDecodeErrors(t *testing.T) { v := headers.TransportModeRecord return &v }(), - Protocol: headers.TransportProtocolUDP, - ClientPorts: &[2]int{35466, 35467}, } - l1, err := net.ListenPacket("udp", "127.0.0.1:35466") - require.NoError(t, err) - defer l1.Close() + if ca != "rtcp too big tcp" { + inTH.Protocol = headers.TransportProtocolUDP + inTH.ClientPorts = &[2]int{35466, 35467} + } else { + inTH.Protocol = headers.TransportProtocolTCP + inTH.InterleavedIDs = &[2]int{0, 1} + } - l2, err := net.ListenPacket("udp", "127.0.0.1:35467") - require.NoError(t, err) - defer l2.Close() + var l1 net.PacketConn + var l2 net.PacketConn + + if ca != "rtcp too big tcp" { + l1, err = net.ListenPacket("udp", "127.0.0.1:35466") + require.NoError(t, err) + defer l1.Close() + + l2, err = net.ListenPacket("udp", "127.0.0.1:35467") + require.NoError(t, err) + defer l2.Close() + } res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, @@ -1628,7 +1629,7 @@ func TestServerPublishDecodeErrors(t *testing.T) { Port: resTH.ServerPorts[1], }) - case "packets lost": + case "rtp packets lost": byts, _ := rtp.Packet{ Header: rtp.Header{ SequenceNumber: 30, @@ -1661,17 +1662,12 @@ func TestServerPublishDecodeErrors(t *testing.T) { Port: resTH.ServerPorts[1], }) - case "cleaner error": - byts, _ := rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: 100, - }, - Payload: []byte{0x99}, - }.Marshal() - l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[0], - }) + case "rtcp too big tcp": + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), + }, make([]byte, 2048)) + require.NoError(t, err) } <-errorRecv diff --git a/serverconn.go b/serverconn.go index 528ee1d4..874592cd 100644 --- a/serverconn.go +++ b/serverconn.go @@ -216,14 +216,15 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { case <-sc.session.ctx.Done(): } - var processFunc func(int, bool, []byte) error + var processFunc func(*ServerSessionSetuppedTrack, bool, []byte) error if sc.session.state == ServerSessionStatePlay { - processFunc = func(trackID int, isRTP bool, payload []byte) error { + processFunc = func(track *ServerSessionSetuppedTrack, isRTP bool, payload []byte) error { if !isRTP { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", - len(payload), maxPacketSize) + onDecodeError(sc.session, fmt.Errorf("RTCP packet size (%d) is greater than maximum allowed (%d)", + len(payload), maxPacketSize)) + return nil } packets, err := rtcp.Unmarshal(payload) @@ -235,7 +236,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { for _, pkt := range packets { h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ Session: sc.session, - TrackID: trackID, + TrackID: track.id, Packet: pkt, }) } @@ -247,7 +248,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { } else { tcpRTPPacketBuffer := newRTPPacketMultiBuffer(uint64(sc.s.ReadBufferCount)) - processFunc = func(trackID int, isRTP bool, payload []byte) error { + processFunc = func(track *ServerSessionSetuppedTrack, isRTP bool, payload []byte) error { if isRTP { pkt := tcpRTPPacketBuffer.next() err := pkt.Unmarshal(payload) @@ -255,28 +256,19 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { return err } - out, err := sc.session.setuppedTracks[trackID].cleaner.Process(pkt) - if err != nil { - onDecodeError(sc.session, err) - // do not return - } - if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { - for _, entry := range out { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: sc.session, - TrackID: trackID, - Packet: entry.Packet, - PTSEqualsDTS: entry.PTSEqualsDTS, - H264NALUs: entry.H264NALUs, - H264PTS: entry.H264PTS, - }) - } + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: sc.session, + TrackID: track.id, + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS(track.track, pkt), + }) } } else { if len(payload) > maxPacketSize { - return fmt.Errorf("payload size (%d) is greater than maximum allowed (%d)", - len(payload), maxPacketSize) + onDecodeError(sc.session, fmt.Errorf("RTCP packet size (%d) is greater than maximum allowed (%d)", + len(payload), maxPacketSize)) + return nil } packets, err := rtcp.Unmarshal(payload) @@ -285,7 +277,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { } for _, pkt := range packets { - sc.session.onPacketRTCP(trackID, pkt) + sc.session.onPacketRTCP(track.id, pkt) } } @@ -313,8 +305,8 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { } // forward frame only if it has been set up - if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok { - err := processFunc(trackID, isRTP, twhat.Payload) + if track, ok := sc.session.tcpTracksByChannel[channel]; ok { + err := processFunc(track, isRTP, twhat.Payload) if err != nil { return err } diff --git a/serverhandler.go b/serverhandler.go index 16dbe5a9..d90dfadf 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -1,8 +1,6 @@ package gortsplib import ( - "time" - "github.com/pion/rtcp" "github.com/pion/rtp" @@ -205,8 +203,6 @@ type ServerHandlerOnPacketRTPCtx struct { TrackID int Packet *rtp.Packet PTSEqualsDTS bool - H264NALUs [][]byte - H264PTS time.Duration } // ServerHandlerOnPacketRTP can be implemented by a ServerHandler. diff --git a/serversession.go b/serversession.go index d625bd12..7aeaa194 100644 --- a/serversession.go +++ b/serversession.go @@ -18,7 +18,6 @@ import ( "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/ringbuffer" "github.com/aler9/gortsplib/pkg/rtcpreceiver" - "github.com/aler9/gortsplib/pkg/rtpcleaner" "github.com/aler9/gortsplib/pkg/rtpreorderer" "github.com/aler9/gortsplib/pkg/url" ) @@ -144,6 +143,7 @@ func (s ServerSessionState) String() string { // ServerSessionSetuppedTrack is a setupped track of a ServerSession. type ServerSessionSetuppedTrack struct { id int + track Track // filled only when publishing tcpChannel int udpRTPReadPort int udpRTPWriteAddr *net.UDPAddr @@ -153,7 +153,6 @@ type ServerSessionSetuppedTrack struct { // publish udpRTCPReceiver *rtcpreceiver.RTCPReceiver reorderer *rtpreorderer.Reorderer - cleaner *rtpcleaner.Cleaner } // ServerSession is a server-side RTSP session. @@ -167,7 +166,7 @@ type ServerSession struct { conns map[*ServerConn]struct{} state ServerSessionState setuppedTracks map[int]*ServerSessionSetuppedTrack - tcpTracksByChannel map[int]int + tcpTracksByChannel map[int]*ServerSessionSetuppedTrack setuppedTransport *Transport setuppedBaseURL *url.URL // publish setuppedStream *ServerStream // read @@ -742,6 +741,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base id: trackID, } + if ss.state == ServerSessionStatePreRecord { + sst.track = ss.announcedTracks[trackID] + } + switch transport { case TransportUDP: sst.udpRTPReadPort = inTH.ClientPorts[0] @@ -779,10 +782,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base sst.tcpChannel = inTH.InterleavedIDs[0] if ss.tcpTracksByChannel == nil { - ss.tcpTracksByChannel = make(map[int]int) + ss.tcpTracksByChannel = make(map[int]*ServerSessionSetuppedTrack) } - ss.tcpTracksByChannel[inTH.InterleavedIDs[0]] = trackID + ss.tcpTracksByChannel[inTH.InterleavedIDs[0]] = sst th.Protocol = headers.TransportProtocolTCP de := headers.TransportDeliveryUnicast @@ -793,7 +796,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base if ss.setuppedTracks == nil { ss.setuppedTracks = make(map[int]*ServerSessionSetuppedTrack) } - ss.setuppedTracks[trackID] = sst res.Header["Transport"] = th.Marshal() @@ -961,12 +963,10 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStateRecord - for trackID, st := range ss.setuppedTracks { + for _, st := range ss.setuppedTracks { if *ss.setuppedTransport == TransportUDP { st.reorderer = rtpreorderer.New() } - _, isH264 := ss.announcedTracks[trackID].(*TrackH264) - st.cleaner = rtpcleaner.New(isH264, *ss.setuppedTransport == TransportTCP) } switch *ss.setuppedTransport { @@ -987,7 +987,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base st.udpRTCPReceiver = rtcpreceiver.New( ss.s.udpReceiverReportPeriod, nil, - ss.announcedTracks[trackID].ClockRate(), + st.track.ClockRate(), func(pkt rtcp.Packet) { ss.WritePacketRTCP(ctrackID, pkt) }) @@ -1078,7 +1078,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } for _, st := range ss.setuppedTracks { - st.cleaner = nil st.reorderer = nil } diff --git a/serverudpl.go b/serverudpl.go index 42eff354..2f1c780b 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -222,28 +222,19 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { // do not return } + track := clientData.track.track + for _, pkt := range packets { - out, err := clientData.track.cleaner.Process(pkt) - if err != nil { - onDecodeError(clientData.session, err) - // do not return - } + ptsEqualsDTS := ptsEqualsDTS(track, pkt) + clientData.track.udpRTCPReceiver.ProcessPacketRTP(now, pkt, ptsEqualsDTS) - if out != nil { - out0 := out[0] - - clientData.track.udpRTCPReceiver.ProcessPacketRTP(now, pkt, out0.PTSEqualsDTS) - - if h, ok := clientData.session.s.Handler.(ServerHandlerOnPacketRTP); ok { - h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ - Session: clientData.session, - TrackID: clientData.track.id, - Packet: out0.Packet, - PTSEqualsDTS: out0.PTSEqualsDTS, - H264NALUs: out0.H264NALUs, - H264PTS: out0.H264PTS, - }) - } + if h, ok := clientData.session.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: clientData.session, + TrackID: clientData.track.id, + Packet: pkt, + PTSEqualsDTS: ptsEqualsDTS, + }) } } }