Upgrade golangci-lint, more linters

Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
Joe Turki
2025-01-17 08:21:15 -06:00
parent 647b9786dd
commit cad1676659
67 changed files with 1619 additions and 1019 deletions

View File

@@ -25,17 +25,32 @@ linters-settings:
- ^os.Exit$ - ^os.Exit$
- ^panic$ - ^panic$
- ^print(ln)?$ - ^print(ln)?$
varnamelen:
max-distance: 12
min-name-length: 2
ignore-type-assert-ok: true
ignore-map-index-ok: true
ignore-chan-recv-ok: true
ignore-decls:
- i int
- n int
- w io.Writer
- r io.Reader
- b []byte
linters: linters:
enable: enable:
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- bidichk # Checks for dangerous unicode character sequences - bidichk # Checks for dangerous unicode character sequences
- bodyclose # checks whether HTTP response body is closed successfully - bodyclose # checks whether HTTP response body is closed successfully
- containedctx # containedctx is a linter that detects struct contained context.Context field
- contextcheck # check the function whether use a non-inherited context - contextcheck # check the function whether use a non-inherited context
- cyclop # checks function and package cyclomatic complexity
- decorder # check declaration order and count of types, constants, variables and functions - decorder # check declaration order and count of types, constants, variables and functions
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- dupl # Tool for code clone detection - dupl # Tool for code clone detection
- durationcheck # check for two durations multiplied together - durationcheck # check for two durations multiplied together
- err113 # Golang linter to check the errors handling expressions
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
@@ -46,18 +61,17 @@ linters:
- forcetypeassert # finds forced type assertions - forcetypeassert # finds forced type assertions
- gci # Gci control golang package import order and make it always deterministic. - gci # Gci control golang package import order and make it always deterministic.
- gochecknoglobals # Checks that no globals are present in Go code - gochecknoglobals # Checks that no globals are present in Go code
- gochecknoinits # Checks that no init functions are present in Go code
- gocognit # Computes and checks the cognitive complexity of functions - gocognit # Computes and checks the cognitive complexity of functions
- goconst # Finds repeated strings that could be replaced by a constant - goconst # Finds repeated strings that could be replaced by a constant
- gocritic # The most opinionated Go source code linter - gocritic # The most opinionated Go source code linter
- gocyclo # Computes and checks the cyclomatic complexity of functions
- godot # Check if comments end in a period
- godox # Tool for detection of FIXME, TODO and other comment keywords - godox # Tool for detection of FIXME, TODO and other comment keywords
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- gofumpt # Gofumpt checks whether code was gofumpt-ed. - gofumpt # Gofumpt checks whether code was gofumpt-ed.
- goheader # Checks is file header matches to pattern - goheader # Checks is file header matches to pattern
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- goprintffuncname # Checks that printf-like functions are named with `f` at the end - goprintffuncname # Checks that printf-like functions are named with `f` at the end
- gosec # Inspects source code for security problems - gosec # Inspects source code for security problems
- gosimple # Linter for Go source code that specializes in simplifying a code - gosimple # Linter for Go source code that specializes in simplifying a code
@@ -65,9 +79,15 @@ linters:
- grouper # An analyzer to analyze expression groups. - grouper # An analyzer to analyze expression groups.
- importas # Enforces consistent import aliases - importas # Enforces consistent import aliases
- ineffassign # Detects when assignments to existing variables are not used - ineffassign # Detects when assignments to existing variables are not used
- lll # Reports long lines
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- misspell # Finds commonly misspelled English words in comments - misspell # Finds commonly misspelled English words in comments
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- noctx # noctx finds sending http request without context.Context - noctx # noctx finds sending http request without context.Context
- predeclared # find code that shadows one of Go's predeclared identifiers - predeclared # find code that shadows one of Go's predeclared identifiers
- revive # golint replacement, finds style mistakes - revive # golint replacement, finds style mistakes
@@ -75,28 +95,22 @@ linters:
- stylecheck # Stylecheck is a replacement for golint - stylecheck # Stylecheck is a replacement for golint
- tagliatelle # Checks the struct tags. - tagliatelle # Checks the struct tags.
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
- unconvert # Remove unnecessary type conversions - unconvert # Remove unnecessary type conversions
- unparam # Reports unused function parameters - unparam # Reports unused function parameters
- unused # Checks Go code for unused constants, variables, functions and types - unused # Checks Go code for unused constants, variables, functions and types
- varnamelen # checks that the length of a variable's name matches its scope
- wastedassign # wastedassign finds wasted assignment statements - wastedassign # wastedassign finds wasted assignment statements
- whitespace # Tool for detection of leading and trailing whitespace - whitespace # Tool for detection of leading and trailing whitespace
disable: disable:
- depguard # Go linter that checks if package imports are in a list of acceptable packages - depguard # Go linter that checks if package imports are in a list of acceptable packages
- containedctx # containedctx is a linter that detects struct contained context.Context field
- cyclop # checks function and package cyclomatic complexity
- funlen # Tool for detection of long functions - funlen # Tool for detection of long functions
- gocyclo # Computes and checks the cyclomatic complexity of functions - gochecknoinits # Checks that no init functions are present in Go code
- godot # Check if comments end in a period - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- gomnd # An analyzer to detect magic numbers. - interfacebloat # A linter that checks length of interface.
- ireturn # Accept Interfaces, Return Concrete Types - ireturn # Accept Interfaces, Return Concrete Types
- lll # Reports long lines - mnd # An analyzer to detect magic numbers
- maintidx # maintidx measures the maintainability index of each function.
- makezero # Finds slice declarations with non-zero initial length
- nakedret # Finds naked returns in functions greater than a specified function length
- nestif # Reports deeply nested if statements
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
- nolintlint # Reports ill-formed or insufficient nolint directives - nolintlint # Reports ill-formed or insufficient nolint directives
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
- prealloc # Finds slice declarations that could potentially be preallocated - prealloc # Finds slice declarations that could potentially be preallocated
@@ -104,8 +118,7 @@ linters:
- rowserrcheck # checks whether Err of rows is checked successfully - rowserrcheck # checks whether Err of rows is checked successfully
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- testpackage # linter that makes you use a separate _test package - testpackage # linter that makes you use a separate _test package
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- varnamelen # checks that the length of a variable's name matches its scope
- wrapcheck # Checks that errors returned from external packages are wrapped - wrapcheck # Checks that errors returned from external packages are wrapped
- wsl # Whitespace Linter - Forces you to use empty lines! - wsl # Whitespace Linter - Forces you to use empty lines!

View File

@@ -21,7 +21,12 @@ type activeTCPConn struct {
closed int32 closed int32
} }
func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress netip.AddrPort, log logging.LeveledLogger) (a *activeTCPConn) { func newActiveTCPConn(
ctx context.Context,
localAddress string,
remoteAddress netip.AddrPort,
log logging.LeveledLogger,
) (a *activeTCPConn) {
a = &activeTCPConn{ a = &activeTCPConn{
readBuffer: packetio.NewBuffer(), readBuffer: packetio.NewBuffer(),
writeBuffer: packetio.NewBuffer(), writeBuffer: packetio.NewBuffer(),
@@ -31,7 +36,8 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
if err != nil { if err != nil {
atomic.StoreInt32(&a.closed, 1) atomic.StoreInt32(&a.closed, 1)
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
return
return a
} }
a.localAddr.Store(laddr) a.localAddr.Store(laddr)
@@ -46,6 +52,7 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String()) conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String())
if err != nil { if err != nil {
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
return return
} }
a.remoteAddr.Store(conn.RemoteAddr()) a.remoteAddr.Store(conn.RemoteAddr())
@@ -57,11 +64,13 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
n, err := readStreamingPacket(conn, buff) n, err := readStreamingPacket(conn, buff)
if err != nil { if err != nil {
log.Infof("Failed to read streaming packet: %s", err) log.Infof("Failed to read streaming packet: %s", err)
break break
} }
if _, err := a.readBuffer.Write(buff[:n]); err != nil { if _, err := a.readBuffer.Write(buff[:n]); err != nil {
log.Infof("Failed to write to buffer: %s", err) log.Infof("Failed to write to buffer: %s", err)
break break
} }
} }
@@ -73,11 +82,13 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
n, err := a.writeBuffer.Read(buff) n, err := a.writeBuffer.Read(buff)
if err != nil { if err != nil {
log.Infof("Failed to read from buffer: %s", err) log.Infof("Failed to read from buffer: %s", err)
break break
} }
if _, err = writeStreamingPacket(conn, buff[:n]); err != nil { if _, err = writeStreamingPacket(conn, buff[:n]); err != nil {
log.Infof("Failed to write streaming packet: %s", err) log.Infof("Failed to write streaming packet: %s", err)
break break
} }
} }
@@ -98,6 +109,7 @@ func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err erro
n, err = a.readBuffer.Read(buff) n, err = a.readBuffer.Read(buff)
// RemoteAddr is assuredly set *after* we can read from the buffer // RemoteAddr is assuredly set *after* we can read from the buffer
srcAddr = a.RemoteAddr() srcAddr = a.RemoteAddr()
return return
} }
@@ -113,6 +125,7 @@ func (a *activeTCPConn) Close() error {
atomic.StoreInt32(&a.closed, 1) atomic.StoreInt32(&a.closed, 1)
_ = a.readBuffer.Close() _ = a.readBuffer.Close()
_ = a.writeBuffer.Close() _ = a.writeBuffer.Close()
return nil return nil
} }

View File

@@ -21,19 +21,25 @@ import (
) )
func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr { func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr {
t.Helper()
net, err := stdnet.NewNet() net, err := stdnet.NewNet()
require.NoError(t, err) require.NoError(t, err)
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false) _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, localAddrs) require.NotEmpty(t, localAddrs)
return localAddrs[0] return localAddrs[0]
} }
func ipv6Available(t *testing.T) bool { func ipv6Available(t *testing.T) bool {
t.Helper()
net, err := stdnet.NewNet() net, err := stdnet.NewNet()
require.NoError(t, err) require.NoError(t, err)
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false) _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false)
require.NoError(t, err) require.NoError(t, err)
return len(localAddrs) > 0 return len(localAddrs) > 0
} }
@@ -89,14 +95,14 @@ func TestActiveTCP(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
r := require.New(t) req := require.New(t)
listener, err := net.ListenTCP("tcp", &net.TCPAddr{ listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: testCase.listenIPAddress.AsSlice(), IP: testCase.listenIPAddress.AsSlice(),
Port: listenPort, Port: listenPort,
Zone: testCase.listenIPAddress.Zone(), Zone: testCase.listenIPAddress.Zone(),
}) })
r.NoError(err) req.NoError(err)
defer func() { defer func() {
_ = listener.Close() _ = listener.Close()
}() }()
@@ -113,7 +119,7 @@ func TestActiveTCP(t *testing.T) {
_ = tcpMux.Close() _ = tcpMux.Close()
}() }()
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") req.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
hostAcceptanceMinWait := 100 * time.Millisecond hostAcceptanceMinWait := 100 * time.Millisecond
cfg := &AgentConfig{ cfg := &AgentConfig{
@@ -128,8 +134,8 @@ func TestActiveTCP(t *testing.T) {
cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather
} }
passiveAgent, err := NewAgent(cfg) passiveAgent, err := NewAgent(cfg)
r.NoError(err) req.NoError(err)
r.NotNil(passiveAgent) req.NotNil(passiveAgent)
activeAgent, err := NewAgent(&AgentConfig{ activeAgent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeHost}, CandidateTypes: []CandidateType{CandidateTypeHost},
@@ -138,44 +144,44 @@ func TestActiveTCP(t *testing.T) {
HostAcceptanceMinWait: &hostAcceptanceMinWait, HostAcceptanceMinWait: &hostAcceptanceMinWait,
InterfaceFilter: problematicNetworkInterfaces, InterfaceFilter: problematicNetworkInterfaces,
}) })
r.NoError(err) req.NoError(err)
r.NotNil(activeAgent) req.NotNil(activeAgent)
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent) passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
r.NotNil(passiveAgentConn) req.NotNil(passiveAgentConn)
r.NotNil(activeAgenConn) req.NotNil(activeAgenConn)
defer func() { defer func() {
r.NoError(activeAgenConn.Close()) req.NoError(activeAgenConn.Close())
r.NoError(passiveAgentConn.Close()) req.NoError(passiveAgentConn.Close())
}() }()
pair := passiveAgent.getSelectedPair() pair := passiveAgent.getSelectedPair()
r.NotNil(pair) req.NotNil(pair)
r.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort()) req.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort())
foo := []byte("foo") foo := []byte("foo")
_, err = passiveAgentConn.Write(foo) _, err = passiveAgentConn.Write(foo)
r.NoError(err) req.NoError(err)
buffer := make([]byte, 1024) buffer := make([]byte, 1024)
n, err := activeAgenConn.Read(buffer) n, err := activeAgenConn.Read(buffer)
r.NoError(err) req.NoError(err)
r.Equal(foo, buffer[:n]) req.Equal(foo, buffer[:n])
bar := []byte("bar") bar := []byte("bar")
_, err = activeAgenConn.Write(bar) _, err = activeAgenConn.Write(bar)
r.NoError(err) req.NoError(err)
n, err = passiveAgentConn.Read(buffer) n, err = passiveAgentConn.Read(buffer)
r.NoError(err) req.NoError(err)
r.Equal(bar, buffer[:n]) req.Equal(bar, buffer[:n])
}) })
} }
} }
// Assert that Active TCP connectivity isn't established inside // Assert that Active TCP connectivity isn't established inside.
// the main thread of the Agent // the main thread of the Agent.
func TestActiveTCP_NonBlocking(t *testing.T) { func TestActiveTCP_NonBlocking(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -219,7 +225,7 @@ func TestActiveTCP_NonBlocking(t *testing.T) {
<-isConnected <-isConnected
} }
// Assert that we ignore remote TCP candidates when running a UDP Only Agent // Assert that we ignore remote TCP candidates when running a UDP Only Agent.
func TestActiveTCP_Respect_NetworkTypes(t *testing.T) { func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop() defer test.TimeOut(time.Second * 5).Stop()
@@ -271,7 +277,9 @@ func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
invalidCandidate, err := UnmarshalCandidate(fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port)) invalidCandidate, err := UnmarshalCandidate(
fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port),
)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate)) require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate))
require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate)) require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate))

18
addr.go
View File

@@ -16,6 +16,7 @@ func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr {
if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) {
return addr.WithZone(zone) return addr.WithZone(zone)
} }
return addr return addr
} }
@@ -30,22 +31,25 @@ func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkTy
// net.IPNet does not have a Zone but we provide it from the interface // net.IPNet does not have a Zone but we provide it from the interface
addr = addrWithOptionalZone(addr, ifcName) addr = addrWithOptionalZone(addr, ifcName)
} }
return addr, port, nt, nil return addr, port, nt, nil
} }
func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { //nolint:cyclop
switch addr := in.(type) { switch addr := in.(type) {
case *net.IPNet: case *net.IPNet:
ipAddr, err := ipAddrToNetIP(addr.IP, "") ipAddr, err := ipAddrToNetIP(addr.IP, "")
if err != nil { if err != nil {
return netip.Addr{}, 0, 0, err return netip.Addr{}, 0, 0, err
} }
return ipAddr, 0, 0, nil return ipAddr, 0, 0, nil
case *net.IPAddr: case *net.IPAddr:
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
if err != nil { if err != nil {
return netip.Addr{}, 0, 0, err return netip.Addr{}, 0, 0, err
} }
return ipAddr, 0, 0, nil return ipAddr, 0, 0, nil
case *net.UDPAddr: case *net.UDPAddr:
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
@@ -58,6 +62,7 @@ func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
} else { } else {
nt = NetworkTypeUDP6 nt = NetworkTypeUDP6
} }
return ipAddr, addr.Port, nt, nil return ipAddr, addr.Port, nt, nil
case *net.TCPAddr: case *net.TCPAddr:
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
@@ -70,6 +75,7 @@ func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
} else { } else {
nt = NetworkTypeTCP6 nt = NetworkTypeTCP6
} }
return ipAddr, addr.Port, nt, nil return ipAddr, addr.Port, nt, nil
default: default:
return netip.Addr{}, 0, 0, addrParseError{in} return netip.Addr{}, 0, 0, addrParseError{in}
@@ -100,6 +106,7 @@ func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) {
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable.
netIPAddr = netIPAddr.Unmap() netIPAddr = netIPAddr.Unmap()
netIPAddr = addrWithOptionalZone(netIPAddr, zone) netIPAddr = addrWithOptionalZone(netIPAddr, zone)
return netIPAddr, nil return netIPAddr, nil
} }
@@ -134,12 +141,13 @@ func toAddrPort(addr net.Addr) AddrPort {
switch addr := addr.(type) { switch addr := addr.(type) {
case *net.UDPAddr: case *net.UDPAddr:
copy(ap[:16], addr.IP.To16()) copy(ap[:16], addr.IP.To16())
ap[16] = uint8(addr.Port >> 8) ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive
ap[17] = uint8(addr.Port) ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive
case *net.TCPAddr: case *net.TCPAddr:
copy(ap[:16], addr.IP.To16()) copy(ap[:16], addr.IP.To16())
ap[16] = uint8(addr.Port >> 8) ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive
ap[17] = uint8(addr.Port) ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive
} }
return ap return ap
} }

330
agent.go
View File

@@ -35,7 +35,7 @@ type bindingRequest struct {
isUseCandidate bool isUseCandidate bool
} }
// Agent represents the ICE agent // Agent represents the ICE agent.
type Agent struct { type Agent struct {
loop *taskloop.Loop loop *taskloop.Loop
@@ -149,8 +149,8 @@ type Agent struct {
enableUseCandidateCheckPriority bool enableUseCandidateCheckPriority bool
} }
// NewAgent creates a new Agent // NewAgent creates a new Agent.
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit,cyclop
var err error var err error
if config.PortMax < config.PortMin { if config.PortMax < config.PortMin {
return nil, ErrPort return nil, ErrPort
@@ -180,7 +180,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
startedCtx, startedFn := context.WithCancel(context.Background()) startedCtx, startedFn := context.WithCancel(context.Background())
a := &Agent{ agent := &Agent{
tieBreaker: globalMathRandomGenerator.Uint64(), tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite, lite: config.Lite,
gatheringState: GatheringStateNew, gatheringState: GatheringStateNew,
@@ -224,34 +224,46 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority, enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority,
} }
a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange, done: make(chan struct{})} agent.connectionStateNotifier = &handlerNotifier{
a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate, done: make(chan struct{})} connectionStateFunc: agent.onConnectionStateChange,
a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange, done: make(chan struct{})} done: make(chan struct{}),
}
agent.candidateNotifier = &handlerNotifier{candidateFunc: agent.onCandidate, done: make(chan struct{})}
agent.selectedCandidatePairNotifier = &handlerNotifier{
candidatePairFunc: agent.onSelectedCandidatePairChange,
done: make(chan struct{}),
}
if a.net == nil { if agent.net == nil {
a.net, err = stdnet.NewNet() agent.net, err = stdnet.NewNet()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create network: %w", err) return nil, fmt.Errorf("failed to create network: %w", err)
} }
} else if _, isVirtual := a.net.(*vnet.Net); isVirtual { } else if _, isVirtual := agent.net.(*vnet.Net); isVirtual {
a.log.Warn("Virtual network is enabled") agent.log.Warn("Virtual network is enabled")
if a.mDNSMode != MulticastDNSModeDisabled { if agent.mDNSMode != MulticastDNSModeDisabled {
a.log.Warn("Virtual network does not support mDNS yet") agent.log.Warn("Virtual network does not support mDNS yet")
} }
} }
localIfcs, _, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback) localIfcs, _, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
agent.networkTypes,
agent.includeLoopback,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting local interfaces: %w", err) return nil, fmt.Errorf("error getting local interfaces: %w", err)
} }
// Opportunistic mDNS: If we can't open the connection, that's ok: we // Opportunistic mDNS: If we can't open the connection, that's ok: we
// can continue without it. // can continue without it.
if a.mDNSConn, a.mDNSMode, err = createMulticastDNS( if agent.mDNSConn, agent.mDNSMode, err = createMulticastDNS(
a.net, agent.net,
a.networkTypes, agent.networkTypes,
localIfcs, localIfcs,
a.includeLoopback, agent.includeLoopback,
mDNSMode, mDNSMode,
mDNSName, mDNSName,
log, log,
@@ -259,54 +271,60 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err) log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
} }
config.initWithDefaults(a) config.initWithDefaults(agent)
// Make sure the buffer doesn't grow indefinitely. // Make sure the buffer doesn't grow indefinitely.
// NOTE: We actually won't get anywhere close to this limit. // NOTE: We actually won't get anywhere close to this limit.
// SRTP will constantly read from the endpoint and drop packets if it's full. // SRTP will constantly read from the endpoint and drop packets if it's full.
a.buf.SetLimitSize(maxBufferSize) agent.buf.SetLimitSize(maxBufferSize)
if agent.lite && (len(agent.candidateTypes) != 1 || agent.candidateTypes[0] != CandidateTypeHost) {
agent.closeMulticastConn()
if a.lite && (len(a.candidateTypes) != 1 || a.candidateTypes[0] != CandidateTypeHost) {
a.closeMulticastConn()
return nil, ErrLiteUsingNonHostCandidates return nil, ErrLiteUsingNonHostCandidates
} }
if len(config.Urls) > 0 && !containsCandidateType(CandidateTypeServerReflexive, a.candidateTypes) && !containsCandidateType(CandidateTypeRelay, a.candidateTypes) { if len(config.Urls) > 0 &&
a.closeMulticastConn() !containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) &&
!containsCandidateType(CandidateTypeRelay, agent.candidateTypes) {
agent.closeMulticastConn()
return nil, ErrUselessUrlsProvided return nil, ErrUselessUrlsProvided
} }
if err = config.initExtIPMapping(a); err != nil { if err = config.initExtIPMapping(agent); err != nil {
a.closeMulticastConn() agent.closeMulticastConn()
return nil, err return nil, err
} }
a.loop = taskloop.New(func() { agent.loop = taskloop.New(func() {
a.removeUfragFromMux() agent.removeUfragFromMux()
a.deleteAllCandidates() agent.deleteAllCandidates()
a.startedFn() agent.startedFn()
if err := a.buf.Close(); err != nil { if err := agent.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err) agent.log.Warnf("Failed to close buffer: %v", err)
} }
a.closeMulticastConn() agent.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed) agent.updateConnectionState(ConnectionStateClosed)
a.gatherCandidateCancel() agent.gatherCandidateCancel()
if a.gatherCandidateDone != nil { if agent.gatherCandidateDone != nil {
<-a.gatherCandidateDone <-agent.gatherCandidateDone
} }
}) })
// Restart is also used to initialize the agent for the first time // Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil { if err := agent.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
a.closeMulticastConn() agent.closeMulticastConn()
_ = a.Close() _ = agent.Close()
return nil, err return nil, err
} }
return a, nil return agent, nil
} }
func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error { func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error {
@@ -348,7 +366,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
}) })
} }
func (a *Agent) connectivityChecks() { func (a *Agent) connectivityChecks() { //nolint:cyclop
lastConnectionState := ConnectionState(0) lastConnectionState := ConnectionState(0)
checkingDuration := time.Time{} checkingDuration := time.Time{}
@@ -372,6 +390,7 @@ func (a *Agent) connectivityChecks() {
// We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed
if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout { if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout {
a.updateConnectionState(ConnectionStateFailed) a.updateConnectionState(ConnectionStateFailed)
return return
} }
default: default:
@@ -383,8 +402,8 @@ func (a *Agent) connectivityChecks() {
} }
} }
t := time.NewTimer(math.MaxInt64) timer := time.NewTimer(math.MaxInt64)
t.Stop() timer.Stop()
for { for {
interval := defaultKeepaliveInterval interval := defaultKeepaliveInterval
@@ -406,18 +425,19 @@ func (a *Agent) connectivityChecks() {
updateInterval(a.disconnectedTimeout) updateInterval(a.disconnectedTimeout)
updateInterval(a.failedTimeout) updateInterval(a.failedTimeout)
t.Reset(interval) timer.Reset(interval)
select { select {
case <-a.forceCandidateContact: case <-a.forceCandidateContact:
if !t.Stop() { if !timer.Stop() {
<-t.C <-timer.C
} }
contact() contact()
case <-t.C: case <-timer.C:
contact() contact()
case <-a.loop.Done(): case <-a.loop.Done():
t.Stop() timer.Stop()
return return
} }
} }
@@ -440,22 +460,23 @@ func (a *Agent) updateConnectionState(newState ConnectionState) {
} }
} }
func (a *Agent) setSelectedPair(p *CandidatePair) { func (a *Agent) setSelectedPair(pair *CandidatePair) {
if p == nil { if pair == nil {
var nilPair *CandidatePair var nilPair *CandidatePair
a.selectedPair.Store(nilPair) a.selectedPair.Store(nilPair)
a.log.Tracef("Unset selected candidate pair") a.log.Tracef("Unset selected candidate pair")
return return
} }
p.nominated = true pair.nominated = true
a.selectedPair.Store(p) a.selectedPair.Store(pair)
a.log.Tracef("Set selected candidate pair: %s", p) a.log.Tracef("Set selected candidate pair: %s", pair)
a.updateConnectionState(ConnectionStateConnected) a.updateConnectionState(ConnectionStateConnected)
// Notify when the selected pair changes // Notify when the selected pair changes
a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(p) a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair)
// Signal connected // Signal connected
a.onConnectedOnce.Do(func() { close(a.onConnected) }) a.onConnectedOnce.Do(func() { close(a.onConnected) })
@@ -498,6 +519,7 @@ func (a *Agent) getBestAvailableCandidatePair() *CandidatePair {
best = p best = p
} }
} }
return best return best
} }
@@ -514,12 +536,14 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
best = p best = p
} }
} }
return best return best
} }
func (a *Agent) addPair(local, remote Candidate) *CandidatePair { func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
p := newCandidatePair(local, remote, a.isControlling) p := newCandidatePair(local, remote, a.isControlling)
a.checklist = append(a.checklist, p) a.checklist = append(a.checklist, p)
return p return p
} }
@@ -529,6 +553,7 @@ func (a *Agent) findPair(local, remote Candidate) *CandidatePair {
return p return p
} }
} }
return nil return nil
} }
@@ -578,68 +603,76 @@ func (a *Agent) checkKeepalive() {
} }
} }
// AddRemoteCandidate adds a new remote candidate // AddRemoteCandidate adds a new remote candidate.
func (a *Agent) AddRemoteCandidate(c Candidate) error { func (a *Agent) AddRemoteCandidate(cand Candidate) error {
if c == nil { if cand == nil {
return nil return nil
} }
// TCP Candidates with TCP type active will probe server passive ones, so // TCP Candidates with TCP type active will probe server passive ones, so
// no need to do anything with them. // no need to do anything with them.
if c.TCPType() == TCPTypeActive { if cand.TCPType() == TCPTypeActive {
a.log.Infof("Ignoring remote candidate with tcpType active: %s", c) a.log.Infof("Ignoring remote candidate with tcpType active: %s", cand)
return nil return nil
} }
// If we have a mDNS Candidate lets fully resolve it before adding it locally // If we have a mDNS Candidate lets fully resolve it before adding it locally
if c.Type() == CandidateTypeHost && strings.HasSuffix(c.Address(), ".local") { if cand.Type() == CandidateTypeHost && strings.HasSuffix(cand.Address(), ".local") {
if a.mDNSMode == MulticastDNSModeDisabled { if a.mDNSMode == MulticastDNSModeDisabled {
a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", c.Address()) a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", cand.Address())
return nil return nil
} }
hostCandidate, ok := c.(*CandidateHost) hostCandidate, ok := cand.(*CandidateHost)
if !ok { if !ok {
return ErrAddressParseFailed return ErrAddressParseFailed
} }
go a.resolveAndAddMulticastCandidate(hostCandidate) go a.resolveAndAddMulticastCandidate(hostCandidate)
return nil return nil
} }
go func() { go func() {
if err := a.loop.Run(a.loop, func(_ context.Context) { if err := a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck // nolint: contextcheck
a.addRemoteCandidate(c) a.addRemoteCandidate(cand)
}); err != nil { }); err != nil {
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err) a.log.Warnf("Failed to add remote candidate %s: %v", cand.Address(), err)
return return
} }
}() }()
return nil return nil
} }
func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) { func (a *Agent) resolveAndAddMulticastCandidate(cand *CandidateHost) {
if a.mDNSConn == nil { if a.mDNSConn == nil {
return return
} }
_, src, err := a.mDNSConn.QueryAddr(c.context(), c.Address()) _, src, err := a.mDNSConn.QueryAddr(cand.context(), cand.Address())
if err != nil { if err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err)
return return
} }
if err = c.setIPAddr(src); err != nil { if err = cand.setIPAddr(src); err != nil {
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err)
return return
} }
if err = a.loop.Run(a.loop, func(_ context.Context) { if err = a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck // nolint: contextcheck
a.addRemoteCandidate(c) a.addRemoteCandidate(cand)
}); err != nil { }); err != nil {
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err) a.log.Warnf("Failed to add mDNS candidate %s: %v", cand.Address(), err)
return return
} }
} }
@@ -652,9 +685,16 @@ func (a *Agent) requestConnectivityCheck() {
} }
func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback) _, localIPs, err := localInterfaces(
a.net,
a.interfaceFilter,
a.ipFilter,
[]NetworkType{remoteCandidate.NetworkType()},
a.includeLoopback,
)
if err != nil { if err != nil {
a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err)
return return
} }
@@ -662,19 +702,21 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
ip, _, _, err := parseAddr(remoteCandidate.addr()) ip, _, _, err := parseAddr(remoteCandidate.addr())
if err != nil { if err != nil {
a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err) a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err)
continue continue
} }
conn := newActiveTCPConn( conn := newActiveTCPConn(
a.loop, a.loop,
net.JoinHostPort(localIPs[i].String(), "0"), net.JoinHostPort(localIPs[i].String(), "0"),
netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port
a.log, a.log,
) )
tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr) tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if !ok { if !ok {
closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", errInvalidAddress) closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", errInvalidAddress)
continue continue
} }
@@ -687,48 +729,52 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
}) })
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", err) closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", err)
continue continue
} }
localCandidate.start(a, conn, a.startedCh) localCandidate.start(a, conn, a.startedCh)
a.localCandidates[localCandidate.NetworkType()] = append(a.localCandidates[localCandidate.NetworkType()], localCandidate) a.localCandidates[localCandidate.NetworkType()] = append(
a.localCandidates[localCandidate.NetworkType()],
localCandidate,
)
a.candidateNotifier.EnqueueCandidate(localCandidate) a.candidateNotifier.EnqueueCandidate(localCandidate)
a.addPair(localCandidate, remoteCandidate) a.addPair(localCandidate, remoteCandidate)
} }
} }
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run) // addRemoteCandidate assumes you are holding the lock (must be execute using a.run).
func (a *Agent) addRemoteCandidate(c Candidate) { func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop
set := a.remoteCandidates[c.NetworkType()] set := a.remoteCandidates[cand.NetworkType()]
for _, candidate := range set { for _, candidate := range set {
if candidate.Equal(c) { if candidate.Equal(cand) {
return return
} }
} }
acceptRemotePassiveTCPCandidate := false acceptRemotePassiveTCPCandidate := false
// Assert that TCP4 or TCP6 is a enabled NetworkType locally // Assert that TCP4 or TCP6 is a enabled NetworkType locally
if !a.disableActiveTCP && c.TCPType() == TCPTypePassive { if !a.disableActiveTCP && cand.TCPType() == TCPTypePassive {
for _, networkType := range a.networkTypes { for _, networkType := range a.networkTypes {
if c.NetworkType() == networkType { if cand.NetworkType() == networkType {
acceptRemotePassiveTCPCandidate = true acceptRemotePassiveTCPCandidate = true
} }
} }
} }
if acceptRemotePassiveTCPCandidate { if acceptRemotePassiveTCPCandidate {
a.addRemotePassiveTCPCandidate(c) a.addRemotePassiveTCPCandidate(cand)
} }
set = append(set, c) set = append(set, cand)
a.remoteCandidates[c.NetworkType()] = set a.remoteCandidates[cand.NetworkType()] = set
if c.TCPType() != TCPTypePassive { if cand.TCPType() != TCPTypePassive {
if localCandidates, ok := a.localCandidates[c.NetworkType()]; ok { if localCandidates, ok := a.localCandidates[cand.NetworkType()]; ok {
for _, localCandidate := range localCandidates { for _, localCandidate := range localCandidates {
a.addPair(localCandidate, c) a.addPair(localCandidate, cand)
} }
} }
} }
@@ -736,42 +782,43 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
a.requestConnectivityCheck() a.requestConnectivityCheck()
} }
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error { func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error {
return a.loop.Run(ctx, func(context.Context) { return a.loop.Run(ctx, func(context.Context) {
set := a.localCandidates[c.NetworkType()] set := a.localCandidates[cand.NetworkType()]
for _, candidate := range set { for _, candidate := range set {
if candidate.Equal(c) { if candidate.Equal(cand) {
a.log.Debugf("Ignore duplicate candidate: %s", c) a.log.Debugf("Ignore duplicate candidate: %s", cand)
if err := c.close(); err != nil { if err := cand.close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate: %v", err) a.log.Warnf("Failed to close duplicate candidate: %v", err)
} }
if err := candidateConn.Close(); err != nil { if err := candidateConn.Close(); err != nil {
a.log.Warnf("Failed to close duplicate candidate connection: %v", err) a.log.Warnf("Failed to close duplicate candidate connection: %v", err)
} }
return return
} }
} }
c.start(a, candidateConn, a.startedCh) cand.start(a, candidateConn, a.startedCh)
set = append(set, c) set = append(set, cand)
a.localCandidates[c.NetworkType()] = set a.localCandidates[cand.NetworkType()] = set
if remoteCandidates, ok := a.remoteCandidates[c.NetworkType()]; ok { if remoteCandidates, ok := a.remoteCandidates[cand.NetworkType()]; ok {
for _, remoteCandidate := range remoteCandidates { for _, remoteCandidate := range remoteCandidates {
a.addPair(c, remoteCandidate) a.addPair(cand, remoteCandidate)
} }
} }
a.requestConnectivityCheck() a.requestConnectivityCheck()
if !c.filterForLocationTracking() { if !cand.filterForLocationTracking() {
a.candidateNotifier.EnqueueCandidate(c) a.candidateNotifier.EnqueueCandidate(cand)
} }
}) })
} }
// GetRemoteCandidates returns the remote candidates // GetRemoteCandidates returns the remote candidates.
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) { func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
var res []Candidate var res []Candidate
@@ -789,7 +836,7 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
return res, nil return res, nil
} }
// GetLocalCandidates returns the local candidates // GetLocalCandidates returns the local candidates.
func (a *Agent) GetLocalCandidates() ([]Candidate, error) { func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
var res []Candidate var res []Candidate
@@ -812,7 +859,7 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
return res, nil return res, nil
} }
// GetLocalUserCredentials returns the local user credentials // GetLocalUserCredentials returns the local user credentials.
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) { func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{}) valSet := make(chan struct{})
err = a.loop.Run(a.loop, func(_ context.Context) { err = a.loop.Run(a.loop, func(_ context.Context) {
@@ -824,10 +871,11 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
if err == nil { if err == nil {
<-valSet <-valSet
} }
return return
} }
// GetRemoteUserCredentials returns the remote user credentials // GetRemoteUserCredentials returns the remote user credentials.
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) { func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{}) valSet := make(chan struct{})
err = a.loop.Run(a.loop, func(_ context.Context) { err = a.loop.Run(a.loop, func(_ context.Context) {
@@ -839,6 +887,7 @@ func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error)
if err == nil { if err == nil {
<-valSet <-valSet
} }
return return
} }
@@ -854,7 +903,7 @@ func (a *Agent) removeUfragFromMux() {
} }
} }
// Close cleans up the Agent // Close cleans up the Agent.
func (a *Agent) Close() error { func (a *Agent) Close() error {
return a.close(false) return a.close(false)
} }
@@ -875,13 +924,14 @@ func (a *Agent) close(graceful bool) error {
a.connectionStateNotifier.Close(graceful) a.connectionStateNotifier.Close(graceful)
a.candidateNotifier.Close(graceful) a.candidateNotifier.Close(graceful)
a.selectedCandidatePairNotifier.Close(graceful) a.selectedCandidatePairNotifier.Close(graceful)
return nil return nil
} }
// Remove all candidates. This closes any listening sockets // Remove all candidates. This closes any listening sockets
// and removes both the local and remote candidate lists. // and removes both the local and remote candidate lists.
// //
// This is used for restarts, failures and on close // This is used for restarts, failures and on close.
func (a *Agent) deleteAllCandidates() { func (a *Agent) deleteAllCandidates() {
for net, cs := range a.localCandidates { for net, cs := range a.localCandidates {
for _, c := range cs { for _, c := range cs {
@@ -905,6 +955,7 @@ func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Cand
ip, port, _, err := parseAddr(addr) ip, port, _, err := parseAddr(addr)
if err != nil { if err != nil {
a.log.Warnf("Failed to parse address: %s; error: %s", addr, err) a.log.Warnf("Failed to parse address: %s; error: %s", addr, err)
return nil return nil
} }
@@ -914,6 +965,7 @@ func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Cand
return c return c
} }
} }
return nil return nil
} }
@@ -937,6 +989,7 @@ func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) {
ip, port, _, err := parseAddr(base.addr()) ip, port, _, err := parseAddr(base.addr())
if err != nil { if err != nil {
a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err) a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err)
return return
} }
@@ -976,70 +1029,85 @@ func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) {
} }
// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination // Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination
// If the bindingRequest was valid remove it from our pending cache // If the bindingRequest was valid remove it from our pending cache.
func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) { func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) {
a.invalidatePendingBindingRequests(time.Now()) a.invalidatePendingBindingRequests(time.Now())
for i := range a.pendingBindingRequests { for i := range a.pendingBindingRequests {
if a.pendingBindingRequests[i].transactionID == id { if a.pendingBindingRequests[i].transactionID == id {
validBindingRequest := a.pendingBindingRequests[i] validBindingRequest := a.pendingBindingRequests[i]
a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...) a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...)
return true, &validBindingRequest, time.Since(validBindingRequest.timestamp) return true, &validBindingRequest, time.Since(validBindingRequest.timestamp)
} }
} }
return false, nil, 0 return false, nil, 0
} }
// handleInbound processes STUN traffic from a remote candidate // handleInbound processes STUN traffic from a remote candidate.
func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop
var err error var err error
if m == nil || local == nil { if msg == nil || local == nil {
return return
} }
if m.Type.Method != stun.MethodBinding || if msg.Type.Method != stun.MethodBinding ||
!(m.Type.Class == stun.ClassSuccessResponse || !(msg.Type.Class == stun.ClassSuccessResponse ||
m.Type.Class == stun.ClassRequest || msg.Type.Class == stun.ClassRequest ||
m.Type.Class == stun.ClassIndication) { msg.Type.Class == stun.ClassIndication) {
a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, m.Type.Class, m.Type.Method) a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, msg.Type.Class, msg.Type.Method)
return return
} }
if a.isControlling { if a.isControlling {
if m.Contains(stun.AttrICEControlling) { if msg.Contains(stun.AttrICEControlling) {
a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true") a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true")
return return
} else if m.Contains(stun.AttrUseCandidate) { } else if msg.Contains(stun.AttrUseCandidate) {
a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true") a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true")
return return
} }
} else { } else {
if m.Contains(stun.AttrICEControlled) { if msg.Contains(stun.AttrICEControlled) {
a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false") a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false")
return return
} }
} }
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote) remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
if m.Type.Class == stun.ClassSuccessResponse { if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(m); err != nil { if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err) a.log.Warnf("Discard message from (%s), %v", remote, err)
return return
} }
if remoteCandidate == nil { if remoteCandidate == nil {
a.log.Warnf("Discard success message from (%s), no such remote", remote) a.log.Warnf("Discard success message from (%s), no such remote", remote)
return return
} }
a.selector.HandleSuccessResponse(m, local, remoteCandidate, remote) a.selector.HandleSuccessResponse(msg, local, remoteCandidate, remote)
} else if m.Type.Class == stun.ClassRequest { } else if msg.Type.Class == stun.ClassRequest {
a.log.Tracef("Inbound STUN (Request) from %s to %s, useCandidate: %v", remote, local, m.Contains(stun.AttrUseCandidate)) a.log.Tracef(
"Inbound STUN (Request) from %s to %s, useCandidate: %v",
remote,
local,
msg.Contains(stun.AttrUseCandidate),
)
if err = stunx.AssertUsername(m, a.localUfrag+":"+a.remoteUfrag); err != nil { if err = stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err) a.log.Warnf("Discard message from (%s), %v", remote, err)
return return
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(m); err != nil { } else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
a.log.Warnf("Discard message from (%s), %v", remote, err) a.log.Warnf("Discard message from (%s), %v", remote, err)
return return
} }
@@ -1047,6 +1115,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
ip, port, networkType, err := parseAddr(remote) ip, port, networkType, err := parseAddr(remote)
if err != nil { if err != nil {
a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err) a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err)
return return
} }
@@ -1062,6 +1131,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig) prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)
if err != nil { if err != nil {
a.log.Errorf("Failed to create new remote prflx candidate (%s)", err) a.log.Errorf("Failed to create new remote prflx candidate (%s)", err)
return return
} }
remoteCandidate = prflxCandidate remoteCandidate = prflxCandidate
@@ -1070,7 +1140,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
a.addRemoteCandidate(remoteCandidate) a.addRemoteCandidate(remoteCandidate)
} }
a.selector.HandleBindingRequest(m, local, remoteCandidate) a.selector.HandleBindingRequest(msg, local, remoteCandidate)
} }
if remoteCandidate != nil { if remoteCandidate != nil {
@@ -1079,7 +1149,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
} }
// validateNonSTUNTraffic processes non STUN traffic from a remote candidate, // validateNonSTUNTraffic processes non STUN traffic from a remote candidate,
// and returns true if it is an actual remote candidate // and returns true if it is an actual remote candidate.
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
var remoteCandidate Candidate var remoteCandidate Candidate
if err := a.loop.Run(local.context(), func(context.Context) { if err := a.loop.Run(local.context(), func(context.Context) {
@@ -1094,7 +1164,7 @@ func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candid
return remoteCandidate, remoteCandidate != nil return remoteCandidate, remoteCandidate != nil
} }
// GetSelectedCandidatePair returns the selected pair or nil if there is none // GetSelectedCandidatePair returns the selected pair or nil if there is none.
func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) { func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) {
selectedPair := a.getSelectedPair() selectedPair := a.getSelectedPair()
if selectedPair == nil { if selectedPair == nil {
@@ -1130,7 +1200,7 @@ func (a *Agent) closeMulticastConn() {
} }
} }
// SetRemoteCredentials sets the credentials of the remote agent // SetRemoteCredentials sets the credentials of the remote agent.
func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error { func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
switch { switch {
case remoteUfrag == "": case remoteUfrag == "":
@@ -1152,7 +1222,7 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
// cancel it. // cancel it.
// After a Restart, the user must then call GatherCandidates explicitly // After a Restart, the user must then call GatherCandidates explicitly
// to start generating new ones. // to start generating new ones.
func (a *Agent) Restart(ufrag, pwd string) error { func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
if ufrag == "" { if ufrag == "" {
var err error var err error
ufrag, err = generateUFrag() ufrag, err = generateUFrag()
@@ -1204,6 +1274,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {
}); runErr != nil { }); runErr != nil {
return runErr return runErr
} }
return err return err
} }
@@ -1221,6 +1292,7 @@ func (a *Agent) setGatheringState(newState GatheringState) error {
} }
<-done <-done
return nil return nil
} }

