mirror of
https://github.com/pion/ice.git
synced 2025-09-26 19:41:11 +08:00
Implement ICE Role conflict resolution
Detect if remote has a role conflict and resolve it as defined by RFC 8445 section-7.3.1.1 Resolves #359
This commit is contained in:
135
agent.go
135
agent.go
@@ -62,7 +62,7 @@ type Agent struct {
|
||||
muHaveStarted sync.Mutex
|
||||
startedCh <-chan struct{}
|
||||
startedFn func()
|
||||
isControlling bool
|
||||
isControlling atomic.Bool
|
||||
|
||||
maxBindingRequests uint16
|
||||
|
||||
@@ -104,7 +104,9 @@ type Agent struct {
|
||||
remoteCandidates map[NetworkType][]Candidate
|
||||
|
||||
checklist []*CandidatePair
|
||||
selector pairCandidateSelector
|
||||
|
||||
selectorLock sync.RWMutex
|
||||
selector pairCandidateSelector
|
||||
|
||||
selectedPair atomic.Value // *CandidatePair
|
||||
|
||||
@@ -343,21 +345,11 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
|
||||
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
|
||||
|
||||
return a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.isControlling = isControlling
|
||||
a.isControlling.Store(isControlling)
|
||||
a.remoteUfrag = remoteUfrag
|
||||
a.remotePwd = remotePwd
|
||||
a.setSelector()
|
||||
|
||||
if isControlling {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
} else {
|
||||
a.selector = &controlledSelector{agent: a, log: a.log}
|
||||
}
|
||||
|
||||
if a.lite {
|
||||
a.selector = &liteSelector{pairCandidateSelector: a.selector}
|
||||
}
|
||||
|
||||
a.selector.Start()
|
||||
a.startedFn()
|
||||
|
||||
a.updateConnectionState(ConnectionStateChecking)
|
||||
@@ -397,7 +389,7 @@ func (a *Agent) connectivityChecks() { //nolint:cyclop
|
||||
default:
|
||||
}
|
||||
|
||||
a.selector.ContactCandidates()
|
||||
a.getSelector().ContactCandidates()
|
||||
}); err != nil {
|
||||
a.log.Warnf("Failed to start connectivity checks: %v", err)
|
||||
}
|
||||
@@ -501,7 +493,7 @@ func (a *Agent) pingAllCandidates() {
|
||||
a.log.Tracef("Maximum requests reached for pair %s, marking it as failed", p)
|
||||
p.state = CandidatePairStateFailed
|
||||
} else {
|
||||
a.selector.PingCandidate(p.Local, p.Remote)
|
||||
a.getSelector().PingCandidate(p.Local, p.Remote)
|
||||
p.bindingRequestCount++
|
||||
}
|
||||
}
|
||||
@@ -542,7 +534,7 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
|
||||
}
|
||||
|
||||
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
|
||||
p := newCandidatePair(local, remote, a.isControlling)
|
||||
p := newCandidatePair(local, remote, a.isControlling.Load())
|
||||
a.checklist = append(a.checklist, p)
|
||||
|
||||
return p
|
||||
@@ -598,7 +590,7 @@ func (a *Agent) checkKeepalive() {
|
||||
if a.keepaliveInterval != 0 {
|
||||
// We use binding request instead of indication to support refresh consent schemas
|
||||
// see https://tools.ietf.org/html/rfc7675
|
||||
a.selector.PingCandidate(selectedPair.Local, selectedPair.Remote)
|
||||
a.getSelector().PingCandidate(selectedPair.Local, selectedPair.Remote)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1064,9 +1056,39 @@ func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bo
|
||||
return false, nil, 0
|
||||
}
|
||||
|
||||
func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, remoteTieBreaker *AttrControl) {
|
||||
localIsGreaterOrEqual := a.tieBreaker >= remoteTieBreaker.Tiebreaker
|
||||
a.log.Warnf("Role conflict local and remote same role(%s), localIsGreaterOrEqual(%t)", a.role(), localIsGreaterOrEqual)
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc8445#section-7.3.1.1
|
||||
// An agent MUST examine the Binding request for either the ICE-
|
||||
// CONTROLLING or ICE-CONTROLLED attribute. It MUST follow these
|
||||
// procedures:
|
||||
|
||||
// If the agent's tiebreaker value is larger than or equal to the contents of the ICE-CONTROLLING attribute
|
||||
// If the agent's tiebreaker value is less than the contents of the ICE-CONTROLLED attribute
|
||||
// the agent generates a Binding error response
|
||||
if (a.isControlling.Load() && localIsGreaterOrEqual) || (!a.isControlling.Load() && !localIsGreaterOrEqual) {
|
||||
if roleConflictMsg, err := stun.Build(msg, stun.BindingError,
|
||||
stun.ErrorCodeAttribute{
|
||||
Code: stun.CodeRoleConflict,
|
||||
Reason: []byte("Role Conflict"),
|
||||
},
|
||||
stun.NewShortTermIntegrity(a.localPwd),
|
||||
stun.Fingerprint,
|
||||
); err != nil {
|
||||
a.log.Warnf("Failed to generate Role Conflict message from: %s to: %s error: %s", local, remote, err)
|
||||
} else {
|
||||
a.sendSTUN(roleConflictMsg, local, remote)
|
||||
}
|
||||
} else {
|
||||
a.isControlling.Store(!a.isControlling.Load())
|
||||
a.setSelector()
|
||||
}
|
||||
}
|
||||
|
||||
// handleInbound processes STUN traffic from a remote candidate.
|
||||
func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop
|
||||
var err error
|
||||
if msg == nil || local == nil {
|
||||
return
|
||||
}
|
||||
@@ -1080,27 +1102,10 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
|
||||
return
|
||||
}
|
||||
|
||||
if a.isControlling {
|
||||
if msg.Contains(stun.AttrICEControlling) {
|
||||
a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true")
|
||||
|
||||
return
|
||||
} else if msg.Contains(stun.AttrUseCandidate) {
|
||||
a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true")
|
||||
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if msg.Contains(stun.AttrICEControlled) {
|
||||
a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
|
||||
|
||||
if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif
|
||||
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
|
||||
if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
|
||||
a.log.Warnf("Discard message from (%s), %v", remote, err)
|
||||
|
||||
return
|
||||
@@ -1112,7 +1117,7 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
|
||||
return
|
||||
}
|
||||
|
||||
a.selector.HandleSuccessResponse(msg, local, remoteCandidate, remote)
|
||||
a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote)
|
||||
} else if msg.Type.Class == stun.ClassRequest {
|
||||
a.log.Tracef(
|
||||
"Inbound STUN (Request) from %s to %s, useCandidate: %v",
|
||||
@@ -1121,11 +1126,11 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
|
||||
msg.Contains(stun.AttrUseCandidate),
|
||||
)
|
||||
|
||||
if err = stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
|
||||
if err := stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
|
||||
a.log.Warnf("Discard message from (%s), %v", remote, err)
|
||||
|
||||
return
|
||||
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
|
||||
} else if err := stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
|
||||
a.log.Warnf("Discard message from (%s), %v", remote, err)
|
||||
|
||||
return
|
||||
@@ -1160,7 +1165,16 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add
|
||||
a.addRemoteCandidate(remoteCandidate)
|
||||
}
|
||||
|
||||
a.selector.HandleBindingRequest(msg, local, remoteCandidate)
|
||||
// Support Remotes that don't set a TIE-BREAKER. Not standards compliant, but
|
||||
// keeping to maintain backwards compat
|
||||
remoteTieBreaker := &AttrControl{}
|
||||
if err := remoteTieBreaker.GetFrom(msg); err == nil && remoteTieBreaker.Role == a.role() {
|
||||
a.handleRoleConflict(msg, local, remoteCandidate, remoteTieBreaker)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
a.getSelector().HandleBindingRequest(msg, local, remoteCandidate)
|
||||
}
|
||||
|
||||
if remoteCandidate != nil {
|
||||
@@ -1282,9 +1296,7 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
|
||||
a.pendingBindingRequests = make([]bindingRequest, 0)
|
||||
a.setSelectedPair(nil)
|
||||
a.deleteAllCandidates()
|
||||
if a.selector != nil {
|
||||
a.selector.Start()
|
||||
}
|
||||
a.setSelector()
|
||||
|
||||
// Restart is used by NewAgent. Accept/Connect should be used to move to checking
|
||||
// for new Agents
|
||||
@@ -1319,3 +1331,36 @@ func (a *Agent) setGatheringState(newState GatheringState) error {
|
||||
func (a *Agent) needsToCheckPriorityOnNominated() bool {
|
||||
return !a.lite || a.enableUseCandidateCheckPriority
|
||||
}
|
||||
|
||||
func (a *Agent) role() Role {
|
||||
if a.isControlling.Load() {
|
||||
return Controlling
|
||||
}
|
||||
|
||||
return Controlled
|
||||
}
|
||||
|
||||
func (a *Agent) setSelector() {
|
||||
a.selectorLock.Lock()
|
||||
defer a.selectorLock.Unlock()
|
||||
|
||||
var s pairCandidateSelector
|
||||
if a.isControlling.Load() {
|
||||
s = &controllingSelector{agent: a, log: a.log}
|
||||
} else {
|
||||
s = &controlledSelector{agent: a, log: a.log}
|
||||
}
|
||||
if a.lite {
|
||||
s = &liteSelector{pairCandidateSelector: s}
|
||||
}
|
||||
|
||||
s.Start()
|
||||
a.selector = s
|
||||
}
|
||||
|
||||
func (a *Agent) getSelector() pairCandidateSelector {
|
||||
a.selectorLock.Lock()
|
||||
defer a.selectorLock.Unlock()
|
||||
|
||||
return a.selector
|
||||
}
|
||||
|
@@ -1947,3 +1947,67 @@ func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop
|
||||
newLastSent = pair.Local.LastSent()
|
||||
require.NotEqual(t, lastSent, newLastSent)
|
||||
}
|
||||
|
||||
func TestRoleConflict(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
defer test.TimeOut(time.Second * 30).Stop()
|
||||
|
||||
runTest := func(doDial bool) {
|
||||
cfg := &AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
MulticastDNSMode: MulticastDNSModeDisabled,
|
||||
InterfaceFilter: problematicNetworkInterfaces,
|
||||
}
|
||||
|
||||
aAgent, err := NewAgent(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
bAgent, err := NewAgent(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
isConnected := make(chan any)
|
||||
err = aAgent.OnConnectionStateChange(func(c ConnectionState) {
|
||||
if c == ConnectionStateConnected {
|
||||
close(isConnected)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
gatherAndExchangeCandidates(t, aAgent, bAgent)
|
||||
|
||||
go func() {
|
||||
ufrag, pwd, routineErr := bAgent.GetLocalUserCredentials()
|
||||
require.NoError(t, routineErr)
|
||||
|
||||
if doDial {
|
||||
_, routineErr = aAgent.Dial(context.TODO(), ufrag, pwd)
|
||||
} else {
|
||||
_, routineErr = aAgent.Accept(context.TODO(), ufrag, pwd)
|
||||
}
|
||||
require.NoError(t, routineErr)
|
||||
}()
|
||||
|
||||
ufrag, pwd, err := aAgent.GetLocalUserCredentials()
|
||||
require.NoError(t, err)
|
||||
|
||||
if doDial {
|
||||
_, err = bAgent.Dial(context.TODO(), ufrag, pwd)
|
||||
} else {
|
||||
_, err = bAgent.Accept(context.TODO(), ufrag, pwd)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
<-isConnected
|
||||
|
||||
require.NoError(t, aAgent.Close())
|
||||
require.NoError(t, bAgent.Close())
|
||||
}
|
||||
|
||||
t.Run("Controlling", func(t *testing.T) {
|
||||
runTest(true)
|
||||
})
|
||||
|
||||
t.Run("Controlled", func(t *testing.T) {
|
||||
runTest(false)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user