Add option to include loopback candidate

Add option to include loopback candidate
This commit is contained in:
cnderrauber
2022-11-22 11:15:14 +08:00
committed by cnderrauber
parent 7f13fd1947
commit e90a58e51a
8 changed files with 113 additions and 14 deletions

View File

@@ -130,6 +130,7 @@ type Agent struct {
interfaceFilter func(string) bool interfaceFilter func(string) bool
ipFilter func(net.IP) bool ipFilter func(net.IP) bool
includeLoopback bool
insecureSkipVerify bool insecureSkipVerify bool
@@ -317,6 +318,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
ipFilter: config.IPFilter, ipFilter: config.IPFilter,
insecureSkipVerify: config.InsecureSkipVerify, insecureSkipVerify: config.InsecureSkipVerify,
includeLoopback: config.IncludeLoopback,
} }
a.tcpMux = config.TCPMux a.tcpMux = config.TCPMux

View File

@@ -165,8 +165,11 @@ type AgentConfig struct {
// dial interface in order to support corporate proxies // dial interface in order to support corporate proxies
ProxyDialer proxy.Dialer ProxyDialer proxy.Dialer
// Accept aggressive nomination in RFC 5245 for compatible with chrome and other browsers // Deprecated: AcceptAggressiveNomination always enabled.
AcceptAggressiveNomination bool AcceptAggressiveNomination bool
// Include loopback addresses in the candidate list.
IncludeLoopback bool
} }
// initWithDefaults populates an agent and falls back to defaults if fields are unset // initWithDefaults populates an agent and falls back to defaults if fields are unset

View File