View File

@@ -14,44 +14,44 @@ import (
) )
const ( const (
// defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase // defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase.
defaultCheckInterval = 200 * time.Millisecond defaultCheckInterval = 200 * time.Millisecond
// keepaliveInterval used to keep candidates alive // keepaliveInterval used to keep candidates alive.
defaultKeepaliveInterval = 2 * time.Second defaultKeepaliveInterval = 2 * time.Second
// defaultDisconnectedTimeout is the default time till an Agent transitions disconnected // defaultDisconnectedTimeout is the default time till an Agent transitions disconnected.
defaultDisconnectedTimeout = 5 * time.Second defaultDisconnectedTimeout = 5 * time.Second
// defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected // defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected.
defaultFailedTimeout = 25 * time.Second defaultFailedTimeout = 25 * time.Second
// defaultHostAcceptanceMinWait is the wait time before nominating a host candidate // defaultHostAcceptanceMinWait is the wait time before nominating a host candidate.
defaultHostAcceptanceMinWait = 0 defaultHostAcceptanceMinWait = 0
// defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate // defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate.
defaultSrflxAcceptanceMinWait = 500 * time.Millisecond defaultSrflxAcceptanceMinWait = 500 * time.Millisecond
// defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate // defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate.
defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond
// defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate // defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate.
defaultRelayAcceptanceMinWait = 2000 * time.Millisecond defaultRelayAcceptanceMinWait = 2000 * time.Millisecond
// defaultSTUNGatherTimeout is the wait time for STUN responses // defaultSTUNGatherTimeout is the wait time for STUN responses.
defaultSTUNGatherTimeout = 5 * time.Second defaultSTUNGatherTimeout = 5 * time.Second
// defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed // defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed.
defaultMaxBindingRequests = 7 defaultMaxBindingRequests = 7
// TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference // TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference
// for host, srflx and prfx candidate types. // for host, srflx and prfx candidate types.
defaultTCPPriorityOffset = 27 defaultTCPPriorityOffset = 27
// maxBufferSize is the number of bytes that can be buffered before we start to error // maxBufferSize is the number of bytes that can be buffered before we start to error.
maxBufferSize = 1000 * 1000 // 1MB maxBufferSize = 1000 * 1000 // 1MB
// maxBindingRequestTimeout is the wait time before binding requests can be deleted // maxBindingRequestTimeout is the wait time before binding requests can be deleted.
maxBindingRequestTimeout = 4000 * time.Millisecond maxBindingRequestTimeout = 4000 * time.Millisecond
) )
@@ -60,7 +60,7 @@ func defaultCandidateTypes() []CandidateType {
} }
// AgentConfig collects the arguments to ice.Agent construction into // AgentConfig collects the arguments to ice.Agent construction into
// a single structure, for future-proofness of the interface // a single structure, for future-proofness of the interface.
type AgentConfig struct { type AgentConfig struct {
Urls []*stun.URI Urls []*stun.URI
@@ -209,109 +209,111 @@ type AgentConfig struct {
EnableUseCandidateCheckPriority bool EnableUseCandidateCheckPriority bool
} }
// initWithDefaults populates an agent and falls back to defaults if fields are unset // initWithDefaults populates an agent and falls back to defaults if fields are unset.
func (config *AgentConfig) initWithDefaults(a *Agent) { func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop
if config.MaxBindingRequests == nil { if config.MaxBindingRequests == nil {
a.maxBindingRequests = defaultMaxBindingRequests agent.maxBindingRequests = defaultMaxBindingRequests
} else { } else {
a.maxBindingRequests = *config.MaxBindingRequests agent.maxBindingRequests = *config.MaxBindingRequests
} }
if config.HostAcceptanceMinWait == nil { if config.HostAcceptanceMinWait == nil {
a.hostAcceptanceMinWait = defaultHostAcceptanceMinWait agent.hostAcceptanceMinWait = defaultHostAcceptanceMinWait
} else { } else {
a.hostAcceptanceMinWait = *config.HostAcceptanceMinWait agent.hostAcceptanceMinWait = *config.HostAcceptanceMinWait
} }
if config.SrflxAcceptanceMinWait == nil { if config.SrflxAcceptanceMinWait == nil {
a.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait agent.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait
} else { } else {
a.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait agent.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait
} }
if config.PrflxAcceptanceMinWait == nil { if config.PrflxAcceptanceMinWait == nil {
a.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait agent.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait
} else { } else {
a.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait agent.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait
} }
if config.RelayAcceptanceMinWait == nil { if config.RelayAcceptanceMinWait == nil {
a.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
} else { } else {
a.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait agent.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
} }
if config.STUNGatherTimeout == nil { if config.STUNGatherTimeout == nil {
a.stunGatherTimeout = defaultSTUNGatherTimeout agent.stunGatherTimeout = defaultSTUNGatherTimeout
} else { } else {
a.stunGatherTimeout = *config.STUNGatherTimeout agent.stunGatherTimeout = *config.STUNGatherTimeout
} }
if config.TCPPriorityOffset == nil { if config.TCPPriorityOffset == nil {
a.tcpPriorityOffset = defaultTCPPriorityOffset agent.tcpPriorityOffset = defaultTCPPriorityOffset
} else { } else {
a.tcpPriorityOffset = *config.TCPPriorityOffset agent.tcpPriorityOffset = *config.TCPPriorityOffset
} }
if config.DisconnectedTimeout == nil { if config.DisconnectedTimeout == nil {
a.disconnectedTimeout = defaultDisconnectedTimeout agent.disconnectedTimeout = defaultDisconnectedTimeout
} else { } else {
a.disconnectedTimeout = *config.DisconnectedTimeout agent.disconnectedTimeout = *config.DisconnectedTimeout
} }
if config.FailedTimeout == nil { if config.FailedTimeout == nil {
a.failedTimeout = defaultFailedTimeout agent.failedTimeout = defaultFailedTimeout
} else { } else {
a.failedTimeout = *config.FailedTimeout agent.failedTimeout = *config.FailedTimeout
} }
if config.KeepaliveInterval == nil { if config.KeepaliveInterval == nil {
a.keepaliveInterval = defaultKeepaliveInterval agent.keepaliveInterval = defaultKeepaliveInterval
} else { } else {
a.keepaliveInterval = *config.KeepaliveInterval agent.keepaliveInterval = *config.KeepaliveInterval
} }
if config.CheckInterval == nil { if config.CheckInterval == nil {
a.checkInterval = defaultCheckInterval agent.checkInterval = defaultCheckInterval
} else { } else {
a.checkInterval = *config.CheckInterval agent.checkInterval = *config.CheckInterval
} }
if len(config.CandidateTypes) == 0 { if len(config.CandidateTypes) == 0 {
a.candidateTypes = defaultCandidateTypes() agent.candidateTypes = defaultCandidateTypes()
} else { } else {
a.candidateTypes = config.CandidateTypes agent.candidateTypes = config.CandidateTypes
} }
} }
func (config *AgentConfig) initExtIPMapping(a *Agent) error { func (config *AgentConfig) initExtIPMapping(agent *Agent) error { //nolint:cyclop
var err error var err error
a.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs) agent.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs)
if err != nil { if err != nil {
return err return err
} }
if a.extIPMapper == nil { if agent.extIPMapper == nil {
return nil // This may happen when config.NAT1To1IPs is an empty array return nil // This may happen when config.NAT1To1IPs is an empty array
} }
if a.extIPMapper.candidateType == CandidateTypeHost { if agent.extIPMapper.candidateType == CandidateTypeHost { //nolint:nestif
if a.mDNSMode == MulticastDNSModeQueryAndGather { if agent.mDNSMode == MulticastDNSModeQueryAndGather {
return ErrMulticastDNSWithNAT1To1IPMapping return ErrMulticastDNSWithNAT1To1IPMapping
} }
candiHostEnabled := false candiHostEnabled := false
for _, candiType := range a.candidateTypes { for _, candiType := range agent.candidateTypes {
if candiType == CandidateTypeHost { if candiType == CandidateTypeHost {
candiHostEnabled = true candiHostEnabled = true
break break
} }
} }
if !candiHostEnabled { if !candiHostEnabled {
return ErrIneffectiveNAT1To1IPMappingHost return ErrIneffectiveNAT1To1IPMappingHost
} }
} else if a.extIPMapper.candidateType == CandidateTypeServerReflexive { } else if agent.extIPMapper.candidateType == CandidateTypeServerReflexive {
candiSrflxEnabled := false candiSrflxEnabled := false
for _, candiType := range a.candidateTypes { for _, candiType := range agent.candidateTypes {
if candiType == CandidateTypeServerReflexive { if candiType == CandidateTypeServerReflexive {
candiSrflxEnabled = true candiSrflxEnabled = true
break break
} }
} }
@@ -319,5 +321,6 @@ func (config *AgentConfig) initExtIPMapping(a *Agent) error {
return ErrIneffectiveNAT1To1IPMappingSrflx return ErrIneffectiveNAT1To1IPMappingSrflx
} }
} }
return nil return nil
} }

View File

@@ -32,6 +32,8 @@ func TestAgentGetBestValidCandidatePair(t *testing.T) {
} }
func setupTestAgentGetBestValidCandidatePair(t *testing.T) *TestAgentGetBestValidCandidatePairFixture { func setupTestAgentGetBestValidCandidatePair(t *testing.T) *TestAgentGetBestValidCandidatePairFixture {
t.Helper()
fixture := new(TestAgentGetBestValidCandidatePairFixture) fixture := new(TestAgentGetBestValidCandidatePairFixture)
fixture.hostLocal = newHostLocal(t) fixture.hostLocal = newHostLocal(t)
fixture.relayRemote = newRelayRemote(t) fixture.relayRemote = newRelayRemote(t)

View File

@@ -5,16 +5,18 @@ package ice
import "sync" import "sync"
// OnConnectionStateChange sets a handler that is fired when the connection state changes // OnConnectionStateChange sets a handler that is fired when the connection state changes.
func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error { func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error {
a.onConnectionStateChangeHdlr.Store(f) a.onConnectionStateChangeHdlr.Store(f)
return nil return nil
} }
// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate // OnSelectedCandidatePairChange sets a handler that is fired when the final candidate.
// pair is selected // pair is selected.
func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error { func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error {
a.onSelectedCandidatePairChangeHdlr.Store(f) a.onSelectedCandidatePairChangeHdlr.Store(f)
return nil return nil
} }
@@ -22,6 +24,7 @@ func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) erro
// the gathering process complete the last candidate is nil. // the gathering process complete the last candidate is nil.
func (a *Agent) OnCandidate(f func(Candidate)) error { func (a *Agent) OnCandidate(f func(Candidate)) error {
a.onCandidateHdlr.Store(f) a.onCandidateHdlr.Store(f)
return nil return nil
} }
@@ -73,6 +76,7 @@ func (h *handlerNotifier) Close(graceful bool) {
select { select {
case <-h.done: case <-h.done:
h.Unlock() h.Unlock()
return return
default: default:
} }
@@ -80,7 +84,7 @@ func (h *handlerNotifier) Close(graceful bool) {
h.Unlock() h.Unlock()
} }
func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) { func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
@@ -97,6 +101,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
if len(h.connectionStates) == 0 { if len(h.connectionStates) == 0 {
h.running = false h.running = false
h.Unlock() h.Unlock()
return return
} }
notification := h.connectionStates[0] notification := h.connectionStates[0]
@@ -106,7 +111,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
} }
} }
h.connectionStates = append(h.connectionStates, s) h.connectionStates = append(h.connectionStates, state)
if !h.running { if !h.running {
h.running = true h.running = true
h.notifiers.Add(1) h.notifiers.Add(1)
@@ -114,7 +119,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
} }
} }
func (h *handlerNotifier) EnqueueCandidate(c Candidate) { func (h *handlerNotifier) EnqueueCandidate(cand Candidate) {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
@@ -131,6 +136,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
if len(h.candidates) == 0 { if len(h.candidates) == 0 {
h.running = false h.running = false
h.Unlock() h.Unlock()
return return
} }
notification := h.candidates[0] notification := h.candidates[0]
@@ -140,7 +146,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
} }
} }
h.candidates = append(h.candidates, c) h.candidates = append(h.candidates, cand)
if !h.running { if !h.running {
h.running = true h.running = true
h.notifiers.Add(1) h.notifiers.Add(1)
@@ -148,7 +154,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
} }
} }
func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) { func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
@@ -165,6 +171,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
if len(h.selectedCandidatePairs) == 0 { if len(h.selectedCandidatePairs) == 0 {
h.running = false h.running = false
h.Unlock() h.Unlock()
return return
} }
notification := h.selectedCandidatePairs[0] notification := h.selectedCandidatePairs[0]
@@ -174,7 +181,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
} }
} }
h.selectedCandidatePairs = append(h.selectedCandidatePairs, p) h.selectedCandidatePairs = append(h.selectedCandidatePairs, pair)
if !h.running { if !h.running {
h.running = true h.running = true
h.notifiers.Add(1) h.notifiers.Add(1)

View File

@@ -15,7 +15,7 @@ func TestConnectionStateNotifier(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
updates := make(chan struct{}, 1) updates := make(chan struct{}, 1)
c := &handlerNotifier{ notifier := &handlerNotifier{
connectionStateFunc: func(_ ConnectionState) { connectionStateFunc: func(_ ConnectionState) {
updates <- struct{}{} updates <- struct{}{}
}, },
@@ -24,7 +24,7 @@ func TestConnectionStateNotifier(t *testing.T) {
// Enqueue all updates upfront to ensure that it // Enqueue all updates upfront to ensure that it
// doesn't block // doesn't block
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
c.EnqueueConnectionState(ConnectionStateNew) notifier.EnqueueConnectionState(ConnectionStateNew)
} }
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@@ -39,12 +39,12 @@ func TestConnectionStateNotifier(t *testing.T) {
close(done) close(done)
}() }()
<-done <-done
c.Close(true) notifier.Close(true)
}) })
t.Run("TestUpdateOrdering", func(t *testing.T) { t.Run("TestUpdateOrdering", func(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
updates := make(chan ConnectionState) updates := make(chan ConnectionState)
c := &handlerNotifier{ notifer := &handlerNotifier{
connectionStateFunc: func(cs ConnectionState) { connectionStateFunc: func(cs ConnectionState) {
updates <- cs updates <- cs
}, },
@@ -66,9 +66,9 @@ func TestConnectionStateNotifier(t *testing.T) {
close(done) close(done)
}() }()
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
c.EnqueueConnectionState(ConnectionState(i)) notifer.EnqueueConnectionState(ConnectionState(i))
} }
<-done <-done
c.Close(true) notifer.Close(true)
}) })
} }

View File

