PeerConnection: more thread-safe

now proctected by lock:
	- CreateOffer
	- CreateAnswer
	- AddTransceiverFromKind
	- AddTransceiverFromTrack

newRTPTransceiver is no longer a PeerConnection method;
pc.addRTPTransceiver would fire onNegotiationNeeded;
pc.AddTrack, pc.RemoveTrack now hold lock for the entire function;

Fixes TestNegotiationNeededStressOneSided() by waiting til all
tracks added to pcA and the negotiation completed
This commit is contained in:
Markus Tzoe
2021-04-10 17:10:53 +08:00
committed by Markus
parent 9ab64a04ac
commit b6ca48ea6d
3 changed files with 111 additions and 125 deletions

View File

@@ -279,20 +279,18 @@ func (pc *PeerConnection) OnNegotiationNeeded(f func()) {
pc.onNegotiationNeededHandler.Store(f)
}
// onNegotiationNeeded enqueues negotiationNeededOp if necessary
// caller of this method should hold `pc.mu` lock
func (pc *PeerConnection) onNegotiationNeeded() {
// https://w3c.github.io/webrtc-pc/#updating-the-negotiation-needed-flag
// non-canon step 1
pc.mu.Lock()
defer pc.mu.Unlock()
if pc.negotiationNeededState == negotiationNeededStateRun {
pc.negotiationNeededState = negotiationNeededStateQueue
return
} else if pc.negotiationNeededState == negotiationNeededStateQueue {
return
}
pc.negotiationNeededState = negotiationNeededStateRun
pc.ops.Enqueue(pc.negotiationNeededOp)
}
@@ -316,11 +314,11 @@ func (pc *PeerConnection) negotiationNeededOp() {
// non-canon, run again if there was a request
defer func() {
pc.mu.Lock()
defer pc.mu.Unlock()
if pc.negotiationNeededState == negotiationNeededStateQueue {
defer pc.onNegotiationNeeded()
}
pc.negotiationNeededState = negotiationNeededStateEmpty
pc.mu.Unlock()
}()
// Step 2.3
@@ -572,10 +570,9 @@ func (pc *PeerConnection) getStatsID() string {
return pc.statsID
}
// hasLocalDescriptionChanged returns whether local media (rtpTransceivers) has changed
// caller of this method should hold `pc.mu` lock
func (pc *PeerConnection) hasLocalDescriptionChanged(desc *SessionDescription) bool {
pc.mu.Lock()
defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers {
m := getByMid(t.Mid(), desc)
if m == nil {
@@ -586,7 +583,6 @@ func (pc *PeerConnection) hasLocalDescriptionChanged(desc *SessionDescription) b
return true
}
}
return false
}
@@ -620,17 +616,19 @@ func (pc *PeerConnection) CreateOffer(options *OfferOptions) (SessionDescription
// steps to create an offer, a video RTCRtpTransceiver was added, requiring additional
// inspection of video system resources.
count := 0
pc.mu.Lock()
defer pc.mu.Unlock()
for {
// We cache current transceivers to ensure they aren't
// mutated during offer generation. We later check if they have
// been mutated and recompute the offer if necessary.
currentTransceivers := pc.GetTransceivers()
currentTransceivers := pc.rtpTransceivers
// in-parallel steps to create an offer
// https://w3c.github.io/webrtc-pc/#dfn-in-parallel-steps-to-create-an-offer
isPlanB := pc.configuration.SDPSemantics == SDPSemanticsPlanB
if pc.currentRemoteDescription != nil {
isPlanB = descriptionIsPlanB(pc.RemoteDescription())
isPlanB = descriptionIsPlanB(pc.currentRemoteDescription)
}
// include unmatched local transceivers
@@ -805,9 +803,10 @@ func (pc *PeerConnection) CreateAnswer(options *AnswerOptions) (SessionDescripti
if connectionRole == sdp.ConnectionRole(0) {
connectionRole = connectionRoleFromDtlsRole(defaultDtlsRoleAnswer)
}
pc.mu.Lock()
defer pc.mu.Unlock()
currentTransceivers := pc.GetTransceivers()
d, err := pc.generateMatchedSDP(currentTransceivers, useIdentity, false /*includeUnmatched */, connectionRole)
d, err := pc.generateMatchedSDP(pc.rtpTransceivers, useIdentity, false /*includeUnmatched */, connectionRole)
if err != nil {
return SessionDescription{}, err
}
@@ -933,7 +932,9 @@ func (pc *PeerConnection) setDescription(sd *SessionDescription, op stateChangeO
pc.signalingState.Set(nextState)
if pc.signalingState.Get() == SignalingStateStable {
pc.isNegotiationNeeded.set(false)
pc.mu.Lock()
pc.onNegotiationNeeded()
pc.mu.Unlock()
}
pc.onSignalingStateChange(nextState)
}
@@ -1063,9 +1064,10 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
localDirection = RTPTransceiverDirectionSendonly
}
t = pc.newRTPTransceiver(receiver, nil, localDirection, kind)
pc.onNegotiationNeeded()
t = newRTPTransceiver(receiver, nil, localDirection, kind)
pc.mu.Lock()
pc.addRTPTransceiver(t)
pc.mu.Unlock()
case direction == RTPTransceiverDirectionRecvonly:
if t.Direction() == RTPTransceiverDirectionSendrecv {
t.setDirection(RTPTransceiverDirectionSendonly)
@@ -1578,38 +1580,31 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
var transceiver *RTPTransceiver
pc.mu.Lock()
defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers {
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil {
transceiver = t
break
}
}
if transceiver != nil {
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err == nil {
err = transceiver.SetSender(sender, track)
if err != nil {
_ = sender.Stop()
transceiver.setSender(nil)
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err == nil {
err = t.SetSender(sender, track)
if err != nil {
_ = sender.Stop()
t.setSender(nil)
}
}
if err != nil {
return nil, err
}
pc.onNegotiationNeeded()
return sender, nil
}
pc.mu.Unlock()
if err != nil {
return nil, err
}
pc.onNegotiationNeeded()
return sender, nil
}
pc.mu.Unlock()
transceiver, err := pc.AddTransceiverFromTrack(track)
transceiver, err := pc.newTransceiverFromTrack(RTPTransceiverDirectionSendrecv, track)
if err != nil {
return nil, err
}
pc.addRTPTransceiver(transceiver)
return transceiver.Sender(), nil
}
@@ -1621,6 +1616,7 @@ func (pc *PeerConnection) RemoveTrack(sender *RTPSender) (err error) {
var transceiver *RTPTransceiver
pc.mu.Lock()
defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers {
if t.Sender() == sender {
transceiver = t
@@ -1628,21 +1624,41 @@ func (pc *PeerConnection) RemoveTrack(sender *RTPSender) (err error) {
}
}
if transceiver == nil {
err = &rtcerr.InvalidAccessError{Err: ErrSenderNotCreatedByConnection}
return &rtcerr.InvalidAccessError{Err: ErrSenderNotCreatedByConnection}
} else if err = sender.Stop(); err == nil {
err = transceiver.setSendingTrack(nil)
if err == nil {
pc.onNegotiationNeeded()
}
}
pc.mu.Unlock()
if err != nil {
return err
}
return
}
pc.onNegotiationNeeded()
return nil
func (pc *PeerConnection) newTransceiverFromTrack(direction RTPTransceiverDirection, track TrackLocal) (t *RTPTransceiver, err error) {
var (
r *RTPReceiver
s *RTPSender
)
switch direction {
case RTPTransceiverDirectionSendrecv:
r, err = pc.api.NewRTPReceiver(track.Kind(), pc.dtlsTransport)
if err != nil {
return
}
s, err = pc.api.NewRTPSender(track, pc.dtlsTransport)
case RTPTransceiverDirectionSendonly:
s, err = pc.api.NewRTPSender(track, pc.dtlsTransport)
default:
err = errPeerConnAddTransceiverFromTrackSupport
}
if err != nil {
return
}
return newRTPTransceiver(r, s, direction, track.Kind()), nil
}
// AddTransceiverFromKind Create a new RtpTransceiver and adds it to the set of transceivers.
func (pc *PeerConnection) AddTransceiverFromKind(kind RTPCodecType, init ...RTPTransceiverInit) (*RTPTransceiver, error) {
func (pc *PeerConnection) AddTransceiverFromKind(kind RTPCodecType, init ...RTPTransceiverInit) (t *RTPTransceiver, err error) {
if pc.isClosed.get() {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -1653,43 +1669,37 @@ func (pc *PeerConnection) AddTransceiverFromKind(kind RTPCodecType, init ...RTPT
} else if len(init) == 1 {
direction = init[0].Direction
}
switch direction {
case RTPTransceiverDirectionSendonly, RTPTransceiverDirectionSendrecv:
codecs := pc.api.mediaEngine.getCodecsByKind(kind)
if len(codecs) == 0 {
return nil, ErrNoCodecsAvailable
}
track, err := NewTrackLocalStaticSample(codecs[0].RTPCodecCapability, util.MathRandAlpha(16), util.MathRandAlpha(16))
if err != nil {
return nil, err
}
return pc.AddTransceiverFromTrack(track, init...)
t, err = pc.newTransceiverFromTrack(direction, track)
if err != nil {
return nil, err
}
case RTPTransceiverDirectionRecvonly:
receiver, err := pc.api.NewRTPReceiver(kind, pc.dtlsTransport)
if err != nil {
return nil, err
}
t := pc.newRTPTransceiver(
receiver,
nil,
RTPTransceiverDirectionRecvonly,
kind,
)
pc.onNegotiationNeeded()
return t, nil
t = newRTPTransceiver(receiver, nil, RTPTransceiverDirectionRecvonly, kind)
default:
return nil, errPeerConnAddTransceiverFromKindSupport
}
pc.mu.Lock()
pc.addRTPTransceiver(t)
pc.mu.Unlock()
return t, nil
}
// AddTransceiverFromTrack Create a new RtpTransceiver(SendRecv or SendOnly) and add it to the set of transceivers.
func (pc *PeerConnection) AddTransceiverFromTrack(track TrackLocal, init ...RTPTransceiverInit) (*RTPTransceiver, error) {
func (pc *PeerConnection) AddTransceiverFromTrack(track TrackLocal, init ...RTPTransceiverInit) (t *RTPTransceiver, err error) {
if pc.isClosed.get() {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -1701,48 +1711,13 @@ func (pc *PeerConnection) AddTransceiverFromTrack(track TrackLocal, init ...RTPT
direction = init[0].Direction
}
switch direction {
case RTPTransceiverDirectionSendrecv:
receiver, err := pc.api.NewRTPReceiver(track.Kind(), pc.dtlsTransport)
if err != nil {
return nil, err
}
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err != nil {
return nil, err
}
t := pc.newRTPTransceiver(
receiver,
sender,
RTPTransceiverDirectionSendrecv,
track.Kind(),
)
pc.onNegotiationNeeded()
return t, nil
case RTPTransceiverDirectionSendonly:
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err != nil {
return nil, err
}
t := pc.newRTPTransceiver(
nil,
sender,
RTPTransceiverDirectionSendonly,
track.Kind(),
)
pc.onNegotiationNeeded()
return t, nil
default:
return nil, errPeerConnAddTransceiverFromTrackSupport
t, err = pc.newTransceiverFromTrack(direction, track)
if err == nil {
pc.mu.Lock()
pc.addRTPTransceiver(t)
pc.mu.Unlock()
}
return
}
// CreateDataChannel creates a new DataChannel object with the given label
@@ -1820,7 +1795,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn
}
}
pc.mu.Lock()
pc.onNegotiationNeeded()
pc.mu.Unlock()
return d, nil
}
@@ -1899,22 +1876,12 @@ func (pc *PeerConnection) Close() error {
return util.FlattenErrs(closeErrs)
}
func (pc *PeerConnection) newRTPTransceiver(
receiver *RTPReceiver,
sender *RTPSender,
direction RTPTransceiverDirection,
kind RTPCodecType,
) *RTPTransceiver {
t := &RTPTransceiver{kind: kind}
t.setReceiver(receiver)
t.setSender(sender)
t.setDirection(direction)
pc.mu.Lock()
// addRTPTransceiver appends t into rtpTransceivers
// and fires onNegotiationNeeded;
// caller of this method should hold `pc.mu` lock
func (pc *PeerConnection) addRTPTransceiver(t *RTPTransceiver) {
pc.rtpTransceivers = append(pc.rtpTransceivers, t)
pc.mu.Unlock()
return t
pc.onNegotiationNeeded()
}
// CurrentLocalDescription represents the local description that was
@@ -2223,12 +2190,15 @@ func (pc *PeerConnection) generateMatchedSDP(transceivers []*RTPTransceiver, use
}
var t *RTPTransceiver
remoteDescription := pc.currentRemoteDescription
if pc.pendingRemoteDescription != nil {
remoteDescription = pc.pendingRemoteDescription
}
localTransceivers := append([]*RTPTransceiver{}, transceivers...)
detectedPlanB := descriptionIsPlanB(pc.RemoteDescription())
detectedPlanB := descriptionIsPlanB(remoteDescription)
mediaSections := []mediaSection{}
alreadyHaveApplicationMediaSection := false
for _, media := range pc.RemoteDescription().parsed.MediaDescriptions {
for _, media := range remoteDescription.parsed.MediaDescriptions {
midValue := getMidValue(media)
if midValue == "" {
return nil, errPeerConnRemoteDescriptionWithoutMidValue

View File

@@ -928,11 +928,16 @@ func TestNegotiationNeededStressOneSided(t *testing.T) {
pcA, pcB, err := newPair()
assert.NoError(t, err)
const expectedTrackCount = 500
ctx, done := context.WithCancel(context.Background())
pcA.OnNegotiationNeeded(func() {
count := len(pcA.GetTransceivers())
assert.NoError(t, signalPair(pcA, pcB))
if count == expectedTrackCount {
done()
}
})
const expectedTrackCount = 500
for i := 0; i < expectedTrackCount; i++ {
track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion")
assert.NoError(t, err)
@@ -940,9 +945,7 @@ func TestNegotiationNeededStressOneSided(t *testing.T) {
_, err = pcA.AddTrack(track)
assert.NoError(t, err)
}
pcA.ops.Done()
<-ctx.Done()
assert.Equal(t, expectedTrackCount, len(pcB.GetTransceivers()))
closePairNow(t, pcA, pcB)
}

View File

@@ -20,6 +20,19 @@ type RTPTransceiver struct {
kind RTPCodecType
}
func newRTPTransceiver(
receiver *RTPReceiver,
sender *RTPSender,
direction RTPTransceiverDirection,
kind RTPCodecType,
) *RTPTransceiver {
t := &RTPTransceiver{kind: kind}
t.setReceiver(receiver)
t.setSender(sender)
t.setDirection(direction)
return t
}
// Sender returns the RTPTransceiver's RTPSender if it has one
func (t *RTPTransceiver) Sender() *RTPSender {
if v := t.sender.Load(); v != nil {