mirror of
https://github.com/pion/ice.git
synced 2025-09-27 03:45:54 +08:00
Add option to include loopback candidate
Add option to include loopback candidate
This commit is contained in:
3
agent.go
3
agent.go
@@ -130,6 +130,7 @@ type Agent struct {
|
||||
|
||||
interfaceFilter func(string) bool
|
||||
ipFilter func(net.IP) bool
|
||||
includeLoopback bool
|
||||
|
||||
insecureSkipVerify bool
|
||||
|
||||
@@ -317,6 +318,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
||||
ipFilter: config.IPFilter,
|
||||
|
||||
insecureSkipVerify: config.InsecureSkipVerify,
|
||||
|
||||
includeLoopback: config.IncludeLoopback,
|
||||
}
|
||||
|
||||
a.tcpMux = config.TCPMux
|
||||
|
@@ -165,8 +165,11 @@ type AgentConfig struct {
|
||||
// dial interface in order to support corporate proxies
|
||||
ProxyDialer proxy.Dialer
|
||||
|
||||
// Accept aggressive nomination in RFC 5245 for compatible with chrome and other browsers
|
||||
// Deprecated: AcceptAggressiveNomination always enabled.
|
||||
AcceptAggressiveNomination bool
|
||||
|
||||
// Include loopback addresses in the candidate list.
|
||||
IncludeLoopback bool
|
||||
}
|
||||
|
||||
// initWithDefaults populates an agent and falls back to defaults if fields are unset
|
||||
|
@@ -149,7 +149,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
||||
delete(networks, udp)
|
||||
}
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes)
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
|
||||
if err != nil {
|
||||
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
|
||||
return
|
||||
|
@@ -13,6 +13,7 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -31,7 +32,7 @@ func TestListenUDP(t *testing.T) {
|
||||
a, err := NewAgent(&AgentConfig{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -86,6 +87,88 @@ func TestListenUDP(t *testing.T) {
|
||||
assert.NoError(t, a.Close())
|
||||
}
|
||||
|
||||
func TestLoopbackCandidate(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
defer report()
|
||||
|
||||
lim := test.TimeOut(time.Second * 30)
|
||||
defer lim.Stop()
|
||||
type testCase struct {
|
||||
name string
|
||||
agentConfig *AgentConfig
|
||||
loExpected bool
|
||||
}
|
||||
mux, err := NewMultiUDPMuxFromPort(12500)
|
||||
assert.NoError(t, err)
|
||||
muxWithLo, errlo := NewMultiUDPMuxFromPort(12501, UDPMuxFromPortWithLoopback())
|
||||
assert.NoError(t, errlo)
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "mux should not have loopback candidate",
|
||||
agentConfig: &AgentConfig{
|
||||
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
UDPMux: mux,
|
||||
},
|
||||
loExpected: false,
|
||||
},
|
||||
{
|
||||
name: "mux with loopback should not have loopback candidate",
|
||||
agentConfig: &AgentConfig{
|
||||
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
UDPMux: muxWithLo,
|
||||
},
|
||||
loExpected: true,
|
||||
},
|
||||
{
|
||||
name: "includeloopback enabled",
|
||||
agentConfig: &AgentConfig{
|
||||
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
IncludeLoopback: true,
|
||||
},
|
||||
loExpected: true,
|
||||
},
|
||||
{
|
||||
name: "includeloopback disabled",
|
||||
agentConfig: &AgentConfig{
|
||||
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||
IncludeLoopback: false,
|
||||
},
|
||||
loExpected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tcase := tc
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
a, err := NewAgent(tc.agentConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||
var loopback int32
|
||||
assert.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||
if c != nil {
|
||||
if net.ParseIP(c.Address()).IsLoopback() {
|
||||
atomic.StoreInt32(&loopback, 1)
|
||||
}
|
||||
} else {
|
||||
candidateGatheredFunc()
|
||||
return
|
||||
}
|
||||
t.Log(c.NetworkType(), c.Priority(), c)
|
||||
}))
|
||||
assert.NoError(t, a.GatherCandidates())
|
||||
|
||||
<-candidateGathered.Done()
|
||||
|
||||
assert.NoError(t, a.Close())
|
||||
assert.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1)
|
||||
})
|
||||
}
|
||||
|
||||
assert.NoError(t, mux.Close())
|
||||
assert.NoError(t, muxWithLo.Close())
|
||||
}
|
||||
|
||||
// Assert that STUN gathering is done concurrently
|
||||
func TestSTUNConcurrency(t *testing.T) {
|
||||
report := test.CheckRoutines(t)
|
||||
|
@@ -29,7 +29,7 @@ func TestVNetGather(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
if len(localIPs) > 0 {
|
||||
t.Fatal("should return no local IP")
|
||||
} else if err != nil {
|
||||
@@ -69,7 +69,7 @@ func TestVNetGather(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
if len(localIPs) == 0 {
|
||||
t.Fatal("should have one local IP")
|
||||
} else if err != nil {
|
||||
@@ -112,7 +112,7 @@ func TestVNetGather(t *testing.T) {
|
||||
t.Fatalf("Failed to create agent: %s", err)
|
||||
}
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
if len(localIPs) == 0 {
|
||||
t.Fatal("localInterfaces found no interfaces, unable to test")
|
||||
} else if err != nil {
|
||||
@@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(localIPs) != 0 {
|
||||
@@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(localIPs) != 0 {
|
||||
@@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(localIPs) == 0 {
|
||||
|
@@ -78,7 +78,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
muxNet := vnet.NewNet(nil)
|
||||
ips, err := localInterfaces(muxNet, nil, nil, networks)
|
||||
ips, err := localInterfaces(muxNet, nil, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
|
@@ -81,7 +81,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
|
||||
opt.apply(¶ms)
|
||||
}
|
||||
muxNet := vnet.NewNet(nil)
|
||||
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks)
|
||||
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -130,6 +130,7 @@ type multiUDPMuxFromPortParam struct {
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
logger logging.LeveledLogger
|
||||
includeLoopback bool
|
||||
}
|
||||
|
||||
type udpMuxFromPortOption struct {
|
||||
@@ -193,3 +194,12 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UDPMuxFromPortWithLoopback set loopback interface should be included
|
||||
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
|
||||
return &udpMuxFromPortOption{
|
||||
f: func(p *multiUDPMuxFromPortParam) {
|
||||
p.includeLoopback = true
|
||||
},
|
||||
}
|
||||
}
|
||||
|
6
util.go
6
util.go
@@ -132,7 +132,7 @@ func stunRequest(read func([]byte) (int, error), write func([]byte) (int, error)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType) ([]net.IP, error) { //nolint:gocognit
|
||||
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
|
||||
ips := []net.IP{}
|
||||
ifaces, err := vnet.Interfaces()
|
||||
if err != nil {
|
||||
@@ -154,7 +154,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue // interface down
|
||||
}
|
||||
if iface.Flags&net.FlagLoopback != 0 {
|
||||
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
|
||||
continue // loopback interface
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
|
||||
case *net.IPAddr:
|
||||
ip = addr.IP
|
||||
}
|
||||
if ip == nil || ip.IsLoopback() {
|
||||
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user