Improve nomination

This implements a basic validation schema using a checklist. We try
every pair at least maxTries, and mark it as failed if we don't get a
success response after that many requests. Once we get a success
response, we check if it belongs to the best candidate available so far,
if it does we nominate it, otherwise we continue.

Also, after a given timeout, if no candidate has been nominated, we
simply choose the best valid candidate we got so far (if no candidate is
valid, we mark the connection as failed).

Finally, the nomination request also has a maximum of maxTries, we mark
the connection as failed if after that many attempt we fail to get a
success response.
This commit is contained in:
Hugo Arregui
2019-05-14 16:20:17 -03:00
committed by Sean DuBois
parent a58a281d3a
commit bf57064619
5 changed files with 311 additions and 90 deletions

232
agent.go
View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"sort"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -26,6 +25,24 @@ const (
// defaultConnectionTimeout used to declare a connection dead // defaultConnectionTimeout used to declare a connection dead
defaultConnectionTimeout = 30 * time.Second defaultConnectionTimeout = 30 * time.Second
// timeout for candidate selection, after this time, the best candidate is used
defaultCandidateSelectionTimeout = 10 * time.Second
// wait time before nominating a host candidate
defaultHostAcceptanceMinWait = 0
// wait time before nominating a srflx candidate
defaultSrflxAcceptanceMinWait = 500 * time.Millisecond
// wait time before nominating a prflx candidate
defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond
// wait time before nominating a relay candidate
defaultRelayAcceptanceMinWait = 2000 * time.Millisecond
// max binding request before considering a pair failed
defaultMaxBindingRequests = 7
// the number of bytes that can be buffered before we start to error // the number of bytes that can be buffered before we start to error
maxBufferSize = 1000 * 1000 // 1MB maxBufferSize = 1000 * 1000 // 1MB
@@ -37,18 +54,6 @@ var (
defaultCandidateTypes = []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay} defaultCandidateTypes = []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay}
) )
type candidatePairs []*candidatePair
func (cp candidatePairs) Len() int { return len(cp) }
func (cp candidatePairs) Swap(i, j int) { cp[i], cp[j] = cp[j], cp[i] }
type byPairPriority struct{ candidatePairs }
// NB: Reverse sort so our candidates start at highest priority
func (bp byPairPriority) Less(i, j int) bool {
return bp.candidatePairs[i].Priority() > bp.candidatePairs[j].Priority()
}
type bindingRequest struct { type bindingRequest struct {
transactionID [stun.TransactionIDSize]byte transactionID [stun.TransactionIDSize]byte
destination net.Addr destination net.Addr
@@ -81,6 +86,14 @@ type Agent struct {
haveStarted atomic.Value haveStarted atomic.Value
isControlling bool isControlling bool
maxBindingRequests uint16
candidateSelectionTimeout time.Duration
hostAcceptanceMinWait time.Duration
srflxAcceptanceMinWait time.Duration
prflxAcceptanceMinWait time.Duration
relayAcceptanceMinWait time.Duration
portmin uint16 portmin uint16
portmax uint16 portmax uint16
@@ -105,10 +118,9 @@ type Agent struct {
remotePwd string remotePwd string
remoteCandidates map[NetworkType][]Candidate remoteCandidates map[NetworkType][]Candidate
checklist []*candidatePair
selector pairCandidateSelector selector pairCandidateSelector
selectedPair *candidatePair selectedPair *candidatePair
validPairs candidatePairs
urls []*URL urls []*URL
networkTypes []NetworkType networkTypes []NetworkType
@@ -177,6 +189,25 @@ type AgentConfig struct {
// task loop handles things like sending keepAlives. This is only value for testing // task loop handles things like sending keepAlives. This is only value for testing
// keepAlive behavior should be modified with KeepaliveInterval and ConnectionTimeout // keepAlive behavior should be modified with KeepaliveInterval and ConnectionTimeout
taskLoopInterval time.Duration taskLoopInterval time.Duration
// MaxBindingRequests is the max amount of binding requests the agent will send
// over a candidate pair for validation or nomination, if after MaxBindingRequests
// the candidate is yet to answer a binding request or a nomination we set the pair as failed
MaxBindingRequests *uint16
// CandidatesSelectionTimeout specify a timeout for selecting candidates, if no nomination has happen
// before this timeout, once hit we will nominate the best valid candidate available,
// or mark the connection as failed if no valid candidate is available
CandidateSelectionTimeout *time.Duration
// HostAcceptanceMinWait specify a minimum wait time before selecting host candidates
HostAcceptanceMinWait *time.Duration
// HostAcceptanceMinWait specify a minimum wait time before selecting srflx candidates
SrflxAcceptanceMinWait *time.Duration
// HostAcceptanceMinWait specify a minimum wait time before selecting prflx candidates
PrflxAcceptanceMinWait *time.Duration
// HostAcceptanceMinWait specify a minimum wait time before selecting relay candidates
RelayAcceptanceMinWait *time.Duration
} }
// NewAgent creates a new Agent // NewAgent creates a new Agent
@@ -215,6 +246,42 @@ func NewAgent(config *AgentConfig) (*Agent, error) {
} }
a.haveStarted.Store(false) a.haveStarted.Store(false)
if config.MaxBindingRequests == nil {
a.maxBindingRequests = defaultMaxBindingRequests
} else {
a.maxBindingRequests = *config.MaxBindingRequests
}
if config.CandidateSelectionTimeout == nil {
a.candidateSelectionTimeout = defaultCandidateSelectionTimeout
} else {
a.candidateSelectionTimeout = *config.CandidateSelectionTimeout
}
if config.HostAcceptanceMinWait == nil {
a.hostAcceptanceMinWait = defaultHostAcceptanceMinWait
} else {
a.hostAcceptanceMinWait = *config.HostAcceptanceMinWait
}
if config.SrflxAcceptanceMinWait == nil {
a.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait
} else {
a.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait
}
if config.PrflxAcceptanceMinWait == nil {
a.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait
} else {
a.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait
}
if config.RelayAcceptanceMinWait == nil {
a.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
} else {
a.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
}
// Make sure the buffer doesn't grow indefinitely. // Make sure the buffer doesn't grow indefinitely.
// NOTE: We actually won't get anywhere close to this limit. // NOTE: We actually won't get anywhere close to this limit.
// SRTP will constantly read from the endpoint and drop packets if it's full. // SRTP will constantly read from the endpoint and drop packets if it's full.
@@ -299,16 +366,30 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)
return a.run(func(agent *Agent) { return a.run(func(agent *Agent) {
agent.isControlling = isControlling
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
a.checklist = make([]*candidatePair, 0)
for networkType, localCandidates := range a.localCandidates {
if remoteCandidates, ok := a.remoteCandidates[networkType]; ok {
for _, localCandidate := range localCandidates {
for _, remoteCandidate := range remoteCandidates {
a.addPair(localCandidate, remoteCandidate)
}
}
}
}
if isControlling { if isControlling {
a.selector = &controllingSelector{agent: a, log: a.log} a.selector = &controllingSelector{agent: a, log: a.log}
} else { } else {
a.selector = &controlledSelector{agent: a, log: a.log} a.selector = &controlledSelector{agent: a, log: a.log}
} }
a.selector.Start()
agent.isControlling = isControlling a.selector.Start()
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
agent.updateConnectionState(ConnectionStateChecking) agent.updateConnectionState(ConnectionStateChecking)
@@ -332,32 +413,6 @@ func (a *Agent) updateConnectionState(newState ConnectionState) {
} }
} }
func (a *Agent) findValidPair(local, remote Candidate) *candidatePair {
for _, p := range a.validPairs {
if p.local == local && p.remote == remote {
return p
}
}
return nil
}
func (a *Agent) addValidPair(local, remote Candidate) *candidatePair {
p := a.findValidPair(local, remote)
if p != nil {
a.log.Tracef("Candidate pair is already valid: %s", p)
return p
}
p = newCandidatePair(local, remote, a.isControlling)
a.log.Tracef("Found valid candidate pair: %s", p)
// keep track of pairs with succesfull bindings since any of them
// can be used for communication until the final pair is selected:
// https://tools.ietf.org/html/draft-ietf-ice-rfc5245bis-20#section-12
a.validPairs = append(a.validPairs, p)
return p
}
func (a *Agent) setSelectedPair(p *candidatePair) { func (a *Agent) setSelectedPair(p *candidatePair) {
a.log.Tracef("Set selected candidate pair: %s", p) a.log.Tracef("Set selected candidate pair: %s", p)
// Notify when the selected pair changes // Notify when the selected pair changes
@@ -370,12 +425,67 @@ func (a *Agent) setSelectedPair(p *candidatePair) {
a.onConnectedOnce.Do(func() { close(a.onConnected) }) a.onConnectedOnce.Do(func() { close(a.onConnected) })
} }
func (a *Agent) getBestValidPair() *candidatePair { func (a *Agent) pingAllCandidates() {
if len(a.validPairs) == 0 { for _, p := range a.checklist {
return nil if p.state != candidatePairStateChecking {
continue
}
if p.bindingRequestCount > a.maxBindingRequests {
a.log.Tracef("max requests reached for pair %s, marking it as failed\n", p)
p.state = candidatePairStateFailed
} else {
a.selector.PingCandidate(p.local, p.remote)
p.bindingRequestCount++
}
} }
sort.Sort(byPairPriority{a.validPairs}) }
return a.validPairs[0]
func (a *Agent) getBestAvailableCandidatePair() *candidatePair {
var best *candidatePair
for _, p := range a.checklist {
if p.state == candidatePairStateFailed {
continue
}
if best == nil {
best = p
} else if best.Priority() < p.Priority() {
best = p
}
}
return best
}
func (a *Agent) getBestValidCandidatePair() *candidatePair {
var best *candidatePair
for _, p := range a.checklist {
if p.state != candidatePairStateValid {
continue
}
if best == nil {
best = p
} else if best.Priority() < p.Priority() {
best = p
}
}
return best
}
func (a *Agent) addPair(local, remote Candidate) *candidatePair {
p := newCandidatePair(local, remote, a.isControlling)
a.checklist = append(a.checklist, p)
return p
}
func (a *Agent) findPair(local, remote Candidate) *candidatePair {
for _, p := range a.checklist {
if p.local == local && p.remote == remote {
return p
}
}
return nil
} }
// A task is a // A task is a
@@ -455,22 +565,6 @@ func (a *Agent) checkKeepalive() {
} }
} }
// pingAllCandidates sends STUN Binding Requests to all candidates
// Note: the caller should hold the agent lock.
func (a *Agent) pingAllCandidates() {
for networkType, localCandidates := range a.localCandidates {
if remoteCandidates, ok := a.remoteCandidates[networkType]; ok {
for _, localCandidate := range localCandidates {
for _, remoteCandidate := range remoteCandidates {
a.selector.PingCandidate(localCandidate, remoteCandidate)
}
}
}
}
}
// AddRemoteCandidate adds a new remote candidate // AddRemoteCandidate adds a new remote candidate
func (a *Agent) AddRemoteCandidate(c Candidate) error { func (a *Agent) AddRemoteCandidate(c Candidate) error {
return a.run(func(agent *Agent) { return a.run(func(agent *Agent) {
@@ -498,6 +592,12 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
} }
} }
} }
if localCandidates, ok := a.localCandidates[c.NetworkType()]; ok {
for _, localCandidate := range localCandidates {
a.addPair(localCandidate, c)
}
}
} }
// GetLocalCandidates returns the local candidates // GetLocalCandidates returns the local candidates

