Add option to include loopback candidate

Add option to include loopback candidate
This commit is contained in:
cnderrauber
2022-11-22 11:15:14 +08:00
committed by cnderrauber
parent 7f13fd1947
commit e90a58e51a
8 changed files with 113 additions and 14 deletions

View File

@@ -130,6 +130,7 @@ type Agent struct {
interfaceFilter func(string) bool
ipFilter func(net.IP) bool
includeLoopback bool
insecureSkipVerify bool
@@ -317,6 +318,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
ipFilter: config.IPFilter,
insecureSkipVerify: config.InsecureSkipVerify,
includeLoopback: config.IncludeLoopback,
}
a.tcpMux = config.TCPMux

View File

@@ -165,8 +165,11 @@ type AgentConfig struct {
// dial interface in order to support corporate proxies
ProxyDialer proxy.Dialer
// Accept aggressive nomination in RFC 5245 for compatible with chrome and other browsers
// Deprecated: AcceptAggressiveNomination always enabled.
AcceptAggressiveNomination bool
// Include loopback addresses in the candidate list.
IncludeLoopback bool
}
// initWithDefaults populates an agent and falls back to defaults if fields are unset

View File

@@ -149,7 +149,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
delete(networks, udp)
}
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
if err != nil {
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
return

View File

@@ -13,6 +13,7 @@ import (
"sort"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
@@ -31,7 +32,7 @@ func TestListenUDP(t *testing.T) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test")
assert.NoError(t, err)
@@ -86,6 +87,88 @@ func TestListenUDP(t *testing.T) {
assert.NoError(t, a.Close())
}
func TestLoopbackCandidate(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
type testCase struct {
name string
agentConfig *AgentConfig
loExpected bool
}
mux, err := NewMultiUDPMuxFromPort(12500)
assert.NoError(t, err)
muxWithLo, errlo := NewMultiUDPMuxFromPort(12501, UDPMuxFromPortWithLoopback())
assert.NoError(t, errlo)
testCases := []testCase{
{
name: "mux should not have loopback candidate",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
UDPMux: mux,
},
loExpected: false,
},
{
name: "mux with loopback should not have loopback candidate",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
UDPMux: muxWithLo,
},
loExpected: true,
},
{
name: "includeloopback enabled",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: true,
},
loExpected: true,
},
{
name: "includeloopback disabled",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: false,
},
loExpected: false,
},
}
for _, tc := range testCases {
tcase := tc
t.Run(tcase.name, func(t *testing.T) {
a, err := NewAgent(tc.agentConfig)
assert.NoError(t, err)
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
var loopback int32
assert.NoError(t, a.OnCandidate(func(c Candidate) {
if c != nil {
if net.ParseIP(c.Address()).IsLoopback() {
atomic.StoreInt32(&loopback, 1)
}
} else {
candidateGatheredFunc()
return
}
t.Log(c.NetworkType(), c.Priority(), c)
}))
assert.NoError(t, a.GatherCandidates())
<-candidateGathered.Done()
assert.NoError(t, a.Close())
assert.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1)
})
}
assert.NoError(t, mux.Close())
assert.NoError(t, muxWithLo.Close())
}
// Assert that STUN gathering is done concurrently
func TestSTUNConcurrency(t *testing.T) {
report := test.CheckRoutines(t)

View File

@@ -29,7 +29,7 @@ func TestVNetGather(t *testing.T) {
})
assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) > 0 {
t.Fatal("should return no local IP")
} else if err != nil {
@@ -69,7 +69,7 @@ func TestVNetGather(t *testing.T) {
})
assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) == 0 {
t.Fatal("should have one local IP")
} else if err != nil {
@@ -112,7 +112,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create agent: %s", err)
}
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) == 0 {
t.Fatal("localInterfaces found no interfaces, unable to test")
} else if err != nil {
@@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil {
t.Fatal(err)
} else if len(localIPs) != 0 {
@@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil {
t.Fatal(err)
} else if len(localIPs) != 0 {
@@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
})
assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil {
t.Fatal(err)
} else if len(localIPs) == 0 {

View File

@@ -78,7 +78,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}
if len(networks) > 0 {
muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, nil, nil, networks)
ips, err := localInterfaces(muxNet, nil, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})

View File

@@ -81,7 +81,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
opt.apply(&params)
}
muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback)
if err != nil {
return nil, err
}
@@ -130,6 +130,7 @@ type multiUDPMuxFromPortParam struct {
readBufferSize int
writeBufferSize int
logger logging.LeveledLogger
includeLoopback bool
}
type udpMuxFromPortOption struct {
@@ -193,3 +194,12 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
},
}
}
// UDPMuxFromPortWithLoopback set loopback interface should be included
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.includeLoopback = true
},
}
}

View File

@@ -132,7 +132,7 @@ func stunRequest(read func([]byte) (int, error), write func([]byte) (int, error)
return res, nil
}
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType) ([]net.IP, error) { //nolint:gocognit
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{}
ifaces, err := vnet.Interfaces()
if err != nil {
@@ -154,7 +154,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface
}
@@ -175,7 +175,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
case *net.IPAddr:
ip = addr.IP
}
if ip == nil || ip.IsLoopback() {
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue
}