diff --git a/atomicbool.go b/atomicbool.go index c5ace62d..76caf17d 100644 --- a/atomicbool.go +++ b/atomicbool.go @@ -12,9 +12,20 @@ func (b *atomicBool) set(value bool) { // nolint: unparam i = 1 } - atomic.StoreInt32(&(b.val), i) + atomic.StoreInt32(&b.val, i) } func (b *atomicBool) get() bool { - return atomic.LoadInt32(&(b.val)) != 0 + return atomic.LoadInt32(&b.val) != 0 +} + +func (b *atomicBool) compareAndSwap(old, new bool) (swapped bool) { + var oldval, newval int32 + if old { + oldval = 1 + } + if new { + newval = 1 + } + return atomic.CompareAndSwapInt32(&b.val, oldval, newval) } diff --git a/peerconnection.go b/peerconnection.go index e708eb0f..aaa3143b 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -370,15 +370,15 @@ func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit for _, t := range pc.rtpTransceivers { // https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag // Step 5.1 - // if t.stopping && !t.stopped { + // if t.stopping && !t.Stopped() { // return true // } m := getByMid(t.Mid(), localDesc) // Step 5.2 - if !t.stopped && m == nil { + if !t.Stopped() && m == nil { return true } - if !t.stopped && m != nil { + if !t.Stopped() && m != nil { // Step 5.3.1 if t.Direction() == RTPTransceiverDirectionSendrecv || t.Direction() == RTPTransceiverDirectionSendonly { descMsid, okMsid := m.Attribute(sdp.AttrKeyMsid) @@ -407,7 +407,7 @@ func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit } } // Step 5.4 - if t.stopped && t.Mid() != "" { + if t.Stopped() && t.Mid() != "" { if getByMid(t.Mid(), localDesc) != nil || getByMid(t.Mid(), remoteDesc) != nil { return true } @@ -1250,7 +1250,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre } receiver := t.Receiver() - if (incomingTrack.kind != t.kind) || + if (incomingTrack.kind != t.Kind()) || (t.Direction() != RTPTransceiverDirectionRecvonly && t.Direction() != RTPTransceiverDirectionSendrecv) || receiver == nil || (receiver.haveReceived()) { @@ -1592,7 +1592,7 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) { pc.mu.Lock() defer pc.mu.Unlock() for _, t := range pc.rtpTransceivers { - if !t.stopped && t.kind == track.Kind() && t.Sender() == nil { + if !t.Stopped() && t.Kind() == track.Kind() && t.Sender() == nil { sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport) if err == nil { err = t.SetSender(sender, track) @@ -1853,7 +1853,7 @@ func (pc *PeerConnection) Close() error { // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4) pc.mu.Lock() for _, t := range pc.rtpTransceivers { - if !t.stopped { + if !t.Stopped() { closeErrs = append(closeErrs, t.Stop()) } } @@ -2157,9 +2157,9 @@ func (pc *PeerConnection) generateUnmatchedSDP(transceivers []*RTPTransceiver, u audio := make([]*RTPTransceiver, 0) for _, t := range transceivers { - if t.kind == RTPCodecTypeVideo { + if t.Kind() == RTPCodecTypeVideo { video = append(video, t) - } else if t.kind == RTPCodecTypeAudio { + } else if t.Kind() == RTPCodecTypeAudio { audio = append(audio, t) } if sender := t.Sender(); sender != nil { @@ -2259,8 +2259,7 @@ func (pc *PeerConnection) generateMatchedSDP(transceivers []*RTPTransceiver, use t, localTransceivers = satisfyTypeAndDirection(kind, direction, localTransceivers) if t == nil { if len(mediaTransceivers) == 0 { - t = &RTPTransceiver{kind: kind, api: pc.api, codecs: pc.api.mediaEngine.getCodecsByKind(kind)} - t.setDirection(RTPTransceiverDirectionInactive) + t = newRTPTransceiver(nil, nil, RTPTransceiverDirectionInactive, kind, pc.api) mediaTransceivers = append(mediaTransceivers, t) } break diff --git a/peerconnection_go_test.go b/peerconnection_go_test.go index f4e9cb99..13aa530b 100644 --- a/peerconnection_go_test.go +++ b/peerconnection_go_test.go @@ -178,7 +178,7 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) { certificate2, err := GenerateCertificate(secretKey2) assert.Nil(t, err) - for _, test := range []struct { + for _, testcase := range []struct { name string init func() (*PeerConnection, error) config Configuration @@ -266,14 +266,14 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) { wantErr: &rtcerr.InvalidAccessError{Err: ErrNoTurnCredentials}, }, } { - pc, err := test.init() + pc, err := testcase.init() if err != nil { - t.Errorf("SetConfiguration %q: init failed: %v", test.name, err) + t.Errorf("SetConfiguration %q: init failed: %v", testcase.name, err) } - err = pc.SetConfiguration(test.config) - if got, want := err, test.wantErr; !reflect.DeepEqual(got, want) { - t.Errorf("SetConfiguration %q: err = %v, want %v", test.name, got, want) + err = pc.SetConfiguration(testcase.config) + if got, want := err, testcase.wantErr; !reflect.DeepEqual(got, want) { + t.Errorf("SetConfiguration %q: err = %v, want %v", testcase.name, got, want) } assert.NoError(t, pc.Close()) @@ -446,14 +446,7 @@ func TestPeerConnection_AnswerWithClosedConnection(t *testing.T) { } func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) { - createTransceiver := func(kind RTPCodecType, direction RTPTransceiverDirection) *RTPTransceiver { - r := &RTPTransceiver{kind: kind} - r.setDirection(direction) - - return r - } - - for _, test := range []struct { + for _, testcase := range []struct { name string kinds []RTPCodecType @@ -466,7 +459,7 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) { "Audio and Video Transceivers can not satisfy each other", []RTPCodecType{RTPCodecTypeVideo}, []RTPTransceiverDirection{RTPTransceiverDirectionSendrecv}, - []*RTPTransceiver{createTransceiver(RTPCodecTypeAudio, RTPTransceiverDirectionSendrecv)}, + []*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeAudio, nil)}, []*RTPTransceiver{nil}, }, { @@ -488,9 +481,9 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) { []RTPCodecType{RTPCodecTypeVideo}, []RTPTransceiverDirection{RTPTransceiverDirectionSendrecv}, - []*RTPTransceiver{createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly)}, + []*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil)}, - []*RTPTransceiver{createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly)}, + []*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil)}, }, { "Don't satisfy a Sendonly with a SendRecv, later SendRecv will be marked as Inactive", @@ -498,39 +491,39 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) { []RTPTransceiverDirection{RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv}, []*RTPTransceiver{ - createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionSendrecv), - createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly), + newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeVideo, nil), + newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil), }, []*RTPTransceiver{ - createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly), - createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionSendrecv), + newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil), + newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeVideo, nil), }, }, } { - if len(test.kinds) != len(test.directions) { + if len(testcase.kinds) != len(testcase.directions) { t.Fatal("Kinds and Directions must be the same length") } got := []*RTPTransceiver{} - for i := range test.kinds { - res, filteredLocalTransceivers := satisfyTypeAndDirection(test.kinds[i], test.directions[i], test.localTransceivers) + for i := range testcase.kinds { + res, filteredLocalTransceivers := satisfyTypeAndDirection(testcase.kinds[i], testcase.directions[i], testcase.localTransceivers) got = append(got, res) - test.localTransceivers = filteredLocalTransceivers + testcase.localTransceivers = filteredLocalTransceivers } - if !reflect.DeepEqual(got, test.want) { + if !reflect.DeepEqual(got, testcase.want) { gotStr := "" for _, t := range got { gotStr += fmt.Sprintf("%+v\n", t) } wantStr := "" - for _, t := range test.want { + for _, t := range testcase.want { wantStr += fmt.Sprintf("%+v\n", t) } - t.Errorf("satisfyTypeAndDirection %q: \ngot\n%s \nwant\n%s", test.name, gotStr, wantStr) + t.Errorf("satisfyTypeAndDirection %q: \ngot\n%s \nwant\n%s", testcase.name, gotStr, wantStr) } } } @@ -1258,7 +1251,7 @@ func TestPeerConnection_TransceiverDirection(t *testing.T) { return err } - for _, test := range []struct { + for _, testcase := range []struct { name string offerDirection RTPTransceiverDirection answerStartDirection RTPTransceiverDirection @@ -1319,11 +1312,11 @@ func TestPeerConnection_TransceiverDirection(t *testing.T) { []RTPTransceiverDirection{RTPTransceiverDirectionRecvonly, RTPTransceiverDirectionSendonly}, }, } { - offerDirection := test.offerDirection - answerStartDirection := test.answerStartDirection - answerFinalDirections := test.answerFinalDirections + offerDirection := testcase.offerDirection + answerStartDirection := testcase.answerStartDirection + answerFinalDirections := testcase.answerFinalDirections - t.Run(test.name, func(t *testing.T) { + t.Run(testcase.name, func(t *testing.T) { pcOffer, pcAnswer, err := newPair() assert.NoError(t, err) @@ -1433,3 +1426,34 @@ func TestPeerConnectionNilCallback(t *testing.T) { assert.NoError(t, pc.Close()) } + +func TestPeerConnection_SkipStoppedTransceiver(t *testing.T) { + defer test.TimeOut(time.Second).Stop() + + pc, err := NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video1", "pion") + assert.NoError(t, err) + + transceiver, err := pc.AddTransceiverFromTrack(track) + assert.NoError(t, err) + assert.Equal(t, 1, len(pc.GetTransceivers())) + assert.NoError(t, pc.RemoveTrack(transceiver.Sender())) + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + assert.NoError(t, transceiver.Stop()) + }() // no error, no panic + } + wg.Wait() + track, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video2", "pion") + assert.NoError(t, err) + _, err = pc.AddTrack(track) // should not use the above stopped transceiver + assert.NoError(t, err) + assert.Equal(t, 2, len(pc.GetTransceivers())) + + assert.NoError(t, pc.Close()) +} diff --git a/rtptransceiver.go b/rtptransceiver.go index c4d9b0ac..d596f43e 100644 --- a/rtptransceiver.go +++ b/rtptransceiver.go @@ -19,7 +19,7 @@ type RTPTransceiver struct { codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences - stopped bool + stopped atomicBool kind RTPCodecType api *API @@ -141,21 +141,26 @@ func (t *RTPTransceiver) Direction() RTPTransceiverDirection { // Stop irreversibly stops the RTPTransceiver func (t *RTPTransceiver) Stop() error { - if sender := t.Sender(); sender != nil { - if err := sender.Stop(); err != nil { - return err + if t.stopped.compareAndSwap(false, true) { + if sender := t.Sender(); sender != nil { + if err := sender.Stop(); err != nil { + return err + } } - } - if receiver := t.Receiver(); receiver != nil { - if err := receiver.Stop(); err != nil { - return err - } - } + if receiver := t.Receiver(); receiver != nil { + if err := receiver.Stop(); err != nil { + return err + } - t.setDirection(RTPTransceiverDirectionInactive) + t.setDirection(RTPTransceiverDirectionInactive) + } + } return nil } +// Stopped indicates whether or not RTPTransceiver has been stopped +func (t *RTPTransceiver) Stopped() bool { return t.stopped.get() } + func (t *RTPTransceiver) setReceiver(r *RTPReceiver) { if r != nil { r.setRTPTransceiver(t) diff --git a/sdp_test.go b/sdp_test.go index 1dde04d9..4b5b1444 100644 --- a/sdp_test.go +++ b/sdp_test.go @@ -375,8 +375,7 @@ func TestPopulateSDP(t *testing.T) { assert.NoError(t, me.RegisterDefaultCodecs()) api := NewAPI(WithMediaEngine(me)) - tr := &RTPTransceiver{kind: RTPCodecTypeVideo, api: api, codecs: me.videoCodecs} - tr.setDirection(RTPTransceiverDirectionRecvonly) + tr := newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, api) ridMap := map[string]string{ "ridkey": "some", }