Move taskloop into dedicated package

Reduce size of Agent and simplify code
This commit is contained in:
Steffen Vogel
2024-03-21 11:47:51 -04:00
committed by Sean DuBois
parent b36d33253b
commit fdca6c47c0
10 changed files with 227 additions and 215 deletions

181
agent.go
View File

@@ -15,8 +15,8 @@ import (
"sync/atomic"
"time"
atomicx "github.com/pion/ice/v3/internal/atomic"
stunx "github.com/pion/ice/v3/internal/stun"
"github.com/pion/ice/v3/internal/taskloop"
"github.com/pion/logging"
"github.com/pion/mdns/v2"
"github.com/pion/stun/v2"
@@ -36,13 +36,12 @@ type bindingRequest struct {
// Agent represents the ICE agent
type Agent struct {
chanTask chan task
loop *taskloop.Loop
onConnectionStateChangeHdlr atomic.Value // func(ConnectionState)
onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate)
onCandidateHdlr atomic.Value // func(Candidate)
// State owned by the taskLoop
onConnected chan struct{}
onConnectedOnce sync.Once
@@ -118,11 +117,6 @@ type Agent struct {
// 1:1 D-NAT IP address mapping
extIPMapper *externalIPMapper
// State for closing
done chan struct{}
taskLoopDone chan struct{}
err atomicx.Error
gatherCandidateCancel func()
gatherCandidateDone chan struct{}
@@ -147,74 +141,6 @@ type Agent struct {
proxyDialer proxy.Dialer
}
type task struct {
fn func(context.Context, *Agent)
done chan struct{}
}
func (a *Agent) ok() error {
select {
case <-a.done:
return a.getErr()
default:
}
return nil
}
func (a *Agent) getErr() error {
if err := a.err.Load(); err != nil {
return err
}
return ErrClosed
}
// Run task in serial. Blocking tasks must be cancelable by context.
func (a *Agent) run(ctx context.Context, t func(context.Context, *Agent)) error {
if err := a.ok(); err != nil {
return err
}
done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case a.chanTask <- task{t, done}:
<-done
return nil
}
}
// taskLoop handles registered tasks and agent close.
func (a *Agent) taskLoop() {
defer func() {
a.deleteAllCandidates()
a.startedFn()
if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
}
a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)
a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
}
close(a.taskLoopDone)
}()
for {
select {
case <-a.done:
return
case t := <-a.chanTask:
t.fn(a.context(), a)
close(t.done)
}
}
}
// NewAgent creates a new Agent
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
var err error
@@ -247,7 +173,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
startedCtx, startedFn := context.WithCancel(context.Background())
a := &Agent{
chanTask: make(chan task),
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
@@ -258,8 +183,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buf: packetio.NewBuffer(),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portMin: config.PortMin,
@@ -333,7 +256,23 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
return nil, err
}
go a.taskLoop()
a.loop = taskloop.New(func() {
a.removeUfragFromMux()
a.deleteAllCandidates()
a.startedFn()
if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
}
a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)
a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
}
})
// Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
@@ -359,10 +298,10 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
return a.run(a.context(), func(_ context.Context, agent *Agent) {
agent.isControlling = isControlling
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
return a.loop.Run(a.loop, func(_ context.Context) {
a.isControlling = isControlling
a.remoteUfrag = remoteUfrag
a.remotePwd = remotePwd
if isControlling {
a.selector = &controllingSelector{agent: a, log: a.log}
@@ -377,7 +316,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.selector.Start()
a.startedFn()
agent.updateConnectionState(ConnectionStateChecking)
a.updateConnectionState(ConnectionStateChecking)
a.requestConnectivityCheck()
go a.connectivityChecks() //nolint:contextcheck
@@ -389,7 +328,7 @@ func (a *Agent) connectivityChecks() {
checkingDuration := time.Time{}
contact := func() {
if err := a.run(a.context(), func(_ context.Context, a *Agent) {
if err := a.loop.Run(a.loop, func(_ context.Context) {
defer func() {
lastConnectionState = a.connectionState
}()
@@ -446,7 +385,7 @@ func (a *Agent) connectivityChecks() {
contact()
case <-t.C:
contact()
case <-a.done:
case <-a.loop.Done():
t.Stop()
return
}
@@ -638,9 +577,9 @@ func (a *Agent) AddRemoteCandidate(c Candidate) error {
}
go func() {
if err := a.run(a.context(), func(_ context.Context, agent *Agent) {
if err := a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck
agent.addRemoteCandidate(c)
a.addRemoteCandidate(c)
}); err != nil {
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
return
@@ -670,9 +609,9 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
return
}
if err = a.run(a.context(), func(_ context.Context, agent *Agent) {
if err = a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck
agent.addRemoteCandidate(c)
a.addRemoteCandidate(c)
}); err != nil {
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
return
@@ -695,7 +634,7 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
for i := range localIPs {
conn := newActiveTCPConn(
a.context(),
a.loop,
net.JoinHostPort(localIPs[i].String(), "0"),
net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())),
a.log,
@@ -763,7 +702,7 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
}
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
return a.run(ctx, func(context.Context, *Agent) {
return a.loop.Run(ctx, func(context.Context) {
set := a.localCandidates[c.NetworkType()]
for _, candidate := range set {
if candidate.Equal(c) {
@@ -799,9 +738,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
var res []Candidate
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
err := a.loop.Run(a.loop, func(_ context.Context) {
var candidates []Candidate
for _, set := range agent.remoteCandidates {
for _, set := range a.remoteCandidates {
candidates = append(candidates, set...)
}
res = candidates
@@ -817,9 +756,9 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
var res []Candidate
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
err := a.loop.Run(a.loop, func(_ context.Context) {
var candidates []Candidate
for _, set := range agent.localCandidates {
for _, set := range a.localCandidates {
candidates = append(candidates, set...)
}
res = candidates
@@ -834,9 +773,9 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
// GetLocalUserCredentials returns the local user credentials
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.run(a.context(), func(_ context.Context, agent *Agent) {
frag = agent.localUfrag
pwd = agent.localPwd
err = a.loop.Run(a.loop, func(_ context.Context) {
frag = a.localUfrag
pwd = a.localPwd
close(valSet)
})
@@ -849,9 +788,9 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
// GetRemoteUserCredentials returns the remote user credentials
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.run(a.context(), func(_ context.Context, agent *Agent) {
frag = agent.remoteUfrag
pwd = agent.remotePwd
err = a.loop.Run(a.loop, func(_ context.Context) {
frag = a.remoteUfrag
pwd = a.remotePwd
close(valSet)
})
@@ -875,17 +814,7 @@ func (a *Agent) removeUfragFromMux() {
// Close cleans up the Agent
func (a *Agent) Close() error {
if err := a.ok(); err != nil {
return err
}
a.err.Store(ErrClosed)
a.removeUfragFromMux()
close(a.done)
<-a.taskLoopDone
return nil
return a.loop.Close()
}
// Remove all candidates. This closes any listening sockets
@@ -1092,7 +1021,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
// and returns true if it is an actual remote candidate
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
var remoteCandidate Candidate
if err := a.run(local.context(), func(context.Context, *Agent) {
if err := a.loop.Run(local.context(), func(context.Context) {
remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote)
if remoteCandidate != nil {
remoteCandidate.seen(false)
@@ -1149,9 +1078,9 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
return ErrRemotePwdEmpty
}
return a.run(a.context(), func(_ context.Context, agent *Agent) {
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
return a.loop.Run(a.loop, func(_ context.Context) {
a.remoteUfrag = remoteUfrag
a.remotePwd = remotePwd
})
}
@@ -1186,17 +1115,17 @@ func (a *Agent) Restart(ufrag, pwd string) error {
}
var err error
if runErr := a.run(a.context(), func(_ context.Context, agent *Agent) {
if agent.gatheringState == GatheringStateGathering {
agent.gatherCandidateCancel()
if runErr := a.loop.Run(a.loop, func(_ context.Context) {
if a.gatheringState == GatheringStateGathering {
a.gatherCandidateCancel()
}
// Clear all agent needed to take back to fresh state
a.removeUfragFromMux()
agent.localUfrag = ufrag
agent.localPwd = pwd
agent.remoteUfrag = ""
agent.remotePwd = ""
a.localUfrag = ufrag
a.localPwd = pwd
a.remoteUfrag = ""
a.remotePwd = ""
a.gatheringState = GatheringStateNew
a.checklist = make([]*CandidatePair, 0)
a.pendingBindingRequests = make([]bindingRequest, 0)
@@ -1219,7 +1148,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {
func (a *Agent) setGatheringState(newState GatheringState) error {
done := make(chan struct{})
if err := a.run(a.context(), func(context.Context, *Agent) {
if err := a.loop.Run(a.loop, func(context.Context) {
if a.gatheringState != newState && newState == GatheringStateComplete {
a.candidateNotifier.EnqueueCandidate(nil)
}

View File

@@ -22,7 +22,7 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
})
require.NoError(t, err)
err = agent.run(context.Background(), func(_ context.Context, agent *Agent) {
err = agent.loop.Run(context.Background(), func(_ context.Context) {
agent.setSelectedPair(candidatePair)
})
require.NoError(t, err)

View File

@@ -11,9 +11,9 @@ import (
// GetCandidatePairsStats returns a list of candidate pair stats
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
var res []CandidatePairStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
result := make([]CandidatePairStats, 0, len(agent.checklist))
for _, cp := range agent.checklist {
err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidatePairStats, 0, len(a.checklist))
for _, cp := range a.checklist {
stat := CandidatePairStats{
Timestamp: time.Now(),
LocalCandidateID: cp.Local.ID(),
@@ -57,9 +57,9 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
// GetLocalCandidatesStats returns a list of local candidates stats
func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
var res []CandidateStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
result := make([]CandidateStats, 0, len(agent.localCandidates))
for networkType, localCandidates := range agent.localCandidates {
err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidateStats, 0, len(a.localCandidates))
for networkType, localCandidates := range a.localCandidates {
for _, c := range localCandidates {
relayProtocol := ""
if c.Type() == CandidateTypeRelay {
@@ -94,9 +94,9 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
// GetRemoteCandidatesStats returns a list of remote candidates stats
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
var res []CandidateStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
result := make([]CandidateStats, 0, len(agent.remoteCandidates))
for networkType, remoteCandidates := range agent.remoteCandidates {
err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidateStats, 0, len(a.remoteCandidates))
for networkType, remoteCandidates := range a.remoteCandidates {
for _, c := range remoteCandidates {
stat := CandidateStats{
Timestamp: time.Now(),

View File

@@ -34,19 +34,6 @@ func (ba *BadAddr) String() string {
return "yyy"
}
func runAgentTest(t *testing.T, config *AgentConfig, task func(ctx context.Context, a *Agent)) {
a, err := NewAgent(config)
if err != nil {
t.Fatalf("Error constructing ice.Agent")
}
if err := a.run(context.Background(), task); err != nil {
t.Fatalf("Agent run failure: %v", err)
}
assert.NoError(t, a.Close())
}
func TestHandlePeerReflexive(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
@@ -56,8 +43,10 @@ func TestHandlePeerReflexive(t *testing.T) {
defer lim.Stop()
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
hostConfig := CandidateHostConfig{
@@ -113,12 +102,15 @@ func TestHandlePeerReflexive(t *testing.T) {
if c.Port() != 999 {
t.Fatal("Port number mismatch")
}
})
}))
assert.NoError(t, a.Close())
})
t.Run("Bad network type with handleInbound()", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
hostConfig := CandidateHostConfig{
@@ -140,12 +132,16 @@ func TestHandlePeerReflexive(t *testing.T) {
if len(a.remoteCandidates) != 0 {
t.Fatal("bad address should not be added to the remote candidate list")
}
})
}))
assert.NoError(t, a.Close())
})
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC")
@@ -179,7 +175,9 @@ func TestHandlePeerReflexive(t *testing.T) {
if len(a.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate")
}
})
}))
assert.NoError(t, a.Close())
})
}
@@ -440,7 +438,7 @@ func TestInboundValidity(t *testing.T) {
t.Fatalf("Error constructing ice.Agent")
}
err = a.run(context.Background(), func(_ context.Context, a *Agent) {
err = a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
// nolint: contextcheck
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote)
@@ -454,8 +452,10 @@ func TestInboundValidity(t *testing.T) {
})
t.Run("Valid bind without fingerprint", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)
assert.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
@@ -470,7 +470,9 @@ func TestInboundValidity(t *testing.T) {
if len(a.remoteCandidates) != 1 {
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
}
})
}))
assert.NoError(t, a.Close())
})
t.Run("Success with invalid TransactionID", func(t *testing.T) {
@@ -1120,7 +1122,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
<-isFailed
done := make(chan struct{})
assert.NoError(t, aAgent.run(context.Background(), func(context.Context, *Agent) {
assert.NoError(t, aAgent.loop.Run(aAgent.loop, func(context.Context) {
assert.Equal(t, len(aAgent.remoteCandidates), 0)
assert.Equal(t, len(aAgent.localCandidates), 0)
close(done)

View File

@@ -267,7 +267,7 @@ func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
return
}
if err := a.run(c, func(_ context.Context, a *Agent) {
if err := a.loop.Run(c, func(_ context.Context) {
// nolint: contextcheck
a.handleInbound(m, c, srcAddr)
}); err != nil {

View File

@@ -1,40 +0,0 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package ice
import (
"context"
"time"
)
func (a *Agent) context() context.Context {
return agentContext(a.done)
}
type agentContext chan struct{}
// Done implements context.Context
func (a agentContext) Done() <-chan struct{} {
return (chan struct{})(a)
}
// Err implements context.Context
func (a agentContext) Err() error {
select {
case <-(chan struct{})(a):
return ErrRunCanceled
default:
return nil
}
}
// Deadline implements context.Context
func (a agentContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Value implements context.Context
func (a agentContext) Value(interface{}) interface{} {
return nil
}

View File

@@ -42,7 +42,7 @@ func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ..
func (a *Agent) GatherCandidates() error {
var gatherErr error
if runErr := a.run(a.context(), func(ctx context.Context, _ *Agent) {
if runErr := a.loop.Run(a.loop, func(ctx context.Context) {
if a.gatheringState != GatheringStateNew {
gatherErr = ErrMultipleGatherAttempted
return
@@ -495,7 +495,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
select {
case <-cancelCtx.Done():
return
case <-a.done:
case <-a.loop.Done():
_ = conn.Close()
}
}()

View File

@@ -0,0 +1,121 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package taskloop implements a task loop to run
// tasks sequentially in a separate Goroutine.
package taskloop
import (
"context"
"errors"
"time"
atomicx "github.com/pion/ice/v3/internal/atomic"
)
// errClosed indicates that the loop has been stopped
var errClosed = errors.New("the agent is closed")
type task struct {
fn func(context.Context)
done chan struct{}
}
// Loop runs submitted task serially in a dedicated Goroutine
type Loop struct {
tasks chan task
// State for closing
done chan struct{}
taskLoopDone chan struct{}
err atomicx.Error
}
// New creates and starts a new task loop
func New(onClose func()) *Loop {
l := &Loop{
tasks: make(chan task),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
}
go l.runLoop(onClose)
return l
}
// runLoop handles registered tasks and agent close.
func (l *Loop) runLoop(onClose func()) {
defer func() {
onClose()
close(l.taskLoopDone)
}()
for {
select {
case <-l.done:
return
case t := <-l.tasks:
t.fn(l)
close(t.done)
}
}
}
// Close stops the loop after finishing the execution of the current task.
// Other pending tasks will not be executed.
func (l *Loop) Close() error {
if err := l.Err(); err != nil {
return err
}
l.err.Store(errClosed)
close(l.done)
<-l.taskLoopDone
return nil
}
// Run serially executes the submitted callback.
// Blocking tasks must be cancelable by context.
func (l *Loop) Run(ctx context.Context, t func(context.Context)) error {
if err := l.Err(); err != nil {
return err
}
done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case l.tasks <- task{t, done}:
<-done
return nil
}
}
// The following methods implement context.Context for TaskLoop
// Done returns a channel that's closed when the task loop has been stopped.
func (l *Loop) Done() <-chan struct{} {
return l.done
}
// Err returns nil if the task loop is still running.
// Otherwise it return errClosed if the loop has been closed/stopped.
func (l *Loop) Err() error {
select {
case <-l.done:
return errClosed
default:
return nil
}
}
// Deadline returns the no valid time as task loops have no deadline.
func (l *Loop) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Value is not supported for task loops
func (l *Loop) Value(interface{}) interface{} {
return nil
}

View File

@@ -43,7 +43,7 @@ func (c *Conn) BytesReceived() uint64 {
}
func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) {
err := a.ok()
err := a.loop.Err()
if err != nil {
return nil, err
}
@@ -54,8 +54,8 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
// Block until pair selected
select {
case <-a.done:
return nil, a.getErr()
case <-a.loop.Done():
return nil, a.loop.Err()
case <-ctx.Done():
return nil, ErrCanceledByCaller
case <-a.onConnected:
@@ -68,7 +68,7 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re
// Read implements the Conn Read method.
func (c *Conn) Read(p []byte) (int, error) {
err := c.agent.ok()
err := c.agent.loop.Err()
if err != nil {
return 0, err
}
@@ -80,7 +80,7 @@ func (c *Conn) Read(p []byte) (int, error) {
// Write implements the Conn Write method.
func (c *Conn) Write(p []byte) (int, error) {
err := c.agent.ok()
err := c.agent.loop.Err()
if err != nil {
return 0, err
}
@@ -91,8 +91,8 @@ func (c *Conn) Write(p []byte) (int, error) {
pair := c.agent.getSelectedPair()
if pair == nil {
if err = c.agent.run(c.agent.context(), func(_ context.Context, a *Agent) {
pair = a.getBestValidCandidatePair()
if err = c.agent.loop.Run(c.agent.loop, func(_ context.Context) {
pair = c.agent.getBestValidCandidatePair()
}); err != nil {
return 0, err
}

View File

@@ -49,8 +49,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
var cs ConnectionState
err := c.agent.run(context.Background(), func(_ context.Context, agent *Agent) {
cs = agent.connectionState
err := c.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = c.agent.connectionState
})
if err != nil {
// We should never get here.