mirror of
https://github.com/pion/ice.git
synced 2025-10-01 22:02:07 +08:00
Add option to include loopback candidate
Add option to include loopback candidate
This commit is contained in:
3
agent.go
3
agent.go
@@ -130,6 +130,7 @@ type Agent struct {
|
|||||||
|
|
||||||
interfaceFilter func(string) bool
|
interfaceFilter func(string) bool
|
||||||
ipFilter func(net.IP) bool
|
ipFilter func(net.IP) bool
|
||||||
|
includeLoopback bool
|
||||||
|
|
||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
|
|
||||||
@@ -317,6 +318,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
|
|||||||
ipFilter: config.IPFilter,
|
ipFilter: config.IPFilter,
|
||||||
|
|
||||||
insecureSkipVerify: config.InsecureSkipVerify,
|
insecureSkipVerify: config.InsecureSkipVerify,
|
||||||
|
|
||||||
|
includeLoopback: config.IncludeLoopback,
|
||||||
}
|
}
|
||||||
|
|
||||||
a.tcpMux = config.TCPMux
|
a.tcpMux = config.TCPMux
|
||||||
|
@@ -165,8 +165,11 @@ type AgentConfig struct {
|
|||||||
// dial interface in order to support corporate proxies
|
// dial interface in order to support corporate proxies
|
||||||
ProxyDialer proxy.Dialer
|
ProxyDialer proxy.Dialer
|
||||||
|
|
||||||
// Accept aggressive nomination in RFC 5245 for compatible with chrome and other browsers
|
// Deprecated: AcceptAggressiveNomination always enabled.
|
||||||
AcceptAggressiveNomination bool
|
AcceptAggressiveNomination bool
|
||||||
|
|
||||||
|
// Include loopback addresses in the candidate list.
|
||||||
|
IncludeLoopback bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// initWithDefaults populates an agent and falls back to defaults if fields are unset
|
// initWithDefaults populates an agent and falls back to defaults if fields are unset
|
||||||
|
@@ -149,7 +149,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
|
|||||||
delete(networks, udp)
|
delete(networks, udp)
|
||||||
}
|
}
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes)
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
|
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
|
||||||
return
|
return
|
||||||
|
@@ -13,6 +13,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ func TestListenUDP(t *testing.T) {
|
|||||||
a, err := NewAgent(&AgentConfig{})
|
a, err := NewAgent(&AgentConfig{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test")
|
assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
@@ -86,6 +87,88 @@ func TestListenUDP(t *testing.T) {
|
|||||||
assert.NoError(t, a.Close())
|
assert.NoError(t, a.Close())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoopbackCandidate(t *testing.T) {
|
||||||
|
report := test.CheckRoutines(t)
|
||||||
|
defer report()
|
||||||
|
|
||||||
|
lim := test.TimeOut(time.Second * 30)
|
||||||
|
defer lim.Stop()
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
agentConfig *AgentConfig
|
||||||
|
loExpected bool
|
||||||
|
}
|
||||||
|
mux, err := NewMultiUDPMuxFromPort(12500)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
muxWithLo, errlo := NewMultiUDPMuxFromPort(12501, UDPMuxFromPortWithLoopback())
|
||||||
|
assert.NoError(t, errlo)
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "mux should not have loopback candidate",
|
||||||
|
agentConfig: &AgentConfig{
|
||||||
|
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||||
|
UDPMux: mux,
|
||||||
|
},
|
||||||
|
loExpected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mux with loopback should not have loopback candidate",
|
||||||
|
agentConfig: &AgentConfig{
|
||||||
|
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||||
|
UDPMux: muxWithLo,
|
||||||
|
},
|
||||||
|
loExpected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "includeloopback enabled",
|
||||||
|
agentConfig: &AgentConfig{
|
||||||
|
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||||
|
IncludeLoopback: true,
|
||||||
|
},
|
||||||
|
loExpected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "includeloopback disabled",
|
||||||
|
agentConfig: &AgentConfig{
|
||||||
|
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
|
||||||
|
IncludeLoopback: false,
|
||||||
|
},
|
||||||
|
loExpected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tcase := tc
|
||||||
|
t.Run(tcase.name, func(t *testing.T) {
|
||||||
|
a, err := NewAgent(tc.agentConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
|
||||||
|
var loopback int32
|
||||||
|
assert.NoError(t, a.OnCandidate(func(c Candidate) {
|
||||||
|
if c != nil {
|
||||||
|
if net.ParseIP(c.Address()).IsLoopback() {
|
||||||
|
atomic.StoreInt32(&loopback, 1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
candidateGatheredFunc()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Log(c.NetworkType(), c.Priority(), c)
|
||||||
|
}))
|
||||||
|
assert.NoError(t, a.GatherCandidates())
|
||||||
|
|
||||||
|
<-candidateGathered.Done()
|
||||||
|
|
||||||
|
assert.NoError(t, a.Close())
|
||||||
|
assert.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, mux.Close())
|
||||||
|
assert.NoError(t, muxWithLo.Close())
|
||||||
|
}
|
||||||
|
|
||||||
// Assert that STUN gathering is done concurrently
|
// Assert that STUN gathering is done concurrently
|
||||||
func TestSTUNConcurrency(t *testing.T) {
|
func TestSTUNConcurrency(t *testing.T) {
|
||||||
report := test.CheckRoutines(t)
|
report := test.CheckRoutines(t)
|
||||||
|
@@ -29,7 +29,7 @@ func TestVNetGather(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
if len(localIPs) > 0 {
|
if len(localIPs) > 0 {
|
||||||
t.Fatal("should return no local IP")
|
t.Fatal("should return no local IP")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -69,7 +69,7 @@ func TestVNetGather(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
if len(localIPs) == 0 {
|
if len(localIPs) == 0 {
|
||||||
t.Fatal("should have one local IP")
|
t.Fatal("should have one local IP")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -112,7 +112,7 @@ func TestVNetGather(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create agent: %s", err)
|
t.Fatalf("Failed to create agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
if len(localIPs) == 0 {
|
if len(localIPs) == 0 {
|
||||||
t.Fatal("localInterfaces found no interfaces, unable to test")
|
t.Fatal("localInterfaces found no interfaces, unable to test")
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if len(localIPs) != 0 {
|
} else if len(localIPs) != 0 {
|
||||||
@@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if len(localIPs) != 0 {
|
} else if len(localIPs) != 0 {
|
||||||
@@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
|
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if len(localIPs) == 0 {
|
} else if len(localIPs) == 0 {
|
||||||
|
@@ -78,7 +78,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
}
|
}
|
||||||
if len(networks) > 0 {
|
if len(networks) > 0 {
|
||||||
muxNet := vnet.NewNet(nil)
|
muxNet := vnet.NewNet(nil)
|
||||||
ips, err := localInterfaces(muxNet, nil, nil, networks)
|
ips, err := localInterfaces(muxNet, nil, nil, networks, true)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||||
|
@@ -81,7 +81,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
|
|||||||
opt.apply(¶ms)
|
opt.apply(¶ms)
|
||||||
}
|
}
|
||||||
muxNet := vnet.NewNet(nil)
|
muxNet := vnet.NewNet(nil)
|
||||||
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks)
|
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -130,6 +130,7 @@ type multiUDPMuxFromPortParam struct {
|
|||||||
readBufferSize int
|
readBufferSize int
|
||||||
writeBufferSize int
|
writeBufferSize int
|
||||||
logger logging.LeveledLogger
|
logger logging.LeveledLogger
|
||||||
|
includeLoopback bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpMuxFromPortOption struct {
|
type udpMuxFromPortOption struct {
|
||||||
@@ -193,3 +194,12 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UDPMuxFromPortWithLoopback set loopback interface should be included
|
||||||
|
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
|
||||||
|
return &udpMuxFromPortOption{
|
||||||
|
f: func(p *multiUDPMuxFromPortParam) {
|
||||||
|
p.includeLoopback = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
6
util.go
6
util.go
@@ -132,7 +132,7 @@ func stunRequest(read func([]byte) (int, error), write func([]byte) (int, error)
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType) ([]net.IP, error) { //nolint:gocognit
|
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
|
||||||
ips := []net.IP{}
|
ips := []net.IP{}
|
||||||
ifaces, err := vnet.Interfaces()
|
ifaces, err := vnet.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -154,7 +154,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
|
|||||||
if iface.Flags&net.FlagUp == 0 {
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
continue // interface down
|
continue // interface down
|
||||||
}
|
}
|
||||||
if iface.Flags&net.FlagLoopback != 0 {
|
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
|
||||||
continue // loopback interface
|
continue // loopback interface
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
|
|||||||
case *net.IPAddr:
|
case *net.IPAddr:
|
||||||
ip = addr.IP
|
ip = addr.IP
|
||||||
}
|
}
|
||||||
if ip == nil || ip.IsLoopback() {
|
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user