Implement ICE Role conflict resolution

Detect if remote has a role conflict and resolve it as defined by
RFC 8445 section-7.3.1.1

Resolves #359
This commit is contained in:
Sean DuBois
2025-07-16 16:34:30 -04:00
parent f6a1153ce7
commit 2c04474e38
2 changed files with 154 additions and 45 deletions

135
agent.go
View File

@@ -62,7 +62,7 @@ type Agent struct {
muHaveStarted sync.Mutex
startedCh <-chan struct{}
startedFn func()
isControlling bool
isControlling atomic.Bool
maxBindingRequests uint16
@@ -104,7 +104,9 @@ type Agent struct {
remoteCandidates map[NetworkType][]Candidate
checklist []*CandidatePair
selector pairCandidateSelector
selectorLock sync.RWMutex
selector pairCandidateSelector
selectedPair atomic.Value // *CandidatePair
@@ -343,21 +345,11 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
return a.loop.Run(a.loop, func(_ context.Context) {
a.isControlling = isControlling
a.isControlling.Store(isControlling)
a.remoteUfrag = remoteUfrag
a.remotePwd = remotePwd
a.setSelector()
if isControlling {
a.selector = &controllingSelector{agent: a, log: a.log}
} else {
a.selector = &controlledSelector{agent: a, log: a.log}
}
if a.lite {
a.selector = &liteSelector{pairCandidateSelector: a.selector}
}
a.selector.Start()
a.startedFn()
a.updateConnectionState(ConnectionStateChecking)
@@ -397,7 +389,7 @@ func (a *Agent) connectivityChecks() { //nolint:cyclop
default:
}
a.selector.ContactCandidates()
a.getSelector().ContactCandidates()
}); err != nil {
a.log.Warnf("Failed to start connectivity checks: %v", err)
}
@@ -501,7 +493,7 @@ func (a *Agent) pingAllCandidates() {
a.log.Tracef("Maximum requests reached for pair %s, marking it as failed", p)
p.state = CandidatePairStateFailed
} else {
a.selector.PingCandidate(p.Local, p.Remote)
a.getSelector().PingCandidate(p.Local, p.Remote)
p.bindingRequestCount++
}
}
@@ -542,7 +534,7 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
}
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
p := newCandidatePair(local, remote, a.isControlling)
p := newCandidatePair(local, remote, a.isControlling.Load())
a.checklist = append(a.checklist, p)
return p
@@ -598,7 +590,7 @@ func (a *Agent) checkKeepalive() {
if a.keepaliveInterval != 0 {
// We use binding request instead of indication to support refresh consent schemas
// see https://tools.ietf.org/html/rfc7675
a.selector.PingCandidate(selectedPair.Local, selectedPair.Remote)
a.getSelector().PingCandidate(selectedPair.Local, selectedPair.Remote)
}
}
@@ -1064,9 +1056,39 @@ func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bo
return false, nil, 0
}
func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, remoteTieBreaker *AttrControl) {
localIsGreaterOrEqual := a.tieBreaker >= remoteTieBreaker.Tiebreaker
a.log.Warnf("Role conflict local and remote same role(%s), localIsGreaterOrEqual(%t)", a.role(), localIsGreaterOrEqual)
// https://datatracker.ietf.org/doc/html/rfc8445#section-7.3.1.1
// An agent MUST examine the Binding request for either the ICE-
// CONTROLLING or ICE-CONTROLLED attribute. It MUST follow these
// procedures:
// If the agent's tiebreaker value is larger than or equal to the contents of the ICE-CONTROLLING attribute
// If the agent's tiebreaker value is less than the contents of the ICE-CONTROLLED attribute
// the agent generates a Binding error response
if (a.isControlling.Load() && localIsGreaterOrEqual) || (!a.isControlling.Load() && !localIsGreaterOrEqual) {
if roleConflictMsg, err := stun.Build(msg, stun.BindingError,
stun.ErrorCodeAttribute{
Code: stun.CodeRoleConflict,
Reason: []byte("Role Conflict"),
},
stun.NewShortTermIntegrity(a.localPwd),
stun.Fingerprint,
); err != nil {
a.log.Warnf("Failed to generate Role Conflict message from: %s to: %s error: %s", local, remote, err)
} else {
a.sendSTUN(roleConflictMsg, local, remote)
}
} else {
a.isControlling.Store(!a.isControlling.Load())
a.setSelector()
}
}
// handleInbound processes STUN traffic from a remote candidate.
func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop
var err error
if msg == nil || local == nil {
return
}
@@ -1080,27 +1102,10 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
return
}
if a.isControlling {
if msg.Contains(stun.AttrICEControlling) {
a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true")
return
} else if msg.Contains(stun.AttrUseCandidate) {
a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true")
return
}
} else {
if msg.Contains(stun.AttrICEControlled) {
a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false")
return
}
}
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err)
return
@@ -1112,7 +1117,7 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
return
}
a.selector.HandleSuccessResponse(msg, local, remoteCandidate, remote)
a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote)
} else if msg.Type.Class == stun.ClassRequest {
a.log.Tracef(
"Inbound STUN (Request) from %s to %s, useCandidate: %v",
@@ -1121,11 +1126,11 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
msg.Contains(stun.AttrUseCandidate),
)
if err = stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
if err := stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err)
return
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
} else if err := stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err)
return
@@ -1160,7 +1165,16 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
a.addRemoteCandidate(remoteCandidate)
}
a.selector.HandleBindingRequest(msg, local, remoteCandidate)
// Support Remotes that don't set a TIE-BREAKER. Not standards compliant, but
// keeping to maintain backwards compat
remoteTieBreaker := &AttrControl{}
if err := remoteTieBreaker.GetFrom(msg); err == nil && remoteTieBreaker.Role == a.role() {
a.handleRoleConflict(msg, local, remoteCandidate, remoteTieBreaker)
return
}
a.getSelector().HandleBindingRequest(msg, local, remoteCandidate)
}
if remoteCandidate != nil {
@@ -1282,9 +1296,7 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
a.pendingBindingRequests = make([]bindingRequest, 0)
a.setSelectedPair(nil)
a.deleteAllCandidates()
if a.selector != nil {
a.selector.Start()
}
a.setSelector()
// Restart is used by NewAgent. Accept/Connect should be used to move to checking
// for new Agents
@@ -1319,3 +1331,36 @@ func (a *Agent) setGatheringState(newState GatheringState) error {
func (a *Agent) needsToCheckPriorityOnNominated() bool {
return !a.lite || a.enableUseCandidateCheckPriority
}
func (a *Agent) role() Role {
if a.isControlling.Load() {
return Controlling
}
return Controlled
}
func (a *Agent) setSelector() {
a.selectorLock.Lock()
defer a.selectorLock.Unlock()
var s pairCandidateSelector
if a.isControlling.Load() {
s = &controllingSelector{agent: a, log: a.log}
} else {
s = &controlledSelector{agent: a, log: a.log}
}
if a.lite {
s = &liteSelector{pairCandidateSelector: s}
}
s.Start()
a.selector = s
}
func (a *Agent) getSelector() pairCandidateSelector {
a.selectorLock.Lock()
defer a.selectorLock.Unlock()
return a.selector
}

View File

@@ -1947,3 +1947,67 @@ func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop
newLastSent = pair.Local.LastSent()
require.NotEqual(t, lastSent, newLastSent)
}
func TestRoleConflict(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
runTest := func(doDial bool) {
cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
InterfaceFilter: problematicNetworkInterfaces,
}
aAgent, err := NewAgent(cfg)
require.NoError(t, err)
bAgent, err := NewAgent(cfg)
require.NoError(t, err)
isConnected := make(chan any)
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
if c == ConnectionStateConnected {
close(isConnected)
}
})
require.NoError(t, err)
gatherAndExchangeCandidates(t, aAgent, bAgent)
go func() {
ufrag, pwd, routineErr := bAgent.GetLocalUserCredentials()
require.NoError(t, routineErr)
if doDial {
_, routineErr = aAgent.Dial(context.TODO(), ufrag, pwd)
} else {
_, routineErr = aAgent.Accept(context.TODO(), ufrag, pwd)
}
require.NoError(t, routineErr)
}()
ufrag, pwd, err := aAgent.GetLocalUserCredentials()
require.NoError(t, err)
if doDial {
_, err = bAgent.Dial(context.TODO(), ufrag, pwd)
} else {
_, err = bAgent.Accept(context.TODO(), ufrag, pwd)
}
require.NoError(t, err)
<-isConnected
require.NoError(t, aAgent.Close())
require.NoError(t, bAgent.Close())
}
t.Run("Controlling", func(t *testing.T) {
runTest(true)
})
t.Run("Controlled", func(t *testing.T) {
runTest(false)
})
}