emit a decode error in case of packets with wrong SSRC

This commit is contained in:
aler9
2023-08-17 14:41:47 +02:00
committed by Alessandro Ros
parent 8b047b545b
commit 4e000eb2dd
13 changed files with 323 additions and 289 deletions

View File

@@ -103,7 +103,12 @@ func (ct *clientFormat) readRTPUDP(pkt *rtp.Packet) {
now := ct.cm.c.timeNow() now := ct.cm.c.timeNow()
for _, pkt := range packets { for _, pkt := range packets {
ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt)) err := ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt))
if err != nil {
ct.cm.c.OnDecodeError(err)
continue
}
ct.onPacketRTP(pkt) ct.onPacketRTP(pkt)
} }
} }
@@ -123,6 +128,12 @@ func (ct *clientFormat) readRTPTCP(pkt *rtp.Packet) {
} }
now := ct.cm.c.timeNow() now := ct.cm.c.timeNow()
ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt))
err := ct.rtcpReceiver.ProcessPacket(pkt, now, ct.format.PTSEqualsDTS(pkt))
if err != nil {
ct.cm.c.OnDecodeError(err)
return
}
ct.onPacketRTP(pkt) ct.onPacketRTP(pkt)
} }

View File

@@ -136,7 +136,7 @@ func (cm *clientMedia) stop() {
func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat { func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat {
for _, format := range cm.formats { for _, format := range cm.formats {
tssrc, ok := format.rtcpReceiver.LastSSRC() tssrc, ok := format.rtcpReceiver.SenderSSRC()
if ok && tssrc == ssrc { if ok && tssrc == ssrc {
return format return format
} }

View File

@@ -34,7 +34,22 @@ func mustMarshalMedias(medias media.Medias) []byte {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return byts
}
func mustMarshalPacketRTP(pkt *rtp.Packet) []byte {
byts, err := pkt.Marshal()
if err != nil {
panic(err)
}
return byts
}
func mustMarshalPacketRTCP(pkt rtcp.Packet) []byte {
byts, err := pkt.Marshal()
if err != nil {
panic(err)
}
return byts return byts
} }
@@ -2119,7 +2134,7 @@ func TestClientPlayRTCPReport(t *testing.T) {
_, _, err = l2.ReadFrom(buf) _, _, err = l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
pkt := rtp.Packet{ _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: true,
@@ -2129,9 +2144,7 @@ func TestClientPlayRTCPReport(t *testing.T) {
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte{0x05, 0x02, 0x03, 0x04}, Payload: []byte{0x05, 0x02, 0x03, 0x04},
} }), &net.UDPAddr{
byts, _ := pkt.Marshal()
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: inTH.ClientPorts[0], Port: inTH.ClientPorts[0],
}) })
@@ -2140,15 +2153,13 @@ func TestClientPlayRTCPReport(t *testing.T) {
// wait for the packet's SSRC to be saved // wait for the packet's SSRC to be saved
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
sr := &rtcp.SenderReport{ _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{
SSRC: 753621, SSRC: 753621,
NTPTime: ntpTimeGoToRTCP(time.Date(2017, 8, 12, 15, 30, 0, 0, time.UTC)), NTPTime: ntpTimeGoToRTCP(time.Date(2017, 8, 12, 15, 30, 0, 0, time.UTC)),
RTPTime: 54352, RTPTime: 54352,
PacketCount: 1, PacketCount: 1,
OctetCount: 4, OctetCount: 4,
} }), &net.UDPAddr{
byts, _ = sr.Marshal()
_, err = l2.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: inTH.ClientPorts[1], Port: inTH.ClientPorts[1],
}) })
@@ -2895,12 +2906,15 @@ func TestClientPlayDecodeErrors(t *testing.T) {
{"udp", "rtp invalid"}, {"udp", "rtp invalid"},
{"udp", "rtcp invalid"}, {"udp", "rtcp invalid"},
{"udp", "rtp packets lost"}, {"udp", "rtp packets lost"},
{"udp", "rtp too big"},
{"udp", "rtcp too big"},
{"udp", "rtp unknown format"}, {"udp", "rtp unknown format"},
{"udp", "wrong ssrc"},
{"udp", "rtcp too big"},
{"udp", "rtp too big"},
{"tcp", "rtp invalid"},
{"tcp", "rtcp invalid"}, {"tcp", "rtcp invalid"},
{"tcp", "rtcp too big"},
{"tcp", "rtp unknown format"}, {"tcp", "rtp unknown format"},
{"tcp", "wrong ssrc"},
{"tcp", "rtcp too big"},
} { } {
t.Run(ca.proto+" "+ca.name, func(t *testing.T) { t.Run(ca.proto+" "+ca.name, func(t *testing.T) {
errorRecv := make(chan struct{}) errorRecv := make(chan struct{})
@@ -3012,47 +3026,91 @@ func TestClientPlayDecodeErrors(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
var writeRTP func(buf []byte)
var writeRTCP func(byts []byte)
if ca.proto == "udp" { //nolint:dupl
writeRTP = func(byts []byte) {
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
require.NoError(t, err)
}
writeRTCP = func(byts []byte) {
_, err = l2.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
require.NoError(t, err)
}
} else {
writeRTP = func(byts []byte) {
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: byts,
}, make([]byte, 2048))
require.NoError(t, err)
}
writeRTCP = func(byts []byte) {
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: byts,
}, make([]byte, 2048))
require.NoError(t, err)
}
}
switch { //nolint:dupl switch { //nolint:dupl
case ca.proto == "udp" && ca.name == "rtp invalid": case ca.name == "rtp invalid":
_, err := l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ writeRTP([]byte{0x01, 0x02})
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtcp invalid": case ca.name == "rtcp invalid":
_, err := l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ writeRTCP([]byte{0x01, 0x02})
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtp packets lost": case ca.name == "rtcp too big":
byts, _ := rtp.Packet{ writeRTCP(bytes.Repeat([]byte{0x01, 0x02}, 2000/2))
case ca.name == "rtp packets lost":
writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
PayloadType: 97, PayloadType: 97,
SequenceNumber: 30, SequenceNumber: 30,
}, },
}.Marshal() }))
_, err := l1.WriteTo(byts, &net.UDPAddr{ writeRTP(mustMarshalPacketRTP(&rtp.Packet{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
require.NoError(t, err)
byts, _ = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
PayloadType: 97, PayloadType: 97,
SequenceNumber: 100, SequenceNumber: 100,
}, },
}.Marshal() }))
_, err = l1.WriteTo(byts, &net.UDPAddr{ case ca.name == "rtp unknown format":
IP: net.ParseIP("127.0.0.1"), writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Port: th.ClientPorts[0], Header: rtp.Header{
}) PayloadType: 111,
require.NoError(t, err) },
}))
case ca.name == "wrong ssrc":
writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{
PayloadType: 97,
SequenceNumber: 1,
SSRC: 123,
},
}))
writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{
PayloadType: 97,
SequenceNumber: 2,
SSRC: 456,
},
}))
case ca.proto == "udp" && ca.name == "rtp too big": case ca.proto == "udp" && ca.name == "rtp too big":
_, err := l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ _, err := l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
@@ -3060,53 +3118,6 @@ func TestClientPlayDecodeErrors(t *testing.T) {
Port: th.ClientPorts[0], Port: th.ClientPorts[0],
}) })
require.NoError(t, err) require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtcp too big":
_, err := l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[1],
})
require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtp unknown format":
byts, _ := rtp.Packet{
Header: rtp.Header{
PayloadType: 111,
},
}.Marshal()
_, err := l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: th.ClientPorts[0],
})
require.NoError(t, err)
case ca.proto == "tcp" && ca.name == "rtcp invalid":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: []byte{0x01, 0x02},
}, make([]byte, 2048))
require.NoError(t, err)
case ca.proto == "tcp" && ca.name == "rtcp too big":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2),
}, make([]byte, 2048))
require.NoError(t, err)
case ca.proto == "tcp" && ca.name == "rtp unknown format":
byts, _ := rtp.Packet{
Header: rtp.Header{
PayloadType: 111,
},
}.Marshal()
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: byts,
}, make([]byte, 2048))
require.NoError(t, err)
} }
req, err = conn.ReadRequest() req, err = conn.ReadRequest()
@@ -3129,21 +3140,22 @@ func TestClientPlayDecodeErrors(t *testing.T) {
return &v return &v
}(), }(),
OnPacketLost: func(err error) { OnPacketLost: func(err error) {
if ca.proto == "udp" && ca.name == "rtp packets lost" { require.EqualError(t, err, "69 RTP packets lost")
require.EqualError(t, err, "69 RTP packets lost")
}
close(errorRecv) close(errorRecv)
}, },
OnDecodeError: func(err error) { OnDecodeError: func(err error) {
switch { switch {
case ca.proto == "udp" && ca.name == "rtp invalid": case ca.name == "rtp invalid":
require.EqualError(t, err, "RTP header size insufficient: 2 < 4") require.EqualError(t, err, "RTP header size insufficient: 2 < 4")
case ca.name == "rtcp invalid": case ca.name == "rtcp invalid":
require.EqualError(t, err, "rtcp: packet too short") require.EqualError(t, err, "rtcp: packet too short")
case ca.proto == "udp" && ca.name == "rtp too big": case ca.name == "rtp unknown format":
require.EqualError(t, err, "RTP packet is too big to be read with UDP") require.EqualError(t, err, "received RTP packet with unknown format: 111")
case ca.name == "wrong ssrc":
require.EqualError(t, err, "received packet with wrong SSRC 456, expected 123")
case ca.proto == "udp" && ca.name == "rtcp too big": case ca.proto == "udp" && ca.name == "rtcp too big":
require.EqualError(t, err, "RTCP packet is too big to be read with UDP") require.EqualError(t, err, "RTCP packet is too big to be read with UDP")
@@ -3151,8 +3163,11 @@ func TestClientPlayDecodeErrors(t *testing.T) {
case ca.proto == "tcp" && ca.name == "rtcp too big": case ca.proto == "tcp" && ca.name == "rtcp too big":
require.EqualError(t, err, "RTCP packet size (2000) is greater than maximum allowed (1472)") require.EqualError(t, err, "RTCP packet size (2000) is greater than maximum allowed (1472)")
case ca.name == "rtp unknown format": case ca.proto == "udp" && ca.name == "rtp too big":
require.EqualError(t, err, "received RTP packet with unknown format: 111") require.EqualError(t, err, "RTP packet is too big to be read with UDP")
default:
t.Errorf("unexpected")
} }
close(errorRecv) close(errorRecv)
}, },
@@ -3261,7 +3276,7 @@ func TestClientPlayPacketNTP(t *testing.T) {
_, _, err = l2.ReadFrom(buf) _, _, err = l2.ReadFrom(buf)
require.NoError(t, err) require.NoError(t, err)
pkt := rtp.Packet{ _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: true,
@@ -3271,9 +3286,7 @@ func TestClientPlayPacketNTP(t *testing.T) {
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte{1, 2, 3, 4}, Payload: []byte{1, 2, 3, 4},
} }), &net.UDPAddr{
byts, _ := pkt.Marshal()
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: inTH.ClientPorts[0], Port: inTH.ClientPorts[0],
}) })
@@ -3282,15 +3295,13 @@ func TestClientPlayPacketNTP(t *testing.T) {
// wait for the packet's SSRC to be saved // wait for the packet's SSRC to be saved
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
sr := &rtcp.SenderReport{ _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{
SSRC: 753621, SSRC: 753621,
NTPTime: ntpTimeGoToRTCP(time.Date(2017, 8, 12, 15, 30, 0, 0, time.UTC)), NTPTime: ntpTimeGoToRTCP(time.Date(2017, 8, 12, 15, 30, 0, 0, time.UTC)),
RTPTime: 54352, RTPTime: 54352,
PacketCount: 1, PacketCount: 1,
OctetCount: 4, OctetCount: 4,
} }), &net.UDPAddr{
byts, _ = sr.Marshal()
_, err = l2.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: inTH.ClientPorts[1], Port: inTH.ClientPorts[1],
}) })
@@ -3298,7 +3309,7 @@ func TestClientPlayPacketNTP(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
pkt = rtp.Packet{ _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: true,
@@ -3308,9 +3319,7 @@ func TestClientPlayPacketNTP(t *testing.T) {
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte{5, 6, 7, 8}, Payload: []byte{5, 6, 7, 8},
} }), &net.UDPAddr{
byts, _ = pkt.Marshal()
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: inTH.ClientPorts[0], Port: inTH.ClientPorts[0],
}) })