View File

@@ -35,11 +35,11 @@ func TestPairSearch(t *testing.T) {
t.Fatalf("Error constructing ice.Agent") t.Fatalf("Error constructing ice.Agent")
} }
if len(a.validPairs) != 0 { if len(a.checklist) != 0 {
t.Fatalf("TestPairSearch is only a valid test if a.validPairs is empty on construction") t.Fatalf("TestPairSearch is only a valid test if a.validPairs is empty on construction")
} }
cp := a.getBestValidPair() cp := a.getBestAvailableCandidatePair()
if cp != nil { if cp != nil {
t.Fatalf("No Candidate pairs should exist") t.Fatalf("No Candidate pairs should exist")
@@ -110,8 +110,14 @@ func TestPairPriority(t *testing.T) {
} }
for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} { for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} {
a.addValidPair(hostLocal, remote) p := a.findPair(hostLocal, remote)
bestPair := a.getBestValidPair()
if p == nil {
p = a.addPair(hostLocal, remote)
}
p.state = candidatePairStateValid
bestPair := a.getBestValidCandidatePair()
if bestPair.String() != (&candidatePair{remote: remote, local: hostLocal}).String() { if bestPair.String() != (&candidatePair{remote: remote, local: hostLocal}).String() {
t.Fatalf("Unexpected bestPair %s (expected remote: %s)", bestPair, remote) t.Fatalf("Unexpected bestPair %s (expected remote: %s)", bestPair, remote)
} }

View File

@@ -6,19 +6,42 @@ import (
"github.com/pion/stun" "github.com/pion/stun"
) )
type candidatePairState int
const (
candidatePairStateChecking candidatePairState = iota + 1
candidatePairStateFailed
candidatePairStateValid
)
func (c candidatePairState) String() string {
switch c {
case candidatePairStateChecking:
return "checking"
case candidatePairStateFailed:
return "failed"
case candidatePairStateValid:
return "valid"
}
return "Unknown candidate pair state"
}
func newCandidatePair(local, remote Candidate, controlling bool) *candidatePair { func newCandidatePair(local, remote Candidate, controlling bool) *candidatePair {
return &candidatePair{ return &candidatePair{
iceRoleControlling: controlling, iceRoleControlling: controlling,
remote: remote, remote: remote,
local: local, local: local,
state: candidatePairStateChecking,
} }
} }
// candidatePair represents a combination of a local and remote candidate // candidatePair represents a combination of a local and remote candidate
type candidatePair struct { type candidatePair struct {
iceRoleControlling bool iceRoleControlling bool
remote Candidate remote Candidate
local Candidate local Candidate
bindingRequestCount uint16
state candidatePairState
} }
func (p *candidatePair) String() string { func (p *candidatePair) String() string {

View File

@@ -2,6 +2,7 @@ package ice
import ( import (
"net" "net"
"time"
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun" "github.com/pion/stun"
@@ -16,12 +17,51 @@ type pairCandidateSelector interface {
} }
type controllingSelector struct { type controllingSelector struct {
agent *Agent startTime time.Time
nominatedPair *candidatePair agent *Agent
log logging.LeveledLogger nominatedPair *candidatePair
nominationRequestCount uint16
log logging.LeveledLogger
} }
func (s *controllingSelector) Start() { func (s *controllingSelector) Start() {
s.startTime = time.Now()
go func() {
time.Sleep(s.agent.candidateSelectionTimeout)
err := s.agent.run(func(a *Agent) {
if s.nominatedPair == nil {
p := s.agent.getBestValidCandidatePair()
if p == nil {
s.log.Trace("check timeout reached and no valid candidate pair found, marking connection as failed")
s.agent.updateConnectionState(ConnectionStateFailed)
} else {
s.log.Tracef("check timeout reached, nominating (%s, %s)", p.local.String(), p.remote.String())
s.nominatedPair = p
s.nominatePair(p)
}
}
})
if err != nil {
s.log.Errorf("error processing checkCandidatesTimeout handler %v", err.Error())
}
}()
}
func (s *controllingSelector) isNominatable(c Candidate) bool {
switch {
case c.Type() == CandidateTypeHost:
return time.Since(s.startTime).Nanoseconds() > s.agent.hostAcceptanceMinWait.Nanoseconds()
case c.Type() == CandidateTypeServerReflexive:
return time.Since(s.startTime).Nanoseconds() > s.agent.srflxAcceptanceMinWait.Nanoseconds()
case c.Type() == CandidateTypePeerReflexive:
return time.Since(s.startTime).Nanoseconds() > s.agent.prflxAcceptanceMinWait.Nanoseconds()
case c.Type() == CandidateTypeRelay:
return time.Since(s.startTime).Nanoseconds() > s.agent.relayAcceptanceMinWait.Nanoseconds()
}
s.log.Errorf("isNominatable invalid candidate type %s", c.Type().String())
return false
} }
func (s *controllingSelector) ContactCandidates() { func (s *controllingSelector) ContactCandidates() {
@@ -32,8 +72,21 @@ func (s *controllingSelector) ContactCandidates() {
s.agent.checkKeepalive() s.agent.checkKeepalive()
} }
case s.nominatedPair != nil: case s.nominatedPair != nil:
if s.nominationRequestCount > s.agent.maxBindingRequests {
s.log.Trace("max nomination requests reached, setting the connection state to failed")
s.agent.updateConnectionState(ConnectionStateFailed)
return
}
s.nominatePair(s.nominatedPair) s.nominatePair(s.nominatedPair)
default: default:
p := s.agent.getBestValidCandidatePair()
if p != nil && s.isNominatable(p.local) && s.isNominatable(p.remote) {
s.log.Tracef("Nominatable pair found, nominating (%s, %s)", p.local.String(), p.remote.String())
s.nominatedPair = p
s.nominatePair(p)
return
}
s.log.Trace("pinging all candidates") s.log.Trace("pinging all candidates")
s.agent.pingAllCandidates() s.agent.pingAllCandidates()
} }
@@ -60,15 +113,29 @@ func (s *controllingSelector) nominatePair(pair *candidatePair) {
s.log.Tracef("ping STUN (nominate candidate pair) from %s to %s\n", pair.local.String(), pair.remote.String()) s.log.Tracef("ping STUN (nominate candidate pair) from %s to %s\n", pair.local.String(), pair.remote.String())
s.agent.sendBindingRequest(msg, pair.local, pair.remote) s.agent.sendBindingRequest(msg, pair.local, pair.remote)
s.nominationRequestCount++
} }
func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) { func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
s.agent.sendBindingSuccess(m, local, remote) s.agent.sendBindingSuccess(m, local, remote)
p := s.agent.findValidPair(local, remote) p := s.agent.findPair(local, remote)
if p != nil && s.nominatedPair == nil && s.agent.selectedPair == nil {
s.nominatedPair = p if p == nil {
s.nominatePair(p) s.agent.addPair(local, remote)
return
}
if p.state == candidatePairStateValid && s.nominatedPair == nil && s.agent.selectedPair == nil {
bestPair := s.agent.getBestAvailableCandidatePair()
if bestPair == nil {
s.log.Tracef("No best pair available\n")
} else if bestPair.Equal(p) && s.isNominatable(p.local) && s.isNominatable(p.remote) {
s.log.Tracef("The candidate (%s, %s) is the best candidate available, marking it as nominated\n",
p.local.String(), p.remote.String())
s.nominatedPair = p
s.nominatePair(p)
}
} }
} }
@@ -89,8 +156,16 @@ func (s *controllingSelector) HandleSucessResponse(m *stun.Message, local, remot
} }
s.log.Tracef("inbound STUN (SuccessResponse) from %s to %s", remote.String(), local.String()) s.log.Tracef("inbound STUN (SuccessResponse) from %s to %s", remote.String(), local.String())
p := s.agent.addValidPair(local, remote) p := s.agent.findPair(local, remote)
if p == nil {
// This shouldn't happen
s.log.Error("Success response from invalid candidate pair")
return
}
p.state = candidatePairStateValid
s.log.Tracef("Found valid candidate pair: %s", p)
if pendingRequest.isUseCandidate && s.agent.selectedPair == nil { if pendingRequest.isUseCandidate && s.agent.selectedPair == nil {
s.agent.setSelectedPair(p) s.agent.setSelectedPair(p)
} }
@@ -173,15 +248,31 @@ func (s *controlledSelector) HandleSucessResponse(m *stun.Message, local, remote
} }
s.log.Tracef("inbound STUN (SuccessResponse) from %s to %s", remote.String(), local.String()) s.log.Tracef("inbound STUN (SuccessResponse) from %s to %s", remote.String(), local.String())
s.agent.addValidPair(local, remote)
p := s.agent.findPair(local, remote)
if p == nil {
// This shouldn't happen
s.log.Error("Success response from invalid candidate pair")
return
}
p.state = candidatePairStateValid
s.log.Tracef("Found valid candidate pair: %s", p)
} }
func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) { func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
if m.Contains(stun.AttrUseCandidate) { useCandidate := m.Contains(stun.AttrUseCandidate)
// https://tools.ietf.org/html/rfc8445#section-7.3.1.5
p := s.agent.findValidPair(local, remote)
if p != nil { p := s.agent.findPair(local, remote)
if p == nil {
p = s.agent.addPair(local, remote)
}
if useCandidate {
// https://tools.ietf.org/html/rfc8445#section-7.3.1.5
if p.state == candidatePairStateValid {
// If the state of this pair is Succeeded, it means that the check // If the state of this pair is Succeeded, it means that the check
// previously sent by this pair produced a successful response and // previously sent by this pair produced a successful response and
// generated a valid pair (Section 7.2.5.3.2). The agent sets the // generated a valid pair (Section 7.2.5.3.2). The agent sets the

View File

@@ -334,10 +334,11 @@ func copyCandidate(o Candidate) Candidate {
case *CandidateHost: case *CandidateHost:
return &CandidateHost{ return &CandidateHost{
candidateBase{ candidateBase{
networkType: orig.networkType, candidateType: orig.candidateType,
ip: orig.ip, networkType: orig.networkType,
port: orig.port, ip: orig.ip,
component: orig.component, port: orig.port,
component: orig.component,
}, },
} }
default: default: