Make RTPTransceiver Stopped an atomic

Add an accessor to make getting value easy. Also add
TestPeerConnection_SkipStoppedTransceiver. This commit also cleans
up RTPTransceiver creation. We used a helper function, when we should
have just used the provide constructor
This commit is contained in:
Sean DuBois
2021-09-03 22:00:40 -04:00
parent 294595aff5
commit 6c3620093d
5 changed files with 97 additions and 59 deletions

View File

@@ -12,9 +12,20 @@ func (b *atomicBool) set(value bool) { // nolint: unparam
i = 1 i = 1
} }
atomic.StoreInt32(&(b.val), i) atomic.StoreInt32(&b.val, i)
} }
func (b *atomicBool) get() bool { func (b *atomicBool) get() bool {
return atomic.LoadInt32(&(b.val)) != 0 return atomic.LoadInt32(&b.val) != 0
}
func (b *atomicBool) compareAndSwap(old, new bool) (swapped bool) {
var oldval, newval int32
if old {
oldval = 1
}
if new {
newval = 1
}
return atomic.CompareAndSwapInt32(&b.val, oldval, newval)
} }

View File

@@ -370,15 +370,15 @@ func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit
for _, t := range pc.rtpTransceivers { for _, t := range pc.rtpTransceivers {
// https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag // https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag
// Step 5.1 // Step 5.1
// if t.stopping && !t.stopped { // if t.stopping && !t.Stopped() {
// return true // return true
// } // }
m := getByMid(t.Mid(), localDesc) m := getByMid(t.Mid(), localDesc)
// Step 5.2 // Step 5.2
if !t.stopped && m == nil { if !t.Stopped() && m == nil {
return true return true
} }
if !t.stopped && m != nil { if !t.Stopped() && m != nil {
// Step 5.3.1 // Step 5.3.1
if t.Direction() == RTPTransceiverDirectionSendrecv || t.Direction() == RTPTransceiverDirectionSendonly { if t.Direction() == RTPTransceiverDirectionSendrecv || t.Direction() == RTPTransceiverDirectionSendonly {
descMsid, okMsid := m.Attribute(sdp.AttrKeyMsid) descMsid, okMsid := m.Attribute(sdp.AttrKeyMsid)
@@ -407,7 +407,7 @@ func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit
} }
} }
// Step 5.4 // Step 5.4
if t.stopped && t.Mid() != "" { if t.Stopped() && t.Mid() != "" {
if getByMid(t.Mid(), localDesc) != nil || getByMid(t.Mid(), remoteDesc) != nil { if getByMid(t.Mid(), localDesc) != nil || getByMid(t.Mid(), remoteDesc) != nil {
return true return true
} }
@@ -1250,7 +1250,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre
} }
receiver := t.Receiver() receiver := t.Receiver()
if (incomingTrack.kind != t.kind) || if (incomingTrack.kind != t.Kind()) ||
(t.Direction() != RTPTransceiverDirectionRecvonly && t.Direction() != RTPTransceiverDirectionSendrecv) || (t.Direction() != RTPTransceiverDirectionRecvonly && t.Direction() != RTPTransceiverDirectionSendrecv) ||
receiver == nil || receiver == nil ||
(receiver.haveReceived()) { (receiver.haveReceived()) {
@@ -1592,7 +1592,7 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
pc.mu.Lock() pc.mu.Lock()
defer pc.mu.Unlock() defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers { for _, t := range pc.rtpTransceivers {
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil { if !t.Stopped() && t.Kind() == track.Kind() && t.Sender() == nil {
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport) sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err == nil { if err == nil {
err = t.SetSender(sender, track) err = t.SetSender(sender, track)
@@ -1853,7 +1853,7 @@ func (pc *PeerConnection) Close() error {
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4) // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4)
pc.mu.Lock() pc.mu.Lock()
for _, t := range pc.rtpTransceivers { for _, t := range pc.rtpTransceivers {
if !t.stopped { if !t.Stopped() {
closeErrs = append(closeErrs, t.Stop()) closeErrs = append(closeErrs, t.Stop())
} }
} }
@@ -2157,9 +2157,9 @@ func (pc *PeerConnection) generateUnmatchedSDP(transceivers []*RTPTransceiver, u
audio := make([]*RTPTransceiver, 0) audio := make([]*RTPTransceiver, 0)
for _, t := range transceivers { for _, t := range transceivers {
if t.kind == RTPCodecTypeVideo { if t.Kind() == RTPCodecTypeVideo {
video = append(video, t) video = append(video, t)
} else if t.kind == RTPCodecTypeAudio { } else if t.Kind() == RTPCodecTypeAudio {
audio = append(audio, t) audio = append(audio, t)
} }
if sender := t.Sender(); sender != nil { if sender := t.Sender(); sender != nil {
@@ -2259,8 +2259,7 @@ func (pc *PeerConnection) generateMatchedSDP(transceivers []*RTPTransceiver, use
t, localTransceivers = satisfyTypeAndDirection(kind, direction, localTransceivers) t, localTransceivers = satisfyTypeAndDirection(kind, direction, localTransceivers)
if t == nil { if t == nil {
if len(mediaTransceivers) == 0 { if len(mediaTransceivers) == 0 {
t = &RTPTransceiver{kind: kind, api: pc.api, codecs: pc.api.mediaEngine.getCodecsByKind(kind)} t = newRTPTransceiver(nil, nil, RTPTransceiverDirectionInactive, kind, pc.api)
t.setDirection(RTPTransceiverDirectionInactive)
mediaTransceivers = append(mediaTransceivers, t) mediaTransceivers = append(mediaTransceivers, t)
} }
break break

View File

@@ -178,7 +178,7 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) {
certificate2, err := GenerateCertificate(secretKey2) certificate2, err := GenerateCertificate(secretKey2)
assert.Nil(t, err) assert.Nil(t, err)
for _, test := range []struct { for _, testcase := range []struct {
name string name string
init func() (*PeerConnection, error) init func() (*PeerConnection, error)
config Configuration config Configuration
@@ -266,14 +266,14 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) {
wantErr: &rtcerr.InvalidAccessError{Err: ErrNoTurnCredentials}, wantErr: &rtcerr.InvalidAccessError{Err: ErrNoTurnCredentials},
}, },
} { } {
pc, err := test.init() pc, err := testcase.init()
if err != nil { if err != nil {
t.Errorf("SetConfiguration %q: init failed: %v", test.name, err) t.Errorf("SetConfiguration %q: init failed: %v", testcase.name, err)
} }
err = pc.SetConfiguration(test.config) err = pc.SetConfiguration(testcase.config)
if got, want := err, test.wantErr; !reflect.DeepEqual(got, want) { if got, want := err, testcase.wantErr; !reflect.DeepEqual(got, want) {
t.Errorf("SetConfiguration %q: err = %v, want %v", test.name, got, want) t.Errorf("SetConfiguration %q: err = %v, want %v", testcase.name, got, want)
} }
assert.NoError(t, pc.Close()) assert.NoError(t, pc.Close())
@@ -446,14 +446,7 @@ func TestPeerConnection_AnswerWithClosedConnection(t *testing.T) {
} }
func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) { func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
createTransceiver := func(kind RTPCodecType, direction RTPTransceiverDirection) *RTPTransceiver { for _, testcase := range []struct {
r := &RTPTransceiver{kind: kind}
r.setDirection(direction)
return r
}
for _, test := range []struct {
name string name string
kinds []RTPCodecType kinds []RTPCodecType
@@ -466,7 +459,7 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
"Audio and Video Transceivers can not satisfy each other", "Audio and Video Transceivers can not satisfy each other",
[]RTPCodecType{RTPCodecTypeVideo}, []RTPCodecType{RTPCodecTypeVideo},
[]RTPTransceiverDirection{RTPTransceiverDirectionSendrecv}, []RTPTransceiverDirection{RTPTransceiverDirectionSendrecv},
[]*RTPTransceiver{createTransceiver(RTPCodecTypeAudio, RTPTransceiverDirectionSendrecv)}, []*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeAudio, nil)},
[]*RTPTransceiver{nil}, []*RTPTransceiver{nil},
}, },
{ {
@@ -488,9 +481,9 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
[]RTPCodecType{RTPCodecTypeVideo}, []RTPCodecType{RTPCodecTypeVideo},
[]RTPTransceiverDirection{RTPTransceiverDirectionSendrecv}, []RTPTransceiverDirection{RTPTransceiverDirectionSendrecv},
[]*RTPTransceiver{createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly)}, []*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil)},
[]*RTPTransceiver{createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly)}, []*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil)},
}, },
{ {
"Don't satisfy a Sendonly with a SendRecv, later SendRecv will be marked as Inactive", "Don't satisfy a Sendonly with a SendRecv, later SendRecv will be marked as Inactive",
@@ -498,39 +491,39 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
[]RTPTransceiverDirection{RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv}, []RTPTransceiverDirection{RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv},
[]*RTPTransceiver{ []*RTPTransceiver{
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionSendrecv), newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeVideo, nil),
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly), newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil),
}, },
[]*RTPTransceiver{ []*RTPTransceiver{
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly), newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil),
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionSendrecv), newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeVideo, nil),
}, },
}, },
} { } {
if len(test.kinds) != len(test.directions) { if len(testcase.kinds) != len(testcase.directions) {
t.Fatal("Kinds and Directions must be the same length") t.Fatal("Kinds and Directions must be the same length")
} }
got := []*RTPTransceiver{} got := []*RTPTransceiver{}
for i := range test.kinds { for i := range testcase.kinds {
res, filteredLocalTransceivers := satisfyTypeAndDirection(test.kinds[i], test.directions[i], test.localTransceivers) res, filteredLocalTransceivers := satisfyTypeAndDirection(testcase.kinds[i], testcase.directions[i], testcase.localTransceivers)
got = append(got, res) got = append(got, res)
test.localTransceivers = filteredLocalTransceivers testcase.localTransceivers = filteredLocalTransceivers
} }
if !reflect.DeepEqual(got, test.want) { if !reflect.DeepEqual(got, testcase.want) {
gotStr := "" gotStr := ""
for _, t := range got { for _, t := range got {
gotStr += fmt.Sprintf("%+v\n", t) gotStr += fmt.Sprintf("%+v\n", t)
} }
wantStr := "" wantStr := ""
for _, t := range test.want { for _, t := range testcase.want {
wantStr += fmt.Sprintf("%+v\n", t) wantStr += fmt.Sprintf("%+v\n", t)
} }
t.Errorf("satisfyTypeAndDirection %q: \ngot\n%s \nwant\n%s", test.name, gotStr, wantStr) t.Errorf("satisfyTypeAndDirection %q: \ngot\n%s \nwant\n%s", testcase.name, gotStr, wantStr)
} }
} }
} }
@@ -1258,7 +1251,7 @@ func TestPeerConnection_TransceiverDirection(t *testing.T) {
return err return err
} }
for _, test := range []struct { for _, testcase := range []struct {
name string name string
offerDirection RTPTransceiverDirection offerDirection RTPTransceiverDirection
answerStartDirection RTPTransceiverDirection answerStartDirection RTPTransceiverDirection
@@ -1319,11 +1312,11 @@ func TestPeerConnection_TransceiverDirection(t *testing.T) {
[]RTPTransceiverDirection{RTPTransceiverDirectionRecvonly, RTPTransceiverDirectionSendonly}, []RTPTransceiverDirection{RTPTransceiverDirectionRecvonly, RTPTransceiverDirectionSendonly},
}, },
} { } {
offerDirection := test.offerDirection offerDirection := testcase.offerDirection
answerStartDirection := test.answerStartDirection answerStartDirection := testcase.answerStartDirection
answerFinalDirections := test.answerFinalDirections answerFinalDirections := testcase.answerFinalDirections
t.Run(test.name, func(t *testing.T) { t.Run(testcase.name, func(t *testing.T) {
pcOffer, pcAnswer, err := newPair() pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err) assert.NoError(t, err)
@@ -1433,3 +1426,34 @@ func TestPeerConnectionNilCallback(t *testing.T) {
assert.NoError(t, pc.Close()) assert.NoError(t, pc.Close())
} }
func TestPeerConnection_SkipStoppedTransceiver(t *testing.T) {
defer test.TimeOut(time.Second).Stop()
pc, err := NewPeerConnection(Configuration{})
assert.NoError(t, err)
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video1", "pion")
assert.NoError(t, err)
transceiver, err := pc.AddTransceiverFromTrack(track)
assert.NoError(t, err)
assert.Equal(t, 1, len(pc.GetTransceivers()))
assert.NoError(t, pc.RemoveTrack(transceiver.Sender()))
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
assert.NoError(t, transceiver.Stop())
}() // no error, no panic
}
wg.Wait()
track, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video2", "pion")
assert.NoError(t, err)
_, err = pc.AddTrack(track) // should not use the above stopped transceiver
assert.NoError(t, err)
assert.Equal(t, 2, len(pc.GetTransceivers()))
assert.NoError(t, pc.Close())
}

View File

@@ -19,7 +19,7 @@ type RTPTransceiver struct {
codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences
stopped bool stopped atomicBool
kind RTPCodecType kind RTPCodecType
api *API api *API
@@ -141,21 +141,26 @@ func (t *RTPTransceiver) Direction() RTPTransceiverDirection {
// Stop irreversibly stops the RTPTransceiver // Stop irreversibly stops the RTPTransceiver
func (t *RTPTransceiver) Stop() error { func (t *RTPTransceiver) Stop() error {
if sender := t.Sender(); sender != nil { if t.stopped.compareAndSwap(false, true) {
if err := sender.Stop(); err != nil { if sender := t.Sender(); sender != nil {
return err if err := sender.Stop(); err != nil {
return err
}
} }
} if receiver := t.Receiver(); receiver != nil {
if receiver := t.Receiver(); receiver != nil { if err := receiver.Stop(); err != nil {
if err := receiver.Stop(); err != nil { return err
return err }
}
}
t.setDirection(RTPTransceiverDirectionInactive) t.setDirection(RTPTransceiverDirectionInactive)
}
}
return nil return nil
} }
// Stopped indicates whether or not RTPTransceiver has been stopped
func (t *RTPTransceiver) Stopped() bool { return t.stopped.get() }
func (t *RTPTransceiver) setReceiver(r *RTPReceiver) { func (t *RTPTransceiver) setReceiver(r *RTPReceiver) {
if r != nil { if r != nil {
r.setRTPTransceiver(t) r.setRTPTransceiver(t)

View File

@@ -375,8 +375,7 @@ func TestPopulateSDP(t *testing.T) {
assert.NoError(t, me.RegisterDefaultCodecs()) assert.NoError(t, me.RegisterDefaultCodecs())
api := NewAPI(WithMediaEngine(me)) api := NewAPI(WithMediaEngine(me))
tr := &RTPTransceiver{kind: RTPCodecTypeVideo, api: api, codecs: me.videoCodecs} tr := newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, api)
tr.setDirection(RTPTransceiverDirectionRecvonly)
ridMap := map[string]string{ ridMap := map[string]string{
"ridkey": "some", "ridkey": "some",
} }