mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-27 04:26:41 +08:00
basichost: use autonatv2 to verify reachability (#3231)
This introduces addrsReachabilityTracker that tracks reachability on a set of addresses. It probes reachability for addresses periodically and has an exponential backoff in case there are too many errors or we don't have any valid autonatv2 peer. There's no smartness in the address selection logic currently. We just test all provided addresses. It also doesn't use the addresses provided by `AddrsFactory`, so currently there's no way to get a user provided address tested for reachability, something that would be a problem for dns addresses. I intend to introduce an alternative to `AddrsFactory`, something like, `AnnounceAddrs(addrs []ma.Multiaddr)` that's just appended to the set of addresses that we have, and check reachability for those addresses. There's only one method exposed in the BasicHost right now that's `ReachableAddrs() []ma.Multiadd`r that returns the host's reachable addrs. Users can also use the event `EvtHostReachableAddrsChanged` to be notified when any addrs reachability changes.
This commit is contained in:
@@ -8,6 +8,8 @@ linters:
|
||||
- revive
|
||||
- unused
|
||||
- prealloc
|
||||
disable:
|
||||
- errcheck
|
||||
|
||||
settings:
|
||||
revive:
|
||||
|
@@ -33,6 +33,7 @@ import (
|
||||
routed "github.com/libp2p/go-libp2p/p2p/host/routed"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/swarm"
|
||||
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
|
||||
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
|
||||
@@ -413,15 +414,7 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
|
||||
return fxopts, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) {
|
||||
var autonatv2Dialer host.Host
|
||||
if cfg.EnableAutoNATv2 {
|
||||
ah, err := cfg.makeAutoNATV2Host()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
autonatv2Dialer = ah
|
||||
}
|
||||
func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus, an *autonatv2.AutoNAT) (*bhost.BasicHost, error) {
|
||||
h, err := bhost.NewHost(swrm, &bhost.HostOpts{
|
||||
EventBus: eventBus,
|
||||
ConnManager: cfg.ConnManager,
|
||||
@@ -437,8 +430,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B
|
||||
EnableMetrics: !cfg.DisableMetrics,
|
||||
PrometheusRegisterer: cfg.PrometheusRegisterer,
|
||||
DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery,
|
||||
EnableAutoNATv2: cfg.EnableAutoNATv2,
|
||||
AutoNATv2Dialer: autonatv2Dialer,
|
||||
AutoNATv2: an,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -517,6 +509,24 @@ func (cfg *Config) NewNode() (host.Host, error) {
|
||||
})
|
||||
return sw, nil
|
||||
}),
|
||||
fx.Provide(func() (*autonatv2.AutoNAT, error) {
|
||||
if !cfg.EnableAutoNATv2 {
|
||||
return nil, nil
|
||||
}
|
||||
ah, err := cfg.makeAutoNATV2Host()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var mt autonatv2.MetricsTracer
|
||||
if !cfg.DisableMetrics {
|
||||
mt = autonatv2.NewMetricsTracer(cfg.PrometheusRegisterer)
|
||||
}
|
||||
autoNATv2, err := autonatv2.New(ah, autonatv2.WithMetricsTracer(mt))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create autonatv2: %w", err)
|
||||
}
|
||||
return autoNATv2, nil
|
||||
}),
|
||||
fx.Provide(cfg.newBasicHost),
|
||||
fx.Provide(func(bh *bhost.BasicHost) identify.IDService {
|
||||
return bh.IDService()
|
||||
|
@@ -2,6 +2,7 @@ package event
|
||||
|
||||
import (
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// EvtLocalReachabilityChanged is an event struct to be emitted when the local's
|
||||
@@ -11,3 +12,12 @@ import (
|
||||
type EvtLocalReachabilityChanged struct {
|
||||
Reachability network.Reachability
|
||||
}
|
||||
|
||||
// EvtHostReachableAddrsChanged is sent when host's reachable or unreachable addresses change
|
||||
// Reachable and Unreachable both contain only Public IP or DNS addresses
|
||||
//
|
||||
// Experimental: This API is unstable. Any changes to this event will be done without a deprecation notice.
|
||||
type EvtHostReachableAddrsChanged struct {
|
||||
Reachable []ma.Multiaddr
|
||||
Unreachable []ma.Multiaddr
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
|
||||
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
|
||||
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
|
||||
"github.com/libp2p/go-netroute"
|
||||
@@ -27,24 +29,36 @@ type observedAddrsManager interface {
|
||||
ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr
|
||||
}
|
||||
|
||||
type hostAddrs struct {
|
||||
addrs []ma.Multiaddr
|
||||
localAddrs []ma.Multiaddr
|
||||
reachableAddrs []ma.Multiaddr
|
||||
unreachableAddrs []ma.Multiaddr
|
||||
relayAddrs []ma.Multiaddr
|
||||
}
|
||||
|
||||
type addrsManager struct {
|
||||
eventbus event.Bus
|
||||
bus event.Bus
|
||||
natManager NATManager
|
||||
addrsFactory AddrsFactory
|
||||
listenAddrs func() []ma.Multiaddr
|
||||
transportForListening func(ma.Multiaddr) transport.Transport
|
||||
observedAddrsManager observedAddrsManager
|
||||
interfaceAddrs *interfaceAddrsCache
|
||||
addrsReachabilityTracker *addrsReachabilityTracker
|
||||
|
||||
// addrsUpdatedChan is notified when addrs change. This is provided by the caller.
|
||||
addrsUpdatedChan chan struct{}
|
||||
|
||||
// triggerAddrsUpdateChan is used to trigger an addresses update.
|
||||
triggerAddrsUpdateChan chan struct{}
|
||||
// addrsUpdatedChan is notified when addresses change.
|
||||
addrsUpdatedChan chan struct{}
|
||||
// triggerReachabilityUpdate is notified when reachable addrs are updated.
|
||||
triggerReachabilityUpdate chan struct{}
|
||||
|
||||
hostReachability atomic.Pointer[network.Reachability]
|
||||
|
||||
addrsMx sync.RWMutex // protects fields below
|
||||
localAddrs []ma.Multiaddr
|
||||
relayAddrs []ma.Multiaddr
|
||||
addrsMx sync.RWMutex
|
||||
currentAddrs hostAddrs
|
||||
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
@@ -52,23 +66,25 @@ type addrsManager struct {
|
||||
}
|
||||
|
||||
func newAddrsManager(
|
||||
eventbus event.Bus,
|
||||
bus event.Bus,
|
||||
natmgr NATManager,
|
||||
addrsFactory AddrsFactory,
|
||||
listenAddrs func() []ma.Multiaddr,
|
||||
transportForListening func(ma.Multiaddr) transport.Transport,
|
||||
observedAddrsManager observedAddrsManager,
|
||||
addrsUpdatedChan chan struct{},
|
||||
client autonatv2Client,
|
||||
) (*addrsManager, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
as := &addrsManager{
|
||||
eventbus: eventbus,
|
||||
bus: bus,
|
||||
listenAddrs: listenAddrs,
|
||||
transportForListening: transportForListening,
|
||||
observedAddrsManager: observedAddrsManager,
|
||||
natManager: natmgr,
|
||||
addrsFactory: addrsFactory,
|
||||
triggerAddrsUpdateChan: make(chan struct{}, 1),
|
||||
triggerReachabilityUpdate: make(chan struct{}, 1),
|
||||
addrsUpdatedChan: addrsUpdatedChan,
|
||||
interfaceAddrs: &interfaceAddrsCache{},
|
||||
ctx: ctx,
|
||||
@@ -76,11 +92,23 @@ func newAddrsManager(
|
||||
}
|
||||
unknownReachability := network.ReachabilityUnknown
|
||||
as.hostReachability.Store(&unknownReachability)
|
||||
|
||||
if client != nil {
|
||||
as.addrsReachabilityTracker = newAddrsReachabilityTracker(client, as.triggerReachabilityUpdate, nil)
|
||||
}
|
||||
return as, nil
|
||||
}
|
||||
|
||||
func (a *addrsManager) Start() error {
|
||||
return a.background()
|
||||
// TODO: add Start method to NATMgr
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
err := a.addrsReachabilityTracker.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error starting addrs reachability tracker: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return a.startBackgroundWorker()
|
||||
}
|
||||
|
||||
func (a *addrsManager) Close() {
|
||||
@@ -91,10 +119,18 @@ func (a *addrsManager) Close() {
|
||||
log.Warnf("error closing natmgr: %s", err)
|
||||
}
|
||||
}
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
err := a.addrsReachabilityTracker.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error closing addrs reachability tracker: %s", err)
|
||||
}
|
||||
}
|
||||
a.wg.Wait()
|
||||
}
|
||||
|
||||
func (a *addrsManager) NetNotifee() network.Notifiee {
|
||||
// Updating addrs in sync provides the nice property that
|
||||
// host.Addrs() just after host.Network().Listen(x) will return x
|
||||
return &network.NotifyBundle{
|
||||
ListenF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() },
|
||||
ListenCloseF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() },
|
||||
@@ -102,37 +138,53 @@ func (a *addrsManager) NetNotifee() network.Notifiee {
|
||||
}
|
||||
|
||||
func (a *addrsManager) triggerAddrsUpdate() {
|
||||
// This is ugly, we update here *and* in the background loop, but this ensures the nice property
|
||||
// that host.Addrs after host.Network().Listen(...) will return the recently added listen address.
|
||||
a.updateLocalAddrs()
|
||||
a.updateAddrs(false, nil)
|
||||
select {
|
||||
case a.triggerAddrsUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (a *addrsManager) background() error {
|
||||
autoRelayAddrsSub, err := a.eventbus.Subscribe(new(event.EvtAutoRelayAddrsUpdated))
|
||||
func (a *addrsManager) startBackgroundWorker() error {
|
||||
autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error subscribing to auto relay addrs: %s", err)
|
||||
}
|
||||
|
||||
autonatReachabilitySub, err := a.eventbus.Subscribe(new(event.EvtLocalReachabilityChanged))
|
||||
autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error subscribing to autonat reachability: %s", err)
|
||||
err1 := autoRelayAddrsSub.Close()
|
||||
if err1 != nil {
|
||||
err1 = fmt.Errorf("error closign autorelaysub: %w", err1)
|
||||
}
|
||||
err = fmt.Errorf("error subscribing to autonat reachability: %s", err)
|
||||
return errors.Join(err, err1)
|
||||
}
|
||||
|
||||
// ensure that we have the correct address after returning from Start()
|
||||
// update local addrs
|
||||
a.updateLocalAddrs()
|
||||
emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful)
|
||||
if err != nil {
|
||||
err1 := autoRelayAddrsSub.Close()
|
||||
if err1 != nil {
|
||||
err1 = fmt.Errorf("error closing autorelaysub: %w", err1)
|
||||
}
|
||||
err2 := autonatReachabilitySub.Close()
|
||||
if err2 != nil {
|
||||
err2 = fmt.Errorf("error closing autonat reachability: %w", err1)
|
||||
}
|
||||
err = fmt.Errorf("error subscribing to autonat reachability: %s", err)
|
||||
return errors.Join(err, err1, err2)
|
||||
}
|
||||
|
||||
var relayAddrs []ma.Multiaddr
|
||||
// update relay addrs in case we're private
|
||||
select {
|
||||
case e := <-autoRelayAddrsSub.Out():
|
||||
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
|
||||
a.updateRelayAddrs(evt.RelayAddrs)
|
||||
relayAddrs = slices.Clone(evt.RelayAddrs)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case e := <-autonatReachabilitySub.Out():
|
||||
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
|
||||
@@ -140,18 +192,25 @@ func (a *addrsManager) background() error {
|
||||
}
|
||||
default:
|
||||
}
|
||||
// update addresses before starting the worker loop. This ensures that any address updates
|
||||
// before calling addrsManager.Start are correctly reported after Start returns.
|
||||
a.updateAddrs(true, relayAddrs)
|
||||
|
||||
a.wg.Add(1)
|
||||
go func() {
|
||||
go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, relayAddrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription,
|
||||
emitter event.Emitter, relayAddrs []ma.Multiaddr,
|
||||
) {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
err := autoRelayAddrsSub.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error closing auto relay addrs sub: %s", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
err := autonatReachabilitySub.Close()
|
||||
err = autonatReachabilitySub.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error closing autonat reachability sub: %s", err)
|
||||
}
|
||||
@@ -159,24 +218,18 @@ func (a *addrsManager) background() error {
|
||||
|
||||
ticker := time.NewTicker(addrChangeTickrInterval)
|
||||
defer ticker.Stop()
|
||||
var prev []ma.Multiaddr
|
||||
var previousAddrs hostAddrs
|
||||
for {
|
||||
a.updateLocalAddrs()
|
||||
curr := a.Addrs()
|
||||
if a.areAddrsDifferent(prev, curr) {
|
||||
log.Debugf("host addresses updated: %s", curr)
|
||||
select {
|
||||
case a.addrsUpdatedChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
prev = curr
|
||||
currAddrs := a.updateAddrs(true, relayAddrs)
|
||||
a.notifyAddrsChanged(emitter, previousAddrs, currAddrs)
|
||||
previousAddrs = currAddrs
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-a.triggerAddrsUpdateChan:
|
||||
case <-a.triggerReachabilityUpdate:
|
||||
case e := <-autoRelayAddrsSub.Out():
|
||||
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
|
||||
a.updateRelayAddrs(evt.RelayAddrs)
|
||||
relayAddrs = slices.Clone(evt.RelayAddrs)
|
||||
}
|
||||
case e := <-autonatReachabilitySub.Out():
|
||||
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
|
||||
@@ -186,24 +239,102 @@ func (a *addrsManager) background() error {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateAddrs updates the addresses of the host and returns the new updated
|
||||
// addrs
|
||||
func (a *addrsManager) updateAddrs(updateRelayAddrs bool, relayAddrs []ma.Multiaddr) hostAddrs {
|
||||
// Must lock while doing both recompute and update as this method is called from
|
||||
// multiple goroutines.
|
||||
a.addrsMx.Lock()
|
||||
defer a.addrsMx.Unlock()
|
||||
|
||||
localAddrs := a.getLocalAddrs()
|
||||
var currReachableAddrs, currUnreachableAddrs []ma.Multiaddr
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
currReachableAddrs, currUnreachableAddrs = a.getConfirmedAddrs(localAddrs)
|
||||
}
|
||||
if !updateRelayAddrs {
|
||||
relayAddrs = a.currentAddrs.relayAddrs
|
||||
} else {
|
||||
// Copy the callers slice
|
||||
relayAddrs = slices.Clone(relayAddrs)
|
||||
}
|
||||
currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs)
|
||||
|
||||
a.currentAddrs = hostAddrs{
|
||||
addrs: append(a.currentAddrs.addrs[:0], currAddrs...),
|
||||
localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...),
|
||||
reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...),
|
||||
unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...),
|
||||
relayAddrs: append(a.currentAddrs.relayAddrs[:0], relayAddrs...),
|
||||
}
|
||||
|
||||
return hostAddrs{
|
||||
localAddrs: localAddrs,
|
||||
addrs: currAddrs,
|
||||
reachableAddrs: currReachableAddrs,
|
||||
unreachableAddrs: currUnreachableAddrs,
|
||||
relayAddrs: relayAddrs,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) {
|
||||
if areAddrsDifferent(previous.localAddrs, current.localAddrs) {
|
||||
log.Debugf("host local addresses updated: %s", current.localAddrs)
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs)
|
||||
}
|
||||
}
|
||||
if areAddrsDifferent(previous.addrs, current.addrs) {
|
||||
log.Debugf("host addresses updated: %s", current.localAddrs)
|
||||
select {
|
||||
case a.addrsUpdatedChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// We *must* send both reachability changed and addrs changed events from the
|
||||
// same goroutine to ensure correct ordering
|
||||
// Consider the events:
|
||||
// - addr x discovered
|
||||
// - addr x is reachable
|
||||
// - addr x removed
|
||||
// We must send these events in the same order. It'll be confusing for consumers
|
||||
// if the reachable event is received after the addr removed event.
|
||||
if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) ||
|
||||
areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) {
|
||||
log.Debugf("host reachable addrs updated: %s", current.localAddrs)
|
||||
if err := emitter.Emit(event.EvtHostReachableAddrsChanged{
|
||||
Reachable: slices.Clone(current.reachableAddrs),
|
||||
Unreachable: slices.Clone(current.unreachableAddrs),
|
||||
}); err != nil {
|
||||
log.Errorf("error sending host reachable addrs changed event: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Addrs returns the node's dialable addresses both public and private.
|
||||
// If autorelay is enabled and node reachability is private, it returns
|
||||
// the node's relay addresses and private network addresses.
|
||||
func (a *addrsManager) Addrs() []ma.Multiaddr {
|
||||
addrs := a.DirectAddrs()
|
||||
a.addrsMx.RLock()
|
||||
directAddrs := slices.Clone(a.currentAddrs.localAddrs)
|
||||
relayAddrs := slices.Clone(a.currentAddrs.relayAddrs)
|
||||
a.addrsMx.RUnlock()
|
||||
return a.getAddrs(directAddrs, relayAddrs)
|
||||
}
|
||||
|
||||
// getAddrs returns the node's dialable addresses. Mutates localAddrs
|
||||
func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
addrs := localAddrs
|
||||
rch := a.hostReachability.Load()
|
||||
if rch != nil && *rch == network.ReachabilityPrivate {
|
||||
a.addrsMx.RLock()
|
||||
// Delete public addresses if the node's reachability is private, and we have relay addresses
|
||||
if len(a.relayAddrs) > 0 {
|
||||
if len(relayAddrs) > 0 {
|
||||
addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr)
|
||||
addrs = append(addrs, a.relayAddrs...)
|
||||
addrs = append(addrs, relayAddrs...)
|
||||
}
|
||||
a.addrsMx.RUnlock()
|
||||
}
|
||||
// Make a copy. Consumers can modify the slice elements
|
||||
addrs = slices.Clone(a.addrsFactory(addrs))
|
||||
@@ -213,7 +344,8 @@ func (a *addrsManager) Addrs() []ma.Multiaddr {
|
||||
return addrs
|
||||
}
|
||||
|
||||
// HolePunchAddrs returns the node's public direct listen addresses for hole punching.
|
||||
// HolePunchAddrs returns all the host's direct public addresses, reachable or unreachable,
|
||||
// suitable for hole punching.
|
||||
func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
|
||||
addrs := a.DirectAddrs()
|
||||
addrs = slices.Clone(a.addrsFactory(addrs))
|
||||
@@ -230,26 +362,23 @@ func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
|
||||
func (a *addrsManager) DirectAddrs() []ma.Multiaddr {
|
||||
a.addrsMx.RLock()
|
||||
defer a.addrsMx.RUnlock()
|
||||
return slices.Clone(a.localAddrs)
|
||||
return slices.Clone(a.currentAddrs.localAddrs)
|
||||
}
|
||||
|
||||
func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) {
|
||||
a.addrsMx.Lock()
|
||||
defer a.addrsMx.Unlock()
|
||||
a.relayAddrs = append(a.relayAddrs[:0], addrs...)
|
||||
// ReachableAddrs returns all addresses of the host that are reachable from the internet
|
||||
func (a *addrsManager) ReachableAddrs() []ma.Multiaddr {
|
||||
a.addrsMx.RLock()
|
||||
defer a.addrsMx.RUnlock()
|
||||
return slices.Clone(a.currentAddrs.reachableAddrs)
|
||||
}
|
||||
|
||||
func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
|
||||
reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs()
|
||||
return removeNotInSource(reachableAddrs, localAddrs), removeNotInSource(unreachableAddrs, localAddrs)
|
||||
}
|
||||
|
||||
var p2pCircuitAddr = ma.StringCast("/p2p-circuit")
|
||||
|
||||
func (a *addrsManager) updateLocalAddrs() {
|
||||
localAddrs := a.getLocalAddrs()
|
||||
slices.SortFunc(localAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
|
||||
a.addrsMx.Lock()
|
||||
a.localAddrs = localAddrs
|
||||
a.addrsMx.Unlock()
|
||||
}
|
||||
|
||||
func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
|
||||
listenAddrs := a.listenAddrs()
|
||||
if len(listenAddrs) == 0 {
|
||||
@@ -260,8 +389,6 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
|
||||
finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs)
|
||||
finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All())
|
||||
|
||||
finalAddrs = ma.Unique(finalAddrs)
|
||||
|
||||
// Remove "/p2p-circuit" addresses from the list.
|
||||
// The p2p-circuit listener reports its address as just /p2p-circuit. This is
|
||||
// useless for dialing. Users need to manage their circuit addresses themselves,
|
||||
@@ -278,6 +405,8 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
|
||||
// Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered
|
||||
// using identify.
|
||||
finalAddrs = a.addCertHashes(finalAddrs)
|
||||
finalAddrs = ma.Unique(finalAddrs)
|
||||
slices.SortFunc(finalAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
return finalAddrs
|
||||
}
|
||||
|
||||
@@ -408,7 +537,7 @@ func (a *addrsManager) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (a *addrsManager) areAddrsDifferent(prev, current []ma.Multiaddr) bool {
|
||||
func areAddrsDifferent(prev, current []ma.Multiaddr) bool {
|
||||
// TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs
|
||||
prev = ma.Unique(prev)
|
||||
current = ma.Unique(current)
|
||||
@@ -547,3 +676,31 @@ func (i *interfaceAddrsCache) updateUnlocked() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeNotInSource removes items from addrs that are not present in source.
|
||||
// Modifies the addrs slice in place
|
||||
// addrs and source must be sorted using multiaddr.Compare.
|
||||
func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr {
|
||||
j := 0
|
||||
// mark entries not in source as nil
|
||||
for i, a := range addrs {
|
||||
// move right as long as a > source[j]
|
||||
for j < len(source) && a.Compare(source[j]) > 0 {
|
||||
j++
|
||||
}
|
||||
// a is not in source if we've reached the end, or a is lesser
|
||||
if j == len(source) || a.Compare(source[j]) < 0 {
|
||||
addrs[i] = nil
|
||||
}
|
||||
// a is in source, nothing to do
|
||||
}
|
||||
// j is the current element, i is the lowest index nil element
|
||||
i := 0
|
||||
for j := range len(addrs) {
|
||||
if addrs[j] != nil {
|
||||
addrs[i], addrs[j] = addrs[j], addrs[i]
|
||||
i++
|
||||
}
|
||||
}
|
||||
return addrs[:i]
|
||||
}
|
||||
|
@@ -1,13 +1,17 @@
|
||||
package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/event"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -30,7 +34,7 @@ func TestAppendNATAddrs(t *testing.T) {
|
||||
// nat mapping success, obsaddress ignored
|
||||
Listen: ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1"),
|
||||
Nat: ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1"),
|
||||
ObsAddrFunc: func(m ma.Multiaddr) []ma.Multiaddr {
|
||||
ObsAddrFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
|
||||
return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/100/quic-v1")}
|
||||
},
|
||||
Expected: []ma.Multiaddr{ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1")},
|
||||
@@ -116,7 +120,7 @@ func TestAppendNATAddrs(t *testing.T) {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
as := &addrsManager{
|
||||
natManager: &mockNatManager{
|
||||
GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr {
|
||||
GetMappingFunc: func(_ ma.Multiaddr) ma.Multiaddr {
|
||||
return tc.Nat
|
||||
},
|
||||
},
|
||||
@@ -135,7 +139,7 @@ type mockNatManager struct {
|
||||
GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr
|
||||
}
|
||||
|
||||
func (m *mockNatManager) Close() error {
|
||||
func (*mockNatManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -146,7 +150,7 @@ func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
|
||||
return m.GetMappingFunc(addr)
|
||||
}
|
||||
|
||||
func (m *mockNatManager) HasDiscoveredNAT() bool {
|
||||
func (*mockNatManager) HasDiscoveredNAT() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -170,6 +174,8 @@ type addrsManagerArgs struct {
|
||||
AddrsFactory AddrsFactory
|
||||
ObservedAddrsManager observedAddrsManager
|
||||
ListenAddrs func() []ma.Multiaddr
|
||||
AutoNATClient autonatv2Client
|
||||
Bus event.Bus
|
||||
}
|
||||
|
||||
type addrsManagerTestCase struct {
|
||||
@@ -179,13 +185,16 @@ type addrsManagerTestCase struct {
|
||||
}
|
||||
|
||||
func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTestCase {
|
||||
eb := eventbus.NewBus()
|
||||
eb := args.Bus
|
||||
if eb == nil {
|
||||
eb = eventbus.NewBus()
|
||||
}
|
||||
if args.AddrsFactory == nil {
|
||||
args.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs }
|
||||
}
|
||||
addrsUpdatedChan := make(chan struct{}, 1)
|
||||
am, err := newAddrsManager(
|
||||
eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan,
|
||||
eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, args.AutoNATClient,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -196,6 +205,7 @@ func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTe
|
||||
rchEm, err := eb.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(am.Close)
|
||||
return addrsManagerTestCase{
|
||||
addrsManager: am,
|
||||
PushRelay: func(relayAddrs []ma.Multiaddr) {
|
||||
@@ -326,7 +336,7 @@ func TestAddrsManager(t *testing.T) {
|
||||
}
|
||||
am := newAddrsManagerTestCase(t, addrsManagerArgs{
|
||||
ObservedAddrsManager: &mockObservedAddrs{
|
||||
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
|
||||
ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
|
||||
return quicAddrs
|
||||
},
|
||||
},
|
||||
@@ -342,7 +352,7 @@ func TestAddrsManager(t *testing.T) {
|
||||
t.Run("public addrs removed when private", func(t *testing.T) {
|
||||
am := newAddrsManagerTestCase(t, addrsManagerArgs{
|
||||
ObservedAddrsManager: &mockObservedAddrs{
|
||||
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
|
||||
ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
|
||||
return []ma.Multiaddr{publicQUIC}
|
||||
},
|
||||
},
|
||||
@@ -384,7 +394,7 @@ func TestAddrsManager(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
ObservedAddrsManager: &mockObservedAddrs{
|
||||
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr {
|
||||
ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
|
||||
return []ma.Multiaddr{publicQUIC}
|
||||
},
|
||||
},
|
||||
@@ -404,7 +414,7 @@ func TestAddrsManager(t *testing.T) {
|
||||
t.Run("updates addresses on signaling", func(t *testing.T) {
|
||||
updateChan := make(chan struct{})
|
||||
am := newAddrsManagerTestCase(t, addrsManagerArgs{
|
||||
AddrsFactory: func(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
AddrsFactory: func(_ []ma.Multiaddr) []ma.Multiaddr {
|
||||
select {
|
||||
case <-updateChan:
|
||||
return []ma.Multiaddr{publicQUIC}
|
||||
@@ -425,17 +435,95 @@ func TestAddrsManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddrsManagerReachabilityEvent(t *testing.T) {
|
||||
publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1")
|
||||
publicQUIC2, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1235/quic-v1")
|
||||
publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234")
|
||||
|
||||
bus := eventbus.NewBus()
|
||||
|
||||
sub, err := bus.Subscribe(new(event.EvtHostReachableAddrsChanged))
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
am := newAddrsManagerTestCase(t, addrsManagerArgs{
|
||||
Bus: bus,
|
||||
// currently they aren't being passed to the reachability tracker
|
||||
ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicQUIC2, publicTCP} },
|
||||
AutoNATClient: mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
if reqs[0].Addr.Equal(publicQUIC) {
|
||||
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
} else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) {
|
||||
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil
|
||||
}
|
||||
return autonatv2.Result{}, errors.New("invalid")
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
reachableAddrs := []ma.Multiaddr{publicQUIC}
|
||||
unreachableAddrs := []ma.Multiaddr{publicTCP, publicQUIC2}
|
||||
select {
|
||||
case e := <-sub.Out():
|
||||
evt := e.(event.EvtHostReachableAddrsChanged)
|
||||
require.ElementsMatch(t, reachableAddrs, evt.Reachable)
|
||||
require.ElementsMatch(t, unreachableAddrs, evt.Unreachable)
|
||||
require.ElementsMatch(t, reachableAddrs, am.ReachableAddrs())
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("expected event for reachability change")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveIfNotInSource(t *testing.T) {
|
||||
var addrs []ma.Multiaddr
|
||||
for i := 0; i < 10; i++ {
|
||||
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i)))
|
||||
}
|
||||
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
cases := []struct {
|
||||
addrs []ma.Multiaddr
|
||||
source []ma.Multiaddr
|
||||
expected []ma.Multiaddr
|
||||
}{
|
||||
{},
|
||||
{addrs: slices.Clone(addrs[:5]), source: nil, expected: nil},
|
||||
{addrs: nil, source: addrs, expected: nil},
|
||||
{addrs: []ma.Multiaddr{addrs[0]}, source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}},
|
||||
{addrs: slices.Clone(addrs), source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}},
|
||||
{addrs: slices.Clone(addrs), source: slices.Clone(addrs[5:]), expected: slices.Clone(addrs[5:])},
|
||||
{addrs: slices.Clone(addrs[:5]), source: []ma.Multiaddr{addrs[0], addrs[2], addrs[8]}, expected: []ma.Multiaddr{addrs[0], addrs[2]}},
|
||||
}
|
||||
for i, tc := range cases {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
addrs := removeNotInSource(tc.addrs, tc.source)
|
||||
require.ElementsMatch(t, tc.expected, addrs, "%s\n%s", tc.expected, tc.addrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAreAddrsDifferent(b *testing.B) {
|
||||
var addrs [10]ma.Multiaddr
|
||||
for i := 0; i < len(addrs); i++ {
|
||||
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i))
|
||||
}
|
||||
am := &addrsManager{}
|
||||
b.Run("areAddrsDifferent", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
am.areAddrsDifferent(addrs[:], addrs[:])
|
||||
areAddrsDifferent(addrs[:], addrs[:])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkRemoveIfNotInSource(b *testing.B) {
|
||||
var addrs [10]ma.Multiaddr
|
||||
for i := 0; i < len(addrs); i++ {
|
||||
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i))
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
removeNotInSource(slices.Clone(addrs[:5]), addrs[:])
|
||||
}
|
||||
}
|
||||
|
666
p2p/host/basic/addrs_reachability_tracker.go
Normal file
666
p2p/host/basic/addrs_reachability_tracker.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
type autonatv2Client interface {
|
||||
GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
// maxAddrsPerRequest is the maximum number of addresses to probe in a single request
|
||||
maxAddrsPerRequest = 10
|
||||
// maxTrackedAddrs is the maximum number of addresses to track
|
||||
// 10 addrs per transport for 5 transports
|
||||
maxTrackedAddrs = 50
|
||||
// defaultMaxConcurrency is the default number of concurrent workers for reachability checks
|
||||
defaultMaxConcurrency = 5
|
||||
// newAddrsProbeDelay is the delay before probing new addr's reachability.
|
||||
newAddrsProbeDelay = 1 * time.Second
|
||||
)
|
||||
|
||||
// addrsReachabilityTracker tracks reachability for addresses.
|
||||
// Use UpdateAddrs to provide addresses for tracking reachability.
|
||||
// reachabilityUpdateCh is notified when reachability for any of the tracked address changes.
|
||||
type addrsReachabilityTracker struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
client autonatv2Client
|
||||
// reachabilityUpdateCh is used to notify when reachability may have changed
|
||||
reachabilityUpdateCh chan struct{}
|
||||
maxConcurrency int
|
||||
newAddrsProbeDelay time.Duration
|
||||
probeManager *probeManager
|
||||
newAddrs chan []ma.Multiaddr
|
||||
clock clock.Clock
|
||||
|
||||
mx sync.Mutex
|
||||
reachableAddrs []ma.Multiaddr
|
||||
unreachableAddrs []ma.Multiaddr
|
||||
}
|
||||
|
||||
// newAddrsReachabilityTracker returns a new addrsReachabilityTracker.
|
||||
// reachabilityUpdateCh is notified when reachability for any of the tracked address changes.
|
||||
func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh chan struct{}, cl clock.Clock) *addrsReachabilityTracker {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if cl == nil {
|
||||
cl = clock.New()
|
||||
}
|
||||
return &addrsReachabilityTracker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: client,
|
||||
reachabilityUpdateCh: reachabilityUpdateCh,
|
||||
probeManager: newProbeManager(cl.Now),
|
||||
newAddrsProbeDelay: newAddrsProbeDelay,
|
||||
maxConcurrency: defaultMaxConcurrency,
|
||||
newAddrs: make(chan []ma.Multiaddr, 1),
|
||||
clock: cl,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) {
|
||||
select {
|
||||
case r.newAddrs <- slices.Clone(addrs):
|
||||
case <-r.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs)
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) Start() error {
|
||||
r.wg.Add(1)
|
||||
go r.background()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) Close() error {
|
||||
r.cancel()
|
||||
r.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultReachabilityRefreshInterval is the default interval to refresh reachability.
|
||||
// In steady state, we check for any required probes every refresh interval.
|
||||
// This doesn't mean we'll probe for any particular address, only that we'll check
|
||||
// if any address needs to be probed.
|
||||
defaultReachabilityRefreshInterval = 5 * time.Minute
|
||||
// maxBackoffInterval is the maximum back off in case we're unable to probe for reachability.
|
||||
// We may be unable to confirm addresses in case there are no valid peers with autonatv2
|
||||
// or the autonatv2 subsystem is consistently erroring.
|
||||
maxBackoffInterval = 5 * time.Minute
|
||||
// backoffStartInterval is the initial back off in case we're unable to probe for reachability.
|
||||
backoffStartInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
func (r *addrsReachabilityTracker) background() {
|
||||
defer r.wg.Done()
|
||||
|
||||
// probeTicker is used to trigger probes at regular intervals
|
||||
probeTicker := r.clock.Ticker(defaultReachabilityRefreshInterval)
|
||||
defer probeTicker.Stop()
|
||||
|
||||
// probeTimer is used to trigger probes at specific times
|
||||
probeTimer := r.clock.Timer(time.Duration(math.MaxInt64))
|
||||
defer probeTimer.Stop()
|
||||
nextProbeTime := time.Time{}
|
||||
|
||||
var task reachabilityTask
|
||||
var backoffInterval time.Duration
|
||||
var currReachable, currUnreachable, prevReachable, prevUnreachable []ma.Multiaddr
|
||||
for {
|
||||
select {
|
||||
case <-probeTicker.C:
|
||||
// don't start a probe if we have a scheduled probe
|
||||
if task.BackoffCh == nil && nextProbeTime.IsZero() {
|
||||
task = r.refreshReachability()
|
||||
}
|
||||
case <-probeTimer.C:
|
||||
if task.BackoffCh == nil {
|
||||
task = r.refreshReachability()
|
||||
}
|
||||
nextProbeTime = time.Time{}
|
||||
case backoff := <-task.BackoffCh:
|
||||
task = reachabilityTask{}
|
||||
// On completion, start the next probe immediately, or wait for backoff.
|
||||
// In case there are no further probes, the reachability tracker will return an empty task,
|
||||
// which hangs forever. Eventually, we'll refresh again when the ticker fires.
|
||||
if backoff {
|
||||
backoffInterval = newBackoffInterval(backoffInterval)
|
||||
} else {
|
||||
backoffInterval = -1 * time.Second // negative to trigger next probe immediately
|
||||
}
|
||||
nextProbeTime = r.clock.Now().Add(backoffInterval)
|
||||
case addrs := <-r.newAddrs:
|
||||
if task.BackoffCh != nil { // cancel running task.
|
||||
task.Cancel()
|
||||
<-task.BackoffCh // ignore backoff from cancelled task
|
||||
task = reachabilityTask{}
|
||||
}
|
||||
r.updateTrackedAddrs(addrs)
|
||||
newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay)
|
||||
if nextProbeTime.Before(newAddrsNextTime) {
|
||||
nextProbeTime = newAddrsNextTime
|
||||
}
|
||||
case <-r.ctx.Done():
|
||||
if task.BackoffCh != nil {
|
||||
task.Cancel()
|
||||
<-task.BackoffCh
|
||||
task = reachabilityTask{}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
currReachable, currUnreachable = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0])
|
||||
if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) {
|
||||
r.notify()
|
||||
}
|
||||
prevReachable = append(prevReachable[:0], currReachable...)
|
||||
prevUnreachable = append(prevUnreachable[:0], currUnreachable...)
|
||||
if !nextProbeTime.IsZero() {
|
||||
probeTimer.Reset(nextProbeTime.Sub(r.clock.Now()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newBackoffInterval(current time.Duration) time.Duration {
|
||||
if current <= 0 {
|
||||
return backoffStartInterval
|
||||
}
|
||||
current *= 2
|
||||
if current > maxBackoffInterval {
|
||||
return maxBackoffInterval
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
|
||||
reachable, unreachable = r.probeManager.AppendConfirmedAddrs(reachable, unreachable)
|
||||
r.mx.Lock()
|
||||
r.reachableAddrs = append(r.reachableAddrs[:0], reachable...)
|
||||
r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...)
|
||||
r.mx.Unlock()
|
||||
return reachable, unreachable
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) notify() {
|
||||
select {
|
||||
case r.reachabilityUpdateCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) {
|
||||
addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool {
|
||||
return !manet.IsPublicAddr(a)
|
||||
})
|
||||
if len(addrs) > maxTrackedAddrs {
|
||||
log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs)
|
||||
addrs = addrs[:maxTrackedAddrs]
|
||||
}
|
||||
r.probeManager.UpdateAddrs(addrs)
|
||||
}
|
||||
|
||||
type probe = []autonatv2.Request
|
||||
|
||||
const probeTimeout = 30 * time.Second
|
||||
|
||||
// reachabilityTask is a task to refresh reachability.
|
||||
// Waiting on the zero value blocks forever.
|
||||
type reachabilityTask struct {
|
||||
Cancel context.CancelFunc
|
||||
// BackoffCh returns whether the caller should backoff before
|
||||
// refreshing reachability
|
||||
BackoffCh chan bool
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask {
|
||||
if len(r.probeManager.GetProbe()) == 0 {
|
||||
return reachabilityTask{}
|
||||
}
|
||||
resCh := make(chan bool, 1)
|
||||
ctx, cancel := context.WithTimeout(r.ctx, 5*time.Minute)
|
||||
r.wg.Add(1)
|
||||
// We run probes provided by addrsTracker. It stops probing when any
|
||||
// of the following happens:
|
||||
// - there are no more probes to run
|
||||
// - context is completed
|
||||
// - there are too many consecutive failures from the client
|
||||
// - the client has no valid peers to probe
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
defer cancel()
|
||||
client := &errCountingClient{autonatv2Client: r.client, MaxConsecutiveErrors: maxConsecutiveErrors}
|
||||
var backoff atomic.Bool
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(r.maxConcurrency)
|
||||
for range r.maxConcurrency {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
reqs := r.probeManager.GetProbe()
|
||||
if len(reqs) == 0 {
|
||||
return
|
||||
}
|
||||
r.probeManager.MarkProbeInProgress(reqs)
|
||||
rctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
res, err := client.GetReachability(rctx, reqs)
|
||||
cancel()
|
||||
r.probeManager.CompleteProbe(reqs, res, err)
|
||||
if isErrorPersistent(err) {
|
||||
backoff.Store(true)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
resCh <- backoff.Load()
|
||||
}()
|
||||
return reachabilityTask{Cancel: cancel, BackoffCh: resCh}
|
||||
}
|
||||
|
||||
var errTooManyConsecutiveFailures = errors.New("too many consecutive failures")
|
||||
|
||||
// errCountingClient counts errors from autonatv2Client and wraps the errors in response with a
|
||||
// errTooManyConsecutiveFailures in case of persistent failures from autonatv2 module.
|
||||
type errCountingClient struct {
|
||||
autonatv2Client
|
||||
MaxConsecutiveErrors int
|
||||
mx sync.Mutex
|
||||
consecutiveErrors int
|
||||
}
|
||||
|
||||
func (c *errCountingClient) GetReachability(ctx context.Context, reqs probe) (autonatv2.Result, error) {
|
||||
res, err := c.autonatv2Client.GetReachability(ctx, reqs)
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
if err != nil && !errors.Is(err, context.Canceled) { // ignore canceled errors, they're not errors from autonatv2
|
||||
c.consecutiveErrors++
|
||||
if c.consecutiveErrors > c.MaxConsecutiveErrors {
|
||||
err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err)
|
||||
}
|
||||
if errors.Is(err, autonatv2.ErrPrivateAddrs) {
|
||||
log.Errorf("private IP addr in autonatv2 request: %s", err)
|
||||
}
|
||||
} else {
|
||||
c.consecutiveErrors = 0
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
const maxConsecutiveErrors = 20
|
||||
|
||||
// isErrorPersistent returns whether the error will repeat on future probes for a while
|
||||
func isErrorPersistent(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, autonatv2.ErrPrivateAddrs) || errors.Is(err, autonatv2.ErrNoPeers) ||
|
||||
errors.Is(err, errTooManyConsecutiveFailures)
|
||||
}
|
||||
|
||||
const (
|
||||
// recentProbeInterval is the interval to probe addresses that have been refused
|
||||
// these are generally addresses with newer transports for which we don't have many peers
|
||||
// capable of dialing the transport
|
||||
recentProbeInterval = 10 * time.Minute
|
||||
// maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which
|
||||
// we wait for `recentProbeInterval` before probing again
|
||||
maxConsecutiveRefusals = 5
|
||||
// maxRecentDialsPerAddr is the maximum number of dials on an address before we stop probing for the address.
|
||||
// This is used to prevent infinite probing of an address whose status is indeterminate for any reason.
|
||||
maxRecentDialsPerAddr = 10
|
||||
// confidence is the absolute difference between the number of successes and failures for an address
|
||||
// targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval`
|
||||
// before probing again.
|
||||
targetConfidence = 3
|
||||
// minConfidence is the confidence threshold for an address to be considered reachable or unreachable
|
||||
// confidence is the absolute difference between the number of successes and failures for an address
|
||||
minConfidence = 2
|
||||
// maxRecentDialsWindow is the maximum number of recent probe results to consider for a single address
|
||||
//
|
||||
// +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure
|
||||
// and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to
|
||||
// targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval.
|
||||
maxRecentDialsWindow = targetConfidence + 2
|
||||
// highConfidenceAddrProbeInterval is the maximum interval between probes for an address
|
||||
highConfidenceAddrProbeInterval = 1 * time.Hour
|
||||
// maxProbeResultTTL is the maximum time to keep probe results for an address
|
||||
maxProbeResultTTL = maxRecentDialsWindow * highConfidenceAddrProbeInterval
|
||||
)
|
||||
|
||||
// probeManager tracks reachability for a set of addresses by periodically probing reachability with autonatv2.
|
||||
// A Probe is a list of addresses which can be tested for reachability with autonatv2.
|
||||
// This struct decides the priority order of addresses for testing reachability, and throttles in case there have
|
||||
// been too many probes for an address in the `ProbeInterval`.
|
||||
//
|
||||
// Use the `runProbes` function to execute the probes with an autonatv2 client.
|
||||
type probeManager struct {
|
||||
now func() time.Time
|
||||
|
||||
mx sync.Mutex
|
||||
inProgressProbes map[string]int // addr -> count
|
||||
inProgressProbesTotal int
|
||||
statuses map[string]*addrStatus
|
||||
addrs []ma.Multiaddr
|
||||
}
|
||||
|
||||
// newProbeManager creates a new probe manager.
|
||||
func newProbeManager(now func() time.Time) *probeManager {
|
||||
return &probeManager{
|
||||
statuses: make(map[string]*addrStatus),
|
||||
inProgressProbes: make(map[string]int),
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
// AppendConfirmedAddrs appends the current confirmed reachable and unreachable addresses.
|
||||
func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
for _, a := range m.addrs {
|
||||
s := m.statuses[string(a.Bytes())]
|
||||
s.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results
|
||||
switch s.Reachability() {
|
||||
case network.ReachabilityPublic:
|
||||
reachable = append(reachable, a)
|
||||
case network.ReachabilityPrivate:
|
||||
unreachable = append(unreachable, a)
|
||||
}
|
||||
}
|
||||
return reachable, unreachable
|
||||
}
|
||||
|
||||
// UpdateAddrs updates the tracked addrs
|
||||
func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
statuses := make(map[string]*addrStatus, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
k := string(addr.Bytes())
|
||||
if _, ok := m.statuses[k]; !ok {
|
||||
statuses[k] = &addrStatus{Addr: addr}
|
||||
} else {
|
||||
statuses[k] = m.statuses[k]
|
||||
}
|
||||
}
|
||||
m.addrs = addrs
|
||||
m.statuses = statuses
|
||||
}
|
||||
|
||||
// GetProbe returns the next probe. Returns zero value in case there are no more probes.
|
||||
// Probes that are run against an autonatv2 client should be marked in progress with
|
||||
// `MarkProbeInProgress` before running.
|
||||
func (m *probeManager) GetProbe() probe {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
now := m.now()
|
||||
for i, a := range m.addrs {
|
||||
ab := a.Bytes()
|
||||
pc := m.statuses[string(ab)].RequiredProbeCount(now)
|
||||
if m.inProgressProbes[string(ab)] >= pc {
|
||||
continue
|
||||
}
|
||||
reqs := make(probe, 0, maxAddrsPerRequest)
|
||||
reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true})
|
||||
// We have the first(primary) address. Append other addresses, ignoring inprogress probes
|
||||
// on secondary addresses. The expectation is that the primary address will
|
||||
// be dialed.
|
||||
for j := 1; j < len(m.addrs); j++ {
|
||||
k := (i + j) % len(m.addrs)
|
||||
ab := m.addrs[k].Bytes()
|
||||
pc := m.statuses[string(ab)].RequiredProbeCount(now)
|
||||
if pc == 0 {
|
||||
continue
|
||||
}
|
||||
reqs = append(reqs, autonatv2.Request{Addr: m.addrs[k], SendDialData: true})
|
||||
if len(reqs) >= maxAddrsPerRequest {
|
||||
break
|
||||
}
|
||||
}
|
||||
return reqs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkProbeInProgress should be called when a probe is started.
|
||||
// All in progress probes *MUST* be completed with `CompleteProbe`
|
||||
func (m *probeManager) MarkProbeInProgress(reqs probe) {
|
||||
if len(reqs) == 0 {
|
||||
return
|
||||
}
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
m.inProgressProbes[string(reqs[0].Addr.Bytes())]++
|
||||
m.inProgressProbesTotal++
|
||||
}
|
||||
|
||||
// InProgressProbes returns the number of probes that are currently in progress.
|
||||
func (m *probeManager) InProgressProbes() int {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
return m.inProgressProbesTotal
|
||||
}
|
||||
|
||||
// CompleteProbe should be called when a probe completes.
|
||||
func (m *probeManager) CompleteProbe(reqs probe, res autonatv2.Result, err error) {
|
||||
now := m.now()
|
||||
|
||||
if len(reqs) == 0 {
|
||||
// should never happen
|
||||
return
|
||||
}
|
||||
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
// decrement in-progress count for the first address
|
||||
primaryAddrKey := string(reqs[0].Addr.Bytes())
|
||||
m.inProgressProbes[primaryAddrKey]--
|
||||
if m.inProgressProbes[primaryAddrKey] <= 0 {
|
||||
delete(m.inProgressProbes, primaryAddrKey)
|
||||
}
|
||||
m.inProgressProbesTotal--
|
||||
|
||||
// nothing to do if the request errored.
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Consider only primary address as refused. This increases the number of
|
||||
// refused probes, but refused probes are cheap for a server as no dials are made.
|
||||
if res.AllAddrsRefused {
|
||||
if s, ok := m.statuses[primaryAddrKey]; ok {
|
||||
s.AddRefusal(now)
|
||||
}
|
||||
return
|
||||
}
|
||||
dialAddrKey := string(res.Addr.Bytes())
|
||||
if dialAddrKey != primaryAddrKey {
|
||||
if s, ok := m.statuses[primaryAddrKey]; ok {
|
||||
s.AddRefusal(now)
|
||||
}
|
||||
}
|
||||
|
||||
// record the result for the dialed address
|
||||
if s, ok := m.statuses[dialAddrKey]; ok {
|
||||
s.AddOutcome(now, res.Reachability, maxRecentDialsWindow)
|
||||
}
|
||||
}
|
||||
|
||||
type dialOutcome struct {
|
||||
Success bool
|
||||
At time.Time
|
||||
}
|
||||
|
||||
type addrStatus struct {
|
||||
Addr ma.Multiaddr
|
||||
lastRefusalTime time.Time
|
||||
consecutiveRefusals int
|
||||
dialTimes []time.Time
|
||||
outcomes []dialOutcome
|
||||
}
|
||||
|
||||
func (s *addrStatus) Reachability() network.Reachability {
|
||||
rch, _, _ := s.reachabilityAndCounts()
|
||||
return rch
|
||||
}
|
||||
|
||||
func (s *addrStatus) RequiredProbeCount(now time.Time) int {
|
||||
if s.consecutiveRefusals >= maxConsecutiveRefusals {
|
||||
if now.Sub(s.lastRefusalTime) < recentProbeInterval {
|
||||
return 0
|
||||
}
|
||||
// reset every `recentProbeInterval`
|
||||
s.lastRefusalTime = time.Time{}
|
||||
s.consecutiveRefusals = 0
|
||||
}
|
||||
|
||||
// Don't probe if we have probed too many times recently
|
||||
rd := s.recentDialCount(now)
|
||||
if rd >= maxRecentDialsPerAddr {
|
||||
return 0
|
||||
}
|
||||
|
||||
return s.requiredProbeCountForConfirmation(now)
|
||||
}
|
||||
|
||||
func (s *addrStatus) requiredProbeCountForConfirmation(now time.Time) int {
|
||||
reachability, successes, failures := s.reachabilityAndCounts()
|
||||
confidence := successes - failures
|
||||
if confidence < 0 {
|
||||
confidence = -confidence
|
||||
}
|
||||
cnt := targetConfidence - confidence
|
||||
if cnt > 0 {
|
||||
return cnt
|
||||
}
|
||||
// we have enough confirmations; check if we should refresh
|
||||
|
||||
// Should never happen. The confidence logic above should require a few probes.
|
||||
if len(s.outcomes) == 0 {
|
||||
return 0
|
||||
}
|
||||
lastOutcome := s.outcomes[len(s.outcomes)-1]
|
||||
// If the last probe result is old, we need to retest
|
||||
if now.Sub(lastOutcome.At) > highConfidenceAddrProbeInterval {
|
||||
return 1
|
||||
}
|
||||
// if the last probe result was different from reachability, probe again.
|
||||
switch reachability {
|
||||
case network.ReachabilityPublic:
|
||||
if !lastOutcome.Success {
|
||||
return 1
|
||||
}
|
||||
case network.ReachabilityPrivate:
|
||||
if lastOutcome.Success {
|
||||
return 1
|
||||
}
|
||||
default:
|
||||
// this should never happen
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *addrStatus) AddRefusal(now time.Time) {
|
||||
s.lastRefusalTime = now
|
||||
s.consecutiveRefusals++
|
||||
}
|
||||
|
||||
func (s *addrStatus) AddOutcome(at time.Time, rch network.Reachability, windowSize int) {
|
||||
s.lastRefusalTime = time.Time{}
|
||||
s.consecutiveRefusals = 0
|
||||
|
||||
s.dialTimes = append(s.dialTimes, at)
|
||||
for i, t := range s.dialTimes {
|
||||
if at.Sub(t) < recentProbeInterval {
|
||||
s.dialTimes = slices.Delete(s.dialTimes, 0, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.RemoveBefore(at.Add(-maxProbeResultTTL)) // remove old outcomes
|
||||
success := false
|
||||
switch rch {
|
||||
case network.ReachabilityPublic:
|
||||
success = true
|
||||
case network.ReachabilityPrivate:
|
||||
success = false
|
||||
default:
|
||||
return // don't store the outcome if reachability is unknown
|
||||
}
|
||||
s.outcomes = append(s.outcomes, dialOutcome{At: at, Success: success})
|
||||
if len(s.outcomes) > windowSize {
|
||||
s.outcomes = slices.Delete(s.outcomes, 0, len(s.outcomes)-windowSize)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveBefore removes outcomes before t
|
||||
func (s *addrStatus) RemoveBefore(t time.Time) {
|
||||
end := 0
|
||||
for ; end < len(s.outcomes); end++ {
|
||||
if !s.outcomes[end].At.Before(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.outcomes = slices.Delete(s.outcomes, 0, end)
|
||||
}
|
||||
|
||||
func (s *addrStatus) recentDialCount(now time.Time) int {
|
||||
cnt := 0
|
||||
for _, t := range slices.Backward(s.dialTimes) {
|
||||
if now.Sub(t) > recentProbeInterval {
|
||||
break
|
||||
}
|
||||
cnt++
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
|
||||
func (s *addrStatus) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) {
|
||||
for _, r := range s.outcomes {
|
||||
if r.Success {
|
||||
successes++
|
||||
} else {
|
||||
failures++
|
||||
}
|
||||
}
|
||||
if successes-failures >= minConfidence {
|
||||
return network.ReachabilityPublic, successes, failures
|
||||
}
|
||||
if failures-successes >= minConfidence {
|
||||
return network.ReachabilityPrivate, successes, failures
|
||||
}
|
||||
return network.ReachabilityUnknown, successes, failures
|
||||
}
|
919
p2p/host/basic/addrs_reachability_tracker_test.go
Normal file
919
p2p/host/basic/addrs_reachability_tracker_test.go
Normal file
@@ -0,0 +1,919 @@
|
||||
package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProbeManager(t *testing.T) {
|
||||
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
|
||||
pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1")
|
||||
pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1")
|
||||
|
||||
cl := clock.NewMock()
|
||||
|
||||
nextProbe := func(pm *probeManager) []autonatv2.Request {
|
||||
reqs := pm.GetProbe()
|
||||
if len(reqs) != 0 {
|
||||
pm.MarkProbeInProgress(reqs)
|
||||
}
|
||||
return reqs
|
||||
}
|
||||
|
||||
makeNewProbeManager := func(addrs []ma.Multiaddr) *probeManager {
|
||||
pm := newProbeManager(cl.Now)
|
||||
pm.UpdateAddrs(addrs)
|
||||
return pm
|
||||
}
|
||||
|
||||
t.Run("addrs updates", func(t *testing.T) {
|
||||
pm := newProbeManager(cl.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub1, pub2})
|
||||
for {
|
||||
reqs := nextProbe(pm)
|
||||
if len(reqs) == 0 {
|
||||
break
|
||||
}
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
reachable, _ := pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1, pub2})
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub3})
|
||||
|
||||
reachable, _ = pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Empty(t, reachable)
|
||||
require.Len(t, pm.statuses, 1)
|
||||
})
|
||||
|
||||
t.Run("inprogress", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
reqs1 := pm.GetProbe()
|
||||
reqs2 := pm.GetProbe()
|
||||
require.Equal(t, reqs1, reqs2)
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
}
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}})
|
||||
}
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
})
|
||||
|
||||
t.Run("refusals", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
var probes [][]autonatv2.Request
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
probes = append(probes, reqs)
|
||||
}
|
||||
// first one refused second one successful
|
||||
for _, p := range probes {
|
||||
pm.CompleteProbe(p, autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
// the second address is validated!
|
||||
probes = nil
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}})
|
||||
probes = append(probes, reqs)
|
||||
}
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
for _, p := range probes {
|
||||
pm.CompleteProbe(p, autonatv2.Result{AllAddrsRefused: true}, nil)
|
||||
}
|
||||
// all requests refused; no more probes for too many refusals
|
||||
reqs = pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
|
||||
cl.Add(recentProbeInterval)
|
||||
reqs = pm.GetProbe()
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}})
|
||||
})
|
||||
|
||||
t.Run("successes", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
for j := 0; j < 2; j++ {
|
||||
for i := 0; i < targetConfidence; i++ {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
}
|
||||
// all addrs confirmed
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
|
||||
cl.Add(highConfidenceAddrProbeInterval + time.Millisecond)
|
||||
reqs = nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
reqs = nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}})
|
||||
})
|
||||
|
||||
t.Run("throttling on indeterminate reachability", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
reachability := network.ReachabilityPublic
|
||||
nextReachability := func() network.Reachability {
|
||||
if reachability == network.ReachabilityPublic {
|
||||
reachability = network.ReachabilityPrivate
|
||||
} else {
|
||||
reachability = network.ReachabilityPublic
|
||||
}
|
||||
return reachability
|
||||
}
|
||||
// both addresses are indeterminate
|
||||
for range 2 * maxRecentDialsPerAddr {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil)
|
||||
}
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
|
||||
cl.Add(recentProbeInterval + time.Millisecond)
|
||||
reqs = pm.GetProbe()
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
for range 2 * maxRecentDialsPerAddr {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil)
|
||||
}
|
||||
reqs = pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
})
|
||||
|
||||
t.Run("reachabilityUpdate", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
for range 2 * targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
if reqs[0].Addr.Equal(pub1) {
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
} else {
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub2, Idx: 0, Reachability: network.ReachabilityPrivate}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2})
|
||||
})
|
||||
t.Run("expiry", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1})
|
||||
for range 2 * targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
|
||||
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
|
||||
cl.Add(maxProbeResultTTL + 1*time.Second)
|
||||
reachable, unreachable = pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
})
|
||||
}
|
||||
|
||||
type mockAutoNATClient struct {
|
||||
F func(context.Context, []autonatv2.Request) (autonatv2.Result, error)
|
||||
}
|
||||
|
||||
func (m mockAutoNATClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return m.F(ctx, reqs)
|
||||
}
|
||||
|
||||
var _ autonatv2Client = mockAutoNATClient{}
|
||||
|
||||
func TestAddrsReachabilityTracker(t *testing.T) {
|
||||
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
|
||||
pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1")
|
||||
pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1")
|
||||
pri := ma.StringCast("/ip4/192.168.1.1/tcp/1")
|
||||
|
||||
newTracker := func(cli mockAutoNATClient, cl clock.Clock) *addrsReachabilityTracker {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if cl == nil {
|
||||
cl = clock.New()
|
||||
}
|
||||
tr := &addrsReachabilityTracker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: cli,
|
||||
newAddrs: make(chan []ma.Multiaddr, 1),
|
||||
reachabilityUpdateCh: make(chan struct{}, 1),
|
||||
maxConcurrency: 3,
|
||||
newAddrsProbeDelay: 0 * time.Second,
|
||||
probeManager: newProbeManager(cl.Now),
|
||||
clock: cl,
|
||||
}
|
||||
err := tr.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := tr.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
return tr
|
||||
}
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
// pub1 reachable, pub2 unreachable, pub3 ignored
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
for i, req := range reqs {
|
||||
if req.Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
|
||||
} else if req.Addr.Equal(pub2) {
|
||||
return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil
|
||||
}
|
||||
}
|
||||
return autonatv2.Result{}, autonatv2.ErrNoPeers
|
||||
},
|
||||
}
|
||||
tr := newTracker(mockClient, nil)
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub2, pub1, pri})
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
reachable, unreachable := tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1)
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2)
|
||||
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pri})
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
reachable, unreachable = tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1)
|
||||
require.Empty(t, unreachable)
|
||||
})
|
||||
|
||||
t.Run("confirmed addrs ordering", func(t *testing.T) {
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
},
|
||||
}
|
||||
tr := newTracker(mockClient, nil)
|
||||
var addrs []ma.Multiaddr
|
||||
for i := 0; i < 10; i++ {
|
||||
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i)))
|
||||
}
|
||||
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return -a.Compare(b) }) // sort in reverse order
|
||||
tr.UpdateAddrs(addrs)
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
reachable, unreachable := tr.ConfirmedAddrs()
|
||||
require.Empty(t, unreachable)
|
||||
|
||||
orderedAddrs := slices.Clone(addrs)
|
||||
slices.Reverse(orderedAddrs)
|
||||
require.Equal(t, reachable, orderedAddrs, "%s %s", reachable, addrs)
|
||||
})
|
||||
|
||||
t.Run("backoff", func(t *testing.T) {
|
||||
notify := make(chan struct{}, 1)
|
||||
drainNotify := func() bool {
|
||||
found := false
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
found = true
|
||||
default:
|
||||
return found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var allow atomic.Bool
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
select {
|
||||
case notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if !allow.Load() {
|
||||
return autonatv2.Result{}, autonatv2.ErrNoPeers
|
||||
}
|
||||
if reqs[0].Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cl := clock.NewMock()
|
||||
tr := newTracker(mockClient, cl)
|
||||
|
||||
// update addrs and wait for initial checks
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
// need to update clock after the background goroutine processes the new addrs
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cl.Add(1)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.True(t, drainNotify()) // check that we did receive probes
|
||||
|
||||
backoffInterval := backoffStartInterval
|
||||
for i := 0; i < 4; i++ {
|
||||
drainNotify()
|
||||
cl.Add(backoffInterval / 2)
|
||||
select {
|
||||
case <-notify:
|
||||
t.Fatal("unexpected call")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
cl.Add(backoffInterval/2 + 1) // +1 to push it slightly over the backoff interval
|
||||
backoffInterval *= 2
|
||||
select {
|
||||
case <-notify:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected probe")
|
||||
}
|
||||
reachable, unreachable := tr.ConfirmedAddrs()
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
}
|
||||
allow.Store(true)
|
||||
drainNotify()
|
||||
cl.Add(backoffInterval + 1)
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("unexpected reachability update")
|
||||
}
|
||||
reachable, unreachable := tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
})
|
||||
|
||||
t.Run("event update", func(t *testing.T) {
|
||||
// allow minConfidence probes to pass
|
||||
called := make(chan struct{}, minConfidence)
|
||||
notify := make(chan struct{})
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
select {
|
||||
case called <- struct{}{}:
|
||||
notify <- struct{}{}
|
||||
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
default:
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
tr := newTracker(mockClient, nil)
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
for i := 0; i < minConfidence; i++ {
|
||||
select {
|
||||
case <-notify:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected call to autonat client")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
reachable, unreachable := tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1}) // same addrs shouldn't get update
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
t.Fatal("didn't expect reachability update")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub2})
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
reachable, unreachable := tr.ConfirmedAddrs()
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("refresh after reset interval", func(t *testing.T) {
|
||||
notify := make(chan struct{}, 1)
|
||||
drainNotify := func() bool {
|
||||
found := false
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
found = true
|
||||
default:
|
||||
return found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
select {
|
||||
case notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if reqs[0].Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cl := clock.NewMock()
|
||||
tr := newTracker(mockClient, cl)
|
||||
|
||||
// update addrs and wait for initial checks
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
// need to update clock after the background goroutine processes the new addrs
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cl.Add(1)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.True(t, drainNotify()) // check that we did receive probes
|
||||
cl.Add(highConfidenceAddrProbeInterval / 2)
|
||||
select {
|
||||
case <-notify:
|
||||
t.Fatal("unexpected call")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
cl.Add(highConfidenceAddrProbeInterval/2 + defaultReachabilityRefreshInterval) // defaultResetInterval for the next probe time
|
||||
select {
|
||||
case <-notify:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected probe")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshReachability(t *testing.T) {
|
||||
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
|
||||
pub2 := ma.StringCast("/ip4/1.1.1.1/tcp/2")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
newTracker := func(client autonatv2Client, pm *probeManager) *addrsReachabilityTracker {
|
||||
return &addrsReachabilityTracker{
|
||||
probeManager: pm,
|
||||
client: client,
|
||||
clock: clock.New(),
|
||||
maxConcurrency: 3,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
t.Run("backoff on ErrNoValidPeers", func(t *testing.T) {
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return autonatv2.Result{}, autonatv2.ErrNoPeers
|
||||
},
|
||||
}
|
||||
|
||||
addrTracker := newProbeManager(time.Now)
|
||||
addrTracker.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
r := newTracker(mockClient, addrTracker)
|
||||
res := r.refreshReachability()
|
||||
require.True(t, <-res.BackoffCh)
|
||||
require.Equal(t, addrTracker.InProgressProbes(), 0)
|
||||
})
|
||||
|
||||
t.Run("returns backoff on errTooManyConsecutiveFailures", func(t *testing.T) {
|
||||
// Create a client that always returns ErrDialRefused
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return autonatv2.Result{}, errors.New("test error")
|
||||
},
|
||||
}
|
||||
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
r := newTracker(mockClient, pm)
|
||||
result := r.refreshReachability()
|
||||
require.True(t, <-result.BackoffCh)
|
||||
require.Equal(t, pm.InProgressProbes(), 0)
|
||||
})
|
||||
|
||||
t.Run("quits on cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
block := make(chan struct{})
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
block <- struct{}{}
|
||||
return autonatv2.Result{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
r := &addrsReachabilityTracker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: mockClient,
|
||||
probeManager: pm,
|
||||
clock: clock.New(),
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result := r.refreshReachability()
|
||||
assert.False(t, <-result.BackoffCh)
|
||||
assert.Equal(t, pm.InProgressProbes(), 0)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
time.Sleep(50 * time.Millisecond) // wait for the cancellation to be processed
|
||||
|
||||
outer:
|
||||
for i := 0; i < defaultMaxConcurrency; i++ {
|
||||
select {
|
||||
case <-block:
|
||||
default:
|
||||
break outer
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-block:
|
||||
t.Fatal("expected no more requests")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("handles refusals", func(t *testing.T) {
|
||||
pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1")
|
||||
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
for i, req := range reqs {
|
||||
if req.Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1})
|
||||
r := newTracker(mockClient, pm)
|
||||
|
||||
result := r.refreshReachability()
|
||||
require.False(t, <-result.BackoffCh)
|
||||
|
||||
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
require.Equal(t, pm.InProgressProbes(), 0)
|
||||
})
|
||||
|
||||
t.Run("handles completions", func(t *testing.T) {
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
for i, req := range reqs {
|
||||
if req.Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
if req.Addr.Equal(pub2) {
|
||||
return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil
|
||||
}
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1})
|
||||
r := newTracker(mockClient, pm)
|
||||
result := r.refreshReachability()
|
||||
require.False(t, <-result.BackoffCh)
|
||||
|
||||
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2})
|
||||
require.Equal(t, pm.InProgressProbes(), 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddrStatusProbeCount(t *testing.T) {
|
||||
cases := []struct {
|
||||
inputs string
|
||||
wantRequiredProbes int
|
||||
wantReachability network.Reachability
|
||||
}{
|
||||
{
|
||||
inputs: "",
|
||||
wantRequiredProbes: 3,
|
||||
wantReachability: network.ReachabilityUnknown,
|
||||
},
|
||||
{
|
||||
inputs: "S",
|
||||
wantRequiredProbes: 2,
|
||||
wantReachability: network.ReachabilityUnknown,
|
||||
},
|
||||
{
|
||||
inputs: "SS",
|
||||
wantRequiredProbes: 1,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SSS",
|
||||
wantRequiredProbes: 0,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SSSSSSSF",
|
||||
wantRequiredProbes: 1,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SFSFSSSS",
|
||||
wantRequiredProbes: 0,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SSSSSFSF",
|
||||
wantRequiredProbes: 2,
|
||||
wantReachability: network.ReachabilityUnknown,
|
||||
},
|
||||
{
|
||||
inputs: "FF",
|
||||
wantRequiredProbes: 1,
|
||||
wantReachability: network.ReachabilityPrivate,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.inputs, func(t *testing.T) {
|
||||
now := time.Time{}.Add(1 * time.Second)
|
||||
ao := addrStatus{}
|
||||
for _, r := range c.inputs {
|
||||
if r == 'S' {
|
||||
ao.AddOutcome(now, network.ReachabilityPublic, 5)
|
||||
} else {
|
||||
ao.AddOutcome(now, network.ReachabilityPrivate, 5)
|
||||
}
|
||||
now = now.Add(1 * time.Second)
|
||||
}
|
||||
require.Equal(t, ao.RequiredProbeCount(now), c.wantRequiredProbes)
|
||||
require.Equal(t, ao.Reachability(), c.wantReachability)
|
||||
if c.wantRequiredProbes == 0 {
|
||||
now = now.Add(highConfidenceAddrProbeInterval + 10*time.Microsecond)
|
||||
require.Equal(t, ao.RequiredProbeCount(now), 1)
|
||||
}
|
||||
|
||||
now = now.Add(1 * time.Second)
|
||||
ao.RemoveBefore(now)
|
||||
require.Len(t, ao.outcomes, 0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAddrTracker(b *testing.B) {
|
||||
cl := clock.NewMock()
|
||||
t := newProbeManager(cl.Now)
|
||||
|
||||
addrs := make([]ma.Multiaddr, 20)
|
||||
for i := range addrs {
|
||||
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000)))
|
||||
}
|
||||
t.UpdateAddrs(addrs)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
p := t.GetProbe()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pp := t.GetProbe()
|
||||
if len(pp) == 0 {
|
||||
pp = p
|
||||
}
|
||||
t.MarkProbeInProgress(pp)
|
||||
t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzAddrsReachabilityTracker(f *testing.F) {
|
||||
type autonatv2Response struct {
|
||||
Result autonatv2.Result
|
||||
Err error
|
||||
}
|
||||
|
||||
newMockClient := func(b []byte) mockAutoNATClient {
|
||||
count := 0
|
||||
return mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
if len(b) == 0 {
|
||||
return autonatv2.Result{}, nil
|
||||
}
|
||||
count = (count + 1) % len(b)
|
||||
if b[count]%3 == 0 {
|
||||
// some address confirmed
|
||||
c1 := (count + 1) % len(b)
|
||||
c2 := (count + 2) % len(b)
|
||||
rch := network.Reachability(b[c1] % 3)
|
||||
n := int(b[c2]) % len(reqs)
|
||||
return autonatv2.Result{
|
||||
Addr: reqs[n].Addr,
|
||||
Idx: n,
|
||||
Reachability: rch,
|
||||
}, nil
|
||||
}
|
||||
outcomes := []autonatv2Response{
|
||||
{Result: autonatv2.Result{AllAddrsRefused: true}},
|
||||
{Err: errors.New("test error")},
|
||||
{Err: autonatv2.ErrPrivateAddrs},
|
||||
{Err: autonatv2.ErrNoPeers},
|
||||
{Result: autonatv2.Result{}, Err: nil},
|
||||
{Result: autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}},
|
||||
{Result: autonatv2.Result{
|
||||
Addr: reqs[0].Addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
AllAddrsRefused: true,
|
||||
}},
|
||||
{Result: autonatv2.Result{
|
||||
Addr: reqs[0].Addr,
|
||||
Idx: len(reqs) - 1, // invalid idx
|
||||
Reachability: network.ReachabilityPublic,
|
||||
AllAddrsRefused: false,
|
||||
}},
|
||||
}
|
||||
outcome := outcomes[int(b[count])%len(outcomes)]
|
||||
return outcome.Result, outcome.Err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Move this to go-multiaddrs
|
||||
getProto := func(protos []byte) ma.Multiaddr {
|
||||
protoType := 0
|
||||
if len(protos) > 0 {
|
||||
protoType = int(protos[0])
|
||||
}
|
||||
|
||||
port1, port2 := 0, 0
|
||||
if len(protos) > 1 {
|
||||
port1 = int(protos[1])
|
||||
}
|
||||
if len(protos) > 2 {
|
||||
port2 = int(protos[2])
|
||||
}
|
||||
protoTemplates := []string{
|
||||
"/tcp/%d/",
|
||||
"/udp/%d/",
|
||||
"/udp/%d/quic-v1/",
|
||||
"/udp/%d/quic-v1/tcp/%d",
|
||||
"/udp/%d/quic-v1/webtransport/",
|
||||
"/udp/%d/webrtc/",
|
||||
"/udp/%d/webrtc-direct/",
|
||||
"/unix/hello/",
|
||||
}
|
||||
s := protoTemplates[protoType%len(protoTemplates)]
|
||||
port1 %= (1 << 16)
|
||||
if strings.Count(s, "%d") == 1 {
|
||||
return ma.StringCast(fmt.Sprintf(s, port1))
|
||||
}
|
||||
port2 %= (1 << 16)
|
||||
return ma.StringCast(fmt.Sprintf(s, port1, port2))
|
||||
}
|
||||
|
||||
getIP := func(ips []byte) ma.Multiaddr {
|
||||
ipType := 0
|
||||
if len(ips) > 0 {
|
||||
ipType = int(ips[0])
|
||||
}
|
||||
ips = ips[1:]
|
||||
var x, y int64
|
||||
split := 128 / 8
|
||||
if len(ips) < split {
|
||||
split = len(ips)
|
||||
}
|
||||
var b [8]byte
|
||||
copy(b[:], ips[:split])
|
||||
x = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
clear(b[:])
|
||||
copy(b[:], ips[split:])
|
||||
y = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
|
||||
var ip netip.Addr
|
||||
switch ipType % 3 {
|
||||
case 0:
|
||||
ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)})
|
||||
return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip))
|
||||
case 1:
|
||||
pubIP := net.ParseIP("2005::") // Public IP address
|
||||
x := int64(binary.LittleEndian.Uint64(pubIP[0:8]))
|
||||
ip = netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
default:
|
||||
ip := netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
}
|
||||
}
|
||||
|
||||
getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr {
|
||||
switch addrType % 4 {
|
||||
case 0:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
case 1:
|
||||
return getProto(protos)
|
||||
case 2:
|
||||
return nil
|
||||
default:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
}
|
||||
}
|
||||
|
||||
getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr {
|
||||
hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "")
|
||||
hostName = strings.ReplaceAll(hostName, "/", "")
|
||||
if hostName == "" {
|
||||
hostName = "localhost"
|
||||
}
|
||||
dnsType := 0
|
||||
if len(hostNameBytes) > 0 {
|
||||
dnsType = int(hostNameBytes[0])
|
||||
}
|
||||
dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"}
|
||||
da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName))
|
||||
return da.Encapsulate(getProto(protos))
|
||||
}
|
||||
|
||||
const maxAddrs = 1000
|
||||
getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr {
|
||||
if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs
|
||||
addrs := make([]ma.Multiaddr, numAddrs)
|
||||
ipIdx := 0
|
||||
protoIdx := 0
|
||||
for i := range numAddrs {
|
||||
addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:])
|
||||
ipIdx = (ipIdx + 1) % len(ips)
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
maxDNSAddrs := 10
|
||||
protoIdx = 0
|
||||
for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 {
|
||||
ed := min(i+2, len(hostNames))
|
||||
addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:]))
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
cl := clock.NewMock()
|
||||
f.Fuzz(func(t *testing.T, numAddrs int, ips, protos, hostNames, autonatResponses []byte) {
|
||||
tr := newAddrsReachabilityTracker(newMockClient(autonatResponses), nil, cl)
|
||||
require.NoError(t, tr.Start())
|
||||
tr.UpdateAddrs(getAddrs(numAddrs, ips, protos, hostNames))
|
||||
|
||||
// fuzz tests need to finish in 10 seconds for some reason
|
||||
// https://github.com/golang/go/issues/48157
|
||||
// https://github.com/golang/go/commit/5d24203c394e6b64c42a9f69b990d94cb6c8aad4#diff-4e3b9481b8794eb058998e2bec389d3db7a23c54e67ac0f7259a3a5d2c79fd04R474-R483
|
||||
const maxIters = 20
|
||||
for range maxIters {
|
||||
cl.Add(5 * time.Minute)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
require.NoError(t, tr.Close())
|
||||
})
|
||||
}
|
@@ -156,8 +156,8 @@ type HostOpts struct {
|
||||
|
||||
// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify
|
||||
DisableIdentifyAddressDiscovery bool
|
||||
EnableAutoNATv2 bool
|
||||
AutoNATv2Dialer host.Host
|
||||
|
||||
AutoNATv2 *autonatv2.AutoNAT
|
||||
}
|
||||
|
||||
// NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network.
|
||||
@@ -236,7 +236,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
|
||||
}); ok {
|
||||
tfl = s.TransportForListening
|
||||
}
|
||||
h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan)
|
||||
|
||||
if opts.AutoNATv2 != nil {
|
||||
h.autonatv2 = opts.AutoNATv2
|
||||
}
|
||||
|
||||
var autonatv2Client autonatv2Client // avoid typed nil errors
|
||||
if h.autonatv2 != nil {
|
||||
autonatv2Client = h.autonatv2
|
||||
}
|
||||
h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan, autonatv2Client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create address service: %w", err)
|
||||
}
|
||||
@@ -283,17 +292,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
|
||||
h.pings = ping.NewPingService(h)
|
||||
}
|
||||
|
||||
if opts.EnableAutoNATv2 {
|
||||
var mt autonatv2.MetricsTracer
|
||||
if opts.EnableMetrics {
|
||||
mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer)
|
||||
}
|
||||
h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create autonatv2: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !h.disableSignedPeerRecord {
|
||||
h.signKey = h.Peerstore().PrivKey(h.ID())
|
||||
cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore())
|
||||
@@ -320,7 +318,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
|
||||
func (h *BasicHost) Start() {
|
||||
h.psManager.Start()
|
||||
if h.autonatv2 != nil {
|
||||
err := h.autonatv2.Start()
|
||||
err := h.autonatv2.Start(h)
|
||||
if err != nil {
|
||||
log.Errorf("autonat v2 failed to start: %s", err)
|
||||
}
|
||||
@@ -754,6 +752,16 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr {
|
||||
return h.addressManager.DirectAddrs()
|
||||
}
|
||||
|
||||
// ReachableAddrs returns all addresses of the host that are reachable from the internet
|
||||
// as verified by autonatv2.
|
||||
//
|
||||
// Experimental: This API may change in the future without deprecation.
|
||||
//
|
||||
// Requires AutoNATv2 to be enabled.
|
||||
func (h *BasicHost) ReachableAddrs() []ma.Multiaddr {
|
||||
return h.addressManager.ReachableAddrs()
|
||||
}
|
||||
|
||||
func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr {
|
||||
totalSize := 0
|
||||
for _, a := range addrs {
|
||||
@@ -836,7 +844,6 @@ func (h *BasicHost) Close() error {
|
||||
if h.cmgr != nil {
|
||||
h.cmgr.Close()
|
||||
}
|
||||
|
||||
h.addressManager.Close()
|
||||
|
||||
if h.ids != nil {
|
||||
|
@@ -47,6 +47,7 @@ func TestHostSimple(t *testing.T) {
|
||||
h1.Start()
|
||||
h2, err := NewHost(swarmt.GenSwarm(t), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer h2.Close()
|
||||
h2.Start()
|
||||
|
||||
@@ -211,6 +212,7 @@ func TestAllAddrs(t *testing.T) {
|
||||
// no listen addrs
|
||||
h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
|
||||
require.NoError(t, err)
|
||||
h.Start()
|
||||
defer h.Close()
|
||||
require.Nil(t, h.AllAddrs())
|
||||
|
||||
|
@@ -4,18 +4,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"math/rand/v2"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-libp2p/core/event"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
@@ -35,11 +34,15 @@ const (
|
||||
// maxPeerAddresses is the number of addresses in a dial request the server
|
||||
// will inspect, rest are ignored.
|
||||
maxPeerAddresses = 50
|
||||
|
||||
defaultThrottlePeerDuration = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoValidPeers = errors.New("no valid peers for autonat v2")
|
||||
ErrDialRefused = errors.New("dial refused")
|
||||
// ErrNoPeers is returned when the client knows no autonatv2 servers.
|
||||
ErrNoPeers = errors.New("no peers for autonat v2")
|
||||
// ErrPrivateAddrs is returned when the request has private IP addresses.
|
||||
ErrPrivateAddrs = errors.New("private addresses cannot be verified with autonatv2")
|
||||
|
||||
log = logging.Logger("autonatv2")
|
||||
)
|
||||
@@ -56,10 +59,12 @@ type Request struct {
|
||||
type Result struct {
|
||||
// Addr is the dialed address
|
||||
Addr ma.Multiaddr
|
||||
// Reachability of the dialed address
|
||||
// Idx is the index of the address that was dialed
|
||||
Idx int
|
||||
// Reachability is the reachability for `Addr`
|
||||
Reachability network.Reachability
|
||||
// Status is the outcome of the dialback
|
||||
Status pb.DialStatus
|
||||
// AllAddrsRefused is true when the server refused to dial all the addresses in the request.
|
||||
AllAddrsRefused bool
|
||||
}
|
||||
|
||||
// AutoNAT implements the AutoNAT v2 client and server.
|
||||
@@ -78,6 +83,10 @@ type AutoNAT struct {
|
||||
|
||||
mx sync.Mutex
|
||||
peers *peersMap
|
||||
throttlePeer map[peer.ID]time.Time
|
||||
// throttlePeerDuration is the duration to wait before making another dial request to the
|
||||
// same server.
|
||||
throttlePeerDuration time.Duration
|
||||
// allowPrivateAddrs enables using private and localhost addresses for reachability checks.
|
||||
// This is only useful for testing.
|
||||
allowPrivateAddrs bool
|
||||
@@ -86,7 +95,7 @@ type AutoNAT struct {
|
||||
// New returns a new AutoNAT instance.
|
||||
// host and dialerHost should have the same dialing capabilities. In case the host doesn't support
|
||||
// a transport, dial back requests for address for that transport will be ignored.
|
||||
func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
|
||||
func New(dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
|
||||
s := defaultSettings()
|
||||
for _, o := range opts {
|
||||
if err := o(s); err != nil {
|
||||
@@ -96,18 +105,20 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT,
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
an := &AutoNAT{
|
||||
host: host,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
srv: newServer(host, dialerHost, s),
|
||||
cli: newClient(host),
|
||||
srv: newServer(dialerHost, s),
|
||||
cli: newClient(),
|
||||
allowPrivateAddrs: s.allowPrivateAddrs,
|
||||
peers: newPeersMap(),
|
||||
throttlePeer: make(map[peer.ID]time.Time),
|
||||
throttlePeerDuration: s.throttlePeerDuration,
|
||||
}
|
||||
return an, nil
|
||||
}
|
||||
|
||||
func (an *AutoNAT) background(sub event.Subscription) {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-an.ctx.Done():
|
||||
@@ -122,12 +133,24 @@ func (an *AutoNAT) background(sub event.Subscription) {
|
||||
an.updatePeer(evt.Peer)
|
||||
case event.EvtPeerIdentificationCompleted:
|
||||
an.updatePeer(evt.Peer)
|
||||
default:
|
||||
log.Errorf("unexpected event: %T", e)
|
||||
}
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
an.mx.Lock()
|
||||
for p, t := range an.throttlePeer {
|
||||
if t.Before(now) {
|
||||
delete(an.throttlePeer, p)
|
||||
}
|
||||
}
|
||||
an.mx.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (an *AutoNAT) Start() error {
|
||||
func (an *AutoNAT) Start(h host.Host) error {
|
||||
an.host = h
|
||||
// Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged
|
||||
// event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers.
|
||||
sub, err := an.host.EventBus().Subscribe([]interface{}{
|
||||
@@ -138,8 +161,8 @@ func (an *AutoNAT) Start() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("event subscription failed: %w", err)
|
||||
}
|
||||
an.cli.Start()
|
||||
an.srv.Start()
|
||||
an.cli.Start(h)
|
||||
an.srv.Start(h)
|
||||
|
||||
an.wg.Add(1)
|
||||
go an.background(sub)
|
||||
@@ -156,24 +179,48 @@ func (an *AutoNAT) Close() {
|
||||
|
||||
// GetReachability makes a single dial request for checking reachability for requested addresses
|
||||
func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) {
|
||||
var filteredReqs []Request
|
||||
if !an.allowPrivateAddrs {
|
||||
filteredReqs = make([]Request, 0, len(reqs))
|
||||
for _, r := range reqs {
|
||||
if !manet.IsPublicAddr(r.Addr) {
|
||||
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr)
|
||||
if manet.IsPublicAddr(r.Addr) {
|
||||
filteredReqs = append(filteredReqs, r)
|
||||
} else {
|
||||
log.Errorf("private address in reachability check: %s", r.Addr)
|
||||
}
|
||||
}
|
||||
if len(filteredReqs) == 0 {
|
||||
return Result{}, ErrPrivateAddrs
|
||||
}
|
||||
} else {
|
||||
filteredReqs = reqs
|
||||
}
|
||||
an.mx.Lock()
|
||||
p := an.peers.GetRand()
|
||||
now := time.Now()
|
||||
var p peer.ID
|
||||
for pr := range an.peers.Shuffled() {
|
||||
if t := an.throttlePeer[pr]; t.After(now) {
|
||||
continue
|
||||
}
|
||||
p = pr
|
||||
an.throttlePeer[p] = time.Now().Add(an.throttlePeerDuration)
|
||||
break
|
||||
}
|
||||
an.mx.Unlock()
|
||||
if p == "" {
|
||||
return Result{}, ErrNoValidPeers
|
||||
return Result{}, ErrNoPeers
|
||||
}
|
||||
|
||||
res, err := an.cli.GetReachability(ctx, p, reqs)
|
||||
res, err := an.cli.GetReachability(ctx, p, filteredReqs)
|
||||
if err != nil {
|
||||
log.Debugf("reachability check with %s failed, err: %s", p, err)
|
||||
return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err)
|
||||
return res, fmt.Errorf("reachability check with %s failed: %w", p, err)
|
||||
}
|
||||
// restore the correct index in case we'd filtered private addresses
|
||||
for i, r := range reqs {
|
||||
if r.Addr.Equal(res.Addr) {
|
||||
res.Idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Debugf("reachability check with %s successful", p)
|
||||
return res, nil
|
||||
@@ -187,7 +234,7 @@ func (an *AutoNAT) updatePeer(p peer.ID) {
|
||||
// and swarm for the current state
|
||||
protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol)
|
||||
connectedness := an.host.Network().Connectedness(p)
|
||||
if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected {
|
||||
if err == nil && connectedness == network.Connected && slices.Contains(protos, DialProtocol) {
|
||||
an.peers.Put(p)
|
||||
} else {
|
||||
an.peers.Delete(p)
|
||||
@@ -208,28 +255,40 @@ func newPeersMap() *peersMap {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *peersMap) GetRand() peer.ID {
|
||||
if len(p.peers) == 0 {
|
||||
return ""
|
||||
// Shuffled iterates over the map in random order
|
||||
func (p *peersMap) Shuffled() iter.Seq[peer.ID] {
|
||||
n := len(p.peers)
|
||||
start := 0
|
||||
if n > 0 {
|
||||
start = rand.IntN(n)
|
||||
}
|
||||
return p.peers[rand.IntN(len(p.peers))]
|
||||
}
|
||||
|
||||
func (p *peersMap) Put(pid peer.ID) {
|
||||
if _, ok := p.peerIdx[pid]; ok {
|
||||
return func(yield func(peer.ID) bool) {
|
||||
for i := range n {
|
||||
if !yield(p.peers[(i+start)%n]) {
|
||||
return
|
||||
}
|
||||
p.peers = append(p.peers, pid)
|
||||
p.peerIdx[pid] = len(p.peers) - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *peersMap) Delete(pid peer.ID) {
|
||||
idx, ok := p.peerIdx[pid]
|
||||
func (p *peersMap) Put(id peer.ID) {
|
||||
if _, ok := p.peerIdx[id]; ok {
|
||||
return
|
||||
}
|
||||
p.peers = append(p.peers, id)
|
||||
p.peerIdx[id] = len(p.peers) - 1
|
||||
}
|
||||
|
||||
func (p *peersMap) Delete(id peer.ID) {
|
||||
idx, ok := p.peerIdx[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.peers[idx] = p.peers[len(p.peers)-1]
|
||||
p.peerIdx[p.peers[idx]] = idx
|
||||
p.peers = p.peers[:len(p.peers)-1]
|
||||
delete(p.peerIdx, pid)
|
||||
n := len(p.peers)
|
||||
lastPeer := p.peers[n-1]
|
||||
p.peers[idx] = lastPeer
|
||||
p.peerIdx[lastPeer] = idx
|
||||
p.peers[n-1] = ""
|
||||
p.peers = p.peers[:n-1]
|
||||
delete(p.peerIdx, id)
|
||||
}
|
||||
|
@@ -2,8 +2,13 @@ package autonatv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -36,11 +41,12 @@ func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT
|
||||
swarm.WithUDPBlackHoleSuccessCounter(nil),
|
||||
swarm.WithIPv6BlackHoleSuccessCounter(nil))))
|
||||
}
|
||||
an, err := New(h, dialer, opts...)
|
||||
opts = append([]AutoNATOption{withThrottlePeerDuration(0)}, opts...)
|
||||
an, err := New(dialer, opts...)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
an.Start()
|
||||
require.NoError(t, an.Start(h))
|
||||
t.Cleanup(an.Close)
|
||||
return an
|
||||
}
|
||||
@@ -74,7 +80,7 @@ func waitForPeer(t testing.TB, a *AutoNAT) {
|
||||
require.Eventually(t, func() bool {
|
||||
a.mx.Lock()
|
||||
defer a.mx.Unlock()
|
||||
return a.peers.GetRand() != ""
|
||||
return len(a.peers.peers) != 0
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -88,7 +94,7 @@ func TestAutoNATPrivateAddr(t *testing.T) {
|
||||
an := newAutoNAT(t, nil)
|
||||
res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}})
|
||||
require.Equal(t, res, Result{})
|
||||
require.Contains(t, err.Error(), "private address cannot be verified by autonatv2")
|
||||
require.ErrorIs(t, err, ErrPrivateAddrs)
|
||||
}
|
||||
|
||||
func TestClientRequest(t *testing.T) {
|
||||
@@ -154,19 +160,6 @@ func TestClientServerError(t *testing.T) {
|
||||
},
|
||||
errorStr: "invalid msg type",
|
||||
},
|
||||
{
|
||||
handler: func(s network.Stream) {
|
||||
w := pbio.NewDelimitedWriter(s)
|
||||
assert.NoError(t, w.WriteMsg(
|
||||
&pb.Message{Msg: &pb.Message_DialResponse{
|
||||
DialResponse: &pb.DialResponse{
|
||||
Status: pb.DialResponse_E_DIAL_REFUSED,
|
||||
},
|
||||
}},
|
||||
))
|
||||
},
|
||||
errorStr: ErrDialRefused.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
@@ -298,6 +291,49 @@ func TestClientDataRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoNATPrivateAndPublicAddrs(t *testing.T) {
|
||||
an := newAutoNAT(t, nil)
|
||||
defer an.Close()
|
||||
defer an.host.Close()
|
||||
|
||||
b := bhost.NewBlankHost(swarmt.GenSwarm(t))
|
||||
defer b.Close()
|
||||
idAndConnect(t, an.host, b)
|
||||
waitForPeer(t, an)
|
||||
|
||||
dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t))
|
||||
defer dialerHost.Close()
|
||||
handler := func(s network.Stream) {
|
||||
w := pbio.NewDelimitedWriter(s)
|
||||
r := pbio.NewDelimitedReader(s, maxMsgSize)
|
||||
var msg pb.Message
|
||||
assert.NoError(t, r.ReadMsg(&msg))
|
||||
w.WriteMsg(&pb.Message{
|
||||
Msg: &pb.Message_DialResponse{
|
||||
DialResponse: &pb.DialResponse{
|
||||
Status: pb.DialResponse_OK,
|
||||
DialStatus: pb.DialStatus_E_DIAL_ERROR,
|
||||
AddrIdx: 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
s.Close()
|
||||
}
|
||||
|
||||
b.SetStreamHandler(DialProtocol, handler)
|
||||
privateAddr := ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")
|
||||
publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/10/quic-v1")
|
||||
res, err := an.GetReachability(context.Background(),
|
||||
[]Request{
|
||||
{Addr: privateAddr},
|
||||
{Addr: publicAddr},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res.Addr, publicAddr, "%s\n%s", res.Addr, publicAddr)
|
||||
require.Equal(t, res.Idx, 1)
|
||||
require.Equal(t, res.Reachability, network.ReachabilityPrivate)
|
||||
}
|
||||
|
||||
func TestClientDialBacks(t *testing.T) {
|
||||
an := newAutoNAT(t, nil, allowPrivateAddrs)
|
||||
defer an.Close()
|
||||
@@ -507,7 +543,6 @@ func TestClientDialBacks(t *testing.T) {
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res.Reachability, network.ReachabilityPublic)
|
||||
require.Equal(t, res.Status, pb.DialStatus_OK)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -551,46 +586,6 @@ func TestEventSubscription(t *testing.T) {
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestPeersMap(t *testing.T) {
|
||||
emptyPeerID := peer.ID("")
|
||||
|
||||
t.Run("single_item", func(t *testing.T) {
|
||||
p := newPeersMap()
|
||||
p.Put("peer1")
|
||||
p.Delete("peer1")
|
||||
p.Put("peer1")
|
||||
require.Equal(t, peer.ID("peer1"), p.GetRand())
|
||||
p.Delete("peer1")
|
||||
require.Equal(t, emptyPeerID, p.GetRand())
|
||||
})
|
||||
|
||||
t.Run("multiple_items", func(t *testing.T) {
|
||||
p := newPeersMap()
|
||||
require.Equal(t, emptyPeerID, p.GetRand())
|
||||
|
||||
allPeers := make(map[peer.ID]bool)
|
||||
for i := 0; i < 20; i++ {
|
||||
pid := peer.ID(fmt.Sprintf("peer-%d", i))
|
||||
allPeers[pid] = true
|
||||
p.Put(pid)
|
||||
}
|
||||
foundPeers := make(map[peer.ID]bool)
|
||||
for i := 0; i < 1000; i++ {
|
||||
pid := p.GetRand()
|
||||
require.NotEqual(t, emptyPeerID, p)
|
||||
require.True(t, allPeers[pid])
|
||||
foundPeers[pid] = true
|
||||
if len(foundPeers) == len(allPeers) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for pid := range allPeers {
|
||||
p.Delete(pid)
|
||||
}
|
||||
require.Equal(t, emptyPeerID, p.GetRand())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAreAddrsConsistency(t *testing.T) {
|
||||
c := &client{
|
||||
normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr {
|
||||
@@ -645,6 +640,12 @@ func TestAreAddrsConsistency(t *testing.T) {
|
||||
dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"),
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
name: "dns6",
|
||||
localAddr: ma.StringCast("/dns6/lib.p2p/udp/12345/quic-v1"),
|
||||
dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/"),
|
||||
success: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -658,3 +659,173 @@ func TestAreAddrsConsistency(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerMap(t *testing.T) {
|
||||
pm := newPeersMap()
|
||||
// Add 1, 2, 3
|
||||
pm.Put(peer.ID("1"))
|
||||
pm.Put(peer.ID("2"))
|
||||
pm.Put(peer.ID("3"))
|
||||
|
||||
// Remove 3, 2
|
||||
pm.Delete(peer.ID("3"))
|
||||
pm.Delete(peer.ID("2"))
|
||||
|
||||
// Add 4
|
||||
pm.Put(peer.ID("4"))
|
||||
|
||||
// Remove 3, 2 again. Should be no op
|
||||
pm.Delete(peer.ID("3"))
|
||||
pm.Delete(peer.ID("2"))
|
||||
|
||||
contains := []peer.ID{"1", "4"}
|
||||
elems := make([]peer.ID, 0)
|
||||
for p := range pm.Shuffled() {
|
||||
elems = append(elems, p)
|
||||
}
|
||||
require.ElementsMatch(t, contains, elems)
|
||||
}
|
||||
|
||||
func FuzzClient(f *testing.F) {
|
||||
a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2))
|
||||
c := newAutoNAT(f, nil)
|
||||
idAndWait(f, c, a)
|
||||
|
||||
// TODO: Move this to go-multiaddrs
|
||||
getProto := func(protos []byte) ma.Multiaddr {
|
||||
protoType := 0
|
||||
if len(protos) > 0 {
|
||||
protoType = int(protos[0])
|
||||
}
|
||||
|
||||
port1, port2 := 0, 0
|
||||
if len(protos) > 1 {
|
||||
port1 = int(protos[1])
|
||||
}
|
||||
if len(protos) > 2 {
|
||||
port2 = int(protos[2])
|
||||
}
|
||||
protoTemplates := []string{
|
||||
"/tcp/%d/",
|
||||
"/udp/%d/",
|
||||
"/udp/%d/quic-v1/",
|
||||
"/udp/%d/quic-v1/tcp/%d",
|
||||
"/udp/%d/quic-v1/webtransport/",
|
||||
"/udp/%d/webrtc/",
|
||||
"/udp/%d/webrtc-direct/",
|
||||
"/unix/hello/",
|
||||
}
|
||||
s := protoTemplates[protoType%len(protoTemplates)]
|
||||
port1 %= (1 << 16)
|
||||
if strings.Count(s, "%d") == 1 {
|
||||
return ma.StringCast(fmt.Sprintf(s, port1))
|
||||
}
|
||||
port2 %= (1 << 16)
|
||||
return ma.StringCast(fmt.Sprintf(s, port1, port2))
|
||||
}
|
||||
|
||||
getIP := func(ips []byte) ma.Multiaddr {
|
||||
ipType := 0
|
||||
if len(ips) > 0 {
|
||||
ipType = int(ips[0])
|
||||
}
|
||||
ips = ips[1:]
|
||||
var x, y int64
|
||||
split := 128 / 8
|
||||
if len(ips) < split {
|
||||
split = len(ips)
|
||||
}
|
||||
var b [8]byte
|
||||
copy(b[:], ips[:split])
|
||||
x = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
clear(b[:])
|
||||
copy(b[:], ips[split:])
|
||||
y = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
|
||||
var ip netip.Addr
|
||||
switch ipType % 3 {
|
||||
case 0:
|
||||
ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)})
|
||||
return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip))
|
||||
case 1:
|
||||
pubIP := net.ParseIP("2005::") // Public IP address
|
||||
x := int64(binary.LittleEndian.Uint64(pubIP[0:8]))
|
||||
ip = netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
default:
|
||||
ip := netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
}
|
||||
}
|
||||
|
||||
getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr {
|
||||
switch addrType % 4 {
|
||||
case 0:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
case 1:
|
||||
return getProto(protos)
|
||||
case 2:
|
||||
return nil
|
||||
default:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
}
|
||||
}
|
||||
|
||||
getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr {
|
||||
hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "")
|
||||
hostName = strings.ReplaceAll(hostName, "/", "")
|
||||
if hostName == "" {
|
||||
hostName = "localhost"
|
||||
}
|
||||
dnsType := 0
|
||||
if len(hostNameBytes) > 0 {
|
||||
dnsType = int(hostNameBytes[0])
|
||||
}
|
||||
dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"}
|
||||
da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName))
|
||||
return da.Encapsulate(getProto(protos))
|
||||
}
|
||||
|
||||
const maxAddrs = 100
|
||||
getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr {
|
||||
if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs
|
||||
addrs := make([]ma.Multiaddr, numAddrs)
|
||||
ipIdx := 0
|
||||
protoIdx := 0
|
||||
for i := range numAddrs {
|
||||
addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:])
|
||||
ipIdx = (ipIdx + 1) % len(ips)
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
maxDNSAddrs := 10
|
||||
protoIdx = 0
|
||||
for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 {
|
||||
ed := min(i+2, len(hostNames))
|
||||
addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:]))
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
// reduce the streamTimeout before running this. TODO: fix this
|
||||
f.Fuzz(func(_ *testing.T, numAddrs int, ips, protos, hostNames []byte) {
|
||||
addrs := getAddrs(numAddrs, ips, protos, hostNames)
|
||||
reqs := make([]Request, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
reqs[i] = Request{Addr: addr, SendDialData: true}
|
||||
}
|
||||
c.GetReachability(context.Background(), reqs)
|
||||
})
|
||||
}
|
||||
|
@@ -35,20 +35,20 @@ type normalizeMultiaddrer interface {
|
||||
NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr
|
||||
}
|
||||
|
||||
func newClient(h host.Host) *client {
|
||||
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
|
||||
if hn, ok := h.(normalizeMultiaddrer); ok {
|
||||
normalizeMultiaddr = hn.NormalizeMultiaddr
|
||||
}
|
||||
func newClient() *client {
|
||||
return &client{
|
||||
host: h,
|
||||
dialData: make([]byte, 4000),
|
||||
normalizeMultiaddr: normalizeMultiaddr,
|
||||
dialBackQueues: make(map[uint64]chan ma.Multiaddr),
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *client) Start() {
|
||||
func (ac *client) Start(h host.Host) {
|
||||
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
|
||||
if hn, ok := h.(normalizeMultiaddrer); ok {
|
||||
normalizeMultiaddr = hn.NormalizeMultiaddr
|
||||
}
|
||||
ac.host = h
|
||||
ac.normalizeMultiaddr = normalizeMultiaddr
|
||||
ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack)
|
||||
}
|
||||
|
||||
@@ -109,9 +109,9 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
break
|
||||
// provide dial data if appropriate
|
||||
case msg.GetDialDataRequest() != nil:
|
||||
if err := ac.validateDialDataRequest(reqs, &msg); err != nil {
|
||||
if err := validateDialDataRequest(reqs, &msg); err != nil {
|
||||
s.Reset()
|
||||
return Result{}, fmt.Errorf("invalid dial data request: %w", err)
|
||||
return Result{}, fmt.Errorf("invalid dial data request: %s %w", s.Conn().RemoteMultiaddr(), err)
|
||||
}
|
||||
// dial data request is valid and we want to send data
|
||||
if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil {
|
||||
@@ -136,7 +136,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
// E_DIAL_REFUSED has implication for deciding future address verificiation priorities
|
||||
// wrap a distinct error for convenient errors.Is usage
|
||||
if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED {
|
||||
return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused)
|
||||
return Result{AllAddrsRefused: true}, nil
|
||||
}
|
||||
return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(),
|
||||
pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())])
|
||||
@@ -147,7 +147,6 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
if int(resp.AddrIdx) >= len(reqs) {
|
||||
return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs))
|
||||
}
|
||||
|
||||
// wait for nonce from the server
|
||||
var dialBackAddr ma.Multiaddr
|
||||
if resp.GetDialStatus() == pb.DialStatus_OK {
|
||||
@@ -163,7 +162,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
return ac.newResult(resp, reqs, dialBackAddr)
|
||||
}
|
||||
|
||||
func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error {
|
||||
func validateDialDataRequest(reqs []Request, msg *pb.Message) error {
|
||||
idx := int(msg.GetDialDataRequest().AddrIdx)
|
||||
if idx >= len(reqs) { // invalid address index
|
||||
return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs))
|
||||
@@ -179,9 +178,13 @@ func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error
|
||||
|
||||
func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) {
|
||||
idx := int(resp.AddrIdx)
|
||||
if idx >= len(reqs) {
|
||||
// This should have been validated by this point, but checking this is cheap.
|
||||
return Result{}, fmt.Errorf("addrs index(%d) greater than len(reqs)(%d)", idx, len(reqs))
|
||||
}
|
||||
addr := reqs[idx].Addr
|
||||
|
||||
var rch network.Reachability
|
||||
rch := network.ReachabilityUnknown //nolint:ineffassign
|
||||
switch resp.DialStatus {
|
||||
case pb.DialStatus_OK:
|
||||
if !ac.areAddrsConsistent(dialBackAddr, addr) {
|
||||
@@ -191,17 +194,16 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr
|
||||
return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
|
||||
}
|
||||
rch = network.ReachabilityPublic
|
||||
case pb.DialStatus_E_DIAL_ERROR:
|
||||
rch = network.ReachabilityPrivate
|
||||
case pb.DialStatus_E_DIAL_BACK_ERROR:
|
||||
if ac.areAddrsConsistent(dialBackAddr, addr) {
|
||||
if !ac.areAddrsConsistent(dialBackAddr, addr) {
|
||||
return Result{}, fmt.Errorf("dial-back stream error: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
|
||||
}
|
||||
// We received the dial back but the server claims the dial back errored.
|
||||
// As long as we received the correct nonce in dial back it is safe to assume
|
||||
// that we are public.
|
||||
rch = network.ReachabilityPublic
|
||||
} else {
|
||||
rch = network.ReachabilityUnknown
|
||||
}
|
||||
case pb.DialStatus_E_DIAL_ERROR:
|
||||
rch = network.ReachabilityPrivate
|
||||
default:
|
||||
// Unexpected response code. Discard the response and fail.
|
||||
log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus)
|
||||
@@ -210,8 +212,8 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr
|
||||
|
||||
return Result{
|
||||
Addr: addr,
|
||||
Idx: idx,
|
||||
Reachability: rch,
|
||||
Status: resp.DialStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -307,7 +309,7 @@ func (ac *client) handleDialBack(s network.Stream) {
|
||||
}
|
||||
|
||||
func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool {
|
||||
if connLocalAddr == nil || dialedAddr == nil {
|
||||
if len(connLocalAddr) == 0 || len(dialedAddr) == 0 {
|
||||
return false
|
||||
}
|
||||
connLocalAddr = ac.normalizeMultiaddr(connLocalAddr)
|
||||
@@ -318,33 +320,32 @@ func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) boo
|
||||
if len(localProtos) != len(externalProtos) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(localProtos); i++ {
|
||||
for i, lp := range localProtos {
|
||||
ep := externalProtos[i]
|
||||
if i == 0 {
|
||||
switch externalProtos[i].Code {
|
||||
switch ep.Code {
|
||||
case ma.P_DNS, ma.P_DNSADDR:
|
||||
if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 {
|
||||
if lp.Code == ma.P_IP4 || lp.Code == ma.P_IP6 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
case ma.P_DNS4:
|
||||
if localProtos[i].Code == ma.P_IP4 {
|
||||
if lp.Code == ma.P_IP4 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
case ma.P_DNS6:
|
||||
if localProtos[i].Code == ma.P_IP6 {
|
||||
if lp.Code == ma.P_IP6 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
if localProtos[i].Code != externalProtos[i].Code {
|
||||
if lp.Code != ep.Code {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if localProtos[i].Code != externalProtos[i].Code {
|
||||
} else if lp.Code != ep.Code {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@@ -13,6 +13,7 @@ type autoNATSettings struct {
|
||||
now func() time.Time
|
||||
amplificatonAttackPreventionDialWait time.Duration
|
||||
metricsTracer MetricsTracer
|
||||
throttlePeerDuration time.Duration
|
||||
}
|
||||
|
||||
func defaultSettings() *autoNATSettings {
|
||||
@@ -25,6 +26,7 @@ func defaultSettings() *autoNATSettings {
|
||||
dataRequestPolicy: amplificationAttackPrevention,
|
||||
amplificatonAttackPreventionDialWait: 3 * time.Second,
|
||||
now: time.Now,
|
||||
throttlePeerDuration: defaultThrottlePeerDuration,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,3 +67,10 @@ func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func withThrottlePeerDuration(d time.Duration) AutoNATOption {
|
||||
return func(s *autoNATSettings) error {
|
||||
s.throttlePeerDuration = d
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@@ -59,10 +59,9 @@ type server struct {
|
||||
allowPrivateAddrs bool
|
||||
}
|
||||
|
||||
func newServer(host, dialer host.Host, s *autoNATSettings) *server {
|
||||
func newServer(dialer host.Host, s *autoNATSettings) *server {
|
||||
return &server{
|
||||
dialerHost: dialer,
|
||||
host: host,
|
||||
dialDataRequestPolicy: s.dataRequestPolicy,
|
||||
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
|
||||
allowPrivateAddrs: s.allowPrivateAddrs,
|
||||
@@ -79,7 +78,8 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server {
|
||||
}
|
||||
|
||||
// Enable attaches the stream handler to the host.
|
||||
func (as *server) Start() {
|
||||
func (as *server) Start(h host.Host) {
|
||||
as.host = h
|
||||
as.host.SetStreamHandler(DialProtocol, as.handleDialRequest)
|
||||
}
|
||||
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -46,8 +47,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
idAndWait(t, c, an)
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("black holed addr", func(t *testing.T) {
|
||||
@@ -64,8 +65,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"),
|
||||
SendDialData: true,
|
||||
}})
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("private addrs", func(t *testing.T) {
|
||||
@@ -76,8 +77,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
idAndWait(t, c, an)
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("relay addrs", func(t *testing.T) {
|
||||
@@ -89,8 +90,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(
|
||||
[]ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("no addr", func(t *testing.T) {
|
||||
@@ -113,8 +114,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
idAndWait(t, c, an)
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("msg too large", func(t *testing.T) {
|
||||
@@ -135,7 +136,6 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
require.ErrorIs(t, err, network.ErrReset)
|
||||
require.Equal(t, Result{}, res)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestServerDataRequest(t *testing.T) {
|
||||
@@ -178,8 +178,8 @@ func TestServerDataRequest(t *testing.T) {
|
||||
|
||||
require.Equal(t, Result{
|
||||
Addr: quicAddr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
|
||||
// Small messages should be rejected for dial data
|
||||
@@ -191,14 +191,11 @@ func TestServerDataRequest(t *testing.T) {
|
||||
func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
|
||||
const concurrentRequests = 5
|
||||
|
||||
// server will skip all tcp addresses
|
||||
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy(
|
||||
stallChan := make(chan struct{})
|
||||
an := newAutoNAT(t, nil, allowPrivateAddrs, withDataRequestPolicy(
|
||||
// stall all allowed requests
|
||||
func(_, dialAddr ma.Multiaddr) bool {
|
||||
<-doneChan
|
||||
<-stallChan
|
||||
return true
|
||||
}),
|
||||
WithServerRateLimit(10, 10, 10, concurrentRequests),
|
||||
@@ -207,16 +204,18 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
|
||||
defer an.Close()
|
||||
defer an.host.Close()
|
||||
|
||||
c := newAutoNAT(t, nil, allowPrivateAddrs)
|
||||
// server will skip all tcp addresses
|
||||
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
|
||||
c := newAutoNAT(t, dialer, allowPrivateAddrs)
|
||||
defer c.Close()
|
||||
defer c.host.Close()
|
||||
|
||||
idAndWait(t, c, an)
|
||||
|
||||
errChan := make(chan error)
|
||||
const N = 10
|
||||
// num concurrentRequests will stall and N will fail
|
||||
for i := 0; i < concurrentRequests+N; i++ {
|
||||
const n = 10
|
||||
// num concurrentRequests will stall and n will fail
|
||||
for i := 0; i < concurrentRequests+n; i++ {
|
||||
go func() {
|
||||
_, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}})
|
||||
errChan <- err
|
||||
@@ -224,17 +223,20 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// check N failures
|
||||
for i := 0; i < N; i++ {
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.Error(t, err)
|
||||
if !strings.Contains(err.Error(), "stream reset") && !strings.Contains(err.Error(), "E_REQUEST_REJECTED") {
|
||||
t.Fatalf("invalid error: %s expected: stream reset or E_REQUEST_REJECTED", err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("expected %d errors: got: %d", N, i)
|
||||
t.Fatalf("expected %d errors: got: %d", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
close(stallChan) // complete stalled requests
|
||||
// check concurrentRequests failures, as we won't send dial data
|
||||
close(doneChan)
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -290,8 +292,8 @@ func TestServerDataRequestJitter(t *testing.T) {
|
||||
|
||||
require.Equal(t, Result{
|
||||
Addr: quicAddr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
if took > 500*time.Millisecond {
|
||||
return
|
||||
@@ -320,8 +322,8 @@ func TestServerDial(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: unreachableAddr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPrivate,
|
||||
Status: pb.DialStatus_E_DIAL_ERROR,
|
||||
}, res)
|
||||
})
|
||||
|
||||
@@ -330,16 +332,16 @@ func TestServerDial(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: hostAddrs[0],
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
for _, addr := range c.host.Addrs() {
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
}
|
||||
})
|
||||
@@ -347,12 +349,8 @@ func TestServerDial(t *testing.T) {
|
||||
t.Run("dialback error", func(t *testing.T) {
|
||||
c.host.RemoveStreamHandler(DialBackProtocol)
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: hostAddrs[0],
|
||||
Reachability: network.ReachabilityUnknown,
|
||||
Status: pb.DialStatus_E_DIAL_BACK_ERROR,
|
||||
}, res)
|
||||
require.ErrorContains(t, err, "dial-back stream error")
|
||||
require.Equal(t, Result{}, res)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -396,7 +394,6 @@ func TestRateLimiter(t *testing.T) {
|
||||
|
||||
cl.AdvanceBy(10 * time.Second)
|
||||
require.True(t, r.Accept("peer3"))
|
||||
|
||||
}
|
||||
|
||||
func TestRateLimiterConcurrentRequests(t *testing.T) {
|
||||
@@ -558,22 +555,23 @@ func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: quicv4Addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
|
||||
// ipv6 address should require dial data
|
||||
_, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "invalid dial data request: low priority addr")
|
||||
require.ErrorContains(t, err, "invalid dial data request")
|
||||
require.ErrorContains(t, err, "low priority addr")
|
||||
|
||||
// ipv6 address should work fine with dial data
|
||||
res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: quicv6Addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user