Move taskloop into dedicated package

Reduce size of Agent and simplify code
This commit is contained in:
Steffen Vogel
2024-03-21 11:47:51 -04:00
committed by Sean DuBois
parent b36d33253b
commit fdca6c47c0
10 changed files with 227 additions and 215 deletions

181
agent.go
View File

@@ -15,8 +15,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
atomicx "github.com/pion/ice/v3/internal/atomic"
stunx "github.com/pion/ice/v3/internal/stun" stunx "github.com/pion/ice/v3/internal/stun"
"github.com/pion/ice/v3/internal/taskloop"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/mdns/v2" "github.com/pion/mdns/v2"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
@@ -36,13 +36,12 @@ type bindingRequest struct {
// Agent represents the ICE agent // Agent represents the ICE agent
type Agent struct { type Agent struct {
chanTask chan task loop *taskloop.Loop
onConnectionStateChangeHdlr atomic.Value // func(ConnectionState) onConnectionStateChangeHdlr atomic.Value // func(ConnectionState)
onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate) onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate)
onCandidateHdlr atomic.Value // func(Candidate) onCandidateHdlr atomic.Value // func(Candidate)
// State owned by the taskLoop
onConnected chan struct{} onConnected chan struct{}
onConnectedOnce sync.Once onConnectedOnce sync.Once
@@ -118,11 +117,6 @@ type Agent struct {
// 1:1 D-NAT IP address mapping // 1:1 D-NAT IP address mapping
extIPMapper *externalIPMapper extIPMapper *externalIPMapper
// State for closing
done chan struct{}
taskLoopDone chan struct{}
err atomicx.Error
gatherCandidateCancel func() gatherCandidateCancel func()
gatherCandidateDone chan struct{} gatherCandidateDone chan struct{}
@@ -147,74 +141,6 @@ type Agent struct {
proxyDialer proxy.Dialer proxyDialer proxy.Dialer
} }
type task struct {
fn func(context.Context, *Agent)
done chan struct{}
}
func (a *Agent) ok() error {
select {
case <-a.done:
return a.getErr()
default:
}
return nil
}
func (a *Agent) getErr() error {
if err := a.err.Load(); err != nil {
return err
}
return ErrClosed
}
// Run task in serial. Blocking tasks must be cancelable by context.
func (a *Agent) run(ctx context.Context, t func(context.Context, *Agent)) error {
if err := a.ok(); err != nil {
return err
}
done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case a.chanTask <- task{t, done}:
<-done
return nil
}
}
// taskLoop handles registered tasks and agent close.
func (a *Agent) taskLoop() {
defer func() {
a.deleteAllCandidates()
a.startedFn()
if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
}
a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)
a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
}
close(a.taskLoopDone)
}()
for {
select {
case <-a.done:
return
case t := <-a.chanTask:
t.fn(a.context(), a)
close(t.done)
}
}
}
// NewAgent creates a new Agent // NewAgent creates a new Agent
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
var err error var err error
@@ -247,7 +173,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
startedCtx, startedFn := context.WithCancel(context.Background()) startedCtx, startedFn := context.WithCancel(context.Background())
a := &Agent{ a := &Agent{
chanTask: make(chan task),
tieBreaker: globalMathRandomGenerator.Uint64(), tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite, lite: config.Lite,
gatheringState: GatheringStateNew, gatheringState: GatheringStateNew,
@@ -258,8 +183,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
networkTypes: config.NetworkTypes, networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}), onConnected: make(chan struct{}),
buf: packetio.NewBuffer(), buf: packetio.NewBuffer(),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
startedCh: startedCtx.Done(), startedCh: startedCtx.Done(),
startedFn: startedFn, startedFn: startedFn,
portMin: config.PortMin, portMin: config.PortMin,
@@ -333,7 +256,23 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
return nil, err return nil, err
} }
go a.taskLoop() a.loop = taskloop.New(func() {
a.removeUfragFromMux()
a.deleteAllCandidates()
a.startedFn()
if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
}
a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)
a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
}
})
// Restart is also used to initialize the agent for the first time // Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil { if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
@@ -359,10 +298,10 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
return a.run(a.context(), func(_ context.Context, agent *Agent) { return a.loop.Run(a.loop, func(_ context.Context) {
agent.isControlling = isControlling a.isControlling = isControlling
agent.remoteUfrag = remoteUfrag a.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd a.remotePwd = remotePwd
if isControlling { if isControlling {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
@@ -377,7 +316,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.selector.Start() a.selector.Start()
a.startedFn() a.startedFn()
agent.updateConnectionState(ConnectionStateChecking) a.updateConnectionState(ConnectionStateChecking)
a.requestConnectivityCheck() a.requestConnectivityCheck()
go a.connectivityChecks() //nolint:contextcheck go a.connectivityChecks() //nolint:contextcheck
@@ -389,7 +328,7 @@ func (a *Agent) connectivityChecks() {
checkingDuration := time.Time{} checkingDuration := time.Time{}
contact := func() { contact := func() {
if err := a.run(a.context(), func(_ context.Context, a *Agent) { if err := a.loop.Run(a.loop, func(_ context.Context) {
defer func() { defer func() {
lastConnectionState = a.connectionState lastConnectionState = a.connectionState
}() }()
@@ -446,7 +385,7 @@ func (a *Agent) connectivityChecks() {
contact() contact()
case <-t.C: case <-t.C:
contact() contact()
case <-a.done: case <-a.loop.Done():
t.Stop() t.Stop()
return return
} }
@@ -638,9 +577,9 @@ func (a *Agent) AddRemoteCandidate(c Candidate) error {
} }
go func() { go func() {
if err := a.run(a.context(), func(_ context.Context, agent *Agent) { if err := a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck // nolint: contextcheck
agent.addRemoteCandidate(c) a.addRemoteCandidate(c)
}); err != nil { }); err != nil {
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err) a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
return return
@@ -670,9 +609,9 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
return return
} }
if err = a.run(a.context(), func(_ context.Context, agent *Agent) { if err = a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck // nolint: contextcheck
agent.addRemoteCandidate(c) a.addRemoteCandidate(c)
}); err != nil { }); err != nil {
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err) a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
return return
@@ -695,7 +634,7 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
for i := range localIPs { for i := range localIPs {
conn := newActiveTCPConn( conn := newActiveTCPConn(
a.context(), a.loop,
net.JoinHostPort(localIPs[i].String(), "0"), net.JoinHostPort(localIPs[i].String(), "0"),
net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())), net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())),
a.log, a.log,
@@ -763,7 +702,7 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
} }
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error { func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
return a.run(ctx, func(context.Context, *Agent) { return a.loop.Run(ctx, func(context.Context) {
set := a.localCandidates[c.NetworkType()] set := a.localCandidates[c.NetworkType()]
for _, candidate := range set { for _, candidate := range set {
if candidate.Equal(c) { if candidate.Equal(c) {
@@ -799,9 +738,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) { func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
var res []Candidate var res []Candidate
err := a.run(a.context(), func(_ context.Context, agent *Agent) { err := a.loop.Run(a.loop, func(_ context.Context) {
var candidates []Candidate var candidates []Candidate
for _, set := range agent.remoteCandidates { for _, set := range a.remoteCandidates {
candidates = append(candidates, set...) candidates = append(candidates, set...)
} }
res = candidates res = candidates
@@ -817,9 +756,9 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
func (a *Agent) GetLocalCandidates() ([]Candidate, error) { func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
var res []Candidate var res []Candidate
err := a.run(a.context(), func(_ context.Context, agent *Agent) { err := a.loop.Run(a.loop, func(_ context.Context) {
var candidates []Candidate var candidates []Candidate
for _, set := range agent.localCandidates { for _, set := range a.localCandidates {
candidates = append(candidates, set...) candidates = append(candidates, set...)
} }
res = candidates res = candidates
@@ -834,9 +773,9 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
// GetLocalUserCredentials returns the local user credentials // GetLocalUserCredentials returns the local user credentials
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) { func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{}) valSet := make(chan struct{})
err = a.run(a.context(), func(_ context.Context, agent *Agent) { err = a.loop.Run(a.loop, func(_ context.Context) {
frag = agent.localUfrag frag = a.localUfrag
pwd = agent.localPwd pwd = a.localPwd
close(valSet) close(valSet)
}) })
@@ -849,9 +788,9 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
// GetRemoteUserCredentials returns the remote user credentials // GetRemoteUserCredentials returns the remote user credentials
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) { func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{}) valSet := make(chan struct{})
err = a.run(a.context(), func(_ context.Context, agent *Agent) { err = a.loop.Run(a.loop, func(_ context.Context) {
frag = agent.remoteUfrag frag = a.remoteUfrag
pwd = agent.remotePwd pwd = a.remotePwd
close(valSet) close(valSet)
}) })
@@ -875,17 +814,7 @@ func (a *Agent) removeUfragFromMux() {
// Close cleans up the Agent // Close cleans up the Agent
func (a *Agent) Close() error { func (a *Agent) Close() error {
if err := a.ok(); err != nil { return a.loop.Close()
return err
}
a.err.Store(ErrClosed)
a.removeUfragFromMux()
close(a.done)
<-a.taskLoopDone
return nil
} }
// Remove all candidates. This closes any listening sockets // Remove all candidates. This closes any listening sockets
@@ -1092,7 +1021,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
// and returns true if it is an actual remote candidate // and returns true if it is an actual remote candidate
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
var remoteCandidate Candidate var remoteCandidate Candidate
if err := a.run(local.context(), func(context.Context, *Agent) { if err := a.loop.Run(local.context(), func(context.Context) {
remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote) remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote)
if remoteCandidate != nil { if remoteCandidate != nil {
remoteCandidate.seen(false) remoteCandidate.seen(false)
@@ -1149,9 +1078,9 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
return ErrRemotePwdEmpty return ErrRemotePwdEmpty
} }
return a.run(a.context(), func(_ context.Context, agent *Agent) { return a.loop.Run(a.loop, func(_ context.Context) {
agent.remoteUfrag = remoteUfrag a.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd a.remotePwd = remotePwd
}) })
} }
@@ -1186,17 +1115,17 @@ func (a *Agent) Restart(ufrag, pwd string) error {
} }
var err error var err error
if runErr := a.run(a.context(), func(_ context.Context, agent *Agent) { if runErr := a.loop.Run(a.loop, func(_ context.Context) {
if agent.gatheringState == GatheringStateGathering { if a.gatheringState == GatheringStateGathering {
agent.gatherCandidateCancel() a.gatherCandidateCancel()
} }
// Clear all agent needed to take back to fresh state // Clear all agent needed to take back to fresh state
a.removeUfragFromMux() a.removeUfragFromMux()
agent.localUfrag = ufrag a.localUfrag = ufrag
agent.localPwd = pwd a.localPwd = pwd
agent.remoteUfrag = "" a.remoteUfrag = ""
agent.remotePwd = "" a.remotePwd = ""
a.gatheringState = GatheringStateNew a.gatheringState = GatheringStateNew
a.checklist = make([]*CandidatePair, 0) a.checklist = make([]*CandidatePair, 0)
a.pendingBindingRequests = make([]bindingRequest, 0) a.pendingBindingRequests = make([]bindingRequest, 0)
@@ -1219,7 +1148,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {
func (a *Agent) setGatheringState(newState GatheringState) error { func (a *Agent) setGatheringState(newState GatheringState) error {
done := make(chan struct{}) done := make(chan struct{})
if err := a.run(a.context(), func(context.Context, *Agent) { if err := a.loop.Run(a.loop, func(context.Context) {
if a.gatheringState != newState && newState == GatheringStateComplete { if a.gatheringState != newState && newState == GatheringStateComplete {
a.candidateNotifier.EnqueueCandidate(nil) a.candidateNotifier.EnqueueCandidate(nil)
} }

View File

@@ -22,7 +22,7 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
err = agent.run(context.Background(), func(_ context.Context, agent *Agent) { err = agent.loop.Run(context.Background(), func(_ context.Context) {
agent.setSelectedPair(candidatePair) agent.setSelectedPair(candidatePair)
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@@ -11,9 +11,9 @@ import (
// GetCandidatePairsStats returns a list of candidate pair stats // GetCandidatePairsStats returns a list of candidate pair stats
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats { func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
var res []CandidatePairStats var res []CandidatePairStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) { err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidatePairStats, 0, len(agent.checklist)) result := make([]CandidatePairStats, 0, len(a.checklist))
for _, cp := range agent.checklist { for _, cp := range a.checklist {
stat := CandidatePairStats{ stat := CandidatePairStats{
Timestamp: time.Now(), Timestamp: time.Now(),
LocalCandidateID: cp.Local.ID(), LocalCandidateID: cp.Local.ID(),
@@ -57,9 +57,9 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
// GetLocalCandidatesStats returns a list of local candidates stats // GetLocalCandidatesStats returns a list of local candidates stats
func (a *Agent) GetLocalCandidatesStats() []CandidateStats { func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
var res []CandidateStats var res []CandidateStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) { err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidateStats, 0, len(agent.localCandidates)) result := make([]CandidateStats, 0, len(a.localCandidates))
for networkType, localCandidates := range agent.localCandidates { for networkType, localCandidates := range a.localCandidates {
for _, c := range localCandidates { for _, c := range localCandidates {
relayProtocol := "" relayProtocol := ""
if c.Type() == CandidateTypeRelay { if c.Type() == CandidateTypeRelay {
@@ -94,9 +94,9 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
// GetRemoteCandidatesStats returns a list of remote candidates stats // GetRemoteCandidatesStats returns a list of remote candidates stats
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats { func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
var res []CandidateStats var res []CandidateStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) { err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidateStats, 0, len(agent.remoteCandidates)) result := make([]CandidateStats, 0, len(a.remoteCandidates))
for networkType, remoteCandidates := range agent.remoteCandidates { for networkType, remoteCandidates := range a.remoteCandidates {
for _, c := range remoteCandidates { for _, c := range remoteCandidates {
stat := CandidateStats{ stat := CandidateStats{
Timestamp: time.Now(), Timestamp: time.Now(),

View File

@@ -34,19 +34,6 @@ func (ba *BadAddr) String() string {
return "yyy" return "yyy"
} }
func runAgentTest(t *testing.T, config *AgentConfig, task func(ctx context.Context, a *Agent)) {
a, err := NewAgent(config)
if err != nil {
t.Fatalf("Error constructing ice.Agent")
}
if err := a.run(context.Background(), task); err != nil {
t.Fatalf("Agent run failure: %v", err)
}
assert.NoError(t, a.Close())
}
func TestHandlePeerReflexive(t *testing.T) { func TestHandlePeerReflexive(t *testing.T) {
report := test.CheckRoutines(t) report := test.CheckRoutines(t)
defer report() defer report()
@@ -56,8 +43,10 @@ func TestHandlePeerReflexive(t *testing.T) {
defer lim.Stop() defer lim.Stop()
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) { t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
var config AgentConfig a, err := NewAgent(&AgentConfig{})
runAgentTest(t, &config, func(_ context.Context, a *Agent) { assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
hostConfig := CandidateHostConfig{ hostConfig := CandidateHostConfig{
@@ -113,12 +102,15 @@ func TestHandlePeerReflexive(t *testing.T) {
if c.Port() != 999 { if c.Port() != 999 {
t.Fatal("Port number mismatch") t.Fatal("Port number mismatch")
} }
}) }))
assert.NoError(t, a.Close())
}) })
t.Run("Bad network type with handleInbound()", func(t *testing.T) { t.Run("Bad network type with handleInbound()", func(t *testing.T) {
var config AgentConfig a, err := NewAgent(&AgentConfig{})
runAgentTest(t, &config, func(_ context.Context, a *Agent) { assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
hostConfig := CandidateHostConfig{ hostConfig := CandidateHostConfig{
@@ -140,12 +132,16 @@ func TestHandlePeerReflexive(t *testing.T) {
if len(a.remoteCandidates) != 0 { if len(a.remoteCandidates) != 0 {
t.Fatal("bad address should not be added to the remote candidate list") t.Fatal("bad address should not be added to the remote candidate list")
} }
}) }))
assert.NoError(t, a.Close())
}) })
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) { t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
var config AgentConfig a, err := NewAgent(&AgentConfig{})
runAgentTest(t, &config, func(_ context.Context, a *Agent) { assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
tID := [stun.TransactionIDSize]byte{} tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC") copy(tID[:], "ABC")
@@ -179,7 +175,9 @@ func TestHandlePeerReflexive(t *testing.T) {
if len(a.remoteCandidates) != 0 { if len(a.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate") t.Fatal("unknown remote was able to create a candidate")
} }
}) }))
assert.NoError(t, a.Close())
}) })
} }
@@ -440,7 +438,7 @@ func TestInboundValidity(t *testing.T) {
t.Fatalf("Error constructing ice.Agent") t.Fatalf("Error constructing ice.Agent")
} }
err = a.run(context.Background(), func(_ context.Context, a *Agent) { err = a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote) a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote)
@@ -454,8 +452,10 @@ func TestInboundValidity(t *testing.T) {
}) })
t.Run("Valid bind without fingerprint", func(t *testing.T) { t.Run("Valid bind without fingerprint", func(t *testing.T) {
var config AgentConfig a, err := NewAgent(&AgentConfig{})
runAgentTest(t, &config, func(_ context.Context, a *Agent) { assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag), stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
@@ -470,7 +470,9 @@ func TestInboundValidity(t *testing.T) {
if len(a.remoteCandidates) != 1 { if len(a.remoteCandidates) != 1 {
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate") t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
} }
}) }))
assert.NoError(t, a.Close())
}) })
t.Run("Success with invalid TransactionID", func(t *testing.T) { t.Run("Success with invalid TransactionID", func(t *testing.T) {
@@ -1120,7 +1122,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
<-isFailed <-isFailed
done := make(chan struct{}) done := make(chan struct{})
assert.NoError(t, aAgent.run(context.Background(), func(context.Context, *Agent) { assert.NoError(t, aAgent.loop.Run(aAgent.loop, func(context.Context) {
assert.Equal(t, len(aAgent.remoteCandidates), 0) assert.Equal(t, len(aAgent.remoteCandidates), 0)
assert.Equal(t, len(aAgent.localCandidates), 0) assert.Equal(t, len(aAgent.localCandidates), 0)
close(done) close(done)

View File

@@ -267,7 +267,7 @@ func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
return return
} }
if err := a.run(c, func(_ context.Context, a *Agent) { if err := a.loop.Run(c, func(_ context.Context) {
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(m, c, srcAddr) a.handleInbound(m, c, srcAddr)
}); err != nil { }); err != nil {

View File

@@ -1,40 +0,0 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package ice
import (
"context"
"time"
)
func (a *Agent) context() context.Context {
return agentContext(a.done)
}
type agentContext chan struct{}
// Done implements context.Context
func (a agentContext) Done() <-chan struct{} {
return (chan struct{})(a)
}
// Err implements context.Context
func (a agentContext) Err() error {
select {
case <-(chan struct{})(a):
return ErrRunCanceled
default:
return nil
}
}
// Deadline implements context.Context
func (a agentContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Value implements context.Context
func (a agentContext) Value(interface{}) interface{} {
return nil
}

View File

@@ -42,7 +42,7 @@ func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ..
func (a *Agent) GatherCandidates() error { func (a *Agent) GatherCandidates() error {
var gatherErr error var gatherErr error
if runErr := a.run(a.context(), func(ctx context.Context, _ *Agent) { if runErr := a.loop.Run(a.loop, func(ctx context.Context) {
if a.gatheringState != GatheringStateNew { if a.gatheringState != GatheringStateNew {
gatherErr = ErrMultipleGatherAttempted gatherErr = ErrMultipleGatherAttempted
return return
@@ -495,7 +495,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
select { select {
case <-cancelCtx.Done(): case <-cancelCtx.Done():
return return
case <-a.done: case <-a.loop.Done():
_ = conn.Close() _ = conn.Close()
} }
}() }()

View File

@@ -0,0 +1,121 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package taskloop implements a task loop to run
// tasks sequentially in a separate Goroutine.
package taskloop
import (
"context"
"errors"
"time"
atomicx "github.com/pion/ice/v3/internal/atomic"
)
// errClosed indicates that the loop has been stopped
var errClosed = errors.New("the agent is closed")
type task struct {
fn func(context.Context)
done chan struct{}
}
// Loop runs submitted task serially in a dedicated Goroutine
type Loop struct {
tasks chan task
// State for closing
done chan struct{}
taskLoopDone chan struct{}
err atomicx.Error
}
// New creates and starts a new task loop
func New(onClose func()) *Loop {
l := &Loop{
tasks: make(chan task),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
}
go l.runLoop(onClose)
return l
}
// runLoop handles registered tasks and agent close.
func (l *Loop) runLoop(onClose func()) {
defer func() {
onClose()
close(l.taskLoopDone)
}()
for {
select {
case <-l.done:
return
case t := <-l.tasks:
t.fn(l)
close(t.done)
}
}
}
// Close stops the loop after finishing the execution of the current task.
// Other pending tasks will not be executed.
func (l *Loop) Close() error {
if err := l.Err(); err != nil {
return err
}
l.err.Store(errClosed)
close(l.done)
<-l.taskLoopDone
return nil
}
// Run serially executes the submitted callback.
// Blocking tasks must be cancelable by context.
func (l *Loop) Run(ctx context.Context, t func(context.Context)) error {
if err := l.Err(); err != nil {
return err
}
done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case l.tasks <- task{t, done}:
<-done
return nil
}
}
// The following methods implement context.Context for TaskLoop
// Done returns a channel that's closed when the task loop has been stopped.
func (l *Loop) Done() <-chan struct{} {
return l.done
}
// Err returns nil if the task loop is still running.
// Otherwise it return errClosed if the loop has been closed/stopped.
func (l *Loop) Err() error {
select {
case <-l.done:
return errClosed
default:
return nil
}
}
// Deadline returns the no valid time as task loops have no deadline.
func (l *Loop) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Value is not supported for task loops
func (l *Loop) Value(interface{}) interface{} {
return nil
}

View File

@@ -43,7 +43,7 @@ func (c *Conn) BytesReceived() uint64 {
} }
func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) {
err := a.ok() err := a.loop.Err()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -54,8 +54,8 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
// Block until pair selected // Block until pair selected
select { select {
case <-a.done: case <-a.loop.Done():
return nil, a.getErr() return nil, a.loop.Err()
case <-ctx.Done(): case <-ctx.Done():
return nil, ErrCanceledByCaller return nil, ErrCanceledByCaller
case <-a.onConnected: case <-a.onConnected:
@@ -68,7 +68,7 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
// Read implements the Conn Read method. // Read implements the Conn Read method.
func (c *Conn) Read(p []byte) (int, error) { func (c *Conn) Read(p []byte) (int, error) {
err := c.agent.ok() err := c.agent.loop.Err()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -80,7 +80,7 @@ func (c *Conn) Read(p []byte) (int, error) {
// Write implements the Conn Write method. // Write implements the Conn Write method.
func (c *Conn) Write(p []byte) (int, error) { func (c *Conn) Write(p []byte) (int, error) {
err := c.agent.ok() err := c.agent.loop.Err()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -91,8 +91,8 @@ func (c *Conn) Write(p []byte) (int, error) {
pair := c.agent.getSelectedPair() pair := c.agent.getSelectedPair()
if pair == nil { if pair == nil {
if err = c.agent.run(c.agent.context(), func(_ context.Context, a *Agent) { if err = c.agent.loop.Run(c.agent.loop, func(_ context.Context) {
pair = a.getBestValidCandidatePair() pair = c.agent.getBestValidCandidatePair()
}); err != nil { }); err != nil {
return 0, err return 0, err
} }

View File

@@ -49,8 +49,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
var cs ConnectionState var cs ConnectionState
err := c.agent.run(context.Background(), func(_ context.Context, agent *Agent) { err := c.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = agent.connectionState cs = c.agent.connectionState
}) })
if err != nil { if err != nil {
// We should never get here. // We should never get here.