mirror of
				https://github.com/pion/ice.git
				synced 2025-11-01 03:02:39 +08:00 
			
		
		
		
	Implement ICE Role conflict resolution
Detect if remote has a role conflict and resolve it as defined by RFC 8445 section-7.3.1.1 Resolves #359
This commit is contained in:
		
							
								
								
									
										133
									
								
								agent.go
									
									
									
									
									
								
							
							
						
						
									
										133
									
								
								agent.go
									
									
									
									
									
								
							| @@ -62,7 +62,7 @@ type Agent struct { | |||||||
| 	muHaveStarted sync.Mutex | 	muHaveStarted sync.Mutex | ||||||
| 	startedCh     <-chan struct{} | 	startedCh     <-chan struct{} | ||||||
| 	startedFn     func() | 	startedFn     func() | ||||||
| 	isControlling bool | 	isControlling atomic.Bool | ||||||
|  |  | ||||||
| 	maxBindingRequests uint16 | 	maxBindingRequests uint16 | ||||||
|  |  | ||||||
| @@ -104,6 +104,8 @@ type Agent struct { | |||||||
| 	remoteCandidates map[NetworkType][]Candidate | 	remoteCandidates map[NetworkType][]Candidate | ||||||
|  |  | ||||||
| 	checklist []*CandidatePair | 	checklist []*CandidatePair | ||||||
|  |  | ||||||
|  | 	selectorLock sync.RWMutex | ||||||
| 	selector     pairCandidateSelector | 	selector     pairCandidateSelector | ||||||
|  |  | ||||||
| 	selectedPair atomic.Value // *CandidatePair | 	selectedPair atomic.Value // *CandidatePair | ||||||
| @@ -343,21 +345,11 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP | |||||||
| 	a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) | 	a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) | ||||||
|  |  | ||||||
| 	return a.loop.Run(a.loop, func(_ context.Context) { | 	return a.loop.Run(a.loop, func(_ context.Context) { | ||||||
| 		a.isControlling = isControlling | 		a.isControlling.Store(isControlling) | ||||||
| 		a.remoteUfrag = remoteUfrag | 		a.remoteUfrag = remoteUfrag | ||||||
| 		a.remotePwd = remotePwd | 		a.remotePwd = remotePwd | ||||||
|  | 		a.setSelector() | ||||||
|  |  | ||||||
| 		if isControlling { |  | ||||||
| 			a.selector = &controllingSelector{agent: a, log: a.log} |  | ||||||
| 		} else { |  | ||||||
| 			a.selector = &controlledSelector{agent: a, log: a.log} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if a.lite { |  | ||||||
| 			a.selector = &liteSelector{pairCandidateSelector: a.selector} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		a.selector.Start() |  | ||||||
| 		a.startedFn() | 		a.startedFn() | ||||||
|  |  | ||||||
| 		a.updateConnectionState(ConnectionStateChecking) | 		a.updateConnectionState(ConnectionStateChecking) | ||||||
| @@ -397,7 +389,7 @@ func (a *Agent) connectivityChecks() { //nolint:cyclop | |||||||
| 			default: | 			default: | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			a.selector.ContactCandidates() | 			a.getSelector().ContactCandidates() | ||||||
| 		}); err != nil { | 		}); err != nil { | ||||||
| 			a.log.Warnf("Failed to start connectivity checks: %v", err) | 			a.log.Warnf("Failed to start connectivity checks: %v", err) | ||||||
| 		} | 		} | ||||||
| @@ -501,7 +493,7 @@ func (a *Agent) pingAllCandidates() { | |||||||
| 			a.log.Tracef("Maximum requests reached for pair %s, marking it as failed", p) | 			a.log.Tracef("Maximum requests reached for pair %s, marking it as failed", p) | ||||||
| 			p.state = CandidatePairStateFailed | 			p.state = CandidatePairStateFailed | ||||||
| 		} else { | 		} else { | ||||||
| 			a.selector.PingCandidate(p.Local, p.Remote) | 			a.getSelector().PingCandidate(p.Local, p.Remote) | ||||||
| 			p.bindingRequestCount++ | 			p.bindingRequestCount++ | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -542,7 +534,7 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Agent) addPair(local, remote Candidate) *CandidatePair { | func (a *Agent) addPair(local, remote Candidate) *CandidatePair { | ||||||
| 	p := newCandidatePair(local, remote, a.isControlling) | 	p := newCandidatePair(local, remote, a.isControlling.Load()) | ||||||
| 	a.checklist = append(a.checklist, p) | 	a.checklist = append(a.checklist, p) | ||||||
|  |  | ||||||
| 	return p | 	return p | ||||||
| @@ -598,7 +590,7 @@ func (a *Agent) checkKeepalive() { | |||||||
| 	if a.keepaliveInterval != 0 { | 	if a.keepaliveInterval != 0 { | ||||||
| 		// We use binding request instead of indication to support refresh consent schemas | 		// We use binding request instead of indication to support refresh consent schemas | ||||||
| 		// see https://tools.ietf.org/html/rfc7675 | 		// see https://tools.ietf.org/html/rfc7675 | ||||||
| 		a.selector.PingCandidate(selectedPair.Local, selectedPair.Remote) | 		a.getSelector().PingCandidate(selectedPair.Local, selectedPair.Remote) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -1064,9 +1056,39 @@ func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bo | |||||||
| 	return false, nil, 0 | 	return false, nil, 0 | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *Agent) handleRoleConflict(msg *stun.Message, local, remote Candidate, remoteTieBreaker *AttrControl) { | ||||||
|  | 	localIsGreaterOrEqual := a.tieBreaker >= remoteTieBreaker.Tiebreaker | ||||||
|  | 	a.log.Warnf("Role conflict local and remote same role(%s), localIsGreaterOrEqual(%t)", a.role(), localIsGreaterOrEqual) | ||||||
|  |  | ||||||
|  | 	// https://datatracker.ietf.org/doc/html/rfc8445#section-7.3.1.1 | ||||||
|  | 	//  An agent MUST examine the Binding request for either the ICE- | ||||||
|  | 	//  CONTROLLING or ICE-CONTROLLED attribute.  It MUST follow these | ||||||
|  | 	// procedures: | ||||||
|  |  | ||||||
|  | 	// If the agent's tiebreaker value is larger than or equal to the contents of the ICE-CONTROLLING attribute | ||||||
|  | 	// If the agent's tiebreaker value is less than the contents of the ICE-CONTROLLED attribute | ||||||
|  | 	//  the agent generates a Binding error response | ||||||
|  | 	if (a.isControlling.Load() && localIsGreaterOrEqual) || (!a.isControlling.Load() && !localIsGreaterOrEqual) { | ||||||
|  | 		if roleConflictMsg, err := stun.Build(msg, stun.BindingError, | ||||||
|  | 			stun.ErrorCodeAttribute{ | ||||||
|  | 				Code:   stun.CodeRoleConflict, | ||||||
|  | 				Reason: []byte("Role Conflict"), | ||||||
|  | 			}, | ||||||
|  | 			stun.NewShortTermIntegrity(a.localPwd), | ||||||
|  | 			stun.Fingerprint, | ||||||
|  | 		); err != nil { | ||||||
|  | 			a.log.Warnf("Failed to generate Role Conflict message from: %s to: %s error: %s", local, remote, err) | ||||||
|  | 		} else { | ||||||
|  | 			a.sendSTUN(roleConflictMsg, local, remote) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		a.isControlling.Store(!a.isControlling.Load()) | ||||||
|  | 		a.setSelector() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // handleInbound processes STUN traffic from a remote candidate. | // handleInbound processes STUN traffic from a remote candidate. | ||||||
| func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop | func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop | ||||||
| 	var err error |  | ||||||
| 	if msg == nil || local == nil { | 	if msg == nil || local == nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| @@ -1080,27 +1102,10 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if a.isControlling { |  | ||||||
| 		if msg.Contains(stun.AttrICEControlling) { |  | ||||||
| 			a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true") |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} else if msg.Contains(stun.AttrUseCandidate) { |  | ||||||
| 			a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true") |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} else { |  | ||||||
| 		if msg.Contains(stun.AttrICEControlled) { |  | ||||||
| 			a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false") |  | ||||||
|  |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote) | 	remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote) | ||||||
|  |  | ||||||
| 	if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif | 	if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif | ||||||
| 		if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil { | 		if err := stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil { | ||||||
| 			a.log.Warnf("Discard message from (%s), %v", remote, err) | 			a.log.Warnf("Discard message from (%s), %v", remote, err) | ||||||
|  |  | ||||||
| 			return | 			return | ||||||
| @@ -1112,7 +1117,7 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		a.selector.HandleSuccessResponse(msg, local, remoteCandidate, remote) | 		a.getSelector().HandleSuccessResponse(msg, local, remoteCandidate, remote) | ||||||
| 	} else if msg.Type.Class == stun.ClassRequest { | 	} else if msg.Type.Class == stun.ClassRequest { | ||||||
| 		a.log.Tracef( | 		a.log.Tracef( | ||||||
| 			"Inbound STUN (Request) from %s to %s, useCandidate: %v", | 			"Inbound STUN (Request) from %s to %s, useCandidate: %v", | ||||||
| @@ -1121,11 +1126,11 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add | |||||||
| 			msg.Contains(stun.AttrUseCandidate), | 			msg.Contains(stun.AttrUseCandidate), | ||||||
| 		) | 		) | ||||||
|  |  | ||||||
| 		if err = stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil { | 		if err := stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil { | ||||||
| 			a.log.Warnf("Discard message from (%s), %v", remote, err) | 			a.log.Warnf("Discard message from (%s), %v", remote, err) | ||||||
|  |  | ||||||
| 			return | 			return | ||||||
| 		} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil { | 		} else if err := stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil { | ||||||
| 			a.log.Warnf("Discard message from (%s), %v", remote, err) | 			a.log.Warnf("Discard message from (%s), %v", remote, err) | ||||||
|  |  | ||||||
| 			return | 			return | ||||||
| @@ -1160,7 +1165,16 @@ func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Add | |||||||
| 			a.addRemoteCandidate(remoteCandidate) | 			a.addRemoteCandidate(remoteCandidate) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		a.selector.HandleBindingRequest(msg, local, remoteCandidate) | 		// Support Remotes that don't set a TIE-BREAKER. Not standards compliant, but | ||||||
|  | 		// keeping to maintain backwards compat | ||||||
|  | 		remoteTieBreaker := &AttrControl{} | ||||||
|  | 		if err := remoteTieBreaker.GetFrom(msg); err == nil && remoteTieBreaker.Role == a.role() { | ||||||
|  | 			a.handleRoleConflict(msg, local, remoteCandidate, remoteTieBreaker) | ||||||
|  |  | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		a.getSelector().HandleBindingRequest(msg, local, remoteCandidate) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if remoteCandidate != nil { | 	if remoteCandidate != nil { | ||||||
| @@ -1282,9 +1296,7 @@ func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop | |||||||
| 		a.pendingBindingRequests = make([]bindingRequest, 0) | 		a.pendingBindingRequests = make([]bindingRequest, 0) | ||||||
| 		a.setSelectedPair(nil) | 		a.setSelectedPair(nil) | ||||||
| 		a.deleteAllCandidates() | 		a.deleteAllCandidates() | ||||||
| 		if a.selector != nil { | 		a.setSelector() | ||||||
| 			a.selector.Start() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// Restart is used by NewAgent. Accept/Connect should be used to move to checking | 		// Restart is used by NewAgent. Accept/Connect should be used to move to checking | ||||||
| 		// for new Agents | 		// for new Agents | ||||||
| @@ -1319,3 +1331,36 @@ func (a *Agent) setGatheringState(newState GatheringState) error { | |||||||
| func (a *Agent) needsToCheckPriorityOnNominated() bool { | func (a *Agent) needsToCheckPriorityOnNominated() bool { | ||||||
| 	return !a.lite || a.enableUseCandidateCheckPriority | 	return !a.lite || a.enableUseCandidateCheckPriority | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *Agent) role() Role { | ||||||
|  | 	if a.isControlling.Load() { | ||||||
|  | 		return Controlling | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return Controlled | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Agent) setSelector() { | ||||||
|  | 	a.selectorLock.Lock() | ||||||
|  | 	defer a.selectorLock.Unlock() | ||||||
|  |  | ||||||
|  | 	var s pairCandidateSelector | ||||||
|  | 	if a.isControlling.Load() { | ||||||
|  | 		s = &controllingSelector{agent: a, log: a.log} | ||||||
|  | 	} else { | ||||||
|  | 		s = &controlledSelector{agent: a, log: a.log} | ||||||
|  | 	} | ||||||
|  | 	if a.lite { | ||||||
|  | 		s = &liteSelector{pairCandidateSelector: s} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.Start() | ||||||
|  | 	a.selector = s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (a *Agent) getSelector() pairCandidateSelector { | ||||||
|  | 	a.selectorLock.Lock() | ||||||
|  | 	defer a.selectorLock.Unlock() | ||||||
|  |  | ||||||
|  | 	return a.selector | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1947,3 +1947,67 @@ func TestAlwaysSentKeepAlive(t *testing.T) { //nolint:cyclop | |||||||
| 	newLastSent = pair.Local.LastSent() | 	newLastSent = pair.Local.LastSent() | ||||||
| 	require.NotEqual(t, lastSent, newLastSent) | 	require.NotEqual(t, lastSent, newLastSent) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestRoleConflict(t *testing.T) { | ||||||
|  | 	defer test.CheckRoutines(t)() | ||||||
|  | 	defer test.TimeOut(time.Second * 30).Stop() | ||||||
|  |  | ||||||
|  | 	runTest := func(doDial bool) { | ||||||
|  | 		cfg := &AgentConfig{ | ||||||
|  | 			NetworkTypes:     supportedNetworkTypes(), | ||||||
|  | 			MulticastDNSMode: MulticastDNSModeDisabled, | ||||||
|  | 			InterfaceFilter:  problematicNetworkInterfaces, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		aAgent, err := NewAgent(cfg) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 		bAgent, err := NewAgent(cfg) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 		isConnected := make(chan any) | ||||||
|  | 		err = aAgent.OnConnectionStateChange(func(c ConnectionState) { | ||||||
|  | 			if c == ConnectionStateConnected { | ||||||
|  | 				close(isConnected) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 		require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 		gatherAndExchangeCandidates(t, aAgent, bAgent) | ||||||
|  |  | ||||||
|  | 		go func() { | ||||||
|  | 			ufrag, pwd, routineErr := bAgent.GetLocalUserCredentials() | ||||||
|  | 			require.NoError(t, routineErr) | ||||||
|  |  | ||||||
|  | 			if doDial { | ||||||
|  | 				_, routineErr = aAgent.Dial(context.TODO(), ufrag, pwd) | ||||||
|  | 			} else { | ||||||
|  | 				_, routineErr = aAgent.Accept(context.TODO(), ufrag, pwd) | ||||||
|  | 			} | ||||||
|  | 			require.NoError(t, routineErr) | ||||||
|  | 		}() | ||||||
|  |  | ||||||
|  | 		ufrag, pwd, err := aAgent.GetLocalUserCredentials() | ||||||
|  | 		require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 		if doDial { | ||||||
|  | 			_, err = bAgent.Dial(context.TODO(), ufrag, pwd) | ||||||
|  | 		} else { | ||||||
|  | 			_, err = bAgent.Accept(context.TODO(), ufrag, pwd) | ||||||
|  | 		} | ||||||
|  | 		require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 		<-isConnected | ||||||
|  |  | ||||||
|  | 		require.NoError(t, aAgent.Close()) | ||||||
|  | 		require.NoError(t, bAgent.Close()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	t.Run("Controlling", func(t *testing.T) { | ||||||
|  | 		runTest(true) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("Controlled", func(t *testing.T) { | ||||||
|  | 		runTest(false) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Sean DuBois
					Sean DuBois