@@ -34,17 +34,23 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
} }
func fixtureTestOnSelectedCandidatePairChange(t *testing.T) (*Agent, *CandidatePair) { func fixtureTestOnSelectedCandidatePairChange(t *testing.T) (*Agent, *CandidatePair) {
t.Helper()
agent, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
candidatePair := makeCandidatePair(t) candidatePair := makeCandidatePair(t)
return agent, candidatePair return agent, candidatePair
} }
func makeCandidatePair(t *testing.T) *CandidatePair { func makeCandidatePair(t *testing.T) *CandidatePair {
t.Helper()
hostLocal := newHostLocal(t) hostLocal := newHostLocal(t)
relayRemote := newRelayRemote(t) relayRemote := newRelayRemote(t)
candidatePair := newCandidatePair(hostLocal, relayRemote, false) candidatePair := newCandidatePair(hostLocal, relayRemote, false)
return candidatePair return candidatePair
} }

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
) )
// GetCandidatePairsStats returns a list of candidate pair stats // GetCandidatePairsStats returns a list of candidate pair stats.
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats { func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
var res []CandidatePairStats var res []CandidatePairStats
err := a.loop.Run(a.loop, func(_ context.Context) { err := a.loop.Run(a.loop, func(_ context.Context) {
@@ -49,13 +49,15 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
}) })
if err != nil { if err != nil {
a.log.Errorf("Failed to get candidate pairs stats: %v", err) a.log.Errorf("Failed to get candidate pairs stats: %v", err)
return []CandidatePairStats{} return []CandidatePairStats{}
} }
return res return res
} }
// GetSelectedCandidatePairStats returns a candidate pair stats for selected candidate pair. // GetSelectedCandidatePairStats returns a candidate pair stats for selected candidate pair.
// Returns false if there is no selected pair // Returns false if there is no selected pair.
func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) { func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) {
isAvailable := false isAvailable := false
var res CandidatePairStats var res CandidatePairStats
@@ -98,33 +100,34 @@ func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) {
}) })
if err != nil { if err != nil {
a.log.Errorf("Failed to get selected candidate pair stats: %v", err) a.log.Errorf("Failed to get selected candidate pair stats: %v", err)
return CandidatePairStats{}, false return CandidatePairStats{}, false
} }
return res, isAvailable return res, isAvailable
} }
// GetLocalCandidatesStats returns a list of local candidates stats // GetLocalCandidatesStats returns a list of local candidates stats.
func (a *Agent) GetLocalCandidatesStats() []CandidateStats { func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
var res []CandidateStats var res []CandidateStats
err := a.loop.Run(a.loop, func(_ context.Context) { err := a.loop.Run(a.loop, func(_ context.Context) {
result := make([]CandidateStats, 0, len(a.localCandidates)) result := make([]CandidateStats, 0, len(a.localCandidates))
for networkType, localCandidates := range a.localCandidates { for networkType, localCandidates := range a.localCandidates {
for _, c := range localCandidates { for _, cand := range localCandidates {
relayProtocol := "" relayProtocol := ""
if c.Type() == CandidateTypeRelay { if cand.Type() == CandidateTypeRelay {
if cRelay, ok := c.(*CandidateRelay); ok { if cRelay, ok := cand.(*CandidateRelay); ok {
relayProtocol = cRelay.RelayProtocol() relayProtocol = cRelay.RelayProtocol()
} }
} }
stat := CandidateStats{ stat := CandidateStats{
Timestamp: time.Now(), Timestamp: time.Now(),
ID: c.ID(), ID: cand.ID(),
NetworkType: networkType, NetworkType: networkType,
IP: c.Address(), IP: cand.Address(),
Port: c.Port(), Port: cand.Port(),
CandidateType: c.Type(), CandidateType: cand.Type(),
Priority: c.Priority(), Priority: cand.Priority(),
// URL string // URL string
RelayProtocol: relayProtocol, RelayProtocol: relayProtocol,
// Deleted bool // Deleted bool
@@ -136,12 +139,14 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
}) })
if err != nil { if err != nil {
a.log.Errorf("Failed to get candidate pair stats: %v", err) a.log.Errorf("Failed to get candidate pair stats: %v", err)
return []CandidateStats{} return []CandidateStats{}
} }
return res return res
} }
// GetRemoteCandidatesStats returns a list of remote candidates stats // GetRemoteCandidatesStats returns a list of remote candidates stats.
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats { func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
var res []CandidateStats var res []CandidateStats
err := a.loop.Run(a.loop, func(_ context.Context) { err := a.loop.Run(a.loop, func(_ context.Context) {
@@ -166,7 +171,9 @@ func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
}) })
if err != nil { if err != nil {
a.log.Errorf("Failed to get candidate pair stats: %v", err) a.log.Errorf("Failed to get candidate pair stats: %v", err)
return []CandidateStats{} return []CandidateStats{}
} }
return res return res
} }

View File

@@ -33,21 +33,21 @@ func (ba *BadAddr) String() string {
return "yyy" return "yyy"
} }
func TestHandlePeerReflexive(t *testing.T) { func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
// Limit runtime in case of deadlocks // Limit runtime in case of deadlocks
defer test.TimeOut(time.Second * 2).Stop() defer test.TimeOut(time.Second * 2).Stop()
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) { t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} agent.selector = &controllingSelector{agent: agent, log: agent.log}
hostConfig := CandidateHostConfig{ hostConfig := CandidateHostConfig{
Network: "udp", Network: "udp",
@@ -64,25 +64,25 @@ func TestHandlePeerReflexive(t *testing.T) {
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag), stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
UseCandidate(), UseCandidate(),
AttrControlling(a.tieBreaker), AttrControlling(agent.tieBreaker),
PriorityAttr(local.Priority()), PriorityAttr(local.Priority()),
stun.NewShortTermIntegrity(a.localPwd), stun.NewShortTermIntegrity(agent.localPwd),
stun.Fingerprint, stun.Fingerprint,
) )
require.NoError(t, err) require.NoError(t, err)
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(msg, local, remote) agent.handleInbound(msg, local, remote)
// Length of remote candidate list must be one now // Length of remote candidate list must be one now
if len(a.remoteCandidates) != 1 { if len(agent.remoteCandidates) != 1 {
t.Fatal("failed to add a network type to the remote candidate list") t.Fatal("failed to add a network type to the remote candidate list")
} }
// Length of remote candidate list for a network type must be 1 // Length of remote candidate list for a network type must be 1
set := a.remoteCandidates[local.NetworkType()] set := agent.remoteCandidates[local.NetworkType()]
if len(set) != 1 { if len(set) != 1 {
t.Fatal("failed to add prflx candidate to remote candidate list") t.Fatal("failed to add prflx candidate to remote candidate list")
} }
@@ -104,14 +104,14 @@ func TestHandlePeerReflexive(t *testing.T) {
}) })
t.Run("Bad network type with handleInbound()", func(t *testing.T) { t.Run("Bad network type with handleInbound()", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} agent.selector = &controllingSelector{agent: agent, log: agent.log}
hostConfig := CandidateHostConfig{ hostConfig := CandidateHostConfig{
Network: "tcp", Network: "tcp",
@@ -127,26 +127,26 @@ func TestHandlePeerReflexive(t *testing.T) {
remote := &BadAddr{} remote := &BadAddr{}
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(nil, local, remote) agent.handleInbound(nil, local, remote)
if len(a.remoteCandidates) != 0 { if len(agent.remoteCandidates) != 0 {
t.Fatal("bad address should not be added to the remote candidate list") t.Fatal("bad address should not be added to the remote candidate list")
} }
})) }))
}) })
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) { t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} agent.selector = &controllingSelector{agent: agent, log: agent.log}
tID := [stun.TransactionIDSize]byte{} tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC") copy(tID[:], "ABC")
a.pendingBindingRequests = []bindingRequest{ agent.pendingBindingRequests = []bindingRequest{
{time.Now(), tID, &net.UDPAddr{}, false}, {time.Now(), tID, &net.UDPAddr{}, false},
} }
@@ -164,14 +164,14 @@ func TestHandlePeerReflexive(t *testing.T) {
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999} remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID), msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
stun.NewShortTermIntegrity(a.remotePwd), stun.NewShortTermIntegrity(agent.remotePwd),
stun.Fingerprint, stun.Fingerprint,
) )
require.NoError(t, err) require.NoError(t, err)
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(msg, local, remote) agent.handleInbound(msg, local, remote)
if len(a.remoteCandidates) != 0 { if len(agent.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate") t.Fatal("unknown remote was able to create a candidate")
} }
})) }))
@@ -281,6 +281,7 @@ func TestConnectivityOnStartup(t *testing.T) {
// Ensure accepted // Ensure accepted
<-accepted <-accepted
return aConn, bConn return aConn, bConn
}(aAgent, bAgent) }(aAgent, bAgent)
@@ -308,9 +309,9 @@ func TestConnectivityLite(t *testing.T) {
MappingBehavior: vnet.EndpointIndependent, MappingBehavior: vnet.EndpointIndependent,
FilteringBehavior: vnet.EndpointIndependent, FilteringBehavior: vnet.EndpointIndependent,
} }
v, err := buildVNet(natType, natType) vent, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer v.close() defer vent.close()
aNotifier, aConnected := onConnected() aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected() bNotifier, bConnected := onConnected()
@@ -319,7 +320,7 @@ func TestConnectivityLite(t *testing.T) {
Urls: []*stun.URI{stunServerURL}, Urls: []*stun.URI{stunServerURL},
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled, MulticastDNSMode: MulticastDNSModeDisabled,
Net: v.net0, Net: vent.net0,
} }
aAgent, err := NewAgent(cfg0) aAgent, err := NewAgent(cfg0)
@@ -335,7 +336,7 @@ func TestConnectivityLite(t *testing.T) {
CandidateTypes: []CandidateType{CandidateTypeHost}, CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled, MulticastDNSMode: MulticastDNSModeDisabled,
Net: v.net1, Net: vent.net1,
} }
bAgent, err := NewAgent(cfg1) bAgent, err := NewAgent(cfg1)
@@ -353,7 +354,7 @@ func TestConnectivityLite(t *testing.T) {
<-bConnected <-bConnected
} }
func TestInboundValidity(t *testing.T) { func TestInboundValidity(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
buildMsg := func(class stun.MessageClass, username, key string) *stun.Message { buildMsg := func(class stun.MessageClass, username, key string) *stun.Message {
@@ -381,21 +382,21 @@ func TestInboundValidity(t *testing.T) {
} }
t.Run("Invalid Binding requests should be discarded", func(t *testing.T) { t.Run("Invalid Binding requests should be discarded", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Error constructing ice.Agent") t.Fatalf("Error constructing ice.Agent")
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
a.handleInbound(buildMsg(stun.ClassRequest, "invalid", a.localPwd), local, remote) agent.handleInbound(buildMsg(stun.ClassRequest, "invalid", agent.localPwd), local, remote)
if len(a.remoteCandidates) == 1 { if len(agent.remoteCandidates) == 1 {
t.Fatal("Binding with invalid Username was able to create prflx candidate") t.Fatal("Binding with invalid Username was able to create prflx candidate")
} }
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) agent.handleInbound(buildMsg(stun.ClassRequest, agent.localUfrag+":"+agent.remoteUfrag, "Invalid"), local, remote)
if len(a.remoteCandidates) == 1 { if len(agent.remoteCandidates) == 1 {
t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate") t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate")
} }
}) })
@@ -452,35 +453,35 @@ func TestInboundValidity(t *testing.T) {
}) })
t.Run("Valid bind without fingerprint", func(t *testing.T) { t.Run("Valid bind without fingerprint", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log} agent.selector = &controllingSelector{agent: agent, log: agent.log}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag), stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
stun.NewShortTermIntegrity(a.localPwd), stun.NewShortTermIntegrity(agent.localPwd),
) )
require.NoError(t, err) require.NoError(t, err)
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(msg, local, remote) agent.handleInbound(msg, local, remote)
if len(a.remoteCandidates) != 1 { if len(agent.remoteCandidates) != 1 {
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate") t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
} }
})) }))
}) })
t.Run("Success with invalid TransactionID", func(t *testing.T) { t.Run("Success with invalid TransactionID", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Error constructing ice.Agent") t.Fatalf("Error constructing ice.Agent")
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
hostConfig := CandidateHostConfig{ hostConfig := CandidateHostConfig{
@@ -499,13 +500,13 @@ func TestInboundValidity(t *testing.T) {
tID := [stun.TransactionIDSize]byte{} tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC") copy(tID[:], "ABC")
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID), msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
stun.NewShortTermIntegrity(a.remotePwd), stun.NewShortTermIntegrity(agent.remotePwd),
stun.Fingerprint, stun.Fingerprint,
) )
require.NoError(t, err) require.NoError(t, err)
a.handleInbound(msg, local, remote) agent.handleInbound(msg, local, remote)
if len(a.remoteCandidates) != 0 { if len(agent.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate") t.Fatal("unknown remote was able to create a candidate")
} }
}) })
@@ -514,35 +515,35 @@ func TestInboundValidity(t *testing.T) {
func TestInvalidAgentStarts(t *testing.T) { func TestInvalidAgentStarts(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel() defer cancel()
if _, err = a.Dial(ctx, "", "bar"); err != nil && !errors.Is(err, ErrRemoteUfragEmpty) { if _, err = agent.Dial(ctx, "", "bar"); err != nil && !errors.Is(err, ErrRemoteUfragEmpty) {
t.Fatal(err) t.Fatal(err)
} }
if _, err = a.Dial(ctx, "foo", ""); err != nil && !errors.Is(err, ErrRemotePwdEmpty) { if _, err = agent.Dial(ctx, "foo", ""); err != nil && !errors.Is(err, ErrRemotePwdEmpty) {
t.Fatal(err) t.Fatal(err)
} }
if _, err = a.Dial(ctx, "foo", "bar"); err != nil && !errors.Is(err, ErrCanceledByCaller) { if _, err = agent.Dial(ctx, "foo", "bar"); err != nil && !errors.Is(err, ErrCanceledByCaller) {
t.Fatal(err) t.Fatal(err)
} }
if _, err = a.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) { if _, err = agent.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) {
t.Fatal(err) t.Fatal(err)
} }
} }
// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages // Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages.
func TestConnectionStateCallback(t *testing.T) { func TestConnectionStateCallback(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 5).Stop() defer test.TimeOut(time.Second * 5).Stop()
@@ -635,18 +636,18 @@ func TestInvalidGather(t *testing.T) {
}) })
} }
func TestCandidatePairsStats(t *testing.T) { func TestCandidatePairsStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
// Avoid deadlocks? // Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop() defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Failed to create agent: %s", err) t.Fatalf("Failed to create agent: %s", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
hostConfig := &CandidateHostConfig{ hostConfig := &CandidateHostConfig{
@@ -711,21 +712,21 @@ func TestCandidatePairsStats(t *testing.T) {
} }
for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} { for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} {
p := a.findPair(hostLocal, remote) p := agent.findPair(hostLocal, remote)
if p == nil { if p == nil {
a.addPair(hostLocal, remote) agent.addPair(hostLocal, remote)
} }
} }
p := a.findPair(hostLocal, prflxRemote) p := agent.findPair(hostLocal, prflxRemote)
p.state = CandidatePairStateFailed p.state = CandidatePairStateFailed
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second) p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
} }
stats := a.GetCandidatePairsStats() stats := agent.GetCandidatePairsStats()
if len(stats) != 4 { if len(stats) != 4 {
t.Fatal("expected 4 candidate pairs stats") t.Fatal("expected 4 candidate pairs stats")
} }
@@ -789,18 +790,18 @@ func TestCandidatePairsStats(t *testing.T) {
} }
} }
func TestSelectedCandidatePairStats(t *testing.T) { func TestSelectedCandidatePairStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
// Avoid deadlocks? // Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop() defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Failed to create agent: %s", err) t.Fatalf("Failed to create agent: %s", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
hostConfig := &CandidateHostConfig{ hostConfig := &CandidateHostConfig{
@@ -828,23 +829,23 @@ func TestSelectedCandidatePairStats(t *testing.T) {
} }
// no selected pair, should return not available // no selected pair, should return not available
_, ok := a.GetSelectedCandidatePairStats() _, ok := agent.GetSelectedCandidatePairStats()
require.False(t, ok) require.False(t, ok)
// add pair and populate some RTT stats // add pair and populate some RTT stats
p := a.findPair(hostLocal, srflxRemote) p := agent.findPair(hostLocal, srflxRemote)
if p == nil { if p == nil {
a.addPair(hostLocal, srflxRemote) agent.addPair(hostLocal, srflxRemote)
p = a.findPair(hostLocal, srflxRemote) p = agent.findPair(hostLocal, srflxRemote)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second) p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
} }
// set the pair as selected // set the pair as selected
a.setSelectedPair(p) agent.setSelectedPair(p)
stats, ok := a.GetSelectedCandidatePairStats() stats, ok := agent.GetSelectedCandidatePairStats()
require.True(t, ok) require.True(t, ok)
if stats.LocalCandidateID != hostLocal.ID() { if stats.LocalCandidateID != hostLocal.ID() {
@@ -872,18 +873,18 @@ func TestSelectedCandidatePairStats(t *testing.T) {
} }
} }
func TestLocalCandidateStats(t *testing.T) { func TestLocalCandidateStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
// Avoid deadlocks? // Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop() defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Failed to create agent: %s", err) t.Fatalf("Failed to create agent: %s", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
hostConfig := &CandidateHostConfig{ hostConfig := &CandidateHostConfig{
@@ -910,9 +911,9 @@ func TestLocalCandidateStats(t *testing.T) {
t.Fatalf("Failed to construct local srflx candidate: %s", err) t.Fatalf("Failed to construct local srflx candidate: %s", err)
} }
a.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal} agent.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal}
localStats := a.GetLocalCandidatesStats() localStats := agent.GetLocalCandidatesStats()
if len(localStats) != 2 { if len(localStats) != 2 {
t.Fatalf("expected 2 local candidates stats, got %d instead", len(localStats)) t.Fatalf("expected 2 local candidates stats, got %d instead", len(localStats))
} }
@@ -953,18 +954,18 @@ func TestLocalCandidateStats(t *testing.T) {
} }
} }
func TestRemoteCandidateStats(t *testing.T) { func TestRemoteCandidateStats(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
// Avoid deadlocks? // Avoid deadlocks?
defer test.TimeOut(1 * time.Second).Stop() defer test.TimeOut(1 * time.Second).Stop()
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Failed to create agent: %s", err) t.Fatalf("Failed to create agent: %s", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
relayConfig := &CandidateRelayConfig{ relayConfig := &CandidateRelayConfig{
@@ -1017,9 +1018,9 @@ func TestRemoteCandidateStats(t *testing.T) {
t.Fatalf("Failed to construct remote host candidate: %s", err) t.Fatalf("Failed to construct remote host candidate: %s", err)
} }
a.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} agent.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote}
remoteStats := a.GetRemoteCandidatesStats() remoteStats := agent.GetRemoteCandidatesStats()
if len(remoteStats) != 4 { if len(remoteStats) != 4 {
t.Fatalf("expected 4 remote candidates stats, got %d instead", len(remoteStats)) t.Fatalf("expected 4 remote candidates stats, got %d instead", len(remoteStats))
} }
@@ -1076,31 +1077,31 @@ func TestRemoteCandidateStats(t *testing.T) {
func TestInitExtIPMapping(t *testing.T) { func TestInitExtIPMapping(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
// a.extIPMapper should be nil by default // agent.extIPMapper should be nil by default
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
if err != nil { if err != nil {
t.Fatalf("Failed to create agent: %v", err) t.Fatalf("Failed to create agent: %v", err)
} }
if a.extIPMapper != nil { if agent.extIPMapper != nil {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
t.Fatal("a.extIPMapper should be nil by default") t.Fatal("a.extIPMapper should be nil by default")
} }
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
// a.extIPMapper should be nil when NAT1To1IPs is a non-nil empty array // a.extIPMapper should be nil when NAT1To1IPs is a non-nil empty array
a, err = NewAgent(&AgentConfig{ agent, err = NewAgent(&AgentConfig{
NAT1To1IPs: []string{}, NAT1To1IPs: []string{},
NAT1To1IPCandidateType: CandidateTypeHost, NAT1To1IPCandidateType: CandidateTypeHost,
}) })
if err != nil { if err != nil {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
t.Fatalf("Failed to create agent: %v", err) t.Fatalf("Failed to create agent: %v", err)
} }
if a.extIPMapper != nil { if agent.extIPMapper != nil {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
t.Fatal("a.extIPMapper should be nil by default") t.Fatal("a.extIPMapper should be nil by default")
} }
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
// NewAgent should return an error when 1:1 NAT for host candidate is enabled // NewAgent should return an error when 1:1 NAT for host candidate is enabled
// but the candidate type does not appear in the CandidateTypes. // but the candidate type does not appear in the CandidateTypes.
@@ -1150,32 +1151,38 @@ func TestBindingRequestTimeout(t *testing.T) {
const expectedRemovalCount = 2 const expectedRemovalCount = 2
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
now := time.Now() now := time.Now()
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now, // Valid timestamp: now, // Valid
}) })
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-3900 * time.Millisecond), // Valid timestamp: now.Add(-3900 * time.Millisecond), // Valid
}) })
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-4100 * time.Millisecond), // Invalid timestamp: now.Add(-4100 * time.Millisecond), // Invalid
}) })
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
timestamp: now.Add(-75 * time.Hour), // Invalid timestamp: now.Add(-75 * time.Hour), // Invalid
}) })
a.invalidatePendingBindingRequests(now) agent.invalidatePendingBindingRequests(now)
require.Equal(t, expectedRemovalCount, len(a.pendingBindingRequests), "Binding invalidation due to timeout did not remove the correct number of binding requests")
require.Equal(
t,
expectedRemovalCount,
len(agent.pendingBindingRequests),
"Binding invalidation due to timeout did not remove the correct number of binding requests",
)
} }
// TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard // TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard
// and ensure it's backwards compatible with previous versions of the pion/ice // and ensure it's backwards compatible with previous versions of the pion/ice.
func TestAgentCredentials(t *testing.T) { func TestAgentCredentials(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -1207,7 +1214,7 @@ func TestAgentCredentials(t *testing.T) {
} }
// Assert that Agent on Failure deletes all existing candidates // Assert that Agent on Failure deletes all existing candidates
// User can then do an ICE Restart to bring agent back // User can then do an ICE Restart to bring agent back.
func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) { func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -1254,7 +1261,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
<-done <-done
} }
// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides // Assert that the ICE Agent can go directly from Connecting -> Failed on both sides.
func TestConnectionStateConnectingToFailed(t *testing.T) { func TestConnectionStateConnectingToFailed(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -1378,6 +1385,7 @@ func TestAgentRestart(t *testing.T) {
out += c.Address() + ":" out += c.Address() + ":"
out += strconv.Itoa(c.Port()) out += strconv.Itoa(c.Port())
} }
return return
} }
@@ -1423,33 +1431,33 @@ func TestAgentRestart(t *testing.T) {
func TestGetRemoteCredentials(t *testing.T) { func TestGetRemoteCredentials(t *testing.T) {
var config AgentConfig var config AgentConfig
a, err := NewAgent(&config) agent, err := NewAgent(&config)
if err != nil { if err != nil {
t.Fatalf("Error constructing ice.Agent: %v", err) t.Fatalf("Error constructing ice.Agent: %v", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
a.remoteUfrag = "remoteUfrag" agent.remoteUfrag = "remoteUfrag"
a.remotePwd = "remotePwd" agent.remotePwd = "remotePwd"
actualUfrag, actualPwd, err := a.GetRemoteUserCredentials() actualUfrag, actualPwd, err := agent.GetRemoteUserCredentials()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, actualUfrag, a.remoteUfrag) require.Equal(t, actualUfrag, agent.remoteUfrag)
require.Equal(t, actualPwd, a.remotePwd) require.Equal(t, actualPwd, agent.remotePwd)
} }
func TestGetRemoteCandidates(t *testing.T) { func TestGetRemoteCandidates(t *testing.T) {
var config AgentConfig var config AgentConfig
a, err := NewAgent(&config) agent, err := NewAgent(&config)
if err != nil { if err != nil {
t.Fatalf("Error constructing ice.Agent: %v", err) t.Fatalf("Error constructing ice.Agent: %v", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
expectedCandidates := []Candidate{} expectedCandidates := []Candidate{}
@@ -1467,10 +1475,10 @@ func TestGetRemoteCandidates(t *testing.T) {
expectedCandidates = append(expectedCandidates, cand) expectedCandidates = append(expectedCandidates, cand)
a.addRemoteCandidate(cand) agent.addRemoteCandidate(cand)
} }
actualCandidates, err := a.GetRemoteCandidates() actualCandidates, err := agent.GetRemoteCandidates()
require.NoError(t, err) require.NoError(t, err)
require.ElementsMatch(t, expectedCandidates, actualCandidates) require.ElementsMatch(t, expectedCandidates, actualCandidates)
} }
@@ -1478,12 +1486,12 @@ func TestGetRemoteCandidates(t *testing.T) {
func TestGetLocalCandidates(t *testing.T) { func TestGetLocalCandidates(t *testing.T) {
var config AgentConfig var config AgentConfig
a, err := NewAgent(&config) agent, err := NewAgent(&config)
if err != nil { if err != nil {
t.Fatalf("Error constructing ice.Agent: %v", err) t.Fatalf("Error constructing ice.Agent: %v", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
dummyConn := &net.UDPConn{} dummyConn := &net.UDPConn{}
@@ -1502,11 +1510,11 @@ func TestGetLocalCandidates(t *testing.T) {
expectedCandidates = append(expectedCandidates, cand) expectedCandidates = append(expectedCandidates, cand)
err = a.addCandidate(context.Background(), cand, dummyConn) err = agent.addCandidate(context.Background(), cand, dummyConn)
require.NoError(t, err) require.NoError(t, err)
} }
actualCandidates, err := a.GetLocalCandidates() actualCandidates, err := agent.GetLocalCandidates()
require.NoError(t, err) require.NoError(t, err)
require.ElementsMatch(t, expectedCandidates, actualCandidates) require.ElementsMatch(t, expectedCandidates, actualCandidates)
} }
@@ -1666,7 +1674,7 @@ func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) {
<-isTested <-isTested
} }
// Assert that a Lite agent goes to disconnected and failed // Assert that a Lite agent goes to disconnected and failed.
func TestLiteLifecycle(t *testing.T) { func TestLiteLifecycle(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -1818,7 +1826,7 @@ func TestGetSelectedCandidatePair(t *testing.T) {
require.NoError(t, wan.Stop()) require.NoError(t, wan.Stop())
} }
func TestAcceptAggressiveNomination(t *testing.T) { func TestAcceptAggressiveNomination(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop() defer test.TimeOut(time.Second * 30).Stop()
@@ -1932,24 +1940,25 @@ func TestAcceptAggressiveNomination(t *testing.T) {
bcandidates, err = bAgent.GetLocalCandidates() bcandidates, err = bAgent.GetLocalCandidates()
require.NoError(t, err) require.NoError(t, err)
for _, c := range bcandidates { for _, cand := range bcandidates {
if c != bAgent.getSelectedPair().Local { if cand != bAgent.getSelectedPair().Local { //nolint:nestif
if expectNewSelectedCandidate == nil { if expectNewSelectedCandidate == nil {
expected_change_priority: expected_change_priority:
for _, candidates := range aAgent.remoteCandidates { for _, candidates := range aAgent.remoteCandidates {
for _, candidate := range candidates { for _, candidate := range candidates {
if candidate.Equal(c) { if candidate.Equal(cand) {
if tc.useHigherPriority { if tc.useHigherPriority {
candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert
} else { } else {
candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert
} }
break expected_change_priority break expected_change_priority
} }
} }
} }
if tc.isExpectedToSwitch { if tc.isExpectedToSwitch {
expectNewSelectedCandidate = c expectNewSelectedCandidate = cand
} else { } else {
expectNewSelectedCandidate = aAgent.getSelectedPair().Remote expectNewSelectedCandidate = aAgent.getSelectedPair().Remote
} }
@@ -1958,18 +1967,27 @@ func TestAcceptAggressiveNomination(t *testing.T) {
change_priority: change_priority:
for _, candidates := range aAgent.remoteCandidates { for _, candidates := range aAgent.remoteCandidates {
for _, candidate := range candidates { for _, candidate := range candidates {
if candidate.Equal(c) { if candidate.Equal(cand) {
if tc.useHigherPriority { if tc.useHigherPriority {
candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert
} else { } else {
candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert
} }
break change_priority break change_priority
} }
} }
} }
} }
_, err = c.writeTo(buildMsg(stun.ClassRequest, aAgent.localUfrag+":"+aAgent.remoteUfrag, aAgent.localPwd, c.Priority()).Raw, bAgent.getSelectedPair().Remote) _, err = cand.writeTo(
buildMsg(
stun.ClassRequest,
aAgent.localUfrag+":"+aAgent.remoteUfrag,
aAgent.localPwd,
cand.Priority(),
).Raw,
bAgent.getSelectedPair().Remote,
)
require.NoError(t, err) require.NoError(t, err)
} }
} }
@@ -1991,7 +2009,7 @@ func TestAcceptAggressiveNomination(t *testing.T) {
require.NoError(t, wan.Stop()) require.NoError(t, wan.Stop())
} }
// Close can deadlock but GracefulClose must not // Close can deadlock but GracefulClose must not.
func TestAgentGracefulCloseDeadlock(t *testing.T) { func TestAgentGracefulCloseDeadlock(t *testing.T) {
defer test.CheckRoutinesStrict(t)() defer test.CheckRoutinesStrict(t)()
defer test.TimeOut(time.Second * 5).Stop() defer test.TimeOut(time.Second * 5).Stop()

View File

@@ -16,7 +16,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux // TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux.
func TestMuxAgent(t *testing.T) { func TestMuxAgent(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -58,7 +58,7 @@ func TestMuxAgent(t *testing.T) {
require.NoError(t, muxedA.Close()) require.NoError(t, muxedA.Close())
}() }()
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeHost}, CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
}) })
@@ -68,10 +68,10 @@ func TestMuxAgent(t *testing.T) {
if aClosed { if aClosed {
return return
} }
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
conn, muxedConn := connect(a, muxedA) conn, muxedConn := connect(agent, muxedA)
pair := muxedA.getSelectedPair() pair := muxedA.getSelectedPair()
require.NotNil(t, pair) require.NotNil(t, pair)

View File

@@ -13,13 +13,13 @@ const (
receiveMTU = 8192 receiveMTU = 8192
defaultLocalPreference = 65535 defaultLocalPreference = 65535
// ComponentRTP indicates that the candidate is used for RTP // ComponentRTP indicates that the candidate is used for RTP.
ComponentRTP uint16 = 1 ComponentRTP uint16 = 1
// ComponentRTCP indicates that the candidate is used for RTCP // ComponentRTCP indicates that the candidate is used for RTCP.
ComponentRTCP ComponentRTCP
) )
// Candidate represents an ICE candidate // Candidate represents an ICE candidate.
type Candidate interface { type Candidate interface {
// An arbitrary string used in the freezing algorithm to // An arbitrary string used in the freezing algorithm to
// group similar candidates. It is the same for two candidates that // group similar candidates. It is the same for two candidates that

View File

@@ -48,12 +48,12 @@ type candidateBase struct {
extensions []CandidateExtension extensions []CandidateExtension
} }
// Done implements context.Context // Done implements context.Context.
func (c *candidateBase) Done() <-chan struct{} { func (c *candidateBase) Done() <-chan struct{} {
return c.closeCh return c.closeCh
} }
// Err implements context.Context // Err implements context.Context.
func (c *candidateBase) Err() error { func (c *candidateBase) Err() error {
select { select {
case <-c.closedCh: case <-c.closedCh:
@@ -63,17 +63,17 @@ func (c *candidateBase) Err() error {
} }
} }
// Deadline implements context.Context // Deadline implements context.Context.
func (c *candidateBase) Deadline() (deadline time.Time, ok bool) { func (c *candidateBase) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false return time.Time{}, false
} }
// Value implements context.Context // Value implements context.Context.
func (c *candidateBase) Value(interface{}) interface{} { func (c *candidateBase) Value(interface{}) interface{} {
return nil return nil
} }
// ID returns Candidate ID // ID returns Candidate ID.
func (c *candidateBase) ID() string { func (c *candidateBase) ID() string {
return c.id return c.id
} }
@@ -86,27 +86,27 @@ func (c *candidateBase) Foundation() string {
return fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(c.Type().String()+c.address+c.networkType.String()))) return fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(c.Type().String()+c.address+c.networkType.String())))
} }
// Address returns Candidate Address // Address returns Candidate Address.
func (c *candidateBase) Address() string { func (c *candidateBase) Address() string {
return c.address return c.address
} }
// Port returns Candidate Port // Port returns Candidate Port.
func (c *candidateBase) Port() int { func (c *candidateBase) Port() int {
return c.port return c.port
} }
// Type returns candidate type // Type returns candidate type.
func (c *candidateBase) Type() CandidateType { func (c *candidateBase) Type() CandidateType {
return c.candidateType return c.candidateType
} }
// NetworkType returns candidate NetworkType // NetworkType returns candidate NetworkType.
func (c *candidateBase) NetworkType() NetworkType { func (c *candidateBase) NetworkType() NetworkType {
return c.networkType return c.networkType
} }
// Component returns candidate component // Component returns candidate component.
func (c *candidateBase) Component() uint16 { func (c *candidateBase) Component() uint16 {
return c.component return c.component
} }
@@ -115,8 +115,8 @@ func (c *candidateBase) SetComponent(component uint16) {
c.component = component c.component = component
} }
// LocalPreference returns the local preference for this candidate // LocalPreference returns the local preference for this candidate.
func (c *candidateBase) LocalPreference() uint16 { func (c *candidateBase) LocalPreference() uint16 { //nolint:cyclop
if c.NetworkType().IsTCP() { if c.NetworkType().IsTCP() {
// RFC 6544, section 4.2 // RFC 6544, section 4.2
// //
@@ -182,6 +182,7 @@ func (c *candidateBase) LocalPreference() uint16 {
case CandidateTypeUnspecified: case CandidateTypeUnspecified:
return 0 return 0
} }
return 0 return 0
}() }()
@@ -191,7 +192,7 @@ func (c *candidateBase) LocalPreference() uint16 {
return defaultLocalPreference return defaultLocalPreference
} }
// RelatedAddress returns *CandidateRelatedAddress // RelatedAddress returns *CandidateRelatedAddress.
func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress { func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress {
return c.relatedAddress return c.relatedAddress
} }
@@ -200,10 +201,11 @@ func (c *candidateBase) TCPType() TCPType {
return c.tcpType return c.tcpType
} }
// start runs the candidate using the provided connection // start runs the candidate using the provided connection.
func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) { func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) {
if c.conn != nil { if c.conn != nil {
c.agent().log.Warn("Can't start already started candidateBase") c.agent().log.Warn("Can't start already started candidateBase")
return return
} }
c.currAgent = a c.currAgent = a
@@ -221,7 +223,7 @@ var bufferPool = sync.Pool{ // nolint:gochecknoglobals
} }
func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) { func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
a := c.agent() agent := c.agent()
defer close(c.closedCh) defer close(c.closedCh)
@@ -242,8 +244,9 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
n, srcAddr, err := c.conn.ReadFrom(buf) n, srcAddr, err := c.conn.ReadFrom(buf)
if err != nil { if err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) { if !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
a.log.Warnf("Failed to read from candidate %s: %v", c, err) agent.log.Warnf("Failed to read from candidate %s: %v", c, err)
} }
return return
} }
@@ -254,8 +257,10 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool { func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool {
if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok { if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok {
candidate.seen(false) candidate.seen(false)
return true return true
} }
return false return false
} }
@@ -267,48 +272,51 @@ func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net
} }
func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
a := c.agent() agent := c.agent()
if stun.IsMessage(buf) { if stun.IsMessage(buf) {
m := &stun.Message{ msg := &stun.Message{
Raw: make([]byte, len(buf)), Raw: make([]byte, len(buf)),
} }
// Explicitly copy raw buffer so Message can own the memory. // Explicitly copy raw buffer so Message can own the memory.
copy(m.Raw, buf) copy(msg.Raw, buf)
if err := msg.Decode(); err != nil {
agent.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err)
if err := m.Decode(); err != nil {
a.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err)
return return
} }
if err := a.loop.Run(c, func(_ context.Context) { if err := agent.loop.Run(c, func(_ context.Context) {
// nolint: contextcheck // nolint: contextcheck
a.handleInbound(m, c, srcAddr) agent.handleInbound(msg, c, srcAddr)
}); err != nil { }); err != nil {
a.log.Warnf("Failed to handle message: %v", err) agent.log.Warnf("Failed to handle message: %v", err)
} }
return return
} }
if !c.validateSTUNTrafficCache(srcAddr) { if !c.validateSTUNTrafficCache(srcAddr) {
remoteCandidate, valid := a.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck remoteCandidate, valid := agent.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck
if !valid { if !valid {
a.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr()) agent.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr())
return return
} }
c.addRemoteCandidateCache(remoteCandidate, srcAddr) c.addRemoteCandidateCache(remoteCandidate, srcAddr)
} }
// Note: This will return packetio.ErrFull if the buffer ever manages to fill up. // Note: This will return packetio.ErrFull if the buffer ever manages to fill up.
if _, err := a.buf.Write(buf); err != nil { if _, err := agent.buf.Write(buf); err != nil {
a.log.Warnf("Failed to write packet: %s", err) agent.log.Warnf("Failed to write packet: %s", err)
return return
} }
} }
// close stops the recvLoop // close stops the recvLoop.
func (c *candidateBase) close() error { func (c *candidateBase) close() error {
// If conn has never been started will be nil // If conn has never been started will be nil
if c.Done() == nil { if c.Done() == nil {
@@ -353,13 +361,15 @@ func (c *candidateBase) writeTo(raw []byte, dst Candidate) (int, error) {
return n, err return n, err
} }
c.agent().log.Infof("Failed to send packet: %v", err) c.agent().log.Infof("Failed to send packet: %v", err)
return n, nil return n, nil
} }
c.seen(true) c.seen(true)
return n, nil return n, nil
} }
// TypePreference returns the type preference for this candidate // TypePreference returns the type preference for this candidate.
func (c *candidateBase) TypePreference() uint16 { func (c *candidateBase) TypePreference() uint16 {
pref := c.Type().Preference() pref := c.Type().Preference()
if pref == 0 { if pref == 0 {
@@ -397,7 +407,7 @@ func (c *candidateBase) Priority() uint32 {
(1<<0)*uint32(256-c.Component()) (1<<0)*uint32(256-c.Component())
} }
// Equal is used to compare two candidateBases // Equal is used to compare two candidateBases.
func (c *candidateBase) Equal(other Candidate) bool { func (c *candidateBase) Equal(other Candidate) bool {
if c.addr() != other.addr() { if c.addr() != other.addr() {
if c.addr() == nil || other.addr() == nil { if c.addr() == nil || other.addr() == nil {
@@ -416,22 +426,30 @@ func (c *candidateBase) Equal(other Candidate) bool {
c.RelatedAddress().Equal(other.RelatedAddress()) c.RelatedAddress().Equal(other.RelatedAddress())
} }
// DeepEqual is same as Equal but also compares the extensions // DeepEqual is same as Equal but also compares the extensions.
func (c *candidateBase) DeepEqual(other Candidate) bool { func (c *candidateBase) DeepEqual(other Candidate) bool {
return c.Equal(other) && c.extensionsEqual(other.Extensions()) return c.Equal(other) && c.extensionsEqual(other.Extensions())
} }
// String makes the candidateBase printable // String makes the candidateBase printable.
func (c *candidateBase) String() string { func (c *candidateBase) String() string {
return fmt.Sprintf("%s %s %s%s (resolved: %v)", c.NetworkType(), c.Type(), net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), c.relatedAddress, c.resolvedAddr) return fmt.Sprintf(
"%s %s %s%s (resolved: %v)",
c.NetworkType(),
c.Type(),
net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())),
c.relatedAddress,
c.resolvedAddr,
)
} }
// LastReceived returns a time.Time indicating the last time // LastReceived returns a time.Time indicating the last time
// this candidate was received // this candidate was received.
func (c *candidateBase) LastReceived() time.Time { func (c *candidateBase) LastReceived() time.Time {
if lastReceived, ok := c.lastReceived.Load().(time.Time); ok { if lastReceived, ok := c.lastReceived.Load().(time.Time); ok {
return lastReceived return lastReceived
} }
return time.Time{} return time.Time{}
} }
@@ -440,11 +458,12 @@ func (c *candidateBase) setLastReceived(t time.Time) {
} }
// LastSent returns a time.Time indicating the last time // LastSent returns a time.Time indicating the last time
// this candidate was sent // this candidate was sent.
func (c *candidateBase) LastSent() time.Time { func (c *candidateBase) LastSent() time.Time {
if lastSent, ok := c.lastSent.Load().(time.Time); ok { if lastSent, ok := c.lastSent.Load().(time.Time); ok {
return lastSent return lastSent
} }
return time.Time{} return time.Time{}
} }
@@ -484,10 +503,11 @@ func removeZoneIDFromAddress(addr string) string {
if i := strings.Index(addr, "%"); i != -1 { if i := strings.Index(addr, "%"); i != -1 {
return addr[:i] return addr[:i]
} }
return addr return addr
} }
// Marshal returns the string representation of the ICECandidate // Marshal returns the string representation of the ICECandidate.
func (c *candidateBase) Marshal() string { func (c *candidateBase) Marshal() string {
val := c.Foundation() val := c.Foundation()
if val == " " { if val == " " {
@@ -618,9 +638,7 @@ func (c *candidateBase) setExtensions(extensions []CandidateExtension) {
// UnmarshalCandidate Parses a candidate from a string // UnmarshalCandidate Parses a candidate from a string
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
func UnmarshalCandidate(raw string) (Candidate, error) { func UnmarshalCandidate(raw string) (Candidate, error) { //nolint:cyclop
// rfc5245
pos := 0 pos := 0
// foundation ( 1*32ice-char ) But we allow for empty foundation, // foundation ( 1*32ice-char ) But we allow for empty foundation,
@@ -805,15 +823,16 @@ func UnmarshalCandidate(raw string) (Candidate, error) {
// Read an ice-char token from the raw string // Read an ice-char token from the raw string
// ice-char = ALPHA / DIGIT / "+" / "/" // ice-char = ALPHA / DIGIT / "+" / "/"
// stop reading when a space is encountered or the end of the string // stop reading when a space is encountered or the end of the string.
func readCandidateCharToken(raw string, start int, limit int) (string, int, error) { func readCandidateCharToken(raw string, start int, limit int) (string, int, error) { //nolint:cyclop
for i, char := range raw[start:] { for i, char := range raw[start:] {
if char == 0x20 { // SP if char == 0x20 { // SP
return raw[start : start+i], start + i + 1, nil return raw[start : start+i], start + i + 1, nil
} }
if i == limit { if i == limit {
return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) //nolint: err113 // handled by caller //nolint: err113 // handled by caller
return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit)
} }
if !(char >= 'A' && char <= 'Z' || if !(char >= 'A' && char <= 'Z' ||
@@ -828,7 +847,7 @@ func readCandidateCharToken(raw string, start int, limit int) (string, int, erro
} }
// Read an ice string token from the raw string until a space is encountered // Read an ice string token from the raw string until a space is encountered
// Or the end of the string, we imply that ice string are UTF-8 encoded // Or the end of the string, we imply that ice string are UTF-8 encoded.
func readCandidateStringToken(raw string, start int) (string, int) { func readCandidateStringToken(raw string, start int) (string, int) {
for i, char := range raw[start:] { for i, char := range raw[start:] {
if char == 0x20 { // SP if char == 0x20 { // SP
@@ -840,7 +859,7 @@ func readCandidateStringToken(raw string, start int) (string, int) {
} }
// Read a digit token from the raw string // Read a digit token from the raw string
// stop reading when a space is encountered or the end of the string // stop reading when a space is encountered or the end of the string.
func readCandidateDigitToken(raw string, start, limit int) (int, int, error) { func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
var val int var val int
for i, char := range raw[start:] { for i, char := range raw[start:] {
@@ -849,7 +868,8 @@ func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
} }
if i == limit { if i == limit {
return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) //nolint: err113 // handled by caller //nolint: err113 // handled by caller
return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit)
} }
if !(char >= '0' && char <= '9') { if !(char >= '0' && char <= '9') {
@@ -862,7 +882,7 @@ func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
return val, len(raw), nil return val, len(raw), nil
} }
// Read and validate RFC 4566 port from the raw string // Read and validate RFC 4566 port from the raw string.
func readCandidatePort(raw string, start int) (int, int, error) { func readCandidatePort(raw string, start int) (int, int, error) {
port, pos, err := readCandidateDigitToken(raw, start, 5) port, pos, err := readCandidateDigitToken(raw, start, 5)
if err != nil { if err != nil {
@@ -878,7 +898,7 @@ func readCandidatePort(raw string, start int) (int, int, error) {
// Read a byte-string token from the raw string // Read a byte-string token from the raw string
// As defined in RFC 4566 1*(%x01-09/%x0B-0C/%x0E-FF) ;any byte except NUL, CR, or LF // As defined in RFC 4566 1*(%x01-09/%x0B-0C/%x0E-FF) ;any byte except NUL, CR, or LF
// we imply that extensions byte-string are UTF-8 encoded // we imply that extensions byte-string are UTF-8 encoded.
func readCandidateByteString(raw string, start int) (string, int, error) { func readCandidateByteString(raw string, start int) (string, int, error) {
for i, char := range raw[start:] { for i, char := range raw[start:] {
if char == 0x20 { // SP if char == 0x20 { // SP
@@ -952,17 +972,23 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
for i := 0; i < len(raw); { for i := 0; i < len(raw); {
key, next, err := readCandidateByteString(raw, i) key, next, err := readCandidateByteString(raw, i)
if err != nil { if err != nil {
return extensions, "", fmt.Errorf("%w: failed to read key %v", errParseExtension, err) //nolint: errorlint // we wrap the error return extensions, "", fmt.Errorf(
"%w: failed to read key %v", errParseExtension, err, //nolint: errorlint // we wrap the error
)
} }
i = next i = next
if i >= len(raw) { if i >= len(raw) {
return extensions, "", fmt.Errorf("%w: missing value for %s in %s", errParseExtension, key, raw) return extensions, "", fmt.Errorf(
"%w: missing value for %s in %s", errParseExtension, key, raw, //nolint: errorlint // we are wrapping the error
)
} }
value, next, err := readCandidateByteString(raw, i) value, next, err := readCandidateByteString(raw, i)
if err != nil { if err != nil {
return extensions, "", fmt.Errorf("%w: failed to read value %v", errParseExtension, err) //nolint: errorlint // we are wrapping the error return extensions, "", fmt.Errorf(
"%w: failed to read value %v", errParseExtension, err, //nolint: errorlint // we are wrapping the error
)
} }
i = next i = next
@@ -973,5 +999,5 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
extensions = append(extensions, CandidateExtension{key, value}) extensions = append(extensions, CandidateExtension{key, value})
} }
return return extensions, rawTCPTypeRaw, nil
} }

View File

@@ -8,14 +8,14 @@ import (
"strings" "strings"
) )
// CandidateHost is a candidate of type host // CandidateHost is a candidate of type host.
type CandidateHost struct { type CandidateHost struct {
candidateBase candidateBase
network string network string
} }
// CandidateHostConfig is the config required to create a new CandidateHost // CandidateHostConfig is the config required to create a new CandidateHost.
type CandidateHostConfig struct { type CandidateHostConfig struct {
CandidateID string CandidateID string
Network string Network string
@@ -28,7 +28,7 @@ type CandidateHostConfig struct {
IsLocationTracked bool IsLocationTracked bool
} }
// NewCandidateHost creates a new host candidate // NewCandidateHost creates a new host candidate.
func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
candidateID := config.CandidateID candidateID := config.CandidateID
@@ -36,7 +36,7 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
candidateID = globalCandidateIDGenerator.Generate() candidateID = globalCandidateIDGenerator.Generate()
} }
c := &CandidateHost{ candidateHost := &CandidateHost{
candidateBase: candidateBase{ candidateBase: candidateBase{
id: candidateID, id: candidateID,
address: config.Address, address: config.Address,
@@ -58,15 +58,15 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
return nil, err return nil, err
} }
if err := c.setIPAddr(ipAddr); err != nil { if err := candidateHost.setIPAddr(ipAddr); err != nil {
return nil, err return nil, err
} }
} else { } else {
// Until mDNS candidate is resolved assume it is UDPv4 // Until mDNS candidate is resolved assume it is UDPv4
c.candidateBase.networkType = NetworkTypeUDP4 candidateHost.candidateBase.networkType = NetworkTypeUDP4
} }
return c, nil return candidateHost, nil
} }
func (c *CandidateHost) setIPAddr(addr netip.Addr) error { func (c *CandidateHost) setIPAddr(addr netip.Addr) error {

View File

@@ -15,7 +15,7 @@ type CandidatePeerReflexive struct {
candidateBase candidateBase
} }
// CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive // CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive.
type CandidatePeerReflexiveConfig struct { type CandidatePeerReflexiveConfig struct {
CandidateID string CandidateID string
Network string Network string
@@ -28,7 +28,7 @@ type CandidatePeerReflexiveConfig struct {
RelPort int RelPort int
} }
// NewCandidatePeerReflexive creates a new peer reflective candidate // NewCandidatePeerReflexive creates a new peer reflective candidate.
func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) { func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) {
ipAddr, err := netip.ParseAddr(config.Address) ipAddr, err := netip.ParseAddr(config.Address)
if err != nil { if err != nil {

View File

@@ -16,7 +16,7 @@ type CandidateRelay struct {
onClose func() error onClose func() error
} }
// CandidateRelayConfig is the config required to create a new CandidateRelay // CandidateRelayConfig is the config required to create a new CandidateRelay.
type CandidateRelayConfig struct { type CandidateRelayConfig struct {
CandidateID string CandidateID string
Network string Network string
@@ -31,7 +31,7 @@ type CandidateRelayConfig struct {
OnClose func() error OnClose func() error
} }
// NewCandidateRelay creates a new relay candidate // NewCandidateRelay creates a new relay candidate.
func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) { func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
candidateID := config.CandidateID candidateID := config.CandidateID
@@ -75,7 +75,7 @@ func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
}, nil }, nil
} }
// LocalPreference returns the local preference for this candidate // LocalPreference returns the local preference for this candidate.
func (c *CandidateRelay) LocalPreference() uint16 { func (c *CandidateRelay) LocalPreference() uint16 {
// These preference values come from libwebrtc // These preference values come from libwebrtc
// https://github.com/mozilla/libwebrtc/blob/1389c76d9c79839a2ca069df1db48aa3f2e6a1ac/p2p/base/turn_port.cc#L61 // https://github.com/mozilla/libwebrtc/blob/1389c76d9c79839a2ca069df1db48aa3f2e6a1ac/p2p/base/turn_port.cc#L61
@@ -103,6 +103,7 @@ func (c *CandidateRelay) close() error {
err = c.onClose() err = c.onClose()
c.onClose = nil c.onClose = nil
} }
return err return err
} }

View File

@@ -13,7 +13,7 @@ type CandidateServerReflexive struct {
candidateBase candidateBase
} }
// CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive // CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive.
type CandidateServerReflexiveConfig struct { type CandidateServerReflexiveConfig struct {
CandidateID string CandidateID string
Network string Network string
@@ -26,7 +26,7 @@ type CandidateServerReflexiveConfig struct {
RelPort int RelPort int
} }
// NewCandidateServerReflexive creates a new server reflective candidate // NewCandidateServerReflexive creates a new server reflective candidate.
func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) { func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) {
ipAddr, err := netip.ParseAddr(config.Address) ipAddr, err := netip.ParseAddr(config.Address)
if err != nil { if err != nil {

View File

@@ -16,7 +16,7 @@ import (
const localhostIPStr = "127.0.0.1" const localhostIPStr = "127.0.0.1"
func TestCandidateTypePreference(t *testing.T) { func TestCandidateTypePreference(t *testing.T) {
r := require.New(t) req := require.New(t)
hostDefaultPreference := uint16(126) hostDefaultPreference := uint16(126)
prflxDefaultPreference := uint16(110) prflxDefaultPreference := uint16(110)
@@ -53,16 +53,16 @@ func TestCandidateTypePreference(t *testing.T) {
} }
if networkType.IsTCP() { if networkType.IsTCP() {
r.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference()) req.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference())
r.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference()) req.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference())
r.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference()) req.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference())
} else { } else {
r.Equal(hostDefaultPreference, hostCandidate.TypePreference()) req.Equal(hostDefaultPreference, hostCandidate.TypePreference())
r.Equal(prflxDefaultPreference, prflxCandidate.TypePreference()) req.Equal(prflxDefaultPreference, prflxCandidate.TypePreference())
r.Equal(srflxDefaultPreference, srflxCandidate.TypePreference()) req.Equal(srflxDefaultPreference, srflxCandidate.TypePreference())
} }
r.Equal(relayDefaultPreference, relayCandidate.TypePreference()) req.Equal(relayDefaultPreference, relayCandidate.TypePreference())
} }
} }
} }
@@ -266,20 +266,27 @@ func TestCandidateFoundation(t *testing.T) {
}).Foundation()) }).Foundation())
} }
func mustCandidateHost(conf *CandidateHostConfig) Candidate { func mustCandidateHost(t *testing.T, conf *CandidateHostConfig) Candidate {
cand, err := NewCandidateHost(conf)
if err != nil {
panic(err)
}
return cand
}
func mustCandidateHostWithExtensions(t *testing.T, conf *CandidateHostConfig, extensions []CandidateExtension) Candidate {
t.Helper() t.Helper()
cand, err := NewCandidateHost(conf) cand, err := NewCandidateHost(conf)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
}
return cand
}
func mustCandidateHostWithExtensions(
t *testing.T,
conf *CandidateHostConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidateHost(conf)
if err != nil {
t.Fatal(err)
} }
cand.setExtensions(extensions) cand.setExtensions(extensions)
@@ -287,20 +294,27 @@ func mustCandidateHostWithExtensions(t *testing.T, conf *CandidateHostConfig, ex
return cand return cand
} }
func mustCandidateRelay(conf *CandidateRelayConfig) Candidate { func mustCandidateRelay(t *testing.T, conf *CandidateRelayConfig) Candidate {
cand, err := NewCandidateRelay(conf)
if err != nil {
panic(err)
}
return cand
}
func mustCandidateRelayWithExtensions(t *testing.T, conf *CandidateRelayConfig, extensions []CandidateExtension) Candidate {
t.Helper() t.Helper()
cand, err := NewCandidateRelay(conf) cand, err := NewCandidateRelay(conf)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
}
return cand
}
func mustCandidateRelayWithExtensions(
t *testing.T,
conf *CandidateRelayConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidateRelay(conf)
if err != nil {
t.Fatal(err)
} }
cand.setExtensions(extensions) cand.setExtensions(extensions)
@@ -308,20 +322,27 @@ func mustCandidateRelayWithExtensions(t *testing.T, conf *CandidateRelayConfig,
return cand return cand
} }
func mustCandidateServerReflexive(conf *CandidateServerReflexiveConfig) Candidate { func mustCandidateServerReflexive(t *testing.T, conf *CandidateServerReflexiveConfig) Candidate {
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
panic(err)
}
return cand
}
func mustCandidateServerReflexiveWithExtensions(t *testing.T, conf *CandidateServerReflexiveConfig, extensions []CandidateExtension) Candidate {
t.Helper() t.Helper()
cand, err := NewCandidateServerReflexive(conf) cand, err := NewCandidateServerReflexive(conf)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
}
return cand
}
func mustCandidateServerReflexiveWithExtensions(
t *testing.T,
conf *CandidateServerReflexiveConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper()
cand, err := NewCandidateServerReflexive(conf)
if err != nil {
t.Fatal(err)
} }
cand.setExtensions(extensions) cand.setExtensions(extensions)
@@ -329,12 +350,16 @@ func mustCandidateServerReflexiveWithExtensions(t *testing.T, conf *CandidateSer
return cand return cand
} }
func mustCandidatePeerReflexiveWithExtensions(t *testing.T, conf *CandidatePeerReflexiveConfig, extensions []CandidateExtension) Candidate { func mustCandidatePeerReflexiveWithExtensions(
t *testing.T,
conf *CandidatePeerReflexiveConfig,
extensions []CandidateExtension,
) Candidate {
t.Helper() t.Helper()
cand, err := NewCandidatePeerReflexive(conf) cand, err := NewCandidatePeerReflexive(conf)
if err != nil { if err != nil {
panic(err) t.Fatal(err)
} }
cand.setExtensions(extensions) cand.setExtensions(extensions)
@@ -349,7 +374,7 @@ func TestCandidateMarshal(t *testing.T) {
expectError bool expectError bool
}{ }{
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP6.String(), Network: NetworkTypeUDP6.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
Port: 53987, Port: 53987,
@@ -360,7 +385,7 @@ func TestCandidateMarshal(t *testing.T) {
false, false,
}, },
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: "10.0.75.1", Address: "10.0.75.1",
Port: 53634, Port: 53634,
@@ -369,7 +394,7 @@ func TestCandidateMarshal(t *testing.T) {
false, false,
}, },
{ {
mustCandidateServerReflexive(&CandidateServerReflexiveConfig{ mustCandidateServerReflexive(t, &CandidateServerReflexiveConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: "191.228.238.68", Address: "191.228.238.68",
Port: 53991, Port: 53991,
@@ -395,11 +420,12 @@ func TestCandidateMarshal(t *testing.T) {
{"network-cost", "10"}, {"network-cost", "10"},
}, },
), ),
//nolint: lll
"4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10", "4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10",
false, false,
}, },
{ {
mustCandidateRelay(&CandidateRelayConfig{ mustCandidateRelay(t, &CandidateRelayConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: "50.0.0.1", Address: "50.0.0.1",
Port: 5000, Port: 5000,
@@ -410,7 +436,7 @@ func TestCandidateMarshal(t *testing.T) {
false, false,
}, },
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeTCP4.String(), Network: NetworkTypeTCP4.String(),
Address: "192.168.0.196", Address: "192.168.0.196",
Port: 0, Port: 0,
@@ -420,7 +446,7 @@ func TestCandidateMarshal(t *testing.T) {
false, false,
}, },
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local", Address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local",
Port: 60542, Port: 60542,
@@ -429,7 +455,7 @@ func TestCandidateMarshal(t *testing.T) {
}, },
// Missing Foundation // Missing Foundation
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: localhostIPStr, Address: localhostIPStr,
Port: 80, Port: 80,
@@ -440,7 +466,7 @@ func TestCandidateMarshal(t *testing.T) {
false, false,
}, },
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: localhostIPStr, Address: localhostIPStr,
Port: 80, Port: 80,
@@ -451,7 +477,7 @@ func TestCandidateMarshal(t *testing.T) {
false, false,
}, },
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeTCP4.String(), Network: NetworkTypeTCP4.String(),
Address: "172.28.142.173", Address: "172.28.142.173",
Port: 7686, Port: 7686,
@@ -467,8 +493,10 @@ func TestCandidateMarshal(t *testing.T) {
{nil, "1938809241", true}, {nil, "1938809241", true},
{nil, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2", true}, {nil, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2", true},
{nil, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host", true}, {nil, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host", true},
//nolint: lll
{nil, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3", true}, {nil, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3", true},
{nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr", true}, {nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr", true},
//nolint: lll
{nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3", true}, {nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3", true},
{nil, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host", true}, {nil, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host", true},
{nil, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host", true}, {nil, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host", true},
@@ -521,12 +549,19 @@ func TestCandidateMarshal(t *testing.T) {
actualCandidate, err := UnmarshalCandidate(test.marshaled) actualCandidate, err := UnmarshalCandidate(test.marshaled)
if test.expectError { if test.expectError {
require.Error(t, err, "expected error", test.marshaled) require.Error(t, err, "expected error", test.marshaled)
return return
} }
require.NoError(t, err) require.NoError(t, err)
require.Truef(t, test.candidate.Equal(actualCandidate), "%s != %s", test.candidate.String(), actualCandidate.String()) require.Truef(
t,
test.candidate.Equal(actualCandidate),
"%s != %s",
test.candidate.String(),
actualCandidate.String(),
)
require.Equal(t, test.marshaled, actualCandidate.Marshal()) require.Equal(t, test.marshaled, actualCandidate.Marshal())
}) })
} }
@@ -573,7 +608,7 @@ func TestCandidateWriteTo(t *testing.T) {
} }
func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) { func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) {
candidateWithZoneID := mustCandidateHost(&CandidateHostConfig{ candidateWithZoneID := mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP6.String(), Network: NetworkTypeUDP6.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%Local Connection", Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%Local Connection",
Port: 53987, Port: 53987,
@@ -583,7 +618,7 @@ func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) {
candidateStr := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host" candidateStr := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host"
require.Equal(t, candidateStr, candidateWithZoneID.Marshal()) require.Equal(t, candidateStr, candidateWithZoneID.Marshal())
candidate := mustCandidateHost(&CandidateHostConfig{ candidate := mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP6.String(), Network: NetworkTypeUDP6.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
Port: 53987, Port: 53987,
@@ -612,6 +647,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"ufrag", "QNvE"}, {"ufrag", "QNvE"},
{"network-id", "4"}, {"network-id", "4"},
}, },
//nolint: lll
"1299692247 1 udp 2122134271 fdc8:cc8:c835:e400:343c:feb:32c8:17b9 58240 typ host generation 0 ufrag QNvE network-id 4", "1299692247 1 udp 2122134271 fdc8:cc8:c835:e400:343c:feb:32c8:17b9 58240 typ host generation 0 ufrag QNvE network-id 4",
}, },
{ {
@@ -620,6 +656,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"network-id", "2"}, {"network-id", "2"},
{"network-cost", "50"}, {"network-cost", "50"},
}, },
//nolint:lll
"647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991 generation 1 network-id 2 network-cost 50", "647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991 generation 1 network-id 2 network-cost 50",
}, },
{ {
@@ -628,6 +665,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"network-id", "2"}, {"network-id", "2"},
{"network-cost", "10"}, {"network-cost", "10"},
}, },
//nolint:lll
"4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10", "4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10",
}, },
{ {
@@ -638,6 +676,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
{"ufrag", "frag42abcdef"}, {"ufrag", "frag42abcdef"},
{"password", "abc123exp123"}, {"password", "abc123exp123"},
}, },
//nolint: lll
"848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001 generation 0 network-id 1 network-cost 20 ufrag frag42abcdef password abc123exp123", "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001 generation 0 network-id 1 network-cost 20 ufrag frag42abcdef password abc123exp123",
}, },
{ {
@@ -703,7 +742,7 @@ func TestCandidateExtensionsDeepEqual(t *testing.T) {
equal bool equal bool
}{ }{
{ {
mustCandidateHost(&CandidateHostConfig{ mustCandidateHost(t, &CandidateHostConfig{
Network: NetworkTypeUDP4.String(), Network: NetworkTypeUDP4.String(),
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
Port: 53987, Port: 53987,

View File

@@ -20,8 +20,7 @@ func newCandidatePair(local, remote Candidate, controlling bool) *CandidatePair
} }
} }
// CandidatePair is a combination of a // CandidatePair is a combination of a local and remote candidate.
// local and remote candidate
type CandidatePair struct { type CandidatePair struct {
iceRoleControlling bool iceRoleControlling bool
Remote Candidate Remote Candidate
@@ -42,8 +41,17 @@ func (p *CandidatePair) String() string {
return "" return ""
} }
return fmt.Sprintf("prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v", return fmt.Sprintf(
p.priority(), p.Local.Priority(), p.Local, p.Remote, p.Remote.Priority(), p.state, p.nominated, p.nominateOnBindingSuccess) "prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v",
p.priority(),
p.Local.Priority(),
p.Local,
p.Remote,
p.Remote.Priority(),
p.state,
p.nominated,
p.nominateOnBindingSuccess,
)
} }
func (p *CandidatePair) equal(other *CandidatePair) bool { func (p *CandidatePair) equal(other *CandidatePair) bool {
@@ -53,6 +61,7 @@ func (p *CandidatePair) equal(other *CandidatePair) bool {
if p == nil || other == nil { if p == nil || other == nil {
return false return false
} }
return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote) return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote)
} }
@@ -60,9 +69,9 @@ func (p *CandidatePair) equal(other *CandidatePair) bool {
// Let G be the priority for the candidate provided by the controlling // Let G be the priority for the candidate provided by the controlling
// agent. Let D be the priority for the candidate provided by the // agent. Let D be the priority for the candidate provided by the
// controlled agent. // controlled agent.
// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0) // pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0).
func (p *CandidatePair) priority() uint64 { func (p *CandidatePair) priority() uint64 {
var g, d uint32 var g, d uint32 //nolint:varnamelen // clearer to use g and d here
if p.iceRoleControlling { if p.iceRoleControlling {
g = p.Local.Priority() g = p.Local.Priority()
d = p.Remote.Priority() d = p.Remote.Priority()
@@ -77,18 +86,21 @@ func (p *CandidatePair) priority() uint64 {
if x < y { if x < y {
return uint64(x) return uint64(x)
} }
return uint64(y) return uint64(y)
} }
localMax := func(x, y uint32) uint64 { localMax := func(x, y uint32) uint64 {
if x > y { if x > y {
return uint64(x) return uint64(x)
} }
return uint64(y) return uint64(y)
} }
cmp := func(x, y uint32) uint64 { cmp := func(x, y uint32) uint64 {
if x > y { if x > y {
return uint64(1) return uint64(1)
} }
return uint64(0) return uint64(0)
} }
@@ -109,7 +121,7 @@ func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) {
} }
// UpdateRoundTripTime sets the current round time of this pair and // UpdateRoundTripTime sets the current round time of this pair and
// accumulates total round trip time and responses received // accumulates total round trip time and responses received.
func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) { func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) {
rttNs := rtt.Nanoseconds() rttNs := rtt.Nanoseconds()
atomic.StoreInt64(&p.currentRoundTripTime, rttNs) atomic.StoreInt64(&p.currentRoundTripTime, rttNs)

View File

@@ -3,12 +3,12 @@
package ice package ice
// CandidatePairState represent the ICE candidate pair state // CandidatePairState represent the ICE candidate pair state.
type CandidatePairState int type CandidatePairState int
const ( const (
// CandidatePairStateWaiting means a check has not been performed for // CandidatePairStateWaiting means a check has not been performed for
// this pair // this pair.
CandidatePairStateWaiting CandidatePairState = iota + 1 CandidatePairStateWaiting CandidatePairState = iota + 1
// CandidatePairStateInProgress means a check has been sent for this pair, // CandidatePairStateInProgress means a check has been sent for this pair,
@@ -36,5 +36,6 @@ func (c CandidatePairState) String() string {
case CandidatePairStateSucceeded: case CandidatePairStateSucceeded:
return "succeeded" return "succeeded"
} }
return "Unknown candidate pair state" return "Unknown candidate pair state"
} }

View File

@@ -12,7 +12,7 @@ type CandidateRelatedAddress struct {
Port int Port int
} }
// String makes CandidateRelatedAddress printable // String makes CandidateRelatedAddress printable.
func (c *CandidateRelatedAddress) String() string { func (c *CandidateRelatedAddress) String() string {
if c == nil { if c == nil {
return "" return ""
@@ -27,6 +27,7 @@ func (c *CandidateRelatedAddress) Equal(other *CandidateRelatedAddress) bool {
if c == nil && other == nil { if c == nil && other == nil {
return true return true
} }
return c != nil && other != nil && return c != nil && other != nil &&
c.Address == other.Address && c.Address == other.Address &&
c.Port == other.Port c.Port == other.Port

View File

@@ -3,10 +3,10 @@
package ice package ice
// CandidateType represents the type of candidate // CandidateType represents the type of candidate.
type CandidateType byte type CandidateType byte
// CandidateType enum // CandidateType enum.
const ( const (
CandidateTypeUnspecified CandidateType = iota CandidateTypeUnspecified CandidateType = iota
CandidateTypeHost CandidateTypeHost
@@ -15,7 +15,7 @@ const (
CandidateTypeRelay CandidateTypeRelay
) )
// String makes CandidateType printable // String makes CandidateType printable.
func (c CandidateType) String() string { func (c CandidateType) String() string {
switch c { switch c {
case CandidateTypeHost: case CandidateTypeHost:
@@ -29,6 +29,7 @@ func (c CandidateType) String() string {
case CandidateTypeUnspecified: case CandidateTypeUnspecified:
return "Unknown candidate type" return "Unknown candidate type"
} }
return "Unknown candidate type" return "Unknown candidate type"
} }
@@ -49,6 +50,7 @@ func (c CandidateType) Preference() uint16 {
case CandidateTypeRelay, CandidateTypeUnspecified: case CandidateTypeRelay, CandidateTypeUnspecified:
return 0 return 0
} }
return 0 return 0
} }
@@ -61,5 +63,6 @@ func containsCandidateType(candidateType CandidateType, candidateTypeList []Cand
return true return true
} }
} }
return false return false
} }

View File

@@ -45,7 +45,7 @@ func (v *virtualNet) close() {
v.wan.Stop() //nolint:errcheck,gosec v.wan.Stop() //nolint:errcheck,gosec
} }
func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) { func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) { //nolint:cyclop
loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory := logging.NewDefaultLoggerFactory()
// WAN // WAN
@@ -77,6 +77,7 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
vnetGlobalIPA + "/" + vnetLocalIPA, vnetGlobalIPA + "/" + vnetLocalIPA,
} }
} }
return []string{ return []string{
vnetGlobalIPA, vnetGlobalIPA,
} }
@@ -114,6 +115,7 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
vnetGlobalIPB + "/" + vnetLocalIPB, vnetGlobalIPB + "/" + vnetLocalIPB,
} }
} }
return []string{ return []string{
vnetGlobalIPB, vnetGlobalIPB,
} }
@@ -175,6 +177,7 @@ func addVNetSTUN(wanNet *vnet.Net, loggerFactory logging.LoggerFactory) (*turn.S
if pw, ok := credMap[username]; ok { if pw, ok := credMap[username]; ok {
return turn.GenerateAuthKey(username, realm, pw), true return turn.GenerateAuthKey(username, realm, pw), true
} }
return nil, false return nil, false
}, },
PacketConnConfigs: []turn.PacketConnConfig{ PacketConnConfigs: []turn.PacketConnConfig{
@@ -222,6 +225,7 @@ func connectWithVNet(aAgent, bAgent *Agent) (*Conn, *Conn) {
// Ensure accepted // Ensure accepted
<-accepted <-accepted
return aConn, bConn return aConn, bConn
} }
@@ -230,7 +234,7 @@ type agentTestConfig struct {
nat1To1IPCandidateType CandidateType nat1To1IPCandidateType CandidateType
} }
func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) { func pipeWithVNet(vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
aNotifier, aConnected := onConnected() aNotifier, aConnected := onConnected()
bNotifier, bConnected := onConnected() bNotifier, bConnected := onConnected()
@@ -247,7 +251,7 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
MulticastDNSMode: MulticastDNSModeDisabled, MulticastDNSMode: MulticastDNSModeDisabled,
NAT1To1IPs: nat1To1IPs, NAT1To1IPs: nat1To1IPs,
NAT1To1IPCandidateType: a0TestConfig.nat1To1IPCandidateType, NAT1To1IPCandidateType: a0TestConfig.nat1To1IPCandidateType,
Net: v.net0, Net: vnet.net0,
} }
aAgent, err := NewAgent(cfg0) aAgent, err := NewAgent(cfg0)
@@ -270,7 +274,7 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
MulticastDNSMode: MulticastDNSModeDisabled, MulticastDNSMode: MulticastDNSModeDisabled,
NAT1To1IPs: nat1To1IPs, NAT1To1IPs: nat1To1IPs,
NAT1To1IPCandidateType: a1TestConfig.nat1To1IPCandidateType, NAT1To1IPCandidateType: a1TestConfig.nat1To1IPCandidateType,
Net: v.net1, Net: vnet.net1,
} }
bAgent, err := NewAgent(cfg1) bAgent, err := NewAgent(cfg1)
@@ -293,6 +297,8 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
} }
func closePipe(t *testing.T, ca *Conn, cb *Conn) { func closePipe(t *testing.T, ca *Conn, cb *Conn) {
t.Helper()
require.NoError(t, ca.Close()) require.NoError(t, ca.Close())
require.NoError(t, cb.Close()) require.NoError(t, cb.Close())
} }
@@ -325,10 +331,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointIndependent, MappingBehavior: vnet.EndpointIndependent,
FilteringBehavior: vnet.EndpointIndependent, FilteringBehavior: vnet.EndpointIndependent,
} }
v, err := buildVNet(natType, natType) vnet, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer v.close() defer vnet.close()
log.Debug("Connecting...") log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{ a0TestConfig := &agentTestConfig{
@@ -341,7 +347,7 @@ func TestConnectivityVNet(t *testing.T) {
stunServerURL, stunServerURL,
}, },
} }
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig) ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
@@ -358,10 +364,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointAddrPortDependent, MappingBehavior: vnet.EndpointAddrPortDependent,
FilteringBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent,
} }
v, err := buildVNet(natType, natType) vnet, err := buildVNet(natType, natType)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer v.close() defer vnet.close()
log.Debug("Connecting...") log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{ a0TestConfig := &agentTestConfig{
@@ -375,7 +381,7 @@ func TestConnectivityVNet(t *testing.T) {
stunServerURL, stunServerURL,
}, },
} }
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig) ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...") log.Debug("Closing...")
closePipe(t, ca, cb) closePipe(t, ca, cb)
@@ -394,10 +400,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointAddrPortDependent, MappingBehavior: vnet.EndpointAddrPortDependent,
FilteringBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent,
} }
v, err := buildVNet(natType0, natType1) vnet, err := buildVNet(natType0, natType1)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer v.close() defer vnet.close()
log.Debug("Connecting...") log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{ a0TestConfig := &agentTestConfig{
@@ -407,7 +413,7 @@ func TestConnectivityVNet(t *testing.T) {
a1TestConfig := &agentTestConfig{ a1TestConfig := &agentTestConfig{
urls: []*stun.URI{}, urls: []*stun.URI{},
} }
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig) ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...") log.Debug("Closing...")
closePipe(t, ca, cb) closePipe(t, ca, cb)
@@ -426,10 +432,10 @@ func TestConnectivityVNet(t *testing.T) {
MappingBehavior: vnet.EndpointAddrPortDependent, MappingBehavior: vnet.EndpointAddrPortDependent,
FilteringBehavior: vnet.EndpointAddrPortDependent, FilteringBehavior: vnet.EndpointAddrPortDependent,
} }
v, err := buildVNet(natType0, natType1) vnet, err := buildVNet(natType0, natType1)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer v.close() defer vnet.close()
log.Debug("Connecting...") log.Debug("Connecting...")
a0TestConfig := &agentTestConfig{ a0TestConfig := &agentTestConfig{
@@ -439,14 +445,15 @@ func TestConnectivityVNet(t *testing.T) {
a1TestConfig := &agentTestConfig{ a1TestConfig := &agentTestConfig{
urls: []*stun.URI{}, urls: []*stun.URI{},
} }
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig) ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
log.Debug("Closing...") log.Debug("Closing...")
closePipe(t, ca, cb) closePipe(t, ca, cb)
}) })
} }
// TestDisconnectedToConnected requires that an agent can go to disconnected, and then return to connected successfully // TestDisconnectedToConnected requires that an agent can go to disconnected,
// and then return to connected successfully.
func TestDisconnectedToConnected(t *testing.T) { func TestDisconnectedToConnected(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -546,7 +553,7 @@ func TestDisconnectedToConnected(t *testing.T) {
require.NoError(t, wan.Stop()) require.NoError(t, wan.Stop())
} }
// Agent.Write should use the best valid pair if a selected pair is not yet available // Agent.Write should use the best valid pair if a selected pair is not yet available.
func TestWriteUseValidPair(t *testing.T) { func TestWriteUseValidPair(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()

View File

@@ -29,69 +29,71 @@ var (
ErrPort = errors.New("invalid port") ErrPort = errors.New("invalid port")
// ErrLocalUfragInsufficientBits indicates local username fragment insufficient bits are provided. // ErrLocalUfragInsufficientBits indicates local username fragment insufficient bits are provided.
// Have to be at least 24 bits long // Have to be at least 24 bits long.
ErrLocalUfragInsufficientBits = errors.New("local username fragment is less than 24 bits long") ErrLocalUfragInsufficientBits = errors.New("local username fragment is less than 24 bits long")
// ErrLocalPwdInsufficientBits indicates local password insufficient bits are provided. // ErrLocalPwdInsufficientBits indicates local password insufficient bits are provided.
// Have to be at least 128 bits long // Have to be at least 128 bits long.
ErrLocalPwdInsufficientBits = errors.New("local password is less than 128 bits long") ErrLocalPwdInsufficientBits = errors.New("local password is less than 128 bits long")
// ErrProtoType indicates an unsupported transport type was provided. // ErrProtoType indicates an unsupported transport type was provided.
ErrProtoType = errors.New("invalid transport protocol type") ErrProtoType = errors.New("invalid transport protocol type")
// ErrClosed indicates the agent is closed // ErrClosed indicates the agent is closed.
ErrClosed = taskloop.ErrClosed ErrClosed = taskloop.ErrClosed
// ErrNoCandidatePairs indicates agent does not have a valid candidate pair // ErrNoCandidatePairs indicates agent does not have a valid candidate pair.
ErrNoCandidatePairs = errors.New("no candidate pairs available") ErrNoCandidatePairs = errors.New("no candidate pairs available")
// ErrCanceledByCaller indicates agent connection was canceled by the caller // ErrCanceledByCaller indicates agent connection was canceled by the caller.
ErrCanceledByCaller = errors.New("connecting canceled by caller") ErrCanceledByCaller = errors.New("connecting canceled by caller")
// ErrMultipleStart indicates agent was started twice // ErrMultipleStart indicates agent was started twice.
ErrMultipleStart = errors.New("attempted to start agent twice") ErrMultipleStart = errors.New("attempted to start agent twice")
// ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag // ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag.
ErrRemoteUfragEmpty = errors.New("remote ufrag is empty") ErrRemoteUfragEmpty = errors.New("remote ufrag is empty")
// ErrRemotePwdEmpty indicates agent was started with an empty remote pwd // ErrRemotePwdEmpty indicates agent was started with an empty remote pwd.
ErrRemotePwdEmpty = errors.New("remote pwd is empty") ErrRemotePwdEmpty = errors.New("remote pwd is empty")
// ErrNoOnCandidateHandler indicates agent was started without OnCandidate // ErrNoOnCandidateHandler indicates agent was started without OnCandidate.
ErrNoOnCandidateHandler = errors.New("no OnCandidate provided") ErrNoOnCandidateHandler = errors.New("no OnCandidate provided")
// ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times // ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times.
ErrMultipleGatherAttempted = errors.New("attempting to gather candidates during gathering state") ErrMultipleGatherAttempted = errors.New("attempting to gather candidates during gathering state")
// ErrUsernameEmpty indicates agent was give TURN URL with an empty Username // ErrUsernameEmpty indicates agent was give TURN URL with an empty Username.
ErrUsernameEmpty = errors.New("username is empty") ErrUsernameEmpty = errors.New("username is empty")
// ErrPasswordEmpty indicates agent was give TURN URL with an empty Password // ErrPasswordEmpty indicates agent was give TURN URL with an empty Password.
ErrPasswordEmpty = errors.New("password is empty") ErrPasswordEmpty = errors.New("password is empty")
// ErrAddressParseFailed indicates we were unable to parse a candidate address // ErrAddressParseFailed indicates we were unable to parse a candidate address.
ErrAddressParseFailed = errors.New("failed to parse address") ErrAddressParseFailed = errors.New("failed to parse address")
// ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent // ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent.
ErrLiteUsingNonHostCandidates = errors.New("lite agents must only use host candidates") ErrLiteUsingNonHostCandidates = errors.New("lite agents must only use host candidates")
// ErrUselessUrlsProvided indicates that one or more URL was provided to the agent but no host // ErrUselessUrlsProvided indicates that one or more URL was provided to the agent but no host
// candidate required them // candidate required them.
ErrUselessUrlsProvided = errors.New("agent does not need URL with selected candidate types") ErrUselessUrlsProvided = errors.New("agent does not need URL with selected candidate types")
// ErrUnsupportedNAT1To1IPCandidateType indicates that the specified NAT1To1IPCandidateType is // ErrUnsupportedNAT1To1IPCandidateType indicates that the specified NAT1To1IPCandidateType is
// unsupported // unsupported.
ErrUnsupportedNAT1To1IPCandidateType = errors.New("unsupported 1:1 NAT IP candidate type") ErrUnsupportedNAT1To1IPCandidateType = errors.New("unsupported 1:1 NAT IP candidate type")
// ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid // ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid.
ErrInvalidNAT1To1IPMapping = errors.New("invalid 1:1 NAT IP mapping") ErrInvalidNAT1To1IPMapping = errors.New("invalid 1:1 NAT IP mapping")
// ErrExternalMappedIPNotFound in NAT1To1IPMapping // ErrExternalMappedIPNotFound in NAT1To1IPMapping.
ErrExternalMappedIPNotFound = errors.New("external mapped IP not found") ErrExternalMappedIPNotFound = errors.New("external mapped IP not found")
// ErrMulticastDNSWithNAT1To1IPMapping indicates that the mDNS gathering cannot be used along // ErrMulticastDNSWithNAT1To1IPMapping indicates that the mDNS gathering cannot be used along
// with 1:1 NAT IP mapping for host candidate. // with 1:1 NAT IP mapping for host candidate.
ErrMulticastDNSWithNAT1To1IPMapping = errors.New("mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate") ErrMulticastDNSWithNAT1To1IPMapping = errors.New(
"mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate",
)
// ErrIneffectiveNAT1To1IPMappingHost indicates that 1:1 NAT IP mapping for host candidate is // ErrIneffectiveNAT1To1IPMappingHost indicates that 1:1 NAT IP mapping for host candidate is
// requested, but the host candidate type is disabled. // requested, but the host candidate type is disabled.
@@ -101,10 +103,12 @@ var (
// requested, but the srflx candidate type is disabled. // requested, but the srflx candidate type is disabled.
ErrIneffectiveNAT1To1IPMappingSrflx = errors.New("1:1 NAT IP mapping for srflx candidate ineffective") ErrIneffectiveNAT1To1IPMappingSrflx = errors.New("1:1 NAT IP mapping for srflx candidate ineffective")
// ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName // ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName.
ErrInvalidMulticastDNSHostName = errors.New("invalid mDNS HostName, must end with .local and can only contain a single '.'") ErrInvalidMulticastDNSHostName = errors.New(
"invalid mDNS HostName, must end with .local and can only contain a single '.'",
)
// ErrRunCanceled indicates a run operation was canceled by its individual done // ErrRunCanceled indicates a run operation was canceled by its individual done.
ErrRunCanceled = errors.New("run was canceled by done") ErrRunCanceled = errors.New("run was canceled by done")
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr. // ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
@@ -113,7 +117,7 @@ var (
// ErrUnknownCandidateTyp indicates that a candidate had a unknown type value. // ErrUnknownCandidateTyp indicates that a candidate had a unknown type value.
ErrUnknownCandidateTyp = errors.New("unknown candidate typ") ErrUnknownCandidateTyp = errors.New("unknown candidate typ")
// ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed // ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed.
ErrDetermineNetworkType = errors.New("unable to determine networkType") ErrDetermineNetworkType = errors.New("unable to determine networkType")
errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate") errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate")
@@ -144,5 +148,5 @@ var (
// UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now. // UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now.
// will be used in the future. // will be used in the future.
// errListenUnspecified = errors.New("can't listen on unspecified address") // errListenUnspecified = errors.New("can't listen on unspecified address").
) )

View File

@@ -26,7 +26,7 @@ var (
localHTTPPort, remoteHTTPPort int localHTTPPort, remoteHTTPPort int
) )
// HTTP Listener to get ICE Credentials from remote Peer // HTTP Listener to get ICE Credentials from remote Peer.
func remoteAuth(_ http.ResponseWriter, r *http.Request) { func remoteAuth(_ http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
panic(err) panic(err)
@@ -36,7 +36,7 @@ func remoteAuth(_ http.ResponseWriter, r *http.Request) {
remoteAuthChannel <- r.PostForm["pwd"][0] remoteAuthChannel <- r.PostForm["pwd"][0]
} }
// HTTP Listener to get ICE Candidate from remote Peer // HTTP Listener to get ICE Candidate from remote Peer.
func remoteCandidate(_ http.ResponseWriter, r *http.Request) { func remoteCandidate(_ http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
panic(err) panic(err)

View File

@@ -13,10 +13,13 @@ func validateIPString(ipStr string) (net.IP, bool, error) {
if ip == nil { if ip == nil {
return nil, false, ErrInvalidNAT1To1IPMapping return nil, false, ErrInvalidNAT1To1IPMapping
} }
return ip, (ip.To4() != nil), nil return ip, (ip.To4() != nil), nil
} }
// ipMapping holds the mapping of local and external IP address for a particular IP family // ipMapping holds the mapping of local and external IP address
//
// for a particular IP family.
type ipMapping struct { type ipMapping struct {
ipSole net.IP // When non-nil, this is the sole external IP for one local IP assumed ipSole net.IP // When non-nil, this is the sole external IP for one local IP assumed
ipMap map[string]net.IP // Local-to-external IP mapping (k: local, v: external) ipMap map[string]net.IP // Local-to-external IP mapping (k: local, v: external)
@@ -75,7 +78,11 @@ type externalIPMapper struct {
candidateType CandidateType candidateType CandidateType
} }
func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIPMapper, error) { //nolint:gocognit //nolint:gocognit,cyclop
func newExternalIPMapper(
candidateType CandidateType,
ips []string,
) (*externalIPMapper, error) {
if len(ips) == 0 { if len(ips) == 0 {
return nil, nil //nolint:nilnil return nil, nil //nolint:nilnil
} }
@@ -85,7 +92,7 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
return nil, ErrUnsupportedNAT1To1IPCandidateType return nil, ErrUnsupportedNAT1To1IPCandidateType
} }
m := &externalIPMapper{ mapper := &externalIPMapper{
ipv4Mapping: ipMapping{ipMap: map[string]net.IP{}}, ipv4Mapping: ipMapping{ipMap: map[string]net.IP{}},
ipv6Mapping: ipMapping{ipMap: map[string]net.IP{}}, ipv6Mapping: ipMapping{ipMap: map[string]net.IP{}},
candidateType: candidateType, candidateType: candidateType,
@@ -101,13 +108,13 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(ipPair) == 1 { if len(ipPair) == 1 { //nolint:nestif
if isExtIPv4 { if isExtIPv4 {
if err := m.ipv4Mapping.setSoleIP(extIP); err != nil { if err := mapper.ipv4Mapping.setSoleIP(extIP); err != nil {
return nil, err return nil, err
} }
} else { } else {
if err := m.ipv6Mapping.setSoleIP(extIP); err != nil { if err := mapper.ipv6Mapping.setSoleIP(extIP); err != nil {
return nil, err return nil, err
} }
} }
@@ -121,7 +128,7 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
return nil, ErrInvalidNAT1To1IPMapping return nil, ErrInvalidNAT1To1IPMapping
} }
if err := m.ipv4Mapping.addIPMapping(locIP, extIP); err != nil { if err := mapper.ipv4Mapping.addIPMapping(locIP, extIP); err != nil {
return nil, err return nil, err
} }
} else { } else {
@@ -129,14 +136,14 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
return nil, ErrInvalidNAT1To1IPMapping return nil, ErrInvalidNAT1To1IPMapping
} }
if err := m.ipv6Mapping.addIPMapping(locIP, extIP); err != nil { if err := mapper.ipv6Mapping.addIPMapping(locIP, extIP); err != nil {
return nil, err return nil, err
} }
} }
} }
} }
return m, nil return mapper, nil
} }
func (m *externalIPMapper) findExternalIP(localIPStr string) (net.IP, error) { func (m *externalIPMapper) findExternalIP(localIPStr string) (net.IP, error) {

View File

@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestExternalIPMapper(t *testing.T) { func TestExternalIPMapper(t *testing.T) { //nolint:maintidx
t.Run("validateIPString", func(t *testing.T) { t.Run("validateIPString", func(t *testing.T) {
var ip net.IP var ip net.IP
var isIPv4 bool var isIPv4 bool
@@ -31,165 +31,165 @@ func TestExternalIPMapper(t *testing.T) {
}) })
t.Run("newExternalIPMapper", func(t *testing.T) { t.Run("newExternalIPMapper", func(t *testing.T) {
var m *externalIPMapper var mapper *externalIPMapper
var err error var err error
// ips being nil should succeed but mapper will be nil also // ips being nil should succeed but mapper will be nil also
m, err = newExternalIPMapper(CandidateTypeUnspecified, nil) mapper, err = newExternalIPMapper(CandidateTypeUnspecified, nil)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// ips being empty should succeed but mapper will still be nil // ips being empty should succeed but mapper will still be nil
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{}) mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{})
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// IPv4 with no explicit local IP, defaults to CandidateTypeHost // IPv4 with no explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4", "1.2.3.4",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match") require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.NotNil(t, m.ipv4Mapping.ipSole) require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.Nil(t, m.ipv6Mapping.ipSole) require.Nil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// IPv4 with no explicit local IP, using CandidateTypeServerReflexive // IPv4 with no explicit local IP, using CandidateTypeServerReflexive
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"1.2.3.4", "1.2.3.4",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeServerReflexive, m.candidateType, "should match") require.Equal(t, CandidateTypeServerReflexive, mapper.candidateType, "should match")
require.NotNil(t, m.ipv4Mapping.ipSole) require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.Nil(t, m.ipv6Mapping.ipSole) require.Nil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// IPv4 with no explicit local IP, defaults to CandidateTypeHost // IPv4 with no explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2601:4567::5678", "2601:4567::5678",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match") require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.Nil(t, m.ipv4Mapping.ipSole) require.Nil(t, mapper.ipv4Mapping.ipSole)
require.NotNil(t, m.ipv6Mapping.ipSole) require.NotNil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// IPv4 and IPv6 in the mix // IPv4 and IPv6 in the mix
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4", "1.2.3.4",
"2601:4567::5678", "2601:4567::5678",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match") require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.NotNil(t, m.ipv4Mapping.ipSole) require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.NotNil(t, m.ipv6Mapping.ipSole) require.NotNil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// Unsupported candidate type - CandidateTypePeerReflexive // Unsupported candidate type - CandidateTypePeerReflexive
m, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
"1.2.3.4", "1.2.3.4",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Unsupported candidate type - CandidateTypeRelay // Unsupported candidate type - CandidateTypeRelay
m, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
"1.2.3.4", "1.2.3.4",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Cannot duplicate mapping IPv4 family // Cannot duplicate mapping IPv4 family
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"1.2.3.4", "1.2.3.4",
"5.6.7.8", "5.6.7.8",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Cannot duplicate mapping IPv6 family // Cannot duplicate mapping IPv6 family
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"2201::1", "2201::1",
"2201::0002", "2201::0002",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Invalide external IP string // Invalide external IP string
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"bad.2.3.4", "bad.2.3.4",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Invalide local IP string // Invalide local IP string
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{ mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
"1.2.3.4/10.0.0.bad", "1.2.3.4/10.0.0.bad",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
}) })
t.Run("newExternalIPMapper with explicit local IP", func(t *testing.T) { t.Run("newExternalIPMapper with explicit local IP", func(t *testing.T) {
var m *externalIPMapper var mapper *externalIPMapper
var err error var err error
// IPv4 with explicit local IP, defaults to CandidateTypeHost // IPv4 with explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/10.0.0.1", "1.2.3.4/10.0.0.1",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
require.Equal(t, CandidateTypeHost, m.candidateType, "should match") require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
require.Nil(t, m.ipv4Mapping.ipSole) require.Nil(t, mapper.ipv4Mapping.ipSole)
require.Nil(t, m.ipv6Mapping.ipSole) require.Nil(t, mapper.ipv6Mapping.ipSole)
require.Equal(t, 1, len(m.ipv4Mapping.ipMap), "should match") require.Equal(t, 1, len(mapper.ipv4Mapping.ipMap), "should match")
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match") require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
// Cannot assign two ext IPs for one local IPv4 // Cannot assign two ext IPs for one local IPv4
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/10.0.0.1", "1.2.3.4/10.0.0.1",
"1.2.3.5/10.0.0.1", "1.2.3.5/10.0.0.1",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Cannot assign two ext IPs for one local IPv6 // Cannot assign two ext IPs for one local IPv6
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2200::1/fe80::1", "2200::1/fe80::1",
"2200::0002/fe80::1", "2200::0002/fe80::1",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Cannot mix different IP family in a pair (1) // Cannot mix different IP family in a pair (1)
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2200::1/10.0.0.1", "2200::1/10.0.0.1",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Cannot mix different IP family in a pair (2) // Cannot mix different IP family in a pair (2)
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/fe80::1", "1.2.3.4/fe80::1",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
// Invalid pair // Invalid pair
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/192.168.0.2/10.0.0.1", "1.2.3.4/192.168.0.2/10.0.0.1",
}) })
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
require.Nil(t, m, "should be nil") require.Nil(t, mapper, "should be nil")
}) })
t.Run("newExternalIPMapper with implicit and explicit local IP", func(t *testing.T) { t.Run("newExternalIPMapper with implicit and explicit local IP", func(t *testing.T) {
@@ -209,100 +209,100 @@ func TestExternalIPMapper(t *testing.T) {
}) })
t.Run("findExternalIP without explicit local IP", func(t *testing.T) { t.Run("findExternalIP without explicit local IP", func(t *testing.T) {
var m *externalIPMapper var mapper *externalIPMapper
var err error var err error
var extIP net.IP var extIP net.IP
// IPv4 with explicit local IP, defaults to CandidateTypeHost // IPv4 with explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4", "1.2.3.4",
"2200::1", "2200::1",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
require.NotNil(t, m.ipv4Mapping.ipSole) require.NotNil(t, mapper.ipv4Mapping.ipSole)
require.NotNil(t, m.ipv6Mapping.ipSole) require.NotNil(t, mapper.ipv6Mapping.ipSole)
// Find external IPv4 // Find external IPv4
extIP, err = m.findExternalIP("10.0.0.1") extIP, err = mapper.findExternalIP("10.0.0.1")
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "1.2.3.4", extIP.String(), "should match") require.Equal(t, "1.2.3.4", extIP.String(), "should match")
// Find external IPv6 // Find external IPv6
extIP, err = m.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose extIP, err = mapper.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "2200::1", extIP.String(), "should match") require.Equal(t, "2200::1", extIP.String(), "should match")
// Bad local IP string // Bad local IP string
_, err = m.findExternalIP("really.bad") _, err = mapper.findExternalIP("really.bad")
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
}) })
t.Run("findExternalIP with explicit local IP", func(t *testing.T) { t.Run("findExternalIP with explicit local IP", func(t *testing.T) {
var m *externalIPMapper var mapper *externalIPMapper
var err error var err error
var extIP net.IP var extIP net.IP
// IPv4 with explicit local IP, defaults to CandidateTypeHost // IPv4 with explicit local IP, defaults to CandidateTypeHost
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4/10.0.0.1", "1.2.3.4/10.0.0.1",
"1.2.3.5/10.0.0.2", "1.2.3.5/10.0.0.2",
"2200::1/fe80::1", "2200::1/fe80::1",
"2200::2/fe80::2", "2200::2/fe80::2",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.NotNil(t, m, "should not be nil") require.NotNil(t, mapper, "should not be nil")
// Find external IPv4 // Find external IPv4
extIP, err = m.findExternalIP("10.0.0.1") extIP, err = mapper.findExternalIP("10.0.0.1")
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "1.2.3.4", extIP.String(), "should match") require.Equal(t, "1.2.3.4", extIP.String(), "should match")
extIP, err = m.findExternalIP("10.0.0.2") extIP, err = mapper.findExternalIP("10.0.0.2")
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "1.2.3.5", extIP.String(), "should match") require.Equal(t, "1.2.3.5", extIP.String(), "should match")
_, err = m.findExternalIP("10.0.0.3") _, err = mapper.findExternalIP("10.0.0.3")
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
// Find external IPv6 // Find external IPv6
extIP, err = m.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose extIP, err = mapper.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "2200::1", extIP.String(), "should match") require.Equal(t, "2200::1", extIP.String(), "should match")
extIP, err = m.findExternalIP("fe80::0002") // Use '0002' instead of '2' on purpose extIP, err = mapper.findExternalIP("fe80::0002") // Use '0002' instead of '2' on purpose
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "2200::2", extIP.String(), "should match") require.Equal(t, "2200::2", extIP.String(), "should match")
_, err = m.findExternalIP("fe80::3") _, err = mapper.findExternalIP("fe80::3")
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
// Bad local IP string // Bad local IP string
_, err = m.findExternalIP("really.bad") _, err = mapper.findExternalIP("really.bad")
require.Error(t, err, "should fail") require.Error(t, err, "should fail")
}) })
t.Run("findExternalIP with empty map", func(t *testing.T) { t.Run("findExternalIP with empty map", func(t *testing.T) {
var m *externalIPMapper var mapper *externalIPMapper
var err error var err error
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"1.2.3.4", "1.2.3.4",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
// Attempt to find IPv6 that does not exist in the map // Attempt to find IPv6 that does not exist in the map
extIP, err := m.findExternalIP("fe80::1") extIP, err := mapper.findExternalIP("fe80::1")
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "fe80::1", extIP.String(), "should match") require.Equal(t, "fe80::1", extIP.String(), "should match")
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{ mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
"2200::1", "2200::1",
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
// Attempt to find IPv4 that does not exist in the map // Attempt to find IPv4 that does not exist in the map
extIP, err = m.findExternalIP("10.0.0.1") extIP, err = mapper.findExternalIP("10.0.0.1")
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
require.Equal(t, "10.0.0.1", extIP.String(), "should match") require.Equal(t, "10.0.0.1", extIP.String(), "should match")
}) })

