diff --git a/icegatherer.go b/icegatherer.go index 96dcb871..0fe7e876 100644 --- a/icegatherer.go +++ b/icegatherer.go @@ -170,8 +170,13 @@ func (g *ICEGatherer) resolveCandidateTypes() []ice.CandidateType { if g.api.settingEngine.candidates.ICELite { return []ice.CandidateType{ice.CandidateTypeHost} } - if g.gatherPolicy == ICETransportPolicyRelay { + + switch g.gatherPolicy { + case ICETransportPolicyRelay: return []ice.CandidateType{ice.CandidateTypeRelay} + case ICETransportPolicyNoHost: + return []ice.CandidateType{ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} + default: } return nil diff --git a/icegatherer_test.go b/icegatherer_test.go index 2972f457..fe636105 100644 --- a/icegatherer_test.go +++ b/icegatherer_test.go @@ -336,6 +336,155 @@ func gatherCandidatesWithSettingEngine(t *testing.T, se SettingEngine, opts ICEG return candidates } +func TestICEGatherer_NoHostPolicyVNet(t *testing.T) { //nolint:cyclop + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + const ( + stunIP = "1.2.3.4" + stunPort = 3478 + externalIP = "1.2.3.10" + localIP = "10.0.0.1" + realm = "pion.ly" + timeout = 3 * time.Second + ) + + loggerFactory := logging.NewDefaultLoggerFactory() + + wan, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "1.2.3.0/24", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err) + + stunNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: stunIP, + }) + assert.NoError(t, err) + assert.NoError(t, wan.AddNet(stunNet)) + + clientLAN, err := vnet.NewRouter(&vnet.RouterConfig{ + StaticIPs: []string{fmt.Sprintf("%s/%s", externalIP, localIP)}, + CIDR: "10.0.0.0/24", + NATType: &vnet.NATType{ + Mode: vnet.NATModeNAT1To1, + }, + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err) + + clientNet, err := vnet.NewNet(&vnet.NetConfig{ + StaticIP: localIP, + }) + assert.NoError(t, err) + assert.NoError(t, clientLAN.AddNet(clientNet)) + assert.NoError(t, wan.AddRouter(clientLAN)) + assert.NoError(t, wan.Start()) + defer func() { + assert.NoError(t, wan.Stop()) + }() + + stunListener, err := stunNet.ListenPacket("udp4", net.JoinHostPort(stunIP, fmt.Sprintf("%d", stunPort))) + assert.NoError(t, err) + + turnServer, err := turn.NewServer(turn.ServerConfig{ + Realm: realm, + LoggerFactory: loggerFactory, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: stunListener, + RelayAddressGenerator: &turn.RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP(stunIP), + Address: "0.0.0.0", + Net: stunNet, + }, + }, + }, + }) + assert.NoError(t, err) + defer func() { + assert.NoError(t, turnServer.Close()) + }() + + iceServer := ICEServer{ + URLs: []string{fmt.Sprintf("stun:%s:%d", stunIP, stunPort)}, + } + + collect := func(t *testing.T, policy ICETransportPolicy) []ICECandidate { + t.Helper() + + se := SettingEngine{} + se.SetICEMulticastDNSMode(ice.MulticastDNSModeDisabled) + se.SetNetworkTypes([]NetworkType{NetworkTypeUDP4}) + se.SetNet(clientNet) + + gatherer, err := NewAPI(WithSettingEngine(se)).NewICEGatherer(ICEGatherOptions{ + ICEServers: []ICEServer{iceServer}, + ICEGatherPolicy: policy, + }) + assert.NoError(t, err) + defer func() { + assert.NoError(t, gatherer.Close()) + }() + + done := make(chan struct{}) + var candidates []ICECandidate + gatherer.OnLocalCandidate(func(c *ICECandidate) { + if c == nil { + close(done) + } else { + candidates = append(candidates, *c) + } + }) + + assert.NoError(t, gatherer.Gather()) + + select { + case <-done: + case <-time.After(timeout): + assert.Fail(t, "gathering did not complete") + } + + return candidates + } + + t.Run("All", func(t *testing.T) { + candidates := collect(t, ICETransportPolicyAll) + assert.NotEmpty(t, candidates) + + var haveHost, haveSrflx bool + for _, c := range candidates { + switch c.Typ { + case ICECandidateTypeHost: + haveHost = true + case ICECandidateTypeSrflx: + haveSrflx = true + assert.Equal(t, externalIP, c.Address) + default: + } + } + + assert.True(t, haveHost, "expected host candidate") + assert.True(t, haveSrflx, "expected srflx candidate") + }) + + t.Run("NoHost", func(t *testing.T) { + candidates := collect(t, ICETransportPolicyNoHost) + if assert.NotEmpty(t, candidates) { + for _, c := range candidates { + assert.Equal(t, ICECandidateTypeSrflx, c.Typ) + assert.Equal(t, externalIP, c.Address) + } + for _, c := range candidates { + assert.NotEqual(t, ICECandidateTypeHost, c.Typ) + } + } + }) +} + func TestICEGatherer_AddressRewriteRulesVNet(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 10) defer lim.Stop() diff --git a/icetransportpolicy.go b/icetransportpolicy.go index 39a1fa36..9aa8cb0f 100644 --- a/icetransportpolicy.go +++ b/icetransportpolicy.go @@ -21,17 +21,23 @@ const ( // ICETransportPolicyRelay indicates only media relay candidates such // as candidates passing through a TURN server are used. ICETransportPolicyRelay + + // ICETransportPolicyNoHost indicates only non-host candidates are used. + ICETransportPolicyNoHost ) // This is done this way because of a linter. const ( - iceTransportPolicyRelayStr = "relay" - iceTransportPolicyAllStr = "all" + iceTransportPolicyRelayStr = "relay" + iceTransportPolicyNoHostStr = "nohost" + iceTransportPolicyAllStr = "all" ) // NewICETransportPolicy takes a string and converts it to ICETransportPolicy. func NewICETransportPolicy(raw string) ICETransportPolicy { switch raw { + case iceTransportPolicyNoHostStr: + return ICETransportPolicyNoHost case iceTransportPolicyRelayStr: return ICETransportPolicyRelay default: @@ -41,6 +47,8 @@ func NewICETransportPolicy(raw string) ICETransportPolicy { func (t ICETransportPolicy) String() string { switch t { + case ICETransportPolicyNoHost: + return iceTransportPolicyNoHostStr case ICETransportPolicyRelay: return iceTransportPolicyRelayStr case ICETransportPolicyAll: diff --git a/icetransportpolicy_test.go b/icetransportpolicy_test.go index 0f0dbf56..e5387fcf 100644 --- a/icetransportpolicy_test.go +++ b/icetransportpolicy_test.go @@ -14,6 +14,7 @@ func TestNewICETransportPolicy(t *testing.T) { policyString string expectedPolicy ICETransportPolicy }{ + {"nohost", ICETransportPolicyNoHost}, {"relay", ICETransportPolicyRelay}, {"all", ICETransportPolicyAll}, } @@ -32,6 +33,7 @@ func TestICETransportPolicy_String(t *testing.T) { policy ICETransportPolicy expectedString string }{ + {ICETransportPolicyNoHost, "nohost"}, {ICETransportPolicyRelay, "relay"}, {ICETransportPolicyAll, "all"}, }