@@ -149,7 +149,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
delete(networks, udp) delete(networks, udp)
} }
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback)
if err != nil { if err != nil {
a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err) a.log.Warnf("failed to iterate local interfaces, host candidates will not be gathered %s", err)
return return

View File

@@ -13,6 +13,7 @@ import (
"sort" "sort"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -31,7 +32,7 @@ func TestListenUDP(t *testing.T) {
a, err := NewAgent(&AgentConfig{}) a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err) assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test") assert.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test")
assert.NoError(t, err) assert.NoError(t, err)
@@ -86,6 +87,88 @@ func TestListenUDP(t *testing.T) {
assert.NoError(t, a.Close()) assert.NoError(t, a.Close())
} }
func TestLoopbackCandidate(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()
type testCase struct {
name string
agentConfig *AgentConfig
loExpected bool
}
mux, err := NewMultiUDPMuxFromPort(12500)
assert.NoError(t, err)
muxWithLo, errlo := NewMultiUDPMuxFromPort(12501, UDPMuxFromPortWithLoopback())
assert.NoError(t, errlo)
testCases := []testCase{
{
name: "mux should not have loopback candidate",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
UDPMux: mux,
},
loExpected: false,
},
{
name: "mux with loopback should not have loopback candidate",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
UDPMux: muxWithLo,
},
loExpected: true,
},
{
name: "includeloopback enabled",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: true,
},
loExpected: true,
},
{
name: "includeloopback disabled",
agentConfig: &AgentConfig{
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
IncludeLoopback: false,
},
loExpected: false,
},
}
for _, tc := range testCases {
tcase := tc
t.Run(tcase.name, func(t *testing.T) {
a, err := NewAgent(tc.agentConfig)
assert.NoError(t, err)
candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background())
var loopback int32
assert.NoError(t, a.OnCandidate(func(c Candidate) {
if c != nil {
if net.ParseIP(c.Address()).IsLoopback() {
atomic.StoreInt32(&loopback, 1)
}
} else {
candidateGatheredFunc()
return
}
t.Log(c.NetworkType(), c.Priority(), c)
}))
assert.NoError(t, a.GatherCandidates())
<-candidateGathered.Done()
assert.NoError(t, a.Close())
assert.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1)
})
}
assert.NoError(t, mux.Close())
assert.NoError(t, muxWithLo.Close())
}
// Assert that STUN gathering is done concurrently // Assert that STUN gathering is done concurrently
func TestSTUNConcurrency(t *testing.T) { func TestSTUNConcurrency(t *testing.T) {
report := test.CheckRoutines(t) report := test.CheckRoutines(t)

View File

@@ -29,7 +29,7 @@ func TestVNetGather(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) > 0 { if len(localIPs) > 0 {
t.Fatal("should return no local IP") t.Fatal("should return no local IP")
} else if err != nil { } else if err != nil {
@@ -69,7 +69,7 @@ func TestVNetGather(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) == 0 { if len(localIPs) == 0 {
t.Fatal("should have one local IP") t.Fatal("should have one local IP")
} else if err != nil { } else if err != nil {
@@ -112,7 +112,7 @@ func TestVNetGather(t *testing.T) {
t.Fatalf("Failed to create agent: %s", err) t.Fatalf("Failed to create agent: %s", err)
} }
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if len(localIPs) == 0 { if len(localIPs) == 0 {
t.Fatal("localInterfaces found no interfaces, unable to test") t.Fatal("localInterfaces found no interfaces, unable to test")
} else if err != nil { } else if err != nil {
@@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} else if len(localIPs) != 0 { } else if len(localIPs) != 0 {
@@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} else if len(localIPs) != 0 { } else if len(localIPs) != 0 {
@@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}) localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} else if len(localIPs) == 0 { } else if len(localIPs) == 0 {

View File

@@ -78,7 +78,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
} }
if len(networks) > 0 { if len(networks) > 0 {
muxNet := vnet.NewNet(nil) muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, nil, nil, networks) ips, err := localInterfaces(muxNet, nil, nil, networks, true)
if err == nil { if err == nil {
for _, ip := range ips { for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})

View File

@@ -81,7 +81,7 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu
opt.apply(&params) opt.apply(&params)
} }
muxNet := vnet.NewNet(nil) muxNet := vnet.NewNet(nil)
ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks) ips, err := localInterfaces(muxNet, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -130,6 +130,7 @@ type multiUDPMuxFromPortParam struct {
readBufferSize int readBufferSize int
writeBufferSize int writeBufferSize int
logger logging.LeveledLogger logger logging.LeveledLogger
includeLoopback bool
} }
type udpMuxFromPortOption struct { type udpMuxFromPortOption struct {
@@ -193,3 +194,12 @@ func UDPMuxFromPortWithLogger(logger logging.LeveledLogger) UDPMuxFromPortOption
}, },
} }
} }
// UDPMuxFromPortWithLoopback set loopback interface should be included
func UDPMuxFromPortWithLoopback() UDPMuxFromPortOption {
return &udpMuxFromPortOption{
f: func(p *multiUDPMuxFromPortParam) {
p.includeLoopback = true
},
}
}

View File

@@ -132,7 +132,7 @@ func stunRequest(read func([]byte) (int, error), write func([]byte) (int, error)
return res, nil return res, nil
} }
func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType) ([]net.IP, error) { //nolint:gocognit func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
ips := []net.IP{} ips := []net.IP{}
ifaces, err := vnet.Interfaces() ifaces, err := vnet.Interfaces()
if err != nil { if err != nil {
@@ -154,7 +154,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
if iface.Flags&net.FlagUp == 0 { if iface.Flags&net.FlagUp == 0 {
continue // interface down continue // interface down
} }
if iface.Flags&net.FlagLoopback != 0 { if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
continue // loopback interface continue // loopback interface
} }
@@ -175,7 +175,7 @@ func localInterfaces(vnet *vnet.Net, interfaceFilter func(string) bool, ipFilter
case *net.IPAddr: case *net.IPAddr:
ip = addr.IP ip = addr.IP
} }
if ip == nil || ip.IsLoopback() { if ip == nil || (ip.IsLoopback() && !includeLoopback) {
continue continue
} }