diff --git a/errors.go b/errors.go index f97d26b6..3dcd04f0 100644 --- a/errors.go +++ b/errors.go @@ -87,4 +87,8 @@ var ( // ErrProtocolTooLarge indicates that value given for a DataChannelInit protocol is //longer then 65535 bytes ErrProtocolTooLarge = errors.New("protocol is larger then 65535 bytes") + + // ErrSenderNotCreatedByConnection indicates RemoveTrack was called with a RtpSender not created + // by this PeerConnection + ErrSenderNotCreatedByConnection = errors.New("RtpSender not created by this PeerConnection") ) diff --git a/peerconnection.go b/peerconnection.go index 4168edc0..6a2bdabb 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1085,20 +1085,22 @@ func (pc *PeerConnection) drainSRTP() { } }() - for { - srtcpSession, err := pc.dtlsTransport.getSRTCPSession() - if err != nil { - pc.log.Warnf("drainSRTP failed to open SrtcpSession: %v", err) - return - } + go func() { + for { + srtcpSession, err := pc.dtlsTransport.getSRTCPSession() + if err != nil { + pc.log.Warnf("drainSRTP failed to open SrtcpSession: %v", err) + return + } - _, ssrc, err := srtcpSession.AcceptStream() - if err != nil { - pc.log.Warnf("Failed to accept RTCP %v \n", err) - return + _, ssrc, err := srtcpSession.AcceptStream() + if err != nil { + pc.log.Warnf("Failed to accept RTCP %v \n", err) + return + } + pc.log.Errorf("Incoming unhandled RTCP ssrc(%d)", ssrc) } - pc.log.Errorf("Incoming unhandled RTCP ssrc(%d)", ssrc) - } + }() } // RemoteDescription returns pendingRemoteDescription if it is not null and @@ -1217,6 +1219,29 @@ func (pc *PeerConnection) AddTransceiver(trackOrKind RTPCodecType, init ...RtpTr return pc.AddTransceiverFromKind(trackOrKind, init...) } +// RemoveTrack removes a Track from the PeerConnection +func (pc *PeerConnection) RemoveTrack(sender *RTPSender) error { + if pc.isClosed.get() { + return &rtcerr.InvalidStateError{Err: ErrConnectionClosed} + } + + var transceiver *RTPTransceiver + for _, t := range pc.GetTransceivers() { + if t.Sender == sender { + transceiver = t + break + } + } + + if transceiver == nil { + return &rtcerr.InvalidAccessError{Err: ErrSenderNotCreatedByConnection} + } else if err := sender.Stop(); err != nil { + return err + } + + return transceiver.setSendingTrack(nil) +} + // AddTransceiverFromKind Create a new RTCRtpTransceiver(SendRecv or RecvOnly) and add it to the set of transceivers. func (pc *PeerConnection) AddTransceiverFromKind(kind RTPCodecType, init ...RtpTransceiverInit) (*RTPTransceiver, error) { if pc.isClosed.get() { @@ -1650,15 +1675,34 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re pc.startRTPReceivers(trackDetailsFromSDP(pc.log, pc.RemoteDescription().parsed), currentTransceivers) pc.startRTPSenders(currentTransceivers) - go pc.drainSRTP() + pc.drainSRTP() pc.startSCTP() } func (pc *PeerConnection) startRenegotation(currentTransceivers []*RTPTransceiver) { - // Delete orphaned Receivers TODO + trackDetails := trackDetailsFromSDP(pc.log, pc.RemoteDescription().parsed) + for _, t := range currentTransceivers { + if t.Receiver == nil || t.Receiver.Track() == nil { + continue + } else if _, ok := trackDetails[t.Receiver.Track().ssrc]; ok { + continue + } + + if err := t.Receiver.Stop(); err != nil { + pc.log.Warnf("Failed to stop RtpReceiver: %s", err) + continue + } + + receiver, err := pc.api.NewRTPReceiver(t.Receiver.kind, pc.dtlsTransport) + if err != nil { + pc.log.Warnf("Failed to create new RtpReceiver: %s", err) + continue + } + t.Receiver = receiver + } pc.startRTPSenders(currentTransceivers) - pc.startRTPReceivers(trackDetailsFromSDP(pc.log, pc.RemoteDescription().parsed), currentTransceivers) + pc.startRTPReceivers(trackDetails, currentTransceivers) } // GetRegisteredRTPCodecs gets a list of registered RTPCodec from the underlying constructed MediaEngine diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index d25d228d..f24978d5 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -994,4 +994,6 @@ func TestGetRegisteredRTPCodecs(t *testing.T) { if actualCodec != expectedCodec { t.Errorf("expected to get %v but got %v", expectedCodec, actualCodec) } + + assert.NoError(t, pc.Close()) } diff --git a/peerconnection_renegotation_test.go b/peerconnection_renegotation_test.go index bd789176..689c000d 100644 --- a/peerconnection_renegotation_test.go +++ b/peerconnection_renegotation_test.go @@ -4,14 +4,27 @@ package webrtc import ( "context" + "io" "math/rand" "testing" "time" + "github.com/pion/transport/test" "github.com/pion/webrtc/v2/pkg/media" "github.com/stretchr/testify/assert" ) +func sendVideoUntilDone(c context.Context, t *testing.T, track *Track) { + for { + select { + case <-time.After(20 * time.Millisecond): + assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Samples: 1})) + case <-c.Done(): + return + } + } +} + /* * Assert the following behaviors * - We are able to call AddTrack after signaling @@ -19,17 +32,12 @@ import ( * - We are able to re-negotiate and AddTrack is properly called */ func TestPeerConnection_Renegotation_AddTrack(t *testing.T) { - const ( - expectedTrackID = "video" - expectedTrackLabel = "pion" - ) - api := NewAPI() - // lim := test.TimeOut(time.Second * 30) - // defer lim.Stop() + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() - // report := test.CheckRoutines(t) - // defer report() + report := test.CheckRoutines(t) + defer report() api.mediaEngine.RegisterDefaultCodecs() pcOffer, pcAnswer, err := api.newPair() @@ -46,20 +54,12 @@ func TestPeerConnection_Renegotation_AddTrack(t *testing.T) { onTrackFiredFunc() }) - haveConnected, haveConnectedFunc := context.WithCancel(context.Background()) - pcOffer.OnICEConnectionStateChange(func(i ICEConnectionState) { - if i == ICEConnectionStateConnected { - haveConnectedFunc() - } - }) - assert.NoError(t, signalPair(pcOffer, pcAnswer)) - <-haveConnected.Done() _, err = pcAnswer.AddTransceiverFromKind(RTPCodecTypeVideo, RtpTransceiverInit{Direction: RTPTransceiverDirectionRecvonly}) assert.NoError(t, err) - vp8Track, err := pcOffer.NewTrack(DefaultPayloadTypeVP8, rand.Uint32(), expectedTrackID, expectedTrackLabel) + vp8Track, err := pcOffer.NewTrack(DefaultPayloadTypeVP8, rand.Uint32(), "foo", "bar") assert.NoError(t, err) _, err = pcOffer.AddTrack(vp8Track) @@ -74,17 +74,56 @@ func TestPeerConnection_Renegotation_AddTrack(t *testing.T) { haveRenegotiated.set(true) assert.NoError(t, signalPair(pcOffer, pcAnswer)) - func() { - for { - select { - case <-time.After(20 * time.Millisecond): - assert.NoError(t, vp8Track.WriteSample(media.Sample{Data: []byte{0x00}, Samples: 1})) - case <-onTrackFired.Done(): - return - } - } - }() + sendVideoUntilDone(onTrackFired, t, vp8Track) assert.NoError(t, pcOffer.Close()) assert.NoError(t, pcAnswer.Close()) } + +func TestPeerConnection_Renegotation_RemoveTrack(t *testing.T) { + api := NewAPI() + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + api.mediaEngine.RegisterDefaultCodecs() + pcOffer, pcAnswer, err := api.newPair() + if err != nil { + t.Fatal(err) + } + + _, err = pcAnswer.AddTransceiverFromKind(RTPCodecTypeVideo, RtpTransceiverInit{Direction: RTPTransceiverDirectionRecvonly}) + assert.NoError(t, err) + + vp8Track, err := pcOffer.NewTrack(DefaultPayloadTypeVP8, rand.Uint32(), "foo", "bar") + assert.NoError(t, err) + + rtpSender, err := pcOffer.AddTrack(vp8Track) + assert.NoError(t, err) + + onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background()) + trackClosed, trackClosedFunc := context.WithCancel(context.Background()) + + pcAnswer.OnTrack(func(track *Track, r *RTPReceiver) { + onTrackFiredFunc() + + for { + if _, err := track.ReadRTP(); err == io.EOF { + trackClosedFunc() + return + } + } + }) + + assert.NoError(t, signalPair(pcOffer, pcAnswer)) + sendVideoUntilDone(onTrackFired, t, vp8Track) + + assert.NoError(t, pcOffer.RemoveTrack(rtpSender)) + assert.NoError(t, signalPair(pcOffer, pcAnswer)) + + <-trackClosed.Done() + assert.NoError(t, pcOffer.Close()) + assert.NoError(t, pcAnswer.Close()) +} diff --git a/rtptransceiver.go b/rtptransceiver.go index 168acba5..be8e27c4 100644 --- a/rtptransceiver.go +++ b/rtptransceiver.go @@ -19,17 +19,17 @@ type RTPTransceiver struct { } func (t *RTPTransceiver) setSendingTrack(track *Track) error { - if track == nil { - return fmt.Errorf("track must not be nil") - } - t.Sender.track = track - switch t.Direction { - case RTPTransceiverDirectionRecvonly: + switch { + case track != nil && t.Direction == RTPTransceiverDirectionRecvonly: t.Direction = RTPTransceiverDirectionSendrecv - case RTPTransceiverDirectionInactive: + case track != nil && t.Direction == RTPTransceiverDirectionInactive: t.Direction = RTPTransceiverDirectionSendonly + case track == nil && t.Direction == RTPTransceiverDirectionSendrecv: + t.Direction = RTPTransceiverDirectionRecvonly + case track == nil && t.Direction == RTPTransceiverDirectionSendonly: + t.Direction = RTPTransceiverDirectionInactive default: return fmt.Errorf("invalid state change in RTPTransceiver.setSending") }