Upgrade golangci-lint, more linters

Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
Joe Turki
2025-01-17 08:21:15 -06:00
parent 647b9786dd
commit cad1676659
67 changed files with 1619 additions and 1019 deletions

View File

@@ -25,17 +25,32 @@ linters-settings:
- ^os.Exit$
- ^panic$
- ^print(ln)?$
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters:
enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@@ -46,18 +61,17 @@ linters:
- forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code
@@ -65,9 +79,15 @@ linters:
- grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes
@@ -75,28 +95,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace
disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- gomnd # An analyzer to detect magic numbers.
- gochecknoinits # Checks that no init functions are present in Go code
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- mnd # An analyzer to detect magic numbers
- nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated
@@ -104,8 +118,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- varnamelen # checks that the length of a variable's name matches its scope
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines!

View File

@@ -21,7 +21,12 @@ type activeTCPConn struct {
closed int32
}
func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress netip.AddrPort, log logging.LeveledLogger) (a *activeTCPConn) {
func newActiveTCPConn(
ctx context.Context,
localAddress string,
remoteAddress netip.AddrPort,
log logging.LeveledLogger,
) (a *activeTCPConn) {
a = &activeTCPConn{
readBuffer: packetio.NewBuffer(),
writeBuffer: packetio.NewBuffer(),
@@ -31,7 +36,8 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
if err != nil {
atomic.StoreInt32(&a.closed, 1)
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
return
return a
}
a.localAddr.Store(laddr)
@@ -46,6 +52,7 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String())
if err != nil {
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
return
}
a.remoteAddr.Store(conn.RemoteAddr())
@@ -57,11 +64,13 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
n, err := readStreamingPacket(conn, buff)
if err != nil {
log.Infof("Failed to read streaming packet: %s", err)
break
}
if _, err := a.readBuffer.Write(buff[:n]); err != nil {
log.Infof("Failed to write to buffer: %s", err)
break
}
}
@@ -73,11 +82,13 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
n, err := a.writeBuffer.Read(buff)
if err != nil {
log.Infof("Failed to read from buffer: %s", err)
break
}
if _, err = writeStreamingPacket(conn, buff[:n]); err != nil {
log.Infof("Failed to write streaming packet: %s", err)
break
}
}
@@ -98,6 +109,7 @@ func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err erro
n, err = a.readBuffer.Read(buff)
// RemoteAddr is assuredly set *after* we can read from the buffer
srcAddr = a.RemoteAddr()
return
}
@@ -113,6 +125,7 @@ func (a *activeTCPConn) Close() error {
atomic.StoreInt32(&a.closed, 1)
_ = a.readBuffer.Close()
_ = a.writeBuffer.Close()
return nil
}

View File

@@ -21,19 +21,25 @@ import (
)
func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr {
t.Helper()
net, err := stdnet.NewNet()
require.NoError(t, err)
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false)
require.NoError(t, err)
require.NotEmpty(t, localAddrs)
return localAddrs[0]
}
func ipv6Available(t *testing.T) bool {
t.Helper()
net, err := stdnet.NewNet()
require.NoError(t, err)
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false)
require.NoError(t, err)
return len(localAddrs) > 0
}
@@ -89,14 +95,14 @@ func TestActiveTCP(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
r := require.New(t)
req := require.New(t)
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: testCase.listenIPAddress.AsSlice(),
Port: listenPort,
Zone: testCase.listenIPAddress.Zone(),
})
r.NoError(err)
req.NoError(err)
defer func() {
_ = listener.Close()
}()
@@ -113,7 +119,7 @@ func TestActiveTCP(t *testing.T) {
_ = tcpMux.Close()
}()
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
req.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
hostAcceptanceMinWait := 100 * time.Millisecond
cfg := &AgentConfig{
@@ -128,8 +134,8 @@ func TestActiveTCP(t *testing.T) {
cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather
}
passiveAgent, err := NewAgent(cfg)
r.NoError(err)
r.NotNil(passiveAgent)
req.NoError(err)
req.NotNil(passiveAgent)
activeAgent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeHost},
@@ -138,44 +144,44 @@ func TestActiveTCP(t *testing.T) {
HostAcceptanceMinWait: &hostAcceptanceMinWait,
InterfaceFilter: problematicNetworkInterfaces,
})
r.NoError(err)
r.NotNil(activeAgent)
req.NoError(err)
req.NotNil(activeAgent)
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
r.NotNil(passiveAgentConn)
r.NotNil(activeAgenConn)
req.NotNil(passiveAgentConn)
req.NotNil(activeAgenConn)
defer func() {
r.NoError(activeAgenConn.Close())
r.NoError(passiveAgentConn.Close())
req.NoError(activeAgenConn.Close())
req.NoError(passiveAgentConn.Close())
}()
pair := passiveAgent.getSelectedPair()
r.NotNil(pair)
r.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort())
req.NotNil(pair)
req.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort())
foo := []byte("foo")
_, err = passiveAgentConn.Write(foo)
r.NoError(err)
req.NoError(err)
buffer := make([]byte, 1024)
n, err := activeAgenConn.Read(buffer)
r.NoError(err)
r.Equal(foo, buffer[:n])
req.NoError(err)
req.Equal(foo, buffer[:n])
bar := []byte("bar")
_, err = activeAgenConn.Write(bar)
r.NoError(err)
req.NoError(err)
n, err = passiveAgentConn.Read(buffer)
r.NoError(err)
r.Equal(bar, buffer[:n])
req.NoError(err)
req.Equal(bar, buffer[:n])
})
}
}
// Assert that Active TCP connectivity isn't established inside
// the main thread of the Agent
// Assert that Active TCP connectivity isn't established inside.
// the main thread of the Agent.
func TestActiveTCP_NonBlocking(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -219,7 +225,7 @@ func TestActiveTCP_NonBlocking(t *testing.T) {
<-isConnected
}
// Assert that we ignore remote TCP candidates when running a UDP Only Agent
// Assert that we ignore remote TCP candidates when running a UDP Only Agent.
func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
@@ -271,7 +277,9 @@ func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
})
require.NoError(t, err)
invalidCandidate, err := UnmarshalCandidate(fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port))
invalidCandidate, err := UnmarshalCandidate(
fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port),
)
require.NoError(t, err)
require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate))
require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate))

18
addr.go
View File

@@ -16,6 +16,7 @@ func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr {
if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) {
return addr.WithZone(zone)
}
return addr
}
@@ -30,22 +31,25 @@ func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkTy
// net.IPNet does not have a Zone but we provide it from the interface
addr = addrWithOptionalZone(addr, ifcName)
}
return addr, port, nt, nil
}
func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { //nolint:cyclop
switch addr := in.(type) {
case *net.IPNet:
ipAddr, err := ipAddrToNetIP(addr.IP, "")
if err != nil {
return netip.Addr{}, 0, 0, err
}
return ipAddr, 0, 0, nil
case *net.IPAddr:
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
if err != nil {
return netip.Addr{}, 0, 0, err
}
return ipAddr, 0, 0, nil
case *net.UDPAddr:
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
@@ -58,6 +62,7 @@ func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
} else {
nt = NetworkTypeUDP6
}
return ipAddr, addr.Port, nt, nil
case *net.TCPAddr:
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
@@ -70,6 +75,7 @@ func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
} else {
nt = NetworkTypeTCP6
}
return ipAddr, addr.Port, nt, nil
default:
return netip.Addr{}, 0, 0, addrParseError{in}
@@ -100,6 +106,7 @@ func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) {
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable.
netIPAddr = netIPAddr.Unmap()
netIPAddr = addrWithOptionalZone(netIPAddr, zone)
return netIPAddr, nil
}
@@ -134,12 +141,13 @@ func toAddrPort(addr net.Addr) AddrPort {
switch addr := addr.(type) {
case *net.UDPAddr:
copy(ap[:16], addr.IP.To16())
ap[16] = uint8(addr.Port >> 8)
ap[17] = uint8(addr.Port)
ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive
ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive
case *net.TCPAddr:
copy(ap[:16], addr.IP.To16())
ap[16] = uint8(addr.Port >> 8)
ap[17] = uint8(addr.Port)
ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive
ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive
}
return ap
}

330
agent.go
View File

