Make RTPTransceiver Stopped an atomic

Add an accessor to make getting value easy. Also add
TestPeerConnection_SkipStoppedTransceiver. This commit also cleans
up RTPTransceiver creation. We used a helper function, when we should
have just used the provide constructor
This commit is contained in:
Sean DuBois
2021-09-03 22:00:40 -04:00
parent 294595aff5
commit 6c3620093d
5 changed files with 97 additions and 59 deletions

View File

@@ -12,9 +12,20 @@ func (b *atomicBool) set(value bool) { // nolint: unparam
i = 1
}
atomic.StoreInt32(&(b.val), i)
atomic.StoreInt32(&b.val, i)
}
func (b *atomicBool) get() bool {
return atomic.LoadInt32(&(b.val)) != 0
return atomic.LoadInt32(&b.val) != 0
}
func (b *atomicBool) compareAndSwap(old, new bool) (swapped bool) {
var oldval, newval int32
if old {
oldval = 1
}
if new {
newval = 1
}
return atomic.CompareAndSwapInt32(&b.val, oldval, newval)
}

View File

@@ -370,15 +370,15 @@ func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit
for _, t := range pc.rtpTransceivers {
// https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag
// Step 5.1
// if t.stopping && !t.stopped {
// if t.stopping && !t.Stopped() {
// return true
// }
m := getByMid(t.Mid(), localDesc)
// Step 5.2
if !t.stopped && m == nil {
if !t.Stopped() && m == nil {
return true
}
if !t.stopped && m != nil {
if !t.Stopped() && m != nil {
// Step 5.3.1
if t.Direction() == RTPTransceiverDirectionSendrecv || t.Direction() == RTPTransceiverDirectionSendonly {
descMsid, okMsid := m.Attribute(sdp.AttrKeyMsid)
@@ -407,7 +407,7 @@ func (pc *PeerConnection) checkNegotiationNeeded() bool { //nolint:gocognit
}
}
// Step 5.4
if t.stopped && t.Mid() != "" {
if t.Stopped() && t.Mid() != "" {
if getByMid(t.Mid(), localDesc) != nil || getByMid(t.Mid(), remoteDesc) != nil {
return true
}
@@ -1250,7 +1250,7 @@ func (pc *PeerConnection) startRTPReceivers(incomingTracks []trackDetails, curre
}
receiver := t.Receiver()
if (incomingTrack.kind != t.kind) ||
if (incomingTrack.kind != t.Kind()) ||
(t.Direction() != RTPTransceiverDirectionRecvonly && t.Direction() != RTPTransceiverDirectionSendrecv) ||
receiver == nil ||
(receiver.haveReceived()) {
@@ -1592,7 +1592,7 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
pc.mu.Lock()
defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers {
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil {
if !t.Stopped() && t.Kind() == track.Kind() && t.Sender() == nil {
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err == nil {
err = t.SetSender(sender, track)
@@ -1853,7 +1853,7 @@ func (pc *PeerConnection) Close() error {
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4)
pc.mu.Lock()
for _, t := range pc.rtpTransceivers {
if !t.stopped {
if !t.Stopped() {
closeErrs = append(closeErrs, t.Stop())
}
}
@@ -2157,9 +2157,9 @@ func (pc *PeerConnection) generateUnmatchedSDP(transceivers []*RTPTransceiver, u
audio := make([]*RTPTransceiver, 0)
for _, t := range transceivers {
if t.kind == RTPCodecTypeVideo {
if t.Kind() == RTPCodecTypeVideo {
video = append(video, t)
} else if t.kind == RTPCodecTypeAudio {
} else if t.Kind() == RTPCodecTypeAudio {
audio = append(audio, t)
}
if sender := t.Sender(); sender != nil {
@@ -2259,8 +2259,7 @@ func (pc *PeerConnection) generateMatchedSDP(transceivers []*RTPTransceiver, use
t, localTransceivers = satisfyTypeAndDirection(kind, direction, localTransceivers)
if t == nil {
if len(mediaTransceivers) == 0 {
t = &RTPTransceiver{kind: kind, api: pc.api, codecs: pc.api.mediaEngine.getCodecsByKind(kind)}
t.setDirection(RTPTransceiverDirectionInactive)
t = newRTPTransceiver(nil, nil, RTPTransceiverDirectionInactive, kind, pc.api)
mediaTransceivers = append(mediaTransceivers, t)
}
break

View File

@@ -178,7 +178,7 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) {
certificate2, err := GenerateCertificate(secretKey2)
assert.Nil(t, err)
for _, test := range []struct {
for _, testcase := range []struct {
name string
init func() (*PeerConnection, error)
config Configuration
@@ -266,14 +266,14 @@ func TestPeerConnection_SetConfiguration_Go(t *testing.T) {
wantErr: &rtcerr.InvalidAccessError{Err: ErrNoTurnCredentials},
},
} {
pc, err := test.init()
pc, err := testcase.init()
if err != nil {
t.Errorf("SetConfiguration %q: init failed: %v", test.name, err)
t.Errorf("SetConfiguration %q: init failed: %v", testcase.name, err)
}
err = pc.SetConfiguration(test.config)
if got, want := err, test.wantErr; !reflect.DeepEqual(got, want) {
t.Errorf("SetConfiguration %q: err = %v, want %v", test.name, got, want)
err = pc.SetConfiguration(testcase.config)
if got, want := err, testcase.wantErr; !reflect.DeepEqual(got, want) {
t.Errorf("SetConfiguration %q: err = %v, want %v", testcase.name, got, want)
}
assert.NoError(t, pc.Close())
@@ -446,14 +446,7 @@ func TestPeerConnection_AnswerWithClosedConnection(t *testing.T) {
}
func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
createTransceiver := func(kind RTPCodecType, direction RTPTransceiverDirection) *RTPTransceiver {
r := &RTPTransceiver{kind: kind}
r.setDirection(direction)
return r
}
for _, test := range []struct {
for _, testcase := range []struct {
name string
kinds []RTPCodecType
@@ -466,7 +459,7 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
"Audio and Video Transceivers can not satisfy each other",
[]RTPCodecType{RTPCodecTypeVideo},
[]RTPTransceiverDirection{RTPTransceiverDirectionSendrecv},
[]*RTPTransceiver{createTransceiver(RTPCodecTypeAudio, RTPTransceiverDirectionSendrecv)},
[]*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeAudio, nil)},
[]*RTPTransceiver{nil},
},
{
@@ -488,9 +481,9 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
[]RTPCodecType{RTPCodecTypeVideo},
[]RTPTransceiverDirection{RTPTransceiverDirectionSendrecv},
[]*RTPTransceiver{createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly)},
[]*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil)},
[]*RTPTransceiver{createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly)},
[]*RTPTransceiver{newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil)},
},
{
"Don't satisfy a Sendonly with a SendRecv, later SendRecv will be marked as Inactive",
@@ -498,39 +491,39 @@ func TestPeerConnection_satisfyTypeAndDirection(t *testing.T) {
[]RTPTransceiverDirection{RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv},
[]*RTPTransceiver{
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionSendrecv),
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly),
newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeVideo, nil),
newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil),
},
[]*RTPTransceiver{
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionRecvonly),
createTransceiver(RTPCodecTypeVideo, RTPTransceiverDirectionSendrecv),
newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, nil),
newRTPTransceiver(nil, nil, RTPTransceiverDirectionSendrecv, RTPCodecTypeVideo, nil),
},
},
} {
if len(test.kinds) != len(test.directions) {
if len(testcase.kinds) != len(testcase.directions) {
t.Fatal("Kinds and Directions must be the same length")
}
got := []*RTPTransceiver{}
for i := range test.kinds {
res, filteredLocalTransceivers := satisfyTypeAndDirection(test.kinds[i], test.directions[i], test.localTransceivers)
for i := range testcase.kinds {
res, filteredLocalTransceivers := satisfyTypeAndDirection(testcase.kinds[i], testcase.directions[i], testcase.localTransceivers)
got = append(got, res)
test.localTransceivers = filteredLocalTransceivers
testcase.localTransceivers = filteredLocalTransceivers
}
if !reflect.DeepEqual(got, test.want) {
if !reflect.DeepEqual(got, testcase.want) {
gotStr := ""
for _, t := range got {
gotStr += fmt.Sprintf("%+v\n", t)
}
wantStr := ""
for _, t := range test.want {
for _, t := range testcase.want {
wantStr += fmt.Sprintf("%+v\n", t)
}
t.Errorf("satisfyTypeAndDirection %q: \ngot\n%s \nwant\n%s", test.name, gotStr, wantStr)
t.Errorf("satisfyTypeAndDirection %q: \ngot\n%s \nwant\n%s", testcase.name, gotStr, wantStr)
}
}
}
@@ -1258,7 +1251,7 @@ func TestPeerConnection_TransceiverDirection(t *testing.T) {
return err
}
for _, test := range []struct {
for _, testcase := range []struct {
name string
offerDirection RTPTransceiverDirection
answerStartDirection RTPTransceiverDirection
@@ -1319,11 +1312,11 @@ func TestPeerConnection_TransceiverDirection(t *testing.T) {
[]RTPTransceiverDirection{RTPTransceiverDirectionRecvonly, RTPTransceiverDirectionSendonly},
},
} {
offerDirection := test.offerDirection
answerStartDirection := test.answerStartDirection
answerFinalDirections := test.answerFinalDirections
offerDirection := testcase.offerDirection
answerStartDirection := testcase.answerStartDirection
answerFinalDirections := testcase.answerFinalDirections
t.Run(test.name, func(t *testing.T) {
t.Run(testcase.name, func(t *testing.T) {
pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)
@@ -1433,3 +1426,34 @@ func TestPeerConnectionNilCallback(t *testing.T) {
assert.NoError(t, pc.Close())
}
func TestPeerConnection_SkipStoppedTransceiver(t *testing.T) {
defer test.TimeOut(time.Second).Stop()
pc, err := NewPeerConnection(Configuration{})
assert.NoError(t, err)
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video1", "pion")
assert.NoError(t, err)
transceiver, err := pc.AddTransceiverFromTrack(track)
assert.NoError(t, err)
assert.Equal(t, 1, len(pc.GetTransceivers()))
assert.NoError(t, pc.RemoveTrack(transceiver.Sender()))
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
assert.NoError(t, transceiver.Stop())
}() // no error, no panic
}
wg.Wait()
track, err = NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video2", "pion")
assert.NoError(t, err)
_, err = pc.AddTrack(track) // should not use the above stopped transceiver
assert.NoError(t, err)
assert.Equal(t, 2, len(pc.GetTransceivers()))
assert.NoError(t, pc.Close())
}

View File

@@ -19,7 +19,7 @@ type RTPTransceiver struct {
codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences
stopped bool
stopped atomicBool
kind RTPCodecType
api *API
@@ -141,21 +141,26 @@ func (t *RTPTransceiver) Direction() RTPTransceiverDirection {
// Stop irreversibly stops the RTPTransceiver
func (t *RTPTransceiver) Stop() error {
if sender := t.Sender(); sender != nil {
if err := sender.Stop(); err != nil {
return err
if t.stopped.compareAndSwap(false, true) {
if sender := t.Sender(); sender != nil {
if err := sender.Stop(); err != nil {
return err
}
}
}
if receiver := t.Receiver(); receiver != nil {
if err := receiver.Stop(); err != nil {
return err
}
}
if receiver := t.Receiver(); receiver != nil {
if err := receiver.Stop(); err != nil {
return err
}
t.setDirection(RTPTransceiverDirectionInactive)
t.setDirection(RTPTransceiverDirectionInactive)
}
}
return nil
}
// Stopped indicates whether or not RTPTransceiver has been stopped
func (t *RTPTransceiver) Stopped() bool { return t.stopped.get() }
func (t *RTPTransceiver) setReceiver(r *RTPReceiver) {
if r != nil {
r.setRTPTransceiver(t)

View File

@@ -375,8 +375,7 @@ func TestPopulateSDP(t *testing.T) {
assert.NoError(t, me.RegisterDefaultCodecs())
api := NewAPI(WithMediaEngine(me))
tr := &RTPTransceiver{kind: RTPCodecTypeVideo, api: api, codecs: me.videoCodecs}
tr.setDirection(RTPTransceiverDirectionRecvonly)
tr := newRTPTransceiver(nil, nil, RTPTransceiverDirectionRecvonly, RTPCodecTypeVideo, api)
ridMap := map[string]string{
"ridkey": "some",
}