diff --git a/errors.go b/errors.go index acece24d..15535f6d 100644 --- a/errors.go +++ b/errors.go @@ -199,6 +199,7 @@ var ( errICERoleUnknown = errors.New("unknown ICE Role") errICEProtocolUnknown = errors.New("unknown protocol") errICEGathererNotStarted = errors.New("gatherer not started") + errAddressRewriteWithNAT1To1 = errors.New("address rewrite rules cannot be combined with NAT1To1IPs") errNetworkTypeUnknown = errors.New("unknown network type") diff --git a/icecandidatetype.go b/icecandidatetype.go index 7ac62db1..8a5fe933 100644 --- a/icecandidatetype.go +++ b/icecandidatetype.go @@ -101,7 +101,7 @@ func getCandidateType(candidateType ice.CandidateType) (ICECandidateType, error) } // MarshalText implements the encoding.TextMarshaler interface. -func (t ICECandidateType) MarshalText() ([]byte, error) { +func (t ICECandidateType) MarshalText() ([]byte, error) { //nolint:staticcheck return []byte(t.String()), nil } @@ -112,3 +112,7 @@ func (t *ICECandidateType) UnmarshalText(b []byte) error { return err } + +func (r ICECandidateType) toICE() ice.CandidateType { + return ice.CandidateType(r) +} diff --git a/icegatherer.go b/icegatherer.go index 236d9c7f..96dcb871 100644 --- a/icegatherer.go +++ b/icegatherer.go @@ -46,6 +46,48 @@ type ICEGatherer struct { sdpMLineIndex atomic.Uint32 // uint16 } +// ICEAddressRewriteMode controls whether a rule replaces or appends candidates. +type ICEAddressRewriteMode byte + +const ( + ICEAddressRewriteModeUnspecified ICEAddressRewriteMode = iota + ICEAddressRewriteReplace + ICEAddressRewriteAppend +) + +func (r ICEAddressRewriteMode) toICE() ice.AddressRewriteMode { + return ice.AddressRewriteMode(r) +} + +// ICEAddressRewriteRule represents a rule for remapping candidate addresses. +type ICEAddressRewriteRule struct { + External []string + Local string + Iface string + CIDR string + AsCandidateType ICECandidateType + Mode ICEAddressRewriteMode + Networks []NetworkType +} + +func (r ICEAddressRewriteRule) toICE() ice.AddressRewriteRule { + candidateType := r.AsCandidateType.toICE() + mode := r.Mode.toICE() + networks := toICENetworkTypes(r.Networks) + + rule := ice.AddressRewriteRule{ + External: append([]string(nil), r.External...), + Local: r.Local, + Iface: r.Iface, + CIDR: r.CIDR, + AsCandidateType: candidateType, + Mode: mode, + Networks: networks, + } + + return rule +} + // NewICEGatherer creates a new NewICEGatherer. // This constructor is part of the ORTC API. It is not // meant to be used together with the basic WebRTC API. @@ -80,7 +122,10 @@ func (g *ICEGatherer) createAgent() error { return nil } - options := g.buildAgentOptions() + options, err := g.buildAgentOptions() + if err != nil { + return err + } agent, err := ice.NewAgentWithOptions(options...) if err != nil { @@ -92,7 +137,7 @@ func (g *ICEGatherer) createAgent() error { return nil } -func (g *ICEGatherer) buildAgentOptions() []ice.AgentOption { +func (g *ICEGatherer) buildAgentOptions() ([]ice.AgentOption, error) { candidateTypes := g.resolveCandidateTypes() nat1To1CandiTyp := g.resolveNAT1To1CandidateType() mDNSMode := g.sanitizedMDNSMode() @@ -103,7 +148,12 @@ func (g *ICEGatherer) buildAgentOptions() []ice.AgentOption { } options = append(options, g.credentialOptions()...) - options = append(options, g.natRewriteOptions(nat1To1CandiTyp)...) + + rewriteOptions, err := g.addressRewriteOptions(nat1To1CandiTyp) + if err != nil { + return nil, err + } + options = append(options, rewriteOptions...) options = append(options, g.timeoutOptions()...) options = append(options, g.miscOptions()...) options = append(options, g.renominationOptions()...) @@ -113,12 +163,7 @@ func (g *ICEGatherer) buildAgentOptions() []ice.AgentOption { requestedNetworkTypes = supportedNetworkTypes() } - var networkTypes []ice.NetworkType - for _, typ := range requestedNetworkTypes { - networkTypes = append(networkTypes, ice.NetworkType(typ)) - } - - return append(options, ice.WithNetworkTypes(networkTypes)) + return append(options, ice.WithNetworkTypes(toICENetworkTypes(requestedNetworkTypes))), nil } func (g *ICEGatherer) resolveCandidateTypes() []ice.CandidateType { @@ -181,19 +226,29 @@ func (g *ICEGatherer) credentialOptions() []ice.AgentOption { } } -func (g *ICEGatherer) natRewriteOptions(candidateType ice.CandidateType) []ice.AgentOption { - if len(g.api.settingEngine.candidates.NAT1To1IPs) == 0 { - return nil +func (g *ICEGatherer) addressRewriteOptions(candidateType ice.CandidateType) ([]ice.AgentOption, error) { + rules := g.api.settingEngine.candidates.addressRewriteRules + nat1To1IPs := g.api.settingEngine.candidates.NAT1To1IPs + if len(rules) > 0 && len(nat1To1IPs) > 0 { + return nil, errAddressRewriteWithNAT1To1 + } + + if len(rules) > 0 { + return []ice.AgentOption{ice.WithAddressRewriteRules(rules...)}, nil + } + + if len(nat1To1IPs) == 0 { + return nil, nil } return []ice.AgentOption{ ice.WithAddressRewriteRules( legacyNAT1To1AddressRewriteRules( - g.api.settingEngine.candidates.NAT1To1IPs, + nat1To1IPs, candidateType, )..., ), - } + }, nil } func (g *ICEGatherer) timeoutOptions() []ice.AgentOption { diff --git a/icegatherer_test.go b/icegatherer_test.go index 2d8f987d..2972f457 100644 --- a/icegatherer_test.go +++ b/icegatherer_test.go @@ -23,6 +23,7 @@ import ( "github.com/pion/transport/v3/vnet" "github.com/pion/turn/v4" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewICEGatherer_Success(t *testing.T) { @@ -250,6 +251,893 @@ func TestLegacyNAT1To1AddressRewriteRulesVNet(t *testing.T) { //nolint:cyclop }) } +func TestICEAddressRewriteRulesWithNAT1To1Conflict(t *testing.T) { + lim := test.TimeOut(time.Second * 5) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("SetterError", func(t *testing.T) { + se := SettingEngine{} + se.SetNAT1To1IPs([]string{"203.0.113.1"}, ICECandidateTypeHost) + + err := se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{"198.51.100.1"}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }) + assert.ErrorIs(t, err, errAddressRewriteWithNAT1To1) + }) + + t.Run("RuntimeError", func(t *testing.T) { + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: "10.0.0.1", + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{"198.51.100.2"}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + })) + se.SetNAT1To1IPs([]string{"203.0.113.2"}, ICECandidateTypeHost) + + gatherer, err := NewAPI(WithSettingEngine(se)).NewICEGatherer(ICEGatherOptions{}) + require.NoError(t, err) + + err = gatherer.Gather() + assert.ErrorIs(t, err, errAddressRewriteWithNAT1To1) + assert.NoError(t, gatherer.Close()) + }) +} + +func gatherCandidatesWithSettingEngine(t *testing.T, se SettingEngine, opts ICEGatherOptions) []ICECandidate { + t.Helper() + + gatherer, err := NewAPI(WithSettingEngine(se)).NewICEGatherer(opts) + require.NoError(t, err) + + done := make(chan struct{}) + var candidates []ICECandidate + gatherer.OnLocalCandidate(func(c *ICECandidate) { + if c == nil { + close(done) + + return + } + candidates = append(candidates, *c) + }) + + require.NoError(t, gatherer.Gather()) + select { + case <-done: + case <-time.After(5 * time.Second): + assert.Fail(t, "gather did not complete") + } + + assert.NoError(t, gatherer.Close()) + + return candidates +} + +func TestICEGatherer_AddressRewriteRulesVNet(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + externalIP = "203.0.113.10" + localIP = "10.0.0.1" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: localIP, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + run := func(rule ICEAddressRewriteRule) []ICECandidate { + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules(rule)) + + return gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + } + + t.Run("HostReplace", func(t *testing.T) { + candidates := run(ICEAddressRewriteRule{ + External: []string{externalIP}, + Local: localIP, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }) + assert.NotEmpty(t, candidates) + + var hostAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.NotEmpty(t, hostAddrs, "expected host candidates") + assert.Subset(t, hostAddrs, []string{externalIP}) + for _, addr := range hostAddrs { + assert.NotEqual(t, localIP, addr) + } + }) + + t.Run("SrflxAppend", func(t *testing.T) { + candidates := run(ICEAddressRewriteRule{ + External: []string{externalIP}, + AsCandidateType: ICECandidateTypeSrflx, + }) + assert.NotEmpty(t, candidates) + + var hostAddrs []string + var srflx ICECandidate + var haveSrflx bool + for _, c := range candidates { + switch c.Typ { + case ICECandidateTypeHost: + hostAddrs = append(hostAddrs, c.Address) + case ICECandidateTypeSrflx: + srflx = c + haveSrflx = true + default: + } + } + + assert.NotEmpty(t, hostAddrs, "expected host candidates") + assert.Contains(t, hostAddrs, localIP) + assert.True(t, haveSrflx, "expected srflx candidate") + assert.Equal(t, externalIP, srflx.Address) + }) +} + +func TestICEGatherer_AddressRewriteRuleFilters(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + t.Run("CIDR", func(t *testing.T) { + const ( + firstIP = "10.0.0.2" + secondIP = "10.0.1.2" + externalIP = "203.0.113.20" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/16", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{firstIP, secondIP}, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{externalIP}, + CIDR: "10.0.0.0/24", + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + })) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var hostAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.Contains(t, hostAddrs, externalIP) + assert.Contains(t, hostAddrs, secondIP) + assert.NotContains(t, hostAddrs, firstIP) + }) + + t.Run("NetworkTypes", func(t *testing.T) { + const ( + localIP = "10.0.0.50" + externalIP = "203.0.113.50" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: localIP, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{externalIP}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + Networks: []NetworkType{NetworkTypeUDP6}, + })) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var hostAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.Contains(t, hostAddrs, localIP) + assert.NotContains(t, hostAddrs, externalIP) + }) +} + +func TestICEGatherer_AddressRewriteHostAppendAndReplace(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + firstLocal = "10.0.0.2" + secondLocal = "10.0.0.3" + firstExternal = "203.0.113.30" + secondExternal = "203.0.113.31" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{firstLocal, secondLocal}, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules( + ICEAddressRewriteRule{ + Local: firstLocal, + External: []string{firstExternal}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }, + ICEAddressRewriteRule{ + Local: secondLocal, + External: []string{secondExternal}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteAppend, + }, + )) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var hostAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.Contains(t, hostAddrs, firstExternal) + assert.NotContains(t, hostAddrs, firstLocal) + assert.Contains(t, hostAddrs, secondLocal) + assert.Contains(t, hostAddrs, secondExternal) +} + +func TestICEGatherer_AddressRewriteSrflxReplace(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + localIP = "10.0.0.60" + externalIP = "203.0.113.60" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: localIP, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{externalIP}, + AsCandidateType: ICECandidateTypeSrflx, + Mode: ICEAddressRewriteReplace, + })) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var hostAddrs []string + var srflxAddrs []string + for _, c := range candidates { + switch c.Typ { + case ICECandidateTypeHost: + hostAddrs = append(hostAddrs, c.Address) + case ICECandidateTypeSrflx: + srflxAddrs = append(srflxAddrs, c.Address) + default: + t.Logf("unexpected candidate type: %s", c.Typ) + } + } + + assert.Contains(t, hostAddrs, localIP) + assert.Contains(t, srflxAddrs, externalIP) + assert.NotContains(t, srflxAddrs, localIP) +} + +func TestICEGatherer_AddressRewriteSrflxAppendWithCatchAll(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + localIP = "10.0.0.80" + appendIP = "203.0.113.81" + replaceIP = "203.0.113.80" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: localIP, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules( + ICEAddressRewriteRule{ + External: []string{appendIP}, + AsCandidateType: ICECandidateTypeSrflx, + Mode: ICEAddressRewriteAppend, + }, + ICEAddressRewriteRule{ + External: []string{replaceIP}, + AsCandidateType: ICECandidateTypeSrflx, + Mode: ICEAddressRewriteReplace, + }, + )) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var srflxAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeSrflx { + srflxAddrs = append(srflxAddrs, c.Address) + } + } + + assert.Contains(t, srflxAddrs, appendIP) + assert.NotContains(t, srflxAddrs, replaceIP) + assert.NotContains(t, srflxAddrs, localIP) +} + +func TestICEGatherer_AddressRewriteMultipleRulesOrdering(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + localIP = "10.0.0.70" + otherLocalIP = "10.0.0.71" + externalIP = "203.0.113.70" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{localIP, otherLocalIP}, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules( + ICEAddressRewriteRule{ + CIDR: "10.0.0.0/24", + External: []string{externalIP}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }, + ICEAddressRewriteRule{ + Local: otherLocalIP, + External: []string{otherLocalIP}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteAppend, + }, + )) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var hostAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.Contains(t, hostAddrs, externalIP) + assert.NotContains(t, hostAddrs, localIP) + assert.Contains(t, hostAddrs, otherLocalIP) +} + +func TestICEGatherer_AddressRewriteIfaceScope(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + localIP = "10.0.0.90" + externalIP = "203.0.113.90" + ) + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + nw, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: localIP, + }) + require.NoError(t, err) + require.NoError(t, router.AddNet(nw)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(nw) + require.NoError(t, se.SetICEAddressRewriteRules( + ICEAddressRewriteRule{ + Iface: "bad0", + External: []string{"198.51.100.90"}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }, + ICEAddressRewriteRule{ + Iface: "eth0", + External: []string{externalIP}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }, + )) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{}) + + var hostAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.Contains(t, hostAddrs, externalIP) + assert.NotContains(t, hostAddrs, localIP) + assert.NotContains(t, hostAddrs, "198.51.100.90") +} + +func TestICEConnection_AddressRewriteAppend(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 15) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + offerIP = "1.2.3.4" + answerIP = "1.2.3.5" + offerExternal = "203.0.113.200" + ) + + wan, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "1.2.3.0/24", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + require.NoError(t, err) + + offerNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{offerIP}, + }) + require.NoError(t, err) + answerNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIPs: []string{answerIP}, + }) + require.NoError(t, err) + + require.NoError(t, wan.AddNet(offerNet)) + require.NoError(t, wan.AddNet(answerNet)) + require.NoError(t, wan.Start()) + defer func() { + assert.NoError(t, wan.Stop()) + }() + + offerSE := SettingEngine{} + offerSE.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + offerSE.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + offerSE.SetNet(offerNet) + require.NoError(t, offerSE.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{offerExternal}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteAppend, + })) + + answerSE := SettingEngine{} + answerSE.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + answerSE.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + answerSE.SetNet(answerNet) + + offerPC, err := NewAPI(WithSettingEngine(offerSE)).NewPeerConnection(Configuration{}) + require.NoError(t, err) + answerPC, err := NewAPI(WithSettingEngine(answerSE)).NewPeerConnection(Configuration{}) + require.NoError(t, err) + defer closePairNow(t, offerPC, answerPC) + + var offerCandidates []ICECandidate + offerPC.OnICECandidate(func(c *ICECandidate) { + if c != nil { + offerCandidates = append(offerCandidates, *c) + } + }) + + assert.NoError(t, signalPair(offerPC, answerPC)) + + connected := untilConnectionState(PeerConnectionStateConnected, offerPC, answerPC) + connected.Wait() + + var hostAddrs []string + for _, c := range offerCandidates { + if c.Typ == ICECandidateTypeHost { + hostAddrs = append(hostAddrs, c.Address) + } + } + + assert.Contains(t, hostAddrs, offerIP) + assert.Contains(t, hostAddrs, offerExternal) +} + +func TestICEAddressRewriteDropRule(t *testing.T) { + se := SettingEngine{} + + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + + err := se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: nil, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }) + assert.NoError(t, err, "rule is allowed to be configured, validation happens in ice") + + gatherer, gErr := NewAPI(WithSettingEngine(se)).NewICEGatherer(ICEGatherOptions{}) + require.NoError(t, gErr) + defer func() { + assert.NoError(t, gatherer.Close()) + }() + + assert.ErrorIs(t, gatherer.Gather(), ice.ErrInvalidAddressRewriteMapping) +} + +func TestICEGatherer_AddressRewriteRelayVNet(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 15) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + turnIP = "10.0.0.2" + clientIP = "10.0.0.3" + relayExternal = "203.0.113.77" + turnListenPort = "3478" + ) + + loggerFactory := logging.NewDefaultLoggerFactory() + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + + turnNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: turnIP, + }) + require.NoError(t, err) + clientNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: clientIP, + }) + require.NoError(t, err) + + require.NoError(t, router.AddNet(turnNet)) + require.NoError(t, router.AddNet(clientNet)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + turnListener, err := turnNet.ListenPacket("udp4", net.JoinHostPort(turnIP, turnListenPort)) + require.NoError(t, err) + + authKey := turn.GenerateAuthKey("user", "pion.ly", "pass") + turnServer, err := turn.NewServer(turn.ServerConfig{ + Realm: "pion.ly", + AuthHandler: func(u, r string, _ net.Addr) ([]byte, bool) { + if u == "user" && r == "pion.ly" { + return authKey, true + } + + return nil, false + }, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: turnListener, + RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP(turnIP), + Address: "0.0.0.0", + Net: turnNet, + }, + }, + }, + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + defer func() { + assert.NoError(t, turnServer.Close()) + }() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(clientNet) + require.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{relayExternal}, + AsCandidateType: ICECandidateTypeRelay, + Mode: ICEAddressRewriteReplace, + })) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{ + ICEServers: []ICEServer{ + { + URLs: []string{fmt.Sprintf("turn:%s:%s?transport=udp", turnIP, turnListenPort)}, + Username: "user", + Credential: "pass", + }, + }, + ICEGatherPolicy: ICETransportPolicyRelay, + }) + + var relayAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeRelay { + relayAddrs = append(relayAddrs, c.Address) + } + } + + assert.NotEmpty(t, relayAddrs, "expected relay candidates") + assert.Subset(t, relayAddrs, []string{relayExternal}) + assert.NotContains(t, relayAddrs, turnIP) +} + +func TestICEGatherer_AddressRewriteRelayAppendVNet(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 15) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + turnIP = "10.0.0.4" + clientIP = "10.0.0.5" + relayExternal = "203.0.113.78" + turnListenPort = "3478" + ) + + loggerFactory := logging.NewDefaultLoggerFactory() + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "10.0.0.0/24", + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + + turnNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: turnIP, + }) + require.NoError(t, err) + clientNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: clientIP, + }) + require.NoError(t, err) + + require.NoError(t, router.AddNet(turnNet)) + require.NoError(t, router.AddNet(clientNet)) + require.NoError(t, router.Start()) + defer func() { + assert.NoError(t, router.Stop()) + }() + + turnListener, err := turnNet.ListenPacket("udp4", net.JoinHostPort(turnIP, turnListenPort)) + require.NoError(t, err) + + authKey := turn.GenerateAuthKey("user", "pion.ly", "pass") + turnServer, err := turn.NewServer(turn.ServerConfig{ + Realm: "pion.ly", + AuthHandler: func(u, r string, _ net.Addr) ([]byte, bool) { + if u == "user" && r == "pion.ly" { + return authKey, true + } + + return nil, false + }, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: turnListener, + RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP(turnIP), + Address: "0.0.0.0", + Net: turnNet, + }, + }, + }, + LoggerFactory: loggerFactory, + }) + require.NoError(t, err) + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(clientNet) + require.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{relayExternal}, + AsCandidateType: ICECandidateTypeRelay, + Mode: ICEAddressRewriteAppend, + })) + + candidates := gatherCandidatesWithSettingEngine(t, se, ICEGatherOptions{ + ICEServers: []ICEServer{ + { + URLs: []string{fmt.Sprintf("turn:%s:%s?transport=udp", turnIP, turnListenPort)}, + Username: "user", + Credential: "pass", + }, + }, + ICEGatherPolicy: ICETransportPolicyRelay, + }) + + var relayAddrs []string + for _, c := range candidates { + if c.Typ == ICECandidateTypeRelay { + relayAddrs = append(relayAddrs, c.Address) + } + } + + assert.Contains(t, relayAddrs, turnIP) + assert.Contains(t, relayAddrs, relayExternal) + + if err := turnServer.Close(); err != nil { + t.Logf("turn server close: %v", err) + } + if err := turnListener.Close(); err != nil { + t.Logf("turn listener close: %v", err) + } +} + func TestICEGatherer_StaticLocalCredentialsVNet(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 20) defer lim.Stop() diff --git a/networktype.go b/networktype.go index a7ee773c..1cd6fe32 100644 --- a/networktype.go +++ b/networktype.go @@ -62,7 +62,7 @@ func (t NetworkType) String() string { } // Protocol returns udp or tcp. -func (t NetworkType) Protocol() string { +func (t NetworkType) Protocol() string { //nolint:staticcheck switch t { case NetworkTypeUDP4: return "udp" @@ -108,3 +108,20 @@ func getNetworkType(iceNetworkType ice.NetworkType) (NetworkType, error) { return NetworkTypeUnknown, fmt.Errorf("%w: %s", errNetworkTypeUnknown, iceNetworkType.String()) } } + +func toICENetworkTypes(networkTypes []NetworkType) []ice.NetworkType { + if len(networkTypes) == 0 { + return nil + } + + converted := make([]ice.NetworkType, 0, len(networkTypes)) + for _, networkType := range networkTypes { + converted = append(converted, networkType.toICE()) + } + + return converted +} + +func (networkType NetworkType) toICE() ice.NetworkType { + return ice.NetworkType(networkType) +} diff --git a/settingengine.go b/settingengine.go index 3111f6d9..f513cfed 100644 --- a/settingengine.go +++ b/settingengine.go @@ -54,6 +54,7 @@ type SettingEngine struct { IPFilter func(net.IP) (keep bool) NAT1To1IPs []string NAT1To1IPCandidateType ICECandidateType + addressRewriteRules []ice.AddressRewriteRule MulticastDNSMode ice.MulticastDNSMode MulticastDNSHostName string UsernameFragment string @@ -329,11 +330,43 @@ func (e *SettingEngine) SetIPFilter(filter func(net.IP) (keep bool)) { // with the public IP. The host candidate is still available along with mDNS // capabilities unaffected. Also, you cannot give STUN server URL at the same time. // It will result in an error otherwise. +// +// Deprecated: Use SetICEAddressRewriteRules instead. To mirror the legacy +// behavior, supply ICEAddressRewriteRule with External set to ips, AsCandidateType +// set to candidateType, and Mode set to ICEAddressRewriteReplace for host +// candidates or ICEAddressRewriteAppend for server reflexive candidates. +// Or leave Mode unspecified to use the default behavior; +// replace for host candidates and append for server reflexive candidates. func (e *SettingEngine) SetNAT1To1IPs(ips []string, candidateType ICECandidateType) { e.candidates.NAT1To1IPs = ips e.candidates.NAT1To1IPCandidateType = candidateType } +// SetICEAddressRewriteRules configures address rewrite rules for candidate publication. +// These rules provide fine-grained control over which local addresses are replaced or +// supplemented with external IPs. +// This replaces the legacy NAT1To1 settings, which will be deprecated in the future. +func (e *SettingEngine) SetICEAddressRewriteRules(rules ...ICEAddressRewriteRule) error { + if len(rules) == 0 { + e.candidates.addressRewriteRules = nil + + return nil + } + + if len(e.candidates.NAT1To1IPs) > 0 { + return errAddressRewriteWithNAT1To1 + } + + converted := make([]ice.AddressRewriteRule, 0, len(rules)) + for _, rule := range rules { + converted = append(converted, rule.toICE()) + } + + e.candidates.addressRewriteRules = converted + + return nil +} + // SetIncludeLoopbackCandidate enable pion to gather loopback candidates, it is useful // for some VM have public IP mapped to loopback interface. func (e *SettingEngine) SetIncludeLoopbackCandidate(include bool) { diff --git a/settingengine_test.go b/settingengine_test.go index c8cefc66..0d5299a7 100644 --- a/settingengine_test.go +++ b/settingengine_test.go @@ -126,6 +126,58 @@ func TestSetNAT1To1IPs(t *testing.T) { assert.Equal(t, typ, settingEngine.candidates.NAT1To1IPCandidateType, "Failed to set NAT1To1IPCandidateType") } +func TestSettingEngine_SetICEAddressRewriteRules_EmptyClears(t *testing.T) { + se := SettingEngine{} + assert.Nil(t, se.candidates.addressRewriteRules) + + assert.NoError(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{"198.51.100.1"}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + })) + assert.NotNil(t, se.candidates.addressRewriteRules) + assert.Len(t, se.candidates.addressRewriteRules, 1) + + se.SetNAT1To1IPs([]string{"203.0.113.1"}, ICECandidateTypeHost) + assert.NoError(t, se.SetICEAddressRewriteRules()) + assert.Nil(t, se.candidates.addressRewriteRules) + + assert.ErrorIs(t, se.SetICEAddressRewriteRules(ICEAddressRewriteRule{ + External: []string{"198.51.100.2"}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }), errAddressRewriteWithNAT1To1) +} + +// ExampleSettingEngine_SetICEAddressRewriteRules_replaceHost demonstrates +// replacing host candidates with a fixed public address using the rewrite API. +func ExampleSettingEngine_SetICEAddressRewriteRules_replaceHost() { + var se SettingEngine + + _ = se.SetICEAddressRewriteRules( + ICEAddressRewriteRule{ + External: []string{"198.51.100.1"}, + AsCandidateType: ICECandidateTypeHost, + Mode: ICEAddressRewriteReplace, + }, + ) +} + +// ExampleSettingEngine_SetICEAddressRewriteRules_appendSrflx demonstrates +// appending a server reflexive candidate that advertises a public address while +// still keeping the host candidate. +func ExampleSettingEngine_SetICEAddressRewriteRules_appendSrflx() { + var se SettingEngine + + _ = se.SetICEAddressRewriteRules( + ICEAddressRewriteRule{ + External: []string{"198.51.100.2"}, + AsCandidateType: ICECandidateTypeSrflx, + Mode: ICEAddressRewriteAppend, + }, + ) +} + func TestSetAnsweringDTLSRole(t *testing.T) { s := SettingEngine{} assert.Error(