Support PayloadTypes changing a TrackRemote

If the PayloadType changes for a SSRC update the codec on the
TrackRemote.

Resolves #1850
This commit is contained in:
digitalix
2021-06-27 13:40:51 -04:00
committed by Sean DuBois
parent 7948437b0b
commit f8a8c09949
6 changed files with 156 additions and 127 deletions

View File

@@ -24,6 +24,8 @@ const (
mediaSectionApplication = "application"
rtpOutboundMTU = 1200
rtpPayloadTypeBitmask = 0x7F
)
func defaultSrtpProtectionProfiles() []dtls.SRTPProtectionProfile {

View File

@@ -219,4 +219,6 @@ var (
errICETransportNotInNew = errors.New("ICETransport can only be called in ICETransportStateNew")
errCertificatePEMFormatError = errors.New("bad Certificate PEM format")
errRTPTooShort = errors.New("not long enough to be a RTP Packet")
)

View File

@@ -1178,22 +1178,17 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
}
go func() {
if err := receiver.Track().determinePayloadType(); err != nil {
pc.log.Warnf("Could not determine PayloadType for SSRC %d", receiver.Track().SSRC())
return
}
params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(receiver.Track().PayloadType())
b := make([]byte, receiveMTU)
n, _, err := receiver.Track().peek(b)
if err != nil {
pc.log.Warnf("no codec could be found for payloadType %d", receiver.Track().PayloadType())
pc.log.Warnf("Could not determine PayloadType for SSRC %d (%s)", receiver.Track().SSRC(), err)
return
}
receiver.Track().mu.Lock()
receiver.Track().kind = receiver.kind
receiver.Track().codec = params.Codecs[0]
receiver.Track().params = params
receiver.Track().mu.Unlock()
if err = receiver.Track().checkAndUpdateTrack(b[:n]); err != nil {
pc.log.Warnf("Failed to set codec settings for track SSRC %d (%s)", receiver.Track().SSRC(), err)
return
}
pc.onTrack(receiver.Track(), receiver)
}()

View File

