From 4e000eb2dd39b24b6b6e16f596f1f45570c5e8c2 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:41:47 +0200 Subject: [PATCH] emit a decode error in case of packets with wrong SSRC --- client_format.go | 15 +- client_media.go | 2 +- client_play_test.go | 221 ++++++++++++++------------ client_record_test.go | 10 +- pkg/rtcpreceiver/rtcpreceiver.go | 33 ++-- pkg/rtcpreceiver/rtcpreceiver_test.go | 33 ++-- pkg/rtcpsender/rtcpsender.go | 22 +-- server_record_test.go | 201 ++++++++++++----------- server_session.go | 2 +- server_session_format.go | 16 +- server_session_media.go | 14 +- server_stream.go | 30 ++-- server_udp_listener.go | 13 -- 13 files changed, 323 insertions(+), 289 deletions(-) diff --git a/client_format.go b/client_format.go index cd0ea071..62496c47 100644 --- a/client_format.go +++ b/client_format.go @@ -103,7 +103,12 @@ func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) { now := ct.cm.c.timeNow() for _, pkt := range packets { - ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt)) + err := ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt)) + if err != nil { + ct.cm.c.OnDecodeError(err) + continue + } + ct.onPacketRTP(pkt) } } @@ -123,6 +128,12 @@ func (ct *clientFormat) readRTPTCP(pkt *rtp.Packet) { } now := ct.cm.c.timeNow() - ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt)) + + err := ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt)) + if err != nil { + ct.cm.c.OnDecodeError(err) + return + } + ct.onPacketRTP(pkt) } diff --git a/client_media.go b/client_media.go index c6046abe..044311da 100644 --- a/client_media.go +++ b/client_media.go @@ -136,7 +136,7 @@ func (cm *clientMedia) stop() { func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat { for _, format := range cm.formats { - tssrc, ok := format.rtcpReceiver.LastSSRC() + tssrc, ok := format.rtcpReceiver.SenderSSRC() if ok && tssrc == ssrc { return format } diff --git a/client_play_test.go b/client_play_test.go index 2f3197e6..ea8887f2 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -34,7 +34,22 @@ func mustMarshalMedias(medias media.Medias) []byte { if err != nil { panic(err) } + return byts +} +func mustMarshalPacketRTP(pkt *rtp.Packet) []byte { + byts, err := pkt.Marshal() + if err != nil { + panic(err) + } + return byts +} + +func mustMarshalPacketRTCP(pkt rtcp.Packet) []byte { + byts, err := pkt.Marshal() + if err != nil { + panic(err) + } return byts } @@ -2119,7 +2134,7 @@ func TestClientPlayRTCPReport(t *testing.T) { _, _, err = l2.ReadFrom(buf) require.NoError(t, err) - pkt := rtp.Packet{ + _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -2129,9 +2144,7 @@ func TestClientPlayRTCPReport(t *testing.T) { SSRC: 753621, }, Payload: []byte{0x05, 0x02, 0x03, 0x04}, - } - byts, _ := pkt.Marshal() - _, err = l1.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[0], }) @@ -2140,15 +2153,13 @@ func TestClientPlayRTCPReport(t *testing.T) { // wait for the packet's SSRC to be saved time.Sleep(200 * time.Millisecond) - sr := &rtcp.SenderReport{ + _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{ SSRC: 753621, NTPTime: ntpTimeGoToRTCP(time.Date(2017, 8, 12, 15, 30, 0, 0, time.UTC)), RTPTime: 54352, PacketCount: 1, OctetCount: 4, - } - byts, _ = sr.Marshal() - _, err = l2.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[1], }) @@ -2895,12 +2906,15 @@ func TestClientPlayDecodeErrors(t *testing.T) { {"udp", "rtp invalid"}, {"udp", "rtcp invalid"}, {"udp", "rtp packets lost"}, - {"udp", "rtp too big"}, - {"udp", "rtcp too big"}, {"udp", "rtp unknown format"}, + {"udp", "wrong ssrc"}, + {"udp", "rtcp too big"}, + {"udp", "rtp too big"}, + {"tcp", "rtp invalid"}, {"tcp", "rtcp invalid"}, - {"tcp", "rtcp too big"}, {"tcp", "rtp unknown format"}, + {"tcp", "wrong ssrc"}, + {"tcp", "rtcp too big"}, } { t.Run(ca.proto+" "+ca.name, func(t *testing.T) { errorRecv := make(chan struct{}) @@ -3012,47 +3026,91 @@ func TestClientPlayDecodeErrors(t *testing.T) { }) require.NoError(t, err) + var writeRTP func(buf []byte) + var writeRTCP func(byts []byte) + + if ca.proto == "udp" { //nolint:dupl + writeRTP = func(byts []byte) { + _, err = l1.WriteTo(byts, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[0], + }) + require.NoError(t, err) + } + + writeRTCP = func(byts []byte) { + _, err = l2.WriteTo(byts, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: th.ClientPorts[1], + }) + require.NoError(t, err) + } + } else { + writeRTP = func(byts []byte) { + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: byts, + }, make([]byte, 2048)) + require.NoError(t, err) + } + + writeRTCP = func(byts []byte) { + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: byts, + }, make([]byte, 2048)) + require.NoError(t, err) + } + } + switch { //nolint:dupl - case ca.proto == "udp" && ca.name == "rtp invalid": - _, err := l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) - require.NoError(t, err) + case ca.name == "rtp invalid": + writeRTP([]byte{0x01, 0x02}) - case ca.proto == "udp" && ca.name == "rtcp invalid": - _, err := l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[1], - }) - require.NoError(t, err) + case ca.name == "rtcp invalid": + writeRTCP([]byte{0x01, 0x02}) - case ca.proto == "udp" && ca.name == "rtp packets lost": - byts, _ := rtp.Packet{ + case ca.name == "rtcp too big": + writeRTCP(bytes.Repeat([]byte{0x01, 0x02}, 2000/2)) + + case ca.name == "rtp packets lost": + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ PayloadType: 97, SequenceNumber: 30, }, - }.Marshal() + })) - _, err := l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) - require.NoError(t, err) - - byts, _ = rtp.Packet{ + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ PayloadType: 97, SequenceNumber: 100, }, - }.Marshal() + })) - _, err = l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) - require.NoError(t, err) + case ca.name == "rtp unknown format": + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ + Header: rtp.Header{ + PayloadType: 111, + }, + })) + + case ca.name == "wrong ssrc": + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ + Header: rtp.Header{ + PayloadType: 97, + SequenceNumber: 1, + SSRC: 123, + }, + })) + + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ + Header: rtp.Header{ + PayloadType: 97, + SequenceNumber: 2, + SSRC: 456, + }, + })) case ca.proto == "udp" && ca.name == "rtp too big": _, err := l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ @@ -3060,53 +3118,6 @@ func TestClientPlayDecodeErrors(t *testing.T) { Port: th.ClientPorts[0], }) require.NoError(t, err) - - case ca.proto == "udp" && ca.name == "rtcp too big": - _, err := l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[1], - }) - require.NoError(t, err) - - case ca.proto == "udp" && ca.name == "rtp unknown format": - byts, _ := rtp.Packet{ - Header: rtp.Header{ - PayloadType: 111, - }, - }.Marshal() - - _, err := l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: th.ClientPorts[0], - }) - require.NoError(t, err) - - case ca.proto == "tcp" && ca.name == "rtcp invalid": - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 1, - Payload: []byte{0x01, 0x02}, - }, make([]byte, 2048)) - require.NoError(t, err) - - case ca.proto == "tcp" && ca.name == "rtcp too big": - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 1, - Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), - }, make([]byte, 2048)) - require.NoError(t, err) - - case ca.proto == "tcp" && ca.name == "rtp unknown format": - byts, _ := rtp.Packet{ - Header: rtp.Header{ - PayloadType: 111, - }, - }.Marshal() - - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 0, - Payload: byts, - }, make([]byte, 2048)) - require.NoError(t, err) } req, err = conn.ReadRequest() @@ -3129,21 +3140,22 @@ func TestClientPlayDecodeErrors(t *testing.T) { return &v }(), OnPacketLost: func(err error) { - if ca.proto == "udp" && ca.name == "rtp packets lost" { - require.EqualError(t, err, "69 RTP packets lost") - } + require.EqualError(t, err, "69 RTP packets lost") close(errorRecv) }, OnDecodeError: func(err error) { switch { - case ca.proto == "udp" && ca.name == "rtp invalid": + case ca.name == "rtp invalid": require.EqualError(t, err, "RTP header size insufficient: 2 < 4") case ca.name == "rtcp invalid": require.EqualError(t, err, "rtcp: packet too short") - case ca.proto == "udp" && ca.name == "rtp too big": - require.EqualError(t, err, "RTP packet is too big to be read with UDP") + case ca.name == "rtp unknown format": + require.EqualError(t, err, "received RTP packet with unknown format: 111") + + case ca.name == "wrong ssrc": + require.EqualError(t, err, "received packet with wrong SSRC 456, expected 123") case ca.proto == "udp" && ca.name == "rtcp too big": require.EqualError(t, err, "RTCP packet is too big to be read with UDP") @@ -3151,8 +3163,11 @@ func TestClientPlayDecodeErrors(t *testing.T) { case ca.proto == "tcp" && ca.name == "rtcp too big": require.EqualError(t, err, "RTCP packet size (2000) is greater than maximum allowed (1472)") - case ca.name == "rtp unknown format": - require.EqualError(t, err, "received RTP packet with unknown format: 111") + case ca.proto == "udp" && ca.name == "rtp too big": + require.EqualError(t, err, "RTP packet is too big to be read with UDP") + + default: + t.Errorf("unexpected") } close(errorRecv) }, @@ -3261,7 +3276,7 @@ func TestClientPlayPacketNTP(t *testing.T) { _, _, err = l2.ReadFrom(buf) require.NoError(t, err) - pkt := rtp.Packet{ + _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -3271,9 +3286,7 @@ func TestClientPlayPacketNTP(t *testing.T) { SSRC: 753621, }, Payload: []byte{1, 2, 3, 4}, - } - byts, _ := pkt.Marshal() - _, err = l1.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[0], }) @@ -3282,15 +3295,13 @@ func TestClientPlayPacketNTP(t *testing.T) { // wait for the packet's SSRC to be saved time.Sleep(100 * time.Millisecond) - sr := &rtcp.SenderReport{ + _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{ SSRC: 753621, NTPTime: ntpTimeGoToRTCP(time.Date(2017, 8, 12, 15, 30, 0, 0, time.UTC)), RTPTime: 54352, PacketCount: 1, OctetCount: 4, - } - byts, _ = sr.Marshal() - _, err = l2.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[1], }) @@ -3298,7 +3309,7 @@ func TestClientPlayPacketNTP(t *testing.T) { time.Sleep(100 * time.Millisecond) - pkt = rtp.Packet{ + _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -3308,9 +3319,7 @@ func TestClientPlayPacketNTP(t *testing.T) { SSRC: 753621, }, Payload: []byte{5, 6, 7, 8}, - } - byts, _ = pkt.Marshal() - _, err = l1.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: inTH.ClientPorts[0], }) diff --git a/client_record_test.go b/client_record_test.go index 96684646..02251942 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -42,10 +42,7 @@ var testRTPPacket = rtp.Packet{ Payload: []byte{0x01, 0x02, 0x03, 0x04}, } -var testRTPPacketMarshaled = func() []byte { - byts, _ := testRTPPacket.Marshal() - return byts -}() +var testRTPPacketMarshaled = mustMarshalPacketRTP(&testRTPPacket) var testRTCPPacket = rtcp.SourceDescription{ Chunks: []rtcp.SourceDescriptionChunk{ @@ -61,10 +58,7 @@ var testRTCPPacket = rtcp.SourceDescription{ }, } -var testRTCPPacketMarshaled = func() []byte { - byts, _ := testRTCPPacket.Marshal() - return byts -}() +var testRTCPPacketMarshaled = mustMarshalPacketRTCP(&testRTCPPacket) func ntpTimeGoToRTCP(v time.Time) uint64 { s := uint64(v.UnixNano()) + 2208988800*1000000000 diff --git a/pkg/rtcpreceiver/rtcpreceiver.go b/pkg/rtcpreceiver/rtcpreceiver.go index caf8cf57..53234457 100644 --- a/pkg/rtcpreceiver/rtcpreceiver.go +++ b/pkg/rtcpreceiver/rtcpreceiver.go @@ -3,6 +3,7 @@ package rtcpreceiver import ( "crypto/rand" + "fmt" "sync" "time" @@ -33,14 +34,14 @@ type RTCPReceiver struct { period time.Duration timeNow func() time.Time writePacketRTCP func(rtcp.Packet) - mutex sync.Mutex + mutex sync.RWMutex // data from RTP packets firstRTPPacketReceived bool timeInitialized bool sequenceNumberCycles uint16 - lastSSRC uint32 lastSequenceNumber uint16 + senderSSRC uint32 lastTimeRTP uint32 lastTimeSystem time.Time totalLost uint32 @@ -133,7 +134,7 @@ func (rr *RTCPReceiver) report() rtcp.Packet { SSRC: rr.receiverSSRC, Reports: []rtcp.ReceptionReport{ { - SSRC: rr.lastSSRC, + SSRC: rr.senderSSRC, LastSequenceNumber: uint32(rr.sequenceNumberCycles)<<16 | uint32(rr.lastSequenceNumber), // equivalent to taking the integer part after multiplying the // loss fraction by 256 @@ -161,7 +162,7 @@ func (rr *RTCPReceiver) report() rtcp.Packet { } // ProcessPacket extracts the needed data from RTP packets. -func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqualsDTS bool) { +func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqualsDTS bool) error { rr.mutex.Lock() defer rr.mutex.Unlock() @@ -169,8 +170,8 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua if !rr.firstRTPPacketReceived { rr.firstRTPPacketReceived = true rr.totalSinceReport = 1 - rr.lastSSRC = pkt.SSRC rr.lastSequenceNumber = pkt.SequenceNumber + rr.senderSSRC = pkt.SSRC if ptsEqualsDTS { rr.timeInitialized = true @@ -180,6 +181,10 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua // subsequent packets } else { + if pkt.SSRC != rr.senderSSRC { + return fmt.Errorf("received packet with wrong SSRC %d, expected %d", pkt.SSRC, rr.senderSSRC) + } + diff := int32(pkt.SequenceNumber) - int32(rr.lastSequenceNumber) // overflow @@ -202,7 +207,6 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua } rr.totalSinceReport += uint32(uint16(diff)) - rr.lastSSRC = pkt.SSRC rr.lastSequenceNumber = pkt.SequenceNumber if ptsEqualsDTS { @@ -220,9 +224,10 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua rr.timeInitialized = true rr.lastTimeRTP = pkt.Timestamp rr.lastTimeSystem = system - rr.lastSSRC = pkt.SSRC } } + + return nil } // ProcessSenderReport extracts the needed data from RTCP sender reports. @@ -236,13 +241,6 @@ func (rr *RTCPReceiver) ProcessSenderReport(sr *rtcp.SenderReport, system time.T rr.lastSenderReportTimeSystem = system } -// LastSSRC returns the SSRC of the last RTP packet. -func (rr *RTCPReceiver) LastSSRC() (uint32, bool) { - rr.mutex.Lock() - defer rr.mutex.Unlock() - return rr.lastSSRC, rr.firstRTPPacketReceived -} - // PacketNTP returns the NTP timestamp of the packet. func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) { rr.mutex.Lock() @@ -257,3 +255,10 @@ func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) { return ntpTimeRTCPToGo(rr.lastSenderReportTimeNTP).Add(timeDiffGo), true } + +// SenderSSRC returns the SSRC of outgoing RTP packets. +func (rr *RTCPReceiver) SenderSSRC() (uint32, bool) { + rr.mutex.RLock() + defer rr.mutex.RUnlock() + return rr.senderSSRC, rr.firstRTPPacketReceived +} diff --git a/pkg/rtcpreceiver/rtcpreceiver_test.go b/pkg/rtcpreceiver/rtcpreceiver_test.go index 7b2a2a08..36570516 100644 --- a/pkg/rtcpreceiver/rtcpreceiver_test.go +++ b/pkg/rtcpreceiver/rtcpreceiver_test.go @@ -62,7 +62,8 @@ func TestRTCPReceiverBase(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -76,7 +77,8 @@ func TestRTCPReceiverBase(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) <-done } @@ -132,7 +134,8 @@ func TestRTCPReceiverOverflow(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -146,7 +149,8 @@ func TestRTCPReceiverOverflow(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) <-done } @@ -205,7 +209,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -219,7 +224,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) <-done } @@ -278,7 +284,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -292,7 +299,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) <-done } @@ -347,7 +355,8 @@ func TestRTCPReceiverJitter(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -361,7 +370,8 @@ func TestRTCPReceiverJitter(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, true) + err = rr.ProcessPacket(&rtpPkt, ts, true) + require.NoError(t, err) rtpPkt = rtp.Packet{ Header: rtp.Header{ @@ -375,7 +385,8 @@ func TestRTCPReceiverJitter(t *testing.T) { Payload: []byte("\x00\x00"), } ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) - rr.ProcessPacket(&rtpPkt, ts, false) + err = rr.ProcessPacket(&rtpPkt, ts, false) + require.NoError(t, err) <-done } diff --git a/pkg/rtcpsender/rtcpsender.go b/pkg/rtcpsender/rtcpsender.go index 805d80e7..9af5f4e3 100644 --- a/pkg/rtcpsender/rtcpsender.go +++ b/pkg/rtcpsender/rtcpsender.go @@ -22,14 +22,14 @@ type RTCPSender struct { period time.Duration timeNow func() time.Time writePacketRTCP func(rtcp.Packet) - mutex sync.Mutex + mutex sync.RWMutex // data from RTP packets initialized bool lastTimeRTP uint32 lastTimeNTP time.Time lastTimeSystem time.Time - lastSSRC uint32 + senderSSRC uint32 lastSequenceNumber uint16 packetCount uint32 octetCount uint32 @@ -102,7 +102,7 @@ func (rs *RTCPSender) report() rtcp.Packet { rtpTime := rs.lastTimeRTP + uint32(systemTimeDiff.Seconds()*rs.clockRate) return &rtcp.SenderReport{ - SSRC: rs.lastSSRC, + SSRC: rs.senderSSRC, NTPTime: ntpTimeGoToRTCP(ntpTime), RTPTime: rtpTime, PacketCount: rs.packetCount, @@ -120,25 +120,25 @@ func (rs *RTCPSender) ProcessPacket(pkt *rtp.Packet, ntp time.Time, ptsEqualsDTS rs.lastTimeRTP = pkt.Timestamp rs.lastTimeNTP = ntp rs.lastTimeSystem = rs.timeNow() + rs.senderSSRC = pkt.SSRC } - rs.lastSSRC = pkt.SSRC rs.lastSequenceNumber = pkt.SequenceNumber rs.packetCount++ rs.octetCount += uint32(len(pkt.Payload)) } -// LastSSRC returns the SSRC of the last RTP packet. -func (rs *RTCPSender) LastSSRC() (uint32, bool) { - rs.mutex.Lock() - defer rs.mutex.Unlock() - return rs.lastSSRC, rs.initialized +// SenderSSRC returns the SSRC of outgoing RTP packets. +func (rs *RTCPSender) SenderSSRC() (uint32, bool) { + rs.mutex.RLock() + defer rs.mutex.RUnlock() + return rs.senderSSRC, rs.initialized } // LastPacketData returns metadata of the last RTP packet. func (rs *RTCPSender) LastPacketData() (uint16, uint32, time.Time, bool) { - rs.mutex.Lock() - defer rs.mutex.Unlock() + rs.mutex.RLock() + defer rs.mutex.RUnlock() return rs.lastSequenceNumber, rs.lastTimeRTP, rs.lastTimeNTP, rs.initialized } diff --git a/server_record_test.go b/server_record_test.go index 143c2dff..a9910ac5 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -865,7 +865,7 @@ func TestServerRecordRTCPReport(t *testing.T) { doRecord(t, conn, "rtsp://localhost:8554/teststream", session) - byts, _ := (&rtp.Packet{ + _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -875,8 +875,7 @@ func TestServerRecordRTCPReport(t *testing.T) { SSRC: 753621, }, Payload: []byte{1, 2, 3, 4}, - }).Marshal() - _, err = l1.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[0], }) @@ -885,14 +884,13 @@ func TestServerRecordRTCPReport(t *testing.T) { // wait for the packet's SSRC to be saved time.Sleep(200 * time.Millisecond) - byts, _ = (&rtcp.SenderReport{ + _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{ SSRC: 753621, NTPTime: ntpTimeGoToRTCP(time.Date(2018, 2, 20, 19, 0, 0, 0, time.UTC)), RTPTime: 54352, PacketCount: 1, OctetCount: 4, - }).Marshal() - _, err = l2.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[1], }) @@ -1198,12 +1196,15 @@ func TestServerRecordDecodeErrors(t *testing.T) { {"udp", "rtp invalid"}, {"udp", "rtcp invalid"}, {"udp", "rtp packets lost"}, - {"udp", "rtp too big"}, - {"udp", "rtcp too big"}, {"udp", "rtp unknown format"}, + {"udp", "wrong ssrc"}, + {"udp", "rtcp too big"}, + {"udp", "rtp too big"}, {"tcp", "rtcp invalid"}, - {"tcp", "rtcp too big"}, + {"tcp", "rtp packets lost"}, {"tcp", "rtp unknown format"}, + {"tcp", "wrong ssrc"}, + {"tcp", "rtcp too big"}, } { t.Run(ca.proto+" "+ca.name, func(t *testing.T) { errorRecv := make(chan struct{}) @@ -1226,21 +1227,22 @@ func TestServerRecordDecodeErrors(t *testing.T) { }, nil }, onPacketLost: func(ctx *ServerHandlerOnPacketLostCtx) { - if ca.proto == "udp" && ca.name == "rtp packets lost" { - require.EqualError(t, ctx.Error, "69 RTP packets lost") - } + require.EqualError(t, ctx.Error, "69 RTP packets lost") close(errorRecv) }, onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { switch { - case ca.proto == "udp" && ca.name == "rtp invalid": + case ca.name == "rtp invalid": require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") case ca.name == "rtcp invalid": require.EqualError(t, ctx.Error, "rtcp: packet too short") - case ca.proto == "udp" && ca.name == "rtp too big": - require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP") + case ca.name == "rtp unknown format": + require.EqualError(t, ctx.Error, "received RTP packet with unknown format: 111") + + case ca.name == "wrong ssrc": + require.EqualError(t, ctx.Error, "received packet with wrong SSRC 456, expected 123") case ca.proto == "udp" && ca.name == "rtcp too big": require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP") @@ -1248,8 +1250,11 @@ func TestServerRecordDecodeErrors(t *testing.T) { case ca.proto == "tcp" && ca.name == "rtcp too big": require.EqualError(t, ctx.Error, "RTCP packet size (2000) is greater than maximum allowed (1472)") - case ca.name == "rtp unknown format": - require.EqualError(t, ctx.Error, "received RTP packet with unknown format: 111") + case ca.proto == "udp" && ca.name == "rtp too big": + require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP") + + default: + t.Errorf("unexpected") } close(errorRecv) }, @@ -1317,47 +1322,91 @@ func TestServerRecordDecodeErrors(t *testing.T) { doRecord(t, conn, "rtsp://localhost:8554/teststream", session) + var writeRTP func(buf []byte) + var writeRTCP func(byts []byte) + + if ca.proto == "udp" { //nolint:dupl + writeRTP = func(byts []byte) { + _, err = l1.WriteTo(byts, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[0], + }) + require.NoError(t, err) + } + + writeRTCP = func(byts []byte) { + _, err = l2.WriteTo(byts, &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: resTH.ServerPorts[1], + }) + require.NoError(t, err) + } + } else { + writeRTP = func(byts []byte) { + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: byts, + }, make([]byte, 2048)) + require.NoError(t, err) + } + + writeRTCP = func(byts []byte) { + err = conn.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 1, + Payload: byts, + }, make([]byte, 2048)) + require.NoError(t, err) + } + } + switch { //nolint:dupl - case ca.proto == "udp" && ca.name == "rtp invalid": - _, err := l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[0], - }) - require.NoError(t, err) + case ca.name == "rtp invalid": + writeRTP([]byte{0x01, 0x02}) - case ca.proto == "udp" && ca.name == "rtcp invalid": - _, err := l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[1], - }) - require.NoError(t, err) + case ca.name == "rtcp invalid": + writeRTCP([]byte{0x01, 0x02}) - case ca.proto == "udp" && ca.name == "rtp packets lost": - byts, _ := rtp.Packet{ + case ca.name == "rtcp too big": + writeRTCP(bytes.Repeat([]byte{0x01, 0x02}, 2000/2)) + + case ca.name == "rtp packets lost": + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ PayloadType: 97, SequenceNumber: 30, }, - }.Marshal() + })) - _, err := l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[0], - }) - require.NoError(t, err) - - byts, _ = rtp.Packet{ + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ PayloadType: 97, SequenceNumber: 100, }, - }.Marshal() + })) - _, err = l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[0], - }) - require.NoError(t, err) + case ca.name == "rtp unknown format": + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ + Header: rtp.Header{ + PayloadType: 111, + }, + })) + + case ca.name == "wrong ssrc": + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ + Header: rtp.Header{ + PayloadType: 97, + SequenceNumber: 1, + SSRC: 123, + }, + })) + + writeRTP(mustMarshalPacketRTP(&rtp.Packet{ + Header: rtp.Header{ + PayloadType: 97, + SequenceNumber: 2, + SSRC: 456, + }, + })) case ca.proto == "udp" && ca.name == "rtp too big": _, err := l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ @@ -1365,53 +1414,6 @@ func TestServerRecordDecodeErrors(t *testing.T) { Port: resTH.ServerPorts[0], }) require.NoError(t, err) - - case ca.proto == "udp" && ca.name == "rtcp too big": - _, err := l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[1], - }) - require.NoError(t, err) - - case ca.proto == "udp" && ca.name == "rtp unknown format": - byts, _ := rtp.Packet{ - Header: rtp.Header{ - PayloadType: 111, - }, - }.Marshal() - - _, err := l1.WriteTo(byts, &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: resTH.ServerPorts[0], - }) - require.NoError(t, err) - - case ca.proto == "tcp" && ca.name == "rtcp invalid": - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 1, - Payload: []byte{0x01, 0x02}, - }, make([]byte, 2048)) - require.NoError(t, err) - - case ca.proto == "tcp" && ca.name == "rtcp too big": - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 1, - Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2), - }, make([]byte, 2048)) - require.NoError(t, err) - - case ca.proto == "tcp" && ca.name == "rtp unknown format": - byts, _ := rtp.Packet{ - Header: rtp.Header{ - PayloadType: 111, - }, - }.Marshal() - - err := conn.WriteInterleavedFrame(&base.InterleavedFrame{ - Channel: 0, - Payload: byts, - }, make([]byte, 2048)) - require.NoError(t, err) } <-errorRecv @@ -1498,7 +1500,7 @@ func TestServerRecordPacketNTP(t *testing.T) { doRecord(t, conn, "rtsp://localhost:8554/teststream", session) - byts, _ := (&rtp.Packet{ + _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -1508,8 +1510,7 @@ func TestServerRecordPacketNTP(t *testing.T) { SSRC: 753621, }, Payload: []byte{1, 2, 3, 4}, - }).Marshal() - _, err = l1.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[0], }) @@ -1518,14 +1519,13 @@ func TestServerRecordPacketNTP(t *testing.T) { // wait for the packet's SSRC to be saved time.Sleep(100 * time.Millisecond) - byts, _ = (&rtcp.SenderReport{ + _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{ SSRC: 753621, NTPTime: ntpTimeGoToRTCP(time.Date(2018, 2, 20, 19, 0, 0, 0, time.UTC)), RTPTime: 54352, PacketCount: 1, OctetCount: 4, - }).Marshal() - _, err = l2.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[1], }) @@ -1533,7 +1533,7 @@ func TestServerRecordPacketNTP(t *testing.T) { time.Sleep(100 * time.Millisecond) - byts, _ = (&rtp.Packet{ + _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: true, @@ -1543,8 +1543,7 @@ func TestServerRecordPacketNTP(t *testing.T) { SSRC: 753621, }, Payload: []byte{1, 2, 3, 4}, - }).Marshal() - _, err = l1.WriteTo(byts, &net.UDPAddr{ + }), &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: th.ServerPorts[0], }) diff --git a/server_session.go b/server_session.go index db49e45d..41a3d751 100644 --- a/server_session.go +++ b/server_session.go @@ -758,7 +758,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( th := headers.Transport{} if ss.state == ServerSessionStatePrePlay { - ssrc, ok := stream.lastSSRC(medi) + ssrc, ok := stream.senderSSRC(medi) if ok { th.SSRC = &ssrc } diff --git a/server_session_format.go b/server_session_format.go index b2f45ecd..560f63d5 100644 --- a/server_session_format.go +++ b/server_session_format.go @@ -58,7 +58,6 @@ func (sf *serverSessionFormat) start() { func (sf *serverSessionFormat) stop() { if sf.rtcpReceiver != nil { sf.rtcpReceiver.Close() - sf.rtcpReceiver = nil } } @@ -77,7 +76,12 @@ func (sf *serverSessionFormat) readRTPUDP(pkt *rtp.Packet, now time.Time) { } for _, pkt := range packets { - sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt)) + err := sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt)) + if err != nil { + sf.sm.ss.onDecodeError(err) + continue + } + sf.onPacketRTP(pkt) } } @@ -97,6 +101,12 @@ func (sf *serverSessionFormat) readRTPTCP(pkt *rtp.Packet) { } now := sf.sm.ss.s.timeNow() - sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt)) + + err := sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt)) + if err != nil { + sf.sm.ss.onDecodeError(err) + return + } + sf.onPacketRTP(pkt) } diff --git a/server_session_media.go b/server_session_media.go index cf484b67..e959929a 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -108,6 +108,16 @@ func (sm *serverSessionMedia) stop() { } } +func (sm *serverSessionMedia) findFormatWithSSRC(ssrc uint32) *serverSessionFormat { + for _, format := range sm.formats { + tssrc, ok := format.rtcpReceiver.SenderSSRC() + if ok && tssrc == ssrc { + return format + } + } + return nil +} + func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) { atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck @@ -218,7 +228,7 @@ func (sm *serverSessionMedia) readRTCPUDPRecord(payload []byte) { for _, pkt := range packets { if sr, ok := pkt.(*rtcp.SenderReport); ok { - format := serverFindFormatWithSSRC(sm.formats, sr.SSRC) + format := sm.findFormatWithSSRC(sr.SSRC) if format != nil { format.rtcpReceiver.ProcessSenderReport(sr, now) } @@ -283,7 +293,7 @@ func (sm *serverSessionMedia) readRTCPTCPRecord(payload []byte) { for _, pkt := range packets { if sr, ok := pkt.(*rtcp.SenderReport); ok { - format := serverFindFormatWithSSRC(sm.formats, sr.SSRC) + format := sm.findFormatWithSSRC(sr.SSRC) if format != nil { format.rtcpReceiver.ProcessSenderReport(sr, now) } diff --git a/server_stream.go b/server_stream.go index 6d389861..b90dae2e 100644 --- a/server_stream.go +++ b/server_stream.go @@ -13,6 +13,16 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/media" ) +func firstFormat(formats map[uint8]*serverStreamFormat) *serverStreamFormat { + var firstKey uint8 + for key := range formats { + firstKey = key + break + } + + return formats[firstKey] +} + // ServerStream represents a data stream. // This is in charge of // - distributing the stream to each reader @@ -66,26 +76,20 @@ func (st *ServerStream) Medias() media.Medias { return st.medias } -func (st *ServerStream) lastSSRC(medi *media.Media) (uint32, bool) { +func (st *ServerStream) senderSSRC(medi *media.Media) (uint32, bool) { st.mutex.Lock() defer st.mutex.Unlock() sm := st.streamMedias[medi] - // since lastSSRC() is used to fill SSRC inside the Transport header, + // senderSSRC() is used to fill SSRC inside the Transport header. // if there are multiple formats inside a single media stream, // do not return anything, since Transport headers don't support multiple SSRCs. if len(sm.formats) > 1 { return 0, false } - var firstKey uint8 - for key := range sm.formats { - firstKey = key - break - } - - return sm.formats[firstKey].rtcpSender.LastSSRC() + return firstFormat(sm.formats).rtcpSender.SenderSSRC() } func (st *ServerStream) rtpInfoEntry(medi *media.Media, now time.Time) *headers.RTPInfoEntry { @@ -101,13 +105,7 @@ func (st *ServerStream) rtpInfoEntry(medi *media.Media, now time.Time) *headers. return nil } - var firstKey uint8 - for key := range sm.formats { - firstKey = key - break - } - - format := sm.formats[firstKey] + format := firstFormat(sm.formats) lastSeqNum, lastTimeRTP, lastTimeNTP, ok := format.rtcpSender.LastPacketData() if !ok { diff --git a/server_udp_listener.go b/server_udp_listener.go index c0f4e43c..a281877c 100644 --- a/server_udp_listener.go +++ b/server_udp_listener.go @@ -10,19 +10,6 @@ import ( "golang.org/x/net/ipv4" ) -func serverFindFormatWithSSRC( - formats map[uint8]*serverSessionFormat, - ssrc uint32, -) *serverSessionFormat { - for _, format := range formats { - tssrc, ok := format.rtcpReceiver.LastSSRC() - if ok && tssrc == ssrc { - return format - } - } - return nil -} - func joinMulticastGroupOnAtLeastOneInterface(p *ipv4.PacketConn, listenIP net.IP) error { intfs, err := net.Interfaces() if err != nil {