119
gather.go
View File

@@ -21,10 +21,11 @@ import (
"github.com/pion/turn/v4" "github.com/pion/turn/v4"
) )
// Close a net.Conn and log if we have a failure // Close a net.Conn and log if we have a failure.
func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...interface{}) { func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...interface{}) {
if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) { if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) {
log.Warnf("Connection is not allocated: "+msg, args...) log.Warnf("Connection is not allocated: "+msg, args...)
return return
} }
@@ -41,9 +42,11 @@ func (a *Agent) GatherCandidates() error {
if runErr := a.loop.Run(a.loop, func(ctx context.Context) { if runErr := a.loop.Run(a.loop, func(ctx context.Context) {
if a.gatheringState != GatheringStateNew { if a.gatheringState != GatheringStateNew {
gatherErr = ErrMultipleGatherAttempted gatherErr = ErrMultipleGatherAttempted
return return
} else if a.onCandidateHdlr.Load() == nil { } else if a.onCandidateHdlr.Load() == nil {
gatherErr = ErrNoOnCandidateHandler gatherErr = ErrNoOnCandidateHandler
return return
} }
@@ -57,13 +60,15 @@ func (a *Agent) GatherCandidates() error {
}); runErr != nil { }); runErr != nil {
return runErr return runErr
} }
return gatherErr return gatherErr
} }
func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { //nolint:cyclop
defer close(done) defer close(done)
if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck
a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err) a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err)
return return
} }
@@ -111,7 +116,8 @@ func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) {
} }
} }
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) { //nolint:gocognit //nolint:gocognit,gocyclo,cyclop
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) {
networks := map[string]struct{}{} networks := map[string]struct{}{}
for _, networkType := range networkTypes { for _, networkType := range networkTypes {
if networkType.IsTCP() { if networkType.IsTCP() {
@@ -132,16 +138,19 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
if err != nil { if err != nil {
a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err)
return return
} }
for _, addr := range localAddrs { for _, addr := range localAddrs {
mappedIP := addr mappedIP := addr
if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost { if a.mDNSMode != MulticastDNSModeQueryAndGather &&
a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
if _mappedIP, innerErr := a.extIPMapper.findExternalIP(addr.String()); innerErr == nil { if _mappedIP, innerErr := a.extIPMapper.findExternalIP(addr.String()); innerErr == nil {
conv, ok := netip.AddrFromSlice(_mappedIP) conv, ok := netip.AddrFromSlice(_mappedIP)
if !ok { if !ok {
a.log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String()) a.log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String())
continue continue
} }
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable
@@ -186,6 +195,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice())
if err != nil { if err != nil {
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag)
continue continue
} }
} else { } else {
@@ -194,6 +204,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice())
if err != nil { if err != nil {
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag)
continue continue
} }
muxConns = []net.PacketConn{conn} muxConns = []net.PacketConn{conn}
@@ -222,6 +233,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
}) })
if err != nil { if err != nil {
a.log.Warnf("Failed to listen %s %s", network, addr) a.log.Warnf("Failed to listen %s %s", network, addr)
continue continue
} }
@@ -229,6 +241,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
conns = append(conns, connAndPort{conn, udpConn.Port}) conns = append(conns, connAndPort{conn, udpConn.Port})
} else { } else {
a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag) a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag)
continue continue
} }
} }
@@ -245,21 +258,38 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
IsLocationTracked: isLocationTracked, IsLocationTracked: isLocationTracked,
} }
c, err := NewCandidateHost(&hostConfig) candidateHost, err := NewCandidateHost(&hostConfig)
if err != nil { if err != nil {
closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err) closeConnAndLog(
connAndPort.conn,
a.log,
"failed to create host candidate: %s %s %d: %v",
network, mappedIP,
connAndPort.port,
err,
)
continue continue
} }
if a.mDNSMode == MulticastDNSModeQueryAndGather { if a.mDNSMode == MulticastDNSModeQueryAndGather {
if err = c.setIPAddr(addr); err != nil { if err = candidateHost.setIPAddr(addr); err != nil {
closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err) closeConnAndLog(
connAndPort.conn,
a.log,
"failed to create host candidate: %s %s %d: %v",
network,
mappedIP,
connAndPort.port,
err,
)
continue continue
} }
} }
if err := a.addCandidate(ctx, c, connAndPort.conn); err != nil { if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil {
if closeErr := c.close(); closeErr != nil { if closeErr := candidateHost.close(); closeErr != nil {
a.log.Warnf("Failed to close candidate: %v", closeErr) a.log.Warnf("Failed to close candidate: %v", closeErr)
} }
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
@@ -287,10 +317,11 @@ func shouldFilterLocationTracked(candidateIP net.IP) bool {
if !ok { if !ok {
return false return false
} }
return shouldFilterLocationTrackedIP(addr) return shouldFilterLocationTrackedIP(addr)
} }
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit,cyclop
if a.udpMux == nil { if a.udpMux == nil {
return errUDPMuxDisabled return errUDPMuxDisabled
} }
@@ -317,6 +348,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
mappedIP, err := a.extIPMapper.findExternalIP(candidateIP.String()) mappedIP, err := a.extIPMapper.findExternalIP(candidateIP.String())
if err != nil { if err != nil {
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String()) a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
continue continue
} }
@@ -359,6 +391,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
c, err := NewCandidateHost(&hostConfig) c, err := NewCandidateHost(&hostConfig)
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err) closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err)
continue continue
} }
@@ -368,6 +401,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
} }
closeConnAndLog(conn, a.log, "failed to add candidate: %s %d: %v", candidateIP, udpAddr.Port, err) closeConnAndLog(conn, a.log, "failed to add candidate: %s %d: %v", candidateIP, udpAddr.Port, err)
continue continue
} }
@@ -391,26 +425,37 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
go func() { go func() {
defer wg.Done() defer wg.Done()
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0}) conn, err := listenUDPInPortRange(
a.net,
a.log,
int(a.portMax),
int(a.portMin),
network,
&net.UDPAddr{IP: nil, Port: 0},
)
if err != nil { if err != nil {
a.log.Warnf("Failed to listen %s: %v", network, err) a.log.Warnf("Failed to listen %s: %v", network, err)
return return
} }
lAddr, ok := conn.LocalAddr().(*net.UDPAddr) lAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok { if !ok {
closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but LocalAddr is not a UDPAddr") closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but LocalAddr is not a UDPAddr")
return return
} }
mappedIP, err := a.extIPMapper.findExternalIP(lAddr.IP.String()) mappedIP, err := a.extIPMapper.findExternalIP(lAddr.IP.String())
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but no external IP is found for %s", lAddr.IP.String()) closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but no external IP is found for %s", lAddr.IP.String())
return return
} }
if shouldFilterLocationTracked(mappedIP) { if shouldFilterLocationTracked(mappedIP) {
closeConnAndLog(conn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP) closeConnAndLog(conn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP)
return return
} }
@@ -429,6 +474,7 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
mappedIP.String(), mappedIP.String(),
lAddr.Port, lAddr.Port,
err) err)
return return
} }
@@ -442,7 +488,8 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
} }
} }
func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { //nolint:gocognit //nolint:gocognit,cyclop
func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
@@ -456,6 +503,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
udpAddr, ok := listenAddr.(*net.UDPAddr) udpAddr, ok := listenAddr.(*net.UDPAddr)
if !ok { if !ok {
a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr") a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr")
continue continue
} }
wg.Add(1) wg.Add(1)
@@ -466,23 +514,27 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
if err != nil { if err != nil {
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
return return
} }
if shouldFilterLocationTracked(serverAddr.IP) { if shouldFilterLocationTracked(serverAddr.IP) {
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
return return
} }
xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, a.stunGatherTimeout) xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, a.stunGatherTimeout)
if err != nil { if err != nil {
a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err) a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err)
return return
} }
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr) conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr)
if err != nil { if err != nil {
a.log.Warnf("Failed to find connection in UDPMuxSrflx %s %s: %v", network, url, err) a.log.Warnf("Failed to find connection in UDPMuxSrflx %s %s: %v", network, url, err)
return return
} }
@@ -500,6 +552,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
c, err := NewCandidateServerReflexive(&srflxConfig) c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
return return
} }
@@ -515,7 +568,8 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
} }
} }
func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { //nolint:gocognit //nolint:cyclop,gocognit
func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
@@ -533,17 +587,27 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
if err != nil { if err != nil {
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err) a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
return return
} }
if shouldFilterLocationTracked(serverAddr.IP) { if shouldFilterLocationTracked(serverAddr.IP) {
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
return return
} }
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0}) conn, err := listenUDPInPortRange(
a.net,
a.log,
int(a.portMax),
int(a.portMin),
network,
&net.UDPAddr{IP: nil, Port: 0},
)
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err)
return return
} }
// If the agent closes midway through the connection // If the agent closes midway through the connection
@@ -562,6 +626,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout) xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout)
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err) closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err)
return return
} }
@@ -580,6 +645,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
c, err := NewCandidateServerReflexive(&srflxConfig) c, err := NewCandidateServerReflexive(&srflxConfig)
if err != nil { if err != nil {
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err) closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
return return
} }
@@ -594,7 +660,8 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
} }
} }
func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { //nolint:gocognit //nolint:maintidx,gocognit,gocyclo,cyclop
func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()
@@ -605,9 +672,11 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
continue continue
case urls[i].Username == "": case urls[i].Username == "":
a.log.Errorf("Failed to gather relay candidates: %v", ErrUsernameEmpty) a.log.Errorf("Failed to gather relay candidates: %v", ErrUsernameEmpty)
return return
case urls[i].Password == "": case urls[i].Password == "":
a.log.Errorf("Failed to gather relay candidates: %v", ErrPasswordEmpty) a.log.Errorf("Failed to gather relay candidates: %v", ErrPasswordEmpty)
return return
} }
@@ -627,6 +696,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN: case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN:
if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil { if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil {
a.log.Warnf("Failed to listen %s: %v", network, err) a.log.Warnf("Failed to listen %s: %v", network, err)
return return
} }
@@ -638,6 +708,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr) conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil { if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr) a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr)
return return
} }
@@ -654,12 +725,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if connectErr != nil { if connectErr != nil {
a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr) a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr)
return return
} }
conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if connectErr != nil { if connectErr != nil {
a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr) a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr)
return return
} }
@@ -671,12 +744,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr) udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr)
if connectErr != nil { if connectErr != nil {
a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr) a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr)
return return
} }
udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr) udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr)
if dialErr != nil { if dialErr != nil {
a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr) a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr)
return return
} }
@@ -686,11 +761,13 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
}) })
if connectErr != nil { if connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return return
} }
if connectErr = conn.HandshakeContext(ctx); connectErr != nil { if connectErr = conn.HandshakeContext(ctx); connectErr != nil {
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr) a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
return return
} }
@@ -702,12 +779,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr) tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
if resolvErr != nil { if resolvErr != nil {
a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr) a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr)
return return
} }
tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if dialErr != nil { if dialErr != nil {
a.log.Warnf("Failed to connect to relay: %v", dialErr) a.log.Warnf("Failed to connect to relay: %v", dialErr)
return return
} }
@@ -721,6 +800,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
a.log.Errorf("Failed to close relay connection: %v", closeErr) a.log.Errorf("Failed to close relay connection: %v", closeErr)
} }
a.log.Warnf("Failed to connect to relay: %v", hsErr) a.log.Warnf("Failed to connect to relay: %v", hsErr)
return return
} }
@@ -730,6 +810,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
locConn = turn.NewSTUNConn(conn) locConn = turn.NewSTUNConn(conn)
default: default:
a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url) a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url)
return return
} }
@@ -743,12 +824,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
}) })
if err != nil { if err != nil {
closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err) closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err)
return return
} }
if err = client.Listen(); err != nil { if err = client.Listen(); err != nil {
client.Close() client.Close()
closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err) closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err)
return return
} }
@@ -756,6 +839,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
if err != nil { if err != nil {
client.Close() client.Close()
closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err) closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err)
return return
} }
@@ -763,6 +847,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
if shouldFilterLocationTracked(rAddr.IP) { if shouldFilterLocationTracked(rAddr.IP) {
a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP)
return return
} }
@@ -776,6 +861,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
RelayProtocol: relayProtocol, RelayProtocol: relayProtocol,
OnClose: func() error { OnClose: func() error {
client.Close() client.Close()
return locConn.Close() return locConn.Close()
}, },
} }
@@ -790,6 +876,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
client.Close() client.Close()
closeConnAndLog(locConn, a.log, "failed to create relay candidate: %s %s: %v", network, rAddr.String(), err) closeConnAndLog(locConn, a.log, "failed to create relay candidate: %s %s: %v", network, rAddr.String(), err)
return return
} }

View File

@@ -31,26 +31,32 @@ import (
) )
func TestListenUDP(t *testing.T) { func TestListenUDP(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) agent, err := NewAgent(&AgentConfig{})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) _, localAddrs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test") require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test")
require.NoError(t, err) require.NoError(t, err)
ip := localAddrs[0].AsSlice() ip := localAddrs[0].AsSlice()
conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
require.NoError(t, err, "listenUDP error with no port restriction") require.NoError(t, err, "listenUDP error with no port restriction")
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn") require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
_, err = listenUDPInPortRange(a.net, a.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) _, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
require.Equal(t, err, ErrPort, "listenUDP with invalid port range did not return ErrPort") require.Equal(t, err, ErrPort, "listenUDP with invalid port range did not return ErrPort")
conn, err = listenUDPInPortRange(a.net, a.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
require.NoError(t, err, "listenUDP error with no port restriction") require.NoError(t, err, "listenUDP error with no port restriction")
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn") require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
@@ -64,7 +70,7 @@ func TestListenUDP(t *testing.T) {
result := make([]int, 0, total) result := make([]int, 0, total)
portRange := make([]int, 0, total) portRange := make([]int, 0, total)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
conn, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0}) conn, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
require.NoError(t, err, "listenUDP error with no port restriction") require.NoError(t, err, "listenUDP error with no port restriction")
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn") require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
@@ -85,7 +91,7 @@ func TestListenUDP(t *testing.T) {
if !reflect.DeepEqual(result, portRange) { if !reflect.DeepEqual(result, portRange) {
t.Fatalf("listenUDP with port restriction [%d, %d], got:%v, want:%v", portMin, portMax, result, portRange) t.Fatalf("listenUDP with port restriction [%d, %d], got:%v, want:%v", portMin, portMax, result, portRange)
} }
_, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0}) _, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax) require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax)
} }
@@ -94,23 +100,23 @@ func TestGatherConcurrency(t *testing.T) {
defer test.TimeOut(time.Second * 30).Stop() defer test.TimeOut(time.Second * 30).Stop()
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: true, IncludeLoopback: true,
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(Candidate) { require.NoError(t, agent.OnCandidate(func(Candidate) {
candidateGatheredFunc() candidateGatheredFunc()
})) }))
// Testing for panic // Testing for panic
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_ = a.GatherCandidates() _ = agent.GatherCandidates()
} }
<-candidateGathered.Done() <-candidateGathered.Done()
@@ -194,26 +200,27 @@ func TestLoopbackCandidate(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
tcase := tc tcase := tc
t.Run(tcase.name, func(t *testing.T) { t.Run(tcase.name, func(t *testing.T) {
a, err := NewAgent(tc.agentConfig) agent, err := NewAgent(tc.agentConfig)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
var loopback int32 var loopback int32
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil { if c != nil {
if net.ParseIP(c.Address()).IsLoopback() { if net.ParseIP(c.Address()).IsLoopback() {
atomic.StoreInt32(&loopback, 1) atomic.StoreInt32(&loopback, 1)
} }
} else { } else {
candidateGatheredFunc() candidateGatheredFunc()
return return
} }
t.Log(c.NetworkType(), c.Priority(), c) t.Log(c.NetworkType(), c.Priority(), c)
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done() <-candidateGathered.Done()
@@ -226,7 +233,7 @@ func TestLoopbackCandidate(t *testing.T) {
require.NoError(t, muxUnspecDefault.Close()) require.NoError(t, muxUnspecDefault.Close())
} }
// Assert that STUN gathering is done concurrently // Assert that STUN gathering is done concurrently.
func TestSTUNConcurrency(t *testing.T) { func TestSTUNConcurrency(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -273,7 +280,7 @@ func TestSTUNConcurrency(t *testing.T) {
_ = listener.Close() _ = listener.Close()
}() }()
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
Urls: urls, Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive}, CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive},
@@ -287,29 +294,36 @@ func TestSTUNConcurrency(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
candidateGatheredFunc() candidateGatheredFunc()
return return
} }
t.Log(c.NetworkType(), c.Priority(), c) t.Log(c.NetworkType(), c.Priority(), c)
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done() <-candidateGathered.Done()
} }
// Assert that TURN gathering is done concurrently // Assert that TURN gathering is done concurrently.
func TestTURNConcurrency(t *testing.T) { func TestTURNConcurrency(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop() defer test.TimeOut(time.Second * 30).Stop()
runTest := func(protocol stun.ProtoType, scheme stun.SchemeType, packetConn net.PacketConn, listener net.Listener, serverPort int) { runTest := func(
protocol stun.ProtoType,
scheme stun.SchemeType,
packetConn net.PacketConn,
listener net.Listener,
serverPort int,
) {
packetConnConfigs := []turn.PacketConnConfig{} packetConnConfigs := []turn.PacketConnConfig{}
if packetConn != nil { if packetConn != nil {
packetConnConfigs = append(packetConnConfigs, turn.PacketConnConfig{ packetConnConfigs = append(packetConnConfigs, turn.PacketConnConfig{
@@ -357,7 +371,7 @@ func TestTURNConcurrency(t *testing.T) {
Port: serverPort, Port: serverPort,
}) })
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeRelay}, CandidateTypes: []CandidateType{CandidateTypeRelay},
InsecureSkipVerify: true, InsecureSkipVerify: true,
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
@@ -365,16 +379,16 @@ func TestTURNConcurrency(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil { if c != nil {
candidateGatheredFunc() candidateGatheredFunc()
} }
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done() <-candidateGathered.Done()
} }
@@ -413,16 +427,20 @@ func TestTURNConcurrency(t *testing.T) {
require.NoError(t, genErr) require.NoError(t, genErr)
serverPort := randomPort(t) serverPort := randomPort(t)
serverListener, err := dtls.Listen("udp", &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort}, &dtls.Config{ serverListener, err := dtls.Listen(
Certificates: []tls.Certificate{certificate}, "udp",
}) &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort},
&dtls.Config{
Certificates: []tls.Certificate{certificate},
},
)
require.NoError(t, err) require.NoError(t, err)
runTest(stun.ProtoTypeUDP, stun.SchemeTypeTURNS, nil, serverListener, serverPort) runTest(stun.ProtoTypeUDP, stun.SchemeTypeTURNS, nil, serverListener, serverPort)
}) })
} }
// Assert that STUN and TURN gathering are done concurrently // Assert that STUN and TURN gathering are done concurrently.
func TestSTUNTURNConcurrency(t *testing.T) { func TestSTUNTURNConcurrency(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -464,25 +482,26 @@ func TestSTUNTURNConcurrency(t *testing.T) {
Password: "password", Password: "password",
}) })
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
Urls: urls, Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay}, CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay},
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
{ {
gatherLim := test.TimeOut(time.Second * 3) // As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s) // As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s)
gatherLim := test.TimeOut(time.Second * 3)
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil { if c != nil {
candidateGatheredFunc() candidateGatheredFunc()
} }
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done() <-candidateGathered.Done()
gatherLim.Stop() gatherLim.Stop()
@@ -528,24 +547,24 @@ func TestTURNSrflx(t *testing.T) {
Password: "password", Password: "password",
}} }}
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
Urls: urls, Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay}, CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay},
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c != nil && c.Type() == CandidateTypeServerReflexive { if c != nil && c.Type() == CandidateTypeServerReflexive {
candidateGatheredFunc() candidateGatheredFunc()
} }
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done() <-candidateGathered.Done()
} }
@@ -580,6 +599,7 @@ func (m *mockConn) SetWriteDeadline(time.Time) error { return io.EOF }
func (m *mockProxy) Dial(string, string) (net.Conn, error) { func (m *mockProxy) Dial(string, string) (net.Conn, error) {
m.proxyWasDialed() m.proxyWasDialed()
return &mockConn{}, nil return &mockConn{}, nil
} }
@@ -599,7 +619,7 @@ func TestTURNProxyDialer(t *testing.T) {
proxyDialer, err := proxy.FromURL(tcpProxyURI, proxy.Direct) proxyDialer, err := proxy.FromURL(tcpProxyURI, proxy.Direct)
require.NoError(t, err) require.NoError(t, err)
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
CandidateTypes: []CandidateType{CandidateTypeRelay}, CandidateTypes: []CandidateType{CandidateTypeRelay},
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
Urls: []*stun.URI{ Urls: []*stun.URI{
@@ -616,17 +636,17 @@ func TestTURNProxyDialer(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGatherFinish, candidateGatherFinishFunc := context.WithCancel(context.Background()) candidateGatherFinish, candidateGatherFinishFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
candidateGatherFinishFunc() candidateGatherFinishFunc()
} }
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGatherFinish.Done() <-candidateGatherFinish.Done()
<-proxyWasDialed.Done() <-proxyWasDialed.Done()
} }
@@ -651,31 +671,31 @@ func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) {
_ = mux.Close() _ = mux.Close()
}() }()
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"}, NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeHost, NAT1To1IPCandidateType: CandidateTypeHost,
UDPMux: mux, UDPMux: mux,
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
gatherCandidateDone := make(chan struct{}) gatherCandidateDone := make(chan struct{})
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
close(gatherCandidateDone) close(gatherCandidateDone)
} else { } else {
require.Equal(t, "1.2.3.4", c.Address()) require.Equal(t, "1.2.3.4", c.Address())
} }
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-gatherCandidateDone <-gatherCandidateDone
require.NotEqual(t, 0, len(mux.connsIPv4)) require.NotEqual(t, 0, len(mux.connsIPv4))
} }
// Assert that candidates are given for each mux in a MultiUDPMux // Assert that candidates are given for each mux in a MultiUDPMux.
func TestMultiUDPMuxUsage(t *testing.T) { func TestMultiUDPMuxUsage(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -700,25 +720,26 @@ func TestMultiUDPMuxUsage(t *testing.T) {
}() }()
} }
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
CandidateTypes: []CandidateType{CandidateTypeHost}, CandidateTypes: []CandidateType{CandidateTypeHost},
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...), UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateCh := make(chan Candidate) candidateCh := make(chan Candidate)
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
close(candidateCh) close(candidateCh)
return return
} }
candidateCh <- c candidateCh <- c
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
portFound := make(map[int]bool) portFound := make(map[int]bool)
for c := range candidateCh { for c := range candidateCh {
@@ -731,7 +752,7 @@ func TestMultiUDPMuxUsage(t *testing.T) {
} }
} }
// Assert that candidates are given for each mux in a MultiTCPMux // Assert that candidates are given for each mux in a MultiTCPMux.
func TestMultiTCPMuxUsage(t *testing.T) { func TestMultiTCPMuxUsage(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -757,25 +778,26 @@ func TestMultiTCPMuxUsage(t *testing.T) {
})) }))
} }
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
CandidateTypes: []CandidateType{CandidateTypeHost}, CandidateTypes: []CandidateType{CandidateTypeHost},
TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...), TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...),
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateCh := make(chan Candidate) candidateCh := make(chan Candidate)
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
close(candidateCh) close(candidateCh)
return return
} }
candidateCh <- c candidateCh <- c
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
portFound := make(map[int]bool) portFound := make(map[int]bool)
for c := range candidateCh { for c := range candidateCh {
@@ -790,7 +812,7 @@ func TestMultiTCPMuxUsage(t *testing.T) {
} }
} }
// Assert that UniversalUDPMux is used while gathering when configured in the Agent // Assert that UniversalUDPMux is used while gathering when configured in the Agent.
func TestUniversalUDPMuxUsage(t *testing.T) { func TestUniversalUDPMuxUsage(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
@@ -816,7 +838,7 @@ func TestUniversalUDPMuxUsage(t *testing.T) {
}) })
} }
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes(), NetworkTypes: supportedNetworkTypes(),
Urls: urls, Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeServerReflexive}, CandidateTypes: []CandidateType{CandidateTypeServerReflexive},
@@ -828,26 +850,32 @@ func TestUniversalUDPMuxUsage(t *testing.T) {
if aClosed { if aClosed {
return return
} }
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
require.NoError(t, a.OnCandidate(func(c Candidate) { require.NoError(t, agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
candidateGatheredFunc() candidateGatheredFunc()
return return
} }
t.Log(c.NetworkType(), c.Priority(), c) t.Log(c.NetworkType(), c.Priority(), c)
})) }))
require.NoError(t, a.GatherCandidates()) require.NoError(t, agent.GatherCandidates())
<-candidateGathered.Done() <-candidateGathered.Done()
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
aClosed = true aClosed = true
// Twice because of 2 STUN servers configured // Twice because of 2 STUN servers configured
require.Equal(t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called") require.Equal(
t,
numSTUNS,
udpMuxSrflx.getXORMappedAddrUsedTimes,
"expected times that GetXORMappedAddr should be called",
)
// One for Restart() when agent has been initialized and one time when Close() the agent // One for Restart() when agent has been initialized and one time when Close() the agent
require.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called") require.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called")
// Twice because of 2 STUN servers configured // Twice because of 2 STUN servers configured
@@ -871,6 +899,7 @@ func (m *universalUDPMuxMock) GetConnForURL(string, string, net.Addr) (net.Packe
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.getConnForURLTimes++ m.getConnForURLTimes++
return m.conn, nil return m.conn, nil
} }
@@ -878,6 +907,7 @@ func (m *universalUDPMuxMock) GetXORMappedAddr(net.Addr, time.Duration) (*stun.X
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
m.getXORMappedAddrUsedTimes++ m.getXORMappedAddrUsedTimes++
return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil
} }

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestVNetGather(t *testing.T) { func TestVNetGather(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory := logging.NewDefaultLoggerFactory()
@@ -51,7 +51,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to parse CIDR: %s", err) t.Fatalf("Failed to parse CIDR: %s", err)
} }
r, err := vnet.NewRouter(&vnet.RouterConfig{ router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: cider, CIDR: cider,
LoggerFactory: loggerFactory, LoggerFactory: loggerFactory,
}) })
@@ -64,7 +64,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create a Net: %s", err) t.Fatalf("Failed to create a Net: %s", err)
} }
err = r.AddNet(nw) err = router.AddNet(nw)
if err != nil { if err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err) t.Fatalf("Failed to add a Net to the router: %s", err)
} }
@@ -94,7 +94,7 @@ func TestVNetGather(t *testing.T) {
}) })
t.Run("listenUDP", func(t *testing.T) { t.Run("listenUDP", func(t *testing.T) {
r, err := vnet.NewRouter(&vnet.RouterConfig{ router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24", CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory, LoggerFactory: loggerFactory,
}) })
@@ -107,20 +107,26 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create a Net: %s", err) t.Fatalf("Failed to create a Net: %s", err)
} }
err = r.AddNet(nw) err = router.AddNet(nw)
if err != nil { if err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err) t.Fatalf("Failed to add a Net to the router: %s", err)
} }
a, err := NewAgent(&AgentConfig{Net: nw}) agent, err := NewAgent(&AgentConfig{Net: nw})
if err != nil { if err != nil {
t.Fatalf("Failed to create agent: %s", err) t.Fatalf("Failed to create agent: %s", err)
} }
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) _, localAddrs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
if len(localAddrs) == 0 { if len(localAddrs) == 0 {
t.Fatal("localInterfaces found no interfaces, unable to test") t.Fatal("localInterfaces found no interfaces, unable to test")
} }
@@ -128,7 +134,7 @@ func TestVNetGather(t *testing.T) {
ip := localAddrs[0].AsSlice() ip := localAddrs[0].AsSlice()
conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
if err != nil { if err != nil {
t.Fatalf("listenUDP error with no port restriction %v", err) t.Fatalf("listenUDP error with no port restriction %v", err)
} else if conn == nil { } else if conn == nil {
@@ -139,12 +145,12 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("failed to close conn") t.Fatalf("failed to close conn")
} }
_, err = listenUDPInPortRange(a.net, a.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) _, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
if !errors.Is(err, ErrPort) { if !errors.Is(err, ErrPort) {
t.Fatal("listenUDP with invalid port range did not return ErrPort") t.Fatal("listenUDP with invalid port range did not return ErrPort")
} }
conn, err = listenUDPInPortRange(a.net, a.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0}) conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
if err != nil { if err != nil {
t.Fatalf("listenUDP error with no port restriction %v", err) t.Fatalf("listenUDP error with no port restriction %v", err)
} else if conn == nil { } else if conn == nil {
@@ -163,7 +169,7 @@ func TestVNetGather(t *testing.T) {
}) })
} }
func TestVNetGatherWithNAT1To1(t *testing.T) { func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory := logging.NewDefaultLoggerFactory()
@@ -206,7 +212,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
err = lan.AddNet(nw) err = lan.AddNet(nw)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{ NetworkTypes: []NetworkType{
NetworkTypeUDP4, NetworkTypeUDP4,
}, },
@@ -215,25 +221,25 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
done := make(chan struct{}) done := make(chan struct{})
err = a.OnCandidate(func(c Candidate) { err = agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
close(done) close(done)
} }
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
err = a.GatherCandidates() err = agent.GatherCandidates()
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
log.Debug("Wait until gathering is complete...") log.Debug("Wait until gathering is complete...")
<-done <-done
log.Debug("Gathering is done") log.Debug("Gathering is done")
candidates, err := a.GetLocalCandidates() candidates, err := agent.GetLocalCandidates()
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
if len(candidates) != 2 { if len(candidates) != 2 {
@@ -248,7 +254,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
} }
} }
if candidates[0].Address() == externalIP0 { if candidates[0].Address() == externalIP0 { //nolint:nestif
if candidates[1].Address() != externalIP1 { if candidates[1].Address() != externalIP1 {
t.Fatalf("Unexpected candidate IP: %s", candidates[1].Address()) t.Fatalf("Unexpected candidate IP: %s", candidates[1].Address())
} }
@@ -305,7 +311,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
err = lan.AddNet(nw) err = lan.AddNet(nw)
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
NetworkTypes: []NetworkType{ NetworkTypes: []NetworkType{
NetworkTypeUDP4, NetworkTypeUDP4,
}, },
@@ -317,25 +323,25 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
done := make(chan struct{}) done := make(chan struct{})
err = a.OnCandidate(func(c Candidate) { err = agent.OnCandidate(func(c Candidate) {
if c == nil { if c == nil {
close(done) close(done)
} }
}) })
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
err = a.GatherCandidates() err = agent.GatherCandidates()
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
log.Debug("Wait until gathering is complete...") log.Debug("Wait until gathering is complete...")
<-done <-done
log.Debug("Gathering is done") log.Debug("Gathering is done")
candidates, err := a.GetLocalCandidates() candidates, err := agent.GetLocalCandidates()
require.NoError(t, err, "should succeed") require.NoError(t, err, "should succeed")
if len(candidates) != 2 { if len(candidates) != 2 {
@@ -367,7 +373,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
loggerFactory := logging.NewDefaultLoggerFactory() loggerFactory := logging.NewDefaultLoggerFactory()
r, err := vnet.NewRouter(&vnet.RouterConfig{ router, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24", CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory, LoggerFactory: loggerFactory,
}) })
@@ -380,24 +386,31 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
t.Fatalf("Failed to create a Net: %s", err) t.Fatalf("Failed to create a Net: %s", err)
} }
if err = r.AddNet(nw); err != nil { if err = router.AddNet(nw); err != nil {
t.Fatalf("Failed to add a Net to the router: %s", err) t.Fatalf("Failed to add a Net to the router: %s", err)
} }
t.Run("InterfaceFilter should exclude the interface", func(t *testing.T) { t.Run("InterfaceFilter should exclude the interface", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
Net: nw, Net: nw,
InterfaceFilter: func(interfaceName string) (keep bool) { InterfaceFilter: func(interfaceName string) (keep bool) {
require.Equal(t, "eth0", interfaceName) require.Equal(t, "eth0", interfaceName)
return false return false
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) _, localIPs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NoError(t, err) require.NoError(t, err)
if len(localIPs) != 0 { if len(localIPs) != 0 {
@@ -406,19 +419,26 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
}) })
t.Run("IPFilter should exclude the IP", func(t *testing.T) { t.Run("IPFilter should exclude the IP", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
Net: nw, Net: nw,
IPFilter: func(ip net.IP) (keep bool) { IPFilter: func(ip net.IP) (keep bool) {
require.Equal(t, net.IP{1, 2, 3, 1}, ip) require.Equal(t, net.IP{1, 2, 3, 1}, ip)
return false return false
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) _, localIPs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NoError(t, err) require.NoError(t, err)
if len(localIPs) != 0 { if len(localIPs) != 0 {
@@ -427,19 +447,26 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
}) })
t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) { t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) {
a, err := NewAgent(&AgentConfig{ agent, err := NewAgent(&AgentConfig{
Net: nw, Net: nw,
InterfaceFilter: func(interfaceName string) (keep bool) { InterfaceFilter: func(interfaceName string) (keep bool) {
require.Equal(t, "eth0", interfaceName) require.Equal(t, "eth0", interfaceName)
return true return true
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, a.Close()) require.NoError(t, agent.Close())
}() }()
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) _, localIPs, err := localInterfaces(
agent.net,
agent.interfaceFilter,
agent.ipFilter,
[]NetworkType{NetworkTypeUDP4},
false,
)
require.NoError(t, err) require.NoError(t, err)
if len(localIPs) == 0 { if len(localIPs) == 0 {

30
ice.go
View File

@@ -3,33 +3,33 @@
package ice package ice
// ConnectionState is an enum showing the state of a ICE Connection // ConnectionState is an enum showing the state of a ICE Connection.
type ConnectionState int type ConnectionState int
// List of supported States // List of supported States.
const ( const (
// ConnectionStateUnknown represents an unknown state // ConnectionStateUnknown represents an unknown state.
ConnectionStateUnknown ConnectionState = iota ConnectionStateUnknown ConnectionState = iota
// ConnectionStateNew ICE agent is gathering addresses // ConnectionStateNew ICE agent is gathering addresses.
ConnectionStateNew ConnectionStateNew
// ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match // ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match.
ConnectionStateChecking ConnectionStateChecking
// ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs // ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs.
ConnectionStateConnected ConnectionStateConnected
// ConnectionStateCompleted ICE agent has finished // ConnectionStateCompleted ICE agent has finished.
ConnectionStateCompleted ConnectionStateCompleted
// ConnectionStateFailed ICE agent never could successfully connect // ConnectionStateFailed ICE agent never could successfully connect.
ConnectionStateFailed ConnectionStateFailed
// ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state // ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state.
ConnectionStateDisconnected ConnectionStateDisconnected
// ConnectionStateClosed ICE agent has finished and is no longer handling requests // ConnectionStateClosed ICE agent has finished and is no longer handling requests.
ConnectionStateClosed ConnectionStateClosed
) )
@@ -54,20 +54,20 @@ func (c ConnectionState) String() string {
} }
} }
// GatheringState describes the state of the candidate gathering process // GatheringState describes the state of the candidate gathering process.
type GatheringState int type GatheringState int
const ( const (
// GatheringStateUnknown represents an unknown state // GatheringStateUnknown represents an unknown state.
GatheringStateUnknown GatheringState = iota GatheringStateUnknown GatheringState = iota
// GatheringStateNew indicates candidate gathering is not yet started // GatheringStateNew indicates candidate gathering is not yet started.
GatheringStateNew GatheringStateNew
// GatheringStateGathering indicates candidate gathering is ongoing // GatheringStateGathering indicates candidate gathering is ongoing.
GatheringStateGathering GatheringStateGathering
// GatheringStateComplete indicates candidate gathering has been completed // GatheringStateComplete indicates candidate gathering has been completed.
GatheringStateComplete GatheringStateComplete
) )

View File

@@ -20,6 +20,7 @@ func (a tiebreaker) AddToAs(m *stun.Message, t stun.AttrType) error {
v := make([]byte, tiebreakerSize) v := make([]byte, tiebreakerSize)
binary.BigEndian.PutUint64(v, uint64(a)) binary.BigEndian.PutUint64(v, uint64(a))
m.Add(t, v) m.Add(t, v)
return nil return nil
} }
@@ -33,6 +34,7 @@ func (a *tiebreaker) GetFromAs(m *stun.Message, t stun.AttrType) error {
return err return err
} }
*a = tiebreaker(binary.BigEndian.Uint64(v)) *a = tiebreaker(binary.BigEndian.Uint64(v))
return nil return nil
} }
@@ -73,6 +75,7 @@ func (c AttrControl) AddTo(m *stun.Message) error {
if c.Role == Controlling { if c.Role == Controlling {
return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlling) return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlling)
} }
return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlled) return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlled)
} }
@@ -80,11 +83,14 @@ func (c AttrControl) AddTo(m *stun.Message) error {
func (c *AttrControl) GetFrom(m *stun.Message) error { func (c *AttrControl) GetFrom(m *stun.Message) error {
if m.Contains(stun.AttrICEControlling) { if m.Contains(stun.AttrICEControlling) {
c.Role = Controlling c.Role = Controlling
return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlling) return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlling)
} }
if m.Contains(stun.AttrICEControlled) { if m.Contains(stun.AttrICEControlled) {
c.Role = Controlled c.Role = Controlled
return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlled) return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlled)
} }
return stun.ErrAttributeNotFound return stun.ErrAttributeNotFound
} }

