Finish agent options

This commit is contained in:
Joe Turki
2025-11-09 05:53:03 +02:00
parent 4c9f55e84f
commit 65bb60dec1
5 changed files with 367 additions and 134 deletions

231
agent.go
View File

@@ -111,8 +111,10 @@ type Agent struct {
selectedPair atomic.Value // *CandidatePair
urls []*stun.URI
networkTypes []NetworkType
urls []*stun.URI
networkTypes []NetworkType
natCandidateType CandidateType
natIPs []string
buf *packetio.Buffer
@@ -169,12 +171,126 @@ type Agent struct {
// NewAgent creates a new Agent.
func NewAgent(config *AgentConfig) (*Agent, error) {
return newAgentWithConfig(config)
return newAgentFromConfig(config)
}
// NewAgentWithOptions creates a new Agent with options only.
func NewAgentWithOptions(opts ...AgentOption) (*Agent, error) {
return newAgentWithConfig(&AgentConfig{}, opts...)
return newAgentFromConfig(&AgentConfig{}, opts...)
}
func newAgentFromConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error) {
if config == nil {
config = &AgentConfig{}
}
agent, err := createAgentBase(config)
if err != nil {
return nil, err
}
agent.localUfrag = config.LocalUfrag
agent.localPwd = config.LocalPwd
agent.natCandidateType = config.NAT1To1IPCandidateType
agent.natIPs = config.NAT1To1IPs
return newAgentWithConfig(agent, opts...)
}
func createAgentBase(config *AgentConfig) (*Agent, error) {
if config.PortMax < config.PortMin {
return nil, ErrPort
}
mDNSName, mDNSMode, err := setupMDNSConfig(config)
if err != nil {
return nil, err
}
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
log := loggerFactory.NewLogger("ice")
startedCtx, startedFn := context.WithCancel(context.Background())
agent := &Agent{
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
connectionState: ConnectionStateNew,
localCandidates: make(map[NetworkType][]Candidate),
remoteCandidates: make(map[NetworkType][]Candidate),
urls: config.Urls,
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buf: packetio.NewBuffer(),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portMin: config.PortMin,
portMax: config.PortMax,
loggerFactory: loggerFactory,
log: log,
net: config.Net,
proxyDialer: config.ProxyDialer,
tcpMux: config.TCPMux,
udpMux: config.UDPMux,
udpMuxSrflx: config.UDPMuxSrflx,
mDNSMode: mDNSMode,
mDNSName: mDNSName,
gatherCandidateCancel: func() {},
forceCandidateContact: make(chan bool, 1),
interfaceFilter: config.InterfaceFilter,
ipFilter: config.IPFilter,
insecureSkipVerify: config.InsecureSkipVerify,
includeLoopback: config.IncludeLoopback,
disableActiveTCP: config.DisableActiveTCP,
userBindingRequestHandler: config.BindingRequestHandler,
enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority,
enableRenomination: false,
nominationValueGenerator: nil,
nominationAttribute: stun.AttrType(0x0030), // Default value
continualGatheringPolicy: GatherOnce, // Default to GatherOnce
networkMonitorInterval: 2 * time.Second,
lastKnownInterfaces: make(map[string]netip.Addr),
automaticRenomination: false,
renominationInterval: 3 * time.Second, // Default matching libwebrtc
}
config.initWithDefaults(agent)
return agent, nil
}
func applyExternalIPMapping(agent *Agent, candidateType CandidateType, ips []string) error {
mapper, err := newExternalIPMapper(candidateType, ips)
if err != nil {
return err
}
agent.extIPMapper = mapper
if agent.extIPMapper == nil {
return nil
}
switch agent.extIPMapper.candidateType {
case CandidateTypeHost:
if agent.mDNSMode == MulticastDNSModeQueryAndGather {
return ErrMulticastDNSWithNAT1To1IPMapping
}
if !containsCandidateType(CandidateTypeHost, agent.candidateTypes) {
return ErrIneffectiveNAT1To1IPMappingHost
}
case CandidateTypeServerReflexive:
if !containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) {
return ErrIneffectiveNAT1To1IPMappingSrflx
}
default:
return nil
}
return nil
}
// setupMDNSConfig validates and returns mDNS configuration.
@@ -199,82 +315,16 @@ func setupMDNSConfig(config *AgentConfig) (string, MulticastDNSMode, error) {
return mDNSName, mDNSMode, nil
}
// newAgentWithConfig is the internal function that creates an agent with config and options.
// newAgentWithConfig finalizes a pre-configured agent with optional overrides.
//
//nolint:gocognit
func newAgentWithConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error) { //nolint:cyclop
//nolint:gocognit,cyclop
func newAgentWithConfig(agent *Agent, opts ...AgentOption) (*Agent, error) {
var err error
if config.PortMax < config.PortMin {
return nil, ErrPort
}
mDNSName, mDNSMode, err := setupMDNSConfig(config)
if err != nil {
return nil, err
}
loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}
log := loggerFactory.NewLogger("ice")
startedCtx, startedFn := context.WithCancel(context.Background())
agent := &Agent{
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
connectionState: ConnectionStateNew,
localCandidates: make(map[NetworkType][]Candidate),
remoteCandidates: make(map[NetworkType][]Candidate),
urls: config.Urls,
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buf: packetio.NewBuffer(),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portMin: config.PortMin,
portMax: config.PortMax,
loggerFactory: loggerFactory,
log: log,
net: config.Net,
proxyDialer: config.ProxyDialer,
tcpMux: config.TCPMux,
udpMux: config.UDPMux,
udpMuxSrflx: config.UDPMuxSrflx,
mDNSMode: mDNSMode,
mDNSName: mDNSName,
gatherCandidateCancel: func() {},
forceCandidateContact: make(chan bool, 1),
interfaceFilter: config.InterfaceFilter,
ipFilter: config.IPFilter,
insecureSkipVerify: config.InsecureSkipVerify,
includeLoopback: config.IncludeLoopback,
disableActiveTCP: config.DisableActiveTCP,
userBindingRequestHandler: config.BindingRequestHandler,
enableUseCandidateCheckPriority: config.EnableUseCandidateCheckPriority,
enableRenomination: false,
nominationValueGenerator: nil,
nominationAttribute: stun.AttrType(0x0030), // Default value
continualGatheringPolicy: GatherOnce, // Default to GatherOnce
networkMonitorInterval: 2 * time.Second,
lastKnownInterfaces: make(map[string]netip.Addr),
automaticRenomination: false,
renominationInterval: 3 * time.Second, // Default matching libwebrtc
for _, opt := range opts {
if err = opt(agent); err != nil {
return nil, err
}
}
agent.connectionStateNotifier = &handlerNotifier{
@@ -317,16 +367,14 @@ func newAgentWithConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error
agent.networkTypes,
localIfcs,
agent.includeLoopback,
mDNSMode,
mDNSName,
log,
loggerFactory,
agent.mDNSMode,
agent.mDNSName,
agent.log,
agent.loggerFactory,
); err != nil {
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
agent.log.Warnf("Failed to initialize mDNS %s: %v", agent.mDNSName, err)
}
config.initWithDefaults(agent)
// Make sure the buffer doesn't grow indefinitely.
// NOTE: We actually won't get anywhere close to this limit.
// SRTP will constantly read from the endpoint and drop packets if it's full.
@@ -338,7 +386,7 @@ func newAgentWithConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error
return nil, ErrLiteUsingNonHostCandidates
}
if len(config.Urls) > 0 &&
if len(agent.urls) > 0 &&
!containsCandidateType(CandidateTypeServerReflexive, agent.candidateTypes) &&
!containsCandidateType(CandidateTypeRelay, agent.candidateTypes) {
agent.closeMulticastConn()
@@ -346,7 +394,7 @@ func newAgentWithConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error
return nil, ErrUselessUrlsProvided
}
if err = config.initExtIPMapping(agent); err != nil {
if err = applyExternalIPMapping(agent, agent.natCandidateType, agent.natIPs); err != nil {
agent.closeMulticastConn()
return nil, err
@@ -371,22 +419,13 @@ func newAgentWithConfig(config *AgentConfig, opts ...AgentOption) (*Agent, error
})
// Restart is also used to initialize the agent for the first time
if err := agent.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
if err := agent.Restart(agent.localUfrag, agent.localPwd); err != nil {
agent.closeMulticastConn()
_ = agent.Close()
return nil, err
}
for _, opt := range opts {
if err := opt(agent); err != nil {
agent.closeMulticastConn()
_ = agent.Close()
return nil, err
}
}
return agent, nil
}

View File

@@ -5,7 +5,6 @@ package ice
import (
"net"
"slices"
"time"
"github.com/pion/logging"
@@ -63,6 +62,14 @@ func defaultCandidateTypes() []CandidateType {
return []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay}
}
func defaultRelayAcceptanceMinWaitFor(candidateTypes []CandidateType) time.Duration {
if len(candidateTypes) == 1 && candidateTypes[0] == CandidateTypeRelay {
return defaultRelayOnlyAcceptanceMinWait
}
return defaultRelayAcceptanceMinWait
}
// AgentConfig collects the arguments to ice.Agent construction into
// a single structure, for future-proofness of the interface.
type AgentConfig struct {
@@ -240,11 +247,7 @@ func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop
}
if config.RelayAcceptanceMinWait == nil {
if len(config.CandidateTypes) == 1 && config.CandidateTypes[0] == CandidateTypeRelay {
agent.relayAcceptanceMinWait = defaultRelayOnlyAcceptanceMinWait
} else {
agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWait
}
agent.relayAcceptanceMinWait = defaultRelayAcceptanceMinWaitFor(config.CandidateTypes)
} else {
agent.relayAcceptanceMinWait = *config.RelayAcceptanceMinWait
}
@@ -291,33 +294,3 @@ func (config *AgentConfig) initWithDefaults(agent *Agent) { //nolint:cyclop
agent.candidateTypes = config.CandidateTypes
}
}
func (config *AgentConfig) initExtIPMapping(agent *Agent) error { //nolint:cyclop
var err error
agent.extIPMapper, err = newExternalIPMapper(config.NAT1To1IPCandidateType, config.NAT1To1IPs)
if err != nil {
return err
}
if agent.extIPMapper == nil {
return nil // This may happen when config.NAT1To1IPs is an empty array
}
switch agent.extIPMapper.candidateType {
case CandidateTypeHost:
if agent.mDNSMode == MulticastDNSModeQueryAndGather {
return ErrMulticastDNSWithNAT1To1IPMapping
}
candiHostEnabled := slices.Contains(agent.candidateTypes, CandidateTypeHost)
if !candiHostEnabled {
return ErrIneffectiveNAT1To1IPMappingHost
}
case CandidateTypeServerReflexive:
candiSrflxEnabled := slices.Contains(agent.candidateTypes, CandidateTypeServerReflexive)
if !candiSrflxEnabled {
return ErrIneffectiveNAT1To1IPMappingSrflx
}
default:
return nil
}
return nil
}

