From 045df4c4bfd47ae0d657af98053f18e9a707ac34 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Wed, 14 Sep 2022 15:25:16 +0800 Subject: [PATCH] Add currentDirection to RTPTransceiver add currentDirection to RTPTransceiver, don't reuse transceiver if its currentDirection is sendrecv or sendonly --- peerconnection.go | 56 ++++++++++++++++++++++++++- peerconnection_renegotiation_test.go | 57 +++++++++++++++++++++++++--- rtptransceiver.go | 22 +++++++++-- 3 files changed, 125 insertions(+), 10 deletions(-) diff --git a/peerconnection.go b/peerconnection.go index 4eeb7015..6f9e77d3 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -984,6 +984,7 @@ func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error { weAnswer := desc.Type == SDPTypeAnswer remoteDesc := pc.RemoteDescription() if weAnswer && remoteDesc != nil { + _ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, false) if err := pc.startRTPSenders(currentTransceivers); err != nil { return err } @@ -1143,6 +1144,7 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { if isRenegotation { if weOffer { + _ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true) if err = pc.startRTPSenders(currentTransceivers); err != nil { return err } @@ -1172,6 +1174,7 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { // Start the networking in a new routine since it will block until // the connection is actually established. if weOffer { + _ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true) if err := pc.startRTPSenders(currentTransceivers); err != nil { return err } @@ -1230,6 +1233,51 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece } } +func setRTPTransceiverCurrentDirection(answer *SessionDescription, currentTransceivers []*RTPTransceiver, weOffer bool) error { + currentTransceivers = append([]*RTPTransceiver{}, currentTransceivers...) + for _, media := range answer.parsed.MediaDescriptions { + midValue := getMidValue(media) + if midValue == "" { + return errPeerConnRemoteDescriptionWithoutMidValue + } + + if media.MediaName.Media == mediaSectionApplication { + continue + } + + var t *RTPTransceiver + t, currentTransceivers = findByMid(midValue, currentTransceivers) + + if t == nil { + return fmt.Errorf("%w: %q", errPeerConnTranscieverMidNil, midValue) + } + + direction := getPeerDirection(media) + if direction == RTPTransceiverDirection(Unknown) { + continue + } + + // reverse direction if it was a remote answer + if weOffer { + switch direction { + case RTPTransceiverDirectionSendonly: + direction = RTPTransceiverDirectionRecvonly + case RTPTransceiverDirectionRecvonly: + // Pion will answer recvonly with a offer recvonly transceiver, so we should + // not change the direction to sendonly if we are the offerer, otherwise this + // tranceiver can't be reuse for AddTrack + if t.Direction() != RTPTransceiverDirectionRecvonly { + direction = RTPTransceiverDirectionSendonly + } + default: + } + } + + t.setCurrentDirection(direction) + } + return nil +} + func runIfNewReceiver( incomingTrack trackDetails, transceivers []*RTPTransceiver, @@ -1706,7 +1754,13 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) { pc.mu.Lock() defer pc.mu.Unlock() for _, t := range pc.rtpTransceivers { - if !t.stopped && t.kind == track.Kind() && t.Sender() == nil { + currentDirection := t.getCurrentDirection() + // According to https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-addtrack, if the + // transceiver can be reused only if it's currentDirection never be sendrecv or sendonly. + // But that will cause sdp inflate. So we only check currentDirection's current value, + // that's worked for all browsers. + if !t.stopped && t.kind == track.Kind() && t.Sender() == nil && + !(currentDirection == RTPTransceiverDirectionSendrecv || currentDirection == RTPTransceiverDirectionSendonly) { sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport) if err == nil { err = t.SetSender(sender, track) diff --git a/peerconnection_renegotiation_test.go b/peerconnection_renegotiation_test.go index cc30fe1d..ef7e0e3a 100644 --- a/peerconnection_renegotiation_test.go +++ b/peerconnection_renegotiation_test.go @@ -128,14 +128,14 @@ func TestPeerConnection_Renegotiation_AddRecvonlyTransceiver(t *testing.T) { pcOffer.OnTrack(func(track *TrackRemote, r *RTPReceiver) { onTrackFiredFunc() }) + assert.NoError(t, signalPair(pcAnswer, pcOffer)) } else { pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) { onTrackFiredFunc() }) + assert.NoError(t, signalPair(pcOffer, pcAnswer)) } - assert.NoError(t, signalPair(pcOffer, pcAnswer)) - sendVideoUntilDone(onTrackFired.Done(), t, []*TrackLocalStaticSample{localTrack}) closePairNow(t, pcOffer, pcAnswer) @@ -380,6 +380,7 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) { offer, err = pcOffer.CreateOffer(nil) assert.NoError(t, err) + assert.NoError(t, pcOffer.SetLocalDescription(offer)) assert.Equal(t, len(offer.parsed.MediaDescriptions), 2) @@ -391,6 +392,11 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) { pcOffer.ops.Done() pcAnswer.ops.Done() + assert.NoError(t, pcAnswer.SetRemoteDescription(offer)) + answer, err = pcAnswer.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, pcOffer.SetRemoteDescription(answer)) + track3, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion3") require.NoError(t, err) @@ -468,12 +474,12 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) { require.NoError(t, pcOffer.RemoveTrack(sender1)) - sender2, err := pcOffer.AddTrack(track2) - require.NoError(t, err) - require.NoError(t, signalPair(pcOffer, pcAnswer)) <-tracksClosed + sender2, err := pcOffer.AddTrack(track2) + require.NoError(t, err) + require.NoError(t, signalPair(pcOffer, pcAnswer)) transceivers = pcOffer.GetTransceivers() require.Equal(t, 1, len(transceivers)) require.Equal(t, "0", transceivers[0].Mid()) @@ -1145,3 +1151,44 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) { closePairNow(t, pcOffer, pcAnswer) }) } + +func TestPeerConnection_Regegotiation_ReuseTransceiver(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + pcOffer, pcAnswer, err := newPair() + if err != nil { + t.Fatal(err) + } + + vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "foo", "bar") + assert.NoError(t, err) + sender, err := pcOffer.AddTrack(vp8Track) + assert.NoError(t, err) + assert.NoError(t, signalPair(pcOffer, pcAnswer)) + + assert.Equal(t, len(pcOffer.GetTransceivers()), 1) + assert.Equal(t, pcOffer.GetTransceivers()[0].getCurrentDirection(), RTPTransceiverDirectionSendonly) + assert.NoError(t, pcOffer.RemoveTrack(sender)) + assert.Equal(t, pcOffer.GetTransceivers()[0].getCurrentDirection(), RTPTransceiverDirectionSendonly) + + // should not reuse tranceiver + vp8Track2, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "foo", "bar") + assert.NoError(t, err) + sender2, err := pcOffer.AddTrack(vp8Track2) + assert.NoError(t, err) + assert.Equal(t, len(pcOffer.GetTransceivers()), 2) + assert.NoError(t, signalPair(pcOffer, pcAnswer)) + assert.True(t, sender2.rtpTransceiver == pcOffer.GetTransceivers()[1]) + + // should reuse first transceiver + sender, err = pcOffer.AddTrack(vp8Track) + assert.NoError(t, err) + assert.Equal(t, len(pcOffer.GetTransceivers()), 2) + assert.True(t, sender.rtpTransceiver == pcOffer.GetTransceivers()[0]) + + closePairNow(t, pcOffer, pcAnswer) +} diff --git a/rtptransceiver.go b/rtptransceiver.go index 776e0108..f94b9ed7 100644 --- a/rtptransceiver.go +++ b/rtptransceiver.go @@ -13,10 +13,11 @@ import ( // RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid. type RTPTransceiver struct { - mid atomic.Value // string - sender atomic.Value // *RTPSender - receiver atomic.Value // *RTPReceiver - direction atomic.Value // RTPTransceiverDirection + mid atomic.Value // string + sender atomic.Value // *RTPSender + receiver atomic.Value // *RTPReceiver + direction atomic.Value // RTPTransceiverDirection + currentDirection atomic.Value // RTPTransceiverDirection codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences @@ -38,6 +39,7 @@ func newRTPTransceiver( t.setReceiver(receiver) t.setSender(sender) t.setDirection(direction) + t.setCurrentDirection(RTPTransceiverDirection(Unknown)) return t } @@ -160,6 +162,7 @@ func (t *RTPTransceiver) Stop() error { } t.setDirection(RTPTransceiverDirectionInactive) + t.setCurrentDirection(RTPTransceiverDirectionInactive) return nil } @@ -179,6 +182,17 @@ func (t *RTPTransceiver) setDirection(d RTPTransceiverDirection) { t.direction.Store(d) } +func (t *RTPTransceiver) setCurrentDirection(d RTPTransceiverDirection) { + t.currentDirection.Store(d) +} + +func (t *RTPTransceiver) getCurrentDirection() RTPTransceiverDirection { + if v, ok := t.currentDirection.Load().(RTPTransceiverDirection); ok { + return v + } + return RTPTransceiverDirection(Unknown) +} + func (t *RTPTransceiver) setSendingTrack(track TrackLocal) error { if err := t.Sender().ReplaceTrack(track); err != nil { return err