View File

@@ -12,11 +12,11 @@ import (
func TestControlled_GetFrom(t *testing.T) { //nolint:dupl func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message) m := new(stun.Message)
var c AttrControlled var attrCtr AttrControlled
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error") t.Error("unexpected error")
} }
if err := m.Build(stun.BindingRequest, &c); err != nil { if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
t.Error(err) t.Error(err)
} }
m1 := new(stun.Message) m1 := new(stun.Message)
@@ -27,7 +27,7 @@ func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
if err := c1.GetFrom(m1); err != nil { if err := c1.GetFrom(m1); err != nil {
t.Error(err) t.Error(err)
} }
if c1 != c { if c1 != attrCtr {
t.Error("not equal") t.Error("not equal")
} }
t.Run("IncorrectSize", func(t *testing.T) { t.Run("IncorrectSize", func(t *testing.T) {
@@ -42,11 +42,11 @@ func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
func TestControlling_GetFrom(t *testing.T) { //nolint:dupl func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message) m := new(stun.Message)
var c AttrControlling var attrCtr AttrControlling
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error") t.Error("unexpected error")
} }
if err := m.Build(stun.BindingRequest, &c); err != nil { if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
t.Error(err) t.Error(err)
} }
m1 := new(stun.Message) m1 := new(stun.Message)
@@ -57,7 +57,7 @@ func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
if err := c1.GetFrom(m1); err != nil { if err := c1.GetFrom(m1); err != nil {
t.Error(err) t.Error(err)
} }
if c1 != c { if c1 != attrCtr {
t.Error("not equal") t.Error("not equal")
} }
t.Run("IncorrectSize", func(t *testing.T) { t.Run("IncorrectSize", func(t *testing.T) {
@@ -70,7 +70,7 @@ func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
}) })
} }
func TestControl_GetFrom(t *testing.T) { func TestControl_GetFrom(t *testing.T) { //nolint:cyclop
t.Run("Blank", func(t *testing.T) { t.Run("Blank", func(t *testing.T) {
m := new(stun.Message) m := new(stun.Message)
var c AttrControl var c AttrControl
@@ -80,13 +80,13 @@ func TestControl_GetFrom(t *testing.T) {
}) })
t.Run("Controlling", func(t *testing.T) { //nolint:dupl t.Run("Controlling", func(t *testing.T) { //nolint:dupl
m := new(stun.Message) m := new(stun.Message)
var c AttrControl var attCtr AttrControl
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { if err := attCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error") t.Error("unexpected error")
} }
c.Role = Controlling attCtr.Role = Controlling
c.Tiebreaker = 4321 attCtr.Tiebreaker = 4321
if err := m.Build(stun.BindingRequest, &c); err != nil { if err := m.Build(stun.BindingRequest, &attCtr); err != nil {
t.Error(err) t.Error(err)
} }
m1 := new(stun.Message) m1 := new(stun.Message)
@@ -97,7 +97,7 @@ func TestControl_GetFrom(t *testing.T) {
if err := c1.GetFrom(m1); err != nil { if err := c1.GetFrom(m1); err != nil {
t.Error(err) t.Error(err)
} }
if c1 != c { if c1 != attCtr {
t.Error("not equal") t.Error("not equal")
} }
t.Run("IncorrectSize", func(t *testing.T) { t.Run("IncorrectSize", func(t *testing.T) {
@@ -111,13 +111,13 @@ func TestControl_GetFrom(t *testing.T) {
}) })
t.Run("Controlled", func(t *testing.T) { //nolint:dupl t.Run("Controlled", func(t *testing.T) { //nolint:dupl
m := new(stun.Message) m := new(stun.Message)
var c AttrControl var attrCtrl AttrControl
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { if err := attrCtrl.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error") t.Error("unexpected error")
} }
c.Role = Controlled attrCtrl.Role = Controlled
c.Tiebreaker = 1234 attrCtrl.Tiebreaker = 1234
if err := m.Build(stun.BindingRequest, &c); err != nil { if err := m.Build(stun.BindingRequest, &attrCtrl); err != nil {
t.Error(err) t.Error(err)
} }
m1 := new(stun.Message) m1 := new(stun.Message)
@@ -128,7 +128,7 @@ func TestControl_GetFrom(t *testing.T) {
if err := c1.GetFrom(m1); err != nil { if err := c1.GetFrom(m1); err != nil {
t.Error(err) t.Error(err)
} }
if c1 != c { if c1 != attrCtrl {
t.Error("not equal") t.Error("not equal")
} }
t.Run("IncorrectSize", func(t *testing.T) { t.Run("IncorrectSize", func(t *testing.T) {

View File

@@ -6,18 +6,19 @@ package atomic
import "sync/atomic" import "sync/atomic"
// Error is an atomic error // Error is an atomic error.
type Error struct { type Error struct {
v atomic.Value v atomic.Value
} }
// Store updates the value of the atomic variable // Store updates the value of the atomic variable.
func (a *Error) Store(err error) { func (a *Error) Store(err error) {
a.v.Store(struct{ error }{err}) a.v.Store(struct{ error }{err})
} }
// Load retrieves the current value of the atomic variable // Load retrieves the current value of the atomic variable.
func (a *Error) Load() error { func (a *Error) Load() error {
err, _ := a.v.Load().(struct{ error }) err, _ := a.v.Load().(struct{ error })
return err.error return err.error
} }

View File

@@ -11,7 +11,7 @@ import (
"time" "time"
) )
// MockPacketConn for tests // MockPacketConn for tests.
type MockPacketConn struct{} type MockPacketConn struct{}
func (m *MockPacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } //nolint:revive func (m *MockPacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } //nolint:revive

View File

@@ -8,18 +8,19 @@ import (
"net" "net"
) )
// Compile-time assertion // Compile-time assertion.
var _ net.PacketConn = (*PacketConn)(nil) var _ net.PacketConn = (*PacketConn)(nil)
// PacketConn wraps a net.Conn and emulates net.PacketConn // PacketConn wraps a net.Conn and emulates net.PacketConn.
type PacketConn struct { type PacketConn struct {
net.Conn net.Conn
} }
// ReadFrom reads a packet from the connection, // ReadFrom reads a packet from the connection.
func (f *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (f *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = f.Conn.Read(p) n, err = f.Conn.Read(p)
addr = f.Conn.RemoteAddr() addr = f.Conn.RemoteAddr()
return return
} }

View File

@@ -59,7 +59,7 @@ func GetXORMappedAddr(conn net.PacketConn, serverAddr net.Addr, timeout time.Dur
return &addr, nil return &addr, nil
} }
// AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value // AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value.
func AssertUsername(m *stun.Message, expectedUsername string) error { func AssertUsername(m *stun.Message, expectedUsername string) error {
var username stun.Username var username stun.Username
if err := username.GetFrom(m); err != nil { if err := username.GetFrom(m); err != nil {

View File

@@ -13,7 +13,7 @@ import (
atomicx "github.com/pion/ice/v4/internal/atomic" atomicx "github.com/pion/ice/v4/internal/atomic"
) )
// ErrClosed indicates that the loop has been stopped // ErrClosed indicates that the loop has been stopped.
var ErrClosed = errors.New("the agent is closed") var ErrClosed = errors.New("the agent is closed")
type task struct { type task struct {
@@ -21,7 +21,7 @@ type task struct {
done chan struct{} done chan struct{}
} }
// Loop runs submitted task serially in a dedicated Goroutine // Loop runs submitted task serially in a dedicated Goroutine.
type Loop struct { type Loop struct {
tasks chan task tasks chan task
@@ -31,7 +31,7 @@ type Loop struct {
err atomicx.Error err atomicx.Error
} }
// New creates and starts a new task loop // New creates and starts a new task loop.
func New(onClose func()) *Loop { func New(onClose func()) *Loop {
l := &Loop{ l := &Loop{
tasks: make(chan task), tasks: make(chan task),
@@ -40,6 +40,7 @@ func New(onClose func()) *Loop {
} }
go l.runLoop(onClose) go l.runLoop(onClose)
return l return l
} }
@@ -86,6 +87,7 @@ func (l *Loop) Run(ctx context.Context, t func(context.Context)) error {
return ctx.Err() return ctx.Err()
case l.tasks <- task{t, done}: case l.tasks <- task{t, done}:
<-done <-done
return nil return nil
} }
} }
@@ -113,7 +115,7 @@ func (l *Loop) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false return time.Time{}, false
} }
// Value is not supported for task loops // Value is not supported for task loops.
func (l *Loop) Value(interface{}) interface{} { func (l *Loop) Value(interface{}) interface{} {
return nil return nil
} }

28
mdns.go
View File

@@ -14,18 +14,19 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
// MulticastDNSMode represents the different Multicast modes ICE can run in // MulticastDNSMode represents the different Multicast modes ICE can run in.
type MulticastDNSMode byte type MulticastDNSMode byte
// MulticastDNSMode enum // MulticastDNSMode enum.
const ( const (
// MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs // MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs.
MulticastDNSModeDisabled MulticastDNSMode = iota + 1 MulticastDNSModeDisabled MulticastDNSMode = iota + 1
// MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs // MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs.
MulticastDNSModeQueryOnly MulticastDNSModeQueryOnly
// MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted, and local host candidates will use mDNS // MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted,
// and local host candidates will use mDNS.
MulticastDNSModeQueryAndGather MulticastDNSModeQueryAndGather
) )
@@ -33,11 +34,13 @@ func generateMulticastDNSName() (string, error) {
// https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering // https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering
// The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”. // The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”.
u, err := uuid.NewRandom() u, err := uuid.NewRandom()
return u.String() + ".local", err return u.String() + ".local", err
} }
//nolint:cyclop
func createMulticastDNS( func createMulticastDNS(
n transport.Net, netTransport transport.Net,
networkTypes []NetworkType, networkTypes []NetworkType,
interfaces []*transport.Interface, interfaces []*transport.Interface,
includeLoopback bool, includeLoopback bool,
@@ -57,6 +60,7 @@ func createMulticastDNS(
for _, nt := range networkTypes { for _, nt := range networkTypes {
if nt.IsIPv4() { if nt.IsIPv4() {
useV4 = true useV4 = true
continue continue
} }
if nt.IsIPv6() { if nt.IsIPv6() {
@@ -65,11 +69,11 @@ func createMulticastDNS(
} }
} }
addr4, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) addr4, mdnsErr := netTransport.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
if mdnsErr != nil { if mdnsErr != nil {
return nil, mDNSMode, mdnsErr return nil, mDNSMode, mdnsErr
} }
addr6, mdnsErr := n.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) addr6, mdnsErr := netTransport.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
if mdnsErr != nil { if mdnsErr != nil {
return nil, mDNSMode, mdnsErr return nil, mDNSMode, mdnsErr
} }
@@ -78,10 +82,11 @@ func createMulticastDNS(
var mdns4Err error var mdns4Err error
if useV4 { if useV4 {
var l transport.UDPConn var l transport.UDPConn
l, mdns4Err = n.ListenUDP("udp4", addr4) l, mdns4Err = netTransport.ListenUDP("udp4", addr4)
if mdns4Err != nil { if mdns4Err != nil {
// If ICE fails to start MulticastDNS server just warn the user and continue // If ICE fails to start MulticastDNS server just warn the user and continue
log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err) log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err)
return nil, MulticastDNSModeDisabled, nil return nil, MulticastDNSModeDisabled, nil
} }
pktConnV4 = ipv4.NewPacketConn(l) pktConnV4 = ipv4.NewPacketConn(l)
@@ -91,9 +96,10 @@ func createMulticastDNS(
var mdns6Err error var mdns6Err error
if useV6 { if useV6 {
var l transport.UDPConn var l transport.UDPConn
l, mdns6Err = n.ListenUDP("udp6", addr6) l, mdns6Err = netTransport.ListenUDP("udp6", addr6)
if mdns6Err != nil { if mdns6Err != nil {
log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err) log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err)
return nil, MulticastDNSModeDisabled, nil return nil, MulticastDNSModeDisabled, nil
} }
pktConnV6 = ipv6.NewPacketConn(l) pktConnV6 = ipv6.NewPacketConn(l)
@@ -119,6 +125,7 @@ func createMulticastDNS(
Interfaces: ifcs, Interfaces: ifcs,
IncludeLoopback: includeLoopback, IncludeLoopback: includeLoopback,
}) })
return conn, mDNSMode, err return conn, mDNSMode, err
case MulticastDNSModeQueryAndGather: case MulticastDNSModeQueryAndGather:
conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{
@@ -126,6 +133,7 @@ func createMulticastDNS(
IncludeLoopback: includeLoopback, IncludeLoopback: includeLoopback,
LocalNames: []string{mDNSName}, LocalNames: []string{mDNSName},
}) })
return conn, mDNSMode, err return conn, mDNSMode, err
default: default:
return nil, mDNSMode, nil return nil, mDNSMode, nil

41
net.go
View File

@@ -24,6 +24,7 @@ func isSupportedIPv6Partial(ip net.IP) bool {
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast) ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast)
return false return false
} }
return true return true
} }
@@ -33,10 +34,11 @@ func isZeros(ip net.IP) bool {
return false return false
} }
} }
return true return true
} }
//nolint:gocognit //nolint:gocognit,cyclop
func localInterfaces( func localInterfaces(
n transport.Net, n transport.Net,
interfaceFilter func(string) (keep bool), interfaceFilter func(string) (keep bool),
@@ -114,27 +116,35 @@ func localInterfaces(
filteredIfaces = append(filteredIfaces, ifaceCopy) filteredIfaces = append(filteredIfaces, ifaceCopy)
} }
} }
return filteredIfaces, ipAddrs, nil return filteredIfaces, ipAddrs, nil
} }
func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, portMin int, network string, lAddr *net.UDPAddr) (transport.UDPConn, error) { //nolint:cyclop
func listenUDPInPortRange(
netTransport transport.Net,
log logging.LeveledLogger,
portMax, portMin int,
network string,
lAddr *net.UDPAddr,
) (transport.UDPConn, error) {
if (lAddr.Port != 0) || ((portMin == 0) && (portMax == 0)) { if (lAddr.Port != 0) || ((portMin == 0) && (portMax == 0)) {
return n.ListenUDP(network, lAddr) return netTransport.ListenUDP(network, lAddr)
} }
var i, j int
i = portMin if portMin == 0 {
if i == 0 { portMin = 1024 // Start at 1024 which is non-privileged
i = 1024 // Start at 1024 which is non-privileged
} }
j = portMax
if j == 0 { if portMax == 0 {
j = 0xFFFF portMax = 0xFFFF
} }
if i > j {
if portMin > portMax {
return nil, ErrPort return nil, ErrPort
} }
portStart := globalMathRandomGenerator.Intn(j-i+1) + i portStart := globalMathRandomGenerator.Intn(portMax-portMin+1) + portMin
portCurrent := portStart portCurrent := portStart
for { for {
addr := &net.UDPAddr{ addr := &net.UDPAddr{
@@ -143,18 +153,19 @@ func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, p
Port: portCurrent, Port: portCurrent,
} }
c, e := n.ListenUDP(network, addr) c, e := netTransport.ListenUDP(network, addr)
if e == nil { if e == nil {
return c, e //nolint:nilerr return c, e //nolint:nilerr
} }
log.Debugf("Failed to listen %s: %v", lAddr.String(), e) log.Debugf("Failed to listen %s: %v", lAddr.String(), e)
portCurrent++ portCurrent++
if portCurrent > j { if portCurrent > portMax {
portCurrent = i portCurrent = portMin
} }
if portCurrent == portStart { if portCurrent == portStart {
break break
} }
} }
return nil, ErrPort return nil, ErrPort
} }