@@ -140,7 +140,13 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error {
return nil
}
if _, err := track.Bind(r.context); err != nil {
codec, err := track.Bind(TrackLocalContext{
id: r.context.id,
params: r.api.mediaEngine.getRTPParametersByKind(r.track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}),
ssrc: r.context.ssrc,
writeStream: r.context.writeStream,
})
if err != nil {
// Re-bind the original track
if _, reBindErr := r.track.Bind(r.context); reBindErr != nil {
return reBindErr
@@ -149,6 +155,11 @@ func (r *RTPSender) ReplaceTrack(track TrackLocal) error {
return err
}
// Codec has changed
if r.payloadType != codec.PayloadType {
r.context.params.Codecs = []RTPCodecParameters{codec}
}
r.track = track
return nil
}

View File

@@ -3,7 +3,6 @@
package webrtc
import (
"bytes"
"context"
"errors"
"io"
@@ -40,114 +39,80 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
t.Run("Basic", func(t *testing.T) {
s := SettingEngine{}
s.DisableSRTPReplayProtection(true)
s := SettingEngine{}
s.DisableSRTPReplayProtection(true)
m := &MediaEngine{}
assert.NoError(t, m.RegisterDefaultCodecs())
m := &MediaEngine{}
assert.NoError(t, m.RegisterDefaultCodecs())
sender, receiver, err := NewAPI(WithMediaEngine(m), WithSettingEngine(s)).newPair(Configuration{})
assert.NoError(t, err)
sender, receiver, err := NewAPI(WithMediaEngine(m), WithSettingEngine(s)).newPair(Configuration{})
assert.NoError(t, err)
trackA, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
assert.NoError(t, err)
trackA, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
assert.NoError(t, err)
trackB, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
assert.NoError(t, err)
trackB, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeH264}, "video", "pion")
assert.NoError(t, err)
rtpSender, err := sender.AddTrack(trackA)
assert.NoError(t, err)
rtpSender, err := sender.AddTrack(trackA)
assert.NoError(t, err)
seenPacketA, seenPacketACancel := context.WithCancel(context.Background())
seenPacketB, seenPacketBCancel := context.WithCancel(context.Background())
seenPacketA, seenPacketACancel := context.WithCancel(context.Background())
seenPacketB, seenPacketBCancel := context.WithCancel(context.Background())
var onTrackCount uint64
receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1))
var onTrackCount uint64
receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1))
for {
pkt, _, err := track.ReadRTP()
if err != nil {
assert.True(t, errors.Is(io.EOF, err))
return
}
switch {
case bytes.Equal(pkt.Payload, []byte{0x10, 0xAA}):
seenPacketACancel()
case bytes.Equal(pkt.Payload, []byte{0x10, 0xBB}):
seenPacketBCancel()
}
for {
pkt, _, err := track.ReadRTP()
if err != nil {
assert.True(t, errors.Is(io.EOF, err))
return
}
})
assert.NoError(t, signalPair(sender, receiver))
// Block Until packet with 0xAA has been seen
func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacketA.Done():
return
default:
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
switch {
case pkt.Payload[len(pkt.Payload)-1] == 0xAA:
assert.Equal(t, track.Codec().MimeType, MimeTypeVP8)
seenPacketACancel()
case pkt.Payload[len(pkt.Payload)-1] == 0xBB:
assert.Equal(t, track.Codec().MimeType, MimeTypeH264)
seenPacketBCancel()
default:
t.Fatalf("Unexpected RTP Data % 02x", pkt.Payload[len(pkt.Payload)-1])
}
}()
assert.NoError(t, rtpSender.ReplaceTrack(trackB))
// Block Until packet with 0xBB has been seen
func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacketB.Done():
return
default:
assert.NoError(t, trackB.WriteSample(media.Sample{Data: []byte{0xBB}, Duration: time.Second}))
}
}
}()
closePairNow(t, sender, receiver)
}
})
t.Run("Invalid Codec Change", func(t *testing.T) {
sender, receiver, err := newPair()
assert.NoError(t, err)
assert.NoError(t, signalPair(sender, receiver))
trackA, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
assert.NoError(t, err)
trackB, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/h264", SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f"}, "video", "pion")
assert.NoError(t, err)
rtpSender, err := sender.AddTrack(trackA)
assert.NoError(t, err)
assert.NoError(t, signalPair(sender, receiver))
seenPacket, seenPacketCancel := context.WithCancel(context.Background())
receiver.OnTrack(func(_ *TrackRemote, _ *RTPReceiver) {
seenPacketCancel()
})
func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacket.Done():
return
default:
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
// Block Until packet with 0xAA has been seen
func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacketA.Done():
return
default:
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
}()
}
}()
assert.True(t, errors.Is(rtpSender.ReplaceTrack(trackB), ErrUnsupportedCodec))
assert.NoError(t, rtpSender.ReplaceTrack(trackB))
closePairNow(t, sender, receiver)
})
// Block Until packet with 0xBB has been seen
func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacketB.Done():
return
default:
assert.NoError(t, trackB.WriteSample(media.Sample{Data: []byte{0xBB}, Duration: time.Second}))
}
}
}()
closePairNow(t, sender, receiver)
}
func Test_RTPSender_GetParameters(t *testing.T) {
@@ -182,7 +147,7 @@ func Test_RTPSender_SetReadDeadline(t *testing.T) {
sender, receiver, wan := createVNetPair(t)
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
assert.NoError(t, err)
rtpSender, err := sender.AddTrack(track)
@@ -201,3 +166,45 @@ func Test_RTPSender_SetReadDeadline(t *testing.T) {
assert.NoError(t, wan.Stop())
closePairNow(t, sender, receiver)
}
func Test_RTPSender_ReplaceTrack_InvalidCodecChange(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
report := test.CheckRoutines(t)
defer report()
sender, receiver, err := newPair()
assert.NoError(t, err)
trackA, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion")
assert.NoError(t, err)
trackB, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "pion")
assert.NoError(t, err)
rtpSender, err := sender.AddTrack(trackA)
assert.NoError(t, err)
assert.NoError(t, signalPair(sender, receiver))
seenPacket, seenPacketCancel := context.WithCancel(context.Background())
receiver.OnTrack(func(_ *TrackRemote, _ *RTPReceiver) {
seenPacketCancel()
})
func() {
for range time.Tick(time.Millisecond * 20) {
select {
case <-seenPacket.Done():
return
default:
assert.NoError(t, trackA.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second}))
}
}
}()
assert.True(t, errors.Is(rtpSender.ReplaceTrack(trackB), ErrUnsupportedCodec))
closePairNow(t, sender, receiver)
}

View File

@@ -116,11 +116,43 @@ func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes,
// released the lock. Deal with it.
if data != nil {
n = copy(b, data)
err = t.checkAndUpdateTrack(b)
return
}
}
return r.readRTP(b, t)
n, attributes, err = r.readRTP(b, t)
if err != nil {
return
}
err = t.checkAndUpdateTrack(b)
return
}
// checkAndUpdateTrack checks payloadType for every incoming packet
// once a different payloadType is detected the track will be updated
func (t *TrackRemote) checkAndUpdateTrack(b []byte) error {
if len(b) < 2 {
return errRTPTooShort
}
if payloadType := PayloadType(b[1] & rtpPayloadTypeBitmask); payloadType != t.PayloadType() {
t.mu.Lock()
defer t.mu.Unlock()
params, err := t.receiver.api.mediaEngine.getRTPParametersByPayloadType(payloadType)
if err != nil {
return err
}
t.kind = t.receiver.kind
t.payloadType = payloadType
t.codec = params.Codecs[0]
t.params = params
}
return nil
}
// ReadRTP is a convenience method that wraps Read and unmarshals for you.
@@ -138,26 +170,6 @@ func (t *TrackRemote) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) {
return r, attributes, nil
}
// determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track
// this is useful because we can't announce it to the user until we know the payloadType
func (t *TrackRemote) determinePayloadType() error {
b := make([]byte, receiveMTU)
n, _, err := t.peek(b)
if err != nil {
return err
}
r := rtp.Packet{}
if err := r.Unmarshal(b[:n]); err != nil {
return err
}
t.mu.Lock()
t.payloadType = PayloadType(r.PayloadType)
defer t.mu.Unlock()
return nil
}
// peek is like Read, but it doesn't discard the packet read
func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error) {
n, a, err = t.Read(b)