View File

@@ -42,10 +42,7 @@ var testRTPPacket = rtp.Packet{
Payload: []byte{0x01, 0x02, 0x03, 0x04}, Payload: []byte{0x01, 0x02, 0x03, 0x04},
} }
var testRTPPacketMarshaled = func() []byte { var testRTPPacketMarshaled = mustMarshalPacketRTP(&testRTPPacket)
byts, _ := testRTPPacket.Marshal()
return byts
}()
var testRTCPPacket = rtcp.SourceDescription{ var testRTCPPacket = rtcp.SourceDescription{
Chunks: []rtcp.SourceDescriptionChunk{ Chunks: []rtcp.SourceDescriptionChunk{
@@ -61,10 +58,7 @@ var testRTCPPacket = rtcp.SourceDescription{
}, },
} }
var testRTCPPacketMarshaled = func() []byte { var testRTCPPacketMarshaled = mustMarshalPacketRTCP(&testRTCPPacket)
byts, _ := testRTCPPacket.Marshal()
return byts
}()
func ntpTimeGoToRTCP(v time.Time) uint64 { func ntpTimeGoToRTCP(v time.Time) uint64 {
s := uint64(v.UnixNano()) + 2208988800*1000000000 s := uint64(v.UnixNano()) + 2208988800*1000000000

View File

@@ -3,6 +3,7 @@ package rtcpreceiver
import ( import (
"crypto/rand" "crypto/rand"
"fmt"
"sync" "sync"
"time" "time"
@@ -33,14 +34,14 @@ type RTCPReceiver struct {
period time.Duration period time.Duration
timeNow func() time.Time timeNow func() time.Time
writePacketRTCP func(rtcp.Packet) writePacketRTCP func(rtcp.Packet)
mutex sync.Mutex mutex sync.RWMutex
// data from RTP packets // data from RTP packets
firstRTPPacketReceived bool firstRTPPacketReceived bool
timeInitialized bool timeInitialized bool
sequenceNumberCycles uint16 sequenceNumberCycles uint16
lastSSRC uint32
lastSequenceNumber uint16 lastSequenceNumber uint16
senderSSRC uint32
lastTimeRTP uint32 lastTimeRTP uint32
lastTimeSystem time.Time lastTimeSystem time.Time
totalLost uint32 totalLost uint32
@@ -133,7 +134,7 @@ func (rr *RTCPReceiver) report() rtcp.Packet {
SSRC: rr.receiverSSRC, SSRC: rr.receiverSSRC,
Reports: []rtcp.ReceptionReport{ Reports: []rtcp.ReceptionReport{
{ {
SSRC: rr.lastSSRC, SSRC: rr.senderSSRC,
LastSequenceNumber: uint32(rr.sequenceNumberCycles)<<16 | uint32(rr.lastSequenceNumber), LastSequenceNumber: uint32(rr.sequenceNumberCycles)<<16 | uint32(rr.lastSequenceNumber),
// equivalent to taking the integer part after multiplying the // equivalent to taking the integer part after multiplying the
// loss fraction by 256 // loss fraction by 256
@@ -161,7 +162,7 @@ func (rr *RTCPReceiver) report() rtcp.Packet {
} }
// ProcessPacket extracts the needed data from RTP packets. // ProcessPacket extracts the needed data from RTP packets.
func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqualsDTS bool) { func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqualsDTS bool) error {
rr.mutex.Lock() rr.mutex.Lock()
defer rr.mutex.Unlock() defer rr.mutex.Unlock()
@@ -169,8 +170,8 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua
if !rr.firstRTPPacketReceived { if !rr.firstRTPPacketReceived {
rr.firstRTPPacketReceived = true rr.firstRTPPacketReceived = true
rr.totalSinceReport = 1 rr.totalSinceReport = 1
rr.lastSSRC = pkt.SSRC
rr.lastSequenceNumber = pkt.SequenceNumber rr.lastSequenceNumber = pkt.SequenceNumber
rr.senderSSRC = pkt.SSRC
if ptsEqualsDTS { if ptsEqualsDTS {
rr.timeInitialized = true rr.timeInitialized = true
@@ -180,6 +181,10 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua
// subsequent packets // subsequent packets
} else { } else {
if pkt.SSRC != rr.senderSSRC {
return fmt.Errorf("received packet with wrong SSRC %d, expected %d", pkt.SSRC, rr.senderSSRC)
}
diff := int32(pkt.SequenceNumber) - int32(rr.lastSequenceNumber) diff := int32(pkt.SequenceNumber) - int32(rr.lastSequenceNumber)
// overflow // overflow
@@ -202,7 +207,6 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua
} }
rr.totalSinceReport += uint32(uint16(diff)) rr.totalSinceReport += uint32(uint16(diff))
rr.lastSSRC = pkt.SSRC
rr.lastSequenceNumber = pkt.SequenceNumber rr.lastSequenceNumber = pkt.SequenceNumber
if ptsEqualsDTS { if ptsEqualsDTS {
@@ -220,9 +224,10 @@ func (rr *RTCPReceiver) ProcessPacket(pkt *rtp.Packet, system time.Time, ptsEqua
rr.timeInitialized = true rr.timeInitialized = true
rr.lastTimeRTP = pkt.Timestamp rr.lastTimeRTP = pkt.Timestamp
rr.lastTimeSystem = system rr.lastTimeSystem = system
rr.lastSSRC = pkt.SSRC
} }
} }
return nil
} }
// ProcessSenderReport extracts the needed data from RTCP sender reports. // ProcessSenderReport extracts the needed data from RTCP sender reports.
@@ -236,13 +241,6 @@ func (rr *RTCPReceiver) ProcessSenderReport(sr *rtcp.SenderReport, system time.T
rr.lastSenderReportTimeSystem = system rr.lastSenderReportTimeSystem = system
} }
// LastSSRC returns the SSRC of the last RTP packet.
func (rr *RTCPReceiver) LastSSRC() (uint32, bool) {
rr.mutex.Lock()
defer rr.mutex.Unlock()
return rr.lastSSRC, rr.firstRTPPacketReceived
}
// PacketNTP returns the NTP timestamp of the packet. // PacketNTP returns the NTP timestamp of the packet.
func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) { func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) {
rr.mutex.Lock() rr.mutex.Lock()
@@ -257,3 +255,10 @@ func (rr *RTCPReceiver) PacketNTP(ts uint32) (time.Time, bool) {
return ntpTimeRTCPToGo(rr.lastSenderReportTimeNTP).Add(timeDiffGo), true return ntpTimeRTCPToGo(rr.lastSenderReportTimeNTP).Add(timeDiffGo), true
} }
// SenderSSRC returns the SSRC of outgoing RTP packets.
func (rr *RTCPReceiver) SenderSSRC() (uint32, bool) {
rr.mutex.RLock()
defer rr.mutex.RUnlock()
return rr.senderSSRC, rr.firstRTPPacketReceived
}

View File

@@ -62,7 +62,8 @@ func TestRTCPReceiverBase(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
rtpPkt = rtp.Packet{ rtpPkt = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -76,7 +77,8 @@ func TestRTCPReceiverBase(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
<-done <-done
} }
@@ -132,7 +134,8 @@ func TestRTCPReceiverOverflow(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
rtpPkt = rtp.Packet{ rtpPkt = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -146,7 +149,8 @@ func TestRTCPReceiverOverflow(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
<-done <-done
} }
@@ -205,7 +209,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
rtpPkt = rtp.Packet{ rtpPkt = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -219,7 +224,8 @@ func TestRTCPReceiverPacketLost(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
<-done <-done
} }
@@ -278,7 +284,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
rtpPkt = rtp.Packet{ rtpPkt = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -292,7 +299,8 @@ func TestRTCPReceiverOverflowPacketLost(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
<-done <-done
} }
@@ -347,7 +355,8 @@ func TestRTCPReceiverJitter(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 20, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
rtpPkt = rtp.Packet{ rtpPkt = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -361,7 +370,8 @@ func TestRTCPReceiverJitter(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 21, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, true) err = rr.ProcessPacket(&rtpPkt, ts, true)
require.NoError(t, err)
rtpPkt = rtp.Packet{ rtpPkt = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
@@ -375,7 +385,8 @@ func TestRTCPReceiverJitter(t *testing.T) {
Payload: []byte("\x00\x00"), Payload: []byte("\x00\x00"),
} }
ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC) ts = time.Date(2008, 0o5, 20, 22, 15, 22, 0, time.UTC)
rr.ProcessPacket(&rtpPkt, ts, false) err = rr.ProcessPacket(&rtpPkt, ts, false)
require.NoError(t, err)
<-done <-done
} }

View File

@@ -22,14 +22,14 @@ type RTCPSender struct {
period time.Duration period time.Duration
timeNow func() time.Time timeNow func() time.Time
writePacketRTCP func(rtcp.Packet) writePacketRTCP func(rtcp.Packet)
mutex sync.Mutex mutex sync.RWMutex
// data from RTP packets // data from RTP packets
initialized bool initialized bool
lastTimeRTP uint32 lastTimeRTP uint32
lastTimeNTP time.Time lastTimeNTP time.Time
lastTimeSystem time.Time lastTimeSystem time.Time
lastSSRC uint32 senderSSRC uint32
lastSequenceNumber uint16 lastSequenceNumber uint16
packetCount uint32 packetCount uint32
octetCount uint32 octetCount uint32
@@ -102,7 +102,7 @@ func (rs *RTCPSender) report() rtcp.Packet {
rtpTime := rs.lastTimeRTP + uint32(systemTimeDiff.Seconds()*rs.clockRate) rtpTime := rs.lastTimeRTP + uint32(systemTimeDiff.Seconds()*rs.clockRate)
return &rtcp.SenderReport{ return &rtcp.SenderReport{
SSRC: rs.lastSSRC, SSRC: rs.senderSSRC,
NTPTime: ntpTimeGoToRTCP(ntpTime), NTPTime: ntpTimeGoToRTCP(ntpTime),
RTPTime: rtpTime, RTPTime: rtpTime,
PacketCount: rs.packetCount, PacketCount: rs.packetCount,
@@ -120,25 +120,25 @@ func (rs *RTCPSender) ProcessPacket(pkt *rtp.Packet, ntp time.Time, ptsEqualsDTS
rs.lastTimeRTP = pkt.Timestamp rs.lastTimeRTP = pkt.Timestamp
rs.lastTimeNTP = ntp rs.lastTimeNTP = ntp
rs.lastTimeSystem = rs.timeNow() rs.lastTimeSystem = rs.timeNow()
rs.senderSSRC = pkt.SSRC
} }
rs.lastSSRC = pkt.SSRC
rs.lastSequenceNumber = pkt.SequenceNumber rs.lastSequenceNumber = pkt.SequenceNumber
rs.packetCount++ rs.packetCount++
rs.octetCount += uint32(len(pkt.Payload)) rs.octetCount += uint32(len(pkt.Payload))
} }
// LastSSRC returns the SSRC of the last RTP packet. // SenderSSRC returns the SSRC of outgoing RTP packets.
func (rs *RTCPSender) LastSSRC() (uint32, bool) { func (rs *RTCPSender) SenderSSRC() (uint32, bool) {
rs.mutex.Lock() rs.mutex.RLock()
defer rs.mutex.Unlock() defer rs.mutex.RUnlock()
return rs.lastSSRC, rs.initialized return rs.senderSSRC, rs.initialized
} }
// LastPacketData returns metadata of the last RTP packet. // LastPacketData returns metadata of the last RTP packet.
func (rs *RTCPSender) LastPacketData() (uint16, uint32, time.Time, bool) { func (rs *RTCPSender) LastPacketData() (uint16, uint32, time.Time, bool) {
rs.mutex.Lock() rs.mutex.RLock()
defer rs.mutex.Unlock() defer rs.mutex.RUnlock()
return rs.lastSequenceNumber, rs.lastTimeRTP, rs.lastTimeNTP, rs.initialized return rs.lastSequenceNumber, rs.lastTimeRTP, rs.lastTimeNTP, rs.initialized
} }

View File

@@ -865,7 +865,7 @@ func TestServerRecordRTCPReport(t *testing.T) {
doRecord(t, conn, "rtsp://localhost:8554/teststream", session) doRecord(t, conn, "rtsp://localhost:8554/teststream", session)
byts, _ := (&rtp.Packet{ _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: true,
@@ -875,8 +875,7 @@ func TestServerRecordRTCPReport(t *testing.T) {
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte{1, 2, 3, 4}, Payload: []byte{1, 2, 3, 4},
}).Marshal() }), &net.UDPAddr{
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[0], Port: th.ServerPorts[0],
}) })
@@ -885,14 +884,13 @@ func TestServerRecordRTCPReport(t *testing.T) {
// wait for the packet's SSRC to be saved // wait for the packet's SSRC to be saved
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
byts, _ = (&rtcp.SenderReport{ _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{
SSRC: 753621, SSRC: 753621,
NTPTime: ntpTimeGoToRTCP(time.Date(2018, 2, 20, 19, 0, 0, 0, time.UTC)), NTPTime: ntpTimeGoToRTCP(time.Date(2018, 2, 20, 19, 0, 0, 0, time.UTC)),
RTPTime: 54352, RTPTime: 54352,
PacketCount: 1, PacketCount: 1,
OctetCount: 4, OctetCount: 4,
}).Marshal() }), &net.UDPAddr{
_, err = l2.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[1], Port: th.ServerPorts[1],
}) })
@@ -1198,12 +1196,15 @@ func TestServerRecordDecodeErrors(t *testing.T) {
{"udp", "rtp invalid"}, {"udp", "rtp invalid"},
{"udp", "rtcp invalid"}, {"udp", "rtcp invalid"},
{"udp", "rtp packets lost"}, {"udp", "rtp packets lost"},
{"udp", "rtp too big"},
{"udp", "rtcp too big"},
{"udp", "rtp unknown format"}, {"udp", "rtp unknown format"},
{"udp", "wrong ssrc"},
{"udp", "rtcp too big"},
{"udp", "rtp too big"},
{"tcp", "rtcp invalid"}, {"tcp", "rtcp invalid"},
{"tcp", "rtcp too big"}, {"tcp", "rtp packets lost"},
{"tcp", "rtp unknown format"}, {"tcp", "rtp unknown format"},
{"tcp", "wrong ssrc"},
{"tcp", "rtcp too big"},
} { } {
t.Run(ca.proto+" "+ca.name, func(t *testing.T) { t.Run(ca.proto+" "+ca.name, func(t *testing.T) {
errorRecv := make(chan struct{}) errorRecv := make(chan struct{})
@@ -1226,21 +1227,22 @@ func TestServerRecordDecodeErrors(t *testing.T) {
}, nil }, nil
}, },
onPacketLost: func(ctx *ServerHandlerOnPacketLostCtx) { onPacketLost: func(ctx *ServerHandlerOnPacketLostCtx) {
if ca.proto == "udp" && ca.name == "rtp packets lost" { require.EqualError(t, ctx.Error, "69 RTP packets lost")
require.EqualError(t, ctx.Error, "69 RTP packets lost")
}
close(errorRecv) close(errorRecv)
}, },
onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) { onDecodeError: func(ctx *ServerHandlerOnDecodeErrorCtx) {
switch { switch {
case ca.proto == "udp" && ca.name == "rtp invalid": case ca.name == "rtp invalid":
require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4") require.EqualError(t, ctx.Error, "RTP header size insufficient: 2 < 4")
case ca.name == "rtcp invalid": case ca.name == "rtcp invalid":
require.EqualError(t, ctx.Error, "rtcp: packet too short") require.EqualError(t, ctx.Error, "rtcp: packet too short")
case ca.proto == "udp" && ca.name == "rtp too big": case ca.name == "rtp unknown format":
require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP") require.EqualError(t, ctx.Error, "received RTP packet with unknown format: 111")
case ca.name == "wrong ssrc":
require.EqualError(t, ctx.Error, "received packet with wrong SSRC 456, expected 123")
case ca.proto == "udp" && ca.name == "rtcp too big": case ca.proto == "udp" && ca.name == "rtcp too big":
require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP") require.EqualError(t, ctx.Error, "RTCP packet is too big to be read with UDP")
@@ -1248,8 +1250,11 @@ func TestServerRecordDecodeErrors(t *testing.T) {
case ca.proto == "tcp" && ca.name == "rtcp too big": case ca.proto == "tcp" && ca.name == "rtcp too big":
require.EqualError(t, ctx.Error, "RTCP packet size (2000) is greater than maximum allowed (1472)") require.EqualError(t, ctx.Error, "RTCP packet size (2000) is greater than maximum allowed (1472)")
case ca.name == "rtp unknown format": case ca.proto == "udp" && ca.name == "rtp too big":
require.EqualError(t, ctx.Error, "received RTP packet with unknown format: 111") require.EqualError(t, ctx.Error, "RTP packet is too big to be read with UDP")
default:
t.Errorf("unexpected")
} }
close(errorRecv) close(errorRecv)
}, },
@@ -1317,47 +1322,91 @@ func TestServerRecordDecodeErrors(t *testing.T) {
doRecord(t, conn, "rtsp://localhost:8554/teststream", session) doRecord(t, conn, "rtsp://localhost:8554/teststream", session)
var writeRTP func(buf []byte)
var writeRTCP func(byts []byte)
if ca.proto == "udp" { //nolint:dupl
writeRTP = func(byts []byte) {
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0],
})
require.NoError(t, err)
}
writeRTCP = func(byts []byte) {
_, err = l2.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[1],
})
require.NoError(t, err)
}
} else {
writeRTP = func(byts []byte) {
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: byts,
}, make([]byte, 2048))
require.NoError(t, err)
}
writeRTCP = func(byts []byte) {
err = conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: byts,
}, make([]byte, 2048))
require.NoError(t, err)
}
}
switch { //nolint:dupl switch { //nolint:dupl
case ca.proto == "udp" && ca.name == "rtp invalid": case ca.name == "rtp invalid":
_, err := l1.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ writeRTP([]byte{0x01, 0x02})
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0],
})
require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtcp invalid": case ca.name == "rtcp invalid":
_, err := l2.WriteTo([]byte{0x01, 0x02}, &net.UDPAddr{ writeRTCP([]byte{0x01, 0x02})
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[1],
})
require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtp packets lost": case ca.name == "rtcp too big":
byts, _ := rtp.Packet{ writeRTCP(bytes.Repeat([]byte{0x01, 0x02}, 2000/2))
case ca.name == "rtp packets lost":
writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
PayloadType: 97, PayloadType: 97,
SequenceNumber: 30, SequenceNumber: 30,
}, },
}.Marshal() }))
_, err := l1.WriteTo(byts, &net.UDPAddr{ writeRTP(mustMarshalPacketRTP(&rtp.Packet{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0],
})
require.NoError(t, err)
byts, _ = rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
PayloadType: 97, PayloadType: 97,
SequenceNumber: 100, SequenceNumber: 100,
}, },
}.Marshal() }))
_, err = l1.WriteTo(byts, &net.UDPAddr{ case ca.name == "rtp unknown format":
IP: net.ParseIP("127.0.0.1"), writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Port: resTH.ServerPorts[0], Header: rtp.Header{
}) PayloadType: 111,
require.NoError(t, err) },
}))
case ca.name == "wrong ssrc":
writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{
PayloadType: 97,
SequenceNumber: 1,
SSRC: 123,
},
}))
writeRTP(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{
PayloadType: 97,
SequenceNumber: 2,
SSRC: 456,
},
}))
case ca.proto == "udp" && ca.name == "rtp too big": case ca.proto == "udp" && ca.name == "rtp too big":
_, err := l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{ _, err := l1.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
@@ -1365,53 +1414,6 @@ func TestServerRecordDecodeErrors(t *testing.T) {
Port: resTH.ServerPorts[0], Port: resTH.ServerPorts[0],
}) })
require.NoError(t, err) require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtcp too big":
_, err := l2.WriteTo(bytes.Repeat([]byte{0x01, 0x02}, 2000/2), &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[1],
})
require.NoError(t, err)
case ca.proto == "udp" && ca.name == "rtp unknown format":
byts, _ := rtp.Packet{
Header: rtp.Header{
PayloadType: 111,
},
}.Marshal()
_, err := l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: resTH.ServerPorts[0],
})
require.NoError(t, err)
case ca.proto == "tcp" && ca.name == "rtcp invalid":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: []byte{0x01, 0x02},
}, make([]byte, 2048))
require.NoError(t, err)
case ca.proto == "tcp" && ca.name == "rtcp too big":
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 1,
Payload: bytes.Repeat([]byte{0x01, 0x02}, 2000/2),
}, make([]byte, 2048))
require.NoError(t, err)
case ca.proto == "tcp" && ca.name == "rtp unknown format":
byts, _ := rtp.Packet{
Header: rtp.Header{
PayloadType: 111,
},
}.Marshal()
err := conn.WriteInterleavedFrame(&base.InterleavedFrame{
Channel: 0,
Payload: byts,
}, make([]byte, 2048))
require.NoError(t, err)
} }
<-errorRecv <-errorRecv
@@ -1498,7 +1500,7 @@ func TestServerRecordPacketNTP(t *testing.T) {
doRecord(t, conn, "rtsp://localhost:8554/teststream", session) doRecord(t, conn, "rtsp://localhost:8554/teststream", session)
byts, _ := (&rtp.Packet{ _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: true,
@@ -1508,8 +1510,7 @@ func TestServerRecordPacketNTP(t *testing.T) {
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte{1, 2, 3, 4}, Payload: []byte{1, 2, 3, 4},
}).Marshal() }), &net.UDPAddr{
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[0], Port: th.ServerPorts[0],
}) })
@@ -1518,14 +1519,13 @@ func TestServerRecordPacketNTP(t *testing.T) {
// wait for the packet's SSRC to be saved // wait for the packet's SSRC to be saved
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
byts, _ = (&rtcp.SenderReport{ _, err = l2.WriteTo(mustMarshalPacketRTCP(&rtcp.SenderReport{
SSRC: 753621, SSRC: 753621,
NTPTime: ntpTimeGoToRTCP(time.Date(2018, 2, 20, 19, 0, 0, 0, time.UTC)), NTPTime: ntpTimeGoToRTCP(time.Date(2018, 2, 20, 19, 0, 0, 0, time.UTC)),
RTPTime: 54352, RTPTime: 54352,
PacketCount: 1, PacketCount: 1,
OctetCount: 4, OctetCount: 4,
}).Marshal() }), &net.UDPAddr{
_, err = l2.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[1], Port: th.ServerPorts[1],
}) })
@@ -1533,7 +1533,7 @@ func TestServerRecordPacketNTP(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
byts, _ = (&rtp.Packet{ _, err = l1.WriteTo(mustMarshalPacketRTP(&rtp.Packet{
Header: rtp.Header{ Header: rtp.Header{
Version: 2, Version: 2,
Marker: true, Marker: true,
@@ -1543,8 +1543,7 @@ func TestServerRecordPacketNTP(t *testing.T) {
SSRC: 753621, SSRC: 753621,
}, },
Payload: []byte{1, 2, 3, 4}, Payload: []byte{1, 2, 3, 4},
}).Marshal() }), &net.UDPAddr{
_, err = l1.WriteTo(byts, &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: th.ServerPorts[0], Port: th.ServerPorts[0],
}) })

View File

@@ -758,7 +758,7 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
th := headers.Transport{} th := headers.Transport{}
if ss.state == ServerSessionStatePrePlay { if ss.state == ServerSessionStatePrePlay {
ssrc, ok := stream.lastSSRC(medi) ssrc, ok := stream.senderSSRC(medi)
if ok { if ok {
th.SSRC = &ssrc th.SSRC = &ssrc
} }

View File

@@ -58,7 +58,6 @@ func (sf *serverSessionFormat) start() {
func (sf *serverSessionFormat) stop() { func (sf *serverSessionFormat) stop() {
if sf.rtcpReceiver != nil { if sf.rtcpReceiver != nil {
sf.rtcpReceiver.Close() sf.rtcpReceiver.Close()
sf.rtcpReceiver = nil
} }
} }
@@ -77,7 +76,12 @@ func (sf *serverSessionFormat) readRTPUDP(pkt *rtp.Packet, now time.Time) {
} }
for _, pkt := range packets { for _, pkt := range packets {
sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt)) err := sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt))
if err != nil {
sf.sm.ss.onDecodeError(err)
continue
}
sf.onPacketRTP(pkt) sf.onPacketRTP(pkt)
} }
} }
@@ -97,6 +101,12 @@ func (sf *serverSessionFormat) readRTPTCP(pkt *rtp.Packet) {
} }
now := sf.sm.ss.s.timeNow() now := sf.sm.ss.s.timeNow()
sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt))
err := sf.rtcpReceiver.ProcessPacket(pkt, now, sf.format.PTSEqualsDTS(pkt))
if err != nil {
sf.sm.ss.onDecodeError(err)
return
}
sf.onPacketRTP(pkt) sf.onPacketRTP(pkt)
} }