View File

@@ -70,3 +70,28 @@ func TestAgentConfig_initWithDefaults(t *testing.T) {
})
}
}
func TestDefaultRelayAcceptanceMinWaitForCandidates(t *testing.T) {
tests := []struct {
name string
candidateType []CandidateType
expectedWait time.Duration
}{
{
name: "relay only",
candidateType: []CandidateType{CandidateTypeRelay},
expectedWait: defaultRelayOnlyAcceptanceMinWait,
},
{
name: "mixed types",
candidateType: []CandidateType{CandidateTypeHost, CandidateTypeRelay},
expectedWait: defaultRelayAcceptanceMinWait,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expectedWait, defaultRelayAcceptanceMinWaitFor(tc.candidateType))
})
}
}

View File

@@ -4,11 +4,15 @@
package ice
import (
"fmt"
"net"
"testing"
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testBooleanOption is a helper function to test boolean agent options.
@@ -173,7 +177,6 @@ func TestWithTCPPriorityOffset(t *testing.T) {
assert.NoError(t, err)
defer agent.Close() //nolint:errcheck
// The default is set via initWithDefaults
assert.Equal(t, uint16(27), agent.tcpPriorityOffset)
})
@@ -375,3 +378,167 @@ func TestWithLoggerFactory(t *testing.T) {
assert.NotNil(t, agent.log)
})
}
func TestWithNetworkTypesAppliedBeforeRestart(t *testing.T) {
t.Run("ipv6 listen skipped when network types option restricts to ipv4", func(t *testing.T) {
stub := newStubNet(t)
agent, err := newAgentFromConfig(&AgentConfig{
Net: stub,
}, WithNetworkTypes([]NetworkType{NetworkTypeUDP4}))
require.NoError(t, err)
defer func() { require.NoError(t, agent.Close()) }()
assert.Zero(t, stub.udp6ListenCount, "unexpected ipv6 listen before restart")
})
}
func TestWithCandidateTypesAffectsURLValidation(t *testing.T) {
stunURL, err := stun.ParseURI("stun:example.com:3478")
require.NoError(t, err)
t.Run("default candidate types accept urls", func(t *testing.T) {
stub := newStubNet(t)
agent, err := newAgentFromConfig(&AgentConfig{
Urls: []*stun.URI{stunURL},
Net: stub,
})
require.NoError(t, err)
require.NoError(t, agent.Close())
})
t.Run("host only candidate types reject urls", func(t *testing.T) {
stub := newStubNet(t)
_, err := newAgentFromConfig(&AgentConfig{
Urls: []*stun.URI{stunURL},
Net: stub,
}, WithCandidateTypes([]CandidateType{CandidateTypeHost}))
require.ErrorIs(t, err, ErrUselessUrlsProvided)
})
}
func TestWithCandidateTypesNAT1To1Validation(t *testing.T) {
t.Run("host mapping requires host candidates", func(t *testing.T) {
stub := newStubNet(t)
_, err := newAgentFromConfig(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeHost,
Net: stub,
}, WithCandidateTypes([]CandidateType{CandidateTypeRelay}))
require.ErrorIs(t, err, ErrIneffectiveNAT1To1IPMappingHost)
})
t.Run("srflx mapping requires srflx candidates", func(t *testing.T) {
stub := newStubNet(t)
_, err := newAgentFromConfig(&AgentConfig{
NAT1To1IPs: []string{"1.2.3.4"},
NAT1To1IPCandidateType: CandidateTypeServerReflexive,
Net: stub,
}, WithCandidateTypes([]CandidateType{CandidateTypeHost}))
require.ErrorIs(t, err, ErrIneffectiveNAT1To1IPMappingSrflx)
})
}
type stubNet struct {
t *testing.T
udp6ListenCount int
}
func newStubNet(t *testing.T) *stubNet {
t.Helper()
return &stubNet{t: t}
}
func (n *stubNet) ListenPacket(network, address string) (net.PacketConn, error) {
return nil, transport.ErrNotSupported
}
func (n *stubNet) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
if network == "udp6" {
n.udp6ListenCount++
}
return nil, fmt.Errorf("stub net does not listen on %s", network) //nolint:err113
}
func (n *stubNet) ListenTCP(network string, laddr *net.TCPAddr) (transport.TCPListener, error) {
return nil, transport.ErrNotSupported
}
func (n *stubNet) Dial(network, address string) (net.Conn, error) {
return nil, transport.ErrNotSupported
}
func (n *stubNet) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *stubNet) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
return nil, transport.ErrNotSupported
}
func (n *stubNet) ResolveIPAddr(network, address string) (*net.IPAddr, error) {
return net.ResolveIPAddr(network, address)
}
func (n *stubNet) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
return net.ResolveUDPAddr(network, address)
}
func (n *stubNet) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr(network, address)
}
func (n *stubNet) Interfaces() ([]*transport.Interface, error) {
iface := transport.NewInterface(net.Interface{
Index: 1,
MTU: 1500,
Name: "stub0",
Flags: net.FlagUp,
})
iface.AddAddress(&net.IPNet{
IP: net.IPv4(192, 0, 2, 1),
Mask: net.CIDRMask(24, 32),
})
return []*transport.Interface{iface}, nil
}
func (n *stubNet) InterfaceByIndex(index int) (*transport.Interface, error) {
ifaces, err := n.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Index == index {
return iface, nil
}
}
return nil, transport.ErrInterfaceNotFound
}
func (n *stubNet) InterfaceByName(name string) (*transport.Interface, error) {
ifaces, err := n.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Name == name {
return iface, nil
}
}
return nil, transport.ErrInterfaceNotFound
}
func (n *stubNet) CreateDialer(dialer *net.Dialer) transport.Dialer {
return nil
}