@@ -35,7 +35,7 @@ type bindingRequest struct {
isUseCandidate bool
}
// Agent represents the ICE agent
// Agent represents the ICE agent.
type Agent struct {
loop *taskloop.Loop
@@ -149,8 +149,8 @@ type Agent struct {
enableUseCandidateCheckPriority bool
}
// NewAgent creates a new Agent
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
// NewAgent creates a new Agent.
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit,cyclop
var err error
if config.PortMax < config.PortMin {
return nil, ErrPort
@@ -180,7 +180,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
startedCtx, startedFn := context.WithCancel(context.Background())
a := &Agent{
agent := &Agent{
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
@@ -224,34 +224,46 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority,
}
a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange, done: make(chan struct{})}
a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate, done: make(chan struct{})}
a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange, done: make(chan struct{})}
agent.connectionStateNotifier = &handlerNotifier{
connectionStateFunc: agent.onConnectionStateChange,
done: make(chan struct{}),
}
agent.candidateNotifier = &handlerNotifier{candidateFunc: agent.onCandidate, done: make(chan struct{})}
agent.selectedCandidatePairNotifier = &handlerNotifier{
candidatePairFunc: agent.onSelectedCandidatePairChange,
done: make(chan struct{}),
}
if a.net == nil {
a.net, err = stdnet.NewNet()
if agent.net == nil {
agent.net, err = stdnet.NewNet()
if err != nil {
return nil, fmt.Errorf("failed to create network: %w", err)
}
} else if _, isVirtual := a.net.(*vnet.Net); isVirtual {
a.log.Warn("Virtual network is enabled")
if a.mDNSMode != MulticastDNSModeDisabled {
a.log.Warn("Virtual network does not support mDNS yet")
} else if _, isVirtual := agent.net.(*vnet.Net); isVirtual {
agent.log.Warn("Virtual network is enabled")
if agent.mDNSMode != MulticastDNSModeDisabled {
agent.log.Warn("Virtual network does not support mDNS yet")
}
}
localIfcs, _, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback)
localIfcs, _, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
agent.networkTypes,
agent.includeLoopback,
)
if err != nil {
return nil, fmt.Errorf("error getting local interfaces: %w", err)
}
// Opportunistic mDNS: If we can't open the connection, that's ok: we
// can continue without it.
if a.mDNSConn, a.mDNSMode, err = createMulticastDNS(
a.net,
a.networkTypes,
if agent.mDNSConn, agent.mDNSMode, err = createMulticastDNS(
agent.net,
agent.networkTypes,
localIfcs,
a.includeLoopback,
agent.includeLoopback,
mDNSMode,
mDNSName,
log,
@@ -259,54 +271,60 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
}
config.initWithDefaults(a)
config.initWithDefaults(agent)
// Make sure the buffer doesn't grow indefinitely.
// NOTE: We actually won't get anywhere close to this limit.
// SRTP will constantly read from the endpoint and drop packets if it's full.
a.buf.SetLimitSize(maxBufferSize)
agent.buf.SetLimitSize(maxBufferSize)
if agent.lite && (len(agent.candidateTypes) != 1 || agent.candidateTypes[0] != CandidateTypeHost) {
agent.closeMulticastConn()
if a.lite && (len(a.candidateTypes) != 1 || a.candidateTypes[0] != CandidateTypeHost) {
a.closeMulticastConn()
return nil, ErrLiteUsingNonHostCandidates
}
if len(config.Urls) > 0 && !containsCandidateType(CandidateTypeServerReflexive, a.candidateTypes) && !containsCandidateType(CandidateTypeRelay, a.candidateTypes) {
a.closeMulticastConn()
if len(config.Urls) > 0 &&
!containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) &&
!containsCandidateType(CandidateTypeRelay, agent.candidateTypes) {
agent.closeMulticastConn()
return nil, ErrUselessUrlsProvided
}
if err = config.initExtIPMapping(a); err != nil {
a.closeMulticastConn()
if err = config.initExtIPMapping(agent); err != nil {
agent.closeMulticastConn()
return nil, err
}
a.loop = taskloop.New(func() {
a.removeUfragFromMux()
a.deleteAllCandidates()
a.startedFn()
agent.loop = taskloop.New(func() {
agent.removeUfragFromMux()
agent.deleteAllCandidates()
agent.startedFn()
if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
if err := agent.buf.Close(); err != nil {
agent.log.Warnf("Failed to close buffer: %v", err)
}
a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)
agent.closeMulticastConn()
agent.updateConnectionState(ConnectionStateClosed)
a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
agent.gatherCandidateCancel()
if agent.gatherCandidateDone != nil {
<-agent.gatherCandidateDone
}
})
// Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
a.closeMulticastConn()
_ = a.Close()
if err := agent.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
agent.closeMulticastConn()
_ = agent.Close()
return nil, err
}
return a, nil
return agent, nil
}
func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error {
@@ -348,7 +366,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
})
}
func (a *Agent) connectivityChecks() {
func (a *Agent) connectivityChecks() { //nolint:cyclop
lastConnectionState := ConnectionState(0)
checkingDuration := time.Time{}
@@ -372,6 +390,7 @@ func (a *Agent) connectivityChecks() {
// We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed
if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout {
a.updateConnectionState(ConnectionStateFailed)
return
}
default:
@@ -383,8 +402,8 @@ func (a *Agent) connectivityChecks() {
}
}
t := time.NewTimer(math.MaxInt64)
t.Stop()
timer := time.NewTimer(math.MaxInt64)
timer.Stop()
for {
interval := defaultKeepaliveInterval
@@ -406,18 +425,19 @@ func (a *Agent) connectivityChecks() {
updateInterval(a.disconnectedTimeout)
updateInterval(a.failedTimeout)
t.Reset(interval)
timer.Reset(interval)
select {
case <-a.forceCandidateContact:
if !t.Stop() {
<-t.C
if !timer.Stop() {
<-timer.C
}
contact()
case <-t.C:
case <-timer.C:
contact()
case <-a.loop.Done():
t.Stop()
timer.Stop()
return
}
}
@@ -440,22 +460,23 @@ func (a *Agent) updateConnectionState(newState ConnectionState) {
}
}
func (a *Agent) setSelectedPair(p *CandidatePair) {
if p == nil {
func (a *Agent) setSelectedPair(pair *CandidatePair) {
if pair == nil {
var nilPair *CandidatePair
a.selectedPair.Store(nilPair)
a.log.Tracef("Unset selected candidate pair")
return
}
p.nominated = true
a.selectedPair.Store(p)
a.log.Tracef("Set selected candidate pair: %s", p)
pair.nominated = true
a.selectedPair.Store(pair)
a.log.Tracef("Set selected candidate pair: %s", pair)
a.updateConnectionState(ConnectionStateConnected)
// Notify when the selected pair changes
a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(p)
a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair)
// Signal connected
a.onConnectedOnce.Do(func() { close(a.onConnected) })
@@ -498,6 +519,7 @@ func (a *Agent) getBestAvailableCandidatePair() *CandidatePair {
best = p
}
}
return best
}
@@ -514,12 +536,14 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
best = p
}
}
return best
}
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
p := newCandidatePair(local, remote, a.isControlling)
a.checklist = append(a.checklist, p)
return p
}
@@ -529,6 +553,7 @@ func (a *Agent) findPair(local, remote Candidate) *CandidatePair {
return p
}
}
return nil
}
@@ -578,68 +603,76 @@ func (a *Agent) checkKeepalive() {
}
}
// AddRemoteCandidate adds a new remote candidate
func (a *Agent) AddRemoteCandidate(c Candidate) error {
if c == nil {
// AddRemoteCandidate adds a new remote candidate.
func (a *Agent) AddRemoteCandidate(cand Candidate) error {
if cand == nil {
return nil
}
// TCP Candidates with TCP type active will probe server passive ones, so
// no need to do anything with them.
if c.TCPType() == TCPTypeActive {
a.log.Infof("Ignoring remote candidate with tcpType active: %s", c)
if cand.TCPType() == TCPTypeActive {
a.log.Infof("Ignoring remote candidate with tcpType active: %s", cand)
return nil
}
// If we have a mDNS Candidate lets fully resolve it before adding it locally
if c.Type() == CandidateTypeHost && strings.HasSuffix(c.Address(), ".local") {
if cand.Type() == CandidateTypeHost && strings.HasSuffix(cand.Address(), ".local") {
if a.mDNSMode == MulticastDNSModeDisabled {
a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", c.Address())
a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", cand.Address())
return nil
}
hostCandidate, ok := c.(*CandidateHost)
hostCandidate, ok := cand.(*CandidateHost)
if !ok {
return ErrAddressParseFailed
}
go a.resolveAndAddMulticastCandidate(hostCandidate)
return nil
}
go func() {
if err := a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck
a.addRemoteCandidate(c)
a.addRemoteCandidate(cand)
}); err != nil {
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
a.log.Warnf("Failed to add remote candidate %s: %v", cand.Address(), err)
return
}
}()
return nil
}
func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
func (a *Agent) resolveAndAddMulticastCandidate(cand *CandidateHost) {
if a.mDNSConn == nil {
return
}
_, src, err := a.mDNSConn.QueryAddr(c.context(), c.Address())
_, src, err := a.mDNSConn.QueryAddr(cand.context(), cand.Address())
if err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err)
a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err)
return
}
if err = c.setIPAddr(src); err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err)
if err = cand.setIPAddr(src); err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err)
return
}
if err = a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck
a.addRemoteCandidate(c)
a.addRemoteCandidate(cand)
}); err != nil {
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
a.log.Warnf("Failed to add mDNS candidate %s: %v", cand.Address(), err)
return
}
}
@@ -652,9 +685,16 @@ func (a *Agent) requestConnectivityCheck() {
}
func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback)
_, localIPs, err := localInterfaces(
a.net,
a.interfaceFilter,
a.ipFilter,
[]NetworkType{remoteCandidate.NetworkType()},
a.includeLoopback,
)
if err != nil {
a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err)
return
}
@@ -662,19 +702,21 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
ip, _, _, err := parseAddr(remoteCandidate.addr())
if err != nil {
a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err)
continue
}
conn := newActiveTCPConn(
a.loop,
net.JoinHostPort(localIPs[i].String(), "0"),
netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())),
netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port
a.log,
)
tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if !ok {
closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", errInvalidAddress)
continue
}
@@ -687,48 +729,52 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
})
if err != nil {
closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", err)
continue
}
localCandidate.start(a, conn, a.startedCh)
a.localCandidates[localCandidate.NetworkType()] = append(a.localCandidates[localCandidate.NetworkType()], localCandidate)
a.localCandidates[localCandidate.NetworkType()] = append(
a.localCandidates[localCandidate.NetworkType()],
localCandidate,
)
a.candidateNotifier.EnqueueCandidate(localCandidate)
a.addPair(localCandidate, remoteCandidate)
}
}
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run)
func (a *Agent) addRemoteCandidate(c Candidate) {
set := a.remoteCandidates[c.NetworkType()]
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run).
func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop
set := a.remoteCandidates[cand.NetworkType()]
for _, candidate := range set {
if candidate.Equal(c) {
if candidate.Equal(cand) {
return
}
}
acceptRemotePassiveTCPCandidate := false
// Assert that TCP4 or TCP6 is a enabled NetworkType locally
if !a.disableActiveTCP && c.TCPType() == TCPTypePassive {
if !a.disableActiveTCP && cand.TCPType() == TCPTypePassive {
for _, networkType := range a.networkTypes {
if c.NetworkType() == networkType {
if cand.NetworkType() == networkType {
acceptRemotePassiveTCPCandidate = true
}
}
}
if acceptRemotePassiveTCPCandidate {
a.addRemotePassiveTCPCandidate(c)
a.addRemotePassiveTCPCandidate(cand)
}
set = append(set, c)
a.remoteCandidates[c.NetworkType()] = set
set = append(set, cand)
a.remoteCandidates[cand.NetworkType()] = set
if c.TCPType() != TCPTypePassive {
if localCandidates, ok := a.localCandidates[c.NetworkType()]; ok {
if cand.TCPType() != TCPTypePassive {
if localCandidates, ok := a.localCandidates[cand.NetworkType()]; ok {
for _, localCandidate := range localCandidates {
a.addPair(localCandidate, c)
a.addPair(localCandidate, cand)
}
}
}
@@ -736,42 +782,43 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
a.requestConnectivityCheck()
}
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error {
return a.loop.Run(ctx, func(context.Context) {
set := a.localCandidates[c.NetworkType()]
set := a.localCandidates[cand.NetworkType()]
for _, candidate := range set {
if candidate.Equal(c) {
a.log.Debugf("Ignore duplicate candidate: %s", c)
if err := c.close(); err != nil {
if candidate.Equal(cand) {
a.log.Debugf("Ignore duplicate candidate: %s", cand)
if err := cand.close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate: %v", err)
}
if err := candidateConn.Close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate connection: %v", err)
}
return
}
}
c.start(a, candidateConn, a.startedCh)
cand.start(a, candidateConn, a.startedCh)
set = append(set, c)
a.localCandidates[c.NetworkType()] = set
set = append(set, cand)
a.localCandidates[cand.NetworkType()] = set
if remoteCandidates, ok := a.remoteCandidates[c.NetworkType()]; ok {
if remoteCandidates, ok := a.remoteCandidates[cand.NetworkType()]; ok {
for _, remoteCandidate := range remoteCandidates {
a.addPair(c, remoteCandidate)
a.addPair(cand, remoteCandidate)
}
}
a.requestConnectivityCheck()
if !c.filterForLocationTracking() {
a.candidateNotifier.EnqueueCandidate(c)
if !cand.filterForLocationTracking() {
a.candidateNotifier.EnqueueCandidate(cand)
}
})
}
// GetRemoteCandidates returns the remote candidates
// GetRemoteCandidates returns the remote candidates.
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
var res []Candidate
@@ -789,7 +836,7 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
return res, nil
}
// GetLocalCandidates returns the local candidates
// GetLocalCandidates returns the local candidates.
func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
var res []Candidate
@@ -812,7 +859,7 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
return res, nil
}
// GetLocalUserCredentials returns the local user credentials
// GetLocalUserCredentials returns the local user credentials.
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.loop.Run(a.loop, func(_ context.Context) {
@@ -824,10 +871,11 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
if err == nil {
<-valSet
}
return
}
// GetRemoteUserCredentials returns the remote user credentials
// GetRemoteUserCredentials returns the remote user credentials.
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.loop.Run(a.loop, func(_ context.Context) {
@@ -839,6 +887,7 @@ func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error)
if err == nil {
<-valSet
}
return
}
@@ -854,7 +903,7 @@ func (a *Agent) removeUfragFromMux() {
}
}
// Close cleans up the Agent
// Close cleans up the Agent.
func (a *Agent) Close() error {
return a.close(false)
}
@@ -875,13 +924,14 @@ func (a *Agent) close(graceful bool) error {
a.connectionStateNotifier.Close(graceful)
a.candidateNotifier.Close(graceful)
a.selectedCandidatePairNotifier.Close(graceful)
return nil
}
// Remove all candidates. This closes any listening sockets
// and removes both the local and remote candidate lists.
//
// This is used for restarts, failures and on close
// This is used for restarts, failures and on close.
func (a *Agent) deleteAllCandidates() {
for net, cs := range a.localCandidates {
for _, c := range cs {
@@ -905,6 +955,7 @@ func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Cand
ip, port, _, err := parseAddr(addr)
if err != nil {
a.log.Warnf("Failed to parse address: %s; error: %s", addr, err)
return nil
}
@@ -914,6 +965,7 @@ func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Cand
return c
}
}
return nil
}
@@ -937,6 +989,7 @@ func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) {
ip, port, _, err := parseAddr(base.addr())
if err != nil {
a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err)
return
}
@@ -976,70 +1029,85 @@ func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) {
}
// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination
// If the bindingRequest was valid remove it from our pending cache
// If the bindingRequest was valid remove it from our pending cache.
func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) {
a.invalidatePendingBindingRequests(time.Now())
for i := range a.pendingBindingRequests {
if a.pendingBindingRequests[i].transactionID == id {
validBindingRequest := a.pendingBindingRequests[i]
a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...)
return true, &validBindingRequest, time.Since(validBindingRequest.timestamp)
}
}
return false, nil, 0
}
// handleInbound processes STUN traffic from a remote candidate
func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit
// handleInbound processes STUN traffic from a remote candidate.
func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop
var err error
if m == nil || local == nil {
if msg == nil || local == nil {
return
}
if m.Type.Method != stun.MethodBinding ||
!(m.Type.Class == stun.ClassSuccessResponse ||
m.Type.Class == stun.ClassRequest ||
m.Type.Class == stun.ClassIndication) {
a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, m.Type.Class, m.Type.Method)
if msg.Type.Method != stun.MethodBinding ||
!(msg.Type.Class == stun.ClassSuccessResponse ||
msg.Type.Class == stun.ClassRequest ||
msg.Type.Class == stun.ClassIndication) {
a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, msg.Type.Class, msg.Type.Method)
return
}
if a.isControlling {
if m.Contains(stun.AttrICEControlling) {
if msg.Contains(stun.AttrICEControlling) {
a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true")
return
} else if m.Contains(stun.AttrUseCandidate) {
} else if msg.Contains(stun.AttrUseCandidate) {
a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true")
return
}
} else {
if m.Contains(stun.AttrICEControlled) {
if msg.Contains(stun.AttrICEControlled) {
a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false")
return
}
}
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
if m.Type.Class == stun.ClassSuccessResponse {
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(m); err != nil {
if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err)
return
}
if remoteCandidate == nil {
a.log.Warnf("Discard success message from (%s), no such remote", remote)
return
}
a.selector.HandleSuccessResponse(m, local, remoteCandidate, remote)
} else if m.Type.Class == stun.ClassRequest {
a.log.Tracef("Inbound STUN (Request) from %s to %s, useCandidate: %v", remote, local, m.Contains(stun.AttrUseCandidate))
a.selector.HandleSuccessResponse(msg, local, remoteCandidate, remote)
} else if msg.Type.Class == stun.ClassRequest {
a.log.Tracef(
"Inbound STUN (Request) from %s to %s, useCandidate: %v",
remote,
local,
msg.Contains(stun.AttrUseCandidate),
)
if err = stunx.AssertUsername(m, a.localUfrag+":"+a.remoteUfrag); err != nil {
if err = stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err)
return
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(m); err != nil {
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err)
return
}
@@ -1047,6 +1115,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
ip, port, networkType, err := parseAddr(remote)
if err != nil {
a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err)
return
}
@@ -1062,6 +1131,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)
if err != nil {
a.log.Errorf("Failed to create new remote prflx candidate (%s)", err)
return
}
remoteCandidate = prflxCandidate
@@ -1070,7 +1140,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
a.addRemoteCandidate(remoteCandidate)
}
a.selector.HandleBindingRequest(m, local, remoteCandidate)
a.selector.HandleBindingRequest(msg, local, remoteCandidate)
}
if remoteCandidate != nil {
@@ -1079,7 +1149,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
}
// validateNonSTUNTraffic processes non STUN traffic from a remote candidate,
// and returns true if it is an actual remote candidate
// and returns true if it is an actual remote candidate.
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
var remoteCandidate Candidate
if err := a.loop.Run(local.context(), func(context.Context) {
@@ -1094,7 +1164,7 @@ func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candid
return remoteCandidate, remoteCandidate != nil
}
// GetSelectedCandidatePair returns the selected pair or nil if there is none
// GetSelectedCandidatePair returns the selected pair or nil if there is none.
func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) {
selectedPair := a.getSelectedPair()
if selectedPair == nil {
@@ -1130,7 +1200,7 @@ func (a *Agent) closeMulticastConn() {
}
}
// SetRemoteCredentials sets the credentials of the remote agent
// SetRemoteCredentials sets the credentials of the remote agent.
func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
switch {
case remoteUfrag == "":
@@ -1152,7 +1222,7 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
// cancel it.
// After a Restart, the user must then call GatherCandidates explicitly
// to start generating new ones.
func (a *Agent) Restart(ufrag, pwd string) error {
func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
if ufrag == "" {
var err error
ufrag, err = generateUFrag()
@@ -1204,6 +1274,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {
}); runErr != nil {
return runErr
}
return err
}
@@ -1221,6 +1292,7 @@ func (a *Agent) setGatheringState(newState GatheringState) error {
}
<-done
return nil
}

View File

@@ -14,44 +14,44 @@ import (
)
const (
// defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase
// defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase.
defaultCheckInterval = 200 * time.Millisecond
// keepaliveInterval used to keep candidates alive
// keepaliveInterval used to keep candidates alive.
defaultKeepaliveInterval = 2 * time.Second
// defaultDisconnectedTimeout is the default time till an Agent transitions disconnected
// defaultDisconnectedTimeout is the default time till an Agent transitions disconnected.
defaultDisconnectedTimeout = 5 * time.Second
// defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected
// defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected.
defaultFailedTimeout = 25 * time.Second
// defaultHostAcceptanceMinWait is the wait time before nominating a host candidate
// defaultHostAcceptanceMinWait is the wait time before nominating a host candidate.
defaultHostAcceptanceMinWait = 0
// defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate
// defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate.
defaultSrflxAcceptanceMinWait = 500 * time.Millisecond
// defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate
// defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate.
defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond
// defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate
// defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate.
defaultRelayAcceptanceMinWait = 2000 * time.Millisecond
// defaultSTUNGatherTimeout is the wait time for STUN responses
// defaultSTUNGatherTimeout is the wait time for STUN responses.
defaultSTUNGatherTimeout = 5 * time.Second
// defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed
// defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed.
defaultMaxBindingRequests = 7
// TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference
// for host, srflx and prfx candidate types.
defaultTCPPriorityOffset = 27
// maxBufferSize is the number of bytes that can be buffered before we start to error
// maxBufferSize is the number of bytes that can be buffered before we start to error.
maxBufferSize = 1000 * 1000 // 1MB
// maxBindingRequestTimeout is the wait time before binding requests can be deleted
// maxBindingRequestTimeout is the wait time before binding requests can be deleted.
maxBindingRequestTimeout = 4000 * time.Millisecond
)
@@ -60,7 +60,7 @@ func defaultCandidateTypes() []CandidateType {
}
// AgentConfig collects the arguments to ice.Agent construction into
// a single structure, for future-proofness of the interface
// a single structure, for future-proofness of the interface.
type AgentConfig struct {
Urls []*stun.URI
@@ -209,109 +209,111 @@ type AgentConfig struct {
EnableUseCandidateCheckPriority bool
}
// initWithDefaults populates an agent and falls back to defaults if fields are unset
func (config *AgentConfig) initWithDefaults(a *Agent) {
// initWithDefaults populates an agent and falls back to defaults if fields are unset.
func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop
if config.MaxBindingRequests == nil {
a.maxBindingRequests = defaultMaxBindingRequests
agent.maxBindingRequests = defaultMaxBindingRequests
} else {
a.maxBindingRequests = *config.MaxBindingRequests
agent.maxBindingRequests = *config.MaxBindingRequests
}
if config.HostAcceptanceMinWait == nil {
a.hostAcceptanceMinWait = defaultHostAcceptanceMinWait
agent.hostAcceptanceMinWait = defaultHostAcceptanceMinWait
} else {
a.hostAcceptanceMinWait = *config.HostAcceptanceMinWait
agent.hostAcceptanceMinWait = *config.HostAcceptanceMinWait
}
if config.SrflxAcceptanceMinWait == nil {
a.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait
agent.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait
} else {
a.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait
agent.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait
}
if config.PrflxAcceptanceMinWait == nil {
a.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait
agent.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait
} else {
a.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait
agent.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait
}
if config.RelayAcceptanceMinWait == nil {
a.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
} else {
a.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
agent.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
}
if config.STUNGatherTimeout == nil {
a.stunGatherTimeout = defaultSTUNGatherTimeout
agent.stunGatherTimeout = defaultSTUNGatherTimeout
} else {
a.stunGatherTimeout = *config.STUNGatherTimeout
agent.stunGatherTimeout = *config.STUNGatherTimeout
}
if config.TCPPriorityOffset == nil {
a.tcpPriorityOffset = defaultTCPPriorityOffset
agent.tcpPriorityOffset = defaultTCPPriorityOffset
} else {
a.tcpPriorityOffset = *config.TCPPriorityOffset
agent.tcpPriorityOffset = *config.TCPPriorityOffset
}
if config.DisconnectedTimeout == nil {
a.disconnectedTimeout = defaultDisconnectedTimeout
agent.disconnectedTimeout = defaultDisconnectedTimeout
} else {
a.disconnectedTimeout = *config.DisconnectedTimeout
agent.disconnectedTimeout = *config.DisconnectedTimeout
}
if config.FailedTimeout == nil {
a.failedTimeout = defaultFailedTimeout
agent.failedTimeout = defaultFailedTimeout
} else {
a.failedTimeout = *config.FailedTimeout
agent.failedTimeout = *config.FailedTimeout
}
if config.KeepaliveInterval == nil {
a.keepaliveInterval = defaultKeepaliveInterval
agent.keepaliveInterval = defaultKeepaliveInterval
} else {
a.keepaliveInterval = *config.KeepaliveInterval
agent.keepaliveInterval = *config.KeepaliveInterval
}
if config.CheckInterval == nil {
a.checkInterval = defaultCheckInterval
agent.checkInterval = defaultCheckInterval
} else {
a.checkInterval = *config.CheckInterval
agent.checkInterval = *config.CheckInterval
}
if len(config.CandidateTypes) == 0 {
a.candidateTypes = defaultCandidateTypes()
agent.candidateTypes = defaultCandidateTypes()
} else {
a.candidateTypes = config.CandidateTypes
agent.candidateTypes = config.CandidateTypes
}
}
func (config *AgentConfig) initExtIPMapping(a *Agent) error {
func (config *AgentConfig) initExtIPMapping(agent *Agent) error { //nolint:cyclop
var err error
a.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs)
agent.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs)
if err != nil {
return err
}
if a.extIPMapper == nil {
if agent.extIPMapper == nil {
return nil // This may happen when config.NAT1To1IPs is an empty array
}
if a.extIPMapper.candidateType == CandidateTypeHost {
if a.mDNSMode == MulticastDNSModeQueryAndGather {
if agent.extIPMapper.candidateType == CandidateTypeHost { //nolint:nestif
if agent.mDNSMode == MulticastDNSModeQueryAndGather {
return ErrMulticastDNSWithNAT1To1IPMapping
}
candiHostEnabled := false
for _, candiType := range a.candidateTypes {
for _, candiType := range agent.candidateTypes {
if candiType == CandidateTypeHost {
candiHostEnabled = true
break
}
}
if !candiHostEnabled {
return ErrIneffectiveNAT1To1IPMappingHost
}
} else if a.extIPMapper.candidateType == CandidateTypeServerReflexive {
} else if agent.extIPMapper.candidateType == CandidateTypeServerReflexive {
candiSrflxEnabled := false
for _, candiType := range a.candidateTypes {
for _, candiType := range agent.candidateTypes {
if candiType == CandidateTypeServerReflexive {
candiSrflxEnabled = true
break
}
}
@@ -319,5 +321,6 @@ func (config *AgentConfig) initExtIPMapping(a *Agent) error {
return ErrIneffectiveNAT1To1IPMappingSrflx
}
}
return nil
}

View File

@@ -32,6 +32,8 @@ func TestAgentGetBestValidCandidatePair(t *testing.T) {
}
func setupTestAgentGetBestValidCandidatePair(t *testing.T) *TestAgentGetBestValidCandidatePairFixture {
t.Helper()
fixture := new(TestAgentGetBestValidCandidatePairFixture)
fixture.hostLocal = newHostLocal(t)
fixture.relayRemote = newRelayRemote(t)

View File

@@ -5,16 +5,18 @@ package ice
import "sync"
// OnConnectionStateChange sets a handler that is fired when the connection state changes
// OnConnectionStateChange sets a handler that is fired when the connection state changes.
func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error {
a.onConnectionStateChangeHdlr.Store(f)
return nil
}
// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate
// pair is selected
// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate.
// pair is selected.
func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error {
a.onSelectedCandidatePairChangeHdlr.Store(f)
return nil
}
@@ -22,6 +24,7 @@ func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) erro
// the gathering process complete the last candidate is nil.
func (a *Agent) OnCandidate(f func(Candidate)) error {
a.onCandidateHdlr.Store(f)
return nil
}
@@ -73,6 +76,7 @@ func (h *handlerNotifier) Close(graceful bool) {
select {
case <-h.done:
h.Unlock()
return
default:
}
@@ -80,7 +84,7 @@ func (h *handlerNotifier) Close(graceful bool) {
h.Unlock()
}
func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) {
h.Lock()
defer h.Unlock()
@@ -97,6 +101,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
if len(h.connectionStates) == 0 {
h.running = false
h.Unlock()
return
}
notification := h.connectionStates[0]
@@ -106,7 +111,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
}
}
h.connectionStates = append(h.connectionStates, s)
h.connectionStates = append(h.connectionStates, state)
if !h.running {
h.running = true
h.notifiers.Add(1)
@@ -114,7 +119,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
}
}
func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
func (h *handlerNotifier) EnqueueCandidate(cand Candidate) {
h.Lock()
defer h.Unlock()
@@ -131,6 +136,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
if len(h.candidates) == 0 {
h.running = false
h.Unlock()
return
}
notification := h.candidates[0]
@@ -140,7 +146,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
}
}
h.candidates = append(h.candidates, c)
h.candidates = append(h.candidates, cand)
if !h.running {
h.running = true
h.notifiers.Add(1)
@@ -148,7 +154,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
}
}
func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) {
h.Lock()
defer h.Unlock()
@@ -165,6 +171,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
if len(h.selectedCandidatePairs) == 0 {
h.running = false
h.Unlock()
return
}
notification := h.selectedCandidatePairs[0]
@@ -174,7 +181,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
}
}
h.selectedCandidatePairs = append(h.selectedCandidatePairs, p)
h.selectedCandidatePairs = append(h.selectedCandidatePairs, pair)
if !h.running {
h.running = true
h.notifiers.Add(1)

View File

@@ -15,7 +15,7 @@ func TestConnectionStateNotifier(t *testing.T) {
defer test.CheckRoutines(t)()
updates := make(chan struct{}, 1)
c := &handlerNotifier{
notifier := &handlerNotifier{
connectionStateFunc: func(_ ConnectionState) {
updates <- struct{}{}
},
@@ -24,7 +24,7 @@ func TestConnectionStateNotifier(t *testing.T) {
// Enqueue all updates upfront to ensure that it
// doesn't block
for i := 0; i < 10000; i++ {
c.EnqueueConnectionState(ConnectionStateNew)
notifier.EnqueueConnectionState(ConnectionStateNew)
}
done := make(chan struct{})
go func() {
@@ -39,12 +39,12 @@ func TestConnectionStateNotifier(t *testing.T) {
close(done)
}()
<-done
c.Close(true)
notifier.Close(true)
})
t.Run("TestUpdateOrdering", func(t *testing.T) {
defer test.CheckRoutines(t)()
updates := make(chan ConnectionState)
c := &handlerNotifier{
notifer := &handlerNotifier{
connectionStateFunc: func(cs ConnectionState) {
updates <- cs
},
@@ -66,9 +66,9 @@ func TestConnectionStateNotifier(t *testing.T) {
close(done)
}()
for i := 0; i < 10000; i++ {
c.EnqueueConnectionState(ConnectionState(i))
notifer.EnqueueConnectionState(ConnectionState(i))
}
<-done
c.Close(true)
notifer.Close(true)
})
}

View File

@@ -34,17 +34,23 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
}
func fixtureTestOnSelectedCandidatePairChange(t *testing.T) (*Agent, *CandidatePair) {
t.Helper()
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
candidatePair := makeCandidatePair(t)
return agent, candidatePair
}
func makeCandidatePair(t *testing.T) *CandidatePair {
t.Helper()
hostLocal := newHostLocal(t)
relayRemote := newRelayRemote(t)
candidatePair := newCandidatePair(hostLocal, relayRemote, false)
return candidatePair
}

View File

@@ -8,7 +8,7 @@ import (
"time"
)
// GetCandidatePairsStats returns a list of candidate pair stats
// GetCandidatePairsStats returns a list of candidate pair stats.
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
var res []CandidatePairStats
err := a.loop.Run(a.loop, func(_ context.Context) {
@@ -49,13 +49,15 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
})
if err != nil {
a.log.Errorf("Failed to get candidate pairs stats: %v", err)
return []CandidatePairStats{}
}
return res
}
// GetSelectedCandidatePairStats returns a candidate pair stats for selected candidate pair.
// Returns false if there is no selected pair
// Returns false if there is no selected pair.
func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) {
isAvailable := false
var res CandidatePairStats
@@ -98,33 +100,34 @@ func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) {
})
if err != nil {
a.log.Errorf("Failed to get selected candidate pair stats: %v", err)
return CandidatePairStats{}, false
}
return res, isAvailable
}
// GetLocalCandidatesStats returns a list of local candidates stats
// GetLocalCandidatesStats returns a list of local candidates stats.
func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
var res []CandidateStats
err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidateStats, 0, len(a.localCandidates))
for networkType, localCandidates := range a.localCandidates {
for _, c := range localCandidates {
for _, cand := range localCandidates {
relayProtocol := ""
if c.Type() == CandidateTypeRelay {
if cRelay, ok := c.(*CandidateRelay); ok {
if cand.Type() == CandidateTypeRelay {
if cRelay, ok := cand.(*CandidateRelay); ok {
relayProtocol = cRelay.RelayProtocol()
}
}
stat := CandidateStats{
Timestamp: time.Now(),
ID: c.ID(),
ID: cand.ID(),
NetworkType: networkType,
IP: c.Address(),
Port: c.Port(),
CandidateType: c.Type(),
Priority: c.Priority(),
IP: cand.Address(),
Port: cand.Port(),
CandidateType: cand.Type(),
Priority: cand.Priority(),
// URL string
RelayProtocol: relayProtocol,
// Deleted bool
@@ -136,12 +139,14 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
})
if err != nil {
a.log.Errorf("Failed to get candidate pair stats: %v", err)
return []CandidateStats{}
}
return res
}
// GetRemoteCandidatesStats returns a list of remote candidates stats
// GetRemoteCandidatesStats returns a list of remote candidates stats.
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
var res []CandidateStats
err := a.loop.Run(a.loop, func(_ context.Context) {
@@ -166,7 +171,9 @@ func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
})
if err != nil {
a.log.Errorf("Failed to get candidate pair stats: %v", err)
return []CandidateStats{}
}
return res
}

View File

@@ -33,21 +33,21 @@ func (ba *BadAddr) String() string {
return "yyy"
}
func TestHandlePeerReflexive(t *testing.T) {
func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 2).Stop()
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
hostConfig := CandidateHostConfig{
Network: "udp",
@@ -64,25 +64,25 @@ func TestHandlePeerReflexive(t *testing.T) {
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(),
AttrControlling(a.tieBreaker),
AttrControlling(agent.tieBreaker),
PriorityAttr(local.Priority()),
stun.NewShortTermIntegrity(a.localPwd),
stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
a.handleInbound(msg, local, remote)
agent.handleInbound(msg, local, remote)
// Length of remote candidate list must be one now
if len(a.remoteCandidates) != 1 {
if len(agent.remoteCandidates) != 1 {
t.Fatal("failed to add a network type to the remote candidate list")
}
// Length of remote candidate list for a network type must be 1
set := a.remoteCandidates[local.NetworkType()]
set := agent.remoteCandidates[local.NetworkType()]
if len(set) != 1 {
t.Fatal("failed to add prflx candidate to remote candidate list")
}
@@ -104,14 +104,14 @@ func TestHandlePeerReflexive(t *testing.T) {
})
t.Run("Bad network type with handleInbound()", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
hostConfig := CandidateHostConfig{
Network: "tcp",
@@ -127,26 +127,26 @@ func TestHandlePeerReflexive(t *testing.T) {
remote := &BadAddr{}
// nolint: contextcheck
a.handleInbound(nil, local, remote)
agent.handleInbound(nil, local, remote)
if len(a.remoteCandidates) != 0 {
if len(agent.remoteCandidates) != 0 {
t.Fatal("bad address should not be added to the remote candidate list")
}
}))
})
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC")
a.pendingBindingRequests = []bindingRequest{
agent.pendingBindingRequests = []bindingRequest{
{time.Now(), tID, &net.UDPAddr{}, false},
}
@@ -164,14 +164,14 @@ func TestHandlePeerReflexive(t *testing.T) {
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
stun.NewShortTermIntegrity(a.remotePwd),
stun.NewShortTermIntegrity(agent.remotePwd),
stun.Fingerprint,
)
require.NoError(t, err)
// nolint: contextcheck
a.handleInbound(msg, local, remote)
if len(a.remoteCandidates) != 0 {
agent.handleInbound(msg, local, remote)
if len(agent.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate")
}
}))
@@ -281,6 +281,7 @@ func TestConnectivityOnStartup(t *testing.T) {
// Ensure accepted
<-accepted
return aConn, bConn
}(aAgent, bAgent)
@@ -308,9 +309,9 @@ func TestConnectivityLite(t *testing.T) {
MappingBehavior: vnet.EndpointIndependent,
FilteringBehavior: vnet.EndpointIndependent,
}
v, err := buildVNet(natType, natType)
vent, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed")
defer v.close()
defer vent.close()
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
@@ -319,7 +320,7 @@ func TestConnectivityLite(t *testing.T) {
Urls: []*stun.URI{stunServerURL},
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: v.net0,
Net: vent.net0,
}
aAgent, err := NewAgent(cfg0)
@@ -335,7 +336,7 @@ func TestConnectivityLite(t *testing.T) {
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: v.net1,
Net: vent.net1,
}
bAgent, err := NewAgent(cfg1)
@@ -353,7 +354,7 @@ func TestConnectivityLite(t *testing.T) {
<-bConnected
}
func TestInboundValidity(t *testing.T) {
func TestInboundValidity(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
buildMsg := func(class stun.MessageClass, username, key string) *stun.Message {
@@ -381,21 +382,21 @@ func TestInboundValidity(t *testing.T) {
}
t.Run("Invalid Binding requests should be discarded", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Error constructing ice.Agent")
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
a.handleInbound(buildMsg(stun.ClassRequest, "invalid", a.localPwd), local, remote)
if len(a.remoteCandidates) == 1 {
agent.handleInbound(buildMsg(stun.ClassRequest, "invalid", agent.localPwd), local, remote)
if len(agent.remoteCandidates) == 1 {
t.Fatal("Binding with invalid Username was able to create prflx candidate")
}
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
if len(a.remoteCandidates) == 1 {
agent.handleInbound(buildMsg(stun.ClassRequest, agent.localUfrag+":"+agent.remoteUfrag, "Invalid"), local, remote)
if len(agent.remoteCandidates) == 1 {
t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate")
}
})
@@ -452,35 +453,35 @@ func TestInboundValidity(t *testing.T) {
})
t.Run("Valid bind without fingerprint", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
agent.selector = &controllingSelector{agent: agent, log: agent.log}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
stun.NewShortTermIntegrity(a.localPwd),
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(agent.localPwd),
)
require.NoError(t, err)
// nolint: contextcheck
a.handleInbound(msg, local, remote)
if len(a.remoteCandidates) != 1 {
agent.handleInbound(msg, local, remote)
if len(agent.remoteCandidates) != 1 {
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
}
}))
})
t.Run("Success with invalid TransactionID", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Error constructing ice.Agent")
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
hostConfig := CandidateHostConfig{
@@ -499,13 +500,13 @@ func TestInboundValidity(t *testing.T) {
tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC")
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
stun.NewShortTermIntegrity(a.remotePwd),
stun.NewShortTermIntegrity(agent.remotePwd),
stun.Fingerprint,
)
require.NoError(t, err)
a.handleInbound(msg, local, remote)
if len(a.remoteCandidates) != 0 {
agent.handleInbound(msg, local, remote)
if len(agent.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate")
}
})
@@ -514,35 +515,35 @@ func TestInboundValidity(t *testing.T) {
func TestInvalidAgentStarts(t *testing.T) {
defer test.CheckRoutines(t)()
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
if _, err = a.Dial(ctx, "", "bar"); err != nil && !errors.Is(err, ErrRemoteUfragEmpty) {
if _, err = agent.Dial(ctx, "", "bar"); err != nil && !errors.Is(err, ErrRemoteUfragEmpty) {
t.Fatal(err)
}
if _, err = a.Dial(ctx, "foo", ""); err != nil && !errors.Is(err, ErrRemotePwdEmpty) {
if _, err = agent.Dial(ctx, "foo", ""); err != nil && !errors.Is(err, ErrRemotePwdEmpty) {
t.Fatal(err)
}
if _, err = a.Dial(ctx, "foo", "bar"); err != nil && !errors.Is(err, ErrCanceledByCaller) {
if _, err = agent.Dial(ctx, "foo", "bar"); err != nil && !errors.Is(err, ErrCanceledByCaller) {
t.Fatal(err)
}
if _, err = a.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) {
if _, err = agent.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) {
t.Fatal(err)
}
}
// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages
func TestConnectionStateCallback(t *testing.T) {
// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages.
func TestConnectionStateCallback(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop()
@@ -635,18 +636,18 @@ func TestInvalidGather(t *testing.T) {
})
}
func TestCandidatePairsStats(t *testing.T) {
func TestCandidatePairsStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Failed to create agent: %s", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
hostConfig := &CandidateHostConfig{
@@ -711,21 +712,21 @@ func TestCandidatePairsStats(t *testing.T) {
}
for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} {
p := a.findPair(hostLocal, remote)
p := agent.findPair(hostLocal, remote)
if p == nil {
a.addPair(hostLocal, remote)
agent.addPair(hostLocal, remote)
}
}
p := a.findPair(hostLocal, prflxRemote)
p := agent.findPair(hostLocal, prflxRemote)
p.state = CandidatePairStateFailed
for i := 0; i < 10; i++ {
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
}
stats := a.GetCandidatePairsStats()
stats := agent.GetCandidatePairsStats()
if len(stats) != 4 {
t.Fatal("expected 4 candidate pairs stats")
}
@@ -789,18 +790,18 @@ func TestCandidatePairsStats(t *testing.T) {
}
}
func TestSelectedCandidatePairStats(t *testing.T) {
func TestSelectedCandidatePairStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Failed to create agent: %s", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
hostConfig := &CandidateHostConfig{
@@ -828,23 +829,23 @@ func TestSelectedCandidatePairStats(t *testing.T) {
}
// no selected pair, should return not available
_, ok := a.GetSelectedCandidatePairStats()
_, ok := agent.GetSelectedCandidatePairStats()
require.False(t, ok)
// add pair and populate some RTT stats
p := a.findPair(hostLocal, srflxRemote)
p := agent.findPair(hostLocal, srflxRemote)
if p == nil {
a.addPair(hostLocal, srflxRemote)
p = a.findPair(hostLocal, srflxRemote)
agent.addPair(hostLocal, srflxRemote)
p = agent.findPair(hostLocal, srflxRemote)
}
for i := 0; i < 10; i++ {
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
}
// set the pair as selected
a.setSelectedPair(p)
agent.setSelectedPair(p)
stats, ok := a.GetSelectedCandidatePairStats()
stats, ok := agent.GetSelectedCandidatePairStats()
require.True(t, ok)
if stats.LocalCandidateID != hostLocal.ID() {
@@ -872,18 +873,18 @@ func TestSelectedCandidatePairStats(t *testing.T) {
}
}
func TestLocalCandidateStats(t *testing.T) {
func TestLocalCandidateStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Failed to create agent: %s", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
hostConfig := &CandidateHostConfig{
@@ -910,9 +911,9 @@ func TestLocalCandidateStats(t *testing.T) {
t.Fatalf("Failed to construct local srflx candidate: %s", err)
}
a.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal}
agent.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal}
localStats := a.GetLocalCandidatesStats()
localStats := agent.GetLocalCandidatesStats()
if len(localStats) != 2 {
t.Fatalf("expected 2 local candidates stats, got %d instead", len(localStats))
}
@@ -953,18 +954,18 @@ func TestLocalCandidateStats(t *testing.T) {
}
}
func TestRemoteCandidateStats(t *testing.T) {
func TestRemoteCandidateStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
// Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Failed to create agent: %s", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
relayConfig := &CandidateRelayConfig{
@@ -1017,9 +1018,9 @@ func TestRemoteCandidateStats(t *testing.T) {
t.Fatalf("Failed to construct remote host candidate: %s", err)
}
a.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote}
agent.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote}
remoteStats := a.GetRemoteCandidatesStats()
remoteStats := agent.GetRemoteCandidatesStats()
if len(remoteStats) != 4 {
t.Fatalf("expected 4 remote candidates stats, got %d instead", len(remoteStats))
}
@@ -1076,31 +1077,31 @@ func TestRemoteCandidateStats(t *testing.T) {
func TestInitExtIPMapping(t *testing.T) {
defer test.CheckRoutines(t)()
// a.extIPMapper should be nil by default
a, err := NewAgent(&AgentConfig{})
// agent.extIPMapper should be nil by default
agent, err := NewAgent(&AgentConfig{})
if err != nil {
t.Fatalf("Failed to create agent: %v", err)
}
if a.extIPMapper != nil {
require.NoError(t, a.Close())
if agent.extIPMapper != nil {
require.NoError(t, agent.Close())
t.Fatal("a.extIPMapper should be nil by default")
}
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
// a.extIPMapper should be nil when NAT1To1IPs is a non-nil empty array
a, err = NewAgent(&AgentConfig{
agent, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{},
NAT1To1IPCandidateType: CandidateTypeHost,
})
if err != nil {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
t.Fatalf("Failed to create agent: %v", err)
}
if a.extIPMapper != nil {
require.NoError(t, a.Close())
if agent.extIPMapper != nil {
require.NoError(t, agent.Close())
t.Fatal("a.extIPMapper should be nil by default")
}
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
// NewAgent should return an error when 1:1 NAT for host candidate is enabled
// but the candidate type does not appear in the CandidateTypes.
@@ -1150,32 +1151,38 @@ func TestBindingRequestTimeout(t *testing.T) {
const expectedRemovalCount = 2
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
now := time.Now()
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now, // Valid
})
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-3900 * time.Millisecond), // Valid
})
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-4100 * time.Millisecond), // Invalid
})
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-75 * time.Hour), // Invalid
})
a.invalidatePendingBindingRequests(now)
require.Equal(t, expectedRemovalCount, len(a.pendingBindingRequests), "Binding invalidation due to timeout did not remove the correct number of binding requests")
agent.invalidatePendingBindingRequests(now)
require.Equal(
t,
expectedRemovalCount,
len(agent.pendingBindingRequests),
"Binding invalidation due to timeout did not remove the correct number of binding requests",
)
}
// TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard
// and ensure it's backwards compatible with previous versions of the pion/ice
// and ensure it's backwards compatible with previous versions of the pion/ice.
func TestAgentCredentials(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -1207,7 +1214,7 @@ func TestAgentCredentials(t *testing.T) {
}
// Assert that Agent on Failure deletes all existing candidates
// User can then do an ICE Restart to bring agent back
// User can then do an ICE Restart to bring agent back.
func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -1254,7 +1261,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
<-done
}
// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides
// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides.
func TestConnectionStateConnectingToFailed(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -1378,6 +1385,7 @@ func TestAgentRestart(t *testing.T) {
out += c.Address() + ":"
out += strconv.Itoa(c.Port())
}
return
}
@@ -1423,33 +1431,33 @@ func TestAgentRestart(t *testing.T) {
func TestGetRemoteCredentials(t *testing.T) {
var config AgentConfig
a, err := NewAgent(&config)
agent, err := NewAgent(&config)
if err != nil {
t.Fatalf("Error constructing ice.Agent: %v", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
a.remoteUfrag = "remoteUfrag"
a.remotePwd = "remotePwd"
agent.remoteUfrag = "remoteUfrag"
agent.remotePwd = "remotePwd"
actualUfrag, actualPwd, err := a.GetRemoteUserCredentials()
actualUfrag, actualPwd, err := agent.GetRemoteUserCredentials()
require.NoError(t, err)
require.Equal(t, actualUfrag, a.remoteUfrag)
require.Equal(t, actualPwd, a.remotePwd)
require.Equal(t, actualUfrag, agent.remoteUfrag)
require.Equal(t, actualPwd, agent.remotePwd)
}
func TestGetRemoteCandidates(t *testing.T) {
var config AgentConfig
a, err := NewAgent(&config)
agent, err := NewAgent(&config)
if err != nil {
t.Fatalf("Error constructing ice.Agent: %v", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
expectedCandidates := []Candidate{}
@@ -1467,10 +1475,10 @@ func TestGetRemoteCandidates(t *testing.T) {
expectedCandidates = append(expectedCandidates, cand)
a.addRemoteCandidate(cand)
agent.addRemoteCandidate(cand)
}
actualCandidates, err := a.GetRemoteCandidates()
actualCandidates, err := agent.GetRemoteCandidates()
require.NoError(t, err)
require.ElementsMatch(t, expectedCandidates, actualCandidates)
}
@@ -1478,12 +1486,12 @@ func TestGetRemoteCandidates(t *testing.T) {
func TestGetLocalCandidates(t *testing.T) {
var config AgentConfig
a, err := NewAgent(&config)
agent, err := NewAgent(&config)
if err != nil {
t.Fatalf("Error constructing ice.Agent: %v", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
dummyConn := &net.UDPConn{}
@@ -1502,11 +1510,11 @@ func TestGetLocalCandidates(t *testing.T) {
expectedCandidates = append(expectedCandidates, cand)
err = a.addCandidate(context.Background(), cand, dummyConn)
err = agent.addCandidate(context.Background(), cand, dummyConn)
require.NoError(t, err)
}
actualCandidates, err := a.GetLocalCandidates()
actualCandidates, err := agent.GetLocalCandidates()
require.NoError(t, err)
require.ElementsMatch(t, expectedCandidates, actualCandidates)
}
@@ -1666,7 +1674,7 @@ func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) {
<-isTested
}
// Assert that a Lite agent goes to disconnected and failed
// Assert that a Lite agent goes to disconnected and failed.
func TestLiteLifecycle(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -1818,7 +1826,7 @@ func TestGetSelectedCandidatePair(t *testing.T) {
require.NoError(t, wan.Stop())
}
func TestAcceptAggressiveNomination(t *testing.T) {
func TestAcceptAggressiveNomination(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
@@ -1932,24 +1940,25 @@ func TestAcceptAggressiveNomination(t *testing.T) {
bcandidates, err = bAgent.GetLocalCandidates()
require.NoError(t, err)
for _, c := range bcandidates {
if c != bAgent.getSelectedPair().Local {
for _, cand := range bcandidates {
if cand != bAgent.getSelectedPair().Local { //nolint:nestif
if expectNewSelectedCandidate == nil {
expected_change_priority:
for _, candidates := range aAgent.remoteCandidates {
for _, candidate := range candidates {
if candidate.Equal(c) {
if candidate.Equal(cand) {
if tc.useHigherPriority {
candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert
} else {
candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert
}
break expected_change_priority
}
}
}
if tc.isExpectedToSwitch {
expectNewSelectedCandidate = c
expectNewSelectedCandidate = cand
} else {
expectNewSelectedCandidate = aAgent.getSelectedPair().Remote
}
@@ -1958,18 +1967,27 @@ func TestAcceptAggressiveNomination(t *testing.T) {
change_priority:
for _, candidates := range aAgent.remoteCandidates {
for _, candidate := range candidates {
if candidate.Equal(c) {
if candidate.Equal(cand) {
if tc.useHigherPriority {
candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert
} else {
candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert
}
break change_priority
}
}
}
}
_, err = c.writeTo(buildMsg(stun.ClassRequest, aAgent.localUfrag+":"+aAgent.remoteUfrag, aAgent.localPwd, c.Priority()).Raw, bAgent.getSelectedPair().Remote)
_, err = cand.writeTo(
buildMsg(
stun.ClassRequest,
aAgent.localUfrag+":"+aAgent.remoteUfrag,
aAgent.localPwd,
cand.Priority(),
).Raw,
bAgent.getSelectedPair().Remote,
)
require.NoError(t, err)
}
}
@@ -1991,7 +2009,7 @@ func TestAcceptAggressiveNomination(t *testing.T) {
require.NoError(t, wan.Stop())
}
// Close can deadlock but GracefulClose must not
// Close can deadlock but GracefulClose must not.
func TestAgentGracefulCloseDeadlock(t *testing.T) {
defer test.CheckRoutinesStrict(t)()
defer test.TimeOut(time.Second * 5).Stop()

View File

@@ -16,7 +16,7 @@ import (
"github.com/stretchr/testify/require"
)
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux.
func TestMuxAgent(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -58,7 +58,7 @@ func TestMuxAgent(t *testing.T) {
require.NoError(t, muxedA.Close())
}()
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(),
})
@@ -68,10 +68,10 @@ func TestMuxAgent(t *testing.T) {
if aClosed {
return
}
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
conn, muxedConn := connect(a, muxedA)
conn, muxedConn := connect(agent, muxedA)
pair := muxedA.getSelectedPair()
require.NotNil(t, pair)

View File

@@ -13,13 +13,13 @@ const (
receiveMTU = 8192
defaultLocalPreference = 65535
// ComponentRTP indicates that the candidate is used for RTP
// ComponentRTP indicates that the candidate is used for RTP.
ComponentRTP uint16 = 1
// ComponentRTCP indicates that the candidate is used for RTCP
// ComponentRTCP indicates that the candidate is used for RTCP.
ComponentRTCP
)
// Candidate represents an ICE candidate
// Candidate represents an ICE candidate.
type Candidate interface {
// An arbitrary string used in the freezing algorithm to
// group similar candidates. It is the same for two candidates that

View File

@@ -48,12 +48,12 @@ type candidateBase struct {
extensions []CandidateExtension
}
// Done implements context.Context
// Done implements context.Context.
func (c *candidateBase) Done() <-chan struct{} {
return c.closeCh
}
// Err implements context.Context
// Err implements context.Context.
func (c *candidateBase) Err() error {
select {
case <-c.closedCh:
@@ -63,17 +63,17 @@ func (c *candidateBase) Err() error {
}
}
// Deadline implements context.Context
// Deadline implements context.Context.
func (c *candidateBase) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Value implements context.Context
// Value implements context.Context.
func (c *candidateBase) Value(interface{}) interface{} {
return nil
}
// ID returns Candidate ID
// ID returns Candidate ID.
func (c *candidateBase) ID() string {
return c.id
}
@@ -86,27 +86,27 @@ func (c *candidateBase) Foundation() string {
return fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(c.Type().String()+c.address+c.networkType.String())))
}
// Address returns Candidate Address
// Address returns Candidate Address.
func (c *candidateBase) Address() string {
return c.address
}
// Port returns Candidate Port
// Port returns Candidate Port.
func (c *candidateBase) Port() int {
return c.port
}
// Type returns candidate type
// Type returns candidate type.
func (c *candidateBase) Type() CandidateType {
return c.candidateType
}
// NetworkType returns candidate NetworkType
// NetworkType returns candidate NetworkType.
func (c *candidateBase) NetworkType() NetworkType {
return c.networkType
}
// Component returns candidate component
// Component returns candidate component.
func (c *candidateBase) Component() uint16 {
return c.component
}
@@ -115,8 +115,8 @@ func (c *candidateBase) SetComponent(component uint16) {
c.component = component
}
// LocalPreference returns the local preference for this candidate
func (c *candidateBase) LocalPreference() uint16 {
// LocalPreference returns the local preference for this candidate.
func (c *candidateBase) LocalPreference() uint16 { //nolint:cyclop
if c.NetworkType().IsTCP() {
// RFC 6544, section 4.2
//
@@ -182,6 +182,7 @@ func (c *candidateBase) LocalPreference() uint16 {
case CandidateTypeUnspecified:
return 0
}
return 0
}()
@@ -191,7 +192,7 @@ func (c *candidateBase) LocalPreference() uint16 {
return defaultLocalPreference
}
// RelatedAddress returns *CandidateRelatedAddress
// RelatedAddress returns *CandidateRelatedAddress.
func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress {
return c.relatedAddress
}
@@ -200,10 +201,11 @@ func (c *candidateBase) TCPType() TCPType {
return c.tcpType
}
// start runs the candidate using the provided connection
// start runs the candidate using the provided connection.
func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) {
if c.conn != nil {
c.agent().log.Warn("Can't start already started candidateBase")
return
}
c.currAgent = a
@@ -221,7 +223,7 @@ var bufferPool = sync.Pool{ // nolint:gochecknoglobals
}
func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
a := c.agent()
agent := c.agent()
defer close(c.closedCh)
@@ -242,8 +244,9 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
n, srcAddr, err := c.conn.ReadFrom(buf)
if err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
a.log.Warnf("Failed to read from candidate %s: %v", c, err)
agent.log.Warnf("Failed to read from candidate %s: %v", c, err)
}
return
}
@@ -254,8 +257,10 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool {
if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok {
candidate.seen(false)
return true
}
return false
}
@@ -267,48 +272,51 @@ func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net
}
func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
a := c.agent()
agent := c.agent()
if stun.IsMessage(buf) {
m := &stun.Message{
msg := &stun.Message{
Raw: make([]byte, len(buf)),
}
// Explicitly copy raw buffer so Message can own the memory.
copy(m.Raw, buf)
copy(msg.Raw, buf)
if err := msg.Decode(); err != nil {
agent.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err)
if err := m.Decode(); err != nil {
a.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err)
return
}
if err := a.loop.Run(c, func(_ context.Context) {
if err := agent.loop.Run(c, func(_ context.Context) {
// nolint: contextcheck
a.handleInbound(m, c, srcAddr)
agent.handleInbound(msg, c, srcAddr)
}); err != nil {
a.log.Warnf("Failed to handle message: %v", err)
agent.log.Warnf("Failed to handle message: %v", err)
}
return
}
if !c.validateSTUNTrafficCache(srcAddr) {
remoteCandidate, valid := a.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck
remoteCandidate, valid := agent.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck
if !valid {
a.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr())
agent.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr())
return
}
c.addRemoteCandidateCache(remoteCandidate, srcAddr)
}
// Note: This will return packetio.ErrFull if the buffer ever manages to fill up.
if _, err := a.buf.Write(buf); err != nil {
a.log.Warnf("Failed to write packet: %s", err)
if _, err := agent.buf.Write(buf); err != nil {
agent.log.Warnf("Failed to write packet: %s", err)
return
}
}
// close stops the recvLoop
// close stops the recvLoop.
func (c *candidateBase) close() error {
// If conn has never been started will be nil
if c.Done() == nil {
@@ -353,13 +361,15 @@ func (c *candidateBase) writeTo(raw []byte, dst Candidate) (int, error) {
return n, err
}
c.agent().log.Infof("Failed to send packet: %v", err)
return n, nil
}
c.seen(true)
return n, nil
}
// TypePreference returns the type preference for this candidate
// TypePreference returns the type preference for this candidate.
func (c *candidateBase) TypePreference() uint16 {
pref := c.Type().Preference()
if pref == 0 {
@@ -397,7 +407,7 @@ func (c *candidateBase) Priority() uint32 {
(1<<0)*uint32(256-c.Component())
}
// Equal is used to compare two candidateBases
// Equal is used to compare two candidateBases.
func (c *candidateBase) Equal(other Candidate) bool {
if c.addr() != other.addr() {
if c.addr() == nil || other.addr() == nil {
@@ -416,22 +426,30 @@ func (c *candidateBase) Equal(other Candidate) bool {
c.RelatedAddress().Equal(other.RelatedAddress())
}
// DeepEqual is same as Equal but also compares the extensions
// DeepEqual is same as Equal but also compares the extensions.
func (c *candidateBase) DeepEqual(other Candidate) bool {
return c.Equal(other) && c.extensionsEqual(other.Extensions())
}
// String makes the candidateBase printable
// String makes the candidateBase printable.
func (c *candidateBase) String() string {
return fmt.Sprintf("%s %s %s%s (resolved: %v)", c.NetworkType(), c.Type(), net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), c.relatedAddress, c.resolvedAddr)
return fmt.Sprintf(
"%s %s %s%s (resolved: %v)",
c.NetworkType(),
c.Type(),
net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())),
c.relatedAddress,
c.resolvedAddr,
)
}
// LastReceived returns a time.Time indicating the last time
// this candidate was received
// this candidate was received.
func (c *candidateBase) LastReceived() time.Time {
if lastReceived, ok := c.lastReceived.Load().(time.Time); ok {
return lastReceived
}
return time.Time{}
}
@@ -440,11 +458,12 @@ func (c *candidateBase) setLastReceived(t time.Time) {
}
// LastSent returns a time.Time indicating the last time
// this candidate was sent
// this candidate was sent.
func (c *candidateBase) LastSent() time.Time {
if lastSent, ok := c.lastSent.Load().(time.Time); ok {
return lastSent
}
return time.Time{}
}
@@ -484,10 +503,11 @@ func removeZoneIDFromAddress(addr string) string {
if i := strings.Index(addr, "%"); i != -1 {
return addr[:i]
}
return addr
}
// Marshal returns the string representation of the ICECandidate
// Marshal returns the string representation of the ICECandidate.
func (c *candidateBase) Marshal() string {
val := c.Foundation()
if val == " " {
@@ -618,9 +638,7 @@ func (c *candidateBase) setExtensions(extensions []CandidateExtension) {
// UnmarshalCandidate Parses a candidate from a string
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
func UnmarshalCandidate(raw string) (Candidate, error) {
// rfc5245
func UnmarshalCandidate(raw string) (Candidate, error) { //nolint:cyclop
pos := 0
// foundation ( 1*32ice-char ) But we allow for empty foundation,
@@ -805,15 +823,16 @@ func UnmarshalCandidate(raw string) (Candidate, error) {
// Read an ice-char token from the raw string
// ice-char = ALPHA / DIGIT / "+" / "/"
// stop reading when a space is encountered or the end of the string
func readCandidateCharToken(raw string, start int, limit int) (string, int, error) {
// stop reading when a space is encountered or the end of the string.
func readCandidateCharToken(raw string, start int, limit int) (string, int, error) { //nolint:cyclop
for i, char := range raw[start:] {
if char == 0x20 { // SP
return raw[start : start+i], start + i + 1, nil
}
if i == limit {
return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) //nolint: err113 // handled by caller
//nolint: err113 // handled by caller
return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit)
}
if !(char >= 'A' && char <= 'Z' ||
@@ -828,7 +847,7 @@ func readCandidateCharToken(raw string, start int, limit int) (string, int, erro
}
// Read an ice string token from the raw string until a space is encountered
// Or the end of the string, we imply that ice string are UTF-8 encoded
// Or the end of the string, we imply that ice string are UTF-8 encoded.
func readCandidateStringToken(raw string, start int) (string, int) {
for i, char := range raw[start:] {
if char == 0x20 { // SP
@@ -840,7 +859,7 @@ func readCandidateStringToken(raw string, start int) (string, int) {
}
// Read a digit token from the raw string
// stop reading when a space is encountered or the end of the string
// stop reading when a space is encountered or the end of the string.
func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
var val int
for i, char := range raw[start:] {
@@ -849,7 +868,8 @@ func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
}
if i == limit {
return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) //nolint: err113 // handled by caller
//nolint: err113 // handled by caller
return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit)
}
if !(char >= '0' && char <= '9') {
@@ -862,7 +882,7 @@ func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
return val, len(raw), nil
}
// Read and validate RFC 4566 port from the raw string
// Read and validate RFC 4566 port from the raw string.
func readCandidatePort(raw string, start int) (int, int, error) {
port, pos, err := readCandidateDigitToken(raw, start, 5)
if err != nil {
@@ -878,7 +898,7 @@ func readCandidatePort(raw string, start int) (int, int, error) {
// Read a byte-string token from the raw string
// As defined in RFC 4566 1*(%x01-09/%x0B-0C/%x0E-FF) ;any byte except NUL, CR, or LF
// we imply that extensions byte-string are UTF-8 encoded
// we imply that extensions byte-string are UTF-8 encoded.
func readCandidateByteString(raw string, start int) (string, int, error) {
for i, char := range raw[start:] {
if char == 0x20 { // SP
@@ -952,17 +972,23 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
for i := 0; i < len(raw); {
key, next, err := readCandidateByteString(raw, i)
if err != nil {
return extensions, "", fmt.Errorf("%w: failed to read key %v", errParseExtension, err) //nolint: errorlint // we wrap the error
return extensions, "", fmt.Errorf(
"%w: failed to read key %v", errParseExtension, err, //nolint: errorlint // we wrap the error
)
}
i = next
if i >= len(raw) {
return extensions, "", fmt.Errorf("%w: missing value for %s in %s", errParseExtension, key, raw)
return extensions, "", fmt.Errorf(
"%w: missing value for %s in %s", errParseExtension, key, raw, //nolint: errorlint // we are wrapping the error
)
}
value, next, err := readCandidateByteString(raw, i)
if err != nil {
return extensions, "", fmt.Errorf("%w: failed to read value %v", errParseExtension, err) //nolint: errorlint // we are wrapping the error
return extensions, "", fmt.Errorf(
"%w: failed to read value %v", errParseExtension, err, //nolint: errorlint // we are wrapping the error
)
}
i = next
@@ -973,5 +999,5 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
extensions = append(extensions, CandidateExtension{key, value})
}
return
return extensions, rawTCPTypeRaw, nil
}

View File

@@ -8,14 +8,14 @@ import (
"strings"
)
// CandidateHost is a candidate of type host
// CandidateHost is a candidate of type host.
type CandidateHost struct {
candidateBase
network string
}
// CandidateHostConfig is the config required to create a new CandidateHost
// CandidateHostConfig is the config required to create a new CandidateHost.
type CandidateHostConfig struct {
CandidateID string
Network string
@@ -28,7 +28,7 @@ type CandidateHostConfig struct {
IsLocationTracked bool
}
// NewCandidateHost creates a new host candidate
// NewCandidateHost creates a new host candidate.
func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
candidateID := config.CandidateID
@@ -36,7 +36,7 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
candidateID = globalCandidateIDGenerator.Generate()
}
c := &CandidateHost{
candidateHost := &CandidateHost{
candidateBase: candidateBase{
id: candidateID,
address: config.Address,
@@ -58,15 +58,15 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
return nil, err
}
if err := c.setIPAddr(ipAddr); err != nil {
if err := candidateHost.setIPAddr(ipAddr); err != nil {
return nil, err
}
} else {
// Until mDNS candidate is resolved assume it is UDPv4
c.candidateBase.networkType = NetworkTypeUDP4
candidateHost.candidateBase.networkType = NetworkTypeUDP4
}
return c, nil
return candidateHost, nil
}
func (c *CandidateHost) setIPAddr(addr netip.Addr) error {

View File

@@ -15,7 +15,7 @@ type CandidatePeerReflexive struct {
candidateBase
}
// CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive
// CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive.
type CandidatePeerReflexiveConfig struct {
CandidateID string
Network string
@@ -28,7 +28,7 @@ type CandidatePeerReflexiveConfig struct {
RelPort int
}
// NewCandidatePeerReflexive creates a new peer reflective candidate
// NewCandidatePeerReflexive creates a new peer reflective candidate.
func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) {
ipAddr, err := netip.ParseAddr(config.Address)
if err != nil {

View File

@@ -16,7 +16,7 @@ type CandidateRelay struct {
onClose func() error
}
// CandidateRelayConfig is the config required to create a new CandidateRelay
// CandidateRelayConfig is the config required to create a new CandidateRelay.
type CandidateRelayConfig struct {
CandidateID string
Network string
@@ -31,7 +31,7 @@ type CandidateRelayConfig struct {
OnClose func() error
}
// NewCandidateRelay creates a new relay candidate
// NewCandidateRelay creates a new relay candidate.
func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
candidateID := config.CandidateID
@@ -75,7 +75,7 @@ func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
}, nil
}
// LocalPreference returns the local preference for this candidate
// LocalPreference returns the local preference for this candidate.
func (c *CandidateRelay) LocalPreference() uint16 {
// These preference values come from libwebrtc
// https://github.com/mozilla/libwebrtc/blob/1389c76d9c79839a2ca069df1db48aa3f2e6a1ac/p2p/base/turn_port.cc#L61
@@ -103,6 +103,7 @@ func (c *CandidateRelay) close() error {
err = c.onClose()
c.onClose = nil
}
return err
}

View File

@@ -13,7 +13,7 @@ type CandidateServerReflexive struct {
candidateBase
}
// CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive
// CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive.
type CandidateServerReflexiveConfig struct {
CandidateID string
Network string
@@ -26,7 +26,7 @@ type CandidateServerReflexiveConfig struct {
RelPort int
}
// NewCandidateServerReflexive creates a new server reflective candidate
// NewCandidateServerReflexive creates a new server reflective candidate.
func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) {
ipAddr, err := netip.ParseAddr(config.Address)
if err != nil {

View File

@@ -16,7 +16,7 @@ import (
const localhostIPStr = "127.0.0.1"
func TestCandidateTypePreference(t *testing.T) {
r := require.New(t)
req := require.New(t)
hostDefaultPreference := uint16(126)
prflxDefaultPreference := uint16(110)
@@ -53,16 +53,16 @@ func TestCandidateTypePreference(t *testing.T) {
}
if networkType.IsTCP() {
r.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference())
r.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference())
r.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference())
req.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference())
req.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference())
req.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference())
} else {
r.Equal(hostDefaultPreference, hostCandidate.TypePreference())
r.Equal(prflxDefaultPreference, prflxCandidate.TypePreference())
r.Equal(srflxDefaultPreference, srflxCandidate.TypePreference())
req.Equal(hostDefaultPreference, hostCandidate.TypePreference())
req.Equal(prflxDefaultPreference, prflxCandidate.TypePreference())
req.Equal(srflxDefaultPreference, srflxCandidate.TypePreference())
}
r.Equal(relayDefaultPreference, relayCandidate.TypePreference())
req.Equal(relayDefaultPreference, relayCandidate.TypePreference())
}
}
}
@@ -266,20 +266,27 @@ func TestCandidateFoundation(t *testing.T) {
}).Foundation())
}
func mustCandidateHost(conf *CandidateHostConfig) Candidate {
cand, err := NewCandidateHost(conf)
if err != nil {
panic(err)
}
return cand
}
func mustCandidateHostWithExtensions(t *testing.T, conf *CandidateHostConfig, extensions []CandidateExtension) Candidate {
func mustCandidateHost(t *testing.T, conf *CandidateHostConfig) Candidate {
t.Helper()
cand, err := NewCandidateHost(conf)
if err != nil {
panic(err)
t.Fatal(err)
}
return cand
}
func mustCandidateHostWithExtensions(
t *testing.T,
conf *CandidateHostConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidateHost(conf)
if err != nil {
t.Fatal(err)
}
cand.setExtensions(extensions)
@@ -287,20 +294,27 @@ func mustCandidateHostWithExtensions(t *testing.T, conf *CandidateHostConfig, ex
return cand
}
func mustCandidateRelay(conf *CandidateRelayConfig) Candidate {
cand, err := NewCandidateRelay(conf)
if err != nil {
panic(err)
}
return cand
}
func mustCandidateRelayWithExtensions(t *testing.T, conf *CandidateRelayConfig, extensions []CandidateExtension) Candidate {
func mustCandidateRelay(t *testing.T, conf *CandidateRelayConfig) Candidate {
t.Helper()
cand, err := NewCandidateRelay(conf)
if err != nil {
panic(err)
t.Fatal(err)
}
return cand
}
func mustCandidateRelayWithExtensions(
t *testing.T,
conf *CandidateRelayConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidateRelay(conf)
if err != nil {
t.Fatal(err)
}
cand.setExtensions(extensions)
@@ -308,20 +322,27 @@ func mustCandidateRelayWithExtensions(t *testing.T, conf *CandidateRelayConfig,
return cand
}
func mustCandidateServerReflexive(conf *CandidateServerReflexiveConfig) Candidate {
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
panic(err)
}
return cand
}
func mustCandidateServerReflexiveWithExtensions(t *testing.T, conf *CandidateServerReflexiveConfig, extensions []CandidateExtension) Candidate {
func mustCandidateServerReflexive(t *testing.T, conf *CandidateServerReflexiveConfig) Candidate {
t.Helper()
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
panic(err)
t.Fatal(err)
}
return cand
}
func mustCandidateServerReflexiveWithExtensions(
t *testing.T,
conf *CandidateServerReflexiveConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
t.Fatal(err)
}
cand.setExtensions(extensions)
@@ -329,12 +350,16 @@ func mustCandidateServerReflexiveWithExtensions(t *testing.T, conf *CandidateSer
return cand
}
func mustCandidatePeerReflexiveWithExtensions(t *testing.T, conf *CandidatePeerReflexiveConfig, extensions []CandidateExtension) Candidate {
func mustCandidatePeerReflexiveWithExtensions(
t *testing.T,
conf *CandidatePeerReflexiveConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidatePeerReflexive(conf)
if err != nil {
panic(err)
t.Fatal(err)
}
cand.setExtensions(extensions)
@@ -349,7 +374,7 @@ func TestCandidateMarshal(t *testing.T) {
expectError bool
}{
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP6.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
Port: 53987,
@@ -360,7 +385,7 @@ func TestCandidateMarshal(t *testing.T) {
false,
},
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(),
Address: "10.0.75.1",
Port: 53634,
@@ -369,7 +394,7 @@ func TestCandidateMarshal(t *testing.T) {
false,
},
{
mustCandidateServerReflexive(&CandidateServerReflexiveConfig{
mustCandidateServerReflexive(t, &CandidateServerReflexiveConfig{
Network: NetworkTypeUDP4.String(),
Address: "191.228.238.68",
Port: 53991,
@@ -395,11 +420,12 @@ func TestCandidateMarshal(t *testing.T) {
{"network-cost", "10"},
},
),
//nolint: lll
"4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10",
false,
},
{
mustCandidateRelay(&CandidateRelayConfig{
mustCandidateRelay(t, &CandidateRelayConfig{
Network: NetworkTypeUDP4.String(),
Address: "50.0.0.1",
Port: 5000,
@@ -410,7 +436,7 @@ func TestCandidateMarshal(t *testing.T) {
false,
},
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeTCP4.String(),
Address: "192.168.0.196",
Port: 0,
@@ -420,7 +446,7 @@ func TestCandidateMarshal(t *testing.T) {
false,
},
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(),
Address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local",
Port: 60542,
@@ -429,7 +455,7 @@ func TestCandidateMarshal(t *testing.T) {
},
// Missing Foundation
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(),
Address: localhostIPStr,
Port: 80,
@@ -440,7 +466,7 @@ func TestCandidateMarshal(t *testing.T) {
false,
},
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(),
Address: localhostIPStr,
Port: 80,
@@ -451,7 +477,7 @@ func TestCandidateMarshal(t *testing.T) {
false,
},
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeTCP4.String(),
Address: "172.28.142.173",
Port: 7686,
@@ -467,8 +493,10 @@ func TestCandidateMarshal(t *testing.T) {
{nil, "1938809241", true},
{nil, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2", true},
{nil, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host", true},
//nolint: lll
{nil, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3", true},
{nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr", true},
//nolint: lll
{nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3", true},
{nil, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host", true},
{nil, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host", true},
@@ -521,12 +549,19 @@ func TestCandidateMarshal(t *testing.T) {
actualCandidate, err := UnmarshalCandidate(test.marshaled)
if test.expectError {
require.Error(t, err, "expected error", test.marshaled)
return
}
require.NoError(t, err)
require.Truef(t, test.candidate.Equal(actualCandidate), "%s != %s", test.candidate.String(), actualCandidate.String())
require.Truef(
t,
test.candidate.Equal(actualCandidate),
"%s != %s",
test.candidate.String(),
actualCandidate.String(),
)
require.Equal(t, test.marshaled, actualCandidate.Marshal())
})
}
@@ -573,7 +608,7 @@ func TestCandidateWriteTo(t *testing.T) {
}
func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) {
candidateWithZoneID := mustCandidateHost(&CandidateHostConfig{
candidateWithZoneID := mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP6.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%Local Connection",
Port: 53987,
@@ -583,7 +618,7 @@ func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) {
candidateStr := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host"
require.Equal(t, candidateStr, candidateWithZoneID.Marshal())
candidate := mustCandidateHost(&CandidateHostConfig{
candidate := mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP6.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
Port: 53987,
@@ -612,6 +647,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"ufrag", "QNvE"},
{"network-id", "4"},
},
//nolint: lll
"1299692247 1 udp 2122134271 fdc8:cc8:c835:e400:343c:feb:32c8:17b9 58240 typ host generation 0 ufrag QNvE network-id 4",
},
{
@@ -620,6 +656,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"network-id", "2"},
{"network-cost", "50"},
},
//nolint:lll
"647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991 generation 1 network-id 2 network-cost 50",
},
{
@@ -628,6 +665,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"network-id", "2"},
{"network-cost", "10"},
},
//nolint:lll
"4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10",
},
{
@@ -638,6 +676,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"ufrag", "frag42abcdef"},
{"password", "abc123exp123"},
},
//nolint: lll
"848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001 generation 0 network-id 1 network-cost 20 ufrag frag42abcdef password abc123exp123",
},
{
@@ -703,7 +742,7 @@ func TestCandidateExtensionsDeepEqual(t *testing.T) {
equal bool
}{
{
mustCandidateHost(&CandidateHostConfig{
mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
Port: 53987,

View File

@@ -20,8 +20,7 @@ func newCandidatePair(local, remote Candidate, controlling bool) *CandidatePair
}
}
// CandidatePair is a combination of a
// local and remote candidate
// CandidatePair is a combination of a local and remote candidate.
type CandidatePair struct {
iceRoleControlling bool
Remote Candidate
@@ -42,8 +41,17 @@ func (p *CandidatePair) String() string {
return ""
}
return fmt.Sprintf("prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v",
p.priority(), p.Local.Priority(), p.Local, p.Remote, p.Remote.Priority(), p.state, p.nominated, p.nominateOnBindingSuccess)
return fmt.Sprintf(
"prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v",
p.priority(),
p.Local.Priority(),
p.Local,
p.Remote,
p.Remote.Priority(),
p.state,
p.nominated,
p.nominateOnBindingSuccess,
)
}
func (p *CandidatePair) equal(other *CandidatePair) bool {
@@ -53,6 +61,7 @@ func (p *CandidatePair) equal(other *CandidatePair) bool {
if p == nil || other == nil {
return false
}
return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote)
}
@@ -60,9 +69,9 @@ func (p *CandidatePair) equal(other *CandidatePair) bool {
// Let G be the priority for the candidate provided by the controlling
// agent. Let D be the priority for the candidate provided by the
// controlled agent.
// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0)
// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0).
func (p *CandidatePair) priority() uint64 {
var g, d uint32
var g, d uint32 //nolint:varnamelen // clearer to use g and d here
if p.iceRoleControlling {
g = p.Local.Priority()
d = p.Remote.Priority()
@@ -77,18 +86,21 @@ func (p *CandidatePair) priority() uint64 {
if x < y {
return uint64(x)
}
return uint64(y)
}
localMax := func(x, y uint32) uint64 {
if x > y {
return uint64(x)
}
return uint64(y)
}
cmp := func(x, y uint32) uint64 {
if x > y {
return uint64(1)
}
return uint64(0)
}
@@ -109,7 +121,7 @@ func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) {
}
// UpdateRoundTripTime sets the current round time of this pair and
// accumulates total round trip time and responses received
// accumulates total round trip time and responses received.
func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) {
rttNs := rtt.Nanoseconds()
atomic.StoreInt64(&p.currentRoundTripTime, rttNs)

View File

@@ -3,12 +3,12 @@
package ice
// CandidatePairState represent the ICE candidate pair state
// CandidatePairState represent the ICE candidate pair state.
type CandidatePairState int
const (
// CandidatePairStateWaiting means a check has not been performed for
// this pair
// this pair.
CandidatePairStateWaiting CandidatePairState = iota + 1
// CandidatePairStateInProgress means a check has been sent for this pair,
@@ -36,5 +36,6 @@ func (c CandidatePairState) String() string {
case CandidatePairStateSucceeded:
return "succeeded"
}
return "Unknown candidate pair state"
}

View File

@@ -12,7 +12,7 @@ type CandidateRelatedAddress struct {
Port int
}
// String makes CandidateRelatedAddress printable
// String makes CandidateRelatedAddress printable.
func (c *CandidateRelatedAddress) String() string {
if c == nil {
return ""
@@ -27,6 +27,7 @@ func (c *CandidateRelatedAddress) Equal(other *CandidateRelatedAddress) bool {
if c == nil && other == nil {
return true
}
return c != nil && other != nil &&
c.Address == other.Address &&
c.Port == other.Port

View File

@@ -3,10 +3,10 @@
package ice
// CandidateType represents the type of candidate
// CandidateType represents the type of candidate.
type CandidateType byte
// CandidateType enum
// CandidateType enum.
const (
CandidateTypeUnspecified CandidateType = iota
CandidateTypeHost
@@ -15,7 +15,7 @@ const (
CandidateTypeRelay
)
// String makes CandidateType printable
// String makes CandidateType printable.
func (c CandidateType) String() string {
switch c {
case CandidateTypeHost:
@@ -29,6 +29,7 @@ func (c CandidateType) String() string {
case CandidateTypeUnspecified:
return "Unknown candidate type"
}
return "Unknown candidate type"
}
@@ -49,6 +50,7 @@ func (c CandidateType) Preference() uint16 {
case CandidateTypeRelay, CandidateTypeUnspecified:
return 0
}
return 0
}
@@ -61,5 +63,6 @@ func containsCandidateType(candidateType CandidateType, candidateTypeList []Cand
return true
}
}
return false
}

View File

@@ -45,7 +45,7 @@ func (v *virtualNet) close() {
v.wan.Stop() //nolint:errcheck,gosec
}
func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) { //nolint:cyclop
loggerFactory := logging.NewDefaultLoggerFactory()
// WAN
@@ -77,6 +77,7 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
vnetGlobalIPA + "/" + vnetLocalIPA,
}
}
return []string{
vnetGlobalIPA,
}
@@ -114,6 +115,7 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
vnetGlobalIPB + "/" + vnetLocalIPB,
}
}
return []string{
vnetGlobalIPB,
}
@@ -175,6 +177,7 @@ func addVNetSTUN(wanNet *vnet.Net, loggerFactory logging.LoggerFactory) (*turn.S
if pw, ok := credMap[username]; ok {
return turn.GenerateAuthKey(username, realm, pw), true
}
return nil, false
},
PacketConnConfigs: []turn.PacketConnConfig{
@@ -222,6 +225,7 @@ func connectWithVNet(aAgent, bAgent *Agent) (*Conn, *Conn) {
// Ensure accepted
<-accepted
return aConn, bConn
}
@@ -230,7 +234,7 @@ type agentTestConfig struct {
nat1To1IPCandidateType CandidateType
}
func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
func pipeWithVNet(vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected()
@@ -247,7 +251,7 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
MulticastDNSMode: MulticastDNSModeDisabled,
NAT1To1IPs: nat1To1IPs,
NAT1To1IPCandidateType: a0TestConfig.nat1To1IPCandidateType,
Net: v.net0,
Net: vnet.net0,
}
aAgent, err := NewAgent(cfg0)
@@ -270,7 +274,7 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
MulticastDNSMode: MulticastDNSModeDisabled,
NAT1To1IPs: nat1To1IPs,
NAT1To1IPCandidateType: a1TestConfig.nat1To1IPCandidateType,
Net: v.net1,
Net: vnet.net1,
}
bAgent, err := NewAgent(cfg1)
@@ -293,6 +297,8 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
}
func closePipe(t *testing.T, ca *Conn, cb *Conn) {
t.Helper()
require.NoError(t, ca.Close())
require.NoError(t, cb.Close())
}
@@ -325,10 +331,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointIndependent,
FilteringBehavior: vnet.EndpointIndependent,
}
v, err := buildVNet(natType, natType)
vnet, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed")
defer v.close()
defer vnet.close()
log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{
@@ -341,7 +347,7 @@ func TestConnectivityVNet(t *testing.T) {
stunServerURL,
},
}
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
time.Sleep(1 * time.Second)
@@ -358,10 +364,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointAddrPortDependent,
FilteringBehavior: vnet.EndpointAddrPortDependent,
}
v, err := buildVNet(natType, natType)
vnet, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed")
defer v.close()
defer vnet.close()
log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{
@@ -375,7 +381,7 @@ func TestConnectivityVNet(t *testing.T) {
stunServerURL,
},
}
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...")
closePipe(t, ca, cb)
@@ -394,10 +400,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointAddrPortDependent,
FilteringBehavior: vnet.EndpointAddrPortDependent,
}
v, err := buildVNet(natType0, natType1)
vnet, err := buildVNet(natType0, natType1)
require.NoError(t, err, "should succeed")
defer v.close()
defer vnet.close()
log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{
@@ -407,7 +413,7 @@ func TestConnectivityVNet(t *testing.T) {
a1TestConfig := &agentTestConfig{
urls: []*stun.URI{},
}
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...")
closePipe(t, ca, cb)
@@ -426,10 +432,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointAddrPortDependent,
FilteringBehavior: vnet.EndpointAddrPortDependent,
}
v, err := buildVNet(natType0, natType1)
vnet, err := buildVNet(natType0, natType1)
require.NoError(t, err, "should succeed")
defer v.close()
defer vnet.close()
log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{
@@ -439,14 +445,15 @@ func TestConnectivityVNet(t *testing.T) {
a1TestConfig := &agentTestConfig{
urls: []*stun.URI{},
}
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...")
closePipe(t, ca, cb)
})
}
// TestDisconnectedToConnected requires that an agent can go to disconnected, and then return to connected successfully
// TestDisconnectedToConnected requires that an agent can go to disconnected,
// and then return to connected successfully.
func TestDisconnectedToConnected(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -546,7 +553,7 @@ func TestDisconnectedToConnected(t *testing.T) {
require.NoError(t, wan.Stop())
}
// Agent.Write should use the best valid pair if a selected pair is not yet available
// Agent.Write should use the best valid pair if a selected pair is not yet available.
func TestWriteUseValidPair(t *testing.T) {
defer test.CheckRoutines(t)()

View File

@@ -29,69 +29,71 @@ var (
ErrPort = errors.New("invalid port")
// ErrLocalUfragInsufficientBits indicates local username fragment insufficient bits are provided.
// Have to be at least 24 bits long
// Have to be at least 24 bits long.
ErrLocalUfragInsufficientBits = errors.New("local username fragment is less than 24 bits long")
// ErrLocalPwdInsufficientBits indicates local password insufficient bits are provided.
// Have to be at least 128 bits long
// Have to be at least 128 bits long.
ErrLocalPwdInsufficientBits = errors.New("local password is less than 128 bits long")
// ErrProtoType indicates an unsupported transport type was provided.
ErrProtoType = errors.New("invalid transport protocol type")
// ErrClosed indicates the agent is closed
// ErrClosed indicates the agent is closed.
ErrClosed = taskloop.ErrClosed
// ErrNoCandidatePairs indicates agent does not have a valid candidate pair
// ErrNoCandidatePairs indicates agent does not have a valid candidate pair.
ErrNoCandidatePairs = errors.New("no candidate pairs available")
// ErrCanceledByCaller indicates agent connection was canceled by the caller
// ErrCanceledByCaller indicates agent connection was canceled by the caller.
ErrCanceledByCaller = errors.New("connecting canceled by caller")
// ErrMultipleStart indicates agent was started twice
// ErrMultipleStart indicates agent was started twice.
ErrMultipleStart = errors.New("attempted to start agent twice")
// ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag
// ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag.
ErrRemoteUfragEmpty = errors.New("remote ufrag is empty")
// ErrRemotePwdEmpty indicates agent was started with an empty remote pwd
// ErrRemotePwdEmpty indicates agent was started with an empty remote pwd.
ErrRemotePwdEmpty = errors.New("remote pwd is empty")
// ErrNoOnCandidateHandler indicates agent was started without OnCandidate
// ErrNoOnCandidateHandler indicates agent was started without OnCandidate.
ErrNoOnCandidateHandler = errors.New("no OnCandidate provided")
// ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times
// ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times.
ErrMultipleGatherAttempted = errors.New("attempting to gather candidates during gathering state")
// ErrUsernameEmpty indicates agent was give TURN URL with an empty Username
// ErrUsernameEmpty indicates agent was give TURN URL with an empty Username.
ErrUsernameEmpty = errors.New("username is empty")
// ErrPasswordEmpty indicates agent was give TURN URL with an empty Password
// ErrPasswordEmpty indicates agent was give TURN URL with an empty Password.
ErrPasswordEmpty = errors.New("password is empty")
// ErrAddressParseFailed indicates we were unable to parse a candidate address
// ErrAddressParseFailed indicates we were unable to parse a candidate address.
ErrAddressParseFailed = errors.New("failed to parse address")
// ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent
// ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent.
ErrLiteUsingNonHostCandidates = errors.New("lite agents must only use host candidates")
// ErrUselessUrlsProvided indicates that one or more URL was provided to the agent but no host
// candidate required them
// candidate required them.
ErrUselessUrlsProvided = errors.New("agent does not need URL with selected candidate types")
// ErrUnsupportedNAT1To1IPCandidateType indicates that the specified NAT1To1IPCandidateType is
// unsupported
// unsupported.
ErrUnsupportedNAT1To1IPCandidateType = errors.New("unsupported 1:1 NAT IP candidate type")
// ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid
// ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid.
ErrInvalidNAT1To1IPMapping = errors.New("invalid 1:1 NAT IP mapping")
// ErrExternalMappedIPNotFound in NAT1To1IPMapping
// ErrExternalMappedIPNotFound in NAT1To1IPMapping.
ErrExternalMappedIPNotFound = errors.New("external mapped IP not found")
// ErrMulticastDNSWithNAT1To1IPMapping indicates that the mDNS gathering cannot be used along
// with 1:1 NAT IP mapping for host candidate.
ErrMulticastDNSWithNAT1To1IPMapping = errors.New("mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate")
ErrMulticastDNSWithNAT1To1IPMapping = errors.New(
"mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate",
)
// ErrIneffectiveNAT1To1IPMappingHost indicates that 1:1 NAT IP mapping for host candidate is
// requested, but the host candidate type is disabled.
@@ -101,10 +103,12 @@ var (
// requested, but the srflx candidate type is disabled.
ErrIneffectiveNAT1To1IPMappingSrflx = errors.New("1:1 NAT IP mapping for srflx candidate ineffective")
// ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName
ErrInvalidMulticastDNSHostName = errors.New("invalid mDNS HostName, must end with .local and can only contain a single '.'")
// ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName.
ErrInvalidMulticastDNSHostName = errors.New(
"invalid mDNS HostName, must end with .local and can only contain a single '.'",
)
// ErrRunCanceled indicates a run operation was canceled by its individual done
// ErrRunCanceled indicates a run operation was canceled by its individual done.
ErrRunCanceled = errors.New("run was canceled by done")
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
@@ -113,7 +117,7 @@ var (
// ErrUnknownCandidateTyp indicates that a candidate had a unknown type value.
ErrUnknownCandidateTyp = errors.New("unknown candidate typ")
// ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed
// ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed.
ErrDetermineNetworkType = errors.New("unable to determine networkType")
errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate")
@@ -144,5 +148,5 @@ var (
// UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now.
// will be used in the future.
// errListenUnspecified = errors.New("can't listen on unspecified address")
// errListenUnspecified = errors.New("can't listen on unspecified address").
)

View File

@@ -26,7 +26,7 @@ var (
localHTTPPort, remoteHTTPPort int
)
// HTTP Listener to get ICE Credentials from remote Peer
// HTTP Listener to get ICE Credentials from remote Peer.
func remoteAuth(_ http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
panic(err)
@@ -36,7 +36,7 @@ func remoteAuth(_ http.ResponseWriter, r *http.Request) {
remoteAuthChannel <- r.PostForm["pwd"][0]
}
// HTTP Listener to get ICE Candidate from remote Peer
// HTTP Listener to get ICE Candidate from remote Peer.
func remoteCandidate(_ http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
panic(err)

View File

@@ -13,10 +13,13 @@ func validateIPString(ipStr string) (net.IP, bool, error) {
if ip == nil {
return nil, false, ErrInvalidNAT1To1IPMapping
}
return ip, (ip.To4() != nil), nil
}
// ipMapping holds the mapping of local and external IP address for a particular IP family
// ipMapping holds the mapping of local and external IP address
//
// for a particular IP family.
type ipMapping struct {
ipSole net.IP // When non-nil, this is the sole external IP for one local IP assumed
ipMap map[string]net.IP // Local-to-external IP mapping (k: local, v: external)
@@ -75,7 +78,11 @@ type externalIPMapper struct {
candidateType CandidateType
}
func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIPMapper, error) { //nolint:gocognit
//nolint:gocognit,cyclop
func newExternalIPMapper(
candidateType CandidateType,
ips []string,
) (*externalIPMapper, error) {
if len(ips) == 0 {
return nil, nil //nolint:nilnil
}
@@ -85,7 +92,7 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
return nil, ErrUnsupportedNAT1To1IPCandidateType
}
m := &externalIPMapper{
mapper := &externalIPMapper{
ipv4Mapping: ipMapping{ipMap: map[string]net.IP{}},
ipv6Mapping: ipMapping{ipMap: map[string]net.IP{}},
candidateType: candidateType,
@@ -101,13 +108,13 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
if err != nil {
return nil, err
}
if len(ipPair) == 1 {
if len(ipPair) == 1 { //nolint:nestif
if isExtIPv4 {
if err := m.ipv4Mapping.setSoleIP(extIP); err != nil {
if err := mapper.ipv4Mapping.setSoleIP(extIP); err != nil {
return nil, err
}
} else {
if err := m.ipv6Mapping.setSoleIP(extIP); err != nil {
if err := mapper.ipv6Mapping.setSoleIP(extIP); err != nil {
return nil, err
}
}
@@ -121,7 +128,7 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
return nil, ErrInvalidNAT1To1IPMapping
}
if err := m.ipv4Mapping.addIPMapping(locIP, extIP); err != nil {
if err := mapper.ipv4Mapping.addIPMapping(locIP, extIP); err != nil {
return nil, err
}
} else {
@@ -129,14 +136,14 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
return nil, ErrInvalidNAT1To1IPMapping
}
if err := m.ipv6Mapping.addIPMapping(locIP, extIP); err != nil {
if err := mapper.ipv6Mapping.addIPMapping(locIP, extIP); err != nil {
return nil, err
}
}
}
}
return m, nil
return mapper, nil
}
func (m *externalIPMapper) findExternalIP(localIPStr string) (net.IP, error) {

View File

@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestExternalIPMapper(t *testing.T) {
func TestExternalIPMapper(t *testing.T) { //nolint:maintidx
t.Run("validateIPString", func(t *testing.T) {
var ip net.IP
var isIPv4 bool
@@ -31,165 +31,165 @@ func TestExternalIPMapper(t *testing.T) {
})
t.Run("newExternalIPMapper", func(t *testing.T) {
var m *externalIPMapper
var mapper *externalIPMapper
var err error
// ips being nil should succeed but mapper will be nil also
m, err = newExternalIPMapper(CandidateTypeUnspecified, nil)
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, nil)
require.NoError(t, err, "should succeed")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// ips being empty should succeed but mapper will still be nil
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{})
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{})
require.NoError(t, err, "should succeed")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// IPv4 with no explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
require.NotNil(t, m.ipv4Mapping.ipSole)
require.Nil(t, m.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.Nil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// IPv4 with no explicit local IP, using CandidateTypeServerReflexive
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"1.2.3.4",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.Equal(t, CandidateTypeServerReflexive, m.candidateType, "should match")
require.NotNil(t, m.ipv4Mapping.ipSole)
require.Nil(t, m.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeServerReflexive, mapper.candidateType, "should match")
require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.Nil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// IPv4 with no explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2601:4567::5678",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
require.Nil(t, m.ipv4Mapping.ipSole)
require.NotNil(t, m.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.Nil(t, mapper.ipv4Mapping.ipSole)
require.NotNil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// IPv4 and IPv6 in the mix
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4",
"2601:4567::5678",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
require.NotNil(t, m.ipv4Mapping.ipSole)
require.NotNil(t, m.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.NotNil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// Unsupported candidate type - CandidateTypePeerReflexive
m, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
"1.2.3.4",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Unsupported candidate type - CandidateTypeRelay
m, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
"1.2.3.4",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Cannot duplicate mapping IPv4 family
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"1.2.3.4",
"5.6.7.8",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Cannot duplicate mapping IPv6 family
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"2201::1",
"2201::0002",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Invalide external IP string
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"bad.2.3.4",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Invalide local IP string
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"1.2.3.4/10.0.0.bad",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
})
t.Run("newExternalIPMapper with explicit local IP", func(t *testing.T) {
var m *externalIPMapper
var mapper *externalIPMapper
var err error
// IPv4 with explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/10.0.0.1",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
require.Nil(t, m.ipv4Mapping.ipSole)
require.Nil(t, m.ipv6Mapping.ipSole)
require.Equal(t, 1, len(m.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.Nil(t, mapper.ipv4Mapping.ipSole)
require.Nil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 1, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// Cannot assign two ext IPs for one local IPv4
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/10.0.0.1",
"1.2.3.5/10.0.0.1",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Cannot assign two ext IPs for one local IPv6
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2200::1/fe80::1",
"2200::0002/fe80::1",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Cannot mix different IP family in a pair (1)
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2200::1/10.0.0.1",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Cannot mix different IP family in a pair (2)
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/fe80::1",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
// Invalid pair
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/192.168.0.2/10.0.0.1",
})
require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil")
require.Nil(t, mapper, "should be nil")
})
t.Run("newExternalIPMapper with implicit and explicit local IP", func(t *testing.T) {
@@ -209,100 +209,100 @@ func TestExternalIPMapper(t *testing.T) {
})
t.Run("findExternalIP without explicit local IP", func(t *testing.T) {
var m *externalIPMapper
var mapper *externalIPMapper
var err error
var extIP net.IP
// IPv4 with explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4",
"2200::1",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.NotNil(t, m.ipv4Mapping.ipSole)
require.NotNil(t, m.ipv6Mapping.ipSole)
require.NotNil(t, mapper, "should not be nil")
require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.NotNil(t, mapper.ipv6Mapping.ipSole)
// Find external IPv4
extIP, err = m.findExternalIP("10.0.0.1")
extIP, err = mapper.findExternalIP("10.0.0.1")
require.NoError(t, err, "should succeed")
require.Equal(t, "1.2.3.4", extIP.String(), "should match")
// Find external IPv6
extIP, err = m.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
extIP, err = mapper.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
require.NoError(t, err, "should succeed")
require.Equal(t, "2200::1", extIP.String(), "should match")
// Bad local IP string
_, err = m.findExternalIP("really.bad")
_, err = mapper.findExternalIP("really.bad")
require.Error(t, err, "should fail")
})
t.Run("findExternalIP with explicit local IP", func(t *testing.T) {
var m *externalIPMapper
var mapper *externalIPMapper
var err error
var extIP net.IP
// IPv4 with explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/10.0.0.1",
"1.2.3.5/10.0.0.2",
"2200::1/fe80::1",
"2200::2/fe80::2",
})
require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil")
require.NotNil(t, mapper, "should not be nil")
// Find external IPv4
extIP, err = m.findExternalIP("10.0.0.1")
extIP, err = mapper.findExternalIP("10.0.0.1")
require.NoError(t, err, "should succeed")
require.Equal(t, "1.2.3.4", extIP.String(), "should match")
extIP, err = m.findExternalIP("10.0.0.2")
extIP, err = mapper.findExternalIP("10.0.0.2")
require.NoError(t, err, "should succeed")
require.Equal(t, "1.2.3.5", extIP.String(), "should match")
_, err = m.findExternalIP("10.0.0.3")
_, err = mapper.findExternalIP("10.0.0.3")
require.Error(t, err, "should fail")
// Find external IPv6
extIP, err = m.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
extIP, err = mapper.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
require.NoError(t, err, "should succeed")
require.Equal(t, "2200::1", extIP.String(), "should match")
extIP, err = m.findExternalIP("fe80::0002") // Use '0002' instead of '2' on purpose
extIP, err = mapper.findExternalIP("fe80::0002") // Use '0002' instead of '2' on purpose
require.NoError(t, err, "should succeed")
require.Equal(t, "2200::2", extIP.String(), "should match")
_, err = m.findExternalIP("fe80::3")
_, err = mapper.findExternalIP("fe80::3")
require.Error(t, err, "should fail")
// Bad local IP string
_, err = m.findExternalIP("really.bad")
_, err = mapper.findExternalIP("really.bad")
require.Error(t, err, "should fail")
})
t.Run("findExternalIP with empty map", func(t *testing.T) {
var m *externalIPMapper
var mapper *externalIPMapper
var err error
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4",
})
require.NoError(t, err, "should succeed")
// Attempt to find IPv6 that does not exist in the map
extIP, err := m.findExternalIP("fe80::1")
extIP, err := mapper.findExternalIP("fe80::1")
require.NoError(t, err, "should succeed")
require.Equal(t, "fe80::1", extIP.String(), "should match")
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2200::1",
})
require.NoError(t, err, "should succeed")
// Attempt to find IPv4 that does not exist in the map
extIP, err = m.findExternalIP("10.0.0.1")
extIP, err = mapper.findExternalIP("10.0.0.1")
require.NoError(t, err, "should succeed")
require.Equal(t, "10.0.0.1", extIP.String(), "should match")
})

119
gather.go
View File

@@ -21,10 +21,11 @@ import (
"github.com/pion/turn/v4"
)
// Close a net.Conn and log if we have a failure
// Close a net.Conn and log if we have a failure.
func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...interface{}) {
if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) {
log.Warnf("Connection is not allocated: "+msg, args...)
return
}
@@ -41,9 +42,11 @@ func (a *Agent) GatherCandidates() error {
if runErr := a.loop.Run(a.loop, func(ctx context.Context) {
if a.gatheringState != GatheringStateNew {
gatherErr = ErrMultipleGatherAttempted
return
} else if a.onCandidateHdlr.Load() == nil {
gatherErr = ErrNoOnCandidateHandler
return
}
@@ -57,13 +60,15 @@ func (a *Agent) GatherCandidates() error {
}); runErr != nil {
return runErr
}
return gatherErr
}
func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) {
func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { //nolint:cyclop
defer close(done)
if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck
a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err)
return
}
@@ -111,7 +116,8 @@ func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) {
}
}
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) { //nolint:gocognit
//nolint:gocognit,gocyclo,cyclop
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) {
networks := map[string]struct{}{}
for _, networkType := range networkTypes {
if networkType.IsTCP() {
@@ -132,16 +138,19 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
if err != nil {
a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err)
return
}
for _, addr := range localAddrs {
mappedIP := addr
if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if a.mDNSMode != MulticastDNSModeQueryAndGather &&
a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if _mappedIP, innerErr := a.extIPMapper.findExternalIP(addr.String()); innerErr == nil {
conv, ok := netip.AddrFromSlice(_mappedIP)
if !ok {
a.log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String())
continue
}
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable
@@ -186,6 +195,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice())
if err != nil {
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag)
continue
}
} else {
@@ -194,6 +204,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice())
if err != nil {
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag)
continue
}
muxConns = []net.PacketConn{conn}
@@ -222,6 +233,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
})
if err != nil {
a.log.Warnf("Failed to listen %s %s", network, addr)
continue
}
@@ -229,6 +241,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
conns = append(conns, connAndPort{conn, udpConn.Port})
} else {
a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag)
continue
}
}
@@ -245,21 +258,38 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
IsLocationTracked: isLocationTracked,
}
c, err := NewCandidateHost(&hostConfig)
candidateHost, err := NewCandidateHost(&hostConfig)
if err != nil {
closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err)
closeConnAndLog(
connAndPort.conn,
a.log,
"failed to create host candidate: %s %s %d: %v",
network, mappedIP,
connAndPort.port,
err,
)
continue
}
if a.mDNSMode == MulticastDNSModeQueryAndGather {
if err = c.setIPAddr(addr); err != nil {
closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err)
if err = candidateHost.setIPAddr(addr); err != nil {
closeConnAndLog(
connAndPort.conn,
a.log,
"failed to create host candidate: %s %s %d: %v",
network,
mappedIP,
connAndPort.port,
err,
)
continue
}
}
if err := a.addCandidate(ctx, c, connAndPort.conn); err != nil {
if closeErr := c.close(); closeErr != nil {
if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil {
if closeErr := candidateHost.close(); closeErr != nil {
a.log.Warnf("Failed to close candidate: %v", closeErr)
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
@@ -287,10 +317,11 @@ func shouldFilterLocationTracked(candidateIP net.IP) bool {
if !ok {
return false
}
return shouldFilterLocationTrackedIP(addr)
}
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit,cyclop
if a.udpMux == nil {
return errUDPMuxDisabled
}
@@ -317,6 +348,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
mappedIP, err := a.extIPMapper.findExternalIP(candidateIP.String())
if err != nil {
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
continue
}
@@ -359,6 +391,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
c, err := NewCandidateHost(&hostConfig)
if err != nil {
closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err)
continue
}
@@ -368,6 +401,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
}
closeConnAndLog(conn, a.log, "failed to add candidate: %s %d: %v", candidateIP, udpAddr.Port, err)
continue
}
@@ -391,26 +425,37 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
go func() {
defer wg.Done()
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0})
conn, err := listenUDPInPortRange(
a.net,
a.log,
int(a.portMax),
int(a.portMin),
network,
&net.UDPAddr{IP: nil, Port: 0},
)
if err != nil {
a.log.Warnf("Failed to listen %s: %v", network, err)
return
}
lAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but LocalAddr is not a UDPAddr")
return
}
mappedIP, err := a.extIPMapper.findExternalIP(lAddr.IP.String())
if err != nil {
closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but no external IP is found for %s", lAddr.IP.String())
return
}
if shouldFilterLocationTracked(mappedIP) {
closeConnAndLog(conn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP)
return
}
@@ -429,6 +474,7 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
mappedIP.String(),
lAddr.Port,
err)
return
}
@@ -442,7 +488,8 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
}
}
func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { //nolint:gocognit
//nolint:gocognit,cyclop
func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) {
var wg sync.WaitGroup
defer wg.Wait()
@@ -456,6 +503,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
udpAddr, ok := listenAddr.(*net.UDPAddr)
if !ok {
a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr")
continue
}
wg.Add(1)
@@ -466,23 +514,27 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
if err != nil {
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
return
}
if shouldFilterLocationTracked(serverAddr.IP) {
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
return
}
xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, a.stunGatherTimeout)
if err != nil {
a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err)
return
}
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr)
if err != nil {
a.log.Warnf("Failed to find connection in UDPMuxSrflx %s %s: %v", network, url, err)
return
}
@@ -500,6 +552,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil {
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
return
}
@@ -515,7 +568,8 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
}
}
func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { //nolint:gocognit
//nolint:cyclop,gocognit
func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) {
var wg sync.WaitGroup
defer wg.Wait()
@@ -533,17 +587,27 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
if err != nil {
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
return
}
if shouldFilterLocationTracked(serverAddr.IP) {
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
return
}
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0})
conn, err := listenUDPInPortRange(
a.net,
a.log,
int(a.portMax),
int(a.portMin),
network,
&net.UDPAddr{IP: nil, Port: 0},
)
if err != nil {
closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err)
return
}
// If the agent closes midway through the connection
@@ -562,6 +626,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout)
if err != nil {
closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err)
return
}
@@ -580,6 +645,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil {
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
return
}
@@ -594,7 +660,8 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
}
}
func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { //nolint:gocognit
//nolint:maintidx,gocognit,gocyclo,cyclop
func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) {
var wg sync.WaitGroup
defer wg.Wait()
@@ -605,9 +672,11 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
continue
case urls[i].Username == "":
a.log.Errorf("Failed to gather relay candidates: %v", ErrUsernameEmpty)
return
case urls[i].Password == "":
a.log.Errorf("Failed to gather relay candidates: %v", ErrPasswordEmpty)
return
}
@@ -627,6 +696,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN:
if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil {
a.log.Warnf("Failed to listen %s: %v", network, err)
return
}
@@ -638,6 +708,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr)
return
}
@@ -654,12 +725,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr)
return
}
conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr)
return
}
@@ -671,12 +744,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr)
return
}
udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr)
if dialErr != nil {
a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr)
return
}
@@ -686,11 +761,13 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
})
if connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return
}
if connectErr = conn.HandshakeContext(ctx); connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return
}
@@ -702,12 +779,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if resolvErr != nil {
a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr)
return
}
tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if dialErr != nil {
a.log.Warnf("Failed to connect to relay: %v", dialErr)
return
}
@@ -721,6 +800,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
a.log.Errorf("Failed to close relay connection: %v", closeErr)
}
a.log.Warnf("Failed to connect to relay: %v", hsErr)
return
}
@@ -730,6 +810,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
locConn = turn.NewSTUNConn(conn)
default:
a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url)
return
}
@@ -743,12 +824,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
})
if err != nil {
closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err)
return
}
if err = client.Listen(); err != nil {
client.Close()
closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err)
return
}
@@ -756,6 +839,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
if err != nil {
client.Close()
closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err)
return
}
@@ -763,6 +847,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
if shouldFilterLocationTracked(rAddr.IP) {
a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP)
return
}
@@ -776,6 +861,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
RelayProtocol: relayProtocol,
OnClose: func() error {
client.Close()
return locConn.Close()
},
}
@@ -790,6 +876,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
client.Close()
closeConnAndLog(locConn, a.log, "failed to create relay candidate: %s %s: %v", network, rAddr.String(), err)
return
}

View File

@@ -31,26 +31,32 @@ import (
)
func TestListenUDP(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
_, localAddrs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test")
require.NoError(t, err)
ip := localAddrs[0].AsSlice()
conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
require.NoError(t, err, "listenUDP error with no port restriction")
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
_, err = listenUDPInPortRange(a.net, a.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
_, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
require.Equal(t, err, ErrPort, "listenUDP with invalid port range did not return ErrPort")
conn, err = listenUDPInPortRange(a.net, a.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
require.NoError(t, err, "listenUDP error with no port restriction")
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
@@ -64,7 +70,7 @@ func TestListenUDP(t *testing.T) {
result := make([]int, 0, total)
portRange := make([]int, 0, total)
for i := 0; i < total; i++ {
conn, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
conn, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
require.NoError(t, err, "listenUDP error with no port restriction")
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
@@ -85,7 +91,7 @@ func TestListenUDP(t *testing.T) {
if !reflect.DeepEqual(result, portRange) {
t.Fatalf("listenUDP with port restriction [%d, %d], got:%v, want:%v", portMin, portMax, result, portRange)
}
_, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
_, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax)
}
@@ -94,23 +100,23 @@ func TestGatherConcurrency(t *testing.T) {
defer test.TimeOut(time.Second * 30).Stop()
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: true,
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(Candidate) {
require.NoError(t, agent.OnCandidate(func(Candidate) {
candidateGatheredFunc()
}))
// Testing for panic
for i := 0; i < 10; i++ {
_ = a.GatherCandidates()
_ = agent.GatherCandidates()
}
<-candidateGathered.Done()
@@ -194,26 +200,27 @@ func TestLoopbackCandidate(t *testing.T) {
for _, tc := range testCases {
tcase := tc
t.Run(tcase.name, func(t *testing.T) {
a, err := NewAgent(tc.agentConfig)
agent, err := NewAgent(tc.agentConfig)
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
var loopback int32
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil {
if net.ParseIP(c.Address()).IsLoopback() {
atomic.StoreInt32(&loopback, 1)
}
} else {
candidateGatheredFunc()
return
}
t.Log(c.NetworkType(), c.Priority(), c)
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done()
@@ -226,7 +233,7 @@ func TestLoopbackCandidate(t *testing.T) {
require.NoError(t, muxUnspecDefault.Close())
}
// Assert that STUN gathering is done concurrently
// Assert that STUN gathering is done concurrently.
func TestSTUNConcurrency(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -273,7 +280,7 @@ func TestSTUNConcurrency(t *testing.T) {
_ = listener.Close()
}()
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive},
@@ -287,29 +294,36 @@ func TestSTUNConcurrency(t *testing.T) {
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil {
candidateGatheredFunc()
return
}
t.Log(c.NetworkType(), c.Priority(), c)
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done()
}
// Assert that TURN gathering is done concurrently
// Assert that TURN gathering is done concurrently.
func TestTURNConcurrency(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
runTest := func(protocol stun.ProtoType, scheme stun.SchemeType, packetConn net.PacketConn, listener net.Listener, serverPort int) {
runTest := func(
protocol stun.ProtoType,
scheme stun.SchemeType,
packetConn net.PacketConn,
listener net.Listener,
serverPort int,
) {
packetConnConfigs := []turn.PacketConnConfig{}
if packetConn != nil {
packetConnConfigs = append(packetConnConfigs, turn.PacketConnConfig{
@@ -357,7 +371,7 @@ func TestTURNConcurrency(t *testing.T) {
Port: serverPort,
})
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeRelay},
InsecureSkipVerify: true,
NetworkTypes: supportedNetworkTypes(),
@@ -365,16 +379,16 @@ func TestTURNConcurrency(t *testing.T) {
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil {
candidateGatheredFunc()
}
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done()
}
@@ -413,16 +427,20 @@ func TestTURNConcurrency(t *testing.T) {
require.NoError(t, genErr)
serverPort := randomPort(t)
serverListener, err := dtls.Listen("udp", &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort}, &dtls.Config{
Certificates: []tls.Certificate{certificate},
})
serverListener, err := dtls.Listen(
"udp",
&net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort},
&dtls.Config{
Certificates: []tls.Certificate{certificate},
},
)
require.NoError(t, err)
runTest(stun.ProtoTypeUDP, stun.SchemeTypeTURNS, nil, serverListener, serverPort)
})
}
// Assert that STUN and TURN gathering are done concurrently
// Assert that STUN and TURN gathering are done concurrently.
func TestSTUNTURNConcurrency(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -464,25 +482,26 @@ func TestSTUNTURNConcurrency(t *testing.T) {
Password: "password",
})
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay},
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
{
gatherLim := test.TimeOut(time.Second * 3) // As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s)
// As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s)
gatherLim := test.TimeOut(time.Second * 3)
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil {
candidateGatheredFunc()
}
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done()
gatherLim.Stop()
@@ -528,24 +547,24 @@ func TestTURNSrflx(t *testing.T) {
Password: "password",
}}
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay},
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil && c.Type() == CandidateTypeServerReflexive {
candidateGatheredFunc()
}
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done()
}
@@ -580,6 +599,7 @@ func (m *mockConn) SetWriteDeadline(time.Time) error { return io.EOF }
func (m *mockProxy) Dial(string, string) (net.Conn, error) {
m.proxyWasDialed()
return &mockConn{}, nil
}
@@ -599,7 +619,7 @@ func TestTURNProxyDialer(t *testing.T) {
proxyDialer, err := proxy.FromURL(tcpProxyURI, proxy.Direct)
require.NoError(t, err)
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeRelay},
NetworkTypes: supportedNetworkTypes(),
Urls: []*stun.URI{
@@ -616,17 +636,17 @@ func TestTURNProxyDialer(t *testing.T) {
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGatherFinish, candidateGatherFinishFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil {
candidateGatherFinishFunc()
}
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGatherFinish.Done()
<-proxyWasDialed.Done()
}
@@ -651,31 +671,31 @@ func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) {
_ = mux.Close()
}()
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeHost,
UDPMux: mux,
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
gatherCandidateDone := make(chan struct{})
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil {
close(gatherCandidateDone)
} else {
require.Equal(t, "1.2.3.4", c.Address())
}
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-gatherCandidateDone
require.NotEqual(t, 0, len(mux.connsIPv4))
}
// Assert that candidates are given for each mux in a MultiUDPMux
// Assert that candidates are given for each mux in a MultiUDPMux.
func TestMultiUDPMuxUsage(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -700,25 +720,26 @@ func TestMultiUDPMuxUsage(t *testing.T) {
}()
}
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
CandidateTypes: []CandidateType{CandidateTypeHost},
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateCh := make(chan Candidate)
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil {
close(candidateCh)
return
}
candidateCh <- c
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
portFound := make(map[int]bool)
for c := range candidateCh {
@@ -731,7 +752,7 @@ func TestMultiUDPMuxUsage(t *testing.T) {
}
}
// Assert that candidates are given for each mux in a MultiTCPMux
// Assert that candidates are given for each mux in a MultiTCPMux.
func TestMultiTCPMuxUsage(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -757,25 +778,26 @@ func TestMultiTCPMuxUsage(t *testing.T) {
}))
}
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
CandidateTypes: []CandidateType{CandidateTypeHost},
TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...),
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateCh := make(chan Candidate)
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil {
close(candidateCh)
return
}
candidateCh <- c
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
portFound := make(map[int]bool)
for c := range candidateCh {
@@ -790,7 +812,7 @@ func TestMultiTCPMuxUsage(t *testing.T) {
}
}
// Assert that UniversalUDPMux is used while gathering when configured in the Agent
// Assert that UniversalUDPMux is used while gathering when configured in the Agent.
func TestUniversalUDPMuxUsage(t *testing.T) {
defer test.CheckRoutines(t)()
@@ -816,7 +838,7 @@ func TestUniversalUDPMuxUsage(t *testing.T) {
})
}
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(),
Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeServerReflexive},
@@ -828,26 +850,32 @@ func TestUniversalUDPMuxUsage(t *testing.T) {
if aClosed {
return
}
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) {
require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil {
candidateGatheredFunc()
return
}
t.Log(c.NetworkType(), c.Priority(), c)
}))
require.NoError(t, a.GatherCandidates())
require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done()
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
aClosed = true
// Twice because of 2 STUN servers configured
require.Equal(t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called")
require.Equal(
t,
numSTUNS,
udpMuxSrflx.getXORMappedAddrUsedTimes,
"expected times that GetXORMappedAddr should be called",
)
// One for Restart() when agent has been initialized and one time when Close() the agent
require.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called")
// Twice because of 2 STUN servers configured
@@ -871,6 +899,7 @@ func (m *universalUDPMuxMock) GetConnForURL(string, string, net.Addr) (net.Packe
m.mu.Lock()
defer m.mu.Unlock()
m.getConnForURLTimes++
return m.conn, nil
}
@@ -878,6 +907,7 @@ func (m *universalUDPMuxMock) GetXORMappedAddr(net.Addr, time.Duration) (*stun.X
m.mu.Lock()
defer m.mu.Unlock()
m.getXORMappedAddrUsedTimes++
return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestVNetGather(t *testing.T) {
func TestVNetGather(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory()
@@ -51,7 +51,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to parse CIDR: %s", err)
}
r, err := vnet.NewRouter(&vnet.RouterConfig{
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: cider,
LoggerFactory: loggerFactory,
})
@@ -64,7 +64,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create a Net: %s", err)
}
err = r.AddNet(nw)
err = router.AddNet(nw)
if err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err)
}
@@ -94,7 +94,7 @@ func TestVNetGather(t *testing.T) {
})
t.Run("listenUDP", func(t *testing.T) {
r, err := vnet.NewRouter(&vnet.RouterConfig{
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory,
})
@@ -107,20 +107,26 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create a Net: %s", err)
}
err = r.AddNet(nw)
err = router.AddNet(nw)
if err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err)
}
a, err := NewAgent(&AgentConfig{Net: nw})
agent, err := NewAgent(&AgentConfig{Net: nw})
if err != nil {
t.Fatalf("Failed to create agent: %s", err)
}
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
_, localAddrs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
if len(localAddrs) == 0 {
t.Fatal("localInterfaces found no interfaces, unable to test")
}
@@ -128,7 +134,7 @@ func TestVNetGather(t *testing.T) {
ip := localAddrs[0].AsSlice()
conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
t.Fatalf("listenUDP error with no port restriction %v", err)
} else if conn == nil {
@@ -139,12 +145,12 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("failed to close conn")
}
_, err = listenUDPInPortRange(a.net, a.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
_, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
if !errors.Is(err, ErrPort) {
t.Fatal("listenUDP with invalid port range did not return ErrPort")
}
conn, err = listenUDPInPortRange(a.net, a.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
t.Fatalf("listenUDP error with no port restriction %v", err)
} else if conn == nil {
@@ -163,7 +169,7 @@ func TestVNetGather(t *testing.T) {
})
}
func TestVNetGatherWithNAT1To1(t *testing.T) {
func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory()
@@ -206,7 +212,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
err = lan.AddNet(nw)
require.NoError(t, err, "should succeed")
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{
NetworkTypeUDP4,
},
@@ -215,25 +221,25 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
})
require.NoError(t, err, "should succeed")
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
done := make(chan struct{})
err = a.OnCandidate(func(c Candidate) {
err = agent.OnCandidate(func(c Candidate) {
if c == nil {
close(done)
}
})
require.NoError(t, err, "should succeed")
err = a.GatherCandidates()
err = agent.GatherCandidates()
require.NoError(t, err, "should succeed")
log.Debug("Wait until gathering is complete...")
<-done
log.Debug("Gathering is done")
candidates, err := a.GetLocalCandidates()
candidates, err := agent.GetLocalCandidates()
require.NoError(t, err, "should succeed")
if len(candidates) != 2 {
@@ -248,7 +254,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
}
}
if candidates[0].Address() == externalIP0 {
if candidates[0].Address() == externalIP0 { //nolint:nestif
if candidates[1].Address() != externalIP1 {
t.Fatalf("Unexpected candidate IP: %s", candidates[1].Address())
}
@@ -305,7 +311,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
err = lan.AddNet(nw)
require.NoError(t, err, "should succeed")
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{
NetworkTypeUDP4,
},
@@ -317,25 +323,25 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
})
require.NoError(t, err, "should succeed")
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
done := make(chan struct{})
err = a.OnCandidate(func(c Candidate) {
err = agent.OnCandidate(func(c Candidate) {
if c == nil {
close(done)
}
})
require.NoError(t, err, "should succeed")
err = a.GatherCandidates()
err = agent.GatherCandidates()
require.NoError(t, err, "should succeed")
log.Debug("Wait until gathering is complete...")
<-done
log.Debug("Gathering is done")
candidates, err := a.GetLocalCandidates()
candidates, err := agent.GetLocalCandidates()
require.NoError(t, err, "should succeed")
if len(candidates) != 2 {
@@ -367,7 +373,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory()
r, err := vnet.NewRouter(&vnet.RouterConfig{
router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory,
})
@@ -380,24 +386,31 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
t.Fatalf("Failed to create a Net: %s", err)
}
if err = r.AddNet(nw); err != nil {
if err = router.AddNet(nw); err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err)
}
t.Run("InterfaceFilter should exclude the interface", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
Net: nw,
InterfaceFilter: func(interfaceName string) (keep bool) {
require.Equal(t, "eth0", interfaceName)
return false
},
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
_, localIPs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NoError(t, err)
if len(localIPs) != 0 {
@@ -406,19 +419,26 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
t.Run("IPFilter should exclude the IP", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
Net: nw,
IPFilter: func(ip net.IP) (keep bool) {
require.Equal(t, net.IP{1, 2, 3, 1}, ip)
return false
},
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
_, localIPs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NoError(t, err)
if len(localIPs) != 0 {
@@ -427,19 +447,26 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{
agent, err := NewAgent(&AgentConfig{
Net: nw,
InterfaceFilter: func(interfaceName string) (keep bool) {
require.Equal(t, "eth0", interfaceName)
return true
},
})
require.NoError(t, err)
defer func() {
require.NoError(t, a.Close())
require.NoError(t, agent.Close())
}()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
_, localIPs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NoError(t, err)
if len(localIPs) == 0 {

30
ice.go
View File

@@ -3,33 +3,33 @@
package ice
// ConnectionState is an enum showing the state of a ICE Connection
// ConnectionState is an enum showing the state of a ICE Connection.
type ConnectionState int
// List of supported States
// List of supported States.
const (
// ConnectionStateUnknown represents an unknown state
// ConnectionStateUnknown represents an unknown state.
ConnectionStateUnknown ConnectionState = iota
// ConnectionStateNew ICE agent is gathering addresses
// ConnectionStateNew ICE agent is gathering addresses.
ConnectionStateNew
// ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match
// ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match.
ConnectionStateChecking
// ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs
// ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs.
ConnectionStateConnected
// ConnectionStateCompleted ICE agent has finished
// ConnectionStateCompleted ICE agent has finished.
ConnectionStateCompleted
// ConnectionStateFailed ICE agent never could successfully connect
// ConnectionStateFailed ICE agent never could successfully connect.
ConnectionStateFailed
// ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state
// ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state.
ConnectionStateDisconnected
// ConnectionStateClosed ICE agent has finished and is no longer handling requests
// ConnectionStateClosed ICE agent has finished and is no longer handling requests.
ConnectionStateClosed
)
@@ -54,20 +54,20 @@ func (c ConnectionState) String() string {
}
}
// GatheringState describes the state of the candidate gathering process
// GatheringState describes the state of the candidate gathering process.
type GatheringState int
const (
// GatheringStateUnknown represents an unknown state
// GatheringStateUnknown represents an unknown state.
GatheringStateUnknown GatheringState = iota
// GatheringStateNew indicates candidate gathering is not yet started
// GatheringStateNew indicates candidate gathering is not yet started.
GatheringStateNew
// GatheringStateGathering indicates candidate gathering is ongoing
// GatheringStateGathering indicates candidate gathering is ongoing.
GatheringStateGathering
// GatheringStateComplete indicates candidate gathering has been completed
// GatheringStateComplete indicates candidate gathering has been completed.
GatheringStateComplete
)

View File

@@ -20,6 +20,7 @@ func (a tiebreaker) AddToAs(m *stun.Message, t stun.AttrType) error {
v := make([]byte, tiebreakerSize)
binary.BigEndian.PutUint64(v, uint64(a))
m.Add(t, v)
return nil
}
@@ -33,6 +34,7 @@ func (a *tiebreaker) GetFromAs(m *stun.Message, t stun.AttrType) error {
return err
}
*a = tiebreaker(binary.BigEndian.Uint64(v))
return nil
}
@@ -73,6 +75,7 @@ func (c AttrControl) AddTo(m *stun.Message) error {
if c.Role == Controlling {
return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlling)
}
return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlled)
}
@@ -80,11 +83,14 @@ func (c AttrControl) AddTo(m *stun.Message) error {
func (c *AttrControl) GetFrom(m *stun.Message) error {
if m.Contains(stun.AttrICEControlling) {
c.Role = Controlling
return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlling)
}
if m.Contains(stun.AttrICEControlled) {
c.Role = Controlled
return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlled)
}
return stun.ErrAttributeNotFound
}

View File

@@ -12,11 +12,11 @@ import (
func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var c AttrControlled
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
var attrCtr AttrControlled
if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
if err := m.Build(stun.BindingRequest, &c); err != nil {
if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
t.Error(err)
}
m1 := new(stun.Message)
@@ -27,7 +27,7 @@ func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != c {
if c1 != attrCtr {
t.Error("not equal")
}
t.Run("IncorrectSize", func(t *testing.T) {
@@ -42,11 +42,11 @@ func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var c AttrControlling
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
var attrCtr AttrControlling
if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
if err := m.Build(stun.BindingRequest, &c); err != nil {
if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
t.Error(err)
}
m1 := new(stun.Message)
@@ -57,7 +57,7 @@ func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != c {
if c1 != attrCtr {
t.Error("not equal")
}
t.Run("IncorrectSize", func(t *testing.T) {
@@ -70,7 +70,7 @@ func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
})
}
func TestControl_GetFrom(t *testing.T) {
func TestControl_GetFrom(t *testing.T) { //nolint:cyclop
t.Run("Blank", func(t *testing.T) {
m := new(stun.Message)
var c AttrControl
@@ -80,13 +80,13 @@ func TestControl_GetFrom(t *testing.T) {
})
t.Run("Controlling", func(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var c AttrControl
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
var attCtr AttrControl
if err := attCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
c.Role = Controlling
c.Tiebreaker = 4321
if err := m.Build(stun.BindingRequest, &c); err != nil {
attCtr.Role = Controlling
attCtr.Tiebreaker = 4321
if err := m.Build(stun.BindingRequest, &attCtr); err != nil {
t.Error(err)
}
m1 := new(stun.Message)
@@ -97,7 +97,7 @@ func TestControl_GetFrom(t *testing.T) {
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != c {
if c1 != attCtr {
t.Error("not equal")
}
t.Run("IncorrectSize", func(t *testing.T) {
@@ -111,13 +111,13 @@ func TestControl_GetFrom(t *testing.T) {
})
t.Run("Controlled", func(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var c AttrControl
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
var attrCtrl AttrControl
if err := attrCtrl.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
c.Role = Controlled
c.Tiebreaker = 1234
if err := m.Build(stun.BindingRequest, &c); err != nil {
attrCtrl.Role = Controlled
attrCtrl.Tiebreaker = 1234
if err := m.Build(stun.BindingRequest, &attrCtrl); err != nil {
t.Error(err)
}
m1 := new(stun.Message)
@@ -128,7 +128,7 @@ func TestControl_GetFrom(t *testing.T) {
if err := c1.GetFrom(m1); err != nil {
t.Error(err)
}
if c1 != c {
if c1 != attrCtrl {
t.Error("not equal")
}
t.Run("IncorrectSize", func(t *testing.T) {

View File

@@ -6,18 +6,19 @@ package atomic
import "sync/atomic"
// Error is an atomic error
// Error is an atomic error.
type Error struct {
v atomic.Value
}
// Store updates the value of the atomic variable
// Store updates the value of the atomic variable.
func (a *Error) Store(err error) {
a.v.Store(struct{ error }{err})
}
// Load retrieves the current value of the atomic variable
// Load retrieves the current value of the atomic variable.
func (a *Error) Load() error {
err, _ := a.v.Load().(struct{ error })
return err.error
}

View File

@@ -11,7 +11,7 @@ import (
"time"
)
// MockPacketConn for tests
// MockPacketConn for tests.
type MockPacketConn struct{}
func (m *MockPacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } //nolint:revive

View File

@@ -8,18 +8,19 @@ import (
"net"
)
// Compile-time assertion
// Compile-time assertion.
var _ net.PacketConn = (*PacketConn)(nil)
// PacketConn wraps a net.Conn and emulates net.PacketConn
// PacketConn wraps a net.Conn and emulates net.PacketConn.
type PacketConn struct {
net.Conn
}
// ReadFrom reads a packet from the connection,
// ReadFrom reads a packet from the connection.
func (f *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = f.Conn.Read(p)
addr = f.Conn.RemoteAddr()
return
}

View File

@@ -59,7 +59,7 @@ func GetXORMappedAddr(conn net.PacketConn, serverAddr net.Addr, timeout time.Dur
return &addr, nil
}
// AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value
// AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value.
func AssertUsername(m *stun.Message, expectedUsername string) error {
var username stun.Username
if err := username.GetFrom(m); err != nil {

View File

@@ -13,7 +13,7 @@ import (
atomicx "github.com/pion/ice/v4/internal/atomic"
)
// ErrClosed indicates that the loop has been stopped
// ErrClosed indicates that the loop has been stopped.
var ErrClosed = errors.New("the agent is closed")
type task struct {
@@ -21,7 +21,7 @@ type task struct {
done chan struct{}
}
// Loop runs submitted task serially in a dedicated Goroutine
// Loop runs submitted task serially in a dedicated Goroutine.
type Loop struct {
tasks chan task
@@ -31,7 +31,7 @@ type Loop struct {
err atomicx.Error
}
// New creates and starts a new task loop
// New creates and starts a new task loop.
func New(onClose func()) *Loop {
l := &Loop{
tasks: make(chan task),
@@ -40,6 +40,7 @@ func New(onClose func()) *Loop {
}
go l.runLoop(onClose)
return l
}
@@ -86,6 +87,7 @@ func (l *Loop) Run(ctx context.Context, t func(context.Context)) error {
return ctx.Err()
case l.tasks <- task{t, done}:
<-done
return nil
}
}
@@ -113,7 +115,7 @@ func (l *Loop) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Value is not supported for task loops
// Value is not supported for task loops.
func (l *Loop) Value(interface{}) interface{} {
return nil
}

28
mdns.go
View File

@@ -14,18 +14,19 @@ import (
"golang.org/x/net/ipv6"
)
// MulticastDNSMode represents the different Multicast modes ICE can run in
// MulticastDNSMode represents the different Multicast modes ICE can run in.
type MulticastDNSMode byte
// MulticastDNSMode enum
// MulticastDNSMode enum.
const (
// MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs
// MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs.
MulticastDNSModeDisabled MulticastDNSMode = iota + 1
// MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs
// MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs.
MulticastDNSModeQueryOnly
// MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted, and local host candidates will use mDNS
// MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted,
// and local host candidates will use mDNS.
MulticastDNSModeQueryAndGather
)
@@ -33,11 +34,13 @@ func generateMulticastDNSName() (string, error) {
// https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering
// The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”.
u, err := uuid.NewRandom()
return u.String() + ".local", err
}
//nolint:cyclop
func createMulticastDNS(
n transport.Net,
netTransport transport.Net,
networkTypes []NetworkType,
interfaces []*transport.Interface,
includeLoopback bool,
@@ -57,6 +60,7 @@ func createMulticastDNS(
for _, nt := range networkTypes {
if nt.IsIPv4() {
useV4 = true
continue
}
if nt.IsIPv6() {
@@ -65,11 +69,11 @@ func createMulticastDNS(
}
}
addr4, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
addr4, mdnsErr := netTransport.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
if mdnsErr != nil {
return nil, mDNSMode, mdnsErr
}
addr6, mdnsErr := n.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
addr6, mdnsErr := netTransport.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
if mdnsErr != nil {
return nil, mDNSMode, mdnsErr
}
@@ -78,10 +82,11 @@ func createMulticastDNS(
var mdns4Err error
if useV4 {
var l transport.UDPConn
l, mdns4Err = n.ListenUDP("udp4", addr4)
l, mdns4Err = netTransport.ListenUDP("udp4", addr4)
if mdns4Err != nil {
// If ICE fails to start MulticastDNS server just warn the user and continue
log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err)
return nil, MulticastDNSModeDisabled, nil
}
pktConnV4 = ipv4.NewPacketConn(l)
@@ -91,9 +96,10 @@ func createMulticastDNS(
var mdns6Err error
if useV6 {
var l transport.UDPConn
l, mdns6Err = n.ListenUDP("udp6", addr6)
l, mdns6Err = netTransport.ListenUDP("udp6", addr6)
if mdns6Err != nil {
log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err)
return nil, MulticastDNSModeDisabled, nil
}
pktConnV6 = ipv6.NewPacketConn(l)
@@ -119,6 +125,7 @@ func createMulticastDNS(
Interfaces: ifcs,
IncludeLoopback: includeLoopback,
})
return conn, mDNSMode, err
case MulticastDNSModeQueryAndGather:
conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{
@@ -126,6 +133,7 @@ func createMulticastDNS(
IncludeLoopback: includeLoopback,
LocalNames: []string{mDNSName},
})
return conn, mDNSMode, err
default:
return nil, mDNSMode, nil

41
net.go
View File

@@ -24,6 +24,7 @@ func isSupportedIPv6Partial(ip net.IP) bool {
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast)
return false
}
return true
}
@@ -33,10 +34,11 @@ func isZeros(ip net.IP) bool {
return false
}
}
return true
}
//nolint:gocognit
//nolint:gocognit,cyclop
func localInterfaces(
n transport.Net,
interfaceFilter func(string) (keep bool),
@@ -114,27 +116,35 @@ func localInterfaces(
filteredIfaces = append(filteredIfaces, ifaceCopy)
}
}
return filteredIfaces, ipAddrs, nil
}
func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, portMin int, network string, lAddr *net.UDPAddr) (transport.UDPConn, error) {
//nolint:cyclop
func listenUDPInPortRange(
netTransport transport.Net,
log logging.LeveledLogger,
portMax, portMin int,
network string,
lAddr *net.UDPAddr,
) (transport.UDPConn, error) {
if (lAddr.Port != 0) || ((portMin == 0) && (portMax == 0)) {
return n.ListenUDP(network, lAddr)
return netTransport.ListenUDP(network, lAddr)
}
var i, j int
i = portMin
if i == 0 {
i = 1024 // Start at 1024 which is non-privileged
if portMin == 0 {
portMin = 1024 // Start at 1024 which is non-privileged
}
j = portMax
if j == 0 {
j = 0xFFFF
if portMax == 0 {
portMax = 0xFFFF
}
if i > j {
if portMin > portMax {
return nil, ErrPort
}
portStart := globalMathRandomGenerator.Intn(j-i+1) + i
portStart := globalMathRandomGenerator.Intn(portMax-portMin+1) + portMin
portCurrent := portStart
for {
addr := &net.UDPAddr{
@@ -143,18 +153,19 @@ func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, p
Port: portCurrent,
}
c, e := n.ListenUDP(network, addr)
c, e := netTransport.ListenUDP(network, addr)
if e == nil {
return c, e //nolint:nilerr
}
log.Debugf("Failed to listen %s: %v", lAddr.String(), e)
portCurrent++
if portCurrent > j {
portCurrent = i
if portCurrent > portMax {
portCurrent = portMin
}
if portCurrent == portStart {
break
}
}
return nil, ErrPort
}

View File

@@ -54,6 +54,7 @@ func problematicNetworkInterfaces(s string) (keep bool) {
appleWirelessDirectLink := strings.Contains(s, "awdl")
appleLowLatencyWLANInterface := strings.Contains(s, "llw")
appleTunnelingInterface := strings.Contains(s, "utun")
return !defaultDockerBridgeNetwork &&
!customDockerBridgeNetwork &&
!accessPoint &&
@@ -68,5 +69,6 @@ func mustAddr(t *testing.T, ip net.IP) netip.Addr {
if !ok {
t.Fatal(ipConvertError{ip})
}
return addr
}

View File

@@ -27,7 +27,7 @@ func supportedNetworkTypes() []NetworkType {
}
}
// NetworkType represents the type of network
// NetworkType represents the type of network.
type NetworkType int
const (
@@ -69,7 +69,7 @@ func (t NetworkType) IsTCP() bool {
return t == NetworkTypeTCP4 || t == NetworkTypeTCP6
}
// NetworkShort returns the short network description
// NetworkShort returns the short network description.
func (t NetworkType) NetworkShort() string {
switch t {
case NetworkTypeUDP4, NetworkTypeUDP6:
@@ -81,7 +81,7 @@ func (t NetworkType) NetworkShort() string {
}
}
// IsReliable returns true if the network is reliable
// IsReliable returns true if the network is reliable.
func (t NetworkType) IsReliable() bool {
switch t {
case NetworkTypeUDP4, NetworkTypeUDP6:
@@ -89,6 +89,7 @@ func (t NetworkType) IsReliable() bool {
case NetworkTypeTCP4, NetworkTypeTCP6:
return true
}
return false
}
@@ -100,6 +101,7 @@ func (t NetworkType) IsIPv4() bool {
case NetworkTypeUDP6, NetworkTypeTCP6:
return false
}
return false
}
@@ -111,6 +113,7 @@ func (t NetworkType) IsIPv6() bool {
case NetworkTypeUDP6, NetworkTypeTCP6:
return true
}
return false
}
@@ -124,12 +127,14 @@ func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) {
if ip.Is4() {
return NetworkTypeUDP4, nil
}
return NetworkTypeUDP6, nil
case strings.HasPrefix(strings.ToLower(network), tcp):
if ip.Is4() {
return NetworkTypeTCP4, nil
}
return NetworkTypeTCP6, nil
}

View File

@@ -19,6 +19,7 @@ func (p PriorityAttr) AddTo(m *stun.Message) error {
v := make([]byte, prioritySize)
binary.BigEndian.PutUint32(v, uint32(p))
m.Add(stun.AttrPriority, v)
return nil
}
@@ -32,5 +33,6 @@ func (p *PriorityAttr) GetFrom(m *stun.Message) error {
return err
}
*p = PriorityAttr(binary.BigEndian.Uint32(v))
return nil
}

View File

@@ -12,11 +12,11 @@ import (
func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message)
var p PriorityAttr
if err := p.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
var priority PriorityAttr
if err := priority.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error")
}
if err := m.Build(stun.BindingRequest, &p); err != nil {
if err := m.Build(stun.BindingRequest, &priority); err != nil {
t.Error(err)
}
m1 := new(stun.Message)
@@ -27,7 +27,7 @@ func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
if err := p1.GetFrom(m1); err != nil {
t.Error(err)
}
if p1 != p {
if p1 != priority {
t.Error("not equal")
}
t.Run("IncorrectSize", func(t *testing.T) {

View File

@@ -23,21 +23,27 @@ func TestRandomGeneratorCollision(t *testing.T) {
},
"PWD": {
gen: func(t *testing.T) string {
t.Helper()
s, err := generatePwd()
require.NoError(t, err)
return s
},
},
"Ufrag": {
gen: func(t *testing.T) string {
t.Helper()
s, err := generateUFrag()
require.NoError(t, err)
return s
},
},
}
const N = 100
const num = 100
const iteration = 100
for name, testCase := range testCases {
@@ -47,9 +53,9 @@ func TestRandomGeneratorCollision(t *testing.T) {
var wg sync.WaitGroup
var mu sync.Mutex
rands := make([]string, 0, N)
rands := make([]string, 0, num)
for i := 0; i < N; i++ {
for i := 0; i < num; i++ {
wg.Add(1)
go func() {
r := testCase.gen(t)
@@ -61,12 +67,12 @@ func TestRandomGeneratorCollision(t *testing.T) {
}
wg.Wait()
if len(rands) != N {
if len(rands) != num {
t.Fatal("Failed to generate randoms")
}
for i := 0; i < N; i++ {
for j := i + 1; j < N; j++ {
for i := 0; i < num; i++ {
for j := i + 1; j < num; j++ {
if rands[i] == rands[j] {
t.Fatalf("generateRandString caused collision: %s == %s", rands[i], rands[j])
}

View File

@@ -26,6 +26,7 @@ func (r *Role) UnmarshalText(text []byte) error {
default:
return fmt.Errorf("%w %q", errUnknownRole, text)
}
return nil
}

View File

@@ -44,6 +44,7 @@ func (s *controllingSelector) isNominatable(c Candidate) bool {
}
s.log.Errorf("Invalid candidate type: %s", c.Type())
return false
}
@@ -63,6 +64,7 @@ func (s *controllingSelector) ContactCandidates() {
p.nominated = true
s.nominatedPair = p
s.nominatePair(p)
return
}
s.agent.pingAllCandidates()
@@ -84,6 +86,7 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) {
)
if err != nil {
s.log.Error(err.Error())
return
}
@@ -91,30 +94,35 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) {
s.agent.sendBindingRequest(msg, pair.Local, pair.Remote)
}
func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
s.agent.sendBindingSuccess(m, local, remote)
func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop
s.agent.sendBindingSuccess(message, local, remote)
p := s.agent.findPair(local, remote)
pair := s.agent.findPair(local, remote)
if p == nil {
if pair == nil {
s.agent.addPair(local, remote)
return
}
if p.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil {
if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil {
bestPair := s.agent.getBestAvailableCandidatePair()
if bestPair == nil {
s.log.Tracef("No best pair available")
} else if bestPair.equal(p) && s.isNominatable(p.Local) && s.isNominatable(p.Remote) {
s.log.Tracef("The candidate (%s, %s) is the best candidate available, marking it as nominated", p.Local, p.Remote)
s.nominatedPair = p
s.nominatePair(p)
} else if bestPair.equal(pair) && s.isNominatable(pair.Local) && s.isNominatable(pair.Remote) {
s.log.Tracef(
"The candidate (%s, %s) is the best candidate available, marking it as nominated",
pair.Local,
pair.Remote,
)
s.nominatedPair = pair
s.nominatePair(pair)
}
}
if s.agent.userBindingRequestHandler != nil {
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch {
s.agent.setSelectedPair(p)
if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch {
s.agent.setSelectedPair(pair)
}
}
}
@@ -123,6 +131,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo
ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID)
if !ok {
s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID)
return
}
@@ -131,26 +140,32 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo
// Assert that NAT is not symmetric
// https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1
if !addrEqual(transactionAddr, remoteAddr) {
s.log.Debugf("Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote)
s.log.Debugf(
"Discard message: transaction source and destination does not match expected(%s), actual(%s)",
transactionAddr,
remote,
)
return
}
s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local)
p := s.agent.findPair(local, remote)
pair := s.agent.findPair(local, remote)
if p == nil {
if pair == nil {
// This shouldn't happen
s.log.Error("Success response from invalid candidate pair")
return
}
p.state = CandidatePairStateSucceeded
s.log.Tracef("Found valid candidate pair: %s", p)
pair.state = CandidatePairStateSucceeded
s.log.Tracef("Found valid candidate pair: %s", pair)
if pendingRequest.isUseCandidate && s.agent.getSelectedPair() == nil {
s.agent.setSelectedPair(p)
s.agent.setSelectedPair(pair)
}
p.UpdateRoundTripTime(rtt)
pair.UpdateRoundTripTime(rtt)
}
func (s *controllingSelector) PingCandidate(local, remote Candidate) {
@@ -163,6 +178,7 @@ func (s *controllingSelector) PingCandidate(local, remote Candidate) {
)
if err != nil {
s.log.Error(err.Error())
return
}
@@ -198,6 +214,7 @@ func (s *controlledSelector) PingCandidate(local, remote Candidate) {
)
if err != nil {
s.log.Error(err.Error())
return
}
@@ -216,6 +233,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID)
if !ok {
s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID)
return
}
@@ -224,52 +242,62 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
// Assert that NAT is not symmetric
// https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1
if !addrEqual(transactionAddr, remoteAddr) {
s.log.Debugf("Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote)
s.log.Debugf(
"Discard message: transaction source and destination does not match expected(%s), actual(%s)",
transactionAddr,
remote,
)
return
}
s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local)
p := s.agent.findPair(local, remote)
if p == nil {
pair := s.agent.findPair(local, remote)
if pair == nil {
// This shouldn't happen
s.log.Error("Success response from invalid candidate pair")
return
}
p.state = CandidatePairStateSucceeded
s.log.Tracef("Found valid candidate pair: %s", p)
if p.nominateOnBindingSuccess {
pair.state = CandidatePairStateSucceeded
s.log.Tracef("Found valid candidate pair: %s", pair)
if pair.nominateOnBindingSuccess {
if selectedPair := s.agent.getSelectedPair(); selectedPair == nil ||
(selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) {
s.agent.setSelectedPair(p)
} else if selectedPair != p {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair)
(selectedPair != pair &&
(!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= pair.priority())) {
s.agent.setSelectedPair(pair)
} else if selectedPair != pair {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair)
}
}
p.UpdateRoundTripTime(rtt)
pair.UpdateRoundTripTime(rtt)
}
func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
p := s.agent.findPair(local, remote)
if p == nil {
p = s.agent.addPair(local, remote)
func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop
pair := s.agent.findPair(local, remote)
if pair == nil {
pair = s.agent.addPair(local, remote)
}
if m.Contains(stun.AttrUseCandidate) {
if message.Contains(stun.AttrUseCandidate) { //nolint:nestif
// https://tools.ietf.org/html/rfc8445#section-7.3.1.5
if p.state == CandidatePairStateSucceeded {
if pair.state == CandidatePairStateSucceeded {
// If the state of this pair is Succeeded, it means that the check
// previously sent by this pair produced a successful response and
// generated a valid pair (Section 7.2.5.3.2). The agent sets the
// nominated flag value of the valid pair to true.
selectedPair := s.agent.getSelectedPair()
if selectedPair == nil || (selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) {
s.agent.setSelectedPair(p)
} else if selectedPair != p {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair)
if selectedPair == nil ||
(selectedPair != pair &&
(!s.agent.needsToCheckPriorityOnNominated() ||
selectedPair.priority() <= pair.priority())) {
s.agent.setSelectedPair(pair)
} else if selectedPair != pair {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair)
}
} else {
// If the received Binding request triggered a new check to be
@@ -280,16 +308,16 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote
// MUST remove the candidate pair from the valid list, set the
// candidate pair state to Failed, and set the checklist state to
// Failed.
p.nominateOnBindingSuccess = true
pair.nominateOnBindingSuccess = true
}
}
s.agent.sendBindingSuccess(m, local, remote)
s.agent.sendBindingSuccess(message, local, remote)
s.PingCandidate(local, remote)
if s.agent.userBindingRequestHandler != nil {
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch {
s.agent.setSelectedPair(p)
if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch {
s.agent.setSelectedPair(pair)
}
}
}
@@ -298,7 +326,7 @@ type liteSelector struct {
pairCandidateSelector
}
// A lite selector should not contact candidates
// A lite selector should not contact candidates.
func (s *liteSelector) ContactCandidates() {
if _, ok := s.pairCandidateSelector.(*controllingSelector); ok {
//nolint:godox

View File

@@ -23,6 +23,8 @@ import (
)
func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool {
t.Helper()
testMessage := []byte("Hello World")
testBuffer := make([]byte, len(testMessage))
@@ -73,6 +75,7 @@ func TestBindingRequestHandler(t *testing.T) {
CheckInterval: &oneHour,
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
controlledLoggingFired.Store(true)
return false
},
})
@@ -87,6 +90,7 @@ func TestBindingRequestHandler(t *testing.T) {
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
// Don't switch candidate pair until we are ready
val, ok := switchToNewCandidatePair.Load().(bool)
return ok && val
},
})

View File

@@ -7,7 +7,7 @@ import (
"time"
)
// CandidatePairStats contains ICE candidate pair statistics
// CandidatePairStats contains ICE candidate pair statistics.
type CandidatePairStats struct {
// Timestamp is the timestamp associated with this object.
Timestamp time.Time

View File

@@ -79,20 +79,20 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
params.AliveDurationForConnFromStun = 30 * time.Second
}
m := &TCPMuxDefault{
mux := &TCPMuxDefault{
params: &params,
connsIPv4: map[string]map[ipAddr]*tcpPacketConn{},
connsIPv6: map[string]map[ipAddr]*tcpPacketConn{},
}
m.wg.Add(1)
mux.wg.Add(1)
go func() {
defer m.wg.Done()
m.start()
defer mux.wg.Done()
mux.start()
}()
return m
return mux
}
func (m *TCPMuxDefault) start() {
@@ -101,6 +101,7 @@ func (m *TCPMuxDefault) start() {
conn, err := m.params.Listener.Accept()
if err != nil {
m.params.Logger.Infof("Error accepting connection: %s", err)
return
}
@@ -130,6 +131,7 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP)
if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
conn.ClearAliveTimer()
return conn, nil
}
@@ -191,12 +193,17 @@ func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
}
}
func (m *TCPMuxDefault) handleConn(conn net.Conn) {
func (m *TCPMuxDefault) handleConn(conn net.Conn) { //nolint:cyclop
buf := make([]byte, 512)
if m.params.FirstStunBindTimeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil {
m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
m.params.Logger.Warnf(
"Failed to set read deadline for first STUN message: %s to %s , err: %s",
conn.RemoteAddr(),
conn.LocalAddr(),
err,
)
}
}
n, err := readStreamingPacket(conn, buf)
@@ -207,6 +214,7 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err)
}
m.closeAndLogError(conn)
return
}
if err = conn.SetReadDeadline(time.Time{}); err != nil {
@@ -223,12 +231,14 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
if err = msg.Decode(); err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN
m.closeAndLogError(conn)
m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
return
}
@@ -239,7 +249,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
m.params.Logger.Warnf(
"No Username attribute in STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return
}
@@ -249,7 +264,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
m.params.Logger.Warnf(
"Failed to get host in STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return
}
@@ -258,7 +278,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
localAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if !ok {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
m.params.Logger.Warnf(
"Failed to get local tcp address in STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return
}
m.mu.Lock()
@@ -269,7 +294,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
if err != nil {
m.mu.Unlock()
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
m.params.Logger.Warnf(
"Failed to create packetConn for STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return
}
}
@@ -277,7 +307,13 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
if err := packetConn.AddConn(conn, buf); err != nil {
m.closeAndLogError(conn)
m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
m.params.Logger.Warnf(
"Error adding conn to tcpPacketConn from %s to %s: %s",
conn.RemoteAddr(),
conn.LocalAddr(),
err,
)
return
}
}
@@ -428,7 +464,7 @@ func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
bufCopy := make([]byte, streamingPacketHeaderLen+len(buf))
binary.BigEndian.PutUint16(bufCopy, uint16(len(buf)))
binary.BigEndian.PutUint16(bufCopy, uint16(len(buf))) //nolint:gosec // G115
copy(bufCopy[2:], buf)
n, err := conn.Write(bufCopy)

View File

@@ -40,6 +40,7 @@ func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net
if len(m.muxes) == 0 {
return nil, errNoTCPMuxAvailable
}
return m.muxes[0].GetConnByUfrag(ufrag, isIPv6, local)
}
@@ -51,7 +52,7 @@ func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
}
// GetAllConns returns a PacketConn for each underlying TCPMux
// GetAllConns returns a PacketConn for each underlying TCPMux.
func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) {
if len(m.muxes) == 0 {
// Make sure that we either return at least one connection or an error.
@@ -68,10 +69,11 @@ func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP
conns = append(conns, conn)
}
}
return conns, nil
}
// Close the multi mux, no further connections could be created
// Close the multi mux, no further connections could be created.
func (m *MultiTCPMuxDefault) Close() error {
var err error
for _, mux := range m.muxes {
@@ -79,5 +81,6 @@ func (m *MultiTCPMuxDefault) Close() error {
err = e
}
}
return err
}

View File

@@ -36,6 +36,7 @@ func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) n
}
go bc.writeProcess()
return bc
}
@@ -44,6 +45,7 @@ func (bc *bufferedConn) Write(b []byte) (int, error) {
if err != nil {
return n, err
}
return n, nil
}
@@ -57,11 +59,13 @@ func (bc *bufferedConn) writeProcess() {
if err != nil {
bc.logger.Warnf("Failed to read from buffer: %s", err)
continue
}
if _, err := bc.Conn.Write(pktBuf[:n]); err != nil {
bc.logger.Warnf("Failed to write: %s", err)
continue
}
}
@@ -70,6 +74,7 @@ func (bc *bufferedConn) writeProcess() {
func (bc *bufferedConn) Close() error {
atomic.StoreInt32(&bc.closed, 1)
_ = bc.buf.Close()
return bc.Conn.Close()
}
@@ -103,7 +108,7 @@ type tcpPacketParams struct {
}
func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
p := &tcpPacketConn{
packet := &tcpPacketConn{
params: &params,
conns: map[string]net.Conn{},
@@ -113,13 +118,13 @@ func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
}
if params.AliveDuration > 0 {
p.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
p.params.Logger.Warn("close tcp packet conn by alive timeout")
_ = p.Close()
packet.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
packet.params.Logger.Warn("close tcp packet conn by alive timeout")
_ = packet.Close()
})
}
return p
return packet
}
func (t *tcpPacketConn) ClearAliveTimer() {
@@ -131,7 +136,12 @@ func (t *tcpPacketConn) ClearAliveTimer() {
}
func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr())
t.params.Logger.Infof(
"Added connection: %s remote %s to local %s",
conn.RemoteAddr().Network(),
conn.RemoteAddr(),
conn.LocalAddr(),
)
t.mu.Lock()
defer t.mu.Unlock()
@@ -183,6 +193,7 @@ func (t *tcpPacketConn) startReading(conn net.Conn) {
if last || !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
}
return
}
@@ -236,6 +247,7 @@ func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
n = len(pkt.Data)
copy(b, pkt.Data[:n])
return n, pkt.RAddr, err
}
@@ -252,6 +264,7 @@ func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
n, err = writeStreamingPacket(conn, buf)
if err != nil {
t.params.Logger.Tracef("%w %s", errWrite, rAddr)
return n, err
}
@@ -272,6 +285,7 @@ func (t *tcpPacketConn) removeConn(conn net.Conn) bool {
t.closeAndLogError(conn)
delete(t.conns, conn.RemoteAddr().String())
return len(t.conns) == 0
}

View File

@@ -32,12 +32,12 @@ type Conn struct {
agent *Agent
}
// BytesSent returns the number of bytes sent
// BytesSent returns the number of bytes sent.
func (c *Conn) BytesSent() uint64 {
return atomic.LoadUint64(&c.bytesSent)
}
// BytesReceived returns the number of bytes received
// BytesReceived returns the number of bytes received.
func (c *Conn) BytesReceived() uint64 {
return atomic.LoadUint64(&c.bytesReceived)
}
@@ -74,18 +74,19 @@ func (c *Conn) Read(p []byte) (int, error) {
}
n, err := c.agent.buf.Read(p)
atomic.AddUint64(&c.bytesReceived, uint64(n))
atomic.AddUint64(&c.bytesReceived, uint64(n)) //nolint:gosec // G115
return n, err
}
// Write implements the Conn Write method.
func (c *Conn) Write(p []byte) (int, error) {
func (c *Conn) Write(packet []byte) (int, error) {
err := c.agent.loop.Err()
if err != nil {
return 0, err
}
if stun.IsMessage(p) {
if stun.IsMessage(packet) {
return 0, errWriteSTUNMessageToIceConn
}
@@ -102,8 +103,9 @@ func (c *Conn) Write(p []byte) (int, error) {
}
}
atomic.AddUint64(&c.bytesSent, uint64(len(p)))
return pair.Write(p)
atomic.AddUint64(&c.bytesSent, uint64(len(packet)))
return pair.Write(packet)
}
// Close implements the Conn Close method. It is used to close
@@ -132,17 +134,17 @@ func (c *Conn) RemoteAddr() net.Addr {
return pair.Remote.addr()
}
// SetDeadline is a stub
// SetDeadline is a stub.
func (c *Conn) SetDeadline(time.Time) error {
return nil
}
// SetReadDeadline is a stub
// SetReadDeadline is a stub.
func (c *Conn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is a stub
// SetWriteDeadline is a stub.
func (c *Conn) SetWriteDeadline(time.Time) error {
return nil
}

View File

@@ -30,13 +30,15 @@ func TestStressDuplex(t *testing.T) {
stressDuplex(t)
}
func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) {
t.Helper()
const pollRate = 100 * time.Millisecond
const margin = 20 * time.Millisecond // Allow 20msec error in time
ticker := time.NewTicker(pollRate)
defer func() {
ticker.Stop()
err := c.Close()
err := conn.Close()
if err != nil {
t.Error(err)
}
@@ -49,8 +51,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
var cs ConnectionState
err := c.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = c.agent.connectionState
err := conn.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = conn.agent.connectionState
})
if err != nil {
// We should never get here.
@@ -63,6 +65,7 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
t.Fatalf("Connection timed out %f msec early", elapsed.Seconds()*1000)
} else {
t.Logf("Connection timed out in %f msec", elapsed.Seconds()*1000)
return
}
}
@@ -133,6 +136,8 @@ func TestReadClosed(t *testing.T) {
}
func stressDuplex(t *testing.T) {
t.Helper()
ca, cb := pipe(nil)
defer func() {
@@ -219,6 +224,7 @@ func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
// Ensure accepted
<-accepted
return aConn, bConn
}
@@ -288,6 +294,7 @@ func pipeWithTimeout(disconnectTimeout time.Duration, iceKeepalive time.Duration
func onConnected() (func(ConnectionState), chan struct{}) {
done := make(chan struct{})
return func(state ConnectionState) {
if state == ConnectionStateConnected {
close(done)
@@ -295,11 +302,11 @@ func onConnected() (func(ConnectionState), chan struct{}) {
}, done
}
func randomPort(t testing.TB) int {
t.Helper()
func randomPort(tb testing.TB) int {
tb.Helper()
conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to pickPort: %v", err)
tb.Fatalf("failed to pickPort: %v", err)
}
defer func() {
_ = conn.Close()
@@ -308,7 +315,8 @@ func randomPort(t testing.TB) int {
case *net.UDPAddr:
return addr.Port
default:
t.Fatalf("unknown addr type %T", addr)
tb.Fatalf("unknown addr type %T", addr)
return 0
}
}

View File

@@ -30,9 +30,9 @@ func TestRemoteLocalAddr(t *testing.T) {
// Agent1 is behind 1:1 NAT
natType1 := &vnet.NATType{Mode: vnet.NATModeNAT1To1}
v, errVnet := buildVNet(natType0, natType1)
builtVnet, errVnet := buildVNet(natType0, natType1)
require.NoError(t, errVnet, "should succeed")
defer v.close()
defer builtVnet.close()
stunServerURL := &stun.URI{
Scheme: stun.SchemeTypeSTUN,
@@ -53,7 +53,7 @@ func TestRemoteLocalAddr(t *testing.T) {
})
t.Run("Remote/Local Pair Match between Agents", func(t *testing.T) {
ca, cb := pipeWithVNet(v,
ca, cb := pipeWithVNet(builtVnet,
&agentTestConfig{
urls: []*stun.URI{stunServerURL},
},

View File

@@ -18,7 +18,7 @@ import (
"github.com/pion/transport/v3/stdnet"
)
// UDPMux allows multiple connections to go over a single UDP port
// UDPMux allows multiple connections to go over a single UDP port.
type UDPMux interface {
io.Closer
GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
@@ -26,7 +26,7 @@ type UDPMux interface {
GetListenAddresses() []net.Addr
}
// UDPMuxDefault is an implementation of the interface
// UDPMuxDefault is an implementation of the interface.
type UDPMuxDefault struct {
params UDPMuxParams
@@ -60,14 +60,14 @@ type UDPMuxParams struct {
Net transport.Net
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
// NewUDPMuxDefault creates an implementation of UDPMux.
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { //nolint:cyclop
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
var localAddrsForUnspecified []net.Addr
if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { //nolint:nestif
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && udpAddr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
@@ -109,7 +109,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}
params.UDPConnString = params.UDPConn.LocalAddr().String()
m := &UDPMuxDefault{
mux := &UDPMuxDefault{
addressMap: map[ipPort]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
@@ -124,17 +124,17 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
localAddrsForUnspecified: localAddrsForUnspecified,
}
go m.connWorker()
go mux.connWorker()
return m
return mux
}
// LocalAddr returns the listening address of this UDPMuxDefault
// LocalAddr returns the listening address of this UDPMuxDefault.
func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
// GetListenAddresses returns the list of addresses that this mux is listening on.
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified
@@ -143,8 +143,8 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()}
}
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
// GetConn returns a PacketConn given the connection's ufrag and network address.
// creates the connection if an existing one can't be found.
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() {
@@ -181,11 +181,11 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return c, nil
}
// RemoveConnByUfrag stops and removes the muxed packet connection
// RemoveConnByUfrag stops and removes the muxed packet connection.
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
// Keep lock section small to avoid deadlock with conn lock.
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
@@ -198,7 +198,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Unlock()
if len(removedConns) == 0 {
// No need to lock if no connection was found
// No need to lock if no connection was found.
return
}
@@ -213,7 +213,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
}
// IsClosed returns true if the mux had been closed
// IsClosed returns true if the mux had been closed.
func (m *UDPMuxDefault) IsClosed() bool {
select {
case <-m.closedChan:
@@ -223,7 +223,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
}
}
// Close the mux, no further connections could be created
// Close the mux, no further connections could be created.
func (m *UDPMuxDefault) Close() error {
var err error
m.closeOnce.Do(func() {
@@ -244,6 +244,7 @@ func (m *UDPMuxDefault) Close() error {
_ = m.params.UDPConn.Close()
})
return err
}
@@ -276,10 +277,11 @@ func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
return c
}
func (m *UDPMuxDefault) connWorker() {
func (m *UDPMuxDefault) connWorker() { //nolint:cyclop
logger := m.params.Logger
defer func() {
@@ -304,11 +306,13 @@ func (m *UDPMuxDefault) connWorker() {
netUDPAddr, ok := addr.(*net.UDPAddr)
if !ok {
logger.Errorf("Underlying PacketConn did not return a UDPAddr")
return
}
udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port))
udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) //nolint:gosec
if err != nil {
logger.Errorf("Failed to create a new IP/Port host pair")
return
}
@@ -325,12 +329,14 @@ func (m *UDPMuxDefault) connWorker() {
if err = msg.Decode(); err != nil {
m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
continue
}
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil {
m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String())
continue
}
@@ -344,6 +350,7 @@ func (m *UDPMuxDefault) connWorker() {
if destinationConn == nil {
m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr, addr)
continue
}
@@ -359,6 +366,7 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}
@@ -386,7 +394,7 @@ type ipPort struct {
// newIPPort create a custom type of address based on netip.Addr and
// port. The underlying ip address passed is converted to IPv6 format
// to simplify ip address handling
// to simplify ip address handling.
func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) {
n, ok := netip.AddrFromSlice(ip.To16())
if !ok {

View File

@@ -29,6 +29,7 @@ func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault {
addrToMux[addr.String()] = mux
}
}
return &MultiUDPMuxDefault{
muxes: muxes,
localAddrToMux: addrToMux,
@@ -42,6 +43,7 @@ func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketCon
if !ok {
return nil, errNoUDPMuxAvailable
}
return mux.GetConn(ufrag, addr)
}
@@ -53,7 +55,7 @@ func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
}
// Close the multi mux, no further connections could be created
// Close the multi mux, no further connections could be created.
func (m *MultiUDPMuxDefault) Close() error {
var err error
for _, mux := range m.muxes {
@@ -61,21 +63,23 @@ func (m *MultiUDPMuxDefault) Close() error {
err = e
}
}
return err
}
// GetListenAddresses returns the list of addresses that this mux is listening on
// GetListenAddresses returns the list of addresses that this mux is listening on.
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
addrs := make([]net.Addr, 0, len(m.localAddrToMux))
for _, mux := range m.muxes {
addrs = append(addrs, mux.GetListenAddresses()...)
}
return addrs
}
// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
// listen all interfaces on the provided port.
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { //nolint:cyclop
params := multiUDPMuxFromPortParam{
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
}
@@ -104,6 +108,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
})
if listenErr != nil {
err = listenErr
break
}
if params.readBufferSize > 0 {
@@ -119,6 +124,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
for _, conn := range conns {
_ = conn.Close()
}
return nil, err
}
@@ -135,7 +141,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
return NewMultiUDPMuxDefault(muxes...), nil
}
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort.
type UDPMuxFromPortOption interface {
apply(*multiUDPMuxFromPortParam)
}
@@ -159,7 +165,7 @@ func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) {
o.f(p)
}
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used.
func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
@@ -168,7 +174,7 @@ func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPor
}
}
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used.
func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
@@ -177,7 +183,7 @@ func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOpt
}
}
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6.
func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
@@ -186,7 +192,7 @@ func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
}
}
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size.
func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
@@ -195,7 +201,7 @@ func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
}
}
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size.
func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
@@ -204,7 +210,7 @@ func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
}
}
// UDPMuxFromPortWithLogger set the logger for the created UDPMux
// UDPMuxFromPortWithLogger set the logger for the created UDPMux.
func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
@@ -213,7 +219,7 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
}
}
// UDPMuxFromPortWithLoopback set loopback interface should be included
// UDPMuxFromPortWithLoopback set loopback interface should be included.
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {

View File

@@ -79,6 +79,8 @@ func TestMultiUDPMux(t *testing.T) {
}
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
t.Helper()
addrs := udpMuxMulti.GetListenAddresses()
pktConns := make([]net.PacketConn, 0, len(addrs))
for _, addr := range addrs {

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestUDPMux(t *testing.T) {
func TestUDPMux(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop()
@@ -127,6 +127,8 @@ func TestUDPMux(t *testing.T) {
}
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
t.Helper()
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
@@ -145,6 +147,8 @@ func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, networ
}
func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
t.Helper()
// Initial messages are dropped
_, err := remoteConn.Write([]byte("dropped bytes"))
require.NoError(t, err)
@@ -222,7 +226,7 @@ func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net
require.NoError(t, err)
h := sha256.Sum256(buf[36:])
copy(buf[4:36], h[:])
binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence))
binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence)) //nolint:gosec // G115
_, err = remoteConn.Write(buf)
require.NoError(t, err)
@@ -238,6 +242,8 @@ func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net
}
func verifyPacket(t *testing.T, b []byte, nextSeq uint32) {
t.Helper()
readSeq := binary.LittleEndian.Uint32(b[0:4])
require.Equal(t, nextSeq, readSeq)
h := sha256.Sum256(b[36:])

View File

@@ -29,7 +29,8 @@ type UniversalUDPMuxDefault struct {
*UDPMuxDefault
params UniversalUDPMuxParams
// Since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
// Since we have a shared socket, for srflx candidates it makes sense
// to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped
}
@@ -42,7 +43,7 @@ type UniversalUDPMuxParams struct {
Net transport.Net
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux.
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
@@ -51,31 +52,31 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
params.XORMappedAddrCacheTTL = time.Second * 25
}
m := &UniversalUDPMuxDefault{
mux := &UniversalUDPMuxDefault{
params: params,
xorMappedMap: make(map[string]*xorMapped),
}
// Wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{
mux.params.UDPConn = &udpConn{
PacketConn: params.UDPConn,
mux: m,
mux: mux,
logger: params.Logger,
}
// Embed UDPMux
udpMuxParams := UDPMuxParams{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
UDPConn: mux.params.UDPConn,
Net: mux.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
mux.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m
return mux
}
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets.
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
@@ -88,7 +89,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(net.Addr, time.Duration) (*net.A
return nil, errNotImplemented
}
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL
// (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
@@ -99,24 +101,24 @@ func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr ne
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil {
return
return n, addr, err
}
if stun.IsMessage(p[:n]) {
if stun.IsMessage(p[:n]) { //nolint:nestif
msg := &stun.Message{
Raw: append([]byte{}, p[:n]...),
}
if err = msg.Decode(); err != nil {
c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
err = nil
return
return n, addr, nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
// Message about this err will be logged in the UDPMux
return
return n, addr, err
}
if c.mux.isXORMappedResponse(msg, udpAddr.String()) {
@@ -125,9 +127,11 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err)
err = nil
}
return
return n, addr, err
}
}
return n, addr, err
}
@@ -135,14 +139,16 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock()
defer m.mu.Unlock()
// Check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
// Check first if it is a STUN server address,
// because remote peer can also send similar messages but as a BindingSuccess.
_, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok
}
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
// and set the mapped address for the server
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute.
// and set the mapped address for the server.
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -167,7 +173,10 @@ func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr,
// Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(
serverAddr net.Addr,
deadline time.Duration,
) (*stun.XORMappedAddress, error) {
m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// If we already have a mapping for this STUN server (address already received)
@@ -203,6 +212,7 @@ func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline
if mappedAddr.addr == nil {
return nil, errNoXorAddrMapping
}
return mappedAddr.addr, nil
case <-time.After(deadline):
return nil, errXORMappedAddrTimeout

View File

@@ -43,6 +43,8 @@ func TestUniversalUDPMux(t *testing.T) {
}
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
t.Helper()
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {

View File

@@ -28,7 +28,7 @@ type udpMuxedConnParams struct {
Logger logging.LeveledLogger
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag.
type udpMuxedConn struct {
params *udpMuxedConnParams
// Remote addresses that we have sent to on this conn
@@ -72,11 +72,12 @@ func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
pkt.reset()
c.params.AddrPool.Put(pkt)
return
return n, rAddr, err
}
if c.state == udpMuxedConnClosed {
c.mu.Unlock()
return 0, nil, io.EOF
}
@@ -101,6 +102,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
return 0, errFailedToCastUDPAddr
}
//nolint:gosec // TODO add port validation G115
ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port))
if err != nil {
return 0, err
@@ -150,12 +152,14 @@ func (c *udpMuxedConn) Close() error {
c.state = udpMuxedConnClosed
close(c.closedChan)
}
return nil
}
func (c *udpMuxedConn) isClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.state == udpMuxedConnClosed
}
@@ -164,6 +168,7 @@ func (c *udpMuxedConn) getAddresses() []ipPort {
defer c.mu.Unlock()
addresses := make([]ipPort, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}
@@ -198,6 +203,7 @@ func (c *udpMuxedConn) containsAddress(addr ipPort) bool {
return true
}
}
return false
}
@@ -205,6 +211,7 @@ func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
if cap(pkt.buf) < len(data) {
c.params.AddrPool.Put(pkt)
return io.ErrShortBuffer
}

36
url.go
View File

@@ -6,77 +6,77 @@ package ice
import "github.com/pion/stun/v3"
type (
// URL represents a STUN (rfc7064) or TURN (rfc7065) URI
// URL represents a STUN (rfc7064) or TURN (rfc7065) URI.
//
// Deprecated: Please use pion/stun.URI
// Deprecated: Please use pion/stun.URI.
URL = stun.URI
// ProtoType indicates the transport protocol type that is used in the ice.URL
// structure.
//
// Deprecated: TPlease use pion/stun.ProtoType
// Deprecated: TPlease use pion/stun.ProtoType.
ProtoType = stun.ProtoType
// SchemeType indicates the type of server used in the ice.URL structure.
//
// Deprecated: Please use pion/stun.SchemeType
// Deprecated: Please use pion/stun.SchemeType.
SchemeType = stun.SchemeType
)
const (
// SchemeTypeSTUN indicates the URL represents a STUN server.
//
// Deprecated: Please use pion/stun.SchemeTypeSTUN
// Deprecated: Please use pion/stun.SchemeTypeSTUN.
SchemeTypeSTUN = stun.SchemeTypeSTUN
// SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server.
//
// Deprecated: Please use pion/stun.SchemeTypeSTUNS
// Deprecated: Please use pion/stun.SchemeTypeSTUNS.
SchemeTypeSTUNS = stun.SchemeTypeSTUNS
// SchemeTypeTURN indicates the URL represents a TURN server.
//
// Deprecated: Please use pion/stun.SchemeTypeTURN
// Deprecated: Please use pion/stun.SchemeTypeTURN.
SchemeTypeTURN = stun.SchemeTypeTURN
// SchemeTypeTURNS indicates the URL represents a TURNS (secure) server.
//
// Deprecated: Please use pion/stun.SchemeTypeTURNS
// Deprecated: Please use pion/stun.SchemeTypeTURNS.
SchemeTypeTURNS = stun.SchemeTypeTURNS
)
const (
// ProtoTypeUDP indicates the URL uses a UDP transport.
//
// Deprecated: Please use pion/stun.ProtoTypeUDP
// Deprecated: Please use pion/stun.ProtoTypeUDP.
ProtoTypeUDP = stun.ProtoTypeUDP
// ProtoTypeTCP indicates the URL uses a TCP transport.
//
// Deprecated: Please use pion/stun.ProtoTypeTCP
// Deprecated: Please use pion/stun.ProtoTypeTCP.
ProtoTypeTCP = stun.ProtoTypeTCP
)
// Unknown represents and unknown ProtoType or SchemeType
// Unknown represents and unknown ProtoType or SchemeType.
//
// Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown
// Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown.
const Unknown = 0
// ParseURL parses a STUN or TURN urls following the ABNF syntax described in
// ParseURL parses a STUN or TURN urls following the ABNF syntax described in.
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
// respectively.
//
// Deprecated: Please use pion/stun.ParseURI
// Deprecated: Please use pion/stun.ParseURI.
var ParseURL = stun.ParseURI //nolint:gochecknoglobals
// NewSchemeType defines a procedure for creating a new SchemeType from a raw
// NewSchemeType defines a procedure for creating a new SchemeType from a raw.
// string naming the scheme type.
//
// Deprecated: Please use pion/stun.NewSchemeType
// Deprecated: Please use pion/stun.NewSchemeType.
var NewSchemeType = stun.NewSchemeType //nolint:gochecknoglobals
// NewProtoType defines a procedure for creating a new ProtoType from a raw
// NewProtoType defines a procedure for creating a new ProtoType from a raw.
// string naming the transport protocol type.
//
// Deprecated: Please use pion/stun.NewProtoType
// Deprecated: Please use pion/stun.NewProtoType.
var NewProtoType = stun.NewProtoType //nolint:gochecknoglobals

View File

@@ -11,12 +11,14 @@ type UseCandidateAttr struct{}
// AddTo adds USE-CANDIDATE attribute to message.
func (UseCandidateAttr) AddTo(m *stun.Message) error {
m.Add(stun.AttrUseCandidate, nil)
return nil
}
// IsSet returns true if USE-CANDIDATE attribute is set.
func (UseCandidateAttr) IsSet(m *stun.Message) bool {
_, err := m.Get(stun.AttrUseCandidate)
return err == nil
}

View File

@@ -13,6 +13,8 @@ import (
)
func newHostRemote(t *testing.T) *CandidateHost {
t.Helper()
remoteHostConfig := &CandidateHostConfig{
Network: "udp",
Address: "1.2.3.5",
@@ -21,10 +23,13 @@ func newHostRemote(t *testing.T) *CandidateHost {
}
hostRemote, err := NewCandidateHost(remoteHostConfig)
require.NoError(t, err)
return hostRemote
}
func newPrflxRemote(t *testing.T) *CandidatePeerReflexive {
t.Helper()
prflxConfig := &CandidatePeerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
@@ -35,10 +40,13 @@ func newPrflxRemote(t *testing.T) *CandidatePeerReflexive {
}
prflxRemote, err := NewCandidatePeerReflexive(prflxConfig)
require.NoError(t, err)
return prflxRemote
}
func newSrflxRemote(t *testing.T) *CandidateServerReflexive {
t.Helper()
srflxConfig := &CandidateServerReflexiveConfig{
Network: "udp",
Address: "10.10.10.2",
@@ -49,10 +57,13 @@ func newSrflxRemote(t *testing.T) *CandidateServerReflexive {
}
srflxRemote, err := NewCandidateServerReflexive(srflxConfig)
require.NoError(t, err)
return srflxRemote
}
func newRelayRemote(t *testing.T) *CandidateRelay {
t.Helper()
relayConfig := &CandidateRelayConfig{
Network: "udp",
Address: "1.2.3.4",
@@ -63,10 +74,13 @@ func newRelayRemote(t *testing.T) *CandidateRelay {
}
relayRemote, err := NewCandidateRelay(relayConfig)
require.NoError(t, err)
return relayRemote
}
func newHostLocal(t *testing.T) *CandidateHost {
t.Helper()
localHostConfig := &CandidateHostConfig{
Network: "udp",
Address: "192.168.1.1",
@@ -75,5 +89,6 @@ func newHostLocal(t *testing.T) *CandidateHost {
}
hostLocal, err := NewCandidateHost(localHostConfig)
require.NoError(t, err)
return hostLocal
}