View File

@@ -108,6 +108,16 @@ func (sm *serverSessionMedia) stop() {
} }
} }
func (sm *serverSessionMedia) findFormatWithSSRC(ssrc uint32) *serverSessionFormat {
for _, format := range sm.formats {
tssrc, ok := format.rtcpReceiver.SenderSSRC()
if ok && tssrc == ssrc {
return format
}
}
return nil
}
func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) { func (sm *serverSessionMedia) writePacketRTPInQueueUDP(payload []byte) {
atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload))) atomic.AddUint64(sm.ss.bytesSent, uint64(len(payload)))
sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck sm.ss.s.udpRTPListener.write(payload, sm.udpRTPWriteAddr) //nolint:errcheck
@@ -218,7 +228,7 @@ func (sm *serverSessionMedia) readRTCPUDPRecord(payload []byte) {
for _, pkt := range packets { for _, pkt := range packets {
if sr, ok := pkt.(*rtcp.SenderReport); ok { if sr, ok := pkt.(*rtcp.SenderReport); ok {
format := serverFindFormatWithSSRC(sm.formats, sr.SSRC) format := sm.findFormatWithSSRC(sr.SSRC)
if format != nil { if format != nil {
format.rtcpReceiver.ProcessSenderReport(sr, now) format.rtcpReceiver.ProcessSenderReport(sr, now)
} }
@@ -283,7 +293,7 @@ func (sm *serverSessionMedia) readRTCPTCPRecord(payload []byte) {
for _, pkt := range packets { for _, pkt := range packets {
if sr, ok := pkt.(*rtcp.SenderReport); ok { if sr, ok := pkt.(*rtcp.SenderReport); ok {
format := serverFindFormatWithSSRC(sm.formats, sr.SSRC) format := sm.findFormatWithSSRC(sr.SSRC)
if format != nil { if format != nil {
format.rtcpReceiver.ProcessSenderReport(sr, now) format.rtcpReceiver.ProcessSenderReport(sr, now)
} }

View File

@@ -13,6 +13,16 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/media" "github.com/bluenviron/gortsplib/v4/pkg/media"
) )
func firstFormat(formats map[uint8]*serverStreamFormat) *serverStreamFormat {
var firstKey uint8
for key := range formats {
firstKey = key
break
}
return formats[firstKey]
}
// ServerStream represents a data stream. // ServerStream represents a data stream.
// This is in charge of // This is in charge of
// - distributing the stream to each reader // - distributing the stream to each reader
@@ -66,26 +76,20 @@ func (st *ServerStream) Medias() media.Medias {
return st.medias return st.medias
} }
func (st *ServerStream) lastSSRC(medi *media.Media) (uint32, bool) { func (st *ServerStream) senderSSRC(medi *media.Media) (uint32, bool) {
st.mutex.Lock() st.mutex.Lock()
defer st.mutex.Unlock() defer st.mutex.Unlock()
sm := st.streamMedias[medi] sm := st.streamMedias[medi]
// since lastSSRC() is used to fill SSRC inside the Transport header, // senderSSRC() is used to fill SSRC inside the Transport header.
// if there are multiple formats inside a single media stream, // if there are multiple formats inside a single media stream,
// do not return anything, since Transport headers don't support multiple SSRCs. // do not return anything, since Transport headers don't support multiple SSRCs.
if len(sm.formats) > 1 { if len(sm.formats) > 1 {
return 0, false return 0, false
} }
var firstKey uint8 return firstFormat(sm.formats).rtcpSender.SenderSSRC()
for key := range sm.formats {
firstKey = key
break
}
return sm.formats[firstKey].rtcpSender.LastSSRC()
} }
func (st *ServerStream) rtpInfoEntry(medi *media.Media, now time.Time) *headers.RTPInfoEntry { func (st *ServerStream) rtpInfoEntry(medi *media.Media, now time.Time) *headers.RTPInfoEntry {
@@ -101,13 +105,7 @@ func (st *ServerStream) rtpInfoEntry(medi *media.Media, now time.Time) *headers.
return nil return nil
} }
var firstKey uint8 format := firstFormat(sm.formats)
for key := range sm.formats {
firstKey = key
break
}
format := sm.formats[firstKey]
lastSeqNum, lastTimeRTP, lastTimeNTP, ok := format.rtcpSender.LastPacketData() lastSeqNum, lastTimeRTP, lastTimeNTP, ok := format.rtcpSender.LastPacketData()
if !ok { if !ok {

View File

@@ -10,19 +10,6 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
func serverFindFormatWithSSRC(
formats map[uint8]*serverSessionFormat,
ssrc uint32,
) *serverSessionFormat {
for _, format := range formats {
tssrc, ok := format.rtcpReceiver.LastSSRC()
if ok && tssrc == ssrc {
return format
}
}
return nil
}
func joinMulticastGroupOnAtLeastOneInterface(p *ipv4.PacketConn, listenIP net.IP) error { func joinMulticastGroupOnAtLeastOneInterface(p *ipv4.PacketConn, listenIP net.IP) error {
intfs, err := net.Interfaces() intfs, err := net.Interfaces()
if err != nil { if err != nil {