View File

@@ -2054,6 +2054,35 @@ func TestRoleConflict(t *testing.T) {
})
}
func TestDefaultCandidateTypes(t *testing.T) {
expected := []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive, CandidateTypeRelay}
first := defaultCandidateTypes()
require.Equal(t, expected, first)
first[0] = CandidateTypeRelay
second := defaultCandidateTypes()
require.Equal(t, expected, second)
}
func TestDefaultRelayAcceptanceMinWaitFor(t *testing.T) {
t.Run("relay only defaults to zero wait", func(t *testing.T) {
wait := defaultRelayAcceptanceMinWaitFor([]CandidateType{CandidateTypeRelay})
require.Equal(t, defaultRelayOnlyAcceptanceMinWait, wait)
})
t.Run("empty candidate types uses general relay wait", func(t *testing.T) {
wait := defaultRelayAcceptanceMinWaitFor(nil)
require.Equal(t, defaultRelayAcceptanceMinWait, wait)
})
t.Run("mixed candidate types uses general relay wait", func(t *testing.T) {
wait := defaultRelayAcceptanceMinWaitFor([]CandidateType{CandidateTypeHost, CandidateTypeRelay})
require.Equal(t, defaultRelayAcceptanceMinWait, wait)
})
}
func TestAgentConfig_initWithDefaults_UsesProvidedValues(t *testing.T) {
valMaxBindingReq := uint16(0)
valSrflxWait := 111 * time.Millisecond
@@ -2113,7 +2142,7 @@ func TestAutomaticRenominationWithVNet(t *testing.T) {
checkInterval := 50 * time.Millisecond
renominationInterval := 200 * time.Millisecond
agent1, err := newAgentWithConfig(&AgentConfig{
agent1, err := newAgentFromConfig(&AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4},
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,