View File

@@ -54,6 +54,7 @@ func problematicNetworkInterfaces(s string) (keep bool) {
appleWirelessDirectLink := strings.Contains(s, "awdl") appleWirelessDirectLink := strings.Contains(s, "awdl")
appleLowLatencyWLANInterface := strings.Contains(s, "llw") appleLowLatencyWLANInterface := strings.Contains(s, "llw")
appleTunnelingInterface := strings.Contains(s, "utun") appleTunnelingInterface := strings.Contains(s, "utun")
return !defaultDockerBridgeNetwork && return !defaultDockerBridgeNetwork &&
!customDockerBridgeNetwork && !customDockerBridgeNetwork &&
!accessPoint && !accessPoint &&
@@ -68,5 +69,6 @@ func mustAddr(t *testing.T, ip net.IP) netip.Addr {
if !ok { if !ok {
t.Fatal(ipConvertError{ip}) t.Fatal(ipConvertError{ip})
} }
return addr return addr
} }

View File

@@ -27,7 +27,7 @@ func supportedNetworkTypes() []NetworkType {
} }
} }
// NetworkType represents the type of network // NetworkType represents the type of network.
type NetworkType int type NetworkType int
const ( const (
@@ -69,7 +69,7 @@ func (t NetworkType) IsTCP() bool {
return t == NetworkTypeTCP4 || t == NetworkTypeTCP6 return t == NetworkTypeTCP4 || t == NetworkTypeTCP6
} }
// NetworkShort returns the short network description // NetworkShort returns the short network description.
func (t NetworkType) NetworkShort() string { func (t NetworkType) NetworkShort() string {
switch t { switch t {
case NetworkTypeUDP4, NetworkTypeUDP6: case NetworkTypeUDP4, NetworkTypeUDP6:
@@ -81,7 +81,7 @@ func (t NetworkType) NetworkShort() string {
} }
} }
// IsReliable returns true if the network is reliable // IsReliable returns true if the network is reliable.
func (t NetworkType) IsReliable() bool { func (t NetworkType) IsReliable() bool {
switch t { switch t {
case NetworkTypeUDP4, NetworkTypeUDP6: case NetworkTypeUDP4, NetworkTypeUDP6:
@@ -89,6 +89,7 @@ func (t NetworkType) IsReliable() bool {
case NetworkTypeTCP4, NetworkTypeTCP6: case NetworkTypeTCP4, NetworkTypeTCP6:
return true return true
} }
return false return false
} }
@@ -100,6 +101,7 @@ func (t NetworkType) IsIPv4() bool {
case NetworkTypeUDP6, NetworkTypeTCP6: case NetworkTypeUDP6, NetworkTypeTCP6:
return false return false
} }
return false return false
} }
@@ -111,6 +113,7 @@ func (t NetworkType) IsIPv6() bool {
case NetworkTypeUDP6, NetworkTypeTCP6: case NetworkTypeUDP6, NetworkTypeTCP6:
return true return true
} }
return false return false
} }
@@ -124,12 +127,14 @@ func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) {
if ip.Is4() { if ip.Is4() {
return NetworkTypeUDP4, nil return NetworkTypeUDP4, nil
} }
return NetworkTypeUDP6, nil return NetworkTypeUDP6, nil
case strings.HasPrefix(strings.ToLower(network), tcp): case strings.HasPrefix(strings.ToLower(network), tcp):
if ip.Is4() { if ip.Is4() {
return NetworkTypeTCP4, nil return NetworkTypeTCP4, nil
} }
return NetworkTypeTCP6, nil return NetworkTypeTCP6, nil
} }

View File

@@ -19,6 +19,7 @@ func (p PriorityAttr) AddTo(m *stun.Message) error {
v := make([]byte, prioritySize) v := make([]byte, prioritySize)
binary.BigEndian.PutUint32(v, uint32(p)) binary.BigEndian.PutUint32(v, uint32(p))
m.Add(stun.AttrPriority, v) m.Add(stun.AttrPriority, v)
return nil return nil
} }
@@ -32,5 +33,6 @@ func (p *PriorityAttr) GetFrom(m *stun.Message) error {
return err return err
} }
*p = PriorityAttr(binary.BigEndian.Uint32(v)) *p = PriorityAttr(binary.BigEndian.Uint32(v))
return nil return nil
} }

View File

@@ -12,11 +12,11 @@ import (
func TestPriority_GetFrom(t *testing.T) { //nolint:dupl func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
m := new(stun.Message) m := new(stun.Message)
var p PriorityAttr var priority PriorityAttr
if err := p.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) { if err := priority.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
t.Error("unexpected error") t.Error("unexpected error")
} }
if err := m.Build(stun.BindingRequest, &p); err != nil { if err := m.Build(stun.BindingRequest, &priority); err != nil {
t.Error(err) t.Error(err)
} }
m1 := new(stun.Message) m1 := new(stun.Message)
@@ -27,7 +27,7 @@ func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
if err := p1.GetFrom(m1); err != nil { if err := p1.GetFrom(m1); err != nil {
t.Error(err) t.Error(err)
} }
if p1 != p { if p1 != priority {
t.Error("not equal") t.Error("not equal")
} }
t.Run("IncorrectSize", func(t *testing.T) { t.Run("IncorrectSize", func(t *testing.T) {

View File

@@ -23,21 +23,27 @@ func TestRandomGeneratorCollision(t *testing.T) {
}, },
"PWD": { "PWD": {
gen: func(t *testing.T) string { gen: func(t *testing.T) string {
t.Helper()
s, err := generatePwd() s, err := generatePwd()
require.NoError(t, err) require.NoError(t, err)
return s return s
}, },
}, },
"Ufrag": { "Ufrag": {
gen: func(t *testing.T) string { gen: func(t *testing.T) string {
t.Helper()
s, err := generateUFrag() s, err := generateUFrag()
require.NoError(t, err) require.NoError(t, err)
return s return s
}, },
}, },
} }
const N = 100 const num = 100
const iteration = 100 const iteration = 100
for name, testCase := range testCases { for name, testCase := range testCases {
@@ -47,9 +53,9 @@ func TestRandomGeneratorCollision(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
var mu sync.Mutex var mu sync.Mutex
rands := make([]string, 0, N) rands := make([]string, 0, num)
for i := 0; i < N; i++ { for i := 0; i < num; i++ {
wg.Add(1) wg.Add(1)
go func() { go func() {
r := testCase.gen(t) r := testCase.gen(t)
@@ -61,12 +67,12 @@ func TestRandomGeneratorCollision(t *testing.T) {
} }
wg.Wait() wg.Wait()
if len(rands) != N { if len(rands) != num {
t.Fatal("Failed to generate randoms") t.Fatal("Failed to generate randoms")
} }
for i := 0; i < N; i++ { for i := 0; i < num; i++ {
for j := i + 1; j < N; j++ { for j := i + 1; j < num; j++ {
if rands[i] == rands[j] { if rands[i] == rands[j] {
t.Fatalf("generateRandString caused collision: %s == %s", rands[i], rands[j]) t.Fatalf("generateRandString caused collision: %s == %s", rands[i], rands[j])
} }

View File

@@ -26,6 +26,7 @@ func (r *Role) UnmarshalText(text []byte) error {
default: default:
return fmt.Errorf("%w %q", errUnknownRole, text) return fmt.Errorf("%w %q", errUnknownRole, text)
} }
return nil return nil
} }

View File

@@ -44,6 +44,7 @@ func (s *controllingSelector) isNominatable(c Candidate) bool {
} }
s.log.Errorf("Invalid candidate type: %s", c.Type()) s.log.Errorf("Invalid candidate type: %s", c.Type())
return false return false
} }
@@ -63,6 +64,7 @@ func (s *controllingSelector) ContactCandidates() {
p.nominated = true p.nominated = true
s.nominatedPair = p s.nominatedPair = p
s.nominatePair(p) s.nominatePair(p)
return return
} }
s.agent.pingAllCandidates() s.agent.pingAllCandidates()
@@ -84,6 +86,7 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) {
) )
if err != nil { if err != nil {
s.log.Error(err.Error()) s.log.Error(err.Error())
return return
} }
@@ -91,30 +94,35 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) {
s.agent.sendBindingRequest(msg, pair.Local, pair.Remote) s.agent.sendBindingRequest(msg, pair.Local, pair.Remote)
} }
func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) { func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop
s.agent.sendBindingSuccess(m, local, remote) s.agent.sendBindingSuccess(message, local, remote)
p := s.agent.findPair(local, remote) pair := s.agent.findPair(local, remote)
if p == nil { if pair == nil {
s.agent.addPair(local, remote) s.agent.addPair(local, remote)
return return
} }
if p.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil { if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil {
bestPair := s.agent.getBestAvailableCandidatePair() bestPair := s.agent.getBestAvailableCandidatePair()
if bestPair == nil { if bestPair == nil {
s.log.Tracef("No best pair available") s.log.Tracef("No best pair available")
} else if bestPair.equal(p) && s.isNominatable(p.Local) && s.isNominatable(p.Remote) { } else if bestPair.equal(pair) && s.isNominatable(pair.Local) && s.isNominatable(pair.Remote) {
s.log.Tracef("The candidate (%s, %s) is the best candidate available, marking it as nominated", p.Local, p.Remote) s.log.Tracef(
s.nominatedPair = p "The candidate (%s, %s) is the best candidate available, marking it as nominated",
s.nominatePair(p) pair.Local,
pair.Remote,
)
s.nominatedPair = pair
s.nominatePair(pair)
} }
} }
if s.agent.userBindingRequestHandler != nil { if s.agent.userBindingRequestHandler != nil {
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch { if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch {
s.agent.setSelectedPair(p) s.agent.setSelectedPair(pair)
} }
} }
} }
@@ -123,6 +131,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo
ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID)
if !ok { if !ok {
s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID)
return return
} }
@@ -131,26 +140,32 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo
// Assert that NAT is not symmetric // Assert that NAT is not symmetric
// https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1
if !addrEqual(transactionAddr, remoteAddr) { if !addrEqual(transactionAddr, remoteAddr) {
s.log.Debugf("Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote) s.log.Debugf(
"Discard message: transaction source and destination does not match expected(%s), actual(%s)",
transactionAddr,
remote,
)
return return
} }
s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local)
p := s.agent.findPair(local, remote) pair := s.agent.findPair(local, remote)
if p == nil { if pair == nil {
// This shouldn't happen // This shouldn't happen
s.log.Error("Success response from invalid candidate pair") s.log.Error("Success response from invalid candidate pair")
return return
} }
p.state = CandidatePairStateSucceeded pair.state = CandidatePairStateSucceeded
s.log.Tracef("Found valid candidate pair: %s", p) s.log.Tracef("Found valid candidate pair: %s", pair)
if pendingRequest.isUseCandidate && s.agent.getSelectedPair() == nil { if pendingRequest.isUseCandidate && s.agent.getSelectedPair() == nil {
s.agent.setSelectedPair(p) s.agent.setSelectedPair(pair)
} }
p.UpdateRoundTripTime(rtt) pair.UpdateRoundTripTime(rtt)
} }
func (s *controllingSelector) PingCandidate(local, remote Candidate) { func (s *controllingSelector) PingCandidate(local, remote Candidate) {
@@ -163,6 +178,7 @@ func (s *controllingSelector) PingCandidate(local, remote Candidate) {
) )
if err != nil { if err != nil {
s.log.Error(err.Error()) s.log.Error(err.Error())
return return
} }
@@ -198,6 +214,7 @@ func (s *controlledSelector) PingCandidate(local, remote Candidate) {
) )
if err != nil { if err != nil {
s.log.Error(err.Error()) s.log.Error(err.Error())
return return
} }
@@ -216,6 +233,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID) ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID)
if !ok { if !ok {
s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID) s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID)
return return
} }
@@ -224,52 +242,62 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
// Assert that NAT is not symmetric // Assert that NAT is not symmetric
// https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1 // https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1
if !addrEqual(transactionAddr, remoteAddr) { if !addrEqual(transactionAddr, remoteAddr) {
s.log.Debugf("Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote) s.log.Debugf(
"Discard message: transaction source and destination does not match expected(%s), actual(%s)",
transactionAddr,
remote,
)
return return
} }
s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local) s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local)
p := s.agent.findPair(local, remote) pair := s.agent.findPair(local, remote)
if p == nil { if pair == nil {
// This shouldn't happen // This shouldn't happen
s.log.Error("Success response from invalid candidate pair") s.log.Error("Success response from invalid candidate pair")
return return
} }
p.state = CandidatePairStateSucceeded pair.state = CandidatePairStateSucceeded
s.log.Tracef("Found valid candidate pair: %s", p) s.log.Tracef("Found valid candidate pair: %s", pair)
if p.nominateOnBindingSuccess { if pair.nominateOnBindingSuccess {
if selectedPair := s.agent.getSelectedPair(); selectedPair == nil || if selectedPair := s.agent.getSelectedPair(); selectedPair == nil ||
(selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) { (selectedPair != pair &&
s.agent.setSelectedPair(p) (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= pair.priority())) {
} else if selectedPair != p { s.agent.setSelectedPair(pair)
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair) } else if selectedPair != pair {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair)
} }
} }
p.UpdateRoundTripTime(rtt) pair.UpdateRoundTripTime(rtt)
} }
func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) { func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop
p := s.agent.findPair(local, remote) pair := s.agent.findPair(local, remote)
if p == nil { if pair == nil {
p = s.agent.addPair(local, remote) pair = s.agent.addPair(local, remote)
} }
if m.Contains(stun.AttrUseCandidate) { if message.Contains(stun.AttrUseCandidate) { //nolint:nestif
// https://tools.ietf.org/html/rfc8445#section-7.3.1.5 // https://tools.ietf.org/html/rfc8445#section-7.3.1.5
if p.state == CandidatePairStateSucceeded { if pair.state == CandidatePairStateSucceeded {
// If the state of this pair is Succeeded, it means that the check // If the state of this pair is Succeeded, it means that the check
// previously sent by this pair produced a successful response and // previously sent by this pair produced a successful response and
// generated a valid pair (Section 7.2.5.3.2). The agent sets the // generated a valid pair (Section 7.2.5.3.2). The agent sets the
// nominated flag value of the valid pair to true. // nominated flag value of the valid pair to true.
selectedPair := s.agent.getSelectedPair() selectedPair := s.agent.getSelectedPair()
if selectedPair == nil || (selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) { if selectedPair == nil ||
s.agent.setSelectedPair(p) (selectedPair != pair &&
} else if selectedPair != p { (!s.agent.needsToCheckPriorityOnNominated() ||
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair) selectedPair.priority() <= pair.priority())) {
s.agent.setSelectedPair(pair)
} else if selectedPair != pair {
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair)
} }
} else { } else {
// If the received Binding request triggered a new check to be // If the received Binding request triggered a new check to be
@@ -280,16 +308,16 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote
// MUST remove the candidate pair from the valid list, set the // MUST remove the candidate pair from the valid list, set the
// candidate pair state to Failed, and set the checklist state to // candidate pair state to Failed, and set the checklist state to
// Failed. // Failed.
p.nominateOnBindingSuccess = true pair.nominateOnBindingSuccess = true
} }
} }
s.agent.sendBindingSuccess(m, local, remote) s.agent.sendBindingSuccess(message, local, remote)
s.PingCandidate(local, remote) s.PingCandidate(local, remote)
if s.agent.userBindingRequestHandler != nil { if s.agent.userBindingRequestHandler != nil {
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch { if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch {
s.agent.setSelectedPair(p) s.agent.setSelectedPair(pair)
} }
} }
} }
@@ -298,7 +326,7 @@ type liteSelector struct {
pairCandidateSelector pairCandidateSelector
} }
// A lite selector should not contact candidates // A lite selector should not contact candidates.
func (s *liteSelector) ContactCandidates() { func (s *liteSelector) ContactCandidates() {
if _, ok := s.pairCandidateSelector.(*controllingSelector); ok { if _, ok := s.pairCandidateSelector.(*controllingSelector); ok {
//nolint:godox //nolint:godox

View File

@@ -23,6 +23,8 @@ import (
) )
func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool { func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool {
t.Helper()
testMessage := []byte("Hello World") testMessage := []byte("Hello World")
testBuffer := make([]byte, len(testMessage)) testBuffer := make([]byte, len(testMessage))
@@ -73,6 +75,7 @@ func TestBindingRequestHandler(t *testing.T) {
CheckInterval: &oneHour, CheckInterval: &oneHour,
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
controlledLoggingFired.Store(true) controlledLoggingFired.Store(true)
return false return false
}, },
}) })
@@ -87,6 +90,7 @@ func TestBindingRequestHandler(t *testing.T) {
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
// Don't switch candidate pair until we are ready // Don't switch candidate pair until we are ready
val, ok := switchToNewCandidatePair.Load().(bool) val, ok := switchToNewCandidatePair.Load().(bool)
return ok && val return ok && val
}, },
}) })

View File

@@ -7,7 +7,7 @@ import (
"time" "time"
) )
// CandidatePairStats contains ICE candidate pair statistics // CandidatePairStats contains ICE candidate pair statistics.
type CandidatePairStats struct { type CandidatePairStats struct {
// Timestamp is the timestamp associated with this object. // Timestamp is the timestamp associated with this object.
Timestamp time.Time Timestamp time.Time

View File

@@ -79,20 +79,20 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
params.AliveDurationForConnFromStun = 30 * time.Second params.AliveDurationForConnFromStun = 30 * time.Second
} }
m := &TCPMuxDefault{ mux := &TCPMuxDefault{
params: &params, params: &params,
connsIPv4: map[string]map[ipAddr]*tcpPacketConn{}, connsIPv4: map[string]map[ipAddr]*tcpPacketConn{},
connsIPv6: map[string]map[ipAddr]*tcpPacketConn{}, connsIPv6: map[string]map[ipAddr]*tcpPacketConn{},
} }
m.wg.Add(1) mux.wg.Add(1)
go func() { go func() {
defer m.wg.Done() defer mux.wg.Done()
m.start() mux.start()
}() }()
return m return mux
} }
func (m *TCPMuxDefault) start() { func (m *TCPMuxDefault) start() {
@@ -101,6 +101,7 @@ func (m *TCPMuxDefault) start() {
conn, err := m.params.Listener.Accept() conn, err := m.params.Listener.Accept()
if err != nil { if err != nil {
m.params.Logger.Infof("Error accepting connection: %s", err) m.params.Logger.Infof("Error accepting connection: %s", err)
return return
} }
@@ -130,6 +131,7 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP)
if conn, ok := m.getConn(ufrag, isIPv6, local); ok { if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
conn.ClearAliveTimer() conn.ClearAliveTimer()
return conn, nil return conn, nil
} }
@@ -191,12 +193,17 @@ func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
} }
} }
func (m *TCPMuxDefault) handleConn(conn net.Conn) { func (m *TCPMuxDefault) handleConn(conn net.Conn) { //nolint:cyclop
buf := make([]byte, 512) buf := make([]byte, 512)
if m.params.FirstStunBindTimeout > 0 { if m.params.FirstStunBindTimeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil {
m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err) m.params.Logger.Warnf(
"Failed to set read deadline for first STUN message: %s to %s , err: %s",
conn.RemoteAddr(),
conn.LocalAddr(),
err,
)
} }
} }
n, err := readStreamingPacket(conn, buf) n, err := readStreamingPacket(conn, buf)
@@ -207,6 +214,7 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err) m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err)
} }
m.closeAndLogError(conn) m.closeAndLogError(conn)
return return
} }
if err = conn.SetReadDeadline(time.Time{}); err != nil { if err = conn.SetReadDeadline(time.Time{}); err != nil {
@@ -223,12 +231,14 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
if err = msg.Decode(); err != nil { if err = msg.Decode(); err != nil {
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
return return
} }
if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
return return
} }
@@ -239,7 +249,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
attr, err := msg.Get(stun.AttrUsername) attr, err := msg.Get(stun.AttrUsername)
if err != nil { if err != nil {
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) m.params.Logger.Warnf(
"No Username attribute in STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return return
} }
@@ -249,7 +264,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil { if err != nil {
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) m.params.Logger.Warnf(
"Failed to get host in STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return return
} }
@@ -258,7 +278,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
localAddr, ok := conn.LocalAddr().(*net.TCPAddr) localAddr, ok := conn.LocalAddr().(*net.TCPAddr)
if !ok { if !ok {
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) m.params.Logger.Warnf(
"Failed to get local tcp address in STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return return
} }
m.mu.Lock() m.mu.Lock()
@@ -269,7 +294,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
if err != nil { if err != nil {
m.mu.Unlock() m.mu.Unlock()
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) m.params.Logger.Warnf(
"Failed to create packetConn for STUN message from %s to %s",
conn.RemoteAddr(),
conn.LocalAddr(),
)
return return
} }
} }
@@ -277,7 +307,13 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
if err := packetConn.AddConn(conn, buf); err != nil { if err := packetConn.AddConn(conn, buf); err != nil {
m.closeAndLogError(conn) m.closeAndLogError(conn)
m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) m.params.Logger.Warnf(
"Error adding conn to tcpPacketConn from %s to %s: %s",
conn.RemoteAddr(),
conn.LocalAddr(),
err,
)
return return
} }
} }
@@ -428,7 +464,7 @@ func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) { func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
bufCopy := make([]byte, streamingPacketHeaderLen+len(buf)) bufCopy := make([]byte, streamingPacketHeaderLen+len(buf))
binary.BigEndian.PutUint16(bufCopy, uint16(len(buf))) binary.BigEndian.PutUint16(bufCopy, uint16(len(buf))) //nolint:gosec // G115
copy(bufCopy[2:], buf) copy(bufCopy[2:], buf)
n, err := conn.Write(bufCopy) n, err := conn.Write(bufCopy)

View File

@@ -40,6 +40,7 @@ func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net
if len(m.muxes) == 0 { if len(m.muxes) == 0 {
return nil, errNoTCPMuxAvailable return nil, errNoTCPMuxAvailable
} }
return m.muxes[0].GetConnByUfrag(ufrag, isIPv6, local) return m.muxes[0].GetConnByUfrag(ufrag, isIPv6, local)
} }
@@ -51,7 +52,7 @@ func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
} }
// GetAllConns returns a PacketConn for each underlying TCPMux // GetAllConns returns a PacketConn for each underlying TCPMux.
func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) { func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) {
if len(m.muxes) == 0 { if len(m.muxes) == 0 {
// Make sure that we either return at least one connection or an error. // Make sure that we either return at least one connection or an error.
@@ -68,10 +69,11 @@ func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP
conns = append(conns, conn) conns = append(conns, conn)
} }
} }
return conns, nil return conns, nil
} }
// Close the multi mux, no further connections could be created // Close the multi mux, no further connections could be created.
func (m *MultiTCPMuxDefault) Close() error { func (m *MultiTCPMuxDefault) Close() error {
var err error var err error
for _, mux := range m.muxes { for _, mux := range m.muxes {
@@ -79,5 +81,6 @@ func (m *MultiTCPMuxDefault) Close() error {
err = e err = e
} }
} }
return err return err
} }

View File

@@ -36,6 +36,7 @@ func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) n
} }
go bc.writeProcess() go bc.writeProcess()
return bc return bc
} }
@@ -44,6 +45,7 @@ func (bc *bufferedConn) Write(b []byte) (int, error) {
if err != nil { if err != nil {
return n, err return n, err
} }
return n, nil return n, nil
} }
@@ -57,11 +59,13 @@ func (bc *bufferedConn) writeProcess() {
if err != nil { if err != nil {
bc.logger.Warnf("Failed to read from buffer: %s", err) bc.logger.Warnf("Failed to read from buffer: %s", err)
continue continue
} }
if _, err := bc.Conn.Write(pktBuf[:n]); err != nil { if _, err := bc.Conn.Write(pktBuf[:n]); err != nil {
bc.logger.Warnf("Failed to write: %s", err) bc.logger.Warnf("Failed to write: %s", err)
continue continue
} }
} }
@@ -70,6 +74,7 @@ func (bc *bufferedConn) writeProcess() {
func (bc *bufferedConn) Close() error { func (bc *bufferedConn) Close() error {
atomic.StoreInt32(&bc.closed, 1) atomic.StoreInt32(&bc.closed, 1)
_ = bc.buf.Close() _ = bc.buf.Close()
return bc.Conn.Close() return bc.Conn.Close()
} }
@@ -103,7 +108,7 @@ type tcpPacketParams struct {
} }
func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn { func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
p := &tcpPacketConn{ packet := &tcpPacketConn{
params: &params, params: &params,
conns: map[string]net.Conn{}, conns: map[string]net.Conn{},
@@ -113,13 +118,13 @@ func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
} }
if params.AliveDuration > 0 { if params.AliveDuration > 0 {
p.aliveTimer = time.AfterFunc(params.AliveDuration, func() { packet.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
p.params.Logger.Warn("close tcp packet conn by alive timeout") packet.params.Logger.Warn("close tcp packet conn by alive timeout")
_ = p.Close() _ = packet.Close()
}) })
} }
return p return packet
} }
func (t *tcpPacketConn) ClearAliveTimer() { func (t *tcpPacketConn) ClearAliveTimer() {
@@ -131,7 +136,12 @@ func (t *tcpPacketConn) ClearAliveTimer() {
} }
func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error { func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr()) t.params.Logger.Infof(
"Added connection: %s remote %s to local %s",
conn.RemoteAddr().Network(),
conn.RemoteAddr(),
conn.LocalAddr(),
)
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@@ -183,6 +193,7 @@ func (t *tcpPacketConn) startReading(conn net.Conn) {
if last || !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) { if last || !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err}) t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
} }
return return
} }
@@ -236,6 +247,7 @@ func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
n = len(pkt.Data) n = len(pkt.Data)
copy(b, pkt.Data[:n]) copy(b, pkt.Data[:n])
return n, pkt.RAddr, err return n, pkt.RAddr, err
} }
@@ -252,6 +264,7 @@ func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
n, err = writeStreamingPacket(conn, buf) n, err = writeStreamingPacket(conn, buf)
if err != nil { if err != nil {
t.params.Logger.Tracef("%w %s", errWrite, rAddr) t.params.Logger.Tracef("%w %s", errWrite, rAddr)
return n, err return n, err
} }
@@ -272,6 +285,7 @@ func (t *tcpPacketConn) removeConn(conn net.Conn) bool {
t.closeAndLogError(conn) t.closeAndLogError(conn)
delete(t.conns, conn.RemoteAddr().String()) delete(t.conns, conn.RemoteAddr().String())
return len(t.conns) == 0 return len(t.conns) == 0
} }

View File

