mirror of
https://github.com/pion/ice.git
synced 2025-09-26 19:41:11 +08:00
Upgrade golangci-lint, more linters
Introduces new linters, upgrade golangci-lint to version (v1.63.4)
This commit is contained in:
@@ -25,17 +25,32 @@ linters-settings:
|
||||
- ^os.Exit$
|
||||
- ^panic$
|
||||
- ^print(ln)?$
|
||||
varnamelen:
|
||||
max-distance: 12
|
||||
min-name-length: 2
|
||||
ignore-type-assert-ok: true
|
||||
ignore-map-index-ok: true
|
||||
ignore-chan-recv-ok: true
|
||||
ignore-decls:
|
||||
- i int
|
||||
- n int
|
||||
- w io.Writer
|
||||
- r io.Reader
|
||||
- b []byte
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
|
||||
- bidichk # Checks for dangerous unicode character sequences
|
||||
- bodyclose # checks whether HTTP response body is closed successfully
|
||||
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
- contextcheck # check the function whether use a non-inherited context
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- decorder # check declaration order and count of types, constants, variables and functions
|
||||
- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
|
||||
- dupl # Tool for code clone detection
|
||||
- durationcheck # check for two durations multiplied together
|
||||
- err113 # Golang linter to check the errors handling expressions
|
||||
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
|
||||
- errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted.
|
||||
- errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`.
|
||||
@@ -46,18 +61,17 @@ linters:
|
||||
- forcetypeassert # finds forced type assertions
|
||||
- gci # Gci control golang package import order and make it always deterministic.
|
||||
- gochecknoglobals # Checks that no globals are present in Go code
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
- gocognit # Computes and checks the cognitive complexity of functions
|
||||
- goconst # Finds repeated strings that could be replaced by a constant
|
||||
- gocritic # The most opinionated Go source code linter
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- godox # Tool for detection of FIXME, TODO and other comment keywords
|
||||
- err113 # Golang linter to check the errors handling expressions
|
||||
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
|
||||
- gofumpt # Gofumpt checks whether code was gofumpt-ed.
|
||||
- goheader # Checks is file header matches to pattern
|
||||
- goimports # Goimports does everything that gofmt does. Additionally it checks unused imports
|
||||
- gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- goprintffuncname # Checks that printf-like functions are named with `f` at the end
|
||||
- gosec # Inspects source code for security problems
|
||||
- gosimple # Linter for Go source code that specializes in simplifying a code
|
||||
@@ -65,9 +79,15 @@ linters:
|
||||
- grouper # An analyzer to analyze expression groups.
|
||||
- importas # Enforces consistent import aliases
|
||||
- ineffassign # Detects when assignments to existing variables are not used
|
||||
- lll # Reports long lines
|
||||
- maintidx # maintidx measures the maintainability index of each function.
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
- misspell # Finds commonly misspelled English words in comments
|
||||
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nilerr # Finds the code that returns nil even if it checks that the error is not nil.
|
||||
- nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value.
|
||||
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
- noctx # noctx finds sending http request without context.Context
|
||||
- predeclared # find code that shadows one of Go's predeclared identifiers
|
||||
- revive # golint replacement, finds style mistakes
|
||||
@@ -75,28 +95,22 @@ linters:
|
||||
- stylecheck # Stylecheck is a replacement for golint
|
||||
- tagliatelle # Checks the struct tags.
|
||||
- tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
|
||||
- unconvert # Remove unnecessary type conversions
|
||||
- unparam # Reports unused function parameters
|
||||
- unused # Checks Go code for unused constants, variables, functions and types
|
||||
- varnamelen # checks that the length of a variable's name matches its scope
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
- whitespace # Tool for detection of leading and trailing whitespace
|
||||
disable:
|
||||
- depguard # Go linter that checks if package imports are in a list of acceptable packages
|
||||
- containedctx # containedctx is a linter that detects struct contained context.Context field
|
||||
- cyclop # checks function and package cyclomatic complexity
|
||||
- funlen # Tool for detection of long functions
|
||||
- gocyclo # Computes and checks the cyclomatic complexity of functions
|
||||
- godot # Check if comments end in a period
|
||||
- gomnd # An analyzer to detect magic numbers.
|
||||
- gochecknoinits # Checks that no init functions are present in Go code
|
||||
- gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
|
||||
- interfacebloat # A linter that checks length of interface.
|
||||
- ireturn # Accept Interfaces, Return Concrete Types
|
||||
- lll # Reports long lines
|
||||
- maintidx # maintidx measures the maintainability index of each function.
|
||||
- makezero # Finds slice declarations with non-zero initial length
|
||||
- nakedret # Finds naked returns in functions greater than a specified function length
|
||||
- nestif # Reports deeply nested if statements
|
||||
- nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity
|
||||
- mnd # An analyzer to detect magic numbers
|
||||
- nolintlint # Reports ill-formed or insufficient nolint directives
|
||||
- paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test
|
||||
- prealloc # Finds slice declarations that could potentially be preallocated
|
||||
@@ -104,8 +118,7 @@ linters:
|
||||
- rowserrcheck # checks whether Err of rows is checked successfully
|
||||
- sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
|
||||
- testpackage # linter that makes you use a separate _test package
|
||||
- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
|
||||
- varnamelen # checks that the length of a variable's name matches its scope
|
||||
- tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
|
||||
- wrapcheck # Checks that errors returned from external packages are wrapped
|
||||
- wsl # Whitespace Linter - Forces you to use empty lines!
|
||||
|
||||
|
@@ -21,7 +21,12 @@ type activeTCPConn struct {
|
||||
closed int32
|
||||
}
|
||||
|
||||
func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress netip.AddrPort, log logging.LeveledLogger) (a *activeTCPConn) {
|
||||
func newActiveTCPConn(
|
||||
ctx context.Context,
|
||||
localAddress string,
|
||||
remoteAddress netip.AddrPort,
|
||||
log logging.LeveledLogger,
|
||||
) (a *activeTCPConn) {
|
||||
a = &activeTCPConn{
|
||||
readBuffer: packetio.NewBuffer(),
|
||||
writeBuffer: packetio.NewBuffer(),
|
||||
@@ -31,7 +36,8 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
|
||||
if err != nil {
|
||||
atomic.StoreInt32(&a.closed, 1)
|
||||
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
|
||||
return
|
||||
|
||||
return a
|
||||
}
|
||||
a.localAddr.Store(laddr)
|
||||
|
||||
@@ -46,6 +52,7 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
|
||||
conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String())
|
||||
if err != nil {
|
||||
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
|
||||
|
||||
return
|
||||
}
|
||||
a.remoteAddr.Store(conn.RemoteAddr())
|
||||
@@ -57,11 +64,13 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
|
||||
n, err := readStreamingPacket(conn, buff)
|
||||
if err != nil {
|
||||
log.Infof("Failed to read streaming packet: %s", err)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if _, err := a.readBuffer.Write(buff[:n]); err != nil {
|
||||
log.Infof("Failed to write to buffer: %s", err)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -73,11 +82,13 @@ func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress ne
|
||||
n, err := a.writeBuffer.Read(buff)
|
||||
if err != nil {
|
||||
log.Infof("Failed to read from buffer: %s", err)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if _, err = writeStreamingPacket(conn, buff[:n]); err != nil {
|
||||
log.Infof("Failed to write streaming packet: %s", err)
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -98,6 +109,7 @@ func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err erro
|
||||
n, err = a.readBuffer.Read(buff)
|
||||
// RemoteAddr is assuredly set *after* we can read from the buffer
|
||||
srcAddr = a.RemoteAddr()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -113,6 +125,7 @@ func (a *activeTCPConn) Close() error {
|
||||
atomic.StoreInt32(&a.closed, 1)
|
||||
_ = a.readBuffer.Close()
|
||||
_ = a.writeBuffer.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -21,19 +21,25 @@ import (
|
||||
)
|
||||
|
||||
func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr {
|
||||
t.Helper()
|
||||
|
||||
net, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, localAddrs)
|
||||
|
||||
return localAddrs[0]
|
||||
}
|
||||
|
||||
func ipv6Available(t *testing.T) bool {
|
||||
t.Helper()
|
||||
|
||||
net, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
return len(localAddrs) > 0
|
||||
}
|
||||
|
||||
@@ -89,14 +95,14 @@ func TestActiveTCP(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
r := require.New(t)
|
||||
req := require.New(t)
|
||||
|
||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
IP: testCase.listenIPAddress.AsSlice(),
|
||||
Port: listenPort,
|
||||
Zone: testCase.listenIPAddress.Zone(),
|
||||
})
|
||||
r.NoError(err)
|
||||
req.NoError(err)
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
@@ -113,7 +119,7 @@ func TestActiveTCP(t *testing.T) {
|
||||
_ = tcpMux.Close()
|
||||
}()
|
||||
|
||||
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
||||
req.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
|
||||
|
||||
hostAcceptanceMinWait := 100 * time.Millisecond
|
||||
cfg := &AgentConfig{
|
||||
@@ -128,8 +134,8 @@ func TestActiveTCP(t *testing.T) {
|
||||
cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather
|
||||
}
|
||||
passiveAgent, err := NewAgent(cfg)
|
||||
r.NoError(err)
|
||||
r.NotNil(passiveAgent)
|
||||
req.NoError(err)
|
||||
req.NotNil(passiveAgent)
|
||||
|
||||
activeAgent, err := NewAgent(&AgentConfig{
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
@@ -138,44 +144,44 @@ func TestActiveTCP(t *testing.T) {
|
||||
HostAcceptanceMinWait: &hostAcceptanceMinWait,
|
||||
InterfaceFilter: problematicNetworkInterfaces,
|
||||
})
|
||||
r.NoError(err)
|
||||
r.NotNil(activeAgent)
|
||||
req.NoError(err)
|
||||
req.NotNil(activeAgent)
|
||||
|
||||
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
|
||||
r.NotNil(passiveAgentConn)
|
||||
r.NotNil(activeAgenConn)
|
||||
req.NotNil(passiveAgentConn)
|
||||
req.NotNil(activeAgenConn)
|
||||
|
||||
defer func() {
|
||||
r.NoError(activeAgenConn.Close())
|
||||
r.NoError(passiveAgentConn.Close())
|
||||
req.NoError(activeAgenConn.Close())
|
||||
req.NoError(passiveAgentConn.Close())
|
||||
}()
|
||||
|
||||
pair := passiveAgent.getSelectedPair()
|
||||
r.NotNil(pair)
|
||||
r.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort())
|
||||
req.NotNil(pair)
|
||||
req.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort())
|
||||
|
||||
foo := []byte("foo")
|
||||
_, err = passiveAgentConn.Write(foo)
|
||||
r.NoError(err)
|
||||
req.NoError(err)
|
||||
|
||||
buffer := make([]byte, 1024)
|
||||
n, err := activeAgenConn.Read(buffer)
|
||||
r.NoError(err)
|
||||
r.Equal(foo, buffer[:n])
|
||||
req.NoError(err)
|
||||
req.Equal(foo, buffer[:n])
|
||||
|
||||
bar := []byte("bar")
|
||||
_, err = activeAgenConn.Write(bar)
|
||||
r.NoError(err)
|
||||
req.NoError(err)
|
||||
|
||||
n, err = passiveAgentConn.Read(buffer)
|
||||
r.NoError(err)
|
||||
r.Equal(bar, buffer[:n])
|
||||
req.NoError(err)
|
||||
req.Equal(bar, buffer[:n])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that Active TCP connectivity isn't established inside
|
||||
// the main thread of the Agent
|
||||
// Assert that Active TCP connectivity isn't established inside.
|
||||
// the main thread of the Agent.
|
||||
func TestActiveTCP_NonBlocking(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -219,7 +225,7 @@ func TestActiveTCP_NonBlocking(t *testing.T) {
|
||||
<-isConnected
|
||||
}
|
||||
|
||||
// Assert that we ignore remote TCP candidates when running a UDP Only Agent
|
||||
// Assert that we ignore remote TCP candidates when running a UDP Only Agent.
|
||||
func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
defer test.TimeOut(time.Second * 5).Stop()
|
||||
@@ -271,7 +277,9 @@ func TestActiveTCP_Respect_NetworkTypes(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidCandidate, err := UnmarshalCandidate(fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port))
|
||||
invalidCandidate, err := UnmarshalCandidate(
|
||||
fmt.Sprintf("1052353102 1 tcp 1675624447 127.0.0.1 %s typ host tcptype passive", port),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, aAgent.AddRemoteCandidate(invalidCandidate))
|
||||
require.NoError(t, bAgent.AddRemoteCandidate(invalidCandidate))
|
||||
|
18
addr.go
18
addr.go
@@ -16,6 +16,7 @@ func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr {
|
||||
if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) {
|
||||
return addr.WithZone(zone)
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
@@ -30,22 +31,25 @@ func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkTy
|
||||
// net.IPNet does not have a Zone but we provide it from the interface
|
||||
addr = addrWithOptionalZone(addr, ifcName)
|
||||
}
|
||||
|
||||
return addr, port, nt, nil
|
||||
}
|
||||
|
||||
func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
|
||||
func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { //nolint:cyclop
|
||||
switch addr := in.(type) {
|
||||
case *net.IPNet:
|
||||
ipAddr, err := ipAddrToNetIP(addr.IP, "")
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, 0, err
|
||||
}
|
||||
|
||||
return ipAddr, 0, 0, nil
|
||||
case *net.IPAddr:
|
||||
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
|
||||
if err != nil {
|
||||
return netip.Addr{}, 0, 0, err
|
||||
}
|
||||
|
||||
return ipAddr, 0, 0, nil
|
||||
case *net.UDPAddr:
|
||||
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
|
||||
@@ -58,6 +62,7 @@ func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
|
||||
} else {
|
||||
nt = NetworkTypeUDP6
|
||||
}
|
||||
|
||||
return ipAddr, addr.Port, nt, nil
|
||||
case *net.TCPAddr:
|
||||
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
|
||||
@@ -70,6 +75,7 @@ func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
|
||||
} else {
|
||||
nt = NetworkTypeTCP6
|
||||
}
|
||||
|
||||
return ipAddr, addr.Port, nt, nil
|
||||
default:
|
||||
return netip.Addr{}, 0, 0, addrParseError{in}
|
||||
@@ -100,6 +106,7 @@ func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) {
|
||||
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable.
|
||||
netIPAddr = netIPAddr.Unmap()
|
||||
netIPAddr = addrWithOptionalZone(netIPAddr, zone)
|
||||
|
||||
return netIPAddr, nil
|
||||
}
|
||||
|
||||
@@ -134,12 +141,13 @@ func toAddrPort(addr net.Addr) AddrPort {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
copy(ap[:16], addr.IP.To16())
|
||||
ap[16] = uint8(addr.Port >> 8)
|
||||
ap[17] = uint8(addr.Port)
|
||||
ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive
|
||||
ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive
|
||||
case *net.TCPAddr:
|
||||
copy(ap[:16], addr.IP.To16())
|
||||
ap[16] = uint8(addr.Port >> 8)
|
||||
ap[17] = uint8(addr.Port)
|
||||
ap[16] = uint8(addr.Port >> 8) //nolint:gosec // G115 false positive
|
||||
ap[17] = uint8(addr.Port) //nolint:gosec // G115 false positive
|
||||
}
|
||||
|
||||
return ap
|
||||
}
|
||||
|
330
agent.go
330
agent.go
@@ -35,7 +35,7 @@ type bindingRequest struct {
|
||||
isUseCandidate bool
|
||||
}
|
||||
|
||||
// Agent represents the ICE agent
|
||||
// Agent represents the ICE agent.
|
||||
type Agent struct {
|
||||
loop *taskloop.Loop
|
||||
|
||||
@@ -149,8 +149,8 @@ type Agent struct {
|
||||
enableUseCandidateCheckPriority bool
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent
|
||||
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
// NewAgent creates a new Agent.
|
||||
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit,cyclop
|
||||
var err error
|
||||
if config.PortMax < config.PortMin {
|
||||
return nil, ErrPort
|
||||
@@ -180,7 +180,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
|
||||
startedCtx, startedFn := context.WithCancel(context.Background())
|
||||
|
||||
a := &Agent{
|
||||
agent := &Agent{
|
||||
tieBreaker: globalMathRandomGenerator.Uint64(),
|
||||
lite: config.Lite,
|
||||
gatheringState: GatheringStateNew,
|
||||
@@ -224,34 +224,46 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
|
||||
enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority,
|
||||
}
|
||||
a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange, done: make(chan struct{})}
|
||||
a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate, done: make(chan struct{})}
|
||||
a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange, done: make(chan struct{})}
|
||||
agent.connectionStateNotifier = &handlerNotifier{
|
||||
connectionStateFunc: agent.onConnectionStateChange,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
agent.candidateNotifier = &handlerNotifier{candidateFunc: agent.onCandidate, done: make(chan struct{})}
|
||||
agent.selectedCandidatePairNotifier = &handlerNotifier{
|
||||
candidatePairFunc: agent.onSelectedCandidatePairChange,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if a.net == nil {
|
||||
a.net, err = stdnet.NewNet()
|
||||
if agent.net == nil {
|
||||
agent.net, err = stdnet.NewNet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create network: %w", err)
|
||||
}
|
||||
} else if _, isVirtual := a.net.(*vnet.Net); isVirtual {
|
||||
a.log.Warn("Virtual network is enabled")
|
||||
if a.mDNSMode != MulticastDNSModeDisabled {
|
||||
a.log.Warn("Virtual network does not support mDNS yet")
|
||||
} else if _, isVirtual := agent.net.(*vnet.Net); isVirtual {
|
||||
agent.log.Warn("Virtual network is enabled")
|
||||
if agent.mDNSMode != MulticastDNSModeDisabled {
|
||||
agent.log.Warn("Virtual network does not support mDNS yet")
|
||||
}
|
||||
}
|
||||
|
||||
localIfcs, _, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback)
|
||||
localIfcs, _, err := localInterfaces(
|
||||
agent.net,
|
||||
agent.interfaceFilter,
|
||||
agent.ipFilter,
|
||||
agent.networkTypes,
|
||||
agent.includeLoopback,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting local interfaces: %w", err)
|
||||
}
|
||||
|
||||
// Opportunistic mDNS: If we can't open the connection, that's ok: we
|
||||
// can continue without it.
|
||||
if a.mDNSConn, a.mDNSMode, err = createMulticastDNS(
|
||||
a.net,
|
||||
a.networkTypes,
|
||||
if agent.mDNSConn, agent.mDNSMode, err = createMulticastDNS(
|
||||
agent.net,
|
||||
agent.networkTypes,
|
||||
localIfcs,
|
||||
a.includeLoopback,
|
||||
agent.includeLoopback,
|
||||
mDNSMode,
|
||||
mDNSName,
|
||||
log,
|
||||
@@ -259,54 +271,60 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
|
||||
}
|
||||
|
||||
config.initWithDefaults(a)
|
||||
config.initWithDefaults(agent)
|
||||
|
||||
// Make sure the buffer doesn't grow indefinitely.
|
||||
// NOTE: We actually won't get anywhere close to this limit.
|
||||
// SRTP will constantly read from the endpoint and drop packets if it's full.
|
||||
a.buf.SetLimitSize(maxBufferSize)
|
||||
agent.buf.SetLimitSize(maxBufferSize)
|
||||
|
||||
if agent.lite && (len(agent.candidateTypes) != 1 || agent.candidateTypes[0] != CandidateTypeHost) {
|
||||
agent.closeMulticastConn()
|
||||
|
||||
if a.lite && (len(a.candidateTypes) != 1 || a.candidateTypes[0] != CandidateTypeHost) {
|
||||
a.closeMulticastConn()
|
||||
return nil, ErrLiteUsingNonHostCandidates
|
||||
}
|
||||
|
||||
if len(config.Urls) > 0 && !containsCandidateType(CandidateTypeServerReflexive, a.candidateTypes) && !containsCandidateType(CandidateTypeRelay, a.candidateTypes) {
|
||||
a.closeMulticastConn()
|
||||
if len(config.Urls) > 0 &&
|
||||
!containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) &&
|
||||
!containsCandidateType(CandidateTypeRelay, agent.candidateTypes) {
|
||||
agent.closeMulticastConn()
|
||||
|
||||
return nil, ErrUselessUrlsProvided
|
||||
}
|
||||
|
||||
if err = config.initExtIPMapping(a); err != nil {
|
||||
a.closeMulticastConn()
|
||||
if err = config.initExtIPMapping(agent); err != nil {
|
||||
agent.closeMulticastConn()
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.loop = taskloop.New(func() {
|
||||
a.removeUfragFromMux()
|
||||
a.deleteAllCandidates()
|
||||
a.startedFn()
|
||||
agent.loop = taskloop.New(func() {
|
||||
agent.removeUfragFromMux()
|
||||
agent.deleteAllCandidates()
|
||||
agent.startedFn()
|
||||
|
||||
if err := a.buf.Close(); err != nil {
|
||||
a.log.Warnf("Failed to close buffer: %v", err)
|
||||
if err := agent.buf.Close(); err != nil {
|
||||
agent.log.Warnf("Failed to close buffer: %v", err)
|
||||
}
|
||||
|
||||
a.closeMulticastConn()
|
||||
a.updateConnectionState(ConnectionStateClosed)
|
||||
agent.closeMulticastConn()
|
||||
agent.updateConnectionState(ConnectionStateClosed)
|
||||
|
||||
a.gatherCandidateCancel()
|
||||
if a.gatherCandidateDone != nil {
|
||||
<-a.gatherCandidateDone
|
||||
agent.gatherCandidateCancel()
|
||||
if agent.gatherCandidateDone != nil {
|
||||
<-agent.gatherCandidateDone
|
||||
}
|
||||
})
|
||||
|
||||
// Restart is also used to initialize the agent for the first time
|
||||
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
|
||||
a.closeMulticastConn()
|
||||
_ = a.Close()
|
||||
if err := agent.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
|
||||
agent.closeMulticastConn()
|
||||
_ = agent.Close()
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return a, nil
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remotePwd string) error {
|
||||
@@ -348,7 +366,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Agent) connectivityChecks() {
|
||||
func (a *Agent) connectivityChecks() { //nolint:cyclop
|
||||
lastConnectionState := ConnectionState(0)
|
||||
checkingDuration := time.Time{}
|
||||
|
||||
@@ -372,6 +390,7 @@ func (a *Agent) connectivityChecks() {
|
||||
// We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed
|
||||
if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout {
|
||||
a.updateConnectionState(ConnectionStateFailed)
|
||||
|
||||
return
|
||||
}
|
||||
default:
|
||||
@@ -383,8 +402,8 @@ func (a *Agent) connectivityChecks() {
|
||||
}
|
||||
}
|
||||
|
||||
t := time.NewTimer(math.MaxInt64)
|
||||
t.Stop()
|
||||
timer := time.NewTimer(math.MaxInt64)
|
||||
timer.Stop()
|
||||
|
||||
for {
|
||||
interval := defaultKeepaliveInterval
|
||||
@@ -406,18 +425,19 @@ func (a *Agent) connectivityChecks() {
|
||||
updateInterval(a.disconnectedTimeout)
|
||||
updateInterval(a.failedTimeout)
|
||||
|
||||
t.Reset(interval)
|
||||
timer.Reset(interval)
|
||||
|
||||
select {
|
||||
case <-a.forceCandidateContact:
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
contact()
|
||||
case <-t.C:
|
||||
case <-timer.C:
|
||||
contact()
|
||||
case <-a.loop.Done():
|
||||
t.Stop()
|
||||
timer.Stop()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -440,22 +460,23 @@ func (a *Agent) updateConnectionState(newState ConnectionState) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) setSelectedPair(p *CandidatePair) {
|
||||
if p == nil {
|
||||
func (a *Agent) setSelectedPair(pair *CandidatePair) {
|
||||
if pair == nil {
|
||||
var nilPair *CandidatePair
|
||||
a.selectedPair.Store(nilPair)
|
||||
a.log.Tracef("Unset selected candidate pair")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
p.nominated = true
|
||||
a.selectedPair.Store(p)
|
||||
a.log.Tracef("Set selected candidate pair: %s", p)
|
||||
pair.nominated = true
|
||||
a.selectedPair.Store(pair)
|
||||
a.log.Tracef("Set selected candidate pair: %s", pair)
|
||||
|
||||
a.updateConnectionState(ConnectionStateConnected)
|
||||
|
||||
// Notify when the selected pair changes
|
||||
a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(p)
|
||||
a.selectedCandidatePairNotifier.EnqueueSelectedCandidatePair(pair)
|
||||
|
||||
// Signal connected
|
||||
a.onConnectedOnce.Do(func() { close(a.onConnected) })
|
||||
@@ -498,6 +519,7 @@ func (a *Agent) getBestAvailableCandidatePair() *CandidatePair {
|
||||
best = p
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
@@ -514,12 +536,14 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
|
||||
best = p
|
||||
}
|
||||
}
|
||||
|
||||
return best
|
||||
}
|
||||
|
||||
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
|
||||
p := newCandidatePair(local, remote, a.isControlling)
|
||||
a.checklist = append(a.checklist, p)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -529,6 +553,7 @@ func (a *Agent) findPair(local, remote Candidate) *CandidatePair {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -578,68 +603,76 @@ func (a *Agent) checkKeepalive() {
|
||||
}
|
||||
}
|
||||
|
||||
// AddRemoteCandidate adds a new remote candidate
|
||||
func (a *Agent) AddRemoteCandidate(c Candidate) error {
|
||||
if c == nil {
|
||||
// AddRemoteCandidate adds a new remote candidate.
|
||||
func (a *Agent) AddRemoteCandidate(cand Candidate) error {
|
||||
if cand == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TCP Candidates with TCP type active will probe server passive ones, so
|
||||
// no need to do anything with them.
|
||||
if c.TCPType() == TCPTypeActive {
|
||||
a.log.Infof("Ignoring remote candidate with tcpType active: %s", c)
|
||||
if cand.TCPType() == TCPTypeActive {
|
||||
a.log.Infof("Ignoring remote candidate with tcpType active: %s", cand)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we have a mDNS Candidate lets fully resolve it before adding it locally
|
||||
if c.Type() == CandidateTypeHost && strings.HasSuffix(c.Address(), ".local") {
|
||||
if cand.Type() == CandidateTypeHost && strings.HasSuffix(cand.Address(), ".local") {
|
||||
if a.mDNSMode == MulticastDNSModeDisabled {
|
||||
a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", c.Address())
|
||||
a.log.Warnf("Remote mDNS candidate added, but mDNS is disabled: (%s)", cand.Address())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
hostCandidate, ok := c.(*CandidateHost)
|
||||
hostCandidate, ok := cand.(*CandidateHost)
|
||||
if !ok {
|
||||
return ErrAddressParseFailed
|
||||
}
|
||||
|
||||
go a.resolveAndAddMulticastCandidate(hostCandidate)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
// nolint: contextcheck
|
||||
a.addRemoteCandidate(c)
|
||||
a.addRemoteCandidate(cand)
|
||||
}); err != nil {
|
||||
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
|
||||
a.log.Warnf("Failed to add remote candidate %s: %v", cand.Address(), err)
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
|
||||
func (a *Agent) resolveAndAddMulticastCandidate(cand *CandidateHost) {
|
||||
if a.mDNSConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, src, err := a.mDNSConn.QueryAddr(c.context(), c.Address())
|
||||
_, src, err := a.mDNSConn.QueryAddr(cand.context(), cand.Address())
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err)
|
||||
a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = c.setIPAddr(src); err != nil {
|
||||
a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err)
|
||||
if err = cand.setIPAddr(src); err != nil {
|
||||
a.log.Warnf("Failed to discover mDNS candidate %s: %v", cand.Address(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
// nolint: contextcheck
|
||||
a.addRemoteCandidate(c)
|
||||
a.addRemoteCandidate(cand)
|
||||
}); err != nil {
|
||||
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
|
||||
a.log.Warnf("Failed to add mDNS candidate %s: %v", cand.Address(), err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -652,9 +685,16 @@ func (a *Agent) requestConnectivityCheck() {
|
||||
}
|
||||
|
||||
func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
|
||||
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback)
|
||||
_, localIPs, err := localInterfaces(
|
||||
a.net,
|
||||
a.interfaceFilter,
|
||||
a.ipFilter,
|
||||
[]NetworkType{remoteCandidate.NetworkType()},
|
||||
a.includeLoopback,
|
||||
)
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -662,19 +702,21 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
|
||||
ip, _, _, err := parseAddr(remoteCandidate.addr())
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
conn := newActiveTCPConn(
|
||||
a.loop,
|
||||
net.JoinHostPort(localIPs[i].String(), "0"),
|
||||
netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())),
|
||||
netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), //nolint:gosec // G115, no overflow, a port
|
||||
a.log,
|
||||
)
|
||||
|
||||
tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", errInvalidAddress)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -687,48 +729,52 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {
|
||||
})
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "Failed to create Active ICE-TCP Candidate: %v", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
localCandidate.start(a, conn, a.startedCh)
|
||||
a.localCandidates[localCandidate.NetworkType()] = append(a.localCandidates[localCandidate.NetworkType()], localCandidate)
|
||||
a.localCandidates[localCandidate.NetworkType()] = append(
|
||||
a.localCandidates[localCandidate.NetworkType()],
|
||||
localCandidate,
|
||||
)
|
||||
a.candidateNotifier.EnqueueCandidate(localCandidate)
|
||||
|
||||
a.addPair(localCandidate, remoteCandidate)
|
||||
}
|
||||
}
|
||||
|
||||
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run)
|
||||
func (a *Agent) addRemoteCandidate(c Candidate) {
|
||||
set := a.remoteCandidates[c.NetworkType()]
|
||||
// addRemoteCandidate assumes you are holding the lock (must be execute using a.run).
|
||||
func (a *Agent) addRemoteCandidate(cand Candidate) { //nolint:cyclop
|
||||
set := a.remoteCandidates[cand.NetworkType()]
|
||||
|
||||
for _, candidate := range set {
|
||||
if candidate.Equal(c) {
|
||||
if candidate.Equal(cand) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
acceptRemotePassiveTCPCandidate := false
|
||||
// Assert that TCP4 or TCP6 is a enabled NetworkType locally
|
||||
if !a.disableActiveTCP && c.TCPType() == TCPTypePassive {
|
||||
if !a.disableActiveTCP && cand.TCPType() == TCPTypePassive {
|
||||
for _, networkType := range a.networkTypes {
|
||||
if c.NetworkType() == networkType {
|
||||
if cand.NetworkType() == networkType {
|
||||
acceptRemotePassiveTCPCandidate = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if acceptRemotePassiveTCPCandidate {
|
||||
a.addRemotePassiveTCPCandidate(c)
|
||||
a.addRemotePassiveTCPCandidate(cand)
|
||||
}
|
||||
|
||||
set = append(set, c)
|
||||
a.remoteCandidates[c.NetworkType()] = set
|
||||
set = append(set, cand)
|
||||
a.remoteCandidates[cand.NetworkType()] = set
|
||||
|
||||
if c.TCPType() != TCPTypePassive {
|
||||
if localCandidates, ok := a.localCandidates[c.NetworkType()]; ok {
|
||||
if cand.TCPType() != TCPTypePassive {
|
||||
if localCandidates, ok := a.localCandidates[cand.NetworkType()]; ok {
|
||||
for _, localCandidate := range localCandidates {
|
||||
a.addPair(localCandidate, c)
|
||||
a.addPair(localCandidate, cand)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -736,42 +782,43 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
|
||||
a.requestConnectivityCheck()
|
||||
}
|
||||
|
||||
func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
|
||||
func (a *Agent) addCandidate(ctx context.Context, cand Candidate, candidateConn net.PacketConn) error {
|
||||
return a.loop.Run(ctx, func(context.Context) {
|
||||
set := a.localCandidates[c.NetworkType()]
|
||||
set := a.localCandidates[cand.NetworkType()]
|
||||
for _, candidate := range set {
|
||||
if candidate.Equal(c) {
|
||||
a.log.Debugf("Ignore duplicate candidate: %s", c)
|
||||
if err := c.close(); err != nil {
|
||||
if candidate.Equal(cand) {
|
||||
a.log.Debugf("Ignore duplicate candidate: %s", cand)
|
||||
if err := cand.close(); err != nil {
|
||||
a.log.Warnf("Failed to close duplicate candidate: %v", err)
|
||||
}
|
||||
if err := candidateConn.Close(); err != nil {
|
||||
a.log.Warnf("Failed to close duplicate candidate connection: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.start(a, candidateConn, a.startedCh)
|
||||
cand.start(a, candidateConn, a.startedCh)
|
||||
|
||||
set = append(set, c)
|
||||
a.localCandidates[c.NetworkType()] = set
|
||||
set = append(set, cand)
|
||||
a.localCandidates[cand.NetworkType()] = set
|
||||
|
||||
if remoteCandidates, ok := a.remoteCandidates[c.NetworkType()]; ok {
|
||||
if remoteCandidates, ok := a.remoteCandidates[cand.NetworkType()]; ok {
|
||||
for _, remoteCandidate := range remoteCandidates {
|
||||
a.addPair(c, remoteCandidate)
|
||||
a.addPair(cand, remoteCandidate)
|
||||
}
|
||||
}
|
||||
|
||||
a.requestConnectivityCheck()
|
||||
|
||||
if !c.filterForLocationTracking() {
|
||||
a.candidateNotifier.EnqueueCandidate(c)
|
||||
if !cand.filterForLocationTracking() {
|
||||
a.candidateNotifier.EnqueueCandidate(cand)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetRemoteCandidates returns the remote candidates
|
||||
// GetRemoteCandidates returns the remote candidates.
|
||||
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
|
||||
var res []Candidate
|
||||
|
||||
@@ -789,7 +836,7 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetLocalCandidates returns the local candidates
|
||||
// GetLocalCandidates returns the local candidates.
|
||||
func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
|
||||
var res []Candidate
|
||||
|
||||
@@ -812,7 +859,7 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetLocalUserCredentials returns the local user credentials
|
||||
// GetLocalUserCredentials returns the local user credentials.
|
||||
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
|
||||
valSet := make(chan struct{})
|
||||
err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
@@ -824,10 +871,11 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
|
||||
if err == nil {
|
||||
<-valSet
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetRemoteUserCredentials returns the remote user credentials
|
||||
// GetRemoteUserCredentials returns the remote user credentials.
|
||||
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
|
||||
valSet := make(chan struct{})
|
||||
err = a.loop.Run(a.loop, func(_ context.Context) {
|
||||
@@ -839,6 +887,7 @@ func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error)
|
||||
if err == nil {
|
||||
<-valSet
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -854,7 +903,7 @@ func (a *Agent) removeUfragFromMux() {
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up the Agent
|
||||
// Close cleans up the Agent.
|
||||
func (a *Agent) Close() error {
|
||||
return a.close(false)
|
||||
}
|
||||
@@ -875,13 +924,14 @@ func (a *Agent) close(graceful bool) error {
|
||||
a.connectionStateNotifier.Close(graceful)
|
||||
a.candidateNotifier.Close(graceful)
|
||||
a.selectedCandidatePairNotifier.Close(graceful)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove all candidates. This closes any listening sockets
|
||||
// and removes both the local and remote candidate lists.
|
||||
//
|
||||
// This is used for restarts, failures and on close
|
||||
// This is used for restarts, failures and on close.
|
||||
func (a *Agent) deleteAllCandidates() {
|
||||
for net, cs := range a.localCandidates {
|
||||
for _, c := range cs {
|
||||
@@ -905,6 +955,7 @@ func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Cand
|
||||
ip, port, _, err := parseAddr(addr)
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to parse address: %s; error: %s", addr, err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -914,6 +965,7 @@ func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Cand
|
||||
return c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -937,6 +989,7 @@ func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) {
|
||||
ip, port, _, err := parseAddr(base.addr())
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -976,70 +1029,85 @@ func (a *Agent) invalidatePendingBindingRequests(filterTime time.Time) {
|
||||
}
|
||||
|
||||
// Assert that the passed TransactionID is in our pendingBindingRequests and returns the destination
|
||||
// If the bindingRequest was valid remove it from our pending cache
|
||||
// If the bindingRequest was valid remove it from our pending cache.
|
||||
func (a *Agent) handleInboundBindingSuccess(id [stun.TransactionIDSize]byte) (bool, *bindingRequest, time.Duration) {
|
||||
a.invalidatePendingBindingRequests(time.Now())
|
||||
for i := range a.pendingBindingRequests {
|
||||
if a.pendingBindingRequests[i].transactionID == id {
|
||||
validBindingRequest := a.pendingBindingRequests[i]
|
||||
a.pendingBindingRequests = append(a.pendingBindingRequests[:i], a.pendingBindingRequests[i+1:]...)
|
||||
|
||||
return true, &validBindingRequest, time.Since(validBindingRequest.timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil, 0
|
||||
}
|
||||
|
||||
// handleInbound processes STUN traffic from a remote candidate
|
||||
func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit
|
||||
// handleInbound processes STUN traffic from a remote candidate.
|
||||
func (a *Agent) handleInbound(msg *stun.Message, local Candidate, remote net.Addr) { //nolint:gocognit,cyclop
|
||||
var err error
|
||||
if m == nil || local == nil {
|
||||
if msg == nil || local == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if m.Type.Method != stun.MethodBinding ||
|
||||
!(m.Type.Class == stun.ClassSuccessResponse ||
|
||||
m.Type.Class == stun.ClassRequest ||
|
||||
m.Type.Class == stun.ClassIndication) {
|
||||
a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, m.Type.Class, m.Type.Method)
|
||||
if msg.Type.Method != stun.MethodBinding ||
|
||||
!(msg.Type.Class == stun.ClassSuccessResponse ||
|
||||
msg.Type.Class == stun.ClassRequest ||
|
||||
msg.Type.Class == stun.ClassIndication) {
|
||||
a.log.Tracef("Unhandled STUN from %s to %s class(%s) method(%s)", remote, local, msg.Type.Class, msg.Type.Method)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if a.isControlling {
|
||||
if m.Contains(stun.AttrICEControlling) {
|
||||
if msg.Contains(stun.AttrICEControlling) {
|
||||
a.log.Debug("Inbound STUN message: isControlling && a.isControlling == true")
|
||||
|
||||
return
|
||||
} else if m.Contains(stun.AttrUseCandidate) {
|
||||
} else if msg.Contains(stun.AttrUseCandidate) {
|
||||
a.log.Debug("Inbound STUN message: useCandidate && a.isControlling == true")
|
||||
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if m.Contains(stun.AttrICEControlled) {
|
||||
if msg.Contains(stun.AttrICEControlled) {
|
||||
a.log.Debug("Inbound STUN message: isControlled && a.isControlling == false")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
remoteCandidate := a.findRemoteCandidate(local.NetworkType(), remote)
|
||||
if m.Type.Class == stun.ClassSuccessResponse {
|
||||
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(m); err != nil {
|
||||
if msg.Type.Class == stun.ClassSuccessResponse { //nolint:nestif
|
||||
if err = stun.MessageIntegrity([]byte(a.remotePwd)).Check(msg); err != nil {
|
||||
a.log.Warnf("Discard message from (%s), %v", remote, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if remoteCandidate == nil {
|
||||
a.log.Warnf("Discard success message from (%s), no such remote", remote)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
a.selector.HandleSuccessResponse(m, local, remoteCandidate, remote)
|
||||
} else if m.Type.Class == stun.ClassRequest {
|
||||
a.log.Tracef("Inbound STUN (Request) from %s to %s, useCandidate: %v", remote, local, m.Contains(stun.AttrUseCandidate))
|
||||
a.selector.HandleSuccessResponse(msg, local, remoteCandidate, remote)
|
||||
} else if msg.Type.Class == stun.ClassRequest {
|
||||
a.log.Tracef(
|
||||
"Inbound STUN (Request) from %s to %s, useCandidate: %v",
|
||||
remote,
|
||||
local,
|
||||
msg.Contains(stun.AttrUseCandidate),
|
||||
)
|
||||
|
||||
if err = stunx.AssertUsername(m, a.localUfrag+":"+a.remoteUfrag); err != nil {
|
||||
if err = stunx.AssertUsername(msg, a.localUfrag+":"+a.remoteUfrag); err != nil {
|
||||
a.log.Warnf("Discard message from (%s), %v", remote, err)
|
||||
|
||||
return
|
||||
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(m); err != nil {
|
||||
} else if err = stun.MessageIntegrity([]byte(a.localPwd)).Check(msg); err != nil {
|
||||
a.log.Warnf("Discard message from (%s), %v", remote, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1047,6 +1115,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
|
||||
ip, port, networkType, err := parseAddr(remote)
|
||||
if err != nil {
|
||||
a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1062,6 +1131,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
|
||||
prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)
|
||||
if err != nil {
|
||||
a.log.Errorf("Failed to create new remote prflx candidate (%s)", err)
|
||||
|
||||
return
|
||||
}
|
||||
remoteCandidate = prflxCandidate
|
||||
@@ -1070,7 +1140,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
|
||||
a.addRemoteCandidate(remoteCandidate)
|
||||
}
|
||||
|
||||
a.selector.HandleBindingRequest(m, local, remoteCandidate)
|
||||
a.selector.HandleBindingRequest(msg, local, remoteCandidate)
|
||||
}
|
||||
|
||||
if remoteCandidate != nil {
|
||||
@@ -1079,7 +1149,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
|
||||
}
|
||||
|
||||
// validateNonSTUNTraffic processes non STUN traffic from a remote candidate,
|
||||
// and returns true if it is an actual remote candidate
|
||||
// and returns true if it is an actual remote candidate.
|
||||
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
|
||||
var remoteCandidate Candidate
|
||||
if err := a.loop.Run(local.context(), func(context.Context) {
|
||||
@@ -1094,7 +1164,7 @@ func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candid
|
||||
return remoteCandidate, remoteCandidate != nil
|
||||
}
|
||||
|
||||
// GetSelectedCandidatePair returns the selected pair or nil if there is none
|
||||
// GetSelectedCandidatePair returns the selected pair or nil if there is none.
|
||||
func (a *Agent) GetSelectedCandidatePair() (*CandidatePair, error) {
|
||||
selectedPair := a.getSelectedPair()
|
||||
if selectedPair == nil {
|
||||
@@ -1130,7 +1200,7 @@ func (a *Agent) closeMulticastConn() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetRemoteCredentials sets the credentials of the remote agent
|
||||
// SetRemoteCredentials sets the credentials of the remote agent.
|
||||
func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
|
||||
switch {
|
||||
case remoteUfrag == "":
|
||||
@@ -1152,7 +1222,7 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
|
||||
// cancel it.
|
||||
// After a Restart, the user must then call GatherCandidates explicitly
|
||||
// to start generating new ones.
|
||||
func (a *Agent) Restart(ufrag, pwd string) error {
|
||||
func (a *Agent) Restart(ufrag, pwd string) error { //nolint:cyclop
|
||||
if ufrag == "" {
|
||||
var err error
|
||||
ufrag, err = generateUFrag()
|
||||
@@ -1204,6 +1274,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {
|
||||
}); runErr != nil {
|
||||
return runErr
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1221,6 +1292,7 @@ func (a *Agent) setGatheringState(newState GatheringState) error {
|
||||
}
|
||||
|
||||
<-done
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -14,44 +14,44 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase
|
||||
// defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase.
|
||||
defaultCheckInterval = 200 * time.Millisecond
|
||||
|
||||
// keepaliveInterval used to keep candidates alive
|
||||
// keepaliveInterval used to keep candidates alive.
|
||||
defaultKeepaliveInterval = 2 * time.Second
|
||||
|
||||
// defaultDisconnectedTimeout is the default time till an Agent transitions disconnected
|
||||
// defaultDisconnectedTimeout is the default time till an Agent transitions disconnected.
|
||||
defaultDisconnectedTimeout = 5 * time.Second
|
||||
|
||||
// defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected
|
||||
// defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected.
|
||||
defaultFailedTimeout = 25 * time.Second
|
||||
|
||||
// defaultHostAcceptanceMinWait is the wait time before nominating a host candidate
|
||||
// defaultHostAcceptanceMinWait is the wait time before nominating a host candidate.
|
||||
defaultHostAcceptanceMinWait = 0
|
||||
|
||||
// defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate
|
||||
// defaultSrflxAcceptanceMinWait is the wait time before nominating a srflx candidate.
|
||||
defaultSrflxAcceptanceMinWait = 500 * time.Millisecond
|
||||
|
||||
// defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate
|
||||
// defaultPrflxAcceptanceMinWait is the wait time before nominating a prflx candidate.
|
||||
defaultPrflxAcceptanceMinWait = 1000 * time.Millisecond
|
||||
|
||||
// defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate
|
||||
// defaultRelayAcceptanceMinWait is the wait time before nominating a relay candidate.
|
||||
defaultRelayAcceptanceMinWait = 2000 * time.Millisecond
|
||||
|
||||
// defaultSTUNGatherTimeout is the wait time for STUN responses
|
||||
// defaultSTUNGatherTimeout is the wait time for STUN responses.
|
||||
defaultSTUNGatherTimeout = 5 * time.Second
|
||||
|
||||
// defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed
|
||||
// defaultMaxBindingRequests is the maximum number of binding requests before considering a pair failed.
|
||||
defaultMaxBindingRequests = 7
|
||||
|
||||
// TCPPriorityOffset is a number which is subtracted from the default (UDP) candidate type preference
|
||||
// for host, srflx and prfx candidate types.
|
||||
defaultTCPPriorityOffset = 27
|
||||
|
||||
// maxBufferSize is the number of bytes that can be buffered before we start to error
|
||||
// maxBufferSize is the number of bytes that can be buffered before we start to error.
|
||||
maxBufferSize = 1000 * 1000 // 1MB
|
||||
|
||||
// maxBindingRequestTimeout is the wait time before binding requests can be deleted
|
||||
// maxBindingRequestTimeout is the wait time before binding requests can be deleted.
|
||||
maxBindingRequestTimeout = 4000 * time.Millisecond
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ func defaultCandidateTypes() []CandidateType {
|
||||
}
|
||||
|
||||
// AgentConfig collects the arguments to ice.Agent construction into
|
||||
// a single structure, for future-proofness of the interface
|
||||
// a single structure, for future-proofness of the interface.
|
||||
type AgentConfig struct {
|
||||
Urls []*stun.URI
|
||||
|
||||
@@ -209,109 +209,111 @@ type AgentConfig struct {
|
||||
EnableUseCandidateCheckPriority bool
|
||||
}
|
||||
|
||||
// initWithDefaults populates an agent and falls back to defaults if fields are unset
|
||||
func (config *AgentConfig) initWithDefaults(a *Agent) {
|
||||
// initWithDefaults populates an agent and falls back to defaults if fields are unset.
|
||||
func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop
|
||||
if config.MaxBindingRequests == nil {
|
||||
a.maxBindingRequests = defaultMaxBindingRequests
|
||||
agent.maxBindingRequests = defaultMaxBindingRequests
|
||||
} else {
|
||||
a.maxBindingRequests = *config.MaxBindingRequests
|
||||
agent.maxBindingRequests = *config.MaxBindingRequests
|
||||
}
|
||||
|
||||
if config.HostAcceptanceMinWait == nil {
|
||||
a.hostAcceptanceMinWait = defaultHostAcceptanceMinWait
|
||||
agent.hostAcceptanceMinWait = defaultHostAcceptanceMinWait
|
||||
} else {
|
||||
a.hostAcceptanceMinWait = *config.HostAcceptanceMinWait
|
||||
agent.hostAcceptanceMinWait = *config.HostAcceptanceMinWait
|
||||
}
|
||||
|
||||
if config.SrflxAcceptanceMinWait == nil {
|
||||
a.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait
|
||||
agent.srflxAcceptanceMinWait = defaultSrflxAcceptanceMinWait
|
||||
} else {
|
||||
a.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait
|
||||
agent.srflxAcceptanceMinWait = *config.SrflxAcceptanceMinWait
|
||||
}
|
||||
|
||||
if config.PrflxAcceptanceMinWait == nil {
|
||||
a.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait
|
||||
agent.prflxAcceptanceMinWait = defaultPrflxAcceptanceMinWait
|
||||
} else {
|
||||
a.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait
|
||||
agent.prflxAcceptanceMinWait = *config.PrflxAcceptanceMinWait
|
||||
}
|
||||
|
||||
if config.RelayAcceptanceMinWait == nil {
|
||||
a.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
|
||||
agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
|
||||
} else {
|
||||
a.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
|
||||
agent.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
|
||||
}
|
||||
|
||||
if config.STUNGatherTimeout == nil {
|
||||
a.stunGatherTimeout = defaultSTUNGatherTimeout
|
||||
agent.stunGatherTimeout = defaultSTUNGatherTimeout
|
||||
} else {
|
||||
a.stunGatherTimeout = *config.STUNGatherTimeout
|
||||
agent.stunGatherTimeout = *config.STUNGatherTimeout
|
||||
}
|
||||
|
||||
if config.TCPPriorityOffset == nil {
|
||||
a.tcpPriorityOffset = defaultTCPPriorityOffset
|
||||
agent.tcpPriorityOffset = defaultTCPPriorityOffset
|
||||
} else {
|
||||
a.tcpPriorityOffset = *config.TCPPriorityOffset
|
||||
agent.tcpPriorityOffset = *config.TCPPriorityOffset
|
||||
}
|
||||
|
||||
if config.DisconnectedTimeout == nil {
|
||||
a.disconnectedTimeout = defaultDisconnectedTimeout
|
||||
agent.disconnectedTimeout = defaultDisconnectedTimeout
|
||||
} else {
|
||||
a.disconnectedTimeout = *config.DisconnectedTimeout
|
||||
agent.disconnectedTimeout = *config.DisconnectedTimeout
|
||||
}
|
||||
|
||||
if config.FailedTimeout == nil {
|
||||
a.failedTimeout = defaultFailedTimeout
|
||||
agent.failedTimeout = defaultFailedTimeout
|
||||
} else {
|
||||
a.failedTimeout = *config.FailedTimeout
|
||||
agent.failedTimeout = *config.FailedTimeout
|
||||
}
|
||||
|
||||
if config.KeepaliveInterval == nil {
|
||||
a.keepaliveInterval = defaultKeepaliveInterval
|
||||
agent.keepaliveInterval = defaultKeepaliveInterval
|
||||
} else {
|
||||
a.keepaliveInterval = *config.KeepaliveInterval
|
||||
agent.keepaliveInterval = *config.KeepaliveInterval
|
||||
}
|
||||
|
||||
if config.CheckInterval == nil {
|
||||
a.checkInterval = defaultCheckInterval
|
||||
agent.checkInterval = defaultCheckInterval
|
||||
} else {
|
||||
a.checkInterval = *config.CheckInterval
|
||||
agent.checkInterval = *config.CheckInterval
|
||||
}
|
||||
|
||||
if len(config.CandidateTypes) == 0 {
|
||||
a.candidateTypes = defaultCandidateTypes()
|
||||
agent.candidateTypes = defaultCandidateTypes()
|
||||
} else {
|
||||
a.candidateTypes = config.CandidateTypes
|
||||
agent.candidateTypes = config.CandidateTypes
|
||||
}
|
||||
}
|
||||
|
||||
func (config *AgentConfig) initExtIPMapping(a *Agent) error {
|
||||
func (config *AgentConfig) initExtIPMapping(agent *Agent) error { //nolint:cyclop
|
||||
var err error
|
||||
a.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs)
|
||||
agent.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if a.extIPMapper == nil {
|
||||
if agent.extIPMapper == nil {
|
||||
return nil // This may happen when config.NAT1To1IPs is an empty array
|
||||
}
|
||||
if a.extIPMapper.candidateType == CandidateTypeHost {
|
||||
if a.mDNSMode == MulticastDNSModeQueryAndGather {
|
||||
if agent.extIPMapper.candidateType == CandidateTypeHost { //nolint:nestif
|
||||
if agent.mDNSMode == MulticastDNSModeQueryAndGather {
|
||||
return ErrMulticastDNSWithNAT1To1IPMapping
|
||||
}
|
||||
candiHostEnabled := false
|
||||
for _, candiType := range a.candidateTypes {
|
||||
for _, candiType := range agent.candidateTypes {
|
||||
if candiType == CandidateTypeHost {
|
||||
candiHostEnabled = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
if !candiHostEnabled {
|
||||
return ErrIneffectiveNAT1To1IPMappingHost
|
||||
}
|
||||
} else if a.extIPMapper.candidateType == CandidateTypeServerReflexive {
|
||||
} else if agent.extIPMapper.candidateType == CandidateTypeServerReflexive {
|
||||
candiSrflxEnabled := false
|
||||
for _, candiType := range a.candidateTypes {
|
||||
for _, candiType := range agent.candidateTypes {
|
||||
if candiType == CandidateTypeServerReflexive {
|
||||
candiSrflxEnabled = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -319,5 +321,6 @@ func (config *AgentConfig) initExtIPMapping(a *Agent) error {
|
||||
return ErrIneffectiveNAT1To1IPMappingSrflx
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -32,6 +32,8 @@ func TestAgentGetBestValidCandidatePair(t *testing.T) {
|
||||
}
|
||||
|
||||
func setupTestAgentGetBestValidCandidatePair(t *testing.T) *TestAgentGetBestValidCandidatePairFixture {
|
||||
t.Helper()
|
||||
|
||||
fixture := new(TestAgentGetBestValidCandidatePairFixture)
|
||||
fixture.hostLocal = newHostLocal(t)
|
||||
fixture.relayRemote = newRelayRemote(t)
|
||||
|
@@ -5,16 +5,18 @@ package ice
|
||||
|
||||
import "sync"
|
||||
|
||||
// OnConnectionStateChange sets a handler that is fired when the connection state changes
|
||||
// OnConnectionStateChange sets a handler that is fired when the connection state changes.
|
||||
func (a *Agent) OnConnectionStateChange(f func(ConnectionState)) error {
|
||||
a.onConnectionStateChangeHdlr.Store(f)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate
|
||||
// pair is selected
|
||||
// OnSelectedCandidatePairChange sets a handler that is fired when the final candidate.
|
||||
// pair is selected.
|
||||
func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) error {
|
||||
a.onSelectedCandidatePairChangeHdlr.Store(f)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -22,6 +24,7 @@ func (a *Agent) OnSelectedCandidatePairChange(f func(Candidate, Candidate)) erro
|
||||
// the gathering process complete the last candidate is nil.
|
||||
func (a *Agent) OnCandidate(f func(Candidate)) error {
|
||||
a.onCandidateHdlr.Store(f)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,6 +76,7 @@ func (h *handlerNotifier) Close(graceful bool) {
|
||||
select {
|
||||
case <-h.done:
|
||||
h.Unlock()
|
||||
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -80,7 +84,7 @@ func (h *handlerNotifier) Close(graceful bool) {
|
||||
h.Unlock()
|
||||
}
|
||||
|
||||
func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
|
||||
func (h *handlerNotifier) EnqueueConnectionState(state ConnectionState) {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
@@ -97,6 +101,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
|
||||
if len(h.connectionStates) == 0 {
|
||||
h.running = false
|
||||
h.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
notification := h.connectionStates[0]
|
||||
@@ -106,7 +111,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
|
||||
}
|
||||
}
|
||||
|
||||
h.connectionStates = append(h.connectionStates, s)
|
||||
h.connectionStates = append(h.connectionStates, state)
|
||||
if !h.running {
|
||||
h.running = true
|
||||
h.notifiers.Add(1)
|
||||
@@ -114,7 +119,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
|
||||
func (h *handlerNotifier) EnqueueCandidate(cand Candidate) {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
@@ -131,6 +136,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
|
||||
if len(h.candidates) == 0 {
|
||||
h.running = false
|
||||
h.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
notification := h.candidates[0]
|
||||
@@ -140,7 +146,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
|
||||
}
|
||||
}
|
||||
|
||||
h.candidates = append(h.candidates, c)
|
||||
h.candidates = append(h.candidates, cand)
|
||||
if !h.running {
|
||||
h.running = true
|
||||
h.notifiers.Add(1)
|
||||
@@ -148,7 +154,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
|
||||
func (h *handlerNotifier) EnqueueSelectedCandidatePair(pair *CandidatePair) {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
|
||||
@@ -165,6 +171,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
|
||||
if len(h.selectedCandidatePairs) == 0 {
|
||||
h.running = false
|
||||
h.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
notification := h.selectedCandidatePairs[0]
|
||||
@@ -174,7 +181,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
|
||||
}
|
||||
}
|
||||
|
||||
h.selectedCandidatePairs = append(h.selectedCandidatePairs, p)
|
||||
h.selectedCandidatePairs = append(h.selectedCandidatePairs, pair)
|
||||
if !h.running {
|
||||
h.running = true
|
||||
h.notifiers.Add(1)
|
||||
|
@@ -15,7 +15,7 @@ func TestConnectionStateNotifier(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
updates := make(chan struct{}, 1)
|
||||
c := &handlerNotifier{
|
||||
notifier := &handlerNotifier{
|
||||
connectionStateFunc: func(_ ConnectionState) {
|
||||
updates <- struct{}{}
|
||||
},
|
||||
@@ -24,7 +24,7 @@ func TestConnectionStateNotifier(t *testing.T) {
|
||||
// Enqueue all updates upfront to ensure that it
|
||||
// doesn't block
|
||||
for i := 0; i < 10000; i++ {
|
||||
c.EnqueueConnectionState(ConnectionStateNew)
|
||||
notifier.EnqueueConnectionState(ConnectionStateNew)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
@@ -39,12 +39,12 @@ func TestConnectionStateNotifier(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
<-done
|
||||
c.Close(true)
|
||||
notifier.Close(true)
|
||||
})
|
||||
t.Run("TestUpdateOrdering", func(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
updates := make(chan ConnectionState)
|
||||
c := &handlerNotifier{
|
||||
notifer := &handlerNotifier{
|
||||
connectionStateFunc: func(cs ConnectionState) {
|
||||
updates <- cs
|
||||
},
|
||||
@@ -66,9 +66,9 @@ func TestConnectionStateNotifier(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
for i := 0; i < 10000; i++ {
|
||||
c.EnqueueConnectionState(ConnectionState(i))
|
||||
notifer.EnqueueConnectionState(ConnectionState(i))
|
||||
}
|
||||
<-done
|
||||
c.Close(true)
|
||||
notifer.Close(true)
|
||||
})
|
||||
}
|
||||
|
@@ -34,17 +34,23 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
|
||||
}
|
||||
|
||||
func fixtureTestOnSelectedCandidatePairChange(t *testing.T) (*Agent, *CandidatePair) {
|
||||
t.Helper()
|
||||
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
|
||||
candidatePair := makeCandidatePair(t)
|
||||
|
||||
return agent, candidatePair
|
||||
}
|
||||
|
||||
func makeCandidatePair(t *testing.T) *CandidatePair {
|
||||
t.Helper()
|
||||
|
||||
hostLocal := newHostLocal(t)
|
||||
relayRemote := newRelayRemote(t)
|
||||
|
||||
candidatePair := newCandidatePair(hostLocal, relayRemote, false)
|
||||
|
||||
return candidatePair
|
||||
}
|
||||
|
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetCandidatePairsStats returns a list of candidate pair stats
|
||||
// GetCandidatePairsStats returns a list of candidate pair stats.
|
||||
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
|
||||
var res []CandidatePairStats
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
@@ -49,13 +49,15 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
|
||||
})
|
||||
if err != nil {
|
||||
a.log.Errorf("Failed to get candidate pairs stats: %v", err)
|
||||
|
||||
return []CandidatePairStats{}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// GetSelectedCandidatePairStats returns a candidate pair stats for selected candidate pair.
|
||||
// Returns false if there is no selected pair
|
||||
// Returns false if there is no selected pair.
|
||||
func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) {
|
||||
isAvailable := false
|
||||
var res CandidatePairStats
|
||||
@@ -98,33 +100,34 @@ func (a *Agent) GetSelectedCandidatePairStats() (CandidatePairStats, bool) {
|
||||
})
|
||||
if err != nil {
|
||||
a.log.Errorf("Failed to get selected candidate pair stats: %v", err)
|
||||
|
||||
return CandidatePairStats{}, false
|
||||
}
|
||||
|
||||
return res, isAvailable
|
||||
}
|
||||
|
||||
// GetLocalCandidatesStats returns a list of local candidates stats
|
||||
// GetLocalCandidatesStats returns a list of local candidates stats.
|
||||
func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
|
||||
var res []CandidateStats
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
result := make([]CandidateStats, 0, len(a.localCandidates))
|
||||
for networkType, localCandidates := range a.localCandidates {
|
||||
for _, c := range localCandidates {
|
||||
for _, cand := range localCandidates {
|
||||
relayProtocol := ""
|
||||
if c.Type() == CandidateTypeRelay {
|
||||
if cRelay, ok := c.(*CandidateRelay); ok {
|
||||
if cand.Type() == CandidateTypeRelay {
|
||||
if cRelay, ok := cand.(*CandidateRelay); ok {
|
||||
relayProtocol = cRelay.RelayProtocol()
|
||||
}
|
||||
}
|
||||
stat := CandidateStats{
|
||||
Timestamp: time.Now(),
|
||||
ID: c.ID(),
|
||||
ID: cand.ID(),
|
||||
NetworkType: networkType,
|
||||
IP: c.Address(),
|
||||
Port: c.Port(),
|
||||
CandidateType: c.Type(),
|
||||
Priority: c.Priority(),
|
||||
IP: cand.Address(),
|
||||
Port: cand.Port(),
|
||||
CandidateType: cand.Type(),
|
||||
Priority: cand.Priority(),
|
||||
// URL string
|
||||
RelayProtocol: relayProtocol,
|
||||
// Deleted bool
|
||||
@@ -136,12 +139,14 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
|
||||
})
|
||||
if err != nil {
|
||||
a.log.Errorf("Failed to get candidate pair stats: %v", err)
|
||||
|
||||
return []CandidateStats{}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// GetRemoteCandidatesStats returns a list of remote candidates stats
|
||||
// GetRemoteCandidatesStats returns a list of remote candidates stats.
|
||||
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
|
||||
var res []CandidateStats
|
||||
err := a.loop.Run(a.loop, func(_ context.Context) {
|
||||
@@ -166,7 +171,9 @@ func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
|
||||
})
|
||||
if err != nil {
|
||||
a.log.Errorf("Failed to get candidate pair stats: %v", err)
|
||||
|
||||
return []CandidateStats{}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
274
agent_test.go
274
agent_test.go
@@ -33,21 +33,21 @@ func (ba *BadAddr) String() string {
|
||||
return "yyy"
|
||||
}
|
||||
|
||||
func TestHandlePeerReflexive(t *testing.T) {
|
||||
func TestHandlePeerReflexive(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
// Limit runtime in case of deadlocks
|
||||
defer test.TimeOut(time.Second * 2).Stop()
|
||||
|
||||
t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
|
||||
agent.selector = &controllingSelector{agent: agent, log: agent.log}
|
||||
|
||||
hostConfig := CandidateHostConfig{
|
||||
Network: "udp",
|
||||
@@ -64,25 +64,25 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
|
||||
|
||||
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
|
||||
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
|
||||
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
|
||||
UseCandidate(),
|
||||
AttrControlling(a.tieBreaker),
|
||||
AttrControlling(agent.tieBreaker),
|
||||
PriorityAttr(local.Priority()),
|
||||
stun.NewShortTermIntegrity(a.localPwd),
|
||||
stun.NewShortTermIntegrity(agent.localPwd),
|
||||
stun.Fingerprint,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(msg, local, remote)
|
||||
agent.handleInbound(msg, local, remote)
|
||||
|
||||
// Length of remote candidate list must be one now
|
||||
if len(a.remoteCandidates) != 1 {
|
||||
if len(agent.remoteCandidates) != 1 {
|
||||
t.Fatal("failed to add a network type to the remote candidate list")
|
||||
}
|
||||
|
||||
// Length of remote candidate list for a network type must be 1
|
||||
set := a.remoteCandidates[local.NetworkType()]
|
||||
set := agent.remoteCandidates[local.NetworkType()]
|
||||
if len(set) != 1 {
|
||||
t.Fatal("failed to add prflx candidate to remote candidate list")
|
||||
}
|
||||
@@ -104,14 +104,14 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Bad network type with handleInbound()", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
|
||||
agent.selector = &controllingSelector{agent: agent, log: agent.log}
|
||||
|
||||
hostConfig := CandidateHostConfig{
|
||||
Network: "tcp",
|
||||
@@ -127,26 +127,26 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
remote := &BadAddr{}
|
||||
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(nil, local, remote)
|
||||
agent.handleInbound(nil, local, remote)
|
||||
|
||||
if len(a.remoteCandidates) != 0 {
|
||||
if len(agent.remoteCandidates) != 0 {
|
||||
t.Fatal("bad address should not be added to the remote candidate list")
|
||||
}
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
|
||||
agent.selector = &controllingSelector{agent: agent, log: agent.log}
|
||||
tID := [stun.TransactionIDSize]byte{}
|
||||
copy(tID[:], "ABC")
|
||||
a.pendingBindingRequests = []bindingRequest{
|
||||
agent.pendingBindingRequests = []bindingRequest{
|
||||
{time.Now(), tID, &net.UDPAddr{}, false},
|
||||
}
|
||||
|
||||
@@ -164,14 +164,14 @@ func TestHandlePeerReflexive(t *testing.T) {
|
||||
|
||||
remote := &net.UDPAddr{IP: net.ParseIP("172.17.0.3"), Port: 999}
|
||||
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
|
||||
stun.NewShortTermIntegrity(a.remotePwd),
|
||||
stun.NewShortTermIntegrity(agent.remotePwd),
|
||||
stun.Fingerprint,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(msg, local, remote)
|
||||
if len(a.remoteCandidates) != 0 {
|
||||
agent.handleInbound(msg, local, remote)
|
||||
if len(agent.remoteCandidates) != 0 {
|
||||
t.Fatal("unknown remote was able to create a candidate")
|
||||
}
|
||||
}))
|
||||
@@ -281,6 +281,7 @@ func TestConnectivityOnStartup(t *testing.T) {
|
||||
|
||||
// Ensure accepted
|
||||
<-accepted
|
||||
|
||||
return aConn, bConn
|
||||
}(aAgent, bAgent)
|
||||
|
||||
@@ -308,9 +309,9 @@ func TestConnectivityLite(t *testing.T) {
|
||||
MappingBehavior: vnet.EndpointIndependent,
|
||||
FilteringBehavior: vnet.EndpointIndependent,
|
||||
}
|
||||
v, err := buildVNet(natType, natType)
|
||||
vent, err := buildVNet(natType, natType)
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer v.close()
|
||||
defer vent.close()
|
||||
|
||||
aNotifier, aConnected := onConnected()
|
||||
bNotifier, bConnected := onConnected()
|
||||
@@ -319,7 +320,7 @@ func TestConnectivityLite(t *testing.T) {
|
||||
Urls: []*stun.URI{stunServerURL},
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
MulticastDNSMode: MulticastDNSModeDisabled,
|
||||
Net: v.net0,
|
||||
Net: vent.net0,
|
||||
}
|
||||
|
||||
aAgent, err := NewAgent(cfg0)
|
||||
@@ -335,7 +336,7 @@ func TestConnectivityLite(t *testing.T) {
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
MulticastDNSMode: MulticastDNSModeDisabled,
|
||||
Net: v.net1,
|
||||
Net: vent.net1,
|
||||
}
|
||||
|
||||
bAgent, err := NewAgent(cfg1)
|
||||
@@ -353,7 +354,7 @@ func TestConnectivityLite(t *testing.T) {
|
||||
<-bConnected
|
||||
}
|
||||
|
||||
func TestInboundValidity(t *testing.T) {
|
||||
func TestInboundValidity(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
buildMsg := func(class stun.MessageClass, username, key string) *stun.Message {
|
||||
@@ -381,21 +382,21 @@ func TestInboundValidity(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Invalid Binding requests should be discarded", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassRequest, "invalid", a.localPwd), local, remote)
|
||||
if len(a.remoteCandidates) == 1 {
|
||||
agent.handleInbound(buildMsg(stun.ClassRequest, "invalid", agent.localPwd), local, remote)
|
||||
if len(agent.remoteCandidates) == 1 {
|
||||
t.Fatal("Binding with invalid Username was able to create prflx candidate")
|
||||
}
|
||||
|
||||
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote)
|
||||
if len(a.remoteCandidates) == 1 {
|
||||
agent.handleInbound(buildMsg(stun.ClassRequest, agent.localUfrag+":"+agent.remoteUfrag, "Invalid"), local, remote)
|
||||
if len(agent.remoteCandidates) == 1 {
|
||||
t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate")
|
||||
}
|
||||
})
|
||||
@@ -452,35 +453,35 @@ func TestInboundValidity(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Valid bind without fingerprint", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) {
|
||||
a.selector = &controllingSelector{agent: a, log: a.log}
|
||||
require.NoError(t, agent.loop.Run(agent.loop, func(_ context.Context) {
|
||||
agent.selector = &controllingSelector{agent: agent, log: agent.log}
|
||||
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
|
||||
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
|
||||
stun.NewShortTermIntegrity(a.localPwd),
|
||||
stun.NewUsername(agent.localUfrag+":"+agent.remoteUfrag),
|
||||
stun.NewShortTermIntegrity(agent.localPwd),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(msg, local, remote)
|
||||
if len(a.remoteCandidates) != 1 {
|
||||
agent.handleInbound(msg, local, remote)
|
||||
if len(agent.remoteCandidates) != 1 {
|
||||
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
|
||||
}
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("Success with invalid TransactionID", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent")
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
hostConfig := CandidateHostConfig{
|
||||
@@ -499,13 +500,13 @@ func TestInboundValidity(t *testing.T) {
|
||||
tID := [stun.TransactionIDSize]byte{}
|
||||
copy(tID[:], "ABC")
|
||||
msg, err := stun.Build(stun.BindingSuccess, stun.NewTransactionIDSetter(tID),
|
||||
stun.NewShortTermIntegrity(a.remotePwd),
|
||||
stun.NewShortTermIntegrity(agent.remotePwd),
|
||||
stun.Fingerprint,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.handleInbound(msg, local, remote)
|
||||
if len(a.remoteCandidates) != 0 {
|
||||
agent.handleInbound(msg, local, remote)
|
||||
if len(agent.remoteCandidates) != 0 {
|
||||
t.Fatal("unknown remote was able to create a candidate")
|
||||
}
|
||||
})
|
||||
@@ -514,35 +515,35 @@ func TestInboundValidity(t *testing.T) {
|
||||
func TestInvalidAgentStarts(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if _, err = a.Dial(ctx, "", "bar"); err != nil && !errors.Is(err, ErrRemoteUfragEmpty) {
|
||||
if _, err = agent.Dial(ctx, "", "bar"); err != nil && !errors.Is(err, ErrRemoteUfragEmpty) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err = a.Dial(ctx, "foo", ""); err != nil && !errors.Is(err, ErrRemotePwdEmpty) {
|
||||
if _, err = agent.Dial(ctx, "foo", ""); err != nil && !errors.Is(err, ErrRemotePwdEmpty) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err = a.Dial(ctx, "foo", "bar"); err != nil && !errors.Is(err, ErrCanceledByCaller) {
|
||||
if _, err = agent.Dial(ctx, "foo", "bar"); err != nil && !errors.Is(err, ErrCanceledByCaller) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err = a.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) {
|
||||
if _, err = agent.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages
|
||||
func TestConnectionStateCallback(t *testing.T) {
|
||||
// Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages.
|
||||
func TestConnectionStateCallback(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
defer test.TimeOut(time.Second * 5).Stop()
|
||||
@@ -635,18 +636,18 @@ func TestInvalidGather(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCandidatePairsStats(t *testing.T) {
|
||||
func TestCandidatePairsStats(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
// Avoid deadlocks?
|
||||
defer test.TimeOut(1 * time.Second).Stop()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create agent: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
hostConfig := &CandidateHostConfig{
|
||||
@@ -711,21 +712,21 @@ func TestCandidatePairsStats(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, remote := range []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote} {
|
||||
p := a.findPair(hostLocal, remote)
|
||||
p := agent.findPair(hostLocal, remote)
|
||||
|
||||
if p == nil {
|
||||
a.addPair(hostLocal, remote)
|
||||
agent.addPair(hostLocal, remote)
|
||||
}
|
||||
}
|
||||
|
||||
p := a.findPair(hostLocal, prflxRemote)
|
||||
p := agent.findPair(hostLocal, prflxRemote)
|
||||
p.state = CandidatePairStateFailed
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
|
||||
}
|
||||
|
||||
stats := a.GetCandidatePairsStats()
|
||||
stats := agent.GetCandidatePairsStats()
|
||||
if len(stats) != 4 {
|
||||
t.Fatal("expected 4 candidate pairs stats")
|
||||
}
|
||||
@@ -789,18 +790,18 @@ func TestCandidatePairsStats(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectedCandidatePairStats(t *testing.T) {
|
||||
func TestSelectedCandidatePairStats(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
// Avoid deadlocks?
|
||||
defer test.TimeOut(1 * time.Second).Stop()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create agent: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
hostConfig := &CandidateHostConfig{
|
||||
@@ -828,23 +829,23 @@ func TestSelectedCandidatePairStats(t *testing.T) {
|
||||
}
|
||||
|
||||
// no selected pair, should return not available
|
||||
_, ok := a.GetSelectedCandidatePairStats()
|
||||
_, ok := agent.GetSelectedCandidatePairStats()
|
||||
require.False(t, ok)
|
||||
|
||||
// add pair and populate some RTT stats
|
||||
p := a.findPair(hostLocal, srflxRemote)
|
||||
p := agent.findPair(hostLocal, srflxRemote)
|
||||
if p == nil {
|
||||
a.addPair(hostLocal, srflxRemote)
|
||||
p = a.findPair(hostLocal, srflxRemote)
|
||||
agent.addPair(hostLocal, srflxRemote)
|
||||
p = agent.findPair(hostLocal, srflxRemote)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
p.UpdateRoundTripTime(time.Duration(i+1) * time.Second)
|
||||
}
|
||||
|
||||
// set the pair as selected
|
||||
a.setSelectedPair(p)
|
||||
agent.setSelectedPair(p)
|
||||
|
||||
stats, ok := a.GetSelectedCandidatePairStats()
|
||||
stats, ok := agent.GetSelectedCandidatePairStats()
|
||||
require.True(t, ok)
|
||||
|
||||
if stats.LocalCandidateID != hostLocal.ID() {
|
||||
@@ -872,18 +873,18 @@ func TestSelectedCandidatePairStats(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalCandidateStats(t *testing.T) {
|
||||
func TestLocalCandidateStats(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
// Avoid deadlocks?
|
||||
defer test.TimeOut(1 * time.Second).Stop()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create agent: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
hostConfig := &CandidateHostConfig{
|
||||
@@ -910,9 +911,9 @@ func TestLocalCandidateStats(t *testing.T) {
|
||||
t.Fatalf("Failed to construct local srflx candidate: %s", err)
|
||||
}
|
||||
|
||||
a.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal}
|
||||
agent.localCandidates[NetworkTypeUDP4] = []Candidate{hostLocal, srflxLocal}
|
||||
|
||||
localStats := a.GetLocalCandidatesStats()
|
||||
localStats := agent.GetLocalCandidatesStats()
|
||||
if len(localStats) != 2 {
|
||||
t.Fatalf("expected 2 local candidates stats, got %d instead", len(localStats))
|
||||
}
|
||||
@@ -953,18 +954,18 @@ func TestLocalCandidateStats(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteCandidateStats(t *testing.T) {
|
||||
func TestRemoteCandidateStats(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
// Avoid deadlocks?
|
||||
defer test.TimeOut(1 * time.Second).Stop()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create agent: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
relayConfig := &CandidateRelayConfig{
|
||||
@@ -1017,9 +1018,9 @@ func TestRemoteCandidateStats(t *testing.T) {
|
||||
t.Fatalf("Failed to construct remote host candidate: %s", err)
|
||||
}
|
||||
|
||||
a.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote}
|
||||
agent.remoteCandidates[NetworkTypeUDP4] = []Candidate{relayRemote, srflxRemote, prflxRemote, hostRemote}
|
||||
|
||||
remoteStats := a.GetRemoteCandidatesStats()
|
||||
remoteStats := agent.GetRemoteCandidatesStats()
|
||||
if len(remoteStats) != 4 {
|
||||
t.Fatalf("expected 4 remote candidates stats, got %d instead", len(remoteStats))
|
||||
}
|
||||
@@ -1076,31 +1077,31 @@ func TestRemoteCandidateStats(t *testing.T) {
|
||||
func TestInitExtIPMapping(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
// a.extIPMapper should be nil by default
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
// agent.extIPMapper should be nil by default
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create agent: %v", err)
|
||||
}
|
||||
if a.extIPMapper != nil {
|
||||
require.NoError(t, a.Close())
|
||||
if agent.extIPMapper != nil {
|
||||
require.NoError(t, agent.Close())
|
||||
t.Fatal("a.extIPMapper should be nil by default")
|
||||
}
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
|
||||
// a.extIPMapper should be nil when NAT1To1IPs is a non-nil empty array
|
||||
a, err = NewAgent(&AgentConfig{
|
||||
agent, err = NewAgent(&AgentConfig{
|
||||
NAT1To1IPs: []string{},
|
||||
NAT1To1IPCandidateType: CandidateTypeHost,
|
||||
})
|
||||
if err != nil {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
t.Fatalf("Failed to create agent: %v", err)
|
||||
}
|
||||
if a.extIPMapper != nil {
|
||||
require.NoError(t, a.Close())
|
||||
if agent.extIPMapper != nil {
|
||||
require.NoError(t, agent.Close())
|
||||
t.Fatal("a.extIPMapper should be nil by default")
|
||||
}
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
|
||||
// NewAgent should return an error when 1:1 NAT for host candidate is enabled
|
||||
// but the candidate type does not appear in the CandidateTypes.
|
||||
@@ -1150,32 +1151,38 @@ func TestBindingRequestTimeout(t *testing.T) {
|
||||
|
||||
const expectedRemovalCount = 2
|
||||
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
now := time.Now()
|
||||
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
|
||||
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
|
||||
timestamp: now, // Valid
|
||||
})
|
||||
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
|
||||
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
|
||||
timestamp: now.Add(-3900 * time.Millisecond), // Valid
|
||||
})
|
||||
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
|
||||
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
|
||||
timestamp: now.Add(-4100 * time.Millisecond), // Invalid
|
||||
})
|
||||
a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{
|
||||
agent.pendingBindingRequests = append(agent.pendingBindingRequests, bindingRequest{
|
||||
timestamp: now.Add(-75 * time.Hour), // Invalid
|
||||
})
|
||||
|
||||
a.invalidatePendingBindingRequests(now)
|
||||
require.Equal(t, expectedRemovalCount, len(a.pendingBindingRequests), "Binding invalidation due to timeout did not remove the correct number of binding requests")
|
||||
agent.invalidatePendingBindingRequests(now)
|
||||
|
||||
require.Equal(
|
||||
t,
|
||||
expectedRemovalCount,
|
||||
len(agent.pendingBindingRequests),
|
||||
"Binding invalidation due to timeout did not remove the correct number of binding requests",
|
||||
)
|
||||
}
|
||||
|
||||
// TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard
|
||||
// and ensure it's backwards compatible with previous versions of the pion/ice
|
||||
// and ensure it's backwards compatible with previous versions of the pion/ice.
|
||||
func TestAgentCredentials(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -1207,7 +1214,7 @@ func TestAgentCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
// Assert that Agent on Failure deletes all existing candidates
|
||||
// User can then do an ICE Restart to bring agent back
|
||||
// User can then do an ICE Restart to bring agent back.
|
||||
func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -1254,7 +1261,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides
|
||||
// Assert that the ICE Agent can go directly from Connecting -> Failed on both sides.
|
||||
func TestConnectionStateConnectingToFailed(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -1378,6 +1385,7 @@ func TestAgentRestart(t *testing.T) {
|
||||
out += c.Address() + ":"
|
||||
out += strconv.Itoa(c.Port())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1423,33 +1431,33 @@ func TestAgentRestart(t *testing.T) {
|
||||
|
||||
func TestGetRemoteCredentials(t *testing.T) {
|
||||
var config AgentConfig
|
||||
a, err := NewAgent(&config)
|
||||
agent, err := NewAgent(&config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
a.remoteUfrag = "remoteUfrag"
|
||||
a.remotePwd = "remotePwd"
|
||||
agent.remoteUfrag = "remoteUfrag"
|
||||
agent.remotePwd = "remotePwd"
|
||||
|
||||
actualUfrag, actualPwd, err := a.GetRemoteUserCredentials()
|
||||
actualUfrag, actualPwd, err := agent.GetRemoteUserCredentials()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, actualUfrag, a.remoteUfrag)
|
||||
require.Equal(t, actualPwd, a.remotePwd)
|
||||
require.Equal(t, actualUfrag, agent.remoteUfrag)
|
||||
require.Equal(t, actualPwd, agent.remotePwd)
|
||||
}
|
||||
|
||||
func TestGetRemoteCandidates(t *testing.T) {
|
||||
var config AgentConfig
|
||||
|
||||
a, err := NewAgent(&config)
|
||||
agent, err := NewAgent(&config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
expectedCandidates := []Candidate{}
|
||||
@@ -1467,10 +1475,10 @@ func TestGetRemoteCandidates(t *testing.T) {
|
||||
|
||||
expectedCandidates = append(expectedCandidates, cand)
|
||||
|
||||
a.addRemoteCandidate(cand)
|
||||
agent.addRemoteCandidate(cand)
|
||||
}
|
||||
|
||||
actualCandidates, err := a.GetRemoteCandidates()
|
||||
actualCandidates, err := agent.GetRemoteCandidates()
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, expectedCandidates, actualCandidates)
|
||||
}
|
||||
@@ -1478,12 +1486,12 @@ func TestGetRemoteCandidates(t *testing.T) {
|
||||
func TestGetLocalCandidates(t *testing.T) {
|
||||
var config AgentConfig
|
||||
|
||||
a, err := NewAgent(&config)
|
||||
agent, err := NewAgent(&config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error constructing ice.Agent: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
dummyConn := &net.UDPConn{}
|
||||
@@ -1502,11 +1510,11 @@ func TestGetLocalCandidates(t *testing.T) {
|
||||
|
||||
expectedCandidates = append(expectedCandidates, cand)
|
||||
|
||||
err = a.addCandidate(context.Background(), cand, dummyConn)
|
||||
err = agent.addCandidate(context.Background(), cand, dummyConn)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
actualCandidates, err := a.GetLocalCandidates()
|
||||
actualCandidates, err := agent.GetLocalCandidates()
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, expectedCandidates, actualCandidates)
|
||||
}
|
||||
@@ -1666,7 +1674,7 @@ func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) {
|
||||
<-isTested
|
||||
}
|
||||
|
||||
// Assert that a Lite agent goes to disconnected and failed
|
||||
// Assert that a Lite agent goes to disconnected and failed.
|
||||
func TestLiteLifecycle(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -1818,7 +1826,7 @@ func TestGetSelectedCandidatePair(t *testing.T) {
|
||||
require.NoError(t, wan.Stop())
|
||||
}
|
||||
|
||||
func TestAcceptAggressiveNomination(t *testing.T) {
|
||||
func TestAcceptAggressiveNomination(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
defer test.TimeOut(time.Second * 30).Stop()
|
||||
@@ -1932,24 +1940,25 @@ func TestAcceptAggressiveNomination(t *testing.T) {
|
||||
bcandidates, err = bAgent.GetLocalCandidates()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, c := range bcandidates {
|
||||
if c != bAgent.getSelectedPair().Local {
|
||||
for _, cand := range bcandidates {
|
||||
if cand != bAgent.getSelectedPair().Local { //nolint:nestif
|
||||
if expectNewSelectedCandidate == nil {
|
||||
expected_change_priority:
|
||||
for _, candidates := range aAgent.remoteCandidates {
|
||||
for _, candidate := range candidates {
|
||||
if candidate.Equal(c) {
|
||||
if candidate.Equal(cand) {
|
||||
if tc.useHigherPriority {
|
||||
candidate.(*CandidateHost).priorityOverride += 1000 //nolint:forcetypeassert
|
||||
} else {
|
||||
candidate.(*CandidateHost).priorityOverride -= 1000 //nolint:forcetypeassert
|
||||
}
|
||||
|
||||
break expected_change_priority
|
||||
}
|
||||
}
|
||||
}
|
||||
if tc.isExpectedToSwitch {
|
||||
expectNewSelectedCandidate = c
|
||||
expectNewSelectedCandidate = cand
|
||||
} else {
|
||||
expectNewSelectedCandidate = aAgent.getSelectedPair().Remote
|
||||
}
|
||||
@@ -1958,18 +1967,27 @@ func TestAcceptAggressiveNomination(t *testing.T) {
|
||||
change_priority:
|
||||
for _, candidates := range aAgent.remoteCandidates {
|
||||
for _, candidate := range candidates {
|
||||
if candidate.Equal(c) {
|
||||
if candidate.Equal(cand) {
|
||||
if tc.useHigherPriority {
|
||||
candidate.(*CandidateHost).priorityOverride += 500 //nolint:forcetypeassert
|
||||
} else {
|
||||
candidate.(*CandidateHost).priorityOverride -= 500 //nolint:forcetypeassert
|
||||
}
|
||||
|
||||
break change_priority
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = c.writeTo(buildMsg(stun.ClassRequest, aAgent.localUfrag+":"+aAgent.remoteUfrag, aAgent.localPwd, c.Priority()).Raw, bAgent.getSelectedPair().Remote)
|
||||
_, err = cand.writeTo(
|
||||
buildMsg(
|
||||
stun.ClassRequest,
|
||||
aAgent.localUfrag+":"+aAgent.remoteUfrag,
|
||||
aAgent.localPwd,
|
||||
cand.Priority(),
|
||||
).Raw,
|
||||
bAgent.getSelectedPair().Remote,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@@ -1991,7 +2009,7 @@ func TestAcceptAggressiveNomination(t *testing.T) {
|
||||
require.NoError(t, wan.Stop())
|
||||
}
|
||||
|
||||
// Close can deadlock but GracefulClose must not
|
||||
// Close can deadlock but GracefulClose must not.
|
||||
func TestAgentGracefulCloseDeadlock(t *testing.T) {
|
||||
defer test.CheckRoutinesStrict(t)()
|
||||
defer test.TimeOut(time.Second * 5).Stop()
|
||||
|
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux
|
||||
// TestMuxAgent is an end to end test over UDP mux, ensuring two agents could connect over mux.
|
||||
func TestMuxAgent(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestMuxAgent(t *testing.T) {
|
||||
require.NoError(t, muxedA.Close())
|
||||
}()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
})
|
||||
@@ -68,10 +68,10 @@ func TestMuxAgent(t *testing.T) {
|
||||
if aClosed {
|
||||
return
|
||||
}
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
conn, muxedConn := connect(a, muxedA)
|
||||
conn, muxedConn := connect(agent, muxedA)
|
||||
|
||||
pair := muxedA.getSelectedPair()
|
||||
require.NotNil(t, pair)
|
||||
|
@@ -13,13 +13,13 @@ const (
|
||||
receiveMTU = 8192
|
||||
defaultLocalPreference = 65535
|
||||
|
||||
// ComponentRTP indicates that the candidate is used for RTP
|
||||
// ComponentRTP indicates that the candidate is used for RTP.
|
||||
ComponentRTP uint16 = 1
|
||||
// ComponentRTCP indicates that the candidate is used for RTCP
|
||||
// ComponentRTCP indicates that the candidate is used for RTCP.
|
||||
ComponentRTCP
|
||||
)
|
||||
|
||||
// Candidate represents an ICE candidate
|
||||
// Candidate represents an ICE candidate.
|
||||
type Candidate interface {
|
||||
// An arbitrary string used in the freezing algorithm to
|
||||
// group similar candidates. It is the same for two candidates that
|
||||
|
@@ -48,12 +48,12 @@ type candidateBase struct {
|
||||
extensions []CandidateExtension
|
||||
}
|
||||
|
||||
// Done implements context.Context
|
||||
// Done implements context.Context.
|
||||
func (c *candidateBase) Done() <-chan struct{} {
|
||||
return c.closeCh
|
||||
}
|
||||
|
||||
// Err implements context.Context
|
||||
// Err implements context.Context.
|
||||
func (c *candidateBase) Err() error {
|
||||
select {
|
||||
case <-c.closedCh:
|
||||
@@ -63,17 +63,17 @@ func (c *candidateBase) Err() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline implements context.Context
|
||||
// Deadline implements context.Context.
|
||||
func (c *candidateBase) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Value implements context.Context
|
||||
// Value implements context.Context.
|
||||
func (c *candidateBase) Value(interface{}) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ID returns Candidate ID
|
||||
// ID returns Candidate ID.
|
||||
func (c *candidateBase) ID() string {
|
||||
return c.id
|
||||
}
|
||||
@@ -86,27 +86,27 @@ func (c *candidateBase) Foundation() string {
|
||||
return fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(c.Type().String()+c.address+c.networkType.String())))
|
||||
}
|
||||
|
||||
// Address returns Candidate Address
|
||||
// Address returns Candidate Address.
|
||||
func (c *candidateBase) Address() string {
|
||||
return c.address
|
||||
}
|
||||
|
||||
// Port returns Candidate Port
|
||||
// Port returns Candidate Port.
|
||||
func (c *candidateBase) Port() int {
|
||||
return c.port
|
||||
}
|
||||
|
||||
// Type returns candidate type
|
||||
// Type returns candidate type.
|
||||
func (c *candidateBase) Type() CandidateType {
|
||||
return c.candidateType
|
||||
}
|
||||
|
||||
// NetworkType returns candidate NetworkType
|
||||
// NetworkType returns candidate NetworkType.
|
||||
func (c *candidateBase) NetworkType() NetworkType {
|
||||
return c.networkType
|
||||
}
|
||||
|
||||
// Component returns candidate component
|
||||
// Component returns candidate component.
|
||||
func (c *candidateBase) Component() uint16 {
|
||||
return c.component
|
||||
}
|
||||
@@ -115,8 +115,8 @@ func (c *candidateBase) SetComponent(component uint16) {
|
||||
c.component = component
|
||||
}
|
||||
|
||||
// LocalPreference returns the local preference for this candidate
|
||||
func (c *candidateBase) LocalPreference() uint16 {
|
||||
// LocalPreference returns the local preference for this candidate.
|
||||
func (c *candidateBase) LocalPreference() uint16 { //nolint:cyclop
|
||||
if c.NetworkType().IsTCP() {
|
||||
// RFC 6544, section 4.2
|
||||
//
|
||||
@@ -182,6 +182,7 @@ func (c *candidateBase) LocalPreference() uint16 {
|
||||
case CandidateTypeUnspecified:
|
||||
return 0
|
||||
}
|
||||
|
||||
return 0
|
||||
}()
|
||||
|
||||
@@ -191,7 +192,7 @@ func (c *candidateBase) LocalPreference() uint16 {
|
||||
return defaultLocalPreference
|
||||
}
|
||||
|
||||
// RelatedAddress returns *CandidateRelatedAddress
|
||||
// RelatedAddress returns *CandidateRelatedAddress.
|
||||
func (c *candidateBase) RelatedAddress() *CandidateRelatedAddress {
|
||||
return c.relatedAddress
|
||||
}
|
||||
@@ -200,10 +201,11 @@ func (c *candidateBase) TCPType() TCPType {
|
||||
return c.tcpType
|
||||
}
|
||||
|
||||
// start runs the candidate using the provided connection
|
||||
// start runs the candidate using the provided connection.
|
||||
func (c *candidateBase) start(a *Agent, conn net.PacketConn, initializedCh <-chan struct{}) {
|
||||
if c.conn != nil {
|
||||
c.agent().log.Warn("Can't start already started candidateBase")
|
||||
|
||||
return
|
||||
}
|
||||
c.currAgent = a
|
||||
@@ -221,7 +223,7 @@ var bufferPool = sync.Pool{ // nolint:gochecknoglobals
|
||||
}
|
||||
|
||||
func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
|
||||
a := c.agent()
|
||||
agent := c.agent()
|
||||
|
||||
defer close(c.closedCh)
|
||||
|
||||
@@ -242,8 +244,9 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
|
||||
n, srcAddr, err := c.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
|
||||
a.log.Warnf("Failed to read from candidate %s: %v", c, err)
|
||||
agent.log.Warnf("Failed to read from candidate %s: %v", c, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -254,8 +257,10 @@ func (c *candidateBase) recvLoop(initializedCh <-chan struct{}) {
|
||||
func (c *candidateBase) validateSTUNTrafficCache(addr net.Addr) bool {
|
||||
if candidate, ok := c.remoteCandidateCaches[toAddrPort(addr)]; ok {
|
||||
candidate.seen(false)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -267,48 +272,51 @@ func (c *candidateBase) addRemoteCandidateCache(candidate Candidate, srcAddr net
|
||||
}
|
||||
|
||||
func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
|
||||
a := c.agent()
|
||||
agent := c.agent()
|
||||
|
||||
if stun.IsMessage(buf) {
|
||||
m := &stun.Message{
|
||||
msg := &stun.Message{
|
||||
Raw: make([]byte, len(buf)),
|
||||
}
|
||||
|
||||
// Explicitly copy raw buffer so Message can own the memory.
|
||||
copy(m.Raw, buf)
|
||||
copy(msg.Raw, buf)
|
||||
|
||||
if err := msg.Decode(); err != nil {
|
||||
agent.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err)
|
||||
|
||||
if err := m.Decode(); err != nil {
|
||||
a.log.Warnf("Failed to handle decode ICE from %s to %s: %v", c.addr(), srcAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := a.loop.Run(c, func(_ context.Context) {
|
||||
if err := agent.loop.Run(c, func(_ context.Context) {
|
||||
// nolint: contextcheck
|
||||
a.handleInbound(m, c, srcAddr)
|
||||
agent.handleInbound(msg, c, srcAddr)
|
||||
}); err != nil {
|
||||
a.log.Warnf("Failed to handle message: %v", err)
|
||||
agent.log.Warnf("Failed to handle message: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !c.validateSTUNTrafficCache(srcAddr) {
|
||||
remoteCandidate, valid := a.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck
|
||||
remoteCandidate, valid := agent.validateNonSTUNTraffic(c, srcAddr) //nolint:contextcheck
|
||||
if !valid {
|
||||
a.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr())
|
||||
agent.log.Warnf("Discarded message from %s, not a valid remote candidate", c.addr())
|
||||
|
||||
return
|
||||
}
|
||||
c.addRemoteCandidateCache(remoteCandidate, srcAddr)
|
||||
}
|
||||
|
||||
// Note: This will return packetio.ErrFull if the buffer ever manages to fill up.
|
||||
if _, err := a.buf.Write(buf); err != nil {
|
||||
a.log.Warnf("Failed to write packet: %s", err)
|
||||
if _, err := agent.buf.Write(buf); err != nil {
|
||||
agent.log.Warnf("Failed to write packet: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// close stops the recvLoop
|
||||
// close stops the recvLoop.
|
||||
func (c *candidateBase) close() error {
|
||||
// If conn has never been started will be nil
|
||||
if c.Done() == nil {
|
||||
@@ -353,13 +361,15 @@ func (c *candidateBase) writeTo(raw []byte, dst Candidate) (int, error) {
|
||||
return n, err
|
||||
}
|
||||
c.agent().log.Infof("Failed to send packet: %v", err)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
c.seen(true)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// TypePreference returns the type preference for this candidate
|
||||
// TypePreference returns the type preference for this candidate.
|
||||
func (c *candidateBase) TypePreference() uint16 {
|
||||
pref := c.Type().Preference()
|
||||
if pref == 0 {
|
||||
@@ -397,7 +407,7 @@ func (c *candidateBase) Priority() uint32 {
|
||||
(1<<0)*uint32(256-c.Component())
|
||||
}
|
||||
|
||||
// Equal is used to compare two candidateBases
|
||||
// Equal is used to compare two candidateBases.
|
||||
func (c *candidateBase) Equal(other Candidate) bool {
|
||||
if c.addr() != other.addr() {
|
||||
if c.addr() == nil || other.addr() == nil {
|
||||
@@ -416,22 +426,30 @@ func (c *candidateBase) Equal(other Candidate) bool {
|
||||
c.RelatedAddress().Equal(other.RelatedAddress())
|
||||
}
|
||||
|
||||
// DeepEqual is same as Equal but also compares the extensions
|
||||
// DeepEqual is same as Equal but also compares the extensions.
|
||||
func (c *candidateBase) DeepEqual(other Candidate) bool {
|
||||
return c.Equal(other) && c.extensionsEqual(other.Extensions())
|
||||
}
|
||||
|
||||
// String makes the candidateBase printable
|
||||
// String makes the candidateBase printable.
|
||||
func (c *candidateBase) String() string {
|
||||
return fmt.Sprintf("%s %s %s%s (resolved: %v)", c.NetworkType(), c.Type(), net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), c.relatedAddress, c.resolvedAddr)
|
||||
return fmt.Sprintf(
|
||||
"%s %s %s%s (resolved: %v)",
|
||||
c.NetworkType(),
|
||||
c.Type(),
|
||||
net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())),
|
||||
c.relatedAddress,
|
||||
c.resolvedAddr,
|
||||
)
|
||||
}
|
||||
|
||||
// LastReceived returns a time.Time indicating the last time
|
||||
// this candidate was received
|
||||
// this candidate was received.
|
||||
func (c *candidateBase) LastReceived() time.Time {
|
||||
if lastReceived, ok := c.lastReceived.Load().(time.Time); ok {
|
||||
return lastReceived
|
||||
}
|
||||
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
@@ -440,11 +458,12 @@ func (c *candidateBase) setLastReceived(t time.Time) {
|
||||
}
|
||||
|
||||
// LastSent returns a time.Time indicating the last time
|
||||
// this candidate was sent
|
||||
// this candidate was sent.
|
||||
func (c *candidateBase) LastSent() time.Time {
|
||||
if lastSent, ok := c.lastSent.Load().(time.Time); ok {
|
||||
return lastSent
|
||||
}
|
||||
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
@@ -484,10 +503,11 @@ func removeZoneIDFromAddress(addr string) string {
|
||||
if i := strings.Index(addr, "%"); i != -1 {
|
||||
return addr[:i]
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// Marshal returns the string representation of the ICECandidate
|
||||
// Marshal returns the string representation of the ICECandidate.
|
||||
func (c *candidateBase) Marshal() string {
|
||||
val := c.Foundation()
|
||||
if val == " " {
|
||||
@@ -618,9 +638,7 @@ func (c *candidateBase) setExtensions(extensions []CandidateExtension) {
|
||||
|
||||
// UnmarshalCandidate Parses a candidate from a string
|
||||
// https://datatracker.ietf.org/doc/html/rfc5245#section-15.1
|
||||
func UnmarshalCandidate(raw string) (Candidate, error) {
|
||||
// rfc5245
|
||||
|
||||
func UnmarshalCandidate(raw string) (Candidate, error) { //nolint:cyclop
|
||||
pos := 0
|
||||
|
||||
// foundation ( 1*32ice-char ) But we allow for empty foundation,
|
||||
@@ -805,15 +823,16 @@ func UnmarshalCandidate(raw string) (Candidate, error) {
|
||||
|
||||
// Read an ice-char token from the raw string
|
||||
// ice-char = ALPHA / DIGIT / "+" / "/"
|
||||
// stop reading when a space is encountered or the end of the string
|
||||
func readCandidateCharToken(raw string, start int, limit int) (string, int, error) {
|
||||
// stop reading when a space is encountered or the end of the string.
|
||||
func readCandidateCharToken(raw string, start int, limit int) (string, int, error) { //nolint:cyclop
|
||||
for i, char := range raw[start:] {
|
||||
if char == 0x20 { // SP
|
||||
return raw[start : start+i], start + i + 1, nil
|
||||
}
|
||||
|
||||
if i == limit {
|
||||
return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) //nolint: err113 // handled by caller
|
||||
//nolint: err113 // handled by caller
|
||||
return "", 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit)
|
||||
}
|
||||
|
||||
if !(char >= 'A' && char <= 'Z' ||
|
||||
@@ -828,7 +847,7 @@ func readCandidateCharToken(raw string, start int, limit int) (string, int, erro
|
||||
}
|
||||
|
||||
// Read an ice string token from the raw string until a space is encountered
|
||||
// Or the end of the string, we imply that ice string are UTF-8 encoded
|
||||
// Or the end of the string, we imply that ice string are UTF-8 encoded.
|
||||
func readCandidateStringToken(raw string, start int) (string, int) {
|
||||
for i, char := range raw[start:] {
|
||||
if char == 0x20 { // SP
|
||||
@@ -840,7 +859,7 @@ func readCandidateStringToken(raw string, start int) (string, int) {
|
||||
}
|
||||
|
||||
// Read a digit token from the raw string
|
||||
// stop reading when a space is encountered or the end of the string
|
||||
// stop reading when a space is encountered or the end of the string.
|
||||
func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
|
||||
var val int
|
||||
for i, char := range raw[start:] {
|
||||
@@ -849,7 +868,8 @@ func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
|
||||
}
|
||||
|
||||
if i == limit {
|
||||
return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit) //nolint: err113 // handled by caller
|
||||
//nolint: err113 // handled by caller
|
||||
return 0, 0, fmt.Errorf("token too long: %s expected 1x%d", raw[start:start+i], limit)
|
||||
}
|
||||
|
||||
if !(char >= '0' && char <= '9') {
|
||||
@@ -862,7 +882,7 @@ func readCandidateDigitToken(raw string, start, limit int) (int, int, error) {
|
||||
return val, len(raw), nil
|
||||
}
|
||||
|
||||
// Read and validate RFC 4566 port from the raw string
|
||||
// Read and validate RFC 4566 port from the raw string.
|
||||
func readCandidatePort(raw string, start int) (int, int, error) {
|
||||
port, pos, err := readCandidateDigitToken(raw, start, 5)
|
||||
if err != nil {
|
||||
@@ -878,7 +898,7 @@ func readCandidatePort(raw string, start int) (int, int, error) {
|
||||
|
||||
// Read a byte-string token from the raw string
|
||||
// As defined in RFC 4566 1*(%x01-09/%x0B-0C/%x0E-FF) ;any byte except NUL, CR, or LF
|
||||
// we imply that extensions byte-string are UTF-8 encoded
|
||||
// we imply that extensions byte-string are UTF-8 encoded.
|
||||
func readCandidateByteString(raw string, start int) (string, int, error) {
|
||||
for i, char := range raw[start:] {
|
||||
if char == 0x20 { // SP
|
||||
@@ -952,17 +972,23 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
|
||||
for i := 0; i < len(raw); {
|
||||
key, next, err := readCandidateByteString(raw, i)
|
||||
if err != nil {
|
||||
return extensions, "", fmt.Errorf("%w: failed to read key %v", errParseExtension, err) //nolint: errorlint // we wrap the error
|
||||
return extensions, "", fmt.Errorf(
|
||||
"%w: failed to read key %v", errParseExtension, err, //nolint: errorlint // we wrap the error
|
||||
)
|
||||
}
|
||||
i = next
|
||||
|
||||
if i >= len(raw) {
|
||||
return extensions, "", fmt.Errorf("%w: missing value for %s in %s", errParseExtension, key, raw)
|
||||
return extensions, "", fmt.Errorf(
|
||||
"%w: missing value for %s in %s", errParseExtension, key, raw, //nolint: errorlint // we are wrapping the error
|
||||
)
|
||||
}
|
||||
|
||||
value, next, err := readCandidateByteString(raw, i)
|
||||
if err != nil {
|
||||
return extensions, "", fmt.Errorf("%w: failed to read value %v", errParseExtension, err) //nolint: errorlint // we are wrapping the error
|
||||
return extensions, "", fmt.Errorf(
|
||||
"%w: failed to read value %v", errParseExtension, err, //nolint: errorlint // we are wrapping the error
|
||||
)
|
||||
}
|
||||
i = next
|
||||
|
||||
@@ -973,5 +999,5 @@ func unmarshalCandidateExtensions(raw string) (extensions []CandidateExtension,
|
||||
extensions = append(extensions, CandidateExtension{key, value})
|
||||
}
|
||||
|
||||
return
|
||||
return extensions, rawTCPTypeRaw, nil
|
||||
}
|
||||
|
@@ -8,14 +8,14 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CandidateHost is a candidate of type host
|
||||
// CandidateHost is a candidate of type host.
|
||||
type CandidateHost struct {
|
||||
candidateBase
|
||||
|
||||
network string
|
||||
}
|
||||
|
||||
// CandidateHostConfig is the config required to create a new CandidateHost
|
||||
// CandidateHostConfig is the config required to create a new CandidateHost.
|
||||
type CandidateHostConfig struct {
|
||||
CandidateID string
|
||||
Network string
|
||||
@@ -28,7 +28,7 @@ type CandidateHostConfig struct {
|
||||
IsLocationTracked bool
|
||||
}
|
||||
|
||||
// NewCandidateHost creates a new host candidate
|
||||
// NewCandidateHost creates a new host candidate.
|
||||
func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
|
||||
candidateID := config.CandidateID
|
||||
|
||||
@@ -36,7 +36,7 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
|
||||
candidateID = globalCandidateIDGenerator.Generate()
|
||||
}
|
||||
|
||||
c := &CandidateHost{
|
||||
candidateHost := &CandidateHost{
|
||||
candidateBase: candidateBase{
|
||||
id: candidateID,
|
||||
address: config.Address,
|
||||
@@ -58,15 +58,15 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.setIPAddr(ipAddr); err != nil {
|
||||
if err := candidateHost.setIPAddr(ipAddr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Until mDNS candidate is resolved assume it is UDPv4
|
||||
c.candidateBase.networkType = NetworkTypeUDP4
|
||||
candidateHost.candidateBase.networkType = NetworkTypeUDP4
|
||||
}
|
||||
|
||||
return c, nil
|
||||
return candidateHost, nil
|
||||
}
|
||||
|
||||
func (c *CandidateHost) setIPAddr(addr netip.Addr) error {
|
||||
|
@@ -15,7 +15,7 @@ type CandidatePeerReflexive struct {
|
||||
candidateBase
|
||||
}
|
||||
|
||||
// CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive
|
||||
// CandidatePeerReflexiveConfig is the config required to create a new CandidatePeerReflexive.
|
||||
type CandidatePeerReflexiveConfig struct {
|
||||
CandidateID string
|
||||
Network string
|
||||
@@ -28,7 +28,7 @@ type CandidatePeerReflexiveConfig struct {
|
||||
RelPort int
|
||||
}
|
||||
|
||||
// NewCandidatePeerReflexive creates a new peer reflective candidate
|
||||
// NewCandidatePeerReflexive creates a new peer reflective candidate.
|
||||
func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) {
|
||||
ipAddr, err := netip.ParseAddr(config.Address)
|
||||
if err != nil {
|
||||
|
@@ -16,7 +16,7 @@ type CandidateRelay struct {
|
||||
onClose func() error
|
||||
}
|
||||
|
||||
// CandidateRelayConfig is the config required to create a new CandidateRelay
|
||||
// CandidateRelayConfig is the config required to create a new CandidateRelay.
|
||||
type CandidateRelayConfig struct {
|
||||
CandidateID string
|
||||
Network string
|
||||
@@ -31,7 +31,7 @@ type CandidateRelayConfig struct {
|
||||
OnClose func() error
|
||||
}
|
||||
|
||||
// NewCandidateRelay creates a new relay candidate
|
||||
// NewCandidateRelay creates a new relay candidate.
|
||||
func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
|
||||
candidateID := config.CandidateID
|
||||
|
||||
@@ -75,7 +75,7 @@ func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LocalPreference returns the local preference for this candidate
|
||||
// LocalPreference returns the local preference for this candidate.
|
||||
func (c *CandidateRelay) LocalPreference() uint16 {
|
||||
// These preference values come from libwebrtc
|
||||
// https://github.com/mozilla/libwebrtc/blob/1389c76d9c79839a2ca069df1db48aa3f2e6a1ac/p2p/base/turn_port.cc#L61
|
||||
@@ -103,6 +103,7 @@ func (c *CandidateRelay) close() error {
|
||||
err = c.onClose()
|
||||
c.onClose = nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -13,7 +13,7 @@ type CandidateServerReflexive struct {
|
||||
candidateBase
|
||||
}
|
||||
|
||||
// CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive
|
||||
// CandidateServerReflexiveConfig is the config required to create a new CandidateServerReflexive.
|
||||
type CandidateServerReflexiveConfig struct {
|
||||
CandidateID string
|
||||
Network string
|
||||
@@ -26,7 +26,7 @@ type CandidateServerReflexiveConfig struct {
|
||||
RelPort int
|
||||
}
|
||||
|
||||
// NewCandidateServerReflexive creates a new server reflective candidate
|
||||
// NewCandidateServerReflexive creates a new server reflective candidate.
|
||||
func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) {
|
||||
ipAddr, err := netip.ParseAddr(config.Address)
|
||||
if err != nil {
|
||||
|
@@ -16,7 +16,7 @@ import (
|
||||
const localhostIPStr = "127.0.0.1"
|
||||
|
||||
func TestCandidateTypePreference(t *testing.T) {
|
||||
r := require.New(t)
|
||||
req := require.New(t)
|
||||
|
||||
hostDefaultPreference := uint16(126)
|
||||
prflxDefaultPreference := uint16(110)
|
||||
@@ -53,16 +53,16 @@ func TestCandidateTypePreference(t *testing.T) {
|
||||
}
|
||||
|
||||
if networkType.IsTCP() {
|
||||
r.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference())
|
||||
r.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference())
|
||||
r.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference())
|
||||
req.Equal(hostDefaultPreference-tcpOffset, hostCandidate.TypePreference())
|
||||
req.Equal(prflxDefaultPreference-tcpOffset, prflxCandidate.TypePreference())
|
||||
req.Equal(srflxDefaultPreference-tcpOffset, srflxCandidate.TypePreference())
|
||||
} else {
|
||||
r.Equal(hostDefaultPreference, hostCandidate.TypePreference())
|
||||
r.Equal(prflxDefaultPreference, prflxCandidate.TypePreference())
|
||||
r.Equal(srflxDefaultPreference, srflxCandidate.TypePreference())
|
||||
req.Equal(hostDefaultPreference, hostCandidate.TypePreference())
|
||||
req.Equal(prflxDefaultPreference, prflxCandidate.TypePreference())
|
||||
req.Equal(srflxDefaultPreference, srflxCandidate.TypePreference())
|
||||
}
|
||||
|
||||
r.Equal(relayDefaultPreference, relayCandidate.TypePreference())
|
||||
req.Equal(relayDefaultPreference, relayCandidate.TypePreference())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -266,20 +266,27 @@ func TestCandidateFoundation(t *testing.T) {
|
||||
}).Foundation())
|
||||
}
|
||||
|
||||
func mustCandidateHost(conf *CandidateHostConfig) Candidate {
|
||||
cand, err := NewCandidateHost(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateHostWithExtensions(t *testing.T, conf *CandidateHostConfig, extensions []CandidateExtension) Candidate {
|
||||
func mustCandidateHost(t *testing.T, conf *CandidateHostConfig) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidateHost(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateHostWithExtensions(
|
||||
t *testing.T,
|
||||
conf *CandidateHostConfig,
|
||||
extensions []CandidateExtension,
|
||||
) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidateHost(conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cand.setExtensions(extensions)
|
||||
@@ -287,20 +294,27 @@ func mustCandidateHostWithExtensions(t *testing.T, conf *CandidateHostConfig, ex
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateRelay(conf *CandidateRelayConfig) Candidate {
|
||||
cand, err := NewCandidateRelay(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateRelayWithExtensions(t *testing.T, conf *CandidateRelayConfig, extensions []CandidateExtension) Candidate {
|
||||
func mustCandidateRelay(t *testing.T, conf *CandidateRelayConfig) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidateRelay(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateRelayWithExtensions(
|
||||
t *testing.T,
|
||||
conf *CandidateRelayConfig,
|
||||
extensions []CandidateExtension,
|
||||
) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidateRelay(conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cand.setExtensions(extensions)
|
||||
@@ -308,20 +322,27 @@ func mustCandidateRelayWithExtensions(t *testing.T, conf *CandidateRelayConfig,
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateServerReflexive(conf *CandidateServerReflexiveConfig) Candidate {
|
||||
cand, err := NewCandidateServerReflexive(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateServerReflexiveWithExtensions(t *testing.T, conf *CandidateServerReflexiveConfig, extensions []CandidateExtension) Candidate {
|
||||
func mustCandidateServerReflexive(t *testing.T, conf *CandidateServerReflexiveConfig) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidateServerReflexive(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidateServerReflexiveWithExtensions(
|
||||
t *testing.T,
|
||||
conf *CandidateServerReflexiveConfig,
|
||||
extensions []CandidateExtension,
|
||||
) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidateServerReflexive(conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cand.setExtensions(extensions)
|
||||
@@ -329,12 +350,16 @@ func mustCandidateServerReflexiveWithExtensions(t *testing.T, conf *CandidateSer
|
||||
return cand
|
||||
}
|
||||
|
||||
func mustCandidatePeerReflexiveWithExtensions(t *testing.T, conf *CandidatePeerReflexiveConfig, extensions []CandidateExtension) Candidate {
|
||||
func mustCandidatePeerReflexiveWithExtensions(
|
||||
t *testing.T,
|
||||
conf *CandidatePeerReflexiveConfig,
|
||||
extensions []CandidateExtension,
|
||||
) Candidate {
|
||||
t.Helper()
|
||||
|
||||
cand, err := NewCandidatePeerReflexive(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cand.setExtensions(extensions)
|
||||
@@ -349,7 +374,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP6.String(),
|
||||
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||
Port: 53987,
|
||||
@@ -360,7 +385,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: "10.0.75.1",
|
||||
Port: 53634,
|
||||
@@ -369,7 +394,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateServerReflexive(&CandidateServerReflexiveConfig{
|
||||
mustCandidateServerReflexive(t, &CandidateServerReflexiveConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: "191.228.238.68",
|
||||
Port: 53991,
|
||||
@@ -395,11 +420,12 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
{"network-cost", "10"},
|
||||
},
|
||||
),
|
||||
//nolint: lll
|
||||
"4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10",
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateRelay(&CandidateRelayConfig{
|
||||
mustCandidateRelay(t, &CandidateRelayConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: "50.0.0.1",
|
||||
Port: 5000,
|
||||
@@ -410,7 +436,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeTCP4.String(),
|
||||
Address: "192.168.0.196",
|
||||
Port: 0,
|
||||
@@ -420,7 +446,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local",
|
||||
Port: 60542,
|
||||
@@ -429,7 +455,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
},
|
||||
// Missing Foundation
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: localhostIPStr,
|
||||
Port: 80,
|
||||
@@ -440,7 +466,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: localhostIPStr,
|
||||
Port: 80,
|
||||
@@ -451,7 +477,7 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
false,
|
||||
},
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeTCP4.String(),
|
||||
Address: "172.28.142.173",
|
||||
Port: 7686,
|
||||
@@ -467,8 +493,10 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
{nil, "1938809241", true},
|
||||
{nil, "1986380506 99999999 udp 2122063615 10.0.75.1 53634 typ host generation 0 network-id 2", true},
|
||||
{nil, "1986380506 1 udp 99999999999 10.0.75.1 53634 typ host", true},
|
||||
//nolint: lll
|
||||
{nil, "4207374051 1 udp 1685790463 191.228.238.68 99999999 typ srflx raddr 192.168.0.278 rport 53991 generation 0 network-id 3", true},
|
||||
{nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr", true},
|
||||
//nolint: lll
|
||||
{nil, "4207374051 1 udp 1685790463 191.228.238.68 53991 typ srflx raddr 192.168.0.278 rport 99999999 generation 0 network-id 3", true},
|
||||
{nil, "4207374051 INVALID udp 2130706431 10.0.75.1 53634 typ host", true},
|
||||
{nil, "4207374051 1 udp INVALID 10.0.75.1 53634 typ host", true},
|
||||
@@ -521,12 +549,19 @@ func TestCandidateMarshal(t *testing.T) {
|
||||
actualCandidate, err := UnmarshalCandidate(test.marshaled)
|
||||
if test.expectError {
|
||||
require.Error(t, err, "expected error", test.marshaled)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Truef(t, test.candidate.Equal(actualCandidate), "%s != %s", test.candidate.String(), actualCandidate.String())
|
||||
require.Truef(
|
||||
t,
|
||||
test.candidate.Equal(actualCandidate),
|
||||
"%s != %s",
|
||||
test.candidate.String(),
|
||||
actualCandidate.String(),
|
||||
)
|
||||
require.Equal(t, test.marshaled, actualCandidate.Marshal())
|
||||
})
|
||||
}
|
||||
@@ -573,7 +608,7 @@ func TestCandidateWriteTo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) {
|
||||
candidateWithZoneID := mustCandidateHost(&CandidateHostConfig{
|
||||
candidateWithZoneID := mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP6.String(),
|
||||
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a%Local Connection",
|
||||
Port: 53987,
|
||||
@@ -583,7 +618,7 @@ func TestMarshalUnmarshalCandidateWithZoneID(t *testing.T) {
|
||||
candidateStr := "750 0 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host"
|
||||
require.Equal(t, candidateStr, candidateWithZoneID.Marshal())
|
||||
|
||||
candidate := mustCandidateHost(&CandidateHostConfig{
|
||||
candidate := mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP6.String(),
|
||||
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||
Port: 53987,
|
||||
@@ -612,6 +647,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
|
||||
{"ufrag", "QNvE"},
|
||||
{"network-id", "4"},
|
||||
},
|
||||
//nolint: lll
|
||||
"1299692247 1 udp 2122134271 fdc8:cc8:c835:e400:343c:feb:32c8:17b9 58240 typ host generation 0 ufrag QNvE network-id 4",
|
||||
},
|
||||
{
|
||||
@@ -620,6 +656,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
|
||||
{"network-id", "2"},
|
||||
{"network-cost", "50"},
|
||||
},
|
||||
//nolint:lll
|
||||
"647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991 generation 1 network-id 2 network-cost 50",
|
||||
},
|
||||
{
|
||||
@@ -628,6 +665,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
|
||||
{"network-id", "2"},
|
||||
{"network-cost", "10"},
|
||||
},
|
||||
//nolint:lll
|
||||
"4207374052 1 tcp 1685790463 192.0.2.15 50000 typ prflx raddr 10.0.0.1 rport 12345 generation 0 network-id 2 network-cost 10",
|
||||
},
|
||||
{
|
||||
@@ -638,6 +676,7 @@ func TestCandidateExtensionsMarshal(t *testing.T) {
|
||||
{"ufrag", "frag42abcdef"},
|
||||
{"password", "abc123exp123"},
|
||||
},
|
||||
//nolint: lll
|
||||
"848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001 generation 0 network-id 1 network-cost 20 ufrag frag42abcdef password abc123exp123",
|
||||
},
|
||||
{
|
||||
@@ -703,7 +742,7 @@ func TestCandidateExtensionsDeepEqual(t *testing.T) {
|
||||
equal bool
|
||||
}{
|
||||
{
|
||||
mustCandidateHost(&CandidateHostConfig{
|
||||
mustCandidateHost(t, &CandidateHostConfig{
|
||||
Network: NetworkTypeUDP4.String(),
|
||||
Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a",
|
||||
Port: 53987,
|
||||
|
@@ -20,8 +20,7 @@ func newCandidatePair(local, remote Candidate, controlling bool) *CandidatePair
|
||||
}
|
||||
}
|
||||
|
||||
// CandidatePair is a combination of a
|
||||
// local and remote candidate
|
||||
// CandidatePair is a combination of a local and remote candidate.
|
||||
type CandidatePair struct {
|
||||
iceRoleControlling bool
|
||||
Remote Candidate
|
||||
@@ -42,8 +41,17 @@ func (p *CandidatePair) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v",
|
||||
p.priority(), p.Local.Priority(), p.Local, p.Remote, p.Remote.Priority(), p.state, p.nominated, p.nominateOnBindingSuccess)
|
||||
return fmt.Sprintf(
|
||||
"prio %d (local, prio %d) %s <-> %s (remote, prio %d), state: %s, nominated: %v, nominateOnBindingSuccess: %v",
|
||||
p.priority(),
|
||||
p.Local.Priority(),
|
||||
p.Local,
|
||||
p.Remote,
|
||||
p.Remote.Priority(),
|
||||
p.state,
|
||||
p.nominated,
|
||||
p.nominateOnBindingSuccess,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *CandidatePair) equal(other *CandidatePair) bool {
|
||||
@@ -53,6 +61,7 @@ func (p *CandidatePair) equal(other *CandidatePair) bool {
|
||||
if p == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return p.Local.Equal(other.Local) && p.Remote.Equal(other.Remote)
|
||||
}
|
||||
|
||||
@@ -60,9 +69,9 @@ func (p *CandidatePair) equal(other *CandidatePair) bool {
|
||||
// Let G be the priority for the candidate provided by the controlling
|
||||
// agent. Let D be the priority for the candidate provided by the
|
||||
// controlled agent.
|
||||
// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0)
|
||||
// pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0).
|
||||
func (p *CandidatePair) priority() uint64 {
|
||||
var g, d uint32
|
||||
var g, d uint32 //nolint:varnamelen // clearer to use g and d here
|
||||
if p.iceRoleControlling {
|
||||
g = p.Local.Priority()
|
||||
d = p.Remote.Priority()
|
||||
@@ -77,18 +86,21 @@ func (p *CandidatePair) priority() uint64 {
|
||||
if x < y {
|
||||
return uint64(x)
|
||||
}
|
||||
|
||||
return uint64(y)
|
||||
}
|
||||
localMax := func(x, y uint32) uint64 {
|
||||
if x > y {
|
||||
return uint64(x)
|
||||
}
|
||||
|
||||
return uint64(y)
|
||||
}
|
||||
cmp := func(x, y uint32) uint64 {
|
||||
if x > y {
|
||||
return uint64(1)
|
||||
}
|
||||
|
||||
return uint64(0)
|
||||
}
|
||||
|
||||
@@ -109,7 +121,7 @@ func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) {
|
||||
}
|
||||
|
||||
// UpdateRoundTripTime sets the current round time of this pair and
|
||||
// accumulates total round trip time and responses received
|
||||
// accumulates total round trip time and responses received.
|
||||
func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) {
|
||||
rttNs := rtt.Nanoseconds()
|
||||
atomic.StoreInt64(&p.currentRoundTripTime, rttNs)
|
||||
|
@@ -3,12 +3,12 @@
|
||||
|
||||
package ice
|
||||
|
||||
// CandidatePairState represent the ICE candidate pair state
|
||||
// CandidatePairState represent the ICE candidate pair state.
|
||||
type CandidatePairState int
|
||||
|
||||
const (
|
||||
// CandidatePairStateWaiting means a check has not been performed for
|
||||
// this pair
|
||||
// this pair.
|
||||
CandidatePairStateWaiting CandidatePairState = iota + 1
|
||||
|
||||
// CandidatePairStateInProgress means a check has been sent for this pair,
|
||||
@@ -36,5 +36,6 @@ func (c CandidatePairState) String() string {
|
||||
case CandidatePairStateSucceeded:
|
||||
return "succeeded"
|
||||
}
|
||||
|
||||
return "Unknown candidate pair state"
|
||||
}
|
||||
|
@@ -12,7 +12,7 @@ type CandidateRelatedAddress struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
// String makes CandidateRelatedAddress printable
|
||||
// String makes CandidateRelatedAddress printable.
|
||||
func (c *CandidateRelatedAddress) String() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
@@ -27,6 +27,7 @@ func (c *CandidateRelatedAddress) Equal(other *CandidateRelatedAddress) bool {
|
||||
if c == nil && other == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return c != nil && other != nil &&
|
||||
c.Address == other.Address &&
|
||||
c.Port == other.Port
|
||||
|
@@ -3,10 +3,10 @@
|
||||
|
||||
package ice
|
||||
|
||||
// CandidateType represents the type of candidate
|
||||
// CandidateType represents the type of candidate.
|
||||
type CandidateType byte
|
||||
|
||||
// CandidateType enum
|
||||
// CandidateType enum.
|
||||
const (
|
||||
CandidateTypeUnspecified CandidateType = iota
|
||||
CandidateTypeHost
|
||||
@@ -15,7 +15,7 @@ const (
|
||||
CandidateTypeRelay
|
||||
)
|
||||
|
||||
// String makes CandidateType printable
|
||||
// String makes CandidateType printable.
|
||||
func (c CandidateType) String() string {
|
||||
switch c {
|
||||
case CandidateTypeHost:
|
||||
@@ -29,6 +29,7 @@ func (c CandidateType) String() string {
|
||||
case CandidateTypeUnspecified:
|
||||
return "Unknown candidate type"
|
||||
}
|
||||
|
||||
return "Unknown candidate type"
|
||||
}
|
||||
|
||||
@@ -49,6 +50,7 @@ func (c CandidateType) Preference() uint16 {
|
||||
case CandidateTypeRelay, CandidateTypeUnspecified:
|
||||
return 0
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -61,5 +63,6 @@ func containsCandidateType(candidateType CandidateType, candidateTypeList []Cand
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@@ -45,7 +45,7 @@ func (v *virtualNet) close() {
|
||||
v.wan.Stop() //nolint:errcheck,gosec
|
||||
}
|
||||
|
||||
func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
|
||||
func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) { //nolint:cyclop
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
|
||||
// WAN
|
||||
@@ -77,6 +77,7 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
|
||||
vnetGlobalIPA + "/" + vnetLocalIPA,
|
||||
}
|
||||
}
|
||||
|
||||
return []string{
|
||||
vnetGlobalIPA,
|
||||
}
|
||||
@@ -114,6 +115,7 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
|
||||
vnetGlobalIPB + "/" + vnetLocalIPB,
|
||||
}
|
||||
}
|
||||
|
||||
return []string{
|
||||
vnetGlobalIPB,
|
||||
}
|
||||
@@ -175,6 +177,7 @@ func addVNetSTUN(wanNet *vnet.Net, loggerFactory logging.LoggerFactory) (*turn.S
|
||||
if pw, ok := credMap[username]; ok {
|
||||
return turn.GenerateAuthKey(username, realm, pw), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
},
|
||||
PacketConnConfigs: []turn.PacketConnConfig{
|
||||
@@ -222,6 +225,7 @@ func connectWithVNet(aAgent, bAgent *Agent) (*Conn, *Conn) {
|
||||
|
||||
// Ensure accepted
|
||||
<-accepted
|
||||
|
||||
return aConn, bConn
|
||||
}
|
||||
|
||||
@@ -230,7 +234,7 @@ type agentTestConfig struct {
|
||||
nat1To1IPCandidateType CandidateType
|
||||
}
|
||||
|
||||
func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
|
||||
func pipeWithVNet(vnet *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*Conn, *Conn) {
|
||||
aNotifier, aConnected := onConnected()
|
||||
bNotifier, bConnected := onConnected()
|
||||
|
||||
@@ -247,7 +251,7 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
|
||||
MulticastDNSMode: MulticastDNSModeDisabled,
|
||||
NAT1To1IPs: nat1To1IPs,
|
||||
NAT1To1IPCandidateType: a0TestConfig.nat1To1IPCandidateType,
|
||||
Net: v.net0,
|
||||
Net: vnet.net0,
|
||||
}
|
||||
|
||||
aAgent, err := NewAgent(cfg0)
|
||||
@@ -270,7 +274,7 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
|
||||
MulticastDNSMode: MulticastDNSModeDisabled,
|
||||
NAT1To1IPs: nat1To1IPs,
|
||||
NAT1To1IPCandidateType: a1TestConfig.nat1To1IPCandidateType,
|
||||
Net: v.net1,
|
||||
Net: vnet.net1,
|
||||
}
|
||||
|
||||
bAgent, err := NewAgent(cfg1)
|
||||
@@ -293,6 +297,8 @@ func pipeWithVNet(v *virtualNet, a0TestConfig, a1TestConfig *agentTestConfig) (*
|
||||
}
|
||||
|
||||
func closePipe(t *testing.T, ca *Conn, cb *Conn) {
|
||||
t.Helper()
|
||||
|
||||
require.NoError(t, ca.Close())
|
||||
require.NoError(t, cb.Close())
|
||||
}
|
||||
@@ -325,10 +331,10 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
MappingBehavior: vnet.EndpointIndependent,
|
||||
FilteringBehavior: vnet.EndpointIndependent,
|
||||
}
|
||||
v, err := buildVNet(natType, natType)
|
||||
vnet, err := buildVNet(natType, natType)
|
||||
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer v.close()
|
||||
defer vnet.close()
|
||||
|
||||
log.Debug("Connecting...")
|
||||
a0TestConfig := &agentTestConfig{
|
||||
@@ -341,7 +347,7 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
stunServerURL,
|
||||
},
|
||||
}
|
||||
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
|
||||
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -358,10 +364,10 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
MappingBehavior: vnet.EndpointAddrPortDependent,
|
||||
FilteringBehavior: vnet.EndpointAddrPortDependent,
|
||||
}
|
||||
v, err := buildVNet(natType, natType)
|
||||
vnet, err := buildVNet(natType, natType)
|
||||
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer v.close()
|
||||
defer vnet.close()
|
||||
|
||||
log.Debug("Connecting...")
|
||||
a0TestConfig := &agentTestConfig{
|
||||
@@ -375,7 +381,7 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
stunServerURL,
|
||||
},
|
||||
}
|
||||
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
|
||||
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
|
||||
|
||||
log.Debug("Closing...")
|
||||
closePipe(t, ca, cb)
|
||||
@@ -394,10 +400,10 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
MappingBehavior: vnet.EndpointAddrPortDependent,
|
||||
FilteringBehavior: vnet.EndpointAddrPortDependent,
|
||||
}
|
||||
v, err := buildVNet(natType0, natType1)
|
||||
vnet, err := buildVNet(natType0, natType1)
|
||||
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer v.close()
|
||||
defer vnet.close()
|
||||
|
||||
log.Debug("Connecting...")
|
||||
a0TestConfig := &agentTestConfig{
|
||||
@@ -407,7 +413,7 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
a1TestConfig := &agentTestConfig{
|
||||
urls: []*stun.URI{},
|
||||
}
|
||||
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
|
||||
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
|
||||
|
||||
log.Debug("Closing...")
|
||||
closePipe(t, ca, cb)
|
||||
@@ -426,10 +432,10 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
MappingBehavior: vnet.EndpointAddrPortDependent,
|
||||
FilteringBehavior: vnet.EndpointAddrPortDependent,
|
||||
}
|
||||
v, err := buildVNet(natType0, natType1)
|
||||
vnet, err := buildVNet(natType0, natType1)
|
||||
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer v.close()
|
||||
defer vnet.close()
|
||||
|
||||
log.Debug("Connecting...")
|
||||
a0TestConfig := &agentTestConfig{
|
||||
@@ -439,14 +445,15 @@ func TestConnectivityVNet(t *testing.T) {
|
||||
a1TestConfig := &agentTestConfig{
|
||||
urls: []*stun.URI{},
|
||||
}
|
||||
ca, cb := pipeWithVNet(v, a0TestConfig, a1TestConfig)
|
||||
ca, cb := pipeWithVNet(vnet, a0TestConfig, a1TestConfig)
|
||||
|
||||
log.Debug("Closing...")
|
||||
closePipe(t, ca, cb)
|
||||
})
|
||||
}
|
||||
|
||||
// TestDisconnectedToConnected requires that an agent can go to disconnected, and then return to connected successfully
|
||||
// TestDisconnectedToConnected requires that an agent can go to disconnected,
|
||||
// and then return to connected successfully.
|
||||
func TestDisconnectedToConnected(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -546,7 +553,7 @@ func TestDisconnectedToConnected(t *testing.T) {
|
||||
require.NoError(t, wan.Stop())
|
||||
}
|
||||
|
||||
// Agent.Write should use the best valid pair if a selected pair is not yet available
|
||||
// Agent.Write should use the best valid pair if a selected pair is not yet available.
|
||||
func TestWriteUseValidPair(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
|
52
errors.go
52
errors.go
@@ -29,69 +29,71 @@ var (
|
||||
ErrPort = errors.New("invalid port")
|
||||
|
||||
// ErrLocalUfragInsufficientBits indicates local username fragment insufficient bits are provided.
|
||||
// Have to be at least 24 bits long
|
||||
// Have to be at least 24 bits long.
|
||||
ErrLocalUfragInsufficientBits = errors.New("local username fragment is less than 24 bits long")
|
||||
|
||||
// ErrLocalPwdInsufficientBits indicates local password insufficient bits are provided.
|
||||
// Have to be at least 128 bits long
|
||||
// Have to be at least 128 bits long.
|
||||
ErrLocalPwdInsufficientBits = errors.New("local password is less than 128 bits long")
|
||||
|
||||
// ErrProtoType indicates an unsupported transport type was provided.
|
||||
ErrProtoType = errors.New("invalid transport protocol type")
|
||||
|
||||
// ErrClosed indicates the agent is closed
|
||||
// ErrClosed indicates the agent is closed.
|
||||
ErrClosed = taskloop.ErrClosed
|
||||
|
||||
// ErrNoCandidatePairs indicates agent does not have a valid candidate pair
|
||||
// ErrNoCandidatePairs indicates agent does not have a valid candidate pair.
|
||||
ErrNoCandidatePairs = errors.New("no candidate pairs available")
|
||||
|
||||
// ErrCanceledByCaller indicates agent connection was canceled by the caller
|
||||
// ErrCanceledByCaller indicates agent connection was canceled by the caller.
|
||||
ErrCanceledByCaller = errors.New("connecting canceled by caller")
|
||||
|
||||
// ErrMultipleStart indicates agent was started twice
|
||||
// ErrMultipleStart indicates agent was started twice.
|
||||
ErrMultipleStart = errors.New("attempted to start agent twice")
|
||||
|
||||
// ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag
|
||||
// ErrRemoteUfragEmpty indicates agent was started with an empty remote ufrag.
|
||||
ErrRemoteUfragEmpty = errors.New("remote ufrag is empty")
|
||||
|
||||
// ErrRemotePwdEmpty indicates agent was started with an empty remote pwd
|
||||
// ErrRemotePwdEmpty indicates agent was started with an empty remote pwd.
|
||||
ErrRemotePwdEmpty = errors.New("remote pwd is empty")
|
||||
|
||||
// ErrNoOnCandidateHandler indicates agent was started without OnCandidate
|
||||
// ErrNoOnCandidateHandler indicates agent was started without OnCandidate.
|
||||
ErrNoOnCandidateHandler = errors.New("no OnCandidate provided")
|
||||
|
||||
// ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times
|
||||
// ErrMultipleGatherAttempted indicates GatherCandidates has been called multiple times.
|
||||
ErrMultipleGatherAttempted = errors.New("attempting to gather candidates during gathering state")
|
||||
|
||||
// ErrUsernameEmpty indicates agent was give TURN URL with an empty Username
|
||||
// ErrUsernameEmpty indicates agent was give TURN URL with an empty Username.
|
||||
ErrUsernameEmpty = errors.New("username is empty")
|
||||
|
||||
// ErrPasswordEmpty indicates agent was give TURN URL with an empty Password
|
||||
// ErrPasswordEmpty indicates agent was give TURN URL with an empty Password.
|
||||
ErrPasswordEmpty = errors.New("password is empty")
|
||||
|
||||
// ErrAddressParseFailed indicates we were unable to parse a candidate address
|
||||
// ErrAddressParseFailed indicates we were unable to parse a candidate address.
|
||||
ErrAddressParseFailed = errors.New("failed to parse address")
|
||||
|
||||
// ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent
|
||||
// ErrLiteUsingNonHostCandidates indicates non host candidates were selected for a lite agent.
|
||||
ErrLiteUsingNonHostCandidates = errors.New("lite agents must only use host candidates")
|
||||
|
||||
// ErrUselessUrlsProvided indicates that one or more URL was provided to the agent but no host
|
||||
// candidate required them
|
||||
// candidate required them.
|
||||
ErrUselessUrlsProvided = errors.New("agent does not need URL with selected candidate types")
|
||||
|
||||
// ErrUnsupportedNAT1To1IPCandidateType indicates that the specified NAT1To1IPCandidateType is
|
||||
// unsupported
|
||||
// unsupported.
|
||||
ErrUnsupportedNAT1To1IPCandidateType = errors.New("unsupported 1:1 NAT IP candidate type")
|
||||
|
||||
// ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid
|
||||
// ErrInvalidNAT1To1IPMapping indicates that the given 1:1 NAT IP mapping is invalid.
|
||||
ErrInvalidNAT1To1IPMapping = errors.New("invalid 1:1 NAT IP mapping")
|
||||
|
||||
// ErrExternalMappedIPNotFound in NAT1To1IPMapping
|
||||
// ErrExternalMappedIPNotFound in NAT1To1IPMapping.
|
||||
ErrExternalMappedIPNotFound = errors.New("external mapped IP not found")
|
||||
|
||||
// ErrMulticastDNSWithNAT1To1IPMapping indicates that the mDNS gathering cannot be used along
|
||||
// with 1:1 NAT IP mapping for host candidate.
|
||||
ErrMulticastDNSWithNAT1To1IPMapping = errors.New("mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate")
|
||||
ErrMulticastDNSWithNAT1To1IPMapping = errors.New(
|
||||
"mDNS gathering cannot be used with 1:1 NAT IP mapping for host candidate",
|
||||
)
|
||||
|
||||
// ErrIneffectiveNAT1To1IPMappingHost indicates that 1:1 NAT IP mapping for host candidate is
|
||||
// requested, but the host candidate type is disabled.
|
||||
@@ -101,10 +103,12 @@ var (
|
||||
// requested, but the srflx candidate type is disabled.
|
||||
ErrIneffectiveNAT1To1IPMappingSrflx = errors.New("1:1 NAT IP mapping for srflx candidate ineffective")
|
||||
|
||||
// ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName
|
||||
ErrInvalidMulticastDNSHostName = errors.New("invalid mDNS HostName, must end with .local and can only contain a single '.'")
|
||||
// ErrInvalidMulticastDNSHostName indicates an invalid MulticastDNSHostName.
|
||||
ErrInvalidMulticastDNSHostName = errors.New(
|
||||
"invalid mDNS HostName, must end with .local and can only contain a single '.'",
|
||||
)
|
||||
|
||||
// ErrRunCanceled indicates a run operation was canceled by its individual done
|
||||
// ErrRunCanceled indicates a run operation was canceled by its individual done.
|
||||
ErrRunCanceled = errors.New("run was canceled by done")
|
||||
|
||||
// ErrTCPRemoteAddrAlreadyExists indicates we already have the connection with same remote addr.
|
||||
@@ -113,7 +117,7 @@ var (
|
||||
// ErrUnknownCandidateTyp indicates that a candidate had a unknown type value.
|
||||
ErrUnknownCandidateTyp = errors.New("unknown candidate typ")
|
||||
|
||||
// ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed
|
||||
// ErrDetermineNetworkType indicates that the NetworkType was not able to be parsed.
|
||||
ErrDetermineNetworkType = errors.New("unable to determine networkType")
|
||||
|
||||
errAttributeTooShortICECandidate = errors.New("attribute not long enough to be ICE candidate")
|
||||
@@ -144,5 +148,5 @@ var (
|
||||
|
||||
// UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now.
|
||||
// will be used in the future.
|
||||
// errListenUnspecified = errors.New("can't listen on unspecified address")
|
||||
// errListenUnspecified = errors.New("can't listen on unspecified address").
|
||||
)
|
||||
|
@@ -26,7 +26,7 @@ var (
|
||||
localHTTPPort, remoteHTTPPort int
|
||||
)
|
||||
|
||||
// HTTP Listener to get ICE Credentials from remote Peer
|
||||
// HTTP Listener to get ICE Credentials from remote Peer.
|
||||
func remoteAuth(_ http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
panic(err)
|
||||
@@ -36,7 +36,7 @@ func remoteAuth(_ http.ResponseWriter, r *http.Request) {
|
||||
remoteAuthChannel <- r.PostForm["pwd"][0]
|
||||
}
|
||||
|
||||
// HTTP Listener to get ICE Candidate from remote Peer
|
||||
// HTTP Listener to get ICE Candidate from remote Peer.
|
||||
func remoteCandidate(_ http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
panic(err)
|
||||
|
@@ -13,10 +13,13 @@ func validateIPString(ipStr string) (net.IP, bool, error) {
|
||||
if ip == nil {
|
||||
return nil, false, ErrInvalidNAT1To1IPMapping
|
||||
}
|
||||
|
||||
return ip, (ip.To4() != nil), nil
|
||||
}
|
||||
|
||||
// ipMapping holds the mapping of local and external IP address for a particular IP family
|
||||
// ipMapping holds the mapping of local and external IP address
|
||||
//
|
||||
// for a particular IP family.
|
||||
type ipMapping struct {
|
||||
ipSole net.IP // When non-nil, this is the sole external IP for one local IP assumed
|
||||
ipMap map[string]net.IP // Local-to-external IP mapping (k: local, v: external)
|
||||
@@ -75,7 +78,11 @@ type externalIPMapper struct {
|
||||
candidateType CandidateType
|
||||
}
|
||||
|
||||
func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIPMapper, error) { //nolint:gocognit
|
||||
//nolint:gocognit,cyclop
|
||||
func newExternalIPMapper(
|
||||
candidateType CandidateType,
|
||||
ips []string,
|
||||
) (*externalIPMapper, error) {
|
||||
if len(ips) == 0 {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
@@ -85,7 +92,7 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
|
||||
return nil, ErrUnsupportedNAT1To1IPCandidateType
|
||||
}
|
||||
|
||||
m := &externalIPMapper{
|
||||
mapper := &externalIPMapper{
|
||||
ipv4Mapping: ipMapping{ipMap: map[string]net.IP{}},
|
||||
ipv6Mapping: ipMapping{ipMap: map[string]net.IP{}},
|
||||
candidateType: candidateType,
|
||||
@@ -101,13 +108,13 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ipPair) == 1 {
|
||||
if len(ipPair) == 1 { //nolint:nestif
|
||||
if isExtIPv4 {
|
||||
if err := m.ipv4Mapping.setSoleIP(extIP); err != nil {
|
||||
if err := mapper.ipv4Mapping.setSoleIP(extIP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := m.ipv6Mapping.setSoleIP(extIP); err != nil {
|
||||
if err := mapper.ipv6Mapping.setSoleIP(extIP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -121,7 +128,7 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
|
||||
return nil, ErrInvalidNAT1To1IPMapping
|
||||
}
|
||||
|
||||
if err := m.ipv4Mapping.addIPMapping(locIP, extIP); err != nil {
|
||||
if err := mapper.ipv4Mapping.addIPMapping(locIP, extIP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
@@ -129,14 +136,14 @@ func newExternalIPMapper(candidateType CandidateType, ips []string) (*externalIP
|
||||
return nil, ErrInvalidNAT1To1IPMapping
|
||||
}
|
||||
|
||||
if err := m.ipv6Mapping.addIPMapping(locIP, extIP); err != nil {
|
||||
if err := mapper.ipv6Mapping.addIPMapping(locIP, extIP); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
return mapper, nil
|
||||
}
|
||||
|
||||
func (m *externalIPMapper) findExternalIP(localIPStr string) (net.IP, error) {
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExternalIPMapper(t *testing.T) {
|
||||
func TestExternalIPMapper(t *testing.T) { //nolint:maintidx
|
||||
t.Run("validateIPString", func(t *testing.T) {
|
||||
var ip net.IP
|
||||
var isIPv4 bool
|
||||
@@ -31,165 +31,165 @@ func TestExternalIPMapper(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("newExternalIPMapper", func(t *testing.T) {
|
||||
var m *externalIPMapper
|
||||
var mapper *externalIPMapper
|
||||
var err error
|
||||
|
||||
// ips being nil should succeed but mapper will be nil also
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, nil)
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, nil)
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// ips being empty should succeed but mapper will still be nil
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{})
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// IPv4 with no explicit local IP, defaults to CandidateTypeHost
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
|
||||
require.NotNil(t, m.ipv4Mapping.ipSole)
|
||||
require.Nil(t, m.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
|
||||
require.NotNil(t, mapper.ipv4Mapping.ipSole)
|
||||
require.Nil(t, mapper.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
|
||||
|
||||
// IPv4 with no explicit local IP, using CandidateTypeServerReflexive
|
||||
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
"1.2.3.4",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.Equal(t, CandidateTypeServerReflexive, m.candidateType, "should match")
|
||||
require.NotNil(t, m.ipv4Mapping.ipSole)
|
||||
require.Nil(t, m.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
require.Equal(t, CandidateTypeServerReflexive, mapper.candidateType, "should match")
|
||||
require.NotNil(t, mapper.ipv4Mapping.ipSole)
|
||||
require.Nil(t, mapper.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
|
||||
|
||||
// IPv4 with no explicit local IP, defaults to CandidateTypeHost
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"2601:4567::5678",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
|
||||
require.Nil(t, m.ipv4Mapping.ipSole)
|
||||
require.NotNil(t, m.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
|
||||
require.Nil(t, mapper.ipv4Mapping.ipSole)
|
||||
require.NotNil(t, mapper.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
|
||||
|
||||
// IPv4 and IPv6 in the mix
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4",
|
||||
"2601:4567::5678",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
|
||||
require.NotNil(t, m.ipv4Mapping.ipSole)
|
||||
require.NotNil(t, m.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(m.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
|
||||
require.NotNil(t, mapper.ipv4Mapping.ipSole)
|
||||
require.NotNil(t, mapper.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 0, len(mapper.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
|
||||
|
||||
// Unsupported candidate type - CandidateTypePeerReflexive
|
||||
m, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
|
||||
"1.2.3.4",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Unsupported candidate type - CandidateTypeRelay
|
||||
m, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypePeerReflexive, []string{
|
||||
"1.2.3.4",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Cannot duplicate mapping IPv4 family
|
||||
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
"1.2.3.4",
|
||||
"5.6.7.8",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Cannot duplicate mapping IPv6 family
|
||||
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
"2201::1",
|
||||
"2201::0002",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Invalide external IP string
|
||||
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
"bad.2.3.4",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Invalide local IP string
|
||||
m, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeServerReflexive, []string{
|
||||
"1.2.3.4/10.0.0.bad",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
})
|
||||
|
||||
t.Run("newExternalIPMapper with explicit local IP", func(t *testing.T) {
|
||||
var m *externalIPMapper
|
||||
var mapper *externalIPMapper
|
||||
var err error
|
||||
|
||||
// IPv4 with explicit local IP, defaults to CandidateTypeHost
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4/10.0.0.1",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, m.candidateType, "should match")
|
||||
require.Nil(t, m.ipv4Mapping.ipSole)
|
||||
require.Nil(t, m.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 1, len(m.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(m.ipv6Mapping.ipMap), "should match")
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
require.Equal(t, CandidateTypeHost, mapper.candidateType, "should match")
|
||||
require.Nil(t, mapper.ipv4Mapping.ipSole)
|
||||
require.Nil(t, mapper.ipv6Mapping.ipSole)
|
||||
require.Equal(t, 1, len(mapper.ipv4Mapping.ipMap), "should match")
|
||||
require.Equal(t, 0, len(mapper.ipv6Mapping.ipMap), "should match")
|
||||
|
||||
// Cannot assign two ext IPs for one local IPv4
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4/10.0.0.1",
|
||||
"1.2.3.5/10.0.0.1",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Cannot assign two ext IPs for one local IPv6
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"2200::1/fe80::1",
|
||||
"2200::0002/fe80::1",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Cannot mix different IP family in a pair (1)
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"2200::1/10.0.0.1",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Cannot mix different IP family in a pair (2)
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4/fe80::1",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
|
||||
// Invalid pair
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4/192.168.0.2/10.0.0.1",
|
||||
})
|
||||
require.Error(t, err, "should fail")
|
||||
require.Nil(t, m, "should be nil")
|
||||
require.Nil(t, mapper, "should be nil")
|
||||
})
|
||||
|
||||
t.Run("newExternalIPMapper with implicit and explicit local IP", func(t *testing.T) {
|
||||
@@ -209,100 +209,100 @@ func TestExternalIPMapper(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("findExternalIP without explicit local IP", func(t *testing.T) {
|
||||
var m *externalIPMapper
|
||||
var mapper *externalIPMapper
|
||||
var err error
|
||||
var extIP net.IP
|
||||
|
||||
// IPv4 with explicit local IP, defaults to CandidateTypeHost
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4",
|
||||
"2200::1",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.NotNil(t, m.ipv4Mapping.ipSole)
|
||||
require.NotNil(t, m.ipv6Mapping.ipSole)
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
require.NotNil(t, mapper.ipv4Mapping.ipSole)
|
||||
require.NotNil(t, mapper.ipv6Mapping.ipSole)
|
||||
|
||||
// Find external IPv4
|
||||
extIP, err = m.findExternalIP("10.0.0.1")
|
||||
extIP, err = mapper.findExternalIP("10.0.0.1")
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "1.2.3.4", extIP.String(), "should match")
|
||||
|
||||
// Find external IPv6
|
||||
extIP, err = m.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
|
||||
extIP, err = mapper.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "2200::1", extIP.String(), "should match")
|
||||
|
||||
// Bad local IP string
|
||||
_, err = m.findExternalIP("really.bad")
|
||||
_, err = mapper.findExternalIP("really.bad")
|
||||
require.Error(t, err, "should fail")
|
||||
})
|
||||
|
||||
t.Run("findExternalIP with explicit local IP", func(t *testing.T) {
|
||||
var m *externalIPMapper
|
||||
var mapper *externalIPMapper
|
||||
var err error
|
||||
var extIP net.IP
|
||||
|
||||
// IPv4 with explicit local IP, defaults to CandidateTypeHost
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4/10.0.0.1",
|
||||
"1.2.3.5/10.0.0.2",
|
||||
"2200::1/fe80::1",
|
||||
"2200::2/fe80::2",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.NotNil(t, m, "should not be nil")
|
||||
require.NotNil(t, mapper, "should not be nil")
|
||||
|
||||
// Find external IPv4
|
||||
extIP, err = m.findExternalIP("10.0.0.1")
|
||||
extIP, err = mapper.findExternalIP("10.0.0.1")
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "1.2.3.4", extIP.String(), "should match")
|
||||
|
||||
extIP, err = m.findExternalIP("10.0.0.2")
|
||||
extIP, err = mapper.findExternalIP("10.0.0.2")
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "1.2.3.5", extIP.String(), "should match")
|
||||
|
||||
_, err = m.findExternalIP("10.0.0.3")
|
||||
_, err = mapper.findExternalIP("10.0.0.3")
|
||||
require.Error(t, err, "should fail")
|
||||
|
||||
// Find external IPv6
|
||||
extIP, err = m.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
|
||||
extIP, err = mapper.findExternalIP("fe80::0001") // Use '0001' instead of '1' on purpose
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "2200::1", extIP.String(), "should match")
|
||||
|
||||
extIP, err = m.findExternalIP("fe80::0002") // Use '0002' instead of '2' on purpose
|
||||
extIP, err = mapper.findExternalIP("fe80::0002") // Use '0002' instead of '2' on purpose
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "2200::2", extIP.String(), "should match")
|
||||
|
||||
_, err = m.findExternalIP("fe80::3")
|
||||
_, err = mapper.findExternalIP("fe80::3")
|
||||
require.Error(t, err, "should fail")
|
||||
|
||||
// Bad local IP string
|
||||
_, err = m.findExternalIP("really.bad")
|
||||
_, err = mapper.findExternalIP("really.bad")
|
||||
require.Error(t, err, "should fail")
|
||||
})
|
||||
|
||||
t.Run("findExternalIP with empty map", func(t *testing.T) {
|
||||
var m *externalIPMapper
|
||||
var mapper *externalIPMapper
|
||||
var err error
|
||||
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"1.2.3.4",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
// Attempt to find IPv6 that does not exist in the map
|
||||
extIP, err := m.findExternalIP("fe80::1")
|
||||
extIP, err := mapper.findExternalIP("fe80::1")
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "fe80::1", extIP.String(), "should match")
|
||||
|
||||
m, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
mapper, err = newExternalIPMapper(CandidateTypeUnspecified, []string{
|
||||
"2200::1",
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
// Attempt to find IPv4 that does not exist in the map
|
||||
extIP, err = m.findExternalIP("10.0.0.1")
|
||||
extIP, err = mapper.findExternalIP("10.0.0.1")
|
||||
require.NoError(t, err, "should succeed")
|
||||
require.Equal(t, "10.0.0.1", extIP.String(), "should match")
|
||||
})
|
||||
|
119
gather.go
119
gather.go
@@ -21,10 +21,11 @@ import (
|
||||
"github.com/pion/turn/v4"
|
||||
)
|
||||
|
||||
// Close a net.Conn and log if we have a failure
|
||||
// Close a net.Conn and log if we have a failure.
|
||||
func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...interface{}) {
|
||||
if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) {
|
||||
log.Warnf("Connection is not allocated: "+msg, args...)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -41,9 +42,11 @@ func (a *Agent) GatherCandidates() error {
|
||||
if runErr := a.loop.Run(a.loop, func(ctx context.Context) {
|
||||
if a.gatheringState != GatheringStateNew {
|
||||
gatherErr = ErrMultipleGatherAttempted
|
||||
|
||||
return
|
||||
} else if a.onCandidateHdlr.Load() == nil {
|
||||
gatherErr = ErrNoOnCandidateHandler
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -57,13 +60,15 @@ func (a *Agent) GatherCandidates() error {
|
||||
}); runErr != nil {
|
||||
return runErr
|
||||
}
|
||||
|
||||
return gatherErr
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) {
|
||||
func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) { //nolint:cyclop
|
||||
defer close(done)
|
||||
if err := a.setGatheringState(GatheringStateGathering); err != nil { //nolint:contextcheck
|
||||
a.log.Warnf("Failed to set gatheringState to GatheringStateGathering: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,7 +116,8 @@ func (a *Agent) gatherCandidates(ctx context.Context, done chan struct{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) { //nolint:gocognit
|
||||
//nolint:gocognit,gocyclo,cyclop
|
||||
func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []NetworkType) {
|
||||
networks := map[string]struct{}{}
|
||||
for _, networkType := range networkTypes {
|
||||
if networkType.IsTCP() {
|
||||
@@ -132,16 +138,19 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, addr := range localAddrs {
|
||||
mappedIP := addr
|
||||
if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
|
||||
if a.mDNSMode != MulticastDNSModeQueryAndGather &&
|
||||
a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost {
|
||||
if _mappedIP, innerErr := a.extIPMapper.findExternalIP(addr.String()); innerErr == nil {
|
||||
conv, ok := netip.AddrFromSlice(_mappedIP)
|
||||
if !ok {
|
||||
a.log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String())
|
||||
|
||||
continue
|
||||
}
|
||||
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable
|
||||
@@ -186,6 +195,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice())
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag)
|
||||
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
@@ -194,6 +204,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice())
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag)
|
||||
|
||||
continue
|
||||
}
|
||||
muxConns = []net.PacketConn{conn}
|
||||
@@ -222,6 +233,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
})
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to listen %s %s", network, addr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -229,6 +241,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
conns = append(conns, connAndPort{conn, udpConn.Port})
|
||||
} else {
|
||||
a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -245,21 +258,38 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
IsLocationTracked: isLocationTracked,
|
||||
}
|
||||
|
||||
c, err := NewCandidateHost(&hostConfig)
|
||||
candidateHost, err := NewCandidateHost(&hostConfig)
|
||||
if err != nil {
|
||||
closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err)
|
||||
closeConnAndLog(
|
||||
connAndPort.conn,
|
||||
a.log,
|
||||
"failed to create host candidate: %s %s %d: %v",
|
||||
network, mappedIP,
|
||||
connAndPort.port,
|
||||
err,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if a.mDNSMode == MulticastDNSModeQueryAndGather {
|
||||
if err = c.setIPAddr(addr); err != nil {
|
||||
closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err)
|
||||
if err = candidateHost.setIPAddr(addr); err != nil {
|
||||
closeConnAndLog(
|
||||
connAndPort.conn,
|
||||
a.log,
|
||||
"failed to create host candidate: %s %s %d: %v",
|
||||
network,
|
||||
mappedIP,
|
||||
connAndPort.port,
|
||||
err,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.addCandidate(ctx, c, connAndPort.conn); err != nil {
|
||||
if closeErr := c.close(); closeErr != nil {
|
||||
if err := a.addCandidate(ctx, candidateHost, connAndPort.conn); err != nil {
|
||||
if closeErr := candidateHost.close(); closeErr != nil {
|
||||
a.log.Warnf("Failed to close candidate: %v", closeErr)
|
||||
}
|
||||
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err)
|
||||
@@ -287,10 +317,11 @@ func shouldFilterLocationTracked(candidateIP net.IP) bool {
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return shouldFilterLocationTrackedIP(addr)
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit
|
||||
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit,cyclop
|
||||
if a.udpMux == nil {
|
||||
return errUDPMuxDisabled
|
||||
}
|
||||
@@ -317,6 +348,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
|
||||
mappedIP, err := a.extIPMapper.findExternalIP(candidateIP.String())
|
||||
if err != nil {
|
||||
a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", candidateIP.String())
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -359,6 +391,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
|
||||
c, err := NewCandidateHost(&hostConfig)
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "failed to create host mux candidate: %s %d: %v", candidateIP, udpAddr.Port, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -368,6 +401,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin
|
||||
}
|
||||
|
||||
closeConnAndLog(conn, a.log, "failed to add candidate: %s %d: %v", candidateIP, udpAddr.Port, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -391,26 +425,37 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0})
|
||||
conn, err := listenUDPInPortRange(
|
||||
a.net,
|
||||
a.log,
|
||||
int(a.portMax),
|
||||
int(a.portMin),
|
||||
network,
|
||||
&net.UDPAddr{IP: nil, Port: 0},
|
||||
)
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to listen %s: %v", network, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
lAddr, ok := conn.LocalAddr().(*net.UDPAddr)
|
||||
if !ok {
|
||||
closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but LocalAddr is not a UDPAddr")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
mappedIP, err := a.extIPMapper.findExternalIP(lAddr.IP.String())
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "1:1 NAT mapping is enabled but no external IP is found for %s", lAddr.IP.String())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if shouldFilterLocationTracked(mappedIP) {
|
||||
closeConnAndLog(conn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -429,6 +474,7 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
|
||||
mappedIP.String(),
|
||||
lAddr.Port,
|
||||
err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -442,7 +488,8 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes []
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { //nolint:gocognit
|
||||
//nolint:gocognit,cyclop
|
||||
func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) {
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
@@ -456,6 +503,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
|
||||
udpAddr, ok := listenAddr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
a.log.Warn("Failed to cast udpMuxSrflx listen address to UDPAddr")
|
||||
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
@@ -466,23 +514,27 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
|
||||
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
|
||||
if err != nil {
|
||||
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if shouldFilterLocationTracked(serverAddr.IP) {
|
||||
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, a.stunGatherTimeout)
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), localAddr)
|
||||
if err != nil {
|
||||
a.log.Warnf("Failed to find connection in UDPMuxSrflx %s %s: %v", network, url, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -500,6 +552,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
|
||||
c, err := NewCandidateServerReflexive(&srflxConfig)
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -515,7 +568,8 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) { //nolint:gocognit
|
||||
//nolint:cyclop,gocognit
|
||||
func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, networkTypes []NetworkType) {
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
@@ -533,17 +587,27 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
|
||||
serverAddr, err := a.net.ResolveUDPAddr(network, hostPort)
|
||||
if err != nil {
|
||||
a.log.Debugf("Failed to resolve STUN host: %s %s: %v", network, hostPort, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if shouldFilterLocationTracked(serverAddr.IP) {
|
||||
a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0})
|
||||
conn, err := listenUDPInPortRange(
|
||||
a.net,
|
||||
a.log,
|
||||
int(a.portMax),
|
||||
int(a.portMin),
|
||||
network,
|
||||
&net.UDPAddr{IP: nil, Port: 0},
|
||||
)
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err)
|
||||
|
||||
return
|
||||
}
|
||||
// If the agent closes midway through the connection
|
||||
@@ -562,6 +626,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
|
||||
xorAddr, err := stunx.GetXORMappedAddr(conn, serverAddr, a.stunGatherTimeout)
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "failed to get server reflexive address %s %s: %v", network, url, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -580,6 +645,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
|
||||
c, err := NewCandidateServerReflexive(&srflxConfig)
|
||||
if err != nil {
|
||||
closeConnAndLog(conn, a.log, "failed to create server reflexive candidate: %s %s %d: %v", network, ip, port, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -594,7 +660,8 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { //nolint:gocognit
|
||||
//nolint:maintidx,gocognit,gocyclo,cyclop
|
||||
func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) {
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
@@ -605,9 +672,11 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
continue
|
||||
case urls[i].Username == "":
|
||||
a.log.Errorf("Failed to gather relay candidates: %v", ErrUsernameEmpty)
|
||||
|
||||
return
|
||||
case urls[i].Password == "":
|
||||
a.log.Errorf("Failed to gather relay candidates: %v", ErrPasswordEmpty)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -627,6 +696,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
case url.Proto == stun.ProtoTypeUDP && url.Scheme == stun.SchemeTypeTURN:
|
||||
if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil {
|
||||
a.log.Warnf("Failed to listen %s: %v", network, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -638,6 +708,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), turnServerAddr)
|
||||
if connectErr != nil {
|
||||
a.log.Warnf("Failed to dial TCP address %s via proxy dialer: %v", turnServerAddr, connectErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -654,12 +725,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
|
||||
if connectErr != nil {
|
||||
a.log.Warnf("Failed to resolve TCP address %s: %v", turnServerAddr, connectErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
|
||||
if connectErr != nil {
|
||||
a.log.Warnf("Failed to dial TCP address %s: %v", turnServerAddr, connectErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -671,12 +744,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
udpAddr, connectErr := a.net.ResolveUDPAddr(network, turnServerAddr)
|
||||
if connectErr != nil {
|
||||
a.log.Warnf("Failed to resolve UDP address %s: %v", turnServerAddr, connectErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr)
|
||||
if dialErr != nil {
|
||||
a.log.Warnf("Failed to dial DTLS address %s: %v", turnServerAddr, dialErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -686,11 +761,13 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
})
|
||||
if connectErr != nil {
|
||||
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if connectErr = conn.HandshakeContext(ctx); connectErr != nil {
|
||||
a.log.Warnf("Failed to create DTLS client: %v", turnServerAddr, connectErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -702,12 +779,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
tcpAddr, resolvErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), turnServerAddr)
|
||||
if resolvErr != nil {
|
||||
a.log.Warnf("Failed to resolve relay address %s: %v", turnServerAddr, resolvErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
|
||||
if dialErr != nil {
|
||||
a.log.Warnf("Failed to connect to relay: %v", dialErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -721,6 +800,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
a.log.Errorf("Failed to close relay connection: %v", closeErr)
|
||||
}
|
||||
a.log.Warnf("Failed to connect to relay: %v", hsErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -730,6 +810,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
locConn = turn.NewSTUNConn(conn)
|
||||
default:
|
||||
a.log.Warnf("Unable to handle URL in gatherCandidatesRelay %v", url)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -743,12 +824,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
})
|
||||
if err != nil {
|
||||
closeConnAndLog(locConn, a.log, "failed to create new TURN client %s %s", turnServerAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err = client.Listen(); err != nil {
|
||||
client.Close()
|
||||
closeConnAndLog(locConn, a.log, "failed to listen on TURN client %s %s", turnServerAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -756,6 +839,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
if err != nil {
|
||||
client.Close()
|
||||
closeConnAndLog(locConn, a.log, "failed to allocate on TURN client %s %s", turnServerAddr, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -763,6 +847,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
|
||||
if shouldFilterLocationTracked(rAddr.IP) {
|
||||
a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -776,6 +861,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
RelayProtocol: relayProtocol,
|
||||
OnClose: func() error {
|
||||
client.Close()
|
||||
|
||||
return locConn.Close()
|
||||
},
|
||||
}
|
||||
@@ -790,6 +876,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { /
|
||||
|
||||
client.Close()
|
||||
closeConnAndLog(locConn, a.log, "failed to create relay candidate: %s %s: %v", network, rAddr.String(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
160
gather_test.go
160
gather_test.go
@@ -31,26 +31,32 @@ import (
|
||||
)
|
||||
|
||||
func TestListenUDP(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
agent, err := NewAgent(&AgentConfig{})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
_, localAddrs, err := localInterfaces(
|
||||
agent.net,
|
||||
agent.interfaceFilter,
|
||||
agent.ipFilter,
|
||||
[]NetworkType{NetworkTypeUDP4},
|
||||
false,
|
||||
)
|
||||
require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test")
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := localAddrs[0].AsSlice()
|
||||
|
||||
conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
require.NoError(t, err, "listenUDP error with no port restriction")
|
||||
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
|
||||
|
||||
_, err = listenUDPInPortRange(a.net, a.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
_, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
require.Equal(t, err, ErrPort, "listenUDP with invalid port range did not return ErrPort")
|
||||
|
||||
conn, err = listenUDPInPortRange(a.net, a.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
require.NoError(t, err, "listenUDP error with no port restriction")
|
||||
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
|
||||
|
||||
@@ -64,7 +70,7 @@ func TestListenUDP(t *testing.T) {
|
||||
result := make([]int, 0, total)
|
||||
portRange := make([]int, 0, total)
|
||||
for i := 0; i < total; i++ {
|
||||
conn, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
conn, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
require.NoError(t, err, "listenUDP error with no port restriction")
|
||||
require.NotNil(t, conn, "listenUDP error with no port restriction return a nil conn")
|
||||
|
||||
@@ -85,7 +91,7 @@ func TestListenUDP(t *testing.T) {
|
||||
if !reflect.DeepEqual(result, portRange) {
|
||||
t.Fatalf("listenUDP with port restriction [%d, %d], got:%v, want:%v", portMin, portMax, result, portRange)
|
||||
}
|
||||
_, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
_, err = listenUDPInPortRange(agent.net, agent.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax)
|
||||
}
|
||||
|
||||
@@ -94,23 +100,23 @@ func TestGatherConcurrency(t *testing.T) {
|
||||
|
||||
defer test.TimeOut(time.Second * 30).Stop()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
IncludeLoopback: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(Candidate) {
|
||||
candidateGatheredFunc()
|
||||
}))
|
||||
|
||||
// Testing for panic
|
||||
for i := 0; i < 10; i++ {
|
||||
_ = a.GatherCandidates()
|
||||
_ = agent.GatherCandidates()
|
||||
}
|
||||
|
||||
<-candidateGathered.Done()
|
||||
@@ -194,26 +200,27 @@ func TestLoopbackCandidate(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
tcase := tc
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
a, err := NewAgent(tc.agentConfig)
|
||||
agent, err := NewAgent(tc.agentConfig)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
var loopback int32
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c != nil {
|
||||
if net.ParseIP(c.Address()).IsLoopback() {
|
||||
atomic.StoreInt32(&loopback, 1)
|
||||
}
|
||||
} else {
|
||||
candidateGatheredFunc()
|
||||
|
||||
return
|
||||
}
|
||||
t.Log(c.NetworkType(), c.Priority(), c)
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
|
||||
@@ -226,7 +233,7 @@ func TestLoopbackCandidate(t *testing.T) {
|
||||
require.NoError(t, muxUnspecDefault.Close())
|
||||
}
|
||||
|
||||
// Assert that STUN gathering is done concurrently
|
||||
// Assert that STUN gathering is done concurrently.
|
||||
func TestSTUNConcurrency(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -273,7 +280,7 @@ func TestSTUNConcurrency(t *testing.T) {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
Urls: urls,
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive},
|
||||
@@ -287,29 +294,36 @@ func TestSTUNConcurrency(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
candidateGatheredFunc()
|
||||
|
||||
return
|
||||
}
|
||||
t.Log(c.NetworkType(), c.Priority(), c)
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
}
|
||||
|
||||
// Assert that TURN gathering is done concurrently
|
||||
// Assert that TURN gathering is done concurrently.
|
||||
func TestTURNConcurrency(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
defer test.TimeOut(time.Second * 30).Stop()
|
||||
|
||||
runTest := func(protocol stun.ProtoType, scheme stun.SchemeType, packetConn net.PacketConn, listener net.Listener, serverPort int) {
|
||||
runTest := func(
|
||||
protocol stun.ProtoType,
|
||||
scheme stun.SchemeType,
|
||||
packetConn net.PacketConn,
|
||||
listener net.Listener,
|
||||
serverPort int,
|
||||
) {
|
||||
packetConnConfigs := []turn.PacketConnConfig{}
|
||||
if packetConn != nil {
|
||||
packetConnConfigs = append(packetConnConfigs, turn.PacketConnConfig{
|
||||
@@ -357,7 +371,7 @@ func TestTURNConcurrency(t *testing.T) {
|
||||
Port: serverPort,
|
||||
})
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
CandidateTypes: []CandidateType{CandidateTypeRelay},
|
||||
InsecureSkipVerify: true,
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
@@ -365,16 +379,16 @@ func TestTURNConcurrency(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c != nil {
|
||||
candidateGatheredFunc()
|
||||
}
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
}
|
||||
@@ -413,16 +427,20 @@ func TestTURNConcurrency(t *testing.T) {
|
||||
require.NoError(t, genErr)
|
||||
|
||||
serverPort := randomPort(t)
|
||||
serverListener, err := dtls.Listen("udp", &net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort}, &dtls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
})
|
||||
serverListener, err := dtls.Listen(
|
||||
"udp",
|
||||
&net.UDPAddr{IP: net.ParseIP(localhostIPStr), Port: serverPort},
|
||||
&dtls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
runTest(stun.ProtoTypeUDP, stun.SchemeTypeTURNS, nil, serverListener, serverPort)
|
||||
})
|
||||
}
|
||||
|
||||
// Assert that STUN and TURN gathering are done concurrently
|
||||
// Assert that STUN and TURN gathering are done concurrently.
|
||||
func TestSTUNTURNConcurrency(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -464,25 +482,26 @@ func TestSTUNTURNConcurrency(t *testing.T) {
|
||||
Password: "password",
|
||||
})
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
Urls: urls,
|
||||
CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
{
|
||||
gatherLim := test.TimeOut(time.Second * 3) // As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s)
|
||||
// As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s)
|
||||
gatherLim := test.TimeOut(time.Second * 3)
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c != nil {
|
||||
candidateGatheredFunc()
|
||||
}
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
gatherLim.Stop()
|
||||
@@ -528,24 +547,24 @@ func TestTURNSrflx(t *testing.T) {
|
||||
Password: "password",
|
||||
}}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
Urls: urls,
|
||||
CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c != nil && c.Type() == CandidateTypeServerReflexive {
|
||||
candidateGatheredFunc()
|
||||
}
|
||||
}))
|
||||
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
}
|
||||
@@ -580,6 +599,7 @@ func (m *mockConn) SetWriteDeadline(time.Time) error { return io.EOF }
|
||||
|
||||
func (m *mockProxy) Dial(string, string) (net.Conn, error) {
|
||||
m.proxyWasDialed()
|
||||
|
||||
return &mockConn{}, nil
|
||||
}
|
||||
|
||||
@@ -599,7 +619,7 @@ func TestTURNProxyDialer(t *testing.T) {
|
||||
proxyDialer, err := proxy.FromURL(tcpProxyURI, proxy.Direct)
|
||||
require.NoError(t, err)
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
CandidateTypes: []CandidateType{CandidateTypeRelay},
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
Urls: []*stun.URI{
|
||||
@@ -616,17 +636,17 @@ func TestTURNProxyDialer(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGatherFinish, candidateGatherFinishFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
candidateGatherFinishFunc()
|
||||
}
|
||||
}))
|
||||
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
<-candidateGatherFinish.Done()
|
||||
<-proxyWasDialed.Done()
|
||||
}
|
||||
@@ -651,31 +671,31 @@ func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) {
|
||||
_ = mux.Close()
|
||||
}()
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NAT1To1IPs: []string{"1.2.3.4"},
|
||||
NAT1To1IPCandidateType: CandidateTypeHost,
|
||||
UDPMux: mux,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
gatherCandidateDone := make(chan struct{})
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(gatherCandidateDone)
|
||||
} else {
|
||||
require.Equal(t, "1.2.3.4", c.Address())
|
||||
}
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
<-gatherCandidateDone
|
||||
|
||||
require.NotEqual(t, 0, len(mux.connsIPv4))
|
||||
}
|
||||
|
||||
// Assert that candidates are given for each mux in a MultiUDPMux
|
||||
// Assert that candidates are given for each mux in a MultiUDPMux.
|
||||
func TestMultiUDPMuxUsage(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -700,25 +720,26 @@ func TestMultiUDPMuxUsage(t *testing.T) {
|
||||
}()
|
||||
}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateCh := make(chan Candidate)
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(candidateCh)
|
||||
|
||||
return
|
||||
}
|
||||
candidateCh <- c
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
portFound := make(map[int]bool)
|
||||
for c := range candidateCh {
|
||||
@@ -731,7 +752,7 @@ func TestMultiUDPMuxUsage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that candidates are given for each mux in a MultiTCPMux
|
||||
// Assert that candidates are given for each mux in a MultiTCPMux.
|
||||
func TestMultiTCPMuxUsage(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -757,25 +778,26 @@ func TestMultiTCPMuxUsage(t *testing.T) {
|
||||
}))
|
||||
}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
CandidateTypes: []CandidateType{CandidateTypeHost},
|
||||
TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateCh := make(chan Candidate)
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(candidateCh)
|
||||
|
||||
return
|
||||
}
|
||||
candidateCh <- c
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
portFound := make(map[int]bool)
|
||||
for c := range candidateCh {
|
||||
@@ -790,7 +812,7 @@ func TestMultiTCPMuxUsage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Assert that UniversalUDPMux is used while gathering when configured in the Agent
|
||||
// Assert that UniversalUDPMux is used while gathering when configured in the Agent.
|
||||
func TestUniversalUDPMuxUsage(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
@@ -816,7 +838,7 @@ func TestUniversalUDPMuxUsage(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: supportedNetworkTypes(),
|
||||
Urls: urls,
|
||||
CandidateTypes: []CandidateType{CandidateTypeServerReflexive},
|
||||
@@ -828,26 +850,32 @@ func TestUniversalUDPMuxUsage(t *testing.T) {
|
||||
if aClosed {
|
||||
return
|
||||
}
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
require.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
require.NoError(t, agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
candidateGatheredFunc()
|
||||
|
||||
return
|
||||
}
|
||||
t.Log(c.NetworkType(), c.Priority(), c)
|
||||
}))
|
||||
require.NoError(t, a.GatherCandidates())
|
||||
require.NoError(t, agent.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
aClosed = true
|
||||
|
||||
// Twice because of 2 STUN servers configured
|
||||
require.Equal(t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called")
|
||||
require.Equal(
|
||||
t,
|
||||
numSTUNS,
|
||||
udpMuxSrflx.getXORMappedAddrUsedTimes,
|
||||
"expected times that GetXORMappedAddr should be called",
|
||||
)
|
||||
// One for Restart() when agent has been initialized and one time when Close() the agent
|
||||
require.Equal(t, 2, udpMuxSrflx.removeConnByUfragTimes, "expected times that RemoveConnByUfrag should be called")
|
||||
// Twice because of 2 STUN servers configured
|
||||
@@ -871,6 +899,7 @@ func (m *universalUDPMuxMock) GetConnForURL(string, string, net.Addr) (net.Packe
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getConnForURLTimes++
|
||||
|
||||
return m.conn, nil
|
||||
}
|
||||
|
||||
@@ -878,6 +907,7 @@ func (m *universalUDPMuxMock) GetXORMappedAddr(net.Addr, time.Duration) (*stun.X
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.getXORMappedAddrUsedTimes++
|
||||
|
||||
return &stun.XORMappedAddress{IP: net.IP{100, 64, 0, 1}, Port: 77878}, nil
|
||||
}
|
||||
|
||||
|
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVNetGather(t *testing.T) {
|
||||
func TestVNetGather(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
@@ -51,7 +51,7 @@ func TestVNetGather(t *testing.T) {
|
||||
t.Fatalf("Failed to parse CIDR: %s", err)
|
||||
}
|
||||
|
||||
r, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
router, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
CIDR: cider,
|
||||
LoggerFactory: loggerFactory,
|
||||
})
|
||||
@@ -64,7 +64,7 @@ func TestVNetGather(t *testing.T) {
|
||||
t.Fatalf("Failed to create a Net: %s", err)
|
||||
}
|
||||
|
||||
err = r.AddNet(nw)
|
||||
err = router.AddNet(nw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add a Net to the router: %s", err)
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func TestVNetGather(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("listenUDP", func(t *testing.T) {
|
||||
r, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
router, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
CIDR: "1.2.3.0/24",
|
||||
LoggerFactory: loggerFactory,
|
||||
})
|
||||
@@ -107,20 +107,26 @@ func TestVNetGather(t *testing.T) {
|
||||
t.Fatalf("Failed to create a Net: %s", err)
|
||||
}
|
||||
|
||||
err = r.AddNet(nw)
|
||||
err = router.AddNet(nw)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add a Net to the router: %s", err)
|
||||
}
|
||||
|
||||
a, err := NewAgent(&AgentConfig{Net: nw})
|
||||
agent, err := NewAgent(&AgentConfig{Net: nw})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create agent: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
_, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
_, localAddrs, err := localInterfaces(
|
||||
agent.net,
|
||||
agent.interfaceFilter,
|
||||
agent.ipFilter,
|
||||
[]NetworkType{NetworkTypeUDP4},
|
||||
false,
|
||||
)
|
||||
if len(localAddrs) == 0 {
|
||||
t.Fatal("localInterfaces found no interfaces, unable to test")
|
||||
}
|
||||
@@ -128,7 +134,7 @@ func TestVNetGather(t *testing.T) {
|
||||
|
||||
ip := localAddrs[0].AsSlice()
|
||||
|
||||
conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
conn, err := listenUDPInPortRange(agent.net, agent.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("listenUDP error with no port restriction %v", err)
|
||||
} else if conn == nil {
|
||||
@@ -139,12 +145,12 @@ func TestVNetGather(t *testing.T) {
|
||||
t.Fatalf("failed to close conn")
|
||||
}
|
||||
|
||||
_, err = listenUDPInPortRange(a.net, a.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
_, err = listenUDPInPortRange(agent.net, agent.log, 4999, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
if !errors.Is(err, ErrPort) {
|
||||
t.Fatal("listenUDP with invalid port range did not return ErrPort")
|
||||
}
|
||||
|
||||
conn, err = listenUDPInPortRange(a.net, a.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
conn, err = listenUDPInPortRange(agent.net, agent.log, 5000, 5000, udp, &net.UDPAddr{IP: ip, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("listenUDP error with no port restriction %v", err)
|
||||
} else if conn == nil {
|
||||
@@ -163,7 +169,7 @@ func TestVNetGather(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestVNetGatherWithNAT1To1(t *testing.T) {
|
||||
func TestVNetGatherWithNAT1To1(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
@@ -206,7 +212,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
|
||||
err = lan.AddNet(nw)
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: []NetworkType{
|
||||
NetworkTypeUDP4,
|
||||
},
|
||||
@@ -215,25 +221,25 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
err = a.OnCandidate(func(c Candidate) {
|
||||
err = agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(done)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
err = a.GatherCandidates()
|
||||
err = agent.GatherCandidates()
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
log.Debug("Wait until gathering is complete...")
|
||||
<-done
|
||||
log.Debug("Gathering is done")
|
||||
|
||||
candidates, err := a.GetLocalCandidates()
|
||||
candidates, err := agent.GetLocalCandidates()
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
if len(candidates) != 2 {
|
||||
@@ -248,7 +254,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if candidates[0].Address() == externalIP0 {
|
||||
if candidates[0].Address() == externalIP0 { //nolint:nestif
|
||||
if candidates[1].Address() != externalIP1 {
|
||||
t.Fatalf("Unexpected candidate IP: %s", candidates[1].Address())
|
||||
}
|
||||
@@ -305,7 +311,7 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
|
||||
err = lan.AddNet(nw)
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
NetworkTypes: []NetworkType{
|
||||
NetworkTypeUDP4,
|
||||
},
|
||||
@@ -317,25 +323,25 @@ func TestVNetGatherWithNAT1To1(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
err = a.OnCandidate(func(c Candidate) {
|
||||
err = agent.OnCandidate(func(c Candidate) {
|
||||
if c == nil {
|
||||
close(done)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
err = a.GatherCandidates()
|
||||
err = agent.GatherCandidates()
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
log.Debug("Wait until gathering is complete...")
|
||||
<-done
|
||||
log.Debug("Gathering is done")
|
||||
|
||||
candidates, err := a.GetLocalCandidates()
|
||||
candidates, err := agent.GetLocalCandidates()
|
||||
require.NoError(t, err, "should succeed")
|
||||
|
||||
if len(candidates) != 2 {
|
||||
@@ -367,7 +373,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
loggerFactory := logging.NewDefaultLoggerFactory()
|
||||
r, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
router, err := vnet.NewRouter(&vnet.RouterConfig{
|
||||
CIDR: "1.2.3.0/24",
|
||||
LoggerFactory: loggerFactory,
|
||||
})
|
||||
@@ -380,24 +386,31 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
t.Fatalf("Failed to create a Net: %s", err)
|
||||
}
|
||||
|
||||
if err = r.AddNet(nw); err != nil {
|
||||
if err = router.AddNet(nw); err != nil {
|
||||
t.Fatalf("Failed to add a Net to the router: %s", err)
|
||||
}
|
||||
|
||||
t.Run("InterfaceFilter should exclude the interface", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
Net: nw,
|
||||
InterfaceFilter: func(interfaceName string) (keep bool) {
|
||||
require.Equal(t, "eth0", interfaceName)
|
||||
|
||||
return false
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
_, localIPs, err := localInterfaces(
|
||||
agent.net,
|
||||
agent.interfaceFilter,
|
||||
agent.ipFilter,
|
||||
[]NetworkType{NetworkTypeUDP4},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(localIPs) != 0 {
|
||||
@@ -406,19 +419,26 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("IPFilter should exclude the IP", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
Net: nw,
|
||||
IPFilter: func(ip net.IP) (keep bool) {
|
||||
require.Equal(t, net.IP{1, 2, 3, 1}, ip)
|
||||
|
||||
return false
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
_, localIPs, err := localInterfaces(
|
||||
agent.net,
|
||||
agent.interfaceFilter,
|
||||
agent.ipFilter,
|
||||
[]NetworkType{NetworkTypeUDP4},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(localIPs) != 0 {
|
||||
@@ -427,19 +447,26 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{
|
||||
agent, err := NewAgent(&AgentConfig{
|
||||
Net: nw,
|
||||
InterfaceFilter: func(interfaceName string) (keep bool) {
|
||||
require.Equal(t, "eth0", interfaceName)
|
||||
|
||||
return true
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, a.Close())
|
||||
require.NoError(t, agent.Close())
|
||||
}()
|
||||
|
||||
_, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
_, localIPs, err := localInterfaces(
|
||||
agent.net,
|
||||
agent.interfaceFilter,
|
||||
agent.ipFilter,
|
||||
[]NetworkType{NetworkTypeUDP4},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(localIPs) == 0 {
|
||||
|
30
ice.go
30
ice.go
@@ -3,33 +3,33 @@
|
||||
|
||||
package ice
|
||||
|
||||
// ConnectionState is an enum showing the state of a ICE Connection
|
||||
// ConnectionState is an enum showing the state of a ICE Connection.
|
||||
type ConnectionState int
|
||||
|
||||
// List of supported States
|
||||
// List of supported States.
|
||||
const (
|
||||
// ConnectionStateUnknown represents an unknown state
|
||||
// ConnectionStateUnknown represents an unknown state.
|
||||
ConnectionStateUnknown ConnectionState = iota
|
||||
|
||||
// ConnectionStateNew ICE agent is gathering addresses
|
||||
// ConnectionStateNew ICE agent is gathering addresses.
|
||||
ConnectionStateNew
|
||||
|
||||
// ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match
|
||||
// ConnectionStateChecking ICE agent has been given local and remote candidates, and is attempting to find a match.
|
||||
ConnectionStateChecking
|
||||
|
||||
// ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs
|
||||
// ConnectionStateConnected ICE agent has a pairing, but is still checking other pairs.
|
||||
ConnectionStateConnected
|
||||
|
||||
// ConnectionStateCompleted ICE agent has finished
|
||||
// ConnectionStateCompleted ICE agent has finished.
|
||||
ConnectionStateCompleted
|
||||
|
||||
// ConnectionStateFailed ICE agent never could successfully connect
|
||||
// ConnectionStateFailed ICE agent never could successfully connect.
|
||||
ConnectionStateFailed
|
||||
|
||||
// ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state
|
||||
// ConnectionStateDisconnected ICE agent connected successfully, but has entered a failed state.
|
||||
ConnectionStateDisconnected
|
||||
|
||||
// ConnectionStateClosed ICE agent has finished and is no longer handling requests
|
||||
// ConnectionStateClosed ICE agent has finished and is no longer handling requests.
|
||||
ConnectionStateClosed
|
||||
)
|
||||
|
||||
@@ -54,20 +54,20 @@ func (c ConnectionState) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// GatheringState describes the state of the candidate gathering process
|
||||
// GatheringState describes the state of the candidate gathering process.
|
||||
type GatheringState int
|
||||
|
||||
const (
|
||||
// GatheringStateUnknown represents an unknown state
|
||||
// GatheringStateUnknown represents an unknown state.
|
||||
GatheringStateUnknown GatheringState = iota
|
||||
|
||||
// GatheringStateNew indicates candidate gathering is not yet started
|
||||
// GatheringStateNew indicates candidate gathering is not yet started.
|
||||
GatheringStateNew
|
||||
|
||||
// GatheringStateGathering indicates candidate gathering is ongoing
|
||||
// GatheringStateGathering indicates candidate gathering is ongoing.
|
||||
GatheringStateGathering
|
||||
|
||||
// GatheringStateComplete indicates candidate gathering has been completed
|
||||
// GatheringStateComplete indicates candidate gathering has been completed.
|
||||
GatheringStateComplete
|
||||
)
|
||||
|
||||
|
@@ -20,6 +20,7 @@ func (a tiebreaker) AddToAs(m *stun.Message, t stun.AttrType) error {
|
||||
v := make([]byte, tiebreakerSize)
|
||||
binary.BigEndian.PutUint64(v, uint64(a))
|
||||
m.Add(t, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,6 +34,7 @@ func (a *tiebreaker) GetFromAs(m *stun.Message, t stun.AttrType) error {
|
||||
return err
|
||||
}
|
||||
*a = tiebreaker(binary.BigEndian.Uint64(v))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -73,6 +75,7 @@ func (c AttrControl) AddTo(m *stun.Message) error {
|
||||
if c.Role == Controlling {
|
||||
return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlling)
|
||||
}
|
||||
|
||||
return tiebreaker(c.Tiebreaker).AddToAs(m, stun.AttrICEControlled)
|
||||
}
|
||||
|
||||
@@ -80,11 +83,14 @@ func (c AttrControl) AddTo(m *stun.Message) error {
|
||||
func (c *AttrControl) GetFrom(m *stun.Message) error {
|
||||
if m.Contains(stun.AttrICEControlling) {
|
||||
c.Role = Controlling
|
||||
|
||||
return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlling)
|
||||
}
|
||||
if m.Contains(stun.AttrICEControlled) {
|
||||
c.Role = Controlled
|
||||
|
||||
return (*tiebreaker)(&c.Tiebreaker).GetFromAs(m, stun.AttrICEControlled)
|
||||
}
|
||||
|
||||
return stun.ErrAttributeNotFound
|
||||
}
|
||||
|
@@ -12,11 +12,11 @@ import (
|
||||
|
||||
func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
|
||||
m := new(stun.Message)
|
||||
var c AttrControlled
|
||||
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
var attrCtr AttrControlled
|
||||
if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
if err := m.Build(stun.BindingRequest, &c); err != nil {
|
||||
if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m1 := new(stun.Message)
|
||||
@@ -27,7 +27,7 @@ func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
|
||||
if err := c1.GetFrom(m1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if c1 != c {
|
||||
if c1 != attrCtr {
|
||||
t.Error("not equal")
|
||||
}
|
||||
t.Run("IncorrectSize", func(t *testing.T) {
|
||||
@@ -42,11 +42,11 @@ func TestControlled_GetFrom(t *testing.T) { //nolint:dupl
|
||||
|
||||
func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
|
||||
m := new(stun.Message)
|
||||
var c AttrControlling
|
||||
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
var attrCtr AttrControlling
|
||||
if err := attrCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
if err := m.Build(stun.BindingRequest, &c); err != nil {
|
||||
if err := m.Build(stun.BindingRequest, &attrCtr); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m1 := new(stun.Message)
|
||||
@@ -57,7 +57,7 @@ func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
|
||||
if err := c1.GetFrom(m1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if c1 != c {
|
||||
if c1 != attrCtr {
|
||||
t.Error("not equal")
|
||||
}
|
||||
t.Run("IncorrectSize", func(t *testing.T) {
|
||||
@@ -70,7 +70,7 @@ func TestControlling_GetFrom(t *testing.T) { //nolint:dupl
|
||||
})
|
||||
}
|
||||
|
||||
func TestControl_GetFrom(t *testing.T) {
|
||||
func TestControl_GetFrom(t *testing.T) { //nolint:cyclop
|
||||
t.Run("Blank", func(t *testing.T) {
|
||||
m := new(stun.Message)
|
||||
var c AttrControl
|
||||
@@ -80,13 +80,13 @@ func TestControl_GetFrom(t *testing.T) {
|
||||
})
|
||||
t.Run("Controlling", func(t *testing.T) { //nolint:dupl
|
||||
m := new(stun.Message)
|
||||
var c AttrControl
|
||||
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
var attCtr AttrControl
|
||||
if err := attCtr.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
c.Role = Controlling
|
||||
c.Tiebreaker = 4321
|
||||
if err := m.Build(stun.BindingRequest, &c); err != nil {
|
||||
attCtr.Role = Controlling
|
||||
attCtr.Tiebreaker = 4321
|
||||
if err := m.Build(stun.BindingRequest, &attCtr); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m1 := new(stun.Message)
|
||||
@@ -97,7 +97,7 @@ func TestControl_GetFrom(t *testing.T) {
|
||||
if err := c1.GetFrom(m1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if c1 != c {
|
||||
if c1 != attCtr {
|
||||
t.Error("not equal")
|
||||
}
|
||||
t.Run("IncorrectSize", func(t *testing.T) {
|
||||
@@ -111,13 +111,13 @@ func TestControl_GetFrom(t *testing.T) {
|
||||
})
|
||||
t.Run("Controlled", func(t *testing.T) { //nolint:dupl
|
||||
m := new(stun.Message)
|
||||
var c AttrControl
|
||||
if err := c.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
var attrCtrl AttrControl
|
||||
if err := attrCtrl.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
c.Role = Controlled
|
||||
c.Tiebreaker = 1234
|
||||
if err := m.Build(stun.BindingRequest, &c); err != nil {
|
||||
attrCtrl.Role = Controlled
|
||||
attrCtrl.Tiebreaker = 1234
|
||||
if err := m.Build(stun.BindingRequest, &attrCtrl); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m1 := new(stun.Message)
|
||||
@@ -128,7 +128,7 @@ func TestControl_GetFrom(t *testing.T) {
|
||||
if err := c1.GetFrom(m1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if c1 != c {
|
||||
if c1 != attrCtrl {
|
||||
t.Error("not equal")
|
||||
}
|
||||
t.Run("IncorrectSize", func(t *testing.T) {
|
||||
|
@@ -6,18 +6,19 @@ package atomic
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Error is an atomic error
|
||||
// Error is an atomic error.
|
||||
type Error struct {
|
||||
v atomic.Value
|
||||
}
|
||||
|
||||
// Store updates the value of the atomic variable
|
||||
// Store updates the value of the atomic variable.
|
||||
func (a *Error) Store(err error) {
|
||||
a.v.Store(struct{ error }{err})
|
||||
}
|
||||
|
||||
// Load retrieves the current value of the atomic variable
|
||||
// Load retrieves the current value of the atomic variable.
|
||||
func (a *Error) Load() error {
|
||||
err, _ := a.v.Load().(struct{ error })
|
||||
|
||||
return err.error
|
||||
}
|
||||
|
@@ -11,7 +11,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// MockPacketConn for tests
|
||||
// MockPacketConn for tests.
|
||||
type MockPacketConn struct{}
|
||||
|
||||
func (m *MockPacketConn) ReadFrom([]byte) (n int, addr net.Addr, err error) { return 0, nil, nil } //nolint:revive
|
||||
|
@@ -8,18 +8,19 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// Compile-time assertion
|
||||
// Compile-time assertion.
|
||||
var _ net.PacketConn = (*PacketConn)(nil)
|
||||
|
||||
// PacketConn wraps a net.Conn and emulates net.PacketConn
|
||||
// PacketConn wraps a net.Conn and emulates net.PacketConn.
|
||||
type PacketConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
// ReadFrom reads a packet from the connection,
|
||||
// ReadFrom reads a packet from the connection.
|
||||
func (f *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = f.Conn.Read(p)
|
||||
addr = f.Conn.RemoteAddr()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -59,7 +59,7 @@ func GetXORMappedAddr(conn net.PacketConn, serverAddr net.Addr, timeout time.Dur
|
||||
return &addr, nil
|
||||
}
|
||||
|
||||
// AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value
|
||||
// AssertUsername checks that the given STUN message m has a USERNAME attribute with a given value.
|
||||
func AssertUsername(m *stun.Message, expectedUsername string) error {
|
||||
var username stun.Username
|
||||
if err := username.GetFrom(m); err != nil {
|
||||
|
@@ -13,7 +13,7 @@ import (
|
||||
atomicx "github.com/pion/ice/v4/internal/atomic"
|
||||
)
|
||||
|
||||
// ErrClosed indicates that the loop has been stopped
|
||||
// ErrClosed indicates that the loop has been stopped.
|
||||
var ErrClosed = errors.New("the agent is closed")
|
||||
|
||||
type task struct {
|
||||
@@ -21,7 +21,7 @@ type task struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// Loop runs submitted task serially in a dedicated Goroutine
|
||||
// Loop runs submitted task serially in a dedicated Goroutine.
|
||||
type Loop struct {
|
||||
tasks chan task
|
||||
|
||||
@@ -31,7 +31,7 @@ type Loop struct {
|
||||
err atomicx.Error
|
||||
}
|
||||
|
||||
// New creates and starts a new task loop
|
||||
// New creates and starts a new task loop.
|
||||
func New(onClose func()) *Loop {
|
||||
l := &Loop{
|
||||
tasks: make(chan task),
|
||||
@@ -40,6 +40,7 @@ func New(onClose func()) *Loop {
|
||||
}
|
||||
|
||||
go l.runLoop(onClose)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
@@ -86,6 +87,7 @@ func (l *Loop) Run(ctx context.Context, t func(context.Context)) error {
|
||||
return ctx.Err()
|
||||
case l.tasks <- task{t, done}:
|
||||
<-done
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -113,7 +115,7 @@ func (l *Loop) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Value is not supported for task loops
|
||||
// Value is not supported for task loops.
|
||||
func (l *Loop) Value(interface{}) interface{} {
|
||||
return nil
|
||||
}
|
||||
|
28
mdns.go
28
mdns.go
@@ -14,18 +14,19 @@ import (
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
// MulticastDNSMode represents the different Multicast modes ICE can run in
|
||||
// MulticastDNSMode represents the different Multicast modes ICE can run in.
|
||||
type MulticastDNSMode byte
|
||||
|
||||
// MulticastDNSMode enum
|
||||
// MulticastDNSMode enum.
|
||||
const (
|
||||
// MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs
|
||||
// MulticastDNSModeDisabled means remote mDNS candidates will be discarded, and local host candidates will use IPs.
|
||||
MulticastDNSModeDisabled MulticastDNSMode = iota + 1
|
||||
|
||||
// MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs
|
||||
// MulticastDNSModeQueryOnly means remote mDNS candidates will be accepted, and local host candidates will use IPs.
|
||||
MulticastDNSModeQueryOnly
|
||||
|
||||
// MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted, and local host candidates will use mDNS
|
||||
// MulticastDNSModeQueryAndGather means remote mDNS candidates will be accepted,
|
||||
// and local host candidates will use mDNS.
|
||||
MulticastDNSModeQueryAndGather
|
||||
)
|
||||
|
||||
@@ -33,11 +34,13 @@ func generateMulticastDNSName() (string, error) {
|
||||
// https://tools.ietf.org/id/draft-ietf-rtcweb-mdns-ice-candidates-02.html#gathering
|
||||
// The unique name MUST consist of a version 4 UUID as defined in [RFC4122], followed by “.local”.
|
||||
u, err := uuid.NewRandom()
|
||||
|
||||
return u.String() + ".local", err
|
||||
}
|
||||
|
||||
//nolint:cyclop
|
||||
func createMulticastDNS(
|
||||
n transport.Net,
|
||||
netTransport transport.Net,
|
||||
networkTypes []NetworkType,
|
||||
interfaces []*transport.Interface,
|
||||
includeLoopback bool,
|
||||
@@ -57,6 +60,7 @@ func createMulticastDNS(
|
||||
for _, nt := range networkTypes {
|
||||
if nt.IsIPv4() {
|
||||
useV4 = true
|
||||
|
||||
continue
|
||||
}
|
||||
if nt.IsIPv6() {
|
||||
@@ -65,11 +69,11 @@ func createMulticastDNS(
|
||||
}
|
||||
}
|
||||
|
||||
addr4, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
|
||||
addr4, mdnsErr := netTransport.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4)
|
||||
if mdnsErr != nil {
|
||||
return nil, mDNSMode, mdnsErr
|
||||
}
|
||||
addr6, mdnsErr := n.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
|
||||
addr6, mdnsErr := netTransport.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6)
|
||||
if mdnsErr != nil {
|
||||
return nil, mDNSMode, mdnsErr
|
||||
}
|
||||
@@ -78,10 +82,11 @@ func createMulticastDNS(
|
||||
var mdns4Err error
|
||||
if useV4 {
|
||||
var l transport.UDPConn
|
||||
l, mdns4Err = n.ListenUDP("udp4", addr4)
|
||||
l, mdns4Err = netTransport.ListenUDP("udp4", addr4)
|
||||
if mdns4Err != nil {
|
||||
// If ICE fails to start MulticastDNS server just warn the user and continue
|
||||
log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err)
|
||||
|
||||
return nil, MulticastDNSModeDisabled, nil
|
||||
}
|
||||
pktConnV4 = ipv4.NewPacketConn(l)
|
||||
@@ -91,9 +96,10 @@ func createMulticastDNS(
|
||||
var mdns6Err error
|
||||
if useV6 {
|
||||
var l transport.UDPConn
|
||||
l, mdns6Err = n.ListenUDP("udp6", addr6)
|
||||
l, mdns6Err = netTransport.ListenUDP("udp6", addr6)
|
||||
if mdns6Err != nil {
|
||||
log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err)
|
||||
|
||||
return nil, MulticastDNSModeDisabled, nil
|
||||
}
|
||||
pktConnV6 = ipv6.NewPacketConn(l)
|
||||
@@ -119,6 +125,7 @@ func createMulticastDNS(
|
||||
Interfaces: ifcs,
|
||||
IncludeLoopback: includeLoopback,
|
||||
})
|
||||
|
||||
return conn, mDNSMode, err
|
||||
case MulticastDNSModeQueryAndGather:
|
||||
conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{
|
||||
@@ -126,6 +133,7 @@ func createMulticastDNS(
|
||||
IncludeLoopback: includeLoopback,
|
||||
LocalNames: []string{mDNSName},
|
||||
})
|
||||
|
||||
return conn, mDNSMode, err
|
||||
default:
|
||||
return nil, mDNSMode, nil
|
||||
|
41
net.go
41
net.go
@@ -24,6 +24,7 @@ func isSupportedIPv6Partial(ip net.IP) bool {
|
||||
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -33,10 +34,11 @@ func isZeros(ip net.IP) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
//nolint:gocognit,cyclop
|
||||
func localInterfaces(
|
||||
n transport.Net,
|
||||
interfaceFilter func(string) (keep bool),
|
||||
@@ -114,27 +116,35 @@ func localInterfaces(
|
||||
filteredIfaces = append(filteredIfaces, ifaceCopy)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredIfaces, ipAddrs, nil
|
||||
}
|
||||
|
||||
func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, portMin int, network string, lAddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
//nolint:cyclop
|
||||
func listenUDPInPortRange(
|
||||
netTransport transport.Net,
|
||||
log logging.LeveledLogger,
|
||||
portMax, portMin int,
|
||||
network string,
|
||||
lAddr *net.UDPAddr,
|
||||
) (transport.UDPConn, error) {
|
||||
if (lAddr.Port != 0) || ((portMin == 0) && (portMax == 0)) {
|
||||
return n.ListenUDP(network, lAddr)
|
||||
return netTransport.ListenUDP(network, lAddr)
|
||||
}
|
||||
var i, j int
|
||||
i = portMin
|
||||
if i == 0 {
|
||||
i = 1024 // Start at 1024 which is non-privileged
|
||||
|
||||
if portMin == 0 {
|
||||
portMin = 1024 // Start at 1024 which is non-privileged
|
||||
}
|
||||
j = portMax
|
||||
if j == 0 {
|
||||
j = 0xFFFF
|
||||
|
||||
if portMax == 0 {
|
||||
portMax = 0xFFFF
|
||||
}
|
||||
if i > j {
|
||||
|
||||
if portMin > portMax {
|
||||
return nil, ErrPort
|
||||
}
|
||||
|
||||
portStart := globalMathRandomGenerator.Intn(j-i+1) + i
|
||||
portStart := globalMathRandomGenerator.Intn(portMax-portMin+1) + portMin
|
||||
portCurrent := portStart
|
||||
for {
|
||||
addr := &net.UDPAddr{
|
||||
@@ -143,18 +153,19 @@ func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, p
|
||||
Port: portCurrent,
|
||||
}
|
||||
|
||||
c, e := n.ListenUDP(network, addr)
|
||||
c, e := netTransport.ListenUDP(network, addr)
|
||||
if e == nil {
|
||||
return c, e //nolint:nilerr
|
||||
}
|
||||
log.Debugf("Failed to listen %s: %v", lAddr.String(), e)
|
||||
portCurrent++
|
||||
if portCurrent > j {
|
||||
portCurrent = i
|
||||
if portCurrent > portMax {
|
||||
portCurrent = portMin
|
||||
}
|
||||
if portCurrent == portStart {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrPort
|
||||
}
|
||||
|
@@ -54,6 +54,7 @@ func problematicNetworkInterfaces(s string) (keep bool) {
|
||||
appleWirelessDirectLink := strings.Contains(s, "awdl")
|
||||
appleLowLatencyWLANInterface := strings.Contains(s, "llw")
|
||||
appleTunnelingInterface := strings.Contains(s, "utun")
|
||||
|
||||
return !defaultDockerBridgeNetwork &&
|
||||
!customDockerBridgeNetwork &&
|
||||
!accessPoint &&
|
||||
@@ -68,5 +69,6 @@ func mustAddr(t *testing.T, ip net.IP) netip.Addr {
|
||||
if !ok {
|
||||
t.Fatal(ipConvertError{ip})
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
@@ -27,7 +27,7 @@ func supportedNetworkTypes() []NetworkType {
|
||||
}
|
||||
}
|
||||
|
||||
// NetworkType represents the type of network
|
||||
// NetworkType represents the type of network.
|
||||
type NetworkType int
|
||||
|
||||
const (
|
||||
@@ -69,7 +69,7 @@ func (t NetworkType) IsTCP() bool {
|
||||
return t == NetworkTypeTCP4 || t == NetworkTypeTCP6
|
||||
}
|
||||
|
||||
// NetworkShort returns the short network description
|
||||
// NetworkShort returns the short network description.
|
||||
func (t NetworkType) NetworkShort() string {
|
||||
switch t {
|
||||
case NetworkTypeUDP4, NetworkTypeUDP6:
|
||||
@@ -81,7 +81,7 @@ func (t NetworkType) NetworkShort() string {
|
||||
}
|
||||
}
|
||||
|
||||
// IsReliable returns true if the network is reliable
|
||||
// IsReliable returns true if the network is reliable.
|
||||
func (t NetworkType) IsReliable() bool {
|
||||
switch t {
|
||||
case NetworkTypeUDP4, NetworkTypeUDP6:
|
||||
@@ -89,6 +89,7 @@ func (t NetworkType) IsReliable() bool {
|
||||
case NetworkTypeTCP4, NetworkTypeTCP6:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -100,6 +101,7 @@ func (t NetworkType) IsIPv4() bool {
|
||||
case NetworkTypeUDP6, NetworkTypeTCP6:
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -111,6 +113,7 @@ func (t NetworkType) IsIPv6() bool {
|
||||
case NetworkTypeUDP6, NetworkTypeTCP6:
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -124,12 +127,14 @@ func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) {
|
||||
if ip.Is4() {
|
||||
return NetworkTypeUDP4, nil
|
||||
}
|
||||
|
||||
return NetworkTypeUDP6, nil
|
||||
|
||||
case strings.HasPrefix(strings.ToLower(network), tcp):
|
||||
if ip.Is4() {
|
||||
return NetworkTypeTCP4, nil
|
||||
}
|
||||
|
||||
return NetworkTypeTCP6, nil
|
||||
}
|
||||
|
||||
|
@@ -19,6 +19,7 @@ func (p PriorityAttr) AddTo(m *stun.Message) error {
|
||||
v := make([]byte, prioritySize)
|
||||
binary.BigEndian.PutUint32(v, uint32(p))
|
||||
m.Add(stun.AttrPriority, v)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -32,5 +33,6 @@ func (p *PriorityAttr) GetFrom(m *stun.Message) error {
|
||||
return err
|
||||
}
|
||||
*p = PriorityAttr(binary.BigEndian.Uint32(v))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -12,11 +12,11 @@ import (
|
||||
|
||||
func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
|
||||
m := new(stun.Message)
|
||||
var p PriorityAttr
|
||||
if err := p.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
var priority PriorityAttr
|
||||
if err := priority.GetFrom(m); !errors.Is(err, stun.ErrAttributeNotFound) {
|
||||
t.Error("unexpected error")
|
||||
}
|
||||
if err := m.Build(stun.BindingRequest, &p); err != nil {
|
||||
if err := m.Build(stun.BindingRequest, &priority); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
m1 := new(stun.Message)
|
||||
@@ -27,7 +27,7 @@ func TestPriority_GetFrom(t *testing.T) { //nolint:dupl
|
||||
if err := p1.GetFrom(m1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if p1 != p {
|
||||
if p1 != priority {
|
||||
t.Error("not equal")
|
||||
}
|
||||
t.Run("IncorrectSize", func(t *testing.T) {
|
||||
|
18
rand_test.go
18
rand_test.go
@@ -23,21 +23,27 @@ func TestRandomGeneratorCollision(t *testing.T) {
|
||||
},
|
||||
"PWD": {
|
||||
gen: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
s, err := generatePwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
return s
|
||||
},
|
||||
},
|
||||
"Ufrag": {
|
||||
gen: func(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
s, err := generateUFrag()
|
||||
require.NoError(t, err)
|
||||
|
||||
return s
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const N = 100
|
||||
const num = 100
|
||||
const iteration = 100
|
||||
|
||||
for name, testCase := range testCases {
|
||||
@@ -47,9 +53,9 @@ func TestRandomGeneratorCollision(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
rands := make([]string, 0, N)
|
||||
rands := make([]string, 0, num)
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
for i := 0; i < num; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
r := testCase.gen(t)
|
||||
@@ -61,12 +67,12 @@ func TestRandomGeneratorCollision(t *testing.T) {
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if len(rands) != N {
|
||||
if len(rands) != num {
|
||||
t.Fatal("Failed to generate randoms")
|
||||
}
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
for j := i + 1; j < N; j++ {
|
||||
for i := 0; i < num; i++ {
|
||||
for j := i + 1; j < num; j++ {
|
||||
if rands[i] == rands[j] {
|
||||
t.Fatalf("generateRandString caused collision: %s == %s", rands[i], rands[j])
|
||||
}
|
||||
|
1
role.go
1
role.go
@@ -26,6 +26,7 @@ func (r *Role) UnmarshalText(text []byte) error {
|
||||
default:
|
||||
return fmt.Errorf("%w %q", errUnknownRole, text)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
116
selection.go
116
selection.go
@@ -44,6 +44,7 @@ func (s *controllingSelector) isNominatable(c Candidate) bool {
|
||||
}
|
||||
|
||||
s.log.Errorf("Invalid candidate type: %s", c.Type())
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -63,6 +64,7 @@ func (s *controllingSelector) ContactCandidates() {
|
||||
p.nominated = true
|
||||
s.nominatedPair = p
|
||||
s.nominatePair(p)
|
||||
|
||||
return
|
||||
}
|
||||
s.agent.pingAllCandidates()
|
||||
@@ -84,6 +86,7 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) {
|
||||
)
|
||||
if err != nil {
|
||||
s.log.Error(err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -91,30 +94,35 @@ func (s *controllingSelector) nominatePair(pair *CandidatePair) {
|
||||
s.agent.sendBindingRequest(msg, pair.Local, pair.Remote)
|
||||
}
|
||||
|
||||
func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
|
||||
s.agent.sendBindingSuccess(m, local, remote)
|
||||
func (s *controllingSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop
|
||||
s.agent.sendBindingSuccess(message, local, remote)
|
||||
|
||||
p := s.agent.findPair(local, remote)
|
||||
pair := s.agent.findPair(local, remote)
|
||||
|
||||
if p == nil {
|
||||
if pair == nil {
|
||||
s.agent.addPair(local, remote)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if p.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil {
|
||||
if pair.state == CandidatePairStateSucceeded && s.nominatedPair == nil && s.agent.getSelectedPair() == nil {
|
||||
bestPair := s.agent.getBestAvailableCandidatePair()
|
||||
if bestPair == nil {
|
||||
s.log.Tracef("No best pair available")
|
||||
} else if bestPair.equal(p) && s.isNominatable(p.Local) && s.isNominatable(p.Remote) {
|
||||
s.log.Tracef("The candidate (%s, %s) is the best candidate available, marking it as nominated", p.Local, p.Remote)
|
||||
s.nominatedPair = p
|
||||
s.nominatePair(p)
|
||||
} else if bestPair.equal(pair) && s.isNominatable(pair.Local) && s.isNominatable(pair.Remote) {
|
||||
s.log.Tracef(
|
||||
"The candidate (%s, %s) is the best candidate available, marking it as nominated",
|
||||
pair.Local,
|
||||
pair.Remote,
|
||||
)
|
||||
s.nominatedPair = pair
|
||||
s.nominatePair(pair)
|
||||
}
|
||||
}
|
||||
|
||||
if s.agent.userBindingRequestHandler != nil {
|
||||
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch {
|
||||
s.agent.setSelectedPair(p)
|
||||
if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch {
|
||||
s.agent.setSelectedPair(pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -123,6 +131,7 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo
|
||||
ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID)
|
||||
if !ok {
|
||||
s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -131,26 +140,32 @@ func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remo
|
||||
// Assert that NAT is not symmetric
|
||||
// https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1
|
||||
if !addrEqual(transactionAddr, remoteAddr) {
|
||||
s.log.Debugf("Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote)
|
||||
s.log.Debugf(
|
||||
"Discard message: transaction source and destination does not match expected(%s), actual(%s)",
|
||||
transactionAddr,
|
||||
remote,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local)
|
||||
p := s.agent.findPair(local, remote)
|
||||
pair := s.agent.findPair(local, remote)
|
||||
|
||||
if p == nil {
|
||||
if pair == nil {
|
||||
// This shouldn't happen
|
||||
s.log.Error("Success response from invalid candidate pair")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
p.state = CandidatePairStateSucceeded
|
||||
s.log.Tracef("Found valid candidate pair: %s", p)
|
||||
pair.state = CandidatePairStateSucceeded
|
||||
s.log.Tracef("Found valid candidate pair: %s", pair)
|
||||
if pendingRequest.isUseCandidate && s.agent.getSelectedPair() == nil {
|
||||
s.agent.setSelectedPair(p)
|
||||
s.agent.setSelectedPair(pair)
|
||||
}
|
||||
|
||||
p.UpdateRoundTripTime(rtt)
|
||||
pair.UpdateRoundTripTime(rtt)
|
||||
}
|
||||
|
||||
func (s *controllingSelector) PingCandidate(local, remote Candidate) {
|
||||
@@ -163,6 +178,7 @@ func (s *controllingSelector) PingCandidate(local, remote Candidate) {
|
||||
)
|
||||
if err != nil {
|
||||
s.log.Error(err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -198,6 +214,7 @@ func (s *controlledSelector) PingCandidate(local, remote Candidate) {
|
||||
)
|
||||
if err != nil {
|
||||
s.log.Error(err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -216,6 +233,7 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
|
||||
ok, pendingRequest, rtt := s.agent.handleInboundBindingSuccess(m.TransactionID)
|
||||
if !ok {
|
||||
s.log.Warnf("Discard message from (%s), unknown TransactionID 0x%x", remote, m.TransactionID)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -224,52 +242,62 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot
|
||||
// Assert that NAT is not symmetric
|
||||
// https://tools.ietf.org/html/rfc8445#section-7.2.5.2.1
|
||||
if !addrEqual(transactionAddr, remoteAddr) {
|
||||
s.log.Debugf("Discard message: transaction source and destination does not match expected(%s), actual(%s)", transactionAddr, remote)
|
||||
s.log.Debugf(
|
||||
"Discard message: transaction source and destination does not match expected(%s), actual(%s)",
|
||||
transactionAddr,
|
||||
remote,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Tracef("Inbound STUN (SuccessResponse) from %s to %s", remote, local)
|
||||
|
||||
p := s.agent.findPair(local, remote)
|
||||
if p == nil {
|
||||
pair := s.agent.findPair(local, remote)
|
||||
if pair == nil {
|
||||
// This shouldn't happen
|
||||
s.log.Error("Success response from invalid candidate pair")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
p.state = CandidatePairStateSucceeded
|
||||
s.log.Tracef("Found valid candidate pair: %s", p)
|
||||
if p.nominateOnBindingSuccess {
|
||||
pair.state = CandidatePairStateSucceeded
|
||||
s.log.Tracef("Found valid candidate pair: %s", pair)
|
||||
if pair.nominateOnBindingSuccess {
|
||||
if selectedPair := s.agent.getSelectedPair(); selectedPair == nil ||
|
||||
(selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) {
|
||||
s.agent.setSelectedPair(p)
|
||||
} else if selectedPair != p {
|
||||
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair)
|
||||
(selectedPair != pair &&
|
||||
(!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= pair.priority())) {
|
||||
s.agent.setSelectedPair(pair)
|
||||
} else if selectedPair != pair {
|
||||
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair)
|
||||
}
|
||||
}
|
||||
|
||||
p.UpdateRoundTripTime(rtt)
|
||||
pair.UpdateRoundTripTime(rtt)
|
||||
}
|
||||
|
||||
func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) {
|
||||
p := s.agent.findPair(local, remote)
|
||||
if p == nil {
|
||||
p = s.agent.addPair(local, remote)
|
||||
func (s *controlledSelector) HandleBindingRequest(message *stun.Message, local, remote Candidate) { //nolint:cyclop
|
||||
pair := s.agent.findPair(local, remote)
|
||||
if pair == nil {
|
||||
pair = s.agent.addPair(local, remote)
|
||||
}
|
||||
|
||||
if m.Contains(stun.AttrUseCandidate) {
|
||||
if message.Contains(stun.AttrUseCandidate) { //nolint:nestif
|
||||
// https://tools.ietf.org/html/rfc8445#section-7.3.1.5
|
||||
|
||||
if p.state == CandidatePairStateSucceeded {
|
||||
if pair.state == CandidatePairStateSucceeded {
|
||||
// If the state of this pair is Succeeded, it means that the check
|
||||
// previously sent by this pair produced a successful response and
|
||||
// generated a valid pair (Section 7.2.5.3.2). The agent sets the
|
||||
// nominated flag value of the valid pair to true.
|
||||
selectedPair := s.agent.getSelectedPair()
|
||||
if selectedPair == nil || (selectedPair != p && (!s.agent.needsToCheckPriorityOnNominated() || selectedPair.priority() <= p.priority())) {
|
||||
s.agent.setSelectedPair(p)
|
||||
} else if selectedPair != p {
|
||||
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair)
|
||||
if selectedPair == nil ||
|
||||
(selectedPair != pair &&
|
||||
(!s.agent.needsToCheckPriorityOnNominated() ||
|
||||
selectedPair.priority() <= pair.priority())) {
|
||||
s.agent.setSelectedPair(pair)
|
||||
} else if selectedPair != pair {
|
||||
s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", pair, selectedPair)
|
||||
}
|
||||
} else {
|
||||
// If the received Binding request triggered a new check to be
|
||||
@@ -280,16 +308,16 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote
|
||||
// MUST remove the candidate pair from the valid list, set the
|
||||
// candidate pair state to Failed, and set the checklist state to
|
||||
// Failed.
|
||||
p.nominateOnBindingSuccess = true
|
||||
pair.nominateOnBindingSuccess = true
|
||||
}
|
||||
}
|
||||
|
||||
s.agent.sendBindingSuccess(m, local, remote)
|
||||
s.agent.sendBindingSuccess(message, local, remote)
|
||||
s.PingCandidate(local, remote)
|
||||
|
||||
if s.agent.userBindingRequestHandler != nil {
|
||||
if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch {
|
||||
s.agent.setSelectedPair(p)
|
||||
if shouldSwitch := s.agent.userBindingRequestHandler(message, local, remote, pair); shouldSwitch {
|
||||
s.agent.setSelectedPair(pair)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -298,7 +326,7 @@ type liteSelector struct {
|
||||
pairCandidateSelector
|
||||
}
|
||||
|
||||
// A lite selector should not contact candidates
|
||||
// A lite selector should not contact candidates.
|
||||
func (s *liteSelector) ContactCandidates() {
|
||||
if _, ok := s.pairCandidateSelector.(*controllingSelector); ok {
|
||||
//nolint:godox
|
||||
|
@@ -23,6 +23,8 @@ import (
|
||||
)
|
||||
|
||||
func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool {
|
||||
t.Helper()
|
||||
|
||||
testMessage := []byte("Hello World")
|
||||
testBuffer := make([]byte, len(testMessage))
|
||||
|
||||
@@ -73,6 +75,7 @@ func TestBindingRequestHandler(t *testing.T) {
|
||||
CheckInterval: &oneHour,
|
||||
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
|
||||
controlledLoggingFired.Store(true)
|
||||
|
||||
return false
|
||||
},
|
||||
})
|
||||
@@ -87,6 +90,7 @@ func TestBindingRequestHandler(t *testing.T) {
|
||||
BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool {
|
||||
// Don't switch candidate pair until we are ready
|
||||
val, ok := switchToNewCandidatePair.Load().(bool)
|
||||
|
||||
return ok && val
|
||||
},
|
||||
})
|
||||
|
2
stats.go
2
stats.go
@@ -7,7 +7,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CandidatePairStats contains ICE candidate pair statistics
|
||||
// CandidatePairStats contains ICE candidate pair statistics.
|
||||
type CandidatePairStats struct {
|
||||
// Timestamp is the timestamp associated with this object.
|
||||
Timestamp time.Time
|
||||
|
62
tcp_mux.go
62
tcp_mux.go
@@ -79,20 +79,20 @@ func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
|
||||
params.AliveDurationForConnFromStun = 30 * time.Second
|
||||
}
|
||||
|
||||
m := &TCPMuxDefault{
|
||||
mux := &TCPMuxDefault{
|
||||
params: ¶ms,
|
||||
|
||||
connsIPv4: map[string]map[ipAddr]*tcpPacketConn{},
|
||||
connsIPv6: map[string]map[ipAddr]*tcpPacketConn{},
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
mux.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
m.start()
|
||||
defer mux.wg.Done()
|
||||
mux.start()
|
||||
}()
|
||||
|
||||
return m
|
||||
return mux
|
||||
}
|
||||
|
||||
func (m *TCPMuxDefault) start() {
|
||||
@@ -101,6 +101,7 @@ func (m *TCPMuxDefault) start() {
|
||||
conn, err := m.params.Listener.Accept()
|
||||
if err != nil {
|
||||
m.params.Logger.Infof("Error accepting connection: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -130,6 +131,7 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP)
|
||||
|
||||
if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
|
||||
conn.ClearAliveTimer()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -191,12 +193,17 @@ func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
func (m *TCPMuxDefault) handleConn(conn net.Conn) { //nolint:cyclop
|
||||
buf := make([]byte, 512)
|
||||
|
||||
if m.params.FirstStunBindTimeout > 0 {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil {
|
||||
m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
m.params.Logger.Warnf(
|
||||
"Failed to set read deadline for first STUN message: %s to %s , err: %s",
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
n, err := readStreamingPacket(conn, buf)
|
||||
@@ -207,6 +214,7 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err)
|
||||
}
|
||||
m.closeAndLogError(conn)
|
||||
|
||||
return
|
||||
}
|
||||
if err = conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
@@ -223,12 +231,14 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
if err = msg.Decode(); err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if m == nil || msg.Type.Method != stun.MethodBinding { // Not a STUN
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Not a STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -239,7 +249,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
attr, err := msg.Get(stun.AttrUsername)
|
||||
if err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
m.params.Logger.Warnf(
|
||||
"No Username attribute in STUN message from %s to %s",
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -249,7 +264,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
m.params.Logger.Warnf(
|
||||
"Failed to get host in STUN message from %s to %s",
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -258,7 +278,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
localAddr, ok := conn.LocalAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
m.params.Logger.Warnf(
|
||||
"Failed to get local tcp address in STUN message from %s to %s",
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
@@ -269,7 +294,12 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
m.params.Logger.Warnf(
|
||||
"Failed to create packetConn for STUN message from %s to %s",
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -277,7 +307,13 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) {
|
||||
|
||||
if err := packetConn.AddConn(conn, buf); err != nil {
|
||||
m.closeAndLogError(conn)
|
||||
m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
m.params.Logger.Warnf(
|
||||
"Error adding conn to tcpPacketConn from %s to %s: %s",
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -428,7 +464,7 @@ func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
|
||||
|
||||
func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
|
||||
bufCopy := make([]byte, streamingPacketHeaderLen+len(buf))
|
||||
binary.BigEndian.PutUint16(bufCopy, uint16(len(buf)))
|
||||
binary.BigEndian.PutUint16(bufCopy, uint16(len(buf))) //nolint:gosec // G115
|
||||
copy(bufCopy[2:], buf)
|
||||
|
||||
n, err := conn.Write(bufCopy)
|
||||
|
@@ -40,6 +40,7 @@ func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net
|
||||
if len(m.muxes) == 0 {
|
||||
return nil, errNoTCPMuxAvailable
|
||||
}
|
||||
|
||||
return m.muxes[0].GetConnByUfrag(ufrag, isIPv6, local)
|
||||
}
|
||||
|
||||
@@ -51,7 +52,7 @@ func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllConns returns a PacketConn for each underlying TCPMux
|
||||
// GetAllConns returns a PacketConn for each underlying TCPMux.
|
||||
func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) {
|
||||
if len(m.muxes) == 0 {
|
||||
// Make sure that we either return at least one connection or an error.
|
||||
@@ -68,10 +69,11 @@ func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
return conns, nil
|
||||
}
|
||||
|
||||
// Close the multi mux, no further connections could be created
|
||||
// Close the multi mux, no further connections could be created.
|
||||
func (m *MultiTCPMuxDefault) Close() error {
|
||||
var err error
|
||||
for _, mux := range m.muxes {
|
||||
@@ -79,5 +81,6 @@ func (m *MultiTCPMuxDefault) Close() error {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
@@ -36,6 +36,7 @@ func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) n
|
||||
}
|
||||
|
||||
go bc.writeProcess()
|
||||
|
||||
return bc
|
||||
}
|
||||
|
||||
@@ -44,6 +45,7 @@ func (bc *bufferedConn) Write(b []byte) (int, error) {
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@@ -57,11 +59,13 @@ func (bc *bufferedConn) writeProcess() {
|
||||
|
||||
if err != nil {
|
||||
bc.logger.Warnf("Failed to read from buffer: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := bc.Conn.Write(pktBuf[:n]); err != nil {
|
||||
bc.logger.Warnf("Failed to write: %s", err)
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -70,6 +74,7 @@ func (bc *bufferedConn) writeProcess() {
|
||||
func (bc *bufferedConn) Close() error {
|
||||
atomic.StoreInt32(&bc.closed, 1)
|
||||
_ = bc.buf.Close()
|
||||
|
||||
return bc.Conn.Close()
|
||||
}
|
||||
|
||||
@@ -103,7 +108,7 @@ type tcpPacketParams struct {
|
||||
}
|
||||
|
||||
func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
|
||||
p := &tcpPacketConn{
|
||||
packet := &tcpPacketConn{
|
||||
params: ¶ms,
|
||||
|
||||
conns: map[string]net.Conn{},
|
||||
@@ -113,13 +118,13 @@ func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
|
||||
}
|
||||
|
||||
if params.AliveDuration > 0 {
|
||||
p.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
|
||||
p.params.Logger.Warn("close tcp packet conn by alive timeout")
|
||||
_ = p.Close()
|
||||
packet.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
|
||||
packet.params.Logger.Warn("close tcp packet conn by alive timeout")
|
||||
_ = packet.Close()
|
||||
})
|
||||
}
|
||||
|
||||
return p
|
||||
return packet
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) ClearAliveTimer() {
|
||||
@@ -131,7 +136,12 @@ func (t *tcpPacketConn) ClearAliveTimer() {
|
||||
}
|
||||
|
||||
func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
|
||||
t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr())
|
||||
t.params.Logger.Infof(
|
||||
"Added connection: %s remote %s to local %s",
|
||||
conn.RemoteAddr().Network(),
|
||||
conn.RemoteAddr(),
|
||||
conn.LocalAddr(),
|
||||
)
|
||||
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
@@ -183,6 +193,7 @@ func (t *tcpPacketConn) startReading(conn net.Conn) {
|
||||
if last || !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) {
|
||||
t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -236,6 +247,7 @@ func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
|
||||
|
||||
n = len(pkt.Data)
|
||||
copy(b, pkt.Data[:n])
|
||||
|
||||
return n, pkt.RAddr, err
|
||||
}
|
||||
|
||||
@@ -252,6 +264,7 @@ func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
n, err = writeStreamingPacket(conn, buf)
|
||||
if err != nil {
|
||||
t.params.Logger.Tracef("%w %s", errWrite, rAddr)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
@@ -272,6 +285,7 @@ func (t *tcpPacketConn) removeConn(conn net.Conn) bool {
|
||||
t.closeAndLogError(conn)
|
||||
|
||||
delete(t.conns, conn.RemoteAddr().String())
|
||||
|
||||
return len(t.conns) == 0
|
||||
}
|
||||
|
||||
|
22
transport.go
22
transport.go
@@ -32,12 +32,12 @@ type Conn struct {
|
||||
agent *Agent
|
||||
}
|
||||
|
||||
// BytesSent returns the number of bytes sent
|
||||
// BytesSent returns the number of bytes sent.
|
||||
func (c *Conn) BytesSent() uint64 {
|
||||
return atomic.LoadUint64(&c.bytesSent)
|
||||
}
|
||||
|
||||
// BytesReceived returns the number of bytes received
|
||||
// BytesReceived returns the number of bytes received.
|
||||
func (c *Conn) BytesReceived() uint64 {
|
||||
return atomic.LoadUint64(&c.bytesReceived)
|
||||
}
|
||||
@@ -74,18 +74,19 @@ func (c *Conn) Read(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
n, err := c.agent.buf.Read(p)
|
||||
atomic.AddUint64(&c.bytesReceived, uint64(n))
|
||||
atomic.AddUint64(&c.bytesReceived, uint64(n)) //nolint:gosec // G115
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements the Conn Write method.
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
func (c *Conn) Write(packet []byte) (int, error) {
|
||||
err := c.agent.loop.Err()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if stun.IsMessage(p) {
|
||||
if stun.IsMessage(packet) {
|
||||
return 0, errWriteSTUNMessageToIceConn
|
||||
}
|
||||
|
||||
@@ -102,8 +103,9 @@ func (c *Conn) Write(p []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddUint64(&c.bytesSent, uint64(len(p)))
|
||||
return pair.Write(p)
|
||||
atomic.AddUint64(&c.bytesSent, uint64(len(packet)))
|
||||
|
||||
return pair.Write(packet)
|
||||
}
|
||||
|
||||
// Close implements the Conn Close method. It is used to close
|
||||
@@ -132,17 +134,17 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||
return pair.Remote.addr()
|
||||
}
|
||||
|
||||
// SetDeadline is a stub
|
||||
// SetDeadline is a stub.
|
||||
func (c *Conn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline is a stub
|
||||
// SetReadDeadline is a stub.
|
||||
func (c *Conn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline is a stub
|
||||
// SetWriteDeadline is a stub.
|
||||
func (c *Conn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
@@ -30,13 +30,15 @@ func TestStressDuplex(t *testing.T) {
|
||||
stressDuplex(t)
|
||||
}
|
||||
|
||||
func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
|
||||
func testTimeout(t *testing.T, conn *Conn, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
const pollRate = 100 * time.Millisecond
|
||||
const margin = 20 * time.Millisecond // Allow 20msec error in time
|
||||
ticker := time.NewTicker(pollRate)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
err := c.Close()
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -49,8 +51,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
|
||||
|
||||
var cs ConnectionState
|
||||
|
||||
err := c.agent.loop.Run(context.Background(), func(_ context.Context) {
|
||||
cs = c.agent.connectionState
|
||||
err := conn.agent.loop.Run(context.Background(), func(_ context.Context) {
|
||||
cs = conn.agent.connectionState
|
||||
})
|
||||
if err != nil {
|
||||
// We should never get here.
|
||||
@@ -63,6 +65,7 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) {
|
||||
t.Fatalf("Connection timed out %f msec early", elapsed.Seconds()*1000)
|
||||
} else {
|
||||
t.Logf("Connection timed out in %f msec", elapsed.Seconds()*1000)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -133,6 +136,8 @@ func TestReadClosed(t *testing.T) {
|
||||
}
|
||||
|
||||
func stressDuplex(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
ca, cb := pipe(nil)
|
||||
|
||||
defer func() {
|
||||
@@ -219,6 +224,7 @@ func connect(aAgent, bAgent *Agent) (*Conn, *Conn) {
|
||||
|
||||
// Ensure accepted
|
||||
<-accepted
|
||||
|
||||
return aConn, bConn
|
||||
}
|
||||
|
||||
@@ -288,6 +294,7 @@ func pipeWithTimeout(disconnectTimeout time.Duration, iceKeepalive time.Duration
|
||||
|
||||
func onConnected() (func(ConnectionState), chan struct{}) {
|
||||
done := make(chan struct{})
|
||||
|
||||
return func(state ConnectionState) {
|
||||
if state == ConnectionStateConnected {
|
||||
close(done)
|
||||
@@ -295,11 +302,11 @@ func onConnected() (func(ConnectionState), chan struct{}) {
|
||||
}, done
|
||||
}
|
||||
|
||||
func randomPort(t testing.TB) int {
|
||||
t.Helper()
|
||||
func randomPort(tb testing.TB) int {
|
||||
tb.Helper()
|
||||
conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to pickPort: %v", err)
|
||||
tb.Fatalf("failed to pickPort: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
@@ -308,7 +315,8 @@ func randomPort(t testing.TB) int {
|
||||
case *net.UDPAddr:
|
||||
return addr.Port
|
||||
default:
|
||||
t.Fatalf("unknown addr type %T", addr)
|
||||
tb.Fatalf("unknown addr type %T", addr)
|
||||
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
@@ -30,9 +30,9 @@ func TestRemoteLocalAddr(t *testing.T) {
|
||||
// Agent1 is behind 1:1 NAT
|
||||
natType1 := &vnet.NATType{Mode: vnet.NATModeNAT1To1}
|
||||
|
||||
v, errVnet := buildVNet(natType0, natType1)
|
||||
builtVnet, errVnet := buildVNet(natType0, natType1)
|
||||
require.NoError(t, errVnet, "should succeed")
|
||||
defer v.close()
|
||||
defer builtVnet.close()
|
||||
|
||||
stunServerURL := &stun.URI{
|
||||
Scheme: stun.SchemeTypeSTUN,
|
||||
@@ -53,7 +53,7 @@ func TestRemoteLocalAddr(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Remote/Local Pair Match between Agents", func(t *testing.T) {
|
||||
ca, cb := pipeWithVNet(v,
|
||||
ca, cb := pipeWithVNet(builtVnet,
|
||||
&agentTestConfig{
|
||||
urls: []*stun.URI{stunServerURL},
|
||||
},
|
||||
|
48
udp_mux.go
48
udp_mux.go
@@ -18,7 +18,7 @@ import (
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
)
|
||||
|
||||
// UDPMux allows multiple connections to go over a single UDP port
|
||||
// UDPMux allows multiple connections to go over a single UDP port.
|
||||
type UDPMux interface {
|
||||
io.Closer
|
||||
GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
|
||||
@@ -26,7 +26,7 @@ type UDPMux interface {
|
||||
GetListenAddresses() []net.Addr
|
||||
}
|
||||
|
||||
// UDPMuxDefault is an implementation of the interface
|
||||
// UDPMuxDefault is an implementation of the interface.
|
||||
type UDPMuxDefault struct {
|
||||
params UDPMuxParams
|
||||
|
||||
@@ -60,14 +60,14 @@ type UDPMuxParams struct {
|
||||
Net transport.Net
|
||||
}
|
||||
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux.
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { //nolint:cyclop
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
}
|
||||
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { //nolint:nestif
|
||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
||||
} else if ok && udpAddr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
@@ -109,7 +109,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
}
|
||||
params.UDPConnString = params.UDPConn.LocalAddr().String()
|
||||
|
||||
m := &UDPMuxDefault{
|
||||
mux := &UDPMuxDefault{
|
||||
addressMap: map[ipPort]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
@@ -124,17 +124,17 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
||||
}
|
||||
|
||||
go m.connWorker()
|
||||
go mux.connWorker()
|
||||
|
||||
return m
|
||||
return mux
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault.
|
||||
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
return m.params.UDPConn.LocalAddr()
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on.
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
if len(m.localAddrsForUnspecified) > 0 {
|
||||
return m.localAddrsForUnspecified
|
||||
@@ -143,8 +143,8 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
}
|
||||
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network address
|
||||
// creates the connection if an existing one can't be found
|
||||
// GetConn returns a PacketConn given the connection's ufrag and network address.
|
||||
// creates the connection if an existing one can't be found.
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
// don't check addr for mux using unspecified address
|
||||
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() {
|
||||
@@ -181,11 +181,11 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||
// RemoveConnByUfrag stops and removes the muxed packet connection.
|
||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||
|
||||
// Keep lock section small to avoid deadlock with conn lock
|
||||
// Keep lock section small to avoid deadlock with conn lock.
|
||||
m.mu.Lock()
|
||||
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||
delete(m.connsIPv4, ufrag)
|
||||
@@ -198,7 +198,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(removedConns) == 0 {
|
||||
// No need to lock if no connection was found
|
||||
// No need to lock if no connection was found.
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsClosed returns true if the mux had been closed
|
||||
// IsClosed returns true if the mux had been closed.
|
||||
func (m *UDPMuxDefault) IsClosed() bool {
|
||||
select {
|
||||
case <-m.closedChan:
|
||||
@@ -223,7 +223,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Close the mux, no further connections could be created
|
||||
// Close the mux, no further connections could be created.
|
||||
func (m *UDPMuxDefault) Close() error {
|
||||
var err error
|
||||
m.closeOnce.Do(func() {
|
||||
@@ -244,6 +244,7 @@ func (m *UDPMuxDefault) Close() error {
|
||||
|
||||
_ = m.params.UDPConn.Close()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -276,10 +277,11 @@ func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
})
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) connWorker() {
|
||||
func (m *UDPMuxDefault) connWorker() { //nolint:cyclop
|
||||
logger := m.params.Logger
|
||||
|
||||
defer func() {
|
||||
@@ -304,11 +306,13 @@ func (m *UDPMuxDefault) connWorker() {
|
||||
netUDPAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
logger.Errorf("Underlying PacketConn did not return a UDPAddr")
|
||||
|
||||
return
|
||||
}
|
||||
udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port))
|
||||
udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) //nolint:gosec
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to create a new IP/Port host pair")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -325,12 +329,14 @@ func (m *UDPMuxDefault) connWorker() {
|
||||
|
||||
if err = msg.Decode(); err != nil {
|
||||
m.params.Logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
||||
if stunAttrErr != nil {
|
||||
m.params.Logger.Warnf("No Username attribute in STUN message from %s", addr.String())
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -344,6 +350,7 @@ func (m *UDPMuxDefault) connWorker() {
|
||||
|
||||
if destinationConn == nil {
|
||||
m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr, addr)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -359,6 +366,7 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
|
||||
} else {
|
||||
val, ok = m.connsIPv4[ufrag]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -386,7 +394,7 @@ type ipPort struct {
|
||||
|
||||
// newIPPort create a custom type of address based on netip.Addr and
|
||||
// port. The underlying ip address passed is converted to IPv6 format
|
||||
// to simplify ip address handling
|
||||
// to simplify ip address handling.
|
||||
func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) {
|
||||
n, ok := netip.AddrFromSlice(ip.To16())
|
||||
if !ok {
|
||||
|
@@ -29,6 +29,7 @@ func NewMultiUDPMuxDefault(muxes ...UDPMux) *MultiUDPMuxDefault {
|
||||
addrToMux[addr.String()] = mux
|
||||
}
|
||||
}
|
||||
|
||||
return &MultiUDPMuxDefault{
|
||||
muxes: muxes,
|
||||
localAddrToMux: addrToMux,
|
||||
@@ -42,6 +43,7 @@ func (m *MultiUDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketCon
|
||||
if !ok {
|
||||
return nil, errNoUDPMuxAvailable
|
||||
}
|
||||
|
||||
return mux.GetConn(ufrag, addr)
|
||||
}
|
||||
|
||||
@@ -53,7 +55,7 @@ func (m *MultiUDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multi mux, no further connections could be created
|
||||
// Close the multi mux, no further connections could be created.
|
||||
func (m *MultiUDPMuxDefault) Close() error {
|
||||
var err error
|
||||
for _, mux := range m.muxes {
|
||||
@@ -61,21 +63,23 @@ func (m *MultiUDPMuxDefault) Close() error {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on.
|
||||
func (m *MultiUDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
addrs := make([]net.Addr, 0, len(m.localAddrToMux))
|
||||
for _, mux := range m.muxes {
|
||||
addrs = append(addrs, mux.GetListenAddresses()...)
|
||||
}
|
||||
|
||||
return addrs
|
||||
}
|
||||
|
||||
// NewMultiUDPMuxFromPort creates an instance of MultiUDPMuxDefault that
|
||||
// listen all interfaces on the provided port.
|
||||
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) {
|
||||
func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMuxDefault, error) { //nolint:cyclop
|
||||
params := multiUDPMuxFromPortParam{
|
||||
networks: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
}
|
||||
@@ -104,6 +108,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
|
||||
})
|
||||
if listenErr != nil {
|
||||
err = listenErr
|
||||
|
||||
break
|
||||
}
|
||||
if params.readBufferSize > 0 {
|
||||
@@ -119,6 +124,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
|
||||
for _, conn := range conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -135,7 +141,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
|
||||
return NewMultiUDPMuxDefault(muxes...), nil
|
||||
}
|
||||
|
||||
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort
|
||||
// UDPMuxFromPortOption provide options for NewMultiUDPMuxFromPort.
|
||||
type UDPMuxFromPortOption interface {
|
||||
apply(*multiUDPMuxFromPortParam)
|
||||
}
|
||||
@@ -159,7 +165,7 @@ func (o *udpMuxFromPortOption) apply(p *multiUDPMuxFromPortParam) {
|
||||
o.f(p)
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used
|
||||
// UDPMuxFromPortWithInterfaceFilter set the filter to filter out interfaces that should not be used.
|
||||
func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
@@ -168,7 +174,7 @@ func UDPMuxFromPortWithInterfaceFilter(f func(string) (keep bool)) UDPMuxFromPor
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used
|
||||
// UDPMuxFromPortWithIPFilter set the filter to filter out IP addresses that should not be used.
|
||||
func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
@@ -177,7 +183,7 @@ func UDPMuxFromPortWithIPFilter(f func(ip net.IP) (keep bool)) UDPMuxFromPortOpt
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6
|
||||
// UDPMuxFromPortWithNetworks set the networks that should be used. default is both IPv4 and IPv6.
|
||||
func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
@@ -186,7 +192,7 @@ func UDPMuxFromPortWithNetworks(networks ...NetworkType) UDPMuxFromPortOption {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size
|
||||
// UDPMuxFromPortWithReadBufferSize set the UDP connection read buffer size.
|
||||
func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
@@ -195,7 +201,7 @@ func UDPMuxFromPortWithReadBufferSize(size int) UDPMuxFromPortOption {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size
|
||||
// UDPMuxFromPortWithWriteBufferSize set the UDP connection write buffer size.
|
||||
func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
@@ -204,7 +210,7 @@ func UDPMuxFromPortWithWriteBufferSize(size int) UDPMuxFromPortOption {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithLogger set the logger for the created UDPMux
|
||||
// UDPMuxFromPortWithLogger set the logger for the created UDPMux.
|
||||
func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
@@ -213,7 +219,7 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithLoopback set loopback interface should be included
|
||||
// UDPMuxFromPortWithLoopback set loopback interface should be included.
|
||||
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
|
@@ -79,6 +79,8 @@ func TestMultiUDPMux(t *testing.T) {
|
||||
}
|
||||
|
||||
func testMultiUDPMuxConnections(t *testing.T, udpMuxMulti *MultiUDPMuxDefault, ufrag string, network string) {
|
||||
t.Helper()
|
||||
|
||||
addrs := udpMuxMulti.GetListenAddresses()
|
||||
pktConns := make([]net.PacketConn, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
|
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUDPMux(t *testing.T) {
|
||||
func TestUDPMux(t *testing.T) { //nolint:cyclop
|
||||
defer test.CheckRoutines(t)()
|
||||
|
||||
defer test.TimeOut(time.Second * 30).Stop()
|
||||
@@ -127,6 +127,8 @@ func TestUDPMux(t *testing.T) {
|
||||
}
|
||||
|
||||
func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
|
||||
t.Helper()
|
||||
|
||||
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
defer func() {
|
||||
@@ -145,6 +147,8 @@ func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, networ
|
||||
}
|
||||
|
||||
func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net.UDPConn, ufrag string) {
|
||||
t.Helper()
|
||||
|
||||
// Initial messages are dropped
|
||||
_, err := remoteConn.Write([]byte("dropped bytes"))
|
||||
require.NoError(t, err)
|
||||
@@ -222,7 +226,7 @@ func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net
|
||||
require.NoError(t, err)
|
||||
h := sha256.Sum256(buf[36:])
|
||||
copy(buf[4:36], h[:])
|
||||
binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence))
|
||||
binary.LittleEndian.PutUint32(buf[0:4], uint32(sequence)) //nolint:gosec // G115
|
||||
|
||||
_, err = remoteConn.Write(buf)
|
||||
require.NoError(t, err)
|
||||
@@ -238,6 +242,8 @@ func testMuxConnectionPair(t *testing.T, pktConn net.PacketConn, remoteConn *net
|
||||
}
|
||||
|
||||
func verifyPacket(t *testing.T, b []byte, nextSeq uint32) {
|
||||
t.Helper()
|
||||
|
||||
readSeq := binary.LittleEndian.Uint32(b[0:4])
|
||||
require.Equal(t, nextSeq, readSeq)
|
||||
h := sha256.Sum256(b[36:])
|
||||
|
@@ -29,7 +29,8 @@ type UniversalUDPMuxDefault struct {
|
||||
*UDPMuxDefault
|
||||
params UniversalUDPMuxParams
|
||||
|
||||
// Since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
||||
// Since we have a shared socket, for srflx candidates it makes sense
|
||||
// to have a shared mapped address across all the agents
|
||||
// stun.XORMappedAddress indexed by the STUN server addr
|
||||
xorMappedMap map[string]*xorMapped
|
||||
}
|
||||
@@ -42,7 +43,7 @@ type UniversalUDPMuxParams struct {
|
||||
Net transport.Net
|
||||
}
|
||||
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux.
|
||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
@@ -51,31 +52,31 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||
}
|
||||
|
||||
m := &UniversalUDPMuxDefault{
|
||||
mux := &UniversalUDPMuxDefault{
|
||||
params: params,
|
||||
xorMappedMap: make(map[string]*xorMapped),
|
||||
}
|
||||
|
||||
// Wrap UDP connection, process server reflexive messages
|
||||
// before they are passed to the UDPMux connection handler (connWorker)
|
||||
m.params.UDPConn = &udpConn{
|
||||
mux.params.UDPConn = &udpConn{
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
mux: mux,
|
||||
logger: params.Logger,
|
||||
}
|
||||
|
||||
// Embed UDPMux
|
||||
udpMuxParams := UDPMuxParams{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
Net: m.params.Net,
|
||||
UDPConn: mux.params.UDPConn,
|
||||
Net: mux.params.Net,
|
||||
}
|
||||
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
||||
mux.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
return m
|
||||
return mux
|
||||
}
|
||||
|
||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets.
|
||||
type udpConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
@@ -88,7 +89,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(net.Addr, time.Duration) (*net.A
|
||||
return nil, errNotImplemented
|
||||
}
|
||||
|
||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL
|
||||
// (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||
// and return a unique connection per server.
|
||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
|
||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
|
||||
@@ -99,24 +101,24 @@ func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr ne
|
||||
func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, addr, err = c.PacketConn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
if stun.IsMessage(p[:n]) {
|
||||
if stun.IsMessage(p[:n]) { //nolint:nestif
|
||||
msg := &stun.Message{
|
||||
Raw: append([]byte{}, p[:n]...),
|
||||
}
|
||||
|
||||
if err = msg.Decode(); err != nil {
|
||||
c.logger.Warnf("Failed to handle decode ICE from %s: %v", addr.String(), err)
|
||||
err = nil
|
||||
return
|
||||
|
||||
return n, addr, nil
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
// Message about this err will be logged in the UDPMux
|
||||
return
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
if c.mux.isXORMappedResponse(msg, udpAddr.String()) {
|
||||
@@ -125,9 +127,11 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
c.logger.Debugf("%w: %v", errGetXorMappedAddrResponse, err)
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
|
||||
return n, addr, err
|
||||
}
|
||||
}
|
||||
|
||||
return n, addr, err
|
||||
}
|
||||
|
||||
@@ -135,14 +139,16 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
// Check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
|
||||
// Check first if it is a STUN server address,
|
||||
// because remote peer can also send similar messages but as a BindingSuccess.
|
||||
_, ok := m.xorMappedMap[stunAddr]
|
||||
_, err := msg.Get(stun.AttrXORMappedAddress)
|
||||
|
||||
return err == nil && ok
|
||||
}
|
||||
|
||||
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
|
||||
// and set the mapped address for the server
|
||||
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute.
|
||||
// and set the mapped address for the server.
|
||||
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -167,7 +173,10 @@ func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr,
|
||||
// Makes a STUN binding request to discover mapped address otherwise.
|
||||
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
|
||||
// Method is safe for concurrent use.
|
||||
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
|
||||
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(
|
||||
serverAddr net.Addr,
|
||||
deadline time.Duration,
|
||||
) (*stun.XORMappedAddress, error) {
|
||||
m.mu.Lock()
|
||||
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
|
||||
// If we already have a mapping for this STUN server (address already received)
|
||||
@@ -203,6 +212,7 @@ func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline
|
||||
if mappedAddr.addr == nil {
|
||||
return nil, errNoXorAddrMapping
|
||||
}
|
||||
|
||||
return mappedAddr.addr, nil
|
||||
case <-time.After(deadline):
|
||||
return nil, errXORMappedAddrTimeout
|
||||
|
@@ -43,6 +43,8 @@ func TestUniversalUDPMux(t *testing.T) {
|
||||
}
|
||||
|
||||
func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
|
||||
t.Helper()
|
||||
|
||||
pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr())
|
||||
require.NoError(t, err, "error retrieving muxed connection for ufrag")
|
||||
defer func() {
|
||||
|
@@ -28,7 +28,7 @@ type udpMuxedConnParams struct {
|
||||
Logger logging.LeveledLogger
|
||||
}
|
||||
|
||||
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||||
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag.
|
||||
type udpMuxedConn struct {
|
||||
params *udpMuxedConnParams
|
||||
// Remote addresses that we have sent to on this conn
|
||||
@@ -72,11 +72,12 @@ func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
|
||||
pkt.reset()
|
||||
c.params.AddrPool.Put(pkt)
|
||||
|
||||
return
|
||||
return n, rAddr, err
|
||||
}
|
||||
|
||||
if c.state == udpMuxedConnClosed {
|
||||
c.mu.Unlock()
|
||||
|
||||
return 0, nil, io.EOF
|
||||
}
|
||||
|
||||
@@ -101,6 +102,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||
return 0, errFailedToCastUDPAddr
|
||||
}
|
||||
|
||||
//nolint:gosec // TODO add port validation G115
|
||||
ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -150,12 +152,14 @@ func (c *udpMuxedConn) Close() error {
|
||||
c.state = udpMuxedConnClosed
|
||||
close(c.closedChan)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpMuxedConn) isClosed() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.state == udpMuxedConnClosed
|
||||
}
|
||||
|
||||
@@ -164,6 +168,7 @@ func (c *udpMuxedConn) getAddresses() []ipPort {
|
||||
defer c.mu.Unlock()
|
||||
addresses := make([]ipPort, len(c.addresses))
|
||||
copy(addresses, c.addresses)
|
||||
|
||||
return addresses
|
||||
}
|
||||
|
||||
@@ -198,6 +203,7 @@ func (c *udpMuxedConn) containsAddress(addr ipPort) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -205,6 +211,7 @@ func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
|
||||
pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
||||
if cap(pkt.buf) < len(data) {
|
||||
c.params.AddrPool.Put(pkt)
|
||||
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
|
36
url.go
36
url.go
@@ -6,77 +6,77 @@ package ice
|
||||
import "github.com/pion/stun/v3"
|
||||
|
||||
type (
|
||||
// URL represents a STUN (rfc7064) or TURN (rfc7065) URI
|
||||
// URL represents a STUN (rfc7064) or TURN (rfc7065) URI.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.URI
|
||||
// Deprecated: Please use pion/stun.URI.
|
||||
URL = stun.URI
|
||||
|
||||
// ProtoType indicates the transport protocol type that is used in the ice.URL
|
||||
// structure.
|
||||
//
|
||||
// Deprecated: TPlease use pion/stun.ProtoType
|
||||
// Deprecated: TPlease use pion/stun.ProtoType.
|
||||
ProtoType = stun.ProtoType
|
||||
|
||||
// SchemeType indicates the type of server used in the ice.URL structure.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.SchemeType
|
||||
// Deprecated: Please use pion/stun.SchemeType.
|
||||
SchemeType = stun.SchemeType
|
||||
)
|
||||
|
||||
const (
|
||||
// SchemeTypeSTUN indicates the URL represents a STUN server.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.SchemeTypeSTUN
|
||||
// Deprecated: Please use pion/stun.SchemeTypeSTUN.
|
||||
SchemeTypeSTUN = stun.SchemeTypeSTUN
|
||||
|
||||
// SchemeTypeSTUNS indicates the URL represents a STUNS (secure) server.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.SchemeTypeSTUNS
|
||||
// Deprecated: Please use pion/stun.SchemeTypeSTUNS.
|
||||
SchemeTypeSTUNS = stun.SchemeTypeSTUNS
|
||||
|
||||
// SchemeTypeTURN indicates the URL represents a TURN server.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.SchemeTypeTURN
|
||||
// Deprecated: Please use pion/stun.SchemeTypeTURN.
|
||||
SchemeTypeTURN = stun.SchemeTypeTURN
|
||||
|
||||
// SchemeTypeTURNS indicates the URL represents a TURNS (secure) server.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.SchemeTypeTURNS
|
||||
// Deprecated: Please use pion/stun.SchemeTypeTURNS.
|
||||
SchemeTypeTURNS = stun.SchemeTypeTURNS
|
||||
)
|
||||
|
||||
const (
|
||||
// ProtoTypeUDP indicates the URL uses a UDP transport.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.ProtoTypeUDP
|
||||
// Deprecated: Please use pion/stun.ProtoTypeUDP.
|
||||
ProtoTypeUDP = stun.ProtoTypeUDP
|
||||
|
||||
// ProtoTypeTCP indicates the URL uses a TCP transport.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.ProtoTypeTCP
|
||||
// Deprecated: Please use pion/stun.ProtoTypeTCP.
|
||||
ProtoTypeTCP = stun.ProtoTypeTCP
|
||||
)
|
||||
|
||||
// Unknown represents and unknown ProtoType or SchemeType
|
||||
// Unknown represents and unknown ProtoType or SchemeType.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown
|
||||
// Deprecated: Please use pion/stun.SchemeTypeUnknown or pion/stun.ProtoTypeUnknown.
|
||||
const Unknown = 0
|
||||
|
||||
// ParseURL parses a STUN or TURN urls following the ABNF syntax described in
|
||||
// ParseURL parses a STUN or TURN urls following the ABNF syntax described in.
|
||||
// https://tools.ietf.org/html/rfc7064 and https://tools.ietf.org/html/rfc7065
|
||||
// respectively.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.ParseURI
|
||||
// Deprecated: Please use pion/stun.ParseURI.
|
||||
var ParseURL = stun.ParseURI //nolint:gochecknoglobals
|
||||
|
||||
// NewSchemeType defines a procedure for creating a new SchemeType from a raw
|
||||
// NewSchemeType defines a procedure for creating a new SchemeType from a raw.
|
||||
// string naming the scheme type.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.NewSchemeType
|
||||
// Deprecated: Please use pion/stun.NewSchemeType.
|
||||
var NewSchemeType = stun.NewSchemeType //nolint:gochecknoglobals
|
||||
|
||||
// NewProtoType defines a procedure for creating a new ProtoType from a raw
|
||||
// NewProtoType defines a procedure for creating a new ProtoType from a raw.
|
||||
// string naming the transport protocol type.
|
||||
//
|
||||
// Deprecated: Please use pion/stun.NewProtoType
|
||||
// Deprecated: Please use pion/stun.NewProtoType.
|
||||
var NewProtoType = stun.NewProtoType //nolint:gochecknoglobals
|
||||
|
@@ -11,12 +11,14 @@ type UseCandidateAttr struct{}
|
||||
// AddTo adds USE-CANDIDATE attribute to message.
|
||||
func (UseCandidateAttr) AddTo(m *stun.Message) error {
|
||||
m.Add(stun.AttrUseCandidate, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsSet returns true if USE-CANDIDATE attribute is set.
|
||||
func (UseCandidateAttr) IsSet(m *stun.Message) bool {
|
||||
_, err := m.Get(stun.AttrUseCandidate)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
|
@@ -13,6 +13,8 @@ import (
|
||||
)
|
||||
|
||||
func newHostRemote(t *testing.T) *CandidateHost {
|
||||
t.Helper()
|
||||
|
||||
remoteHostConfig := &CandidateHostConfig{
|
||||
Network: "udp",
|
||||
Address: "1.2.3.5",
|
||||
@@ -21,10 +23,13 @@ func newHostRemote(t *testing.T) *CandidateHost {
|
||||
}
|
||||
hostRemote, err := NewCandidateHost(remoteHostConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
return hostRemote
|
||||
}
|
||||
|
||||
func newPrflxRemote(t *testing.T) *CandidatePeerReflexive {
|
||||
t.Helper()
|
||||
|
||||
prflxConfig := &CandidatePeerReflexiveConfig{
|
||||
Network: "udp",
|
||||
Address: "10.10.10.2",
|
||||
@@ -35,10 +40,13 @@ func newPrflxRemote(t *testing.T) *CandidatePeerReflexive {
|
||||
}
|
||||
prflxRemote, err := NewCandidatePeerReflexive(prflxConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
return prflxRemote
|
||||
}
|
||||
|
||||
func newSrflxRemote(t *testing.T) *CandidateServerReflexive {
|
||||
t.Helper()
|
||||
|
||||
srflxConfig := &CandidateServerReflexiveConfig{
|
||||
Network: "udp",
|
||||
Address: "10.10.10.2",
|
||||
@@ -49,10 +57,13 @@ func newSrflxRemote(t *testing.T) *CandidateServerReflexive {
|
||||
}
|
||||
srflxRemote, err := NewCandidateServerReflexive(srflxConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
return srflxRemote
|
||||
}
|
||||
|
||||
func newRelayRemote(t *testing.T) *CandidateRelay {
|
||||
t.Helper()
|
||||
|
||||
relayConfig := &CandidateRelayConfig{
|
||||
Network: "udp",
|
||||
Address: "1.2.3.4",
|
||||
@@ -63,10 +74,13 @@ func newRelayRemote(t *testing.T) *CandidateRelay {
|
||||
}
|
||||
relayRemote, err := NewCandidateRelay(relayConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
return relayRemote
|
||||
}
|
||||
|
||||
func newHostLocal(t *testing.T) *CandidateHost {
|
||||
t.Helper()
|
||||
|
||||
localHostConfig := &CandidateHostConfig{
|
||||
Network: "udp",
|
||||
Address: "192.168.1.1",
|
||||
@@ -75,5 +89,6 @@ func newHostLocal(t *testing.T) *CandidateHost {
|
||||
}
|
||||
hostLocal, err := NewCandidateHost(localHostConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
return hostLocal
|
||||
}
|
||||
|
Reference in New Issue
Block a user