mirror of
https://github.com/pion/ice.git
synced 2025-09-26 19:41:11 +08:00
Move taskloop into dedicated package
Reduce size of Agent and simplify code
This commit is contained in:

committed by
Sean DuBois

parent
b36d33253b
commit
fdca6c47c0
181
agent.go
181
agent.go
@@ -15,8 +15,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
atomicx "github.com/pion/ice/v3/internal/atomic"
|
||||
stunx "github.com/pion/ice/v3/internal/stun"
|
||||
"github.com/pion/ice/v3/internal/taskloop"
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/mdns/v2"
|
||||
"github.com/pion/stun/v2"
|
||||
@@ -36,13 +36,12 @@ type bindingRequest struct {
|
||||
|
||||
// Agent represents the ICE agent
|
||||
type Agent struct {
|
||||
chanTask chan task
|
||||
loop *taskloop.Loop
|
||||
|
||||
onConnectionStateChangeHdlr atomic.Value // func(ConnectionState)
|
||||
onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate)
|
||||
onCandidateHdlr atomic.Value // func(Candidate)
|
||||
|
||||
// State owned by the taskLoop
|
||||
onConnected chan struct{}
|
||||
onConnectedOnce sync.Once
|
||||
|
||||
@@ -118,11 +117,6 @@ type Agent struct {
|
||||
// 1:1 D-NAT IP address mapping
|
||||
extIPMapper *externalIPMapper
|
||||
|
||||
// State for closing
|
||||
done chan struct{}
|
||||
taskLoopDone chan struct{}
|
||||
err atomicx.Error
|
||||
|
||||
gatherCandidateCancel func()
|
||||
gatherCandidateDone chan struct{}
|
||||
|
||||
@@ -147,74 +141,6 @@ type Agent struct {
|
||||
proxyDialer proxy.Dialer
|
||||
}
|
||||
|
||||
type task struct {
|
||||
fn func(context.Context, *Agent)
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (a *Agent) ok() error {
|
||||
select {
|
||||
case <-a.done:
|
||||
return a.getErr()
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) getErr() error {
|
||||
if err := a.err.Load(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrClosed
|
||||
}
|
||||
|
||||
// Run task in serial. Blocking tasks must be cancelable by context.
|
||||
func (a *Agent) run(ctx context.Context, t func(context.Context, *Agent)) error {
|
||||
if err := a.ok(); err != nil {
|
||||
return err
|
||||
}
|
||||
done := make(chan struct{})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case a.chanTask <- task{t, done}:
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// taskLoop handles registered tasks and agent close.
|
||||
func (a *Agent) taskLoop() {
|
||||
defer func() {
|
||||
a.deleteAllCandidates()
|
||||
a.startedFn()
|
||||
|
||||
if err := a.buf.Close(); err != nil {
|
||||
a.log.Warnf("Failed to close buffer: %v", err)
|
||||
}
|
||||
|
||||
a.closeMulticastConn()
|
||||
a.updateConnectionState(ConnectionStateClosed)
|
||||
|
||||
a.gatherCandidateCancel()
|
||||
if a.gatherCandidateDone != nil {
|
||||
<-a.gatherCandidateDone
|
||||
}
|
||||
|
||||
close(a.taskLoopDone)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.done:
|
||||
return
|
||||
case t := <-a.chanTask:
|
||||
t.fn(a.context(), a)
|
||||
close(t.done)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent
|
||||
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
var err error
|
||||
@@ -247,7 +173,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
startedCtx, startedFn := context.WithCancel(context.Background())
|
||||
|
||||
a := &Agent{
|
||||
chanTask: make(chan task),
|
||||
tieBreaker: globalMathRandomGenerator.Uint64(),
|
||||
lite: config.Lite,
|
||||
gatheringState: GatheringStateNew,
|
||||
@@ -258,8 +183,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
networkTypes: config.NetworkTypes,
|
||||
onConnected: make(chan struct{}),
|
||||
buf: packetio.NewBuffer(),
|
||||
done: make(chan struct{}),
|
||||
taskLoopDone: make(chan struct{}),
|
||||
startedCh: startedCtx.Done(),
|
||||
startedFn: startedFn,
|
||||
portMin: config.PortMin,
|
||||
@@ -333,7 +256,23 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go a.taskLoop()
|
||||
a.loop = taskloop.New(func() {
|
||||
a.removeUfragFromMux()
|
||||
a.deleteAllCandidates()
|
||||
a.startedFn()
|
||||
|
||||
if err := a.buf.Close(); err != nil {
|
||||
a.log.Warnf("Failed to close buffer: %v", err)
|
||||
}
|
||||
|
||||
a.closeMulticastConn()
|
||||
a.updateConnectionState(ConnectionStateClosed)
|
||||
|
||||
a.gatherCandidateCancel()
|
||||
if a.gatherCandidateDone != nil {
|
||||
<-a.gatherCandidateDone
|
||||
}
|
||||
})
|
||||
|
||||
// Restart is also used to initialize the agent for the first time
|
||||
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
|
||||
@@ -359,10 +298,10 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
|
||||
|
||||
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
|
||||
|
||||
return a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
agent.isControlling = isControlling
|
||||
agent.remoteUfrag = remoteUfrag
|
||||
agent.remotePwd = remotePwd
|
||||
return a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.isControlling = isControlling
|
||||
a.remoteUfrag = remoteUfrag
|
||||
a.remotePwd = remotePwd
|
||||
|
||||
if isControlling {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
@@ -377,7 +316,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
|
||||
a.selector.Start()
|
||||
a.startedFn()
|
||||
|
||||
agent.updateConnectionState(ConnectionStateChecking)
|
||||
a.updateConnectionState(ConnectionStateChecking)
|
||||
|
||||
a.requestConnectivityCheck()
|
||||
go a.connectivityChecks() //nolint:contextcheck
|
||||
@@ -389,7 +328,7 @@ func (a *Agent) connectivityChecks() {
|
||||
checkingDuration := time.Time{}
|
||||
|
||||
contact := func() {
|
||||
if err := a.run(a.context(), func(_ context.Context, a *Agent) {
|
||||
if err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
defer func() {
|
||||
lastConnectionState = a.connectionState
|
||||
}()
|
||||
@@ -446,7 +385,7 @@ func (a *Agent) connectivityChecks() {
|
||||
contact()
|
||||
case <-t.C:
|
||||
contact()
|
||||
case <-a.done:
|
||||
case <-a.loop.Done():
|
||||
t.Stop()
|
||||
return
|
||||
}
|
||||
@@ -638,9 +577,9 @@ func (a *Agent) AddRemoteCandidate(c Candidate) error {
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
if err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
// nolint: contextcheck
|
||||
agent.addRemoteCandidate(c)
|
||||
a.addRemoteCandidate(c)
|
||||
}); err != nil {
|
||||
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
|
||||
return
|
||||
@@ -670,9 +609,9 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
if err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
// nolint: contextcheck
|
||||
agent.addRemoteCandidate(c)
|
||||
a.addRemoteCandidate(c)
|
||||
}); err != nil {
|
||||
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
|
||||
return
|
||||
@@ -695,7 +634,7 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
|
||||
|
||||
for i := range localIPs {
|
||||
conn := newActiveTCPConn(
|
||||
a.context(),
|
||||
a.loop,
|
||||
net.JoinHostPort(localIPs[i].String(), "0"),
|
||||
net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())),
|
||||
a.log,
|
||||
@@ -763,7 +702,7 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
|
||||
}
|
||||
|
||||
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
|
||||
return a.run(ctx, func(context.Context, *Agent) {
|
||||
return a.loop.Run(ctx, func(context.Context) {
|
||||
set := a.localCandidates[c.NetworkType()]
|
||||
for _, candidate := range set {
|
||||
if candidate.Equal(c) {
|
||||
@@ -799,9 +738,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
|
||||
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
|
||||
var res []Candidate
|
||||
|
||||
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
var candidates []Candidate
|
||||
for _, set := range agent.remoteCandidates {
|
||||
for _, set := range a.remoteCandidates {
|
||||
candidates = append(candidates, set...)
|
||||
}
|
||||
res = candidates
|
||||
@@ -817,9 +756,9 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
|
||||
func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
|
||||
var res []Candidate
|
||||
|
||||
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
var candidates []Candidate
|
||||
for _, set := range agent.localCandidates {
|
||||
for _, set := range a.localCandidates {
|
||||
candidates = append(candidates, set...)
|
||||
}
|
||||
res = candidates
|
||||
@@ -834,9 +773,9 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
|
||||
// GetLocalUserCredentials returns the local user credentials
|
||||
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
|
||||
valSet := make(chan struct{})
|
||||
err = a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
frag = agent.localUfrag
|
||||
pwd = agent.localPwd
|
||||
err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
frag = a.localUfrag
|
||||
pwd = a.localPwd
|
||||
close(valSet)
|
||||
})
|
||||
|
||||
@@ -849,9 +788,9 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
|
||||
// GetRemoteUserCredentials returns the remote user credentials
|
||||
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
|
||||
valSet := make(chan struct{})
|
||||
err = a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
frag = agent.remoteUfrag
|
||||
pwd = agent.remotePwd
|
||||
err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
frag = a.remoteUfrag
|
||||
pwd = a.remotePwd
|
||||
close(valSet)
|
||||
})
|
||||
|
||||
@@ -875,17 +814,7 @@ func (a *Agent) removeUfragFromMux() {
|
||||
|
||||
// Close cleans up the Agent
|
||||
func (a *Agent) Close() error {
|
||||
if err := a.ok(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.err.Store(ErrClosed)
|
||||
|
||||
a.removeUfragFromMux()
|
||||
|
||||
close(a.done)
|
||||
<-a.taskLoopDone
|
||||
return nil
|
||||
return a.loop.Close()
|
||||
}
|
||||
|
||||
// Remove all candidates. This closes any listening sockets
|
||||
@@ -1092,7 +1021,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
|
||||
// and returns true if it is an actual remote candidate
|
||||
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
|
||||
var remoteCandidate Candidate
|
||||
if err := a.run(local.context(), func(context.Context, *Agent) {
|
||||
if err := a.loop.Run(local.context(), func(context.Context) {
|
||||
remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote)
|
||||
if remoteCandidate != nil {
|
||||
remoteCandidate.seen(false)
|
||||
@@ -1149,9 +1078,9 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
|
||||
return ErrRemotePwdEmpty
|
||||
}
|
||||
|
||||
return a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
agent.remoteUfrag = remoteUfrag
|
||||
agent.remotePwd = remotePwd
|
||||
return a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.remoteUfrag = remoteUfrag
|
||||
a.remotePwd = remotePwd
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1186,17 +1115,17 @@ func (a *Agent) Restart(ufrag, pwd string) error {
|
||||
}
|
||||
|
||||
var err error
|
||||
if runErr := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
if agent.gatheringState == GatheringStateGathering {
|
||||
agent.gatherCandidateCancel()
|
||||
if runErr := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
if a.gatheringState == GatheringStateGathering {
|
||||
a.gatherCandidateCancel()
|
||||
}
|
||||
|
||||
// Clear all agent needed to take back to fresh state
|
||||
a.removeUfragFromMux()
|
||||
agent.localUfrag = ufrag
|
||||
agent.localPwd = pwd
|
||||
agent.remoteUfrag = ""
|
||||
agent.remotePwd = ""
|
||||
a.localUfrag = ufrag
|
||||
a.localPwd = pwd
|
||||
a.remoteUfrag = ""
|
||||
a.remotePwd = ""
|
||||
a.gatheringState = GatheringStateNew
|
||||
a.checklist = make([]*CandidatePair, 0)
|
||||
a.pendingBindingRequests = make([]bindingRequest, 0)
|
||||
@@ -1219,7 +1148,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {
|
||||
|
||||
func (a *Agent) setGatheringState(newState GatheringState) error {
|
||||
done := make(chan struct{})
|
||||
if err := a.run(a.context(), func(context.Context, *Agent) {
|
||||
if err := a.loop.Run(a.loop, func(context.Context) {
|
||||
if a.gatheringState != newState && newState == GatheringStateComplete {
|
||||
a.candidateNotifier.EnqueueCandidate(nil)
|
||||
}
|
||||
|
@@ -22,7 +22,7 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = agent.run(context.Background(), func(_ context.Context, agent *Agent) {
|
||||
err = agent.loop.Run(context.Background(), func(_ context.Context) {
|
||||
agent.setSelectedPair(candidatePair)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@@ -11,9 +11,9 @@ import (
|
||||
// GetCandidatePairsStats returns a list of candidate pair stats
|
||||
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
|
||||
var res []CandidatePairStats
|
||||
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
result := make([]CandidatePairStats, 0, len(agent.checklist))
|
||||
for _, cp := range agent.checklist {
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
result := make([]CandidatePairStats, 0, len(a.checklist))
|
||||
for _, cp := range a.checklist {
|
||||
stat := CandidatePairStats{
|
||||
Timestamp: time.Now(),
|
||||
LocalCandidateID: cp.Local.ID(),
|
||||
@@ -57,9 +57,9 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
|
||||
// GetLocalCandidatesStats returns a list of local candidates stats
|
||||
func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
|
||||
var res []CandidateStats
|
||||
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
result := make([]CandidateStats, 0, len(agent.localCandidates))
|
||||
for networkType, localCandidates := range agent.localCandidates {
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
result := make([]CandidateStats, 0, len(a.localCandidates))
|
||||
for networkType, localCandidates := range a.localCandidates {
|
||||
for _, c := range localCandidates {
|
||||
relayProtocol := ""
|
||||
if c.Type() == CandidateTypeRelay {
|
||||
@@ -94,9 +94,9 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
|
||||
// GetRemoteCandidatesStats returns a list of remote candidates stats
|
||||
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
|
||||
var res []CandidateStats
|
||||
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
|
||||
result := make([]CandidateStats, 0, len(agent.remoteCandidates))
|
||||
for networkType, remoteCandidates := range agent.remoteCandidates {
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
result := make([]CandidateStats, 0, len(a.remoteCandidates))
|
||||
for networkType, remoteCandidates := range a.remoteCandidates {
|
||||
for _, c := range remoteCandidates {
|
||||
stat := CandidateStats{
|
||||
Timestamp: time.Now(),
|
||||
|
@@ -34,19 +34,6 @@ func (ba *BadAddr) String() string {
|
||||
return "yyy"
|
||||
}
|
||||
|
||||
func runAgentTest(t *testing.T, config *AgentConfig, task func(ctx context.Context, a *Agent)) {
|
||||
a, err := NewAgent(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
if err := a.run(context.Background(), task); err != nil {
|
||||
t.Fatalf("Agent run failure: %v", err)
|
||||
}
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
func TestHandlePeerReflexive(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
@@ -56,8 +43,10 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
defer lim.Stop()
|
||||
|
||||
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
|
||||
var config AgentConfig
|
||||
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
|
||||
hostConfig := CandidateHostConfig{
|
||||
@@ -113,12 +102,15 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
if c.Port() != 999 {
|
||||
t.Fatal("Port number mismatch")
|
||||
}
|
||||
})
|
||||
}))
|
||||
assert.NoError(t, a.Close())
|
||||
})
|
||||
|
||||
t.Run("Bad network type with handleInbound()", func(t *testing.T) {
|
||||
var config AgentConfig
|
||||
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
|
||||
hostConfig := CandidateHostConfig{
|
||||
@@ -140,12 +132,16 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
if len(a.remoteCandidates) != 0 {
|
||||
t.Fatal("bad address should not be added to the remote candidate list")
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
})
|
||||
|
||||
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
|
||||
var config AgentConfig
|
||||
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
tID := [stun.TransactionIDSize]byte{}
|
||||
copy(tID[:], "ABC")
|
||||
@@ -179,7 +175,9 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
if len(a.remoteCandidates) != 0 {
|
||||
t.Fatal("unknown remote was able to create a candidate")
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -440,7 +438,7 @@ func TestInboundValidity(t *testing.T) {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
|
||||
err = a.run(context.Background(), func(_ context.Context, a *Agent) {
|
||||
err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote)
|
||||
@@ -454,8 +452,10 @@ func TestInboundValidity(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Valid bind without fingerprint", func(t *testing.T) {
|
||||
var config AgentConfig
|
||||
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
|
||||
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
|
||||
@@ -470,7 +470,9 @@ func TestInboundValidity(t *testing.T) {
|
||||
if len(a.remoteCandidates) != 1 {
|
||||
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
})
|
||||
|
||||
t.Run("Success with invalid TransactionID", func(t *testing.T) {
|
||||
@@ -1120,7 +1122,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
|
||||
<-isFailed
|
||||
|
||||
done := make(chan struct{})
|
||||
assert.NoError(t, aAgent.run(context.Background(), func(context.Context, *Agent) {
|
||||
assert.NoError(t, aAgent.loop.Run(aAgent.loop, func(context.Context) {
|
||||
assert.Equal(t, len(aAgent.remoteCandidates), 0)
|
||||
assert.Equal(t, len(aAgent.localCandidates), 0)
|
||||
close(done)
|
||||
|
@@ -267,7 +267,7 @@ func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.run(c, func(_ context.Context, a *Agent) {
|
||||
if err := a.loop.Run(c, func(_ context.Context) {
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(m, c, srcAddr)
|
||||
}); err != nil {
|
||||
|
40
context.go
40
context.go
@@ -1,40 +0,0 @@
|
||||
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package ice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (a *Agent) context() context.Context {
|
||||
return agentContext(a.done)
|
||||
}
|
||||
|
||||
type agentContext chan struct{}
|
||||
|
||||
// Done implements context.Context
|
||||
func (a agentContext) Done() <-chan struct{} {
|
||||
return (chan struct{})(a)
|
||||
}
|
||||
|
||||
// Err implements context.Context
|
||||
func (a agentContext) Err() error {
|
||||
select {
|
||||
case <-(chan struct{})(a):
|
||||
return ErrRunCanceled
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline implements context.Context
|
||||
func (a agentContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Value implements context.Context
|
||||
func (a agentContext) Value(interface{}) interface{} {
|
||||
return nil
|
||||
}
|
@@ -42,7 +42,7 @@ func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ..
|
||||
func (a *Agent) GatherCandidates() error {
|
||||
var gatherErr error
|
||||
|
||||
if runErr := a.run(a.context(), func(ctx context.Context, _ *Agent) {
|
||||
if runErr := a.loop.Run(a.loop, func(ctx context.Context) {
|
||||
if a.gatheringState != GatheringStateNew {
|
||||
gatherErr = ErrMultipleGatherAttempted
|
||||
return
|
||||
@@ -495,7 +495,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
|
||||
select {
|
||||
case <-cancelCtx.Done():
|
||||
return
|
||||
case <-a.done:
|
||||
case <-a.loop.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
121
internal/taskloop/taskloop.go
Normal file
121
internal/taskloop/taskloop.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Package taskloop implements a task loop to run
|
||||
// tasks sequentially in a separate Goroutine.
|
||||
package taskloop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
atomicx "github.com/pion/ice/v3/internal/atomic"
|
||||
)
|
||||
|
||||
// errClosed indicates that the loop has been stopped
|
||||
var errClosed = errors.New("the agent is closed")
|
||||
|
||||
type task struct {
|
||||
fn func(context.Context)
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Loop runs submitted task serially in a dedicated Goroutine
|
||||
type Loop struct {
|
||||
tasks chan task
|
||||
|
||||
// State for closing
|
||||
done chan struct{}
|
||||
taskLoopDone chan struct{}
|
||||
err atomicx.Error
|
||||
}
|
||||
|
||||
// New creates and starts a new task loop
|
||||
func New(onClose func()) *Loop {
|
||||
l := &Loop{
|
||||
tasks: make(chan task),
|
||||
done: make(chan struct{}),
|
||||
taskLoopDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
go l.runLoop(onClose)
|
||||
return l
|
||||
}
|
||||
|
||||
// runLoop handles registered tasks and agent close.
|
||||
func (l *Loop) runLoop(onClose func()) {
|
||||
defer func() {
|
||||
onClose()
|
||||
close(l.taskLoopDone)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.done:
|
||||
return
|
||||
case t := <-l.tasks:
|
||||
t.fn(l)
|
||||
close(t.done)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the loop after finishing the execution of the current task.
|
||||
// Other pending tasks will not be executed.
|
||||
func (l *Loop) Close() error {
|
||||
if err := l.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.err.Store(errClosed)
|
||||
|
||||
close(l.done)
|
||||
<-l.taskLoopDone
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run serially executes the submitted callback.
|
||||
// Blocking tasks must be cancelable by context.
|
||||
func (l *Loop) Run(ctx context.Context, t func(context.Context)) error {
|
||||
if err := l.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
done := make(chan struct{})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case l.tasks <- task{t, done}:
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// The following methods implement context.Context for TaskLoop
|
||||
|
||||
// Done returns a channel that's closed when the task loop has been stopped.
|
||||
func (l *Loop) Done() <-chan struct{} {
|
||||
return l.done
|
||||
}
|
||||
|
||||
// Err returns nil if the task loop is still running.
|
||||
// Otherwise it return errClosed if the loop has been closed/stopped.
|
||||
func (l *Loop) Err() error {
|
||||
select {
|
||||
case <-l.done:
|
||||
return errClosed
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline returns the no valid time as task loops have no deadline.
|
||||
func (l *Loop) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Value is not supported for task loops
|
||||
func (l *Loop) Value(interface{}) interface{} {
|
||||
return nil
|
||||
}
|
14
transport.go
14
transport.go
@@ -43,7 +43,7 @@ func (c *Conn) BytesReceived() uint64 {
|
||||
}
|
||||
|
||||
func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) {
|
||||
err := a.ok()
|
||||
err := a.loop.Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -54,8 +54,8 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
|
||||
|
||||
// Block until pair selected
|
||||
select {
|
||||
case <-a.done:
|
||||
return nil, a.getErr()
|
||||
case <-a.loop.Done():
|
||||
return nil, a.loop.Err()
|
||||
case <-ctx.Done():
|
||||
return nil, ErrCanceledByCaller
|
||||
case <-a.onConnected:
|
||||
@@ -68,7 +68,7 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
|
||||
|
||||
// Read implements the Conn Read method.
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
err := c.agent.ok()
|
||||
err := c.agent.loop.Err()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func (c *Conn) Read(p []byte) (int, error) {
|
||||
|
||||
// Write implements the Conn Write method.
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
err := c.agent.ok()
|
||||
err := c.agent.loop.Err()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -91,8 +91,8 @@ func (c *Conn) Write(p []byte) (int, error) {
|
||||
|
||||
pair := c.agent.getSelectedPair()
|
||||
if pair == nil {
|
||||
if err = c.agent.run(c.agent.context(), func(_ context.Context, a *Agent) {
|
||||
pair = a.getBestValidCandidatePair()
|
||||
if err = c.agent.loop.Run(c.agent.loop, func(_ context.Context) {
|
||||
pair = c.agent.getBestValidCandidatePair()
|
||||
}); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@@ -49,8 +49,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
|
||||
|
||||
var cs ConnectionState
|
||||
|
||||
err := c.agent.run(context.Background(), func(_ context.Context, agent *Agent) {
|
||||
cs = agent.connectionState
|
||||
err := c.agent.loop.Run(context.Background(), func(_ context.Context) {
|
||||
cs = c.agent.connectionState
|
||||
})
|
||||
if err != nil {
|
||||
// We should never get here.
|
||||
|
Reference in New Issue
Block a user