Replace custom atomicBool with sync/atomic.Bool

- Remove custom atomicBool implementation
- Replace all atomicBool usages with standard library sync/atomic.Bool

Signed-off-by: Xiaobo Liu <cppcoffee@gmail.com>
This commit is contained in:
Xiaobo Liu
2025-06-25 16:14:29 +08:00
parent 887f5c6e0c
commit 4f67c90d22
8 changed files with 66 additions and 94 deletions

View File

@@ -1,32 +0,0 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package webrtc
import "sync/atomic"
type atomicBool struct {
val int32
}
func (b *atomicBool) set(value bool) { // nolint: unparam
var i int32
if value {
i = 1
}
atomic.StoreInt32(&(b.val), i)
}
func (b *atomicBool) get() bool {
return atomic.LoadInt32(&(b.val)) != 0
}
func (b *atomicBool) swap(value bool) bool {
var i int32
if value {
i = 1
}
return atomic.SwapInt32(&(b.val), i) != 0
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"sync"
"sync/atomic"
"testing"
"time"
@@ -450,23 +451,23 @@ func TestDataChannelParameters(t *testing.T) { //nolint:cyclop
assert.Fail(t, "OnDataChannel must not be fired when negotiated == true")
})
seenAnswerMessage := &atomicBool{}
seenOfferMessage := &atomicBool{}
seenAnswerMessage := &atomic.Bool{}
seenOfferMessage := &atomic.Bool{}
answerDatachannel.OnMessage(func(msg DataChannelMessage) {
if msg.IsString && string(msg.Data) == expectedMessage {
seenAnswerMessage.set(true)
seenAnswerMessage.Store(true)
}
})
offerDatachannel.OnMessage(func(msg DataChannelMessage) {
if msg.IsString && string(msg.Data) == expectedMessage {
seenOfferMessage.set(true)
seenOfferMessage.Store(true)
}
})
go func() {
for seenAnswerMessage.get() && seenOfferMessage.get() {
for seenAnswerMessage.Load() && seenOfferMessage.Load() {
if offerDatachannel.ReadyState() == DataChannelStateOpen {
assert.NoError(t, offerDatachannel.SendText(expectedMessage))
}

View File

@@ -6,6 +6,7 @@ package webrtc
import (
"container/list"
"sync"
"sync/atomic"
)
// Operation is a function.
@@ -17,13 +18,13 @@ type operations struct {
busyCh chan struct{}
ops *list.List
updateNegotiationNeededFlagOnEmptyChain *atomicBool
updateNegotiationNeededFlagOnEmptyChain *atomic.Bool
onNegotiationNeeded func()
isClosed bool
}
func newOperations(
updateNegotiationNeededFlagOnEmptyChain *atomicBool,
updateNegotiationNeededFlagOnEmptyChain *atomic.Bool,
onNegotiationNeeded func(),
) *operations {
return &operations{
@@ -150,9 +151,9 @@ func (o *operations) start() {
fn()
fn = o.pop()
}
if !o.updateNegotiationNeededFlagOnEmptyChain.get() {
if !o.updateNegotiationNeededFlagOnEmptyChain.Load() {
return
}
o.updateNegotiationNeededFlagOnEmptyChain.set(false)
o.updateNegotiationNeededFlagOnEmptyChain.Store(false)
o.onNegotiationNeeded()
}

View File

@@ -5,13 +5,14 @@ package webrtc
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
)
func TestOperations_Enqueue(t *testing.T) {
updateNegotiationNeededFlagOnEmptyChain := &atomicBool{}
updateNegotiationNeededFlagOnEmptyChain := &atomic.Bool{}
onNegotiationNeededCalledCount := 0
var onNegotiationNeededCalledCountMu sync.Mutex
ops := newOperations(updateNegotiationNeededFlagOnEmptyChain, func() {
@@ -29,7 +30,7 @@ func TestOperations_Enqueue(t *testing.T) {
ops.Enqueue(func() {
results[j] = j * j
if resultSetCopy > 50 {
updateNegotiationNeededFlagOnEmptyChain.set(true)
updateNegotiationNeededFlagOnEmptyChain.Store(true)
}
})
}(i)
@@ -46,14 +47,14 @@ func TestOperations_Enqueue(t *testing.T) {
}
func TestOperations_Done(*testing.T) {
ops := newOperations(&atomicBool{}, func() {
ops := newOperations(&atomic.Bool{}, func() {
})
defer ops.GracefulClose()
ops.Done()
}
func TestOperations_GracefulClose(t *testing.T) {
ops := newOperations(&atomicBool{}, func() {
ops := newOperations(&atomic.Bool{}, func() {
})
counter := 0

View File

@@ -55,12 +55,12 @@ type PeerConnection struct {
idpLoginURL *string
isClosed *atomicBool
isClosed *atomic.Bool
isGracefullyClosingOrClosed bool
isCloseDone chan struct{}
isGracefulCloseDone chan struct{}
isNegotiationNeeded *atomicBool
updateNegotiationNeededFlagOnEmptyChain *atomicBool
isNegotiationNeeded *atomic.Bool
updateNegotiationNeededFlagOnEmptyChain *atomic.Bool
lastOffer string
lastAnswer string
@@ -124,11 +124,11 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
Certificates: []Certificate{},
ICECandidatePoolSize: 0,
},
isClosed: &atomicBool{},
isClosed: &atomic.Bool{},
isCloseDone: make(chan struct{}),
isGracefulCloseDone: make(chan struct{}),
isNegotiationNeeded: &atomicBool{},
updateNegotiationNeededFlagOnEmptyChain: &atomicBool{},
isNegotiationNeeded: &atomic.Bool{},
updateNegotiationNeededFlagOnEmptyChain: &atomic.Bool{},
lastOffer: "",
lastAnswer: "",
greaterMid: -1,
@@ -296,7 +296,7 @@ func (pc *PeerConnection) onNegotiationNeeded() {
// 4.7.3.1 If the length of connection.[[Operations]] is not 0, then set
// connection.[[UpdateNegotiationNeededFlagOnEmptyChain]] to true, and abort these steps.
if !pc.ops.IsEmpty() {
pc.updateNegotiationNeededFlagOnEmptyChain.set(true)
pc.updateNegotiationNeededFlagOnEmptyChain.Store(true)
return
}
@@ -306,7 +306,7 @@ func (pc *PeerConnection) onNegotiationNeeded() {
// https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag
func (pc *PeerConnection) negotiationNeededOp() {
// 4.7.3.2.1 If connection.[[IsClosed]] is true, abort these steps.
if pc.isClosed.get() {
if pc.isClosed.Load() {
return
}
@@ -314,7 +314,7 @@ func (pc *PeerConnection) negotiationNeededOp() {
// then set connection.[[UpdateNegotiationNeededFlagOnEmptyChain]] to
// true, and abort these steps.
if !pc.ops.IsEmpty() {
pc.updateNegotiationNeededFlagOnEmptyChain.set(true)
pc.updateNegotiationNeededFlagOnEmptyChain.Store(true)
return
}
@@ -328,18 +328,18 @@ func (pc *PeerConnection) negotiationNeededOp() {
// clear the negotiation-needed flag by setting connection.[[NegotiationNeeded]]
// to false, and abort these steps.
if !pc.checkNegotiationNeeded() {
pc.isNegotiationNeeded.set(false)
pc.isNegotiationNeeded.Store(false)
return
}
// 4.7.3.2.5 If connection.[[NegotiationNeeded]] is already true, abort these steps.
if pc.isNegotiationNeeded.get() {
if pc.isNegotiationNeeded.Load() {
return
}
// 4.7.3.2.6 Set connection.[[NegotiationNeeded]] to true.
pc.isNegotiationNeeded.set(true)
pc.isNegotiationNeeded.Store(true)
// 4.7.3.2.7 Fire an event named negotiationneeded at connection.
if handler, ok := pc.onNegotiationNeededHandler.Load().(func()); ok && handler != nil {
@@ -513,7 +513,7 @@ func (pc *PeerConnection) onConnectionStateChange(cs PeerConnectionState) {
// SetConfiguration updates the configuration of this PeerConnection object.
func (pc *PeerConnection) SetConfiguration(configuration Configuration) error { //nolint:gocognit,cyclop
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2)
if pc.isClosed.get() {
if pc.isClosed.Load() {
return &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -623,7 +623,7 @@ func (pc *PeerConnection) CreateOffer(options *OfferOptions) (SessionDescription
switch {
case useIdentity:
return SessionDescription{}, errIdentityProviderNotImplemented
case pc.isClosed.get():
case pc.isClosed.Load():
return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -763,7 +763,7 @@ func (pc *PeerConnection) updateConnectionState(
connectionState := PeerConnectionStateNew
switch {
// The RTCPeerConnection object's [[IsClosed]] slot is true.
case pc.isClosed.get():
case pc.isClosed.Load():
connectionState = PeerConnectionStateClosed
// Any of the RTCIceTransports or RTCDtlsTransports are in a "failed" state.
@@ -844,7 +844,7 @@ func (pc *PeerConnection) CreateAnswer(*AnswerOptions) (SessionDescription, erro
return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrNoRemoteDescription}
case useIdentity:
return SessionDescription{}, errIdentityProviderNotImplemented
case pc.isClosed.get():
case pc.isClosed.Load():
return SessionDescription{}, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
case pc.signalingState.Get() != SignalingStateHaveRemoteOffer &&
pc.signalingState.Get() != SignalingStateHaveLocalPranswer:
@@ -891,7 +891,7 @@ func (pc *PeerConnection) CreateAnswer(*AnswerOptions) (SessionDescription, erro
//nolint:gocognit,cyclop
func (pc *PeerConnection) setDescription(sd *SessionDescription, op stateChangeOp) error {
switch {
case pc.isClosed.get():
case pc.isClosed.Load():
return &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
case NewSDPType(sd.Type.String()) == SDPTypeUnknown:
return &rtcerr.TypeError{
@@ -995,7 +995,7 @@ func (pc *PeerConnection) setDescription(sd *SessionDescription, op stateChangeO
if err == nil {
pc.signalingState.Set(nextState)
if pc.signalingState.Get() == SignalingStateStable {
pc.isNegotiationNeeded.set(false)
pc.isNegotiationNeeded.Store(false)
pc.mu.Lock()
pc.onNegotiationNeeded()
pc.mu.Unlock()
@@ -1010,7 +1010,7 @@ func (pc *PeerConnection) setDescription(sd *SessionDescription, op stateChangeO
//
//nolint:cyclop
func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error {
if pc.isClosed.get() {
if pc.isClosed.Load() {
return &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -1081,7 +1081,7 @@ func (pc *PeerConnection) LocalDescription() *SessionDescription {
//
//nolint:gocognit,gocyclo,cyclop,maintidx
func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
if pc.isClosed.get() {
if pc.isClosed.Load() {
return &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -1886,7 +1886,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { //nolint:cyclop
return
}
if pc.isClosed.get() {
if pc.isClosed.Load() {
if err = srtpReadStream.Close(); err != nil {
pc.log.Warnf("Failed to close RTP stream %v", err)
}
@@ -2076,7 +2076,7 @@ func (pc *PeerConnection) GetTransceivers() []*RTPTransceiver {
//
//nolint:cyclop
func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
if pc.isClosed.get() {
if pc.isClosed.Load() {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -2118,7 +2118,7 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
// RemoveTrack removes a Track from the PeerConnection.
func (pc *PeerConnection) RemoveTrack(sender *RTPSender) (err error) {
if pc.isClosed.get() {
if pc.isClosed.Load() {
return &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -2186,7 +2186,7 @@ func (pc *PeerConnection) AddTransceiverFromKind(
kind RTPCodecType,
init ...RTPTransceiverInit,
) (t *RTPTransceiver, err error) {
if pc.isClosed.get() {
if pc.isClosed.Load() {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -2231,7 +2231,7 @@ func (pc *PeerConnection) AddTransceiverFromTrack(
track TrackLocal,
init ...RTPTransceiverInit,
) (t *RTPTransceiver, err error) {
if pc.isClosed.get() {
if pc.isClosed.Load() {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -2259,7 +2259,7 @@ func (pc *PeerConnection) AddTransceiverFromTrack(
//nolint:cyclop
func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelInit) (*DataChannel, error) {
// https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #2)
if pc.isClosed.get() {
if pc.isClosed.Load() {
return nil, &rtcerr.InvalidStateError{Err: ErrConnectionClosed}
}
@@ -2380,7 +2380,7 @@ func (pc *PeerConnection) close(shouldGracefullyClose bool) error { //nolint:cyc
// some overlapping close cases when both normal and graceful close are used
// that should be idempotent, but be cautioned when writing new close behavior
// to preserve this property.
isAlreadyClosingOrClosed := pc.isClosed.swap(true)
isAlreadyClosingOrClosed := pc.isClosed.Swap(true)
isAlreadyGracefullyClosingOrClosed := pc.isGracefullyClosingOrClosed
if shouldGracefullyClose && !isAlreadyGracefullyClosingOrClosed {
pc.isGracefullyClosingOrClosed = true
@@ -2668,7 +2668,7 @@ func (pc *PeerConnection) startTransports(
}
pc.dtlsTransport.internalOnCloseHandler = func() {
if pc.isClosed.get() || pc.api.settingEngine.disableCloseByDTLS {
if pc.isClosed.Load() || pc.api.settingEngine.disableCloseByDTLS {
return
}

View File

@@ -19,6 +19,7 @@ import (
"regexp"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -592,8 +593,8 @@ func TestPeerConnection_IceLite(t *testing.T) {
}
func TestOnICEGatheringStateChange(t *testing.T) {
seenGathering := &atomicBool{}
seenComplete := &atomicBool{}
seenGathering := &atomic.Bool{}
seenComplete := &atomic.Bool{}
seenGatheringAndComplete := make(chan any)
@@ -607,13 +608,13 @@ func TestOnICEGatheringStateChange(t *testing.T) {
switch s { // nolint:exhaustive
case ICEGatheringStateGathering:
assert.False(t, seenGathering.get(), "Completed before gathering")
seenGathering.set(true)
assert.False(t, seenGathering.Load(), "Completed before gathering")
seenGathering.Store(true)
case ICEGatheringStateComplete:
seenComplete.set(true)
seenComplete.Store(true)
}
if seenGathering.get() && seenComplete.get() {
if seenGathering.Load() && seenComplete.Load() {
close(seenGatheringAndComplete)
}
}
@@ -913,12 +914,12 @@ func TestICERestart_Error_Handling(t *testing.T) {
offerPeerConnection.OnICEConnectionStateChange(pushICEState)
answerPeerConnection.OnICEConnectionStateChange(pushICEState)
keepPackets := &atomicBool{}
keepPackets.set(true)
keepPackets := &atomic.Bool{}
keepPackets.Store(true)
// Add a filter that monitors the traffic on the router
wan.AddChunkFilter(func(vnet.Chunk) bool {
return keepPackets.get()
return keepPackets.Load()
})
const testMessage = "testMessage"
@@ -960,13 +961,13 @@ func TestICERestart_Error_Handling(t *testing.T) {
// Drop all packets, assert we have disconnected
// and send a DataChannel message when disconnected
keepPackets.set(false)
keepPackets.Store(false)
blockUntilICEState(ICEConnectionStateFailed)
assert.NoError(t, dataChannel.SendText(testMessage))
// ICE Restart and assert we have reconnected
// block until our DataChannel message is delivered
keepPackets.set(true)
keepPackets.Store(true)
connectWithICERestart(offerPeerConnection, answerPeerConnection)
blockUntilICEState(ICEConnectionStateConnected)
assert.Equal(t, testMessage, <-dataChannelMessages)

View File

@@ -319,12 +319,12 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) { //nolint:cyclop
pcOffer, pcAnswer, wan := createVNetPair(t, nil)
keepPackets := &atomicBool{}
keepPackets.set(true)
keepPackets := &atomic.Bool{}
keepPackets.Store(true)
// Add a filter that monitors the traffic on the router
wan.AddChunkFilter(func(vnet.Chunk) bool {
return keepPackets.get()
return keepPackets.Load()
})
vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2")
@@ -349,7 +349,7 @@ func TestPeerConnection_Media_Disconnected(t *testing.T) { //nolint:cyclop
time.Sleep(time.Second)
}
keepPackets.set(false)
keepPackets.Store(false)
}
})

View File

@@ -162,10 +162,10 @@ func TestPeerConnection_Renegotiation_AddTrack(t *testing.T) {
pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)
haveRenegotiated := &atomicBool{}
haveRenegotiated := &atomic.Bool{}
onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background())
pcAnswer.OnTrack(func(*TrackRemote, *RTPReceiver) {
assert.True(t, haveRenegotiated.get(), "OnTrack was called before renegotiation")
assert.True(t, haveRenegotiated.Load(), "OnTrack was called before renegotiation")
onTrackFiredFunc()
})
@@ -189,7 +189,7 @@ func TestPeerConnection_Renegotiation_AddTrack(t *testing.T) {
time.Sleep(20 * time.Millisecond)
}
haveRenegotiated.set(true)
haveRenegotiated.Store(true)
assert.False(t, sender.isNegotiated())
offer, err := pcOffer.CreateOffer(nil)
assert.True(t, sender.isNegotiated())
@@ -283,11 +283,11 @@ func TestPeerConnection_Renegotiation_AddTrack_Rename(t *testing.T) {
pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)
haveRenegotiated := &atomicBool{}
haveRenegotiated := &atomic.Bool{}
onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background())
var atomicRemoteTrack atomic.Value
pcOffer.OnTrack(func(track *TrackRemote, _ *RTPReceiver) {
assert.True(t, haveRenegotiated.get(), "OnTrack was called before renegotiation")
assert.True(t, haveRenegotiated.Load(), "OnTrack was called before renegotiation")
onTrackFiredFunc()
atomicRemoteTrack.Store(track)
})
@@ -307,7 +307,7 @@ func TestPeerConnection_Renegotiation_AddTrack_Rename(t *testing.T) {
vp8Track.rtpTrack.id = "foo2"
vp8Track.rtpTrack.streamID = "bar2"
haveRenegotiated.set(true)
haveRenegotiated.Store(true)
assert.NoError(t, signalPair(pcOffer, pcAnswer))
sendVideoUntilDone(t, onTrackFired.Done(), []*TrackLocalStaticSample{vp8Track})