@@ -32,12 +32,12 @@ type Conn struct {
agent *Agent agent *Agent
} }
// BytesSent returns the number of bytes sent // BytesSent returns the number of bytes sent.
func (c *Conn) BytesSent() uint64 { func (c *Conn) BytesSent() uint64 {
return atomic.LoadUint64(&c.bytesSent) return atomic.LoadUint64(&c.bytesSent)
} }
// BytesReceived returns the number of bytes received // BytesReceived returns the number of bytes received.
func (c *Conn) BytesReceived() uint64 { func (c *Conn) BytesReceived() uint64 {
return atomic.LoadUint64(&c.bytesReceived) return atomic.LoadUint64(&c.bytesReceived)
} }
@@ -74,18 +74,19 @@ func (c *Conn) Read(p []byte) (int, error) {
} }
n, err := c.agent.buf.Read(p) n, err := c.agent.buf.Read(p)
atomic.AddUint64(&c.bytesReceived, uint64(n)) atomic.AddUint64(&c.bytesReceived, uint64(n)) //nolint:gosec // G115
return n, err return n, err
} }
// Write implements the Conn Write method. // Write implements the Conn Write method.
func (c *Conn) Write(p []byte) (int, error) { func (c *Conn) Write(packet []byte) (int, error) {
err := c.agent.loop.Err() err := c.agent.loop.Err()
if err != nil { if err != nil {
return 0, err return 0, err
} }
if stun.IsMessage(p) { if stun.IsMessage(packet) {
return 0, errWriteSTUNMessageToIceConn return 0, errWriteSTUNMessageToIceConn
} }
@@ -102,8 +103,9 @@ func (c *Conn) Write(p []byte) (int, error) {
} }
} }
atomic.AddUint64(&c.bytesSent, uint64(len(p))) atomic.AddUint64(&c.bytesSent, uint64(len(packet)))
return pair.Write(p)
return pair.Write(packet)
} }
// Close implements the Conn Close method. It is used to close // Close implements the Conn Close method. It is used to close
@@ -132,17 +134,17 @@ func (c *Conn) RemoteAddr() net.Addr {
return pair.Remote.addr() return pair.Remote.addr()
} }
// SetDeadline is a stub // SetDeadline is a stub.
func (c *Conn) SetDeadline(time.Time) error { func (c *Conn) SetDeadline(time.Time) error {
return nil return nil
} }
// SetReadDeadline is a stub // SetReadDeadline is a stub.
func (c *Conn) SetReadDeadline(time.Time) error { func (c *Conn) SetReadDeadline(time.Time) error {
return nil return nil
} }
// SetWriteDeadline is a stub // SetWriteDeadline is a stub.
func (c *Conn) SetWriteDeadline(time.Time) error { func (c *Conn) SetWriteDeadline(time.Time) error {
return nil return nil
} }

View File

@@ -30,13 +30,15 @@ func TestStressDuplex(t *testing.T) {
stressDuplex(t) stressDuplex(t)
} }
func testTimeout(t *testing.T, c *Conn, timeout time.Duration) { func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) {
t.Helper()
const pollRate = 100 * time.Millisecond const pollRate = 100 * time.Millisecond
const margin = 20 * time.Millisecond // Allow 20msec error in time const margin = 20 * time.Millisecond // Allow 20msec error in time
ticker := time.NewTicker(pollRate) ticker := time.NewTicker(pollRate)
defer func() { defer func() {
ticker.Stop() ticker.Stop()
err := c.Close() err := conn.Close()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -49,8 +51,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
var cs ConnectionState var cs ConnectionState
err := c.agent.loop.Run(context.Background(), func(_ context.Context) { err := conn.agent.loop.Run(context.Background(), func(_ context.Context) {
cs = c.agent.connectionState cs = conn.agent.connectionState
}) })
if err != nil { if err != nil {
// We should never get here. // We should never get here.
@@ -63,6 +65,7 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
t.Fatalf("Connection timed out %f msec early", elapsed.Seconds()*1000) t.Fatalf("Connection timed out %f msec early", elapsed.Seconds()*1000)
} else { } else {
t.Logf("Connection timed out in %f msec", elapsed.Seconds()*1000) t.Logf("Connection timed out in %f msec", elapsed.Seconds()*1000)
return return
} }
} }
@@ -133,6 +136,8 @@ func TestReadClosed(t *testing.T) {
} }
func stressDuplex(t *testing.T) { func stressDuplex(t *testing.T) {
t.Helper()
ca, cb := pipe(nil) ca, cb := pipe(nil)
defer func() { defer func() {
@@ -219,6 +224,7 @@ func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
// Ensure accepted // Ensure accepted
<-accepted <-accepted
return aConn, bConn return aConn, bConn
} }
@@ -288,6 +294,7 @@ func pipeWithTimeout(disconnectTimeout time.Duration, iceKeepalive time.Duration
func onConnected() (func(ConnectionState), chan struct{}) { func onConnected() (func(ConnectionState), chan struct{}) {
done := make(chan struct{}) done := make(chan struct{})
return func(state ConnectionState) { return func(state ConnectionState) {
if state == ConnectionStateConnected { if state == ConnectionStateConnected {
close(done) close(done)
@@ -295,11 +302,11 @@ func onConnected() (func(ConnectionState), chan struct{}) {
}, done }, done
} }
func randomPort(t testing.TB) int { func randomPort(tb testing.TB) int {
t.Helper() tb.Helper()
conn, err := net.ListenPacket("udp4", "127.0.0.1:0") conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("failed to pickPort: %v", err) tb.Fatalf("failed to pickPort: %v", err)
} }
defer func() { defer func() {
_ = conn.Close() _ = conn.Close()
@@ -308,7 +315,8 @@ func randomPort(t testing.TB) int {
case *net.UDPAddr: case *net.UDPAddr:
return addr.Port return addr.Port
default: default:
t.Fatalf("unknown addr type %T", addr) tb.Fatalf("unknown addr type %T", addr)
return 0 return 0
} }
} }

View File

@@ -30,9 +30,9 @@ func TestRemoteLocalAddr(t *testing.T) {
// Agent1 is behind 1:1 NAT // Agent1 is behind 1:1 NAT
natType1 := &vnet.NATType{Mode: vnet.NATModeNAT1To1} natType1 := &vnet.NATType{Mode: vnet.NATModeNAT1To1}
v, errVnet := buildVNet(natType0, natType1) builtVnet, errVnet := buildVNet(natType0, natType1)
require.NoError(t, errVnet, "should succeed") require.NoError(t, errVnet, "should succeed")
defer v.close() defer builtVnet.close()
stunServerURL := &stun.URI{ stunServerURL := &stun.URI{
Scheme: stun.SchemeTypeSTUN, Scheme: stun.SchemeTypeSTUN,
@@ -53,7 +53,7 @@ func TestRemoteLocalAddr(t *testing.T) {
}) })
t.Run("Remote/Local Pair Match between Agents", func(t *testing.T) { t.Run("Remote/Local Pair Match between Agents", func(t *testing.T) {
ca, cb := pipeWithVNet(v, ca, cb := pipeWithVNet(builtVnet,
&agentTestConfig{ &agentTestConfig{
urls: []*stun.URI{stunServerURL}, urls: []*stun.URI{stunServerURL},
}, },

View File

@@ -18,7 +18,7 @@ import (
"github.com/pion/transport/v3/stdnet" "github.com/pion/transport/v3/stdnet"
) )
// UDPMux allows multiple connections to go over a single UDP port // UDPMux allows multiple connections to go over a single UDP port.
type UDPMux interface { type UDPMux interface {
io.Closer io.Closer
GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
@@ -26,7 +26,7 @@ type UDPMux interface {
GetListenAddresses() []net.Addr GetListenAddresses() []net.Addr
} }
// UDPMuxDefault is an implementation of the interface // UDPMuxDefault is an implementation of the interface.
type UDPMuxDefault struct { type UDPMuxDefault struct {
params UDPMuxParams params UDPMuxParams
@@ -60,14 +60,14 @@ type UDPMuxParams struct {
Net transport.Net Net transport.Net
} }
// NewUDPMuxDefault creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux.
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { //nolint:cyclop
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
var localAddrsForUnspecified []net.Addr var localAddrsForUnspecified []net.Addr
if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { //nolint:nestif
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && udpAddr.IP.IsUnspecified() { } else if ok && udpAddr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
@@ -109,7 +109,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
} }
params.UDPConnString = params.UDPConn.LocalAddr().String() params.UDPConnString = params.UDPConn.LocalAddr().String()
m := &UDPMuxDefault{ mux := &UDPMuxDefault{
addressMap: map[ipPort]*udpMuxedConn{}, addressMap: map[ipPort]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
@@ -124,17 +124,17 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
localAddrsForUnspecified: localAddrsForUnspecified, localAddrsForUnspecified: localAddrsForUnspecified,
} }
go m.connWorker() go mux.connWorker()
return m return mux
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this UDPMuxDefault.
func (m *UDPMuxDefault) LocalAddr() net.Addr { func (m *UDPMuxDefault) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr() return m.params.UDPConn.LocalAddr()
} }
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on.
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
if len(m.localAddrsForUnspecified) > 0 { if len(m.localAddrsForUnspecified) > 0 {
return m.localAddrsForUnspecified return m.localAddrsForUnspecified
@@ -143,8 +143,8 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
return []net.Addr{m.LocalAddr()} return []net.Addr{m.LocalAddr()}
} }
// GetConn returns a PacketConn given the connection's ufrag and network address // GetConn returns a PacketConn given the connection's ufrag and network address.
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found.
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() { if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() {
@@ -181,11 +181,11 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return c, nil return c, nil
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection.
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2) removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock // Keep lock section small to avoid deadlock with conn lock.
m.mu.Lock() m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok { if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag) delete(m.connsIPv4, ufrag)
@@ -198,7 +198,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Unlock() m.mu.Unlock()
if len(removedConns) == 0 { if len(removedConns) == 0 {
// No need to lock if no connection was found // No need to lock if no connection was found.
return return
} }
@@ -213,7 +213,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
} }
// IsClosed returns true if the mux had been closed // IsClosed returns true if the mux had been closed.
func (m *UDPMuxDefault) IsClosed() bool { func (m *UDPMuxDefault) IsClosed() bool {
select { select {
case <-m.closedChan: case <-m.closedChan:
@@ -223,7 +223,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
} }
} }
// Close the mux, no further connections could be created // Close the mux, no further connections could be created.
func (m *UDPMuxDefault) Close() error { func (m *UDPMuxDefault) Close() error {
var err error var err error
m.closeOnce.Do(func() { m.closeOnce.Do(func() {
@@ -244,6 +244,7 @@ func (m *UDPMuxDefault) Close() error {
_ = m.params.UDPConn.Close() _ = m.params.UDPConn.Close()
}) })
return err return err
} }
@@ -276,10 +277,11 @@ func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
LocalAddr: m.LocalAddr(), LocalAddr: m.LocalAddr(),
Logger: m.params.Logger, Logger: m.params.Logger,
}) })
return c return c
} }
func (m *UDPMuxDefault) connWorker() { func (m *UDPMuxDefault) connWorker() { //nolint:cyclop
logger := m.params.Logger logger := m.params.Logger
defer func() { defer func() {
@@ -304,11 +306,13 @@ func (m *UDPMuxDefault) connWorker() {
netUDPAddr, ok := addr.(*net.UDPAddr) netUDPAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
logger.Errorf("Underlying PacketConn did not return a UDPAddr") logger.Errorf("Underlying PacketConn did not return a UDPAddr")
return return
} }
udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) //nolint:gosec
if err != nil { if err != nil {
logger.Errorf("Failed to create a new IP/Port host pair") logger.Errorf("Failed to create a new IP/Port host pair")
return return
} }
@@ -325,12 +329,14 @@ func (m *UDPMuxDefault) connWorker() {
if err = msg.Decode(); err != nil { if err = msg.Decode(); err != nil {
m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
continue continue
} }
attr, stunAttrErr := msg.Get(stun.AttrUsername) attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil { if stunAttrErr != nil {
m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String()) m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String())
continue continue
} }
@@ -344,6 +350,7 @@ func (m *UDPMuxDefault) connWorker() {
if destinationConn == nil { if destinationConn == nil {
m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr, addr) m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr, addr)
continue continue
} }
@@ -359,6 +366,7 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
} else { } else {
val, ok = m.connsIPv4[ufrag] val, ok = m.connsIPv4[ufrag]
} }
return return
} }
@@ -386,7 +394,7 @@ type ipPort struct {
// newIPPort create a custom type of address based on netip.Addr and // newIPPort create a custom type of address based on netip.Addr and
// port. The underlying ip address passed is converted to IPv6 format // port. The underlying ip address passed is converted to IPv6 format
// to simplify ip address handling // to simplify ip address handling.
func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) { func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) {
n, ok := netip.AddrFromSlice(ip.To16()) n, ok := netip.AddrFromSlice(ip.To16())
if !ok { if !ok {

View File

@@ -29,6 +29,7 @@ func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault {
addrToMux[addr.String()] = mux addrToMux[addr.String()] = mux
} }
} }
return &MultiUDPMuxDefault{ return &MultiUDPMuxDefault{
muxes: muxes, muxes: muxes,
localAddrToMux: addrToMux, localAddrToMux: addrToMux,
@@ -42,6 +43,7 @@ func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketCon
if !ok { if !ok {
return nil, errNoUDPMuxAvailable return nil, errNoUDPMuxAvailable
} }
return mux.GetConn(ufrag, addr) return mux.GetConn(ufrag, addr)
} }
@@ -53,7 +55,7 @@ func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
} }
// Close the multi mux, no further connections could be created // Close the multi mux, no further connections could be created.
func (m *MultiUDPMuxDefault) Close() error { func (m *MultiUDPMuxDefault) Close() error {
var err error var err error
for _, mux := range m.muxes { for _, mux := range m.muxes {
@@ -61,21 +63,23 @@ func (m *MultiUDPMuxDefault) Close() error {
err = e err = e
} }
} }
return err return err
} }
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on.
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr { func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
addrs := make([]net.Addr, 0, len(m.localAddrToMux)) addrs := make([]net.Addr, 0, len(m.localAddrToMux))
for _, mux := range m.muxes { for _, mux := range m.muxes {
addrs = append(addrs, mux.GetListenAddresses()...) addrs = append(addrs, mux.GetListenAddresses()...)
} }
return addrs return addrs
} }
// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that // NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
// listen all interfaces on the provided port. // listen all interfaces on the provided port.
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { //nolint:cyclop
params := multiUDPMuxFromPortParam{ params := multiUDPMuxFromPortParam{
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
} }
@@ -104,6 +108,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
}) })
if listenErr != nil { if listenErr != nil {
err = listenErr err = listenErr
break break
} }
if params.readBufferSize > 0 { if params.readBufferSize > 0 {
@@ -119,6 +124,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
for _, conn := range conns { for _, conn := range conns {
_ = conn.Close() _ = conn.Close()
} }
return nil, err return nil, err
} }
@@ -135,7 +141,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
return NewMultiUDPMuxDefault(muxes...), nil return NewMultiUDPMuxDefault(muxes...), nil
} }
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort // UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort.
type UDPMuxFromPortOption interface { type UDPMuxFromPortOption interface {
apply(*multiUDPMuxFromPortParam) apply(*multiUDPMuxFromPortParam)
} }
@@ -159,7 +165,7 @@ func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) {
o.f(p) o.f(p)
} }
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used // UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used.
func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPortOption { func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {
@@ -168,7 +174,7 @@ func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPor
} }
} }
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used // UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used.
func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOption { func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {
@@ -177,7 +183,7 @@ func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOpt
} }
} }
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6 // UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6.
func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption { func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {
@@ -186,7 +192,7 @@ func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
} }
} }
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size // UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size.
func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption { func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {
@@ -195,7 +201,7 @@ func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
} }
} }
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size // UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size.
func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption { func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {
@@ -204,7 +210,7 @@ func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
} }
} }
// UDPMuxFromPortWithLogger set the logger for the created UDPMux // UDPMuxFromPortWithLogger set the logger for the created UDPMux.
func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption { func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {
@@ -213,7 +219,7 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
} }
} }
// UDPMuxFromPortWithLoopback set loopback interface should be included // UDPMuxFromPortWithLoopback set loopback interface should be included.
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption { func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
return &udpMuxFromPortOption{ return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) { f: func(p *multiUDPMuxFromPortParam) {

View File

@@ -79,6 +79,8 @@ func TestMultiUDPMux(t *testing.T) {
} }
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) { func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
t.Helper()
addrs := udpMuxMulti.GetListenAddresses() addrs := udpMuxMulti.GetListenAddresses()
pktConns := make([]net.PacketConn, 0, len(addrs)) pktConns := make([]net.PacketConn, 0, len(addrs))
for _, addr := range addrs { for _, addr := range addrs {

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestUDPMux(t *testing.T) { func TestUDPMux(t *testing.T) { //nolint:cyclop
defer test.CheckRoutines(t)() defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 30).Stop() defer test.TimeOut(time.Second * 30).Stop()
@@ -127,6 +127,8 @@ func TestUDPMux(t *testing.T) {
} }
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
t.Helper()
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr()) pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() { defer func() {
@@ -145,6 +147,8 @@ func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, networ
} }
func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) { func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
t.Helper()
// Initial messages are dropped // Initial messages are dropped
_, err := remoteConn.Write([]byte("dropped bytes")) _, err := remoteConn.Write([]byte("dropped bytes"))
require.NoError(t, err) require.NoError(t, err)
@@ -222,7 +226,7 @@ func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net
require.NoError(t, err) require.NoError(t, err)
h := sha256.Sum256(buf[36:]) h := sha256.Sum256(buf[36:])
copy(buf[4:36], h[:]) copy(buf[4:36], h[:])
binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence)) binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence)) //nolint:gosec // G115
_, err = remoteConn.Write(buf) _, err = remoteConn.Write(buf)
require.NoError(t, err) require.NoError(t, err)
@@ -238,6 +242,8 @@ func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net
} }
func verifyPacket(t *testing.T, b []byte, nextSeq uint32) { func verifyPacket(t *testing.T, b []byte, nextSeq uint32) {
t.Helper()
readSeq := binary.LittleEndian.Uint32(b[0:4]) readSeq := binary.LittleEndian.Uint32(b[0:4])
require.Equal(t, nextSeq, readSeq) require.Equal(t, nextSeq, readSeq)
h := sha256.Sum256(b[36:]) h := sha256.Sum256(b[36:])

View File

@@ -29,7 +29,8 @@ type UniversalUDPMuxDefault struct {
*UDPMuxDefault *UDPMuxDefault
params UniversalUDPMuxParams params UniversalUDPMuxParams
// Since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents // Since we have a shared socket, for srflx candidates it makes sense
// to have a shared mapped address across all the agents
// stun.XORMappedAddress indexed by the STUN server addr // stun.XORMappedAddress indexed by the STUN server addr
xorMappedMap map[string]*xorMapped xorMappedMap map[string]*xorMapped
} }
@@ -42,7 +43,7 @@ type UniversalUDPMuxParams struct {
Net transport.Net Net transport.Net
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux.
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
@@ -51,31 +52,31 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
params.XORMappedAddrCacheTTL = time.Second * 25 params.XORMappedAddrCacheTTL = time.Second * 25
} }
m := &UniversalUDPMuxDefault{ mux := &UniversalUDPMuxDefault{
params: params, params: params,
xorMappedMap: make(map[string]*xorMapped), xorMappedMap: make(map[string]*xorMapped),
} }
// Wrap UDP connection, process server reflexive messages // Wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ mux.params.UDPConn = &udpConn{
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: mux,
logger: params.Logger, logger: params.Logger,
} }
// Embed UDPMux // Embed UDPMux
udpMuxParams := UDPMuxParams{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: mux.params.UDPConn,
Net: m.params.Net, Net: mux.params.Net,
} }
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) mux.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
return m return mux
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets.
type udpConn struct { type udpConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
@@ -88,7 +89,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(net.Addr, time.Duration) (*net.A
return nil, errNotImplemented return nil, errNotImplemented
} }
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL
// (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
@@ -99,24 +101,24 @@ func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr ne
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, addr, err = c.PacketConn.ReadFrom(p) n, addr, err = c.PacketConn.ReadFrom(p)
if err != nil { if err != nil {
return return n, addr, err
} }
if stun.IsMessage(p[:n]) { if stun.IsMessage(p[:n]) { //nolint:nestif
msg := &stun.Message{ msg := &stun.Message{
Raw: append([]byte{}, p[:n]...), Raw: append([]byte{}, p[:n]...),
} }
if err = msg.Decode(); err != nil { if err = msg.Decode(); err != nil {
c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err) c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
err = nil
return return n, addr, nil
} }
udpAddr, ok := addr.(*net.UDPAddr) udpAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
// Message about this err will be logged in the UDPMux // Message about this err will be logged in the UDPMux
return return n, addr, err
} }
if c.mux.isXORMappedResponse(msg, udpAddr.String()) { if c.mux.isXORMappedResponse(msg, udpAddr.String()) {
@@ -125,9 +127,11 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err) c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err)
err = nil err = nil
} }
return
return n, addr, err
} }
} }
return n, addr, err return n, addr, err
} }
@@ -135,14 +139,16 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
// Check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess // Check first if it is a STUN server address,
// because remote peer can also send similar messages but as a BindingSuccess.
_, ok := m.xorMappedMap[stunAddr] _, ok := m.xorMappedMap[stunAddr]
_, err := msg.Get(stun.AttrXORMappedAddress) _, err := msg.Get(stun.AttrXORMappedAddress)
return err == nil && ok return err == nil && ok
} }
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute // handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute.
// and set the mapped address for the server // and set the mapped address for the server.
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -167,7 +173,10 @@ func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr,
// Makes a STUN binding request to discover mapped address otherwise. // Makes a STUN binding request to discover mapped address otherwise.
// Blocks until the stun.XORMappedAddress has been discovered or deadline. // Blocks until the stun.XORMappedAddress has been discovered or deadline.
// Method is safe for concurrent use. // Method is safe for concurrent use.
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { func (m *UniversalUDPMuxDefault) GetXORMappedAddr(
serverAddr net.Addr,
deadline time.Duration,
) (*stun.XORMappedAddress, error) {
m.mu.Lock() m.mu.Lock()
mappedAddr, ok := m.xorMappedMap[serverAddr.String()] mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
// If we already have a mapping for this STUN server (address already received) // If we already have a mapping for this STUN server (address already received)
@@ -203,6 +212,7 @@ func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline
if mappedAddr.addr == nil { if mappedAddr.addr == nil {
return nil, errNoXorAddrMapping return nil, errNoXorAddrMapping
} }
return mappedAddr.addr, nil return mappedAddr.addr, nil
case <-time.After(deadline): case <-time.After(deadline):
return nil, errXORMappedAddrTimeout return nil, errXORMappedAddrTimeout

View File

@@ -43,6 +43,8 @@ func TestUniversalUDPMux(t *testing.T) {
} }
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) { func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
t.Helper()
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr()) pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
require.NoError(t, err, "error retrieving muxed connection for ufrag") require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() { defer func() {

View File

@@ -28,7 +28,7 @@ type udpMuxedConnParams struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
} }
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag.
type udpMuxedConn struct { type udpMuxedConn struct {
params *udpMuxedConnParams params *udpMuxedConnParams
// Remote addresses that we have sent to on this conn // Remote addresses that we have sent to on this conn
@@ -72,11 +72,12 @@ func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
pkt.reset() pkt.reset()
c.params.AddrPool.Put(pkt) c.params.AddrPool.Put(pkt)
return return n, rAddr, err
} }
if c.state == udpMuxedConnClosed { if c.state == udpMuxedConnClosed {
c.mu.Unlock() c.mu.Unlock()
return 0, nil, io.EOF return 0, nil, io.EOF
} }
@@ -101,6 +102,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
return 0, errFailedToCastUDPAddr return 0, errFailedToCastUDPAddr
} }
//nolint:gosec // TODO add port validation G115
ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port))
if err != nil { if err != nil {
return 0, err return 0, err
@@ -150,12 +152,14 @@ func (c *udpMuxedConn) Close() error {
c.state = udpMuxedConnClosed c.state = udpMuxedConnClosed
close(c.closedChan) close(c.closedChan)
} }
return nil return nil
} }
func (c *udpMuxedConn) isClosed() bool { func (c *udpMuxedConn) isClosed() bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
return c.state == udpMuxedConnClosed return c.state == udpMuxedConnClosed
} }
@@ -164,6 +168,7 @@ func (c *udpMuxedConn) getAddresses() []ipPort {
defer c.mu.Unlock() defer c.mu.Unlock()
addresses := make([]ipPort, len(c.addresses)) addresses := make([]ipPort, len(c.addresses))
copy(addresses, c.addresses) copy(addresses, c.addresses)
return addresses return addresses
} }
@@ -198,6 +203,7 @@ func (c *udpMuxedConn) containsAddress(addr ipPort) bool {
return true return true
} }
} }
return false return false
} }
@@ -205,6 +211,7 @@ func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
if cap(pkt.buf) < len(data) { if cap(pkt.buf) < len(data) {
c.params.AddrPool.Put(pkt) c.params.AddrPool.Put(pkt)
return io.ErrShortBuffer return io.ErrShortBuffer
} }

36
url.go
View File

@@ -6,77 +6,77 @@ package ice
import "github.com/pion/stun/v3" import "github.com/pion/stun/v3"
type ( type (
// URL represents a STUN (rfc7064) or TURN (rfc7065) URI // URL represents a STUN (rfc7064) or TURN (rfc7065) URI.
// //
// Deprecated: Please use pion/stun.URI // Deprecated: Please use pion/stun.URI.
URL = stun.URI URL = stun.URI
// ProtoType indicates the transport protocol type that is used in the ice.URL // ProtoType indicates the transport protocol type that is used in the ice.URL
// structure. // structure.
// //
// Deprecated: TPlease use pion/stun.ProtoType // Deprecated: TPlease use pion/stun.ProtoType.
ProtoType = stun.ProtoType ProtoType = stun.ProtoType
// SchemeType indicates the type of server used in the ice.URL structure. // SchemeType indicates the type of server used in the ice.URL structure.
// //
// Deprecated: Please use pion/stun.SchemeType // Deprecated: Please use pion/stun.SchemeType.
SchemeType = stun.SchemeType SchemeType = stun.SchemeType
) )
const ( const (
// SchemeTypeSTUN indicates the URL represents a STUN server. // SchemeTypeSTUN indicates the URL represents a STUN server.
// //
// Deprecated: Please use pion/stun.SchemeTypeSTUN // Deprecated: Please use pion/stun.SchemeTypeSTUN.
SchemeTypeSTUN = stun.SchemeTypeSTUN SchemeTypeSTUN = stun.SchemeTypeSTUN
// SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server. // SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server.
// //
// Deprecated: Please use pion/stun.SchemeTypeSTUNS // Deprecated: Please use pion/stun.SchemeTypeSTUNS.
SchemeTypeSTUNS = stun.SchemeTypeSTUNS SchemeTypeSTUNS = stun.SchemeTypeSTUNS
// SchemeTypeTURN indicates the URL represents a TURN server. // SchemeTypeTURN indicates the URL represents a TURN server.
// //
// Deprecated: Please use pion/stun.SchemeTypeTURN // Deprecated: Please use pion/stun.SchemeTypeTURN.
SchemeTypeTURN = stun.SchemeTypeTURN SchemeTypeTURN = stun.SchemeTypeTURN
// SchemeTypeTURNS indicates the URL represents a TURNS (secure) server. // SchemeTypeTURNS indicates the URL represents a TURNS (secure) server.
// //
// Deprecated: Please use pion/stun.SchemeTypeTURNS // Deprecated: Please use pion/stun.SchemeTypeTURNS.
SchemeTypeTURNS = stun.SchemeTypeTURNS SchemeTypeTURNS = stun.SchemeTypeTURNS
) )
const ( const (
// ProtoTypeUDP indicates the URL uses a UDP transport. // ProtoTypeUDP indicates the URL uses a UDP transport.
// //
// Deprecated: Please use pion/stun.ProtoTypeUDP // Deprecated: Please use pion/stun.ProtoTypeUDP.
ProtoTypeUDP = stun.ProtoTypeUDP ProtoTypeUDP = stun.ProtoTypeUDP
// ProtoTypeTCP indicates the URL uses a TCP transport. // ProtoTypeTCP indicates the URL uses a TCP transport.
// //
// Deprecated: Please use pion/stun.ProtoTypeTCP // Deprecated: Please use pion/stun.ProtoTypeTCP.
ProtoTypeTCP = stun.ProtoTypeTCP ProtoTypeTCP = stun.ProtoTypeTCP
) )
// Unknown represents and unknown ProtoType or SchemeType // Unknown represents and unknown ProtoType or SchemeType.
// //
// Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown // Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown.
const Unknown = 0 const Unknown = 0
// ParseURL parses a STUN or TURN urls following the ABNF syntax described in // ParseURL parses a STUN or TURN urls following the ABNF syntax described in.
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065 // https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
// respectively. // respectively.
// //
// Deprecated: Please use pion/stun.ParseURI // Deprecated: Please use pion/stun.ParseURI.
var ParseURL = stun.ParseURI //nolint:gochecknoglobals var ParseURL = stun.ParseURI //nolint:gochecknoglobals
// NewSchemeType defines a procedure for creating a new SchemeType from a raw // NewSchemeType defines a procedure for creating a new SchemeType from a raw.
// string naming the scheme type. // string naming the scheme type.
// //
// Deprecated: Please use pion/stun.NewSchemeType // Deprecated: Please use pion/stun.NewSchemeType.
var NewSchemeType = stun.NewSchemeType //nolint:gochecknoglobals var NewSchemeType = stun.NewSchemeType //nolint:gochecknoglobals
// NewProtoType defines a procedure for creating a new ProtoType from a raw // NewProtoType defines a procedure for creating a new ProtoType from a raw.
// string naming the transport protocol type. // string naming the transport protocol type.
// //
// Deprecated: Please use pion/stun.NewProtoType // Deprecated: Please use pion/stun.NewProtoType.
var NewProtoType = stun.NewProtoType //nolint:gochecknoglobals var NewProtoType = stun.NewProtoType //nolint:gochecknoglobals

View File

@@ -11,12 +11,14 @@ type UseCandidateAttr struct{}
// AddTo adds USE-CANDIDATE attribute to message. // AddTo adds USE-CANDIDATE attribute to message.
func (UseCandidateAttr) AddTo(m *stun.Message) error { func (UseCandidateAttr) AddTo(m *stun.Message) error {
m.Add(stun.AttrUseCandidate, nil) m.Add(stun.AttrUseCandidate, nil)
return nil return nil
} }
// IsSet returns true if USE-CANDIDATE attribute is set. // IsSet returns true if USE-CANDIDATE attribute is set.
func (UseCandidateAttr) IsSet(m *stun.Message) bool { func (UseCandidateAttr) IsSet(m *stun.Message) bool {
_, err := m.Get(stun.AttrUseCandidate) _, err := m.Get(stun.AttrUseCandidate)
return err == nil return err == nil
} }

View File

@@ -13,6 +13,8 @@ import (
) )
func newHostRemote(t *testing.T) *CandidateHost { func newHostRemote(t *testing.T) *CandidateHost {
t.Helper()
remoteHostConfig := &CandidateHostConfig{ remoteHostConfig := &CandidateHostConfig{
Network: "udp", Network: "udp",
Address: "1.2.3.5", Address: "1.2.3.5",
@@ -21,10 +23,13 @@ func newHostRemote(t *testing.T) *CandidateHost {
} }
hostRemote, err := NewCandidateHost(remoteHostConfig) hostRemote, err := NewCandidateHost(remoteHostConfig)
require.NoError(t, err) require.NoError(t, err)
return hostRemote return hostRemote
} }
func newPrflxRemote(t *testing.T) *CandidatePeerReflexive { func newPrflxRemote(t *testing.T) *CandidatePeerReflexive {
t.Helper()
prflxConfig := &CandidatePeerReflexiveConfig{ prflxConfig := &CandidatePeerReflexiveConfig{
Network: "udp", Network: "udp",
Address: "10.10.10.2", Address: "10.10.10.2",
@@ -35,10 +40,13 @@ func newPrflxRemote(t *testing.T) *CandidatePeerReflexive {
} }
prflxRemote, err := NewCandidatePeerReflexive(prflxConfig) prflxRemote, err := NewCandidatePeerReflexive(prflxConfig)
require.NoError(t, err) require.NoError(t, err)
return prflxRemote return prflxRemote
} }
func newSrflxRemote(t *testing.T) *CandidateServerReflexive { func newSrflxRemote(t *testing.T) *CandidateServerReflexive {
t.Helper()
srflxConfig := &CandidateServerReflexiveConfig{ srflxConfig := &CandidateServerReflexiveConfig{
Network: "udp", Network: "udp",
Address: "10.10.10.2", Address: "10.10.10.2",
@@ -49,10 +57,13 @@ func newSrflxRemote(t *testing.T) *CandidateServerReflexive {
} }
srflxRemote, err := NewCandidateServerReflexive(srflxConfig) srflxRemote, err := NewCandidateServerReflexive(srflxConfig)
require.NoError(t, err) require.NoError(t, err)
return srflxRemote return srflxRemote
} }
func newRelayRemote(t *testing.T) *CandidateRelay { func newRelayRemote(t *testing.T) *CandidateRelay {
t.Helper()
relayConfig := &CandidateRelayConfig{ relayConfig := &CandidateRelayConfig{
Network: "udp", Network: "udp",
Address: "1.2.3.4", Address: "1.2.3.4",
@@ -63,10 +74,13 @@ func newRelayRemote(t *testing.T) *CandidateRelay {
} }
relayRemote, err := NewCandidateRelay(relayConfig) relayRemote, err := NewCandidateRelay(relayConfig)
require.NoError(t, err) require.NoError(t, err)
return relayRemote return relayRemote
} }
func newHostLocal(t *testing.T) *CandidateHost { func newHostLocal(t *testing.T) *CandidateHost {
t.Helper()
localHostConfig := &CandidateHostConfig{ localHostConfig := &CandidateHostConfig{
Network: "udp", Network: "udp",
Address: "192.168.1.1", Address: "192.168.1.1",
@@ -75,5 +89,6 @@ func newHostLocal(t *testing.T) *CandidateHost {
} }
hostLocal, err := NewCandidateHost(localHostConfig) hostLocal, err := NewCandidateHost(localHostConfig)
require.NoError(t, err) require.NoError(t, err)
return hostLocal return hostLocal
} }