basichost: use autonatv2 to verify reachability (#3231)

This introduces addrsReachabilityTracker that tracks reachability on
a set of addresses. It probes reachability for addresses periodically
and has an exponential backoff in case there are too many errors
or we don't have any valid autonatv2 peer.

There's no smartness in the address selection logic currently. We just
test all provided addresses. It also doesn't use the addresses provided
by `AddrsFactory`, so currently there's no way to get a user provided
address tested for reachability, something that would be a problem for
dns addresses. I intend to introduce an alternative to
`AddrsFactory`, something like, `AnnounceAddrs(addrs []ma.Multiaddr)`
that's just appended to the set of addresses that we have, and check
reachability for those addresses.

There's only one method exposed in the BasicHost right now that's
`ReachableAddrs() []ma.Multiadd`r that returns the host's reachable
addrs. Users can also use the event `EvtHostReachableAddrsChanged`
to be notified when any addrs reachability changes.
This commit is contained in:
sukun
2025-06-03 17:13:56 +05:30
committed by GitHub
parent 31c8c83308
commit fb1d9512e8
15 changed files with 2420 additions and 321 deletions

View File

@@ -8,6 +8,8 @@ linters:
- revive - revive
- unused - unused
- prealloc - prealloc
disable:
- errcheck
settings: settings:
revive: revive:

View File

@@ -33,6 +33,7 @@ import (
routed "github.com/libp2p/go-libp2p/p2p/host/routed" routed "github.com/libp2p/go-libp2p/p2p/host/routed"
"github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/net/swarm"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
@@ -413,15 +414,7 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
return fxopts, nil return fxopts, nil
} }
func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus, an *autonatv2.AutoNAT) (*bhost.BasicHost, error) {
var autonatv2Dialer host.Host
if cfg.EnableAutoNATv2 {
ah, err := cfg.makeAutoNATV2Host()
if err != nil {
return nil, err
}
autonatv2Dialer = ah
}
h, err := bhost.NewHost(swrm, &bhost.HostOpts{ h, err := bhost.NewHost(swrm, &bhost.HostOpts{
EventBus: eventBus, EventBus: eventBus,
ConnManager: cfg.ConnManager, ConnManager: cfg.ConnManager,
@@ -437,8 +430,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B
EnableMetrics: !cfg.DisableMetrics, EnableMetrics: !cfg.DisableMetrics,
PrometheusRegisterer: cfg.PrometheusRegisterer, PrometheusRegisterer: cfg.PrometheusRegisterer,
DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery, DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery,
EnableAutoNATv2: cfg.EnableAutoNATv2, AutoNATv2: an,
AutoNATv2Dialer: autonatv2Dialer,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -517,6 +509,24 @@ func (cfg *Config) NewNode() (host.Host, error) {
}) })
return sw, nil return sw, nil
}), }),
fx.Provide(func() (*autonatv2.AutoNAT, error) {
if !cfg.EnableAutoNATv2 {
return nil, nil
}
ah, err := cfg.makeAutoNATV2Host()
if err != nil {
return nil, err
}
var mt autonatv2.MetricsTracer
if !cfg.DisableMetrics {
mt = autonatv2.NewMetricsTracer(cfg.PrometheusRegisterer)
}
autoNATv2, err := autonatv2.New(ah, autonatv2.WithMetricsTracer(mt))
if err != nil {
return nil, fmt.Errorf("failed to create autonatv2: %w", err)
}
return autoNATv2, nil
}),
fx.Provide(cfg.newBasicHost), fx.Provide(cfg.newBasicHost),
fx.Provide(func(bh *bhost.BasicHost) identify.IDService { fx.Provide(func(bh *bhost.BasicHost) identify.IDService {
return bh.IDService() return bh.IDService()

View File

@@ -2,6 +2,7 @@ package event
import ( import (
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
ma "github.com/multiformats/go-multiaddr"
) )
// EvtLocalReachabilityChanged is an event struct to be emitted when the local's // EvtLocalReachabilityChanged is an event struct to be emitted when the local's
@@ -11,3 +12,12 @@ import (
type EvtLocalReachabilityChanged struct { type EvtLocalReachabilityChanged struct {
Reachability network.Reachability Reachability network.Reachability
} }
// EvtHostReachableAddrsChanged is sent when host's reachable or unreachable addresses change
// Reachable and Unreachable both contain only Public IP or DNS addresses
//
// Experimental: This API is unstable. Any changes to this event will be done without a deprecation notice.
type EvtHostReachableAddrsChanged struct {
Reachable []ma.Multiaddr
Unreachable []ma.Multiaddr
}

View File

@@ -2,6 +2,7 @@ package basichost
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"slices" "slices"
@@ -13,6 +14,7 @@ import (
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -27,24 +29,36 @@ type observedAddrsManager interface {
ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr
} }
type hostAddrs struct {
addrs []ma.Multiaddr
localAddrs []ma.Multiaddr
reachableAddrs []ma.Multiaddr
unreachableAddrs []ma.Multiaddr
relayAddrs []ma.Multiaddr
}
type addrsManager struct { type addrsManager struct {
eventbus event.Bus bus event.Bus
natManager NATManager natManager NATManager
addrsFactory AddrsFactory addrsFactory AddrsFactory
listenAddrs func() []ma.Multiaddr listenAddrs func() []ma.Multiaddr
transportForListening func(ma.Multiaddr) transport.Transport transportForListening func(ma.Multiaddr) transport.Transport
observedAddrsManager observedAddrsManager observedAddrsManager observedAddrsManager
interfaceAddrs *interfaceAddrsCache interfaceAddrs *interfaceAddrsCache
addrsReachabilityTracker *addrsReachabilityTracker
// addrsUpdatedChan is notified when addrs change. This is provided by the caller.
addrsUpdatedChan chan struct{}
// triggerAddrsUpdateChan is used to trigger an addresses update. // triggerAddrsUpdateChan is used to trigger an addresses update.
triggerAddrsUpdateChan chan struct{} triggerAddrsUpdateChan chan struct{}
// addrsUpdatedChan is notified when addresses change. // triggerReachabilityUpdate is notified when reachable addrs are updated.
addrsUpdatedChan chan struct{} triggerReachabilityUpdate chan struct{}
hostReachability atomic.Pointer[network.Reachability] hostReachability atomic.Pointer[network.Reachability]
addrsMx sync.RWMutex // protects fields below addrsMx sync.RWMutex
localAddrs []ma.Multiaddr currentAddrs hostAddrs
relayAddrs []ma.Multiaddr
wg sync.WaitGroup wg sync.WaitGroup
ctx context.Context ctx context.Context
@@ -52,35 +66,49 @@ type addrsManager struct {
} }
func newAddrsManager( func newAddrsManager(
eventbus event.Bus, bus event.Bus,
natmgr NATManager, natmgr NATManager,
addrsFactory AddrsFactory, addrsFactory AddrsFactory,
listenAddrs func() []ma.Multiaddr, listenAddrs func() []ma.Multiaddr,
transportForListening func(ma.Multiaddr) transport.Transport, transportForListening func(ma.Multiaddr) transport.Transport,
observedAddrsManager observedAddrsManager, observedAddrsManager observedAddrsManager,
addrsUpdatedChan chan struct{}, addrsUpdatedChan chan struct{},
client autonatv2Client,
) (*addrsManager, error) { ) (*addrsManager, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
as := &addrsManager{ as := &addrsManager{
eventbus: eventbus, bus: bus,
listenAddrs: listenAddrs, listenAddrs: listenAddrs,
transportForListening: transportForListening, transportForListening: transportForListening,
observedAddrsManager: observedAddrsManager, observedAddrsManager: observedAddrsManager,
natManager: natmgr, natManager: natmgr,
addrsFactory: addrsFactory, addrsFactory: addrsFactory,
triggerAddrsUpdateChan: make(chan struct{}, 1), triggerAddrsUpdateChan: make(chan struct{}, 1),
addrsUpdatedChan: addrsUpdatedChan, triggerReachabilityUpdate: make(chan struct{}, 1),
interfaceAddrs: &interfaceAddrsCache{}, addrsUpdatedChan: addrsUpdatedChan,
ctx: ctx, interfaceAddrs: &interfaceAddrsCache{},
ctxCancel: cancel, ctx: ctx,
ctxCancel: cancel,
} }
unknownReachability := network.ReachabilityUnknown unknownReachability := network.ReachabilityUnknown
as.hostReachability.Store(&unknownReachability) as.hostReachability.Store(&unknownReachability)
if client != nil {
as.addrsReachabilityTracker = newAddrsReachabilityTracker(client, as.triggerReachabilityUpdate, nil)
}
return as, nil return as, nil
} }
func (a *addrsManager) Start() error { func (a *addrsManager) Start() error {
return a.background() // TODO: add Start method to NATMgr
if a.addrsReachabilityTracker != nil {
err := a.addrsReachabilityTracker.Start()
if err != nil {
return fmt.Errorf("error starting addrs reachability tracker: %s", err)
}
}
return a.startBackgroundWorker()
} }
func (a *addrsManager) Close() { func (a *addrsManager) Close() {
@@ -91,10 +119,18 @@ func (a *addrsManager) Close() {
log.Warnf("error closing natmgr: %s", err) log.Warnf("error closing natmgr: %s", err)
} }
} }
if a.addrsReachabilityTracker != nil {
err := a.addrsReachabilityTracker.Close()
if err != nil {
log.Warnf("error closing addrs reachability tracker: %s", err)
}
}
a.wg.Wait() a.wg.Wait()
} }
func (a *addrsManager) NetNotifee() network.Notifiee { func (a *addrsManager) NetNotifee() network.Notifiee {
// Updating addrs in sync provides the nice property that
// host.Addrs() just after host.Network().Listen(x) will return x
return &network.NotifyBundle{ return &network.NotifyBundle{
ListenF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() }, ListenF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() },
ListenCloseF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() }, ListenCloseF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() },
@@ -102,37 +138,53 @@ func (a *addrsManager) NetNotifee() network.Notifiee {
} }
func (a *addrsManager) triggerAddrsUpdate() { func (a *addrsManager) triggerAddrsUpdate() {
// This is ugly, we update here *and* in the background loop, but this ensures the nice property a.updateAddrs(false, nil)
// that host.Addrs after host.Network().Listen(...) will return the recently added listen address.
a.updateLocalAddrs()
select { select {
case a.triggerAddrsUpdateChan <- struct{}{}: case a.triggerAddrsUpdateChan <- struct{}{}:
default: default:
} }
} }
func (a *addrsManager) background() error { func (a *addrsManager) startBackgroundWorker() error {
autoRelayAddrsSub, err := a.eventbus.Subscribe(new(event.EvtAutoRelayAddrsUpdated)) autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager"))
if err != nil { if err != nil {
return fmt.Errorf("error subscribing to auto relay addrs: %s", err) return fmt.Errorf("error subscribing to auto relay addrs: %s", err)
} }
autonatReachabilitySub, err := a.eventbus.Subscribe(new(event.EvtLocalReachabilityChanged)) autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager"))
if err != nil { if err != nil {
return fmt.Errorf("error subscribing to autonat reachability: %s", err) err1 := autoRelayAddrsSub.Close()
if err1 != nil {
err1 = fmt.Errorf("error closign autorelaysub: %w", err1)
}
err = fmt.Errorf("error subscribing to autonat reachability: %s", err)
return errors.Join(err, err1)
} }
// ensure that we have the correct address after returning from Start() emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful)
// update local addrs if err != nil {
a.updateLocalAddrs() err1 := autoRelayAddrsSub.Close()
if err1 != nil {
err1 = fmt.Errorf("error closing autorelaysub: %w", err1)
}
err2 := autonatReachabilitySub.Close()
if err2 != nil {
err2 = fmt.Errorf("error closing autonat reachability: %w", err1)
}
err = fmt.Errorf("error subscribing to autonat reachability: %s", err)
return errors.Join(err, err1, err2)
}
var relayAddrs []ma.Multiaddr
// update relay addrs in case we're private // update relay addrs in case we're private
select { select {
case e := <-autoRelayAddrsSub.Out(): case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok { if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
a.updateRelayAddrs(evt.RelayAddrs) relayAddrs = slices.Clone(evt.RelayAddrs)
} }
default: default:
} }
select { select {
case e := <-autonatReachabilitySub.Out(): case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok { if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
@@ -140,70 +192,149 @@ func (a *addrsManager) background() error {
} }
default: default:
} }
// update addresses before starting the worker loop. This ensures that any address updates
// before calling addrsManager.Start are correctly reported after Start returns.
a.updateAddrs(true, relayAddrs)
a.wg.Add(1) a.wg.Add(1)
go func() { go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, relayAddrs)
defer a.wg.Done() return nil
defer func() { }
err := autoRelayAddrsSub.Close()
if err != nil {
log.Warnf("error closing auto relay addrs sub: %s", err)
}
}()
defer func() {
err := autonatReachabilitySub.Close()
if err != nil {
log.Warnf("error closing autonat reachability sub: %s", err)
}
}()
ticker := time.NewTicker(addrChangeTickrInterval) func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription,
defer ticker.Stop() emitter event.Emitter, relayAddrs []ma.Multiaddr,
var prev []ma.Multiaddr ) {
for { defer a.wg.Done()
a.updateLocalAddrs() defer func() {
curr := a.Addrs() err := autoRelayAddrsSub.Close()
if a.areAddrsDifferent(prev, curr) { if err != nil {
log.Debugf("host addresses updated: %s", curr) log.Warnf("error closing auto relay addrs sub: %s", err)
select { }
case a.addrsUpdatedChan <- struct{}{}: err = autonatReachabilitySub.Close()
default: if err != nil {
} log.Warnf("error closing autonat reachability sub: %s", err)
}
prev = curr
select {
case <-ticker.C:
case <-a.triggerAddrsUpdateChan:
case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
a.updateRelayAddrs(evt.RelayAddrs)
}
case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
a.hostReachability.Store(&evt.Reachability)
}
case <-a.ctx.Done():
return
}
} }
}() }()
return nil
ticker := time.NewTicker(addrChangeTickrInterval)
defer ticker.Stop()
var previousAddrs hostAddrs
for {
currAddrs := a.updateAddrs(true, relayAddrs)
a.notifyAddrsChanged(emitter, previousAddrs, currAddrs)
previousAddrs = currAddrs
select {
case <-ticker.C:
case <-a.triggerAddrsUpdateChan:
case <-a.triggerReachabilityUpdate:
case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
relayAddrs = slices.Clone(evt.RelayAddrs)
}
case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
a.hostReachability.Store(&evt.Reachability)
}
case <-a.ctx.Done():
return
}
}
}
// updateAddrs updates the addresses of the host and returns the new updated
// addrs
func (a *addrsManager) updateAddrs(updateRelayAddrs bool, relayAddrs []ma.Multiaddr) hostAddrs {
// Must lock while doing both recompute and update as this method is called from
// multiple goroutines.
a.addrsMx.Lock()
defer a.addrsMx.Unlock()
localAddrs := a.getLocalAddrs()
var currReachableAddrs, currUnreachableAddrs []ma.Multiaddr
if a.addrsReachabilityTracker != nil {
currReachableAddrs, currUnreachableAddrs = a.getConfirmedAddrs(localAddrs)
}
if !updateRelayAddrs {
relayAddrs = a.currentAddrs.relayAddrs
} else {
// Copy the callers slice
relayAddrs = slices.Clone(relayAddrs)
}
currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs)
a.currentAddrs = hostAddrs{
addrs: append(a.currentAddrs.addrs[:0], currAddrs...),
localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...),
reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...),
unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...),
relayAddrs: append(a.currentAddrs.relayAddrs[:0], relayAddrs...),
}
return hostAddrs{
localAddrs: localAddrs,
addrs: currAddrs,
reachableAddrs: currReachableAddrs,
unreachableAddrs: currUnreachableAddrs,
relayAddrs: relayAddrs,
}
}
func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) {
if areAddrsDifferent(previous.localAddrs, current.localAddrs) {
log.Debugf("host local addresses updated: %s", current.localAddrs)
if a.addrsReachabilityTracker != nil {
a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs)
}
}
if areAddrsDifferent(previous.addrs, current.addrs) {
log.Debugf("host addresses updated: %s", current.localAddrs)
select {
case a.addrsUpdatedChan <- struct{}{}:
default:
}
}
// We *must* send both reachability changed and addrs changed events from the
// same goroutine to ensure correct ordering
// Consider the events:
// - addr x discovered
// - addr x is reachable
// - addr x removed
// We must send these events in the same order. It'll be confusing for consumers
// if the reachable event is received after the addr removed event.
if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) ||
areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) {
log.Debugf("host reachable addrs updated: %s", current.localAddrs)
if err := emitter.Emit(event.EvtHostReachableAddrsChanged{
Reachable: slices.Clone(current.reachableAddrs),
Unreachable: slices.Clone(current.unreachableAddrs),
}); err != nil {
log.Errorf("error sending host reachable addrs changed event: %s", err)
}
}
} }
// Addrs returns the node's dialable addresses both public and private. // Addrs returns the node's dialable addresses both public and private.
// If autorelay is enabled and node reachability is private, it returns // If autorelay is enabled and node reachability is private, it returns
// the node's relay addresses and private network addresses. // the node's relay addresses and private network addresses.
func (a *addrsManager) Addrs() []ma.Multiaddr { func (a *addrsManager) Addrs() []ma.Multiaddr {
addrs := a.DirectAddrs() a.addrsMx.RLock()
directAddrs := slices.Clone(a.currentAddrs.localAddrs)
relayAddrs := slices.Clone(a.currentAddrs.relayAddrs)
a.addrsMx.RUnlock()
return a.getAddrs(directAddrs, relayAddrs)
}
// getAddrs returns the node's dialable addresses. Mutates localAddrs
func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr {
addrs := localAddrs
rch := a.hostReachability.Load() rch := a.hostReachability.Load()
if rch != nil && *rch == network.ReachabilityPrivate { if rch != nil && *rch == network.ReachabilityPrivate {
a.addrsMx.RLock()
// Delete public addresses if the node's reachability is private, and we have relay addresses // Delete public addresses if the node's reachability is private, and we have relay addresses
if len(a.relayAddrs) > 0 { if len(relayAddrs) > 0 {
addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr) addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr)
addrs = append(addrs, a.relayAddrs...) addrs = append(addrs, relayAddrs...)
} }
a.addrsMx.RUnlock()
} }
// Make a copy. Consumers can modify the slice elements // Make a copy. Consumers can modify the slice elements
addrs = slices.Clone(a.addrsFactory(addrs)) addrs = slices.Clone(a.addrsFactory(addrs))
@@ -213,7 +344,8 @@ func (a *addrsManager) Addrs() []ma.Multiaddr {
return addrs return addrs
} }
// HolePunchAddrs returns the node's public direct listen addresses for hole punching. // HolePunchAddrs returns all the host's direct public addresses, reachable or unreachable,
// suitable for hole punching.
func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr { func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
addrs := a.DirectAddrs() addrs := a.DirectAddrs()
addrs = slices.Clone(a.addrsFactory(addrs)) addrs = slices.Clone(a.addrsFactory(addrs))
@@ -230,26 +362,23 @@ func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
func (a *addrsManager) DirectAddrs() []ma.Multiaddr { func (a *addrsManager) DirectAddrs() []ma.Multiaddr {
a.addrsMx.RLock() a.addrsMx.RLock()
defer a.addrsMx.RUnlock() defer a.addrsMx.RUnlock()
return slices.Clone(a.localAddrs) return slices.Clone(a.currentAddrs.localAddrs)
} }
func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) { // ReachableAddrs returns all addresses of the host that are reachable from the internet
a.addrsMx.Lock() func (a *addrsManager) ReachableAddrs() []ma.Multiaddr {
defer a.addrsMx.Unlock() a.addrsMx.RLock()
a.relayAddrs = append(a.relayAddrs[:0], addrs...) defer a.addrsMx.RUnlock()
return slices.Clone(a.currentAddrs.reachableAddrs)
}
func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
reachableAddrs, unreachableAddrs = a.addrsReachabilityTracker.ConfirmedAddrs()
return removeNotInSource(reachableAddrs, localAddrs), removeNotInSource(unreachableAddrs, localAddrs)
} }
var p2pCircuitAddr = ma.StringCast("/p2p-circuit") var p2pCircuitAddr = ma.StringCast("/p2p-circuit")
func (a *addrsManager) updateLocalAddrs() {
localAddrs := a.getLocalAddrs()
slices.SortFunc(localAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
a.addrsMx.Lock()
a.localAddrs = localAddrs
a.addrsMx.Unlock()
}
func (a *addrsManager) getLocalAddrs() []ma.Multiaddr { func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
listenAddrs := a.listenAddrs() listenAddrs := a.listenAddrs()
if len(listenAddrs) == 0 { if len(listenAddrs) == 0 {
@@ -260,8 +389,6 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs) finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs)
finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All()) finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All())
finalAddrs = ma.Unique(finalAddrs)
// Remove "/p2p-circuit" addresses from the list. // Remove "/p2p-circuit" addresses from the list.
// The p2p-circuit listener reports its address as just /p2p-circuit. This is // The p2p-circuit listener reports its address as just /p2p-circuit. This is
// useless for dialing. Users need to manage their circuit addresses themselves, // useless for dialing. Users need to manage their circuit addresses themselves,
@@ -278,6 +405,8 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
// Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered
// using identify. // using identify.
finalAddrs = a.addCertHashes(finalAddrs) finalAddrs = a.addCertHashes(finalAddrs)
finalAddrs = ma.Unique(finalAddrs)
slices.SortFunc(finalAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
return finalAddrs return finalAddrs
} }
@@ -408,7 +537,7 @@ func (a *addrsManager) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr {
return addrs return addrs
} }
func (a *addrsManager) areAddrsDifferent(prev, current []ma.Multiaddr) bool { func areAddrsDifferent(prev, current []ma.Multiaddr) bool {
// TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs // TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs
prev = ma.Unique(prev) prev = ma.Unique(prev)
current = ma.Unique(current) current = ma.Unique(current)
@@ -547,3 +676,31 @@ func (i *interfaceAddrsCache) updateUnlocked() {
} }
} }
} }
// removeNotInSource removes items from addrs that are not present in source.
// Modifies the addrs slice in place
// addrs and source must be sorted using multiaddr.Compare.
func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr {
j := 0
// mark entries not in source as nil
for i, a := range addrs {
// move right as long as a > source[j]
for j < len(source) && a.Compare(source[j]) > 0 {
j++
}
// a is not in source if we've reached the end, or a is lesser
if j == len(source) || a.Compare(source[j]) < 0 {
addrs[i] = nil
}
// a is in source, nothing to do
}
// j is the current element, i is the lowest index nil element
i := 0
for j := range len(addrs) {
if addrs[j] != nil {
addrs[i], addrs[j] = addrs[j], addrs[i]
i++
}
}
return addrs[:i]
}

View File

@@ -1,13 +1,17 @@
package basichost package basichost
import ( import (
"context"
"errors"
"fmt" "fmt"
"slices"
"testing" "testing"
"time" "time"
"github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -30,7 +34,7 @@ func TestAppendNATAddrs(t *testing.T) {
// nat mapping success, obsaddress ignored // nat mapping success, obsaddress ignored
Listen: ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1"), Listen: ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1"),
Nat: ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1"), Nat: ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1"),
ObsAddrFunc: func(m ma.Multiaddr) []ma.Multiaddr { ObsAddrFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/100/quic-v1")} return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/100/quic-v1")}
}, },
Expected: []ma.Multiaddr{ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1")}, Expected: []ma.Multiaddr{ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1")},
@@ -116,7 +120,7 @@ func TestAppendNATAddrs(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
as := &addrsManager{ as := &addrsManager{
natManager: &mockNatManager{ natManager: &mockNatManager{
GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { GetMappingFunc: func(_ ma.Multiaddr) ma.Multiaddr {
return tc.Nat return tc.Nat
}, },
}, },
@@ -135,7 +139,7 @@ type mockNatManager struct {
GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr
} }
func (m *mockNatManager) Close() error { func (*mockNatManager) Close() error {
return nil return nil
} }
@@ -146,7 +150,7 @@ func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
return m.GetMappingFunc(addr) return m.GetMappingFunc(addr)
} }
func (m *mockNatManager) HasDiscoveredNAT() bool { func (*mockNatManager) HasDiscoveredNAT() bool {
return true return true
} }
@@ -170,6 +174,8 @@ type addrsManagerArgs struct {
AddrsFactory AddrsFactory AddrsFactory AddrsFactory
ObservedAddrsManager observedAddrsManager ObservedAddrsManager observedAddrsManager
ListenAddrs func() []ma.Multiaddr ListenAddrs func() []ma.Multiaddr
AutoNATClient autonatv2Client
Bus event.Bus
} }
type addrsManagerTestCase struct { type addrsManagerTestCase struct {
@@ -179,13 +185,16 @@ type addrsManagerTestCase struct {
} }
func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTestCase { func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTestCase {
eb := eventbus.NewBus() eb := args.Bus
if eb == nil {
eb = eventbus.NewBus()
}
if args.AddrsFactory == nil { if args.AddrsFactory == nil {
args.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } args.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs }
} }
addrsUpdatedChan := make(chan struct{}, 1) addrsUpdatedChan := make(chan struct{}, 1)
am, err := newAddrsManager( am, err := newAddrsManager(
eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, args.AutoNATClient,
) )
require.NoError(t, err) require.NoError(t, err)
@@ -196,6 +205,7 @@ func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTe
rchEm, err := eb.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) rchEm, err := eb.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(am.Close)
return addrsManagerTestCase{ return addrsManagerTestCase{
addrsManager: am, addrsManager: am,
PushRelay: func(relayAddrs []ma.Multiaddr) { PushRelay: func(relayAddrs []ma.Multiaddr) {
@@ -326,7 +336,7 @@ func TestAddrsManager(t *testing.T) {
} }
am := newAddrsManagerTestCase(t, addrsManagerArgs{ am := newAddrsManagerTestCase(t, addrsManagerArgs{
ObservedAddrsManager: &mockObservedAddrs{ ObservedAddrsManager: &mockObservedAddrs{
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
return quicAddrs return quicAddrs
}, },
}, },
@@ -342,7 +352,7 @@ func TestAddrsManager(t *testing.T) {
t.Run("public addrs removed when private", func(t *testing.T) { t.Run("public addrs removed when private", func(t *testing.T) {
am := newAddrsManagerTestCase(t, addrsManagerArgs{ am := newAddrsManagerTestCase(t, addrsManagerArgs{
ObservedAddrsManager: &mockObservedAddrs{ ObservedAddrsManager: &mockObservedAddrs{
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{publicQUIC} return []ma.Multiaddr{publicQUIC}
}, },
}, },
@@ -384,7 +394,7 @@ func TestAddrsManager(t *testing.T) {
return nil return nil
}, },
ObservedAddrsManager: &mockObservedAddrs{ ObservedAddrsManager: &mockObservedAddrs{
ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { ObservedAddrsForFunc: func(_ ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{publicQUIC} return []ma.Multiaddr{publicQUIC}
}, },
}, },
@@ -404,7 +414,7 @@ func TestAddrsManager(t *testing.T) {
t.Run("updates addresses on signaling", func(t *testing.T) { t.Run("updates addresses on signaling", func(t *testing.T) {
updateChan := make(chan struct{}) updateChan := make(chan struct{})
am := newAddrsManagerTestCase(t, addrsManagerArgs{ am := newAddrsManagerTestCase(t, addrsManagerArgs{
AddrsFactory: func(addrs []ma.Multiaddr) []ma.Multiaddr { AddrsFactory: func(_ []ma.Multiaddr) []ma.Multiaddr {
select { select {
case <-updateChan: case <-updateChan:
return []ma.Multiaddr{publicQUIC} return []ma.Multiaddr{publicQUIC}
@@ -425,17 +435,95 @@ func TestAddrsManager(t *testing.T) {
}) })
} }
func TestAddrsManagerReachabilityEvent(t *testing.T) {
publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1")
publicQUIC2, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1235/quic-v1")
publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234")
bus := eventbus.NewBus()
sub, err := bus.Subscribe(new(event.EvtHostReachableAddrsChanged))
require.NoError(t, err)
defer sub.Close()
am := newAddrsManagerTestCase(t, addrsManagerArgs{
Bus: bus,
// currently they aren't being passed to the reachability tracker
ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicQUIC2, publicTCP} },
AutoNATClient: mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
if reqs[0].Addr.Equal(publicQUIC) {
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
} else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) {
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil
}
return autonatv2.Result{}, errors.New("invalid")
},
},
})
reachableAddrs := []ma.Multiaddr{publicQUIC}
unreachableAddrs := []ma.Multiaddr{publicTCP, publicQUIC2}
select {
case e := <-sub.Out():
evt := e.(event.EvtHostReachableAddrsChanged)
require.ElementsMatch(t, reachableAddrs, evt.Reachable)
require.ElementsMatch(t, unreachableAddrs, evt.Unreachable)
require.ElementsMatch(t, reachableAddrs, am.ReachableAddrs())
case <-time.After(5 * time.Second):
t.Fatal("expected event for reachability change")
}
}
func TestRemoveIfNotInSource(t *testing.T) {
var addrs []ma.Multiaddr
for i := 0; i < 10; i++ {
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i)))
}
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
cases := []struct {
addrs []ma.Multiaddr
source []ma.Multiaddr
expected []ma.Multiaddr
}{
{},
{addrs: slices.Clone(addrs[:5]), source: nil, expected: nil},
{addrs: nil, source: addrs, expected: nil},
{addrs: []ma.Multiaddr{addrs[0]}, source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}},
{addrs: slices.Clone(addrs), source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}},
{addrs: slices.Clone(addrs), source: slices.Clone(addrs[5:]), expected: slices.Clone(addrs[5:])},
{addrs: slices.Clone(addrs[:5]), source: []ma.Multiaddr{addrs[0], addrs[2], addrs[8]}, expected: []ma.Multiaddr{addrs[0], addrs[2]}},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
addrs := removeNotInSource(tc.addrs, tc.source)
require.ElementsMatch(t, tc.expected, addrs, "%s\n%s", tc.expected, tc.addrs)
})
}
}
func BenchmarkAreAddrsDifferent(b *testing.B) { func BenchmarkAreAddrsDifferent(b *testing.B) {
var addrs [10]ma.Multiaddr var addrs [10]ma.Multiaddr
for i := 0; i < len(addrs); i++ { for i := 0; i < len(addrs); i++ {
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i)) addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i))
} }
am := &addrsManager{}
b.Run("areAddrsDifferent", func(b *testing.B) { b.Run("areAddrsDifferent", func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
am.areAddrsDifferent(addrs[:], addrs[:]) areAddrsDifferent(addrs[:], addrs[:])
} }
}) })
} }
func BenchmarkRemoveIfNotInSource(b *testing.B) {
var addrs [10]ma.Multiaddr
for i := 0; i < len(addrs); i++ {
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i))
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
removeNotInSource(slices.Clone(addrs[:5]), addrs[:])
}
}

View File

@@ -0,0 +1,666 @@
package basichost
import (
"context"
"errors"
"fmt"
"math"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type autonatv2Client interface {
GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error)
}
const (
// maxAddrsPerRequest is the maximum number of addresses to probe in a single request
maxAddrsPerRequest = 10
// maxTrackedAddrs is the maximum number of addresses to track
// 10 addrs per transport for 5 transports
maxTrackedAddrs = 50
// defaultMaxConcurrency is the default number of concurrent workers for reachability checks
defaultMaxConcurrency = 5
// newAddrsProbeDelay is the delay before probing new addr's reachability.
newAddrsProbeDelay = 1 * time.Second
)
// addrsReachabilityTracker tracks reachability for addresses.
// Use UpdateAddrs to provide addresses for tracking reachability.
// reachabilityUpdateCh is notified when reachability for any of the tracked address changes.
type addrsReachabilityTracker struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
client autonatv2Client
// reachabilityUpdateCh is used to notify when reachability may have changed
reachabilityUpdateCh chan struct{}
maxConcurrency int
newAddrsProbeDelay time.Duration
probeManager *probeManager
newAddrs chan []ma.Multiaddr
clock clock.Clock
mx sync.Mutex
reachableAddrs []ma.Multiaddr
unreachableAddrs []ma.Multiaddr
}
// newAddrsReachabilityTracker returns a new addrsReachabilityTracker.
// reachabilityUpdateCh is notified when reachability for any of the tracked address changes.
func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh chan struct{}, cl clock.Clock) *addrsReachabilityTracker {
ctx, cancel := context.WithCancel(context.Background())
if cl == nil {
cl = clock.New()
}
return &addrsReachabilityTracker{
ctx: ctx,
cancel: cancel,
client: client,
reachabilityUpdateCh: reachabilityUpdateCh,
probeManager: newProbeManager(cl.Now),
newAddrsProbeDelay: newAddrsProbeDelay,
maxConcurrency: defaultMaxConcurrency,
newAddrs: make(chan []ma.Multiaddr, 1),
clock: cl,
}
}
func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) {
select {
case r.newAddrs <- slices.Clone(addrs):
case <-r.ctx.Done():
}
}
func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
r.mx.Lock()
defer r.mx.Unlock()
return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs)
}
func (r *addrsReachabilityTracker) Start() error {
r.wg.Add(1)
go r.background()
return nil
}
func (r *addrsReachabilityTracker) Close() error {
r.cancel()
r.wg.Wait()
return nil
}
const (
// defaultReachabilityRefreshInterval is the default interval to refresh reachability.
// In steady state, we check for any required probes every refresh interval.
// This doesn't mean we'll probe for any particular address, only that we'll check
// if any address needs to be probed.
defaultReachabilityRefreshInterval = 5 * time.Minute
// maxBackoffInterval is the maximum back off in case we're unable to probe for reachability.
// We may be unable to confirm addresses in case there are no valid peers with autonatv2
// or the autonatv2 subsystem is consistently erroring.
maxBackoffInterval = 5 * time.Minute
// backoffStartInterval is the initial back off in case we're unable to probe for reachability.
backoffStartInterval = 5 * time.Second
)
func (r *addrsReachabilityTracker) background() {
defer r.wg.Done()
// probeTicker is used to trigger probes at regular intervals
probeTicker := r.clock.Ticker(defaultReachabilityRefreshInterval)
defer probeTicker.Stop()
// probeTimer is used to trigger probes at specific times
probeTimer := r.clock.Timer(time.Duration(math.MaxInt64))
defer probeTimer.Stop()
nextProbeTime := time.Time{}
var task reachabilityTask
var backoffInterval time.Duration
var currReachable, currUnreachable, prevReachable, prevUnreachable []ma.Multiaddr
for {
select {
case <-probeTicker.C:
// don't start a probe if we have a scheduled probe
if task.BackoffCh == nil && nextProbeTime.IsZero() {
task = r.refreshReachability()
}
case <-probeTimer.C:
if task.BackoffCh == nil {
task = r.refreshReachability()
}
nextProbeTime = time.Time{}
case backoff := <-task.BackoffCh:
task = reachabilityTask{}
// On completion, start the next probe immediately, or wait for backoff.
// In case there are no further probes, the reachability tracker will return an empty task,
// which hangs forever. Eventually, we'll refresh again when the ticker fires.
if backoff {
backoffInterval = newBackoffInterval(backoffInterval)
} else {
backoffInterval = -1 * time.Second // negative to trigger next probe immediately
}
nextProbeTime = r.clock.Now().Add(backoffInterval)
case addrs := <-r.newAddrs:
if task.BackoffCh != nil { // cancel running task.
task.Cancel()
<-task.BackoffCh // ignore backoff from cancelled task
task = reachabilityTask{}
}
r.updateTrackedAddrs(addrs)
newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay)
if nextProbeTime.Before(newAddrsNextTime) {
nextProbeTime = newAddrsNextTime
}
case <-r.ctx.Done():
if task.BackoffCh != nil {
task.Cancel()
<-task.BackoffCh
task = reachabilityTask{}
}
return
}
currReachable, currUnreachable = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0])
if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) {
r.notify()
}
prevReachable = append(prevReachable[:0], currReachable...)
prevUnreachable = append(prevUnreachable[:0], currUnreachable...)
if !nextProbeTime.IsZero() {
probeTimer.Reset(nextProbeTime.Sub(r.clock.Now()))
}
}
}
func newBackoffInterval(current time.Duration) time.Duration {
if current <= 0 {
return backoffStartInterval
}
current *= 2
if current > maxBackoffInterval {
return maxBackoffInterval
}
return current
}
func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
reachable, unreachable = r.probeManager.AppendConfirmedAddrs(reachable, unreachable)
r.mx.Lock()
r.reachableAddrs = append(r.reachableAddrs[:0], reachable...)
r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...)
r.mx.Unlock()
return reachable, unreachable
}
func (r *addrsReachabilityTracker) notify() {
select {
case r.reachabilityUpdateCh <- struct{}{}:
default:
}
}
func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) {
addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool {
return !manet.IsPublicAddr(a)
})
if len(addrs) > maxTrackedAddrs {
log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs)
addrs = addrs[:maxTrackedAddrs]
}
r.probeManager.UpdateAddrs(addrs)
}
type probe = []autonatv2.Request
const probeTimeout = 30 * time.Second
// reachabilityTask is a task to refresh reachability.
// Waiting on the zero value blocks forever.
type reachabilityTask struct {
Cancel context.CancelFunc
// BackoffCh returns whether the caller should backoff before
// refreshing reachability
BackoffCh chan bool
}
func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask {
if len(r.probeManager.GetProbe()) == 0 {
return reachabilityTask{}
}
resCh := make(chan bool, 1)
ctx, cancel := context.WithTimeout(r.ctx, 5*time.Minute)
r.wg.Add(1)
// We run probes provided by addrsTracker. It stops probing when any
// of the following happens:
// - there are no more probes to run
// - context is completed
// - there are too many consecutive failures from the client
// - the client has no valid peers to probe
go func() {
defer r.wg.Done()
defer cancel()
client := &errCountingClient{autonatv2Client: r.client, MaxConsecutiveErrors: maxConsecutiveErrors}
var backoff atomic.Bool
var wg sync.WaitGroup
wg.Add(r.maxConcurrency)
for range r.maxConcurrency {
go func() {
defer wg.Done()
for {
if ctx.Err() != nil {
return
}
reqs := r.probeManager.GetProbe()
if len(reqs) == 0 {
return
}
r.probeManager.MarkProbeInProgress(reqs)
rctx, cancel := context.WithTimeout(ctx, probeTimeout)
res, err := client.GetReachability(rctx, reqs)
cancel()
r.probeManager.CompleteProbe(reqs, res, err)
if isErrorPersistent(err) {
backoff.Store(true)
return
}
}
}()
}
wg.Wait()
resCh <- backoff.Load()
}()
return reachabilityTask{Cancel: cancel, BackoffCh: resCh}
}
var errTooManyConsecutiveFailures = errors.New("too many consecutive failures")
// errCountingClient counts errors from autonatv2Client and wraps the errors in response with a
// errTooManyConsecutiveFailures in case of persistent failures from autonatv2 module.
type errCountingClient struct {
autonatv2Client
MaxConsecutiveErrors int
mx sync.Mutex
consecutiveErrors int
}
func (c *errCountingClient) GetReachability(ctx context.Context, reqs probe) (autonatv2.Result, error) {
res, err := c.autonatv2Client.GetReachability(ctx, reqs)
c.mx.Lock()
defer c.mx.Unlock()
if err != nil && !errors.Is(err, context.Canceled) { // ignore canceled errors, they're not errors from autonatv2
c.consecutiveErrors++
if c.consecutiveErrors > c.MaxConsecutiveErrors {
err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err)
}
if errors.Is(err, autonatv2.ErrPrivateAddrs) {
log.Errorf("private IP addr in autonatv2 request: %s", err)
}
} else {
c.consecutiveErrors = 0
}
return res, err
}
const maxConsecutiveErrors = 20
// isErrorPersistent returns whether the error will repeat on future probes for a while
func isErrorPersistent(err error) bool {
if err == nil {
return false
}
return errors.Is(err, autonatv2.ErrPrivateAddrs) || errors.Is(err, autonatv2.ErrNoPeers) ||
errors.Is(err, errTooManyConsecutiveFailures)
}
const (
// recentProbeInterval is the interval to probe addresses that have been refused
// these are generally addresses with newer transports for which we don't have many peers
// capable of dialing the transport
recentProbeInterval = 10 * time.Minute
// maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which
// we wait for `recentProbeInterval` before probing again
maxConsecutiveRefusals = 5
// maxRecentDialsPerAddr is the maximum number of dials on an address before we stop probing for the address.
// This is used to prevent infinite probing of an address whose status is indeterminate for any reason.
maxRecentDialsPerAddr = 10
// confidence is the absolute difference between the number of successes and failures for an address
// targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval`
// before probing again.
targetConfidence = 3
// minConfidence is the confidence threshold for an address to be considered reachable or unreachable
// confidence is the absolute difference between the number of successes and failures for an address
minConfidence = 2
// maxRecentDialsWindow is the maximum number of recent probe results to consider for a single address
//
// +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure
// and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to
// targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval.
maxRecentDialsWindow = targetConfidence + 2
// highConfidenceAddrProbeInterval is the maximum interval between probes for an address
highConfidenceAddrProbeInterval = 1 * time.Hour
// maxProbeResultTTL is the maximum time to keep probe results for an address
maxProbeResultTTL = maxRecentDialsWindow * highConfidenceAddrProbeInterval
)
// probeManager tracks reachability for a set of addresses by periodically probing reachability with autonatv2.
// A Probe is a list of addresses which can be tested for reachability with autonatv2.
// This struct decides the priority order of addresses for testing reachability, and throttles in case there have
// been too many probes for an address in the `ProbeInterval`.
//
// Use the `runProbes` function to execute the probes with an autonatv2 client.
type probeManager struct {
now func() time.Time
mx sync.Mutex
inProgressProbes map[string]int // addr -> count
inProgressProbesTotal int
statuses map[string]*addrStatus
addrs []ma.Multiaddr
}
// newProbeManager creates a new probe manager.
func newProbeManager(now func() time.Time) *probeManager {
return &probeManager{
statuses: make(map[string]*addrStatus),
inProgressProbes: make(map[string]int),
now: now,
}
}
// AppendConfirmedAddrs appends the current confirmed reachable and unreachable addresses.
func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable []ma.Multiaddr) (reachableAddrs, unreachableAddrs []ma.Multiaddr) {
m.mx.Lock()
defer m.mx.Unlock()
for _, a := range m.addrs {
s := m.statuses[string(a.Bytes())]
s.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results
switch s.Reachability() {
case network.ReachabilityPublic:
reachable = append(reachable, a)
case network.ReachabilityPrivate:
unreachable = append(unreachable, a)
}
}
return reachable, unreachable
}
// UpdateAddrs updates the tracked addrs
func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) {
m.mx.Lock()
defer m.mx.Unlock()
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
statuses := make(map[string]*addrStatus, len(addrs))
for _, addr := range addrs {
k := string(addr.Bytes())
if _, ok := m.statuses[k]; !ok {
statuses[k] = &addrStatus{Addr: addr}
} else {
statuses[k] = m.statuses[k]
}
}
m.addrs = addrs
m.statuses = statuses
}
// GetProbe returns the next probe. Returns zero value in case there are no more probes.
// Probes that are run against an autonatv2 client should be marked in progress with
// `MarkProbeInProgress` before running.
func (m *probeManager) GetProbe() probe {
m.mx.Lock()
defer m.mx.Unlock()
now := m.now()
for i, a := range m.addrs {
ab := a.Bytes()
pc := m.statuses[string(ab)].RequiredProbeCount(now)
if m.inProgressProbes[string(ab)] >= pc {
continue
}
reqs := make(probe, 0, maxAddrsPerRequest)
reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true})
// We have the first(primary) address. Append other addresses, ignoring inprogress probes
// on secondary addresses. The expectation is that the primary address will
// be dialed.
for j := 1; j < len(m.addrs); j++ {
k := (i + j) % len(m.addrs)
ab := m.addrs[k].Bytes()
pc := m.statuses[string(ab)].RequiredProbeCount(now)
if pc == 0 {
continue
}
reqs = append(reqs, autonatv2.Request{Addr: m.addrs[k], SendDialData: true})
if len(reqs) >= maxAddrsPerRequest {
break
}
}
return reqs
}
return nil
}
// MarkProbeInProgress should be called when a probe is started.
// All in progress probes *MUST* be completed with `CompleteProbe`
func (m *probeManager) MarkProbeInProgress(reqs probe) {
if len(reqs) == 0 {
return
}
m.mx.Lock()
defer m.mx.Unlock()
m.inProgressProbes[string(reqs[0].Addr.Bytes())]++
m.inProgressProbesTotal++
}
// InProgressProbes returns the number of probes that are currently in progress.
func (m *probeManager) InProgressProbes() int {
m.mx.Lock()
defer m.mx.Unlock()
return m.inProgressProbesTotal
}
// CompleteProbe should be called when a probe completes.
func (m *probeManager) CompleteProbe(reqs probe, res autonatv2.Result, err error) {
now := m.now()
if len(reqs) == 0 {
// should never happen
return
}
m.mx.Lock()
defer m.mx.Unlock()
// decrement in-progress count for the first address
primaryAddrKey := string(reqs[0].Addr.Bytes())
m.inProgressProbes[primaryAddrKey]--
if m.inProgressProbes[primaryAddrKey] <= 0 {
delete(m.inProgressProbes, primaryAddrKey)
}
m.inProgressProbesTotal--
// nothing to do if the request errored.
if err != nil {
return
}
// Consider only primary address as refused. This increases the number of
// refused probes, but refused probes are cheap for a server as no dials are made.
if res.AllAddrsRefused {
if s, ok := m.statuses[primaryAddrKey]; ok {
s.AddRefusal(now)
}
return
}
dialAddrKey := string(res.Addr.Bytes())
if dialAddrKey != primaryAddrKey {
if s, ok := m.statuses[primaryAddrKey]; ok {
s.AddRefusal(now)
}
}
// record the result for the dialed address
if s, ok := m.statuses[dialAddrKey]; ok {
s.AddOutcome(now, res.Reachability, maxRecentDialsWindow)
}
}
type dialOutcome struct {
Success bool
At time.Time
}
type addrStatus struct {
Addr ma.Multiaddr
lastRefusalTime time.Time
consecutiveRefusals int
dialTimes []time.Time
outcomes []dialOutcome
}
func (s *addrStatus) Reachability() network.Reachability {
rch, _, _ := s.reachabilityAndCounts()
return rch
}
func (s *addrStatus) RequiredProbeCount(now time.Time) int {
if s.consecutiveRefusals >= maxConsecutiveRefusals {
if now.Sub(s.lastRefusalTime) < recentProbeInterval {
return 0
}
// reset every `recentProbeInterval`
s.lastRefusalTime = time.Time{}
s.consecutiveRefusals = 0
}
// Don't probe if we have probed too many times recently
rd := s.recentDialCount(now)
if rd >= maxRecentDialsPerAddr {
return 0
}
return s.requiredProbeCountForConfirmation(now)
}
func (s *addrStatus) requiredProbeCountForConfirmation(now time.Time) int {
reachability, successes, failures := s.reachabilityAndCounts()
confidence := successes - failures
if confidence < 0 {
confidence = -confidence
}
cnt := targetConfidence - confidence
if cnt > 0 {
return cnt
}
// we have enough confirmations; check if we should refresh
// Should never happen. The confidence logic above should require a few probes.
if len(s.outcomes) == 0 {
return 0
}
lastOutcome := s.outcomes[len(s.outcomes)-1]
// If the last probe result is old, we need to retest
if now.Sub(lastOutcome.At) > highConfidenceAddrProbeInterval {
return 1
}
// if the last probe result was different from reachability, probe again.
switch reachability {
case network.ReachabilityPublic:
if !lastOutcome.Success {
return 1
}
case network.ReachabilityPrivate:
if lastOutcome.Success {
return 1
}
default:
// this should never happen
return 1
}
return 0
}
func (s *addrStatus) AddRefusal(now time.Time) {
s.lastRefusalTime = now
s.consecutiveRefusals++
}
func (s *addrStatus) AddOutcome(at time.Time, rch network.Reachability, windowSize int) {
s.lastRefusalTime = time.Time{}
s.consecutiveRefusals = 0
s.dialTimes = append(s.dialTimes, at)
for i, t := range s.dialTimes {
if at.Sub(t) < recentProbeInterval {
s.dialTimes = slices.Delete(s.dialTimes, 0, i)
break
}
}
s.RemoveBefore(at.Add(-maxProbeResultTTL)) // remove old outcomes
success := false
switch rch {
case network.ReachabilityPublic:
success = true
case network.ReachabilityPrivate:
success = false
default:
return // don't store the outcome if reachability is unknown
}
s.outcomes = append(s.outcomes, dialOutcome{At: at, Success: success})
if len(s.outcomes) > windowSize {
s.outcomes = slices.Delete(s.outcomes, 0, len(s.outcomes)-windowSize)
}
}
// RemoveBefore removes outcomes before t
func (s *addrStatus) RemoveBefore(t time.Time) {
end := 0
for ; end < len(s.outcomes); end++ {
if !s.outcomes[end].At.Before(t) {
break
}
}
s.outcomes = slices.Delete(s.outcomes, 0, end)
}
func (s *addrStatus) recentDialCount(now time.Time) int {
cnt := 0
for _, t := range slices.Backward(s.dialTimes) {
if now.Sub(t) > recentProbeInterval {
break
}
cnt++
}
return cnt
}
func (s *addrStatus) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) {
for _, r := range s.outcomes {
if r.Success {
successes++
} else {
failures++
}
}
if successes-failures >= minConfidence {
return network.ReachabilityPublic, successes, failures
}
if failures-successes >= minConfidence {
return network.ReachabilityPrivate, successes, failures
}
return network.ReachabilityUnknown, successes, failures
}

View File

@@ -0,0 +1,919 @@
package basichost
import (
"context"
"encoding/binary"
"errors"
"fmt"
"math/rand"
"net"
"net/netip"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProbeManager(t *testing.T) {
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1")
pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1")
cl := clock.NewMock()
nextProbe := func(pm *probeManager) []autonatv2.Request {
reqs := pm.GetProbe()
if len(reqs) != 0 {
pm.MarkProbeInProgress(reqs)
}
return reqs
}
makeNewProbeManager := func(addrs []ma.Multiaddr) *probeManager {
pm := newProbeManager(cl.Now)
pm.UpdateAddrs(addrs)
return pm
}
t.Run("addrs updates", func(t *testing.T) {
pm := newProbeManager(cl.Now)
pm.UpdateAddrs([]ma.Multiaddr{pub1, pub2})
for {
reqs := nextProbe(pm)
if len(reqs) == 0 {
break
}
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
}
reachable, _ := pm.AppendConfirmedAddrs(nil, nil)
require.Equal(t, reachable, []ma.Multiaddr{pub1, pub2})
pm.UpdateAddrs([]ma.Multiaddr{pub3})
reachable, _ = pm.AppendConfirmedAddrs(nil, nil)
require.Empty(t, reachable)
require.Len(t, pm.statuses, 1)
})
t.Run("inprogress", func(t *testing.T) {
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
reqs1 := pm.GetProbe()
reqs2 := pm.GetProbe()
require.Equal(t, reqs1, reqs2)
for range targetConfidence {
reqs := nextProbe(pm)
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
}
for range targetConfidence {
reqs := nextProbe(pm)
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}})
}
reqs := pm.GetProbe()
require.Empty(t, reqs)
})
t.Run("refusals", func(t *testing.T) {
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
var probes [][]autonatv2.Request
for range targetConfidence {
reqs := nextProbe(pm)
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
probes = append(probes, reqs)
}
// first one refused second one successful
for _, p := range probes {
pm.CompleteProbe(p, autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil)
}
// the second address is validated!
probes = nil
for range targetConfidence {
reqs := nextProbe(pm)
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}})
probes = append(probes, reqs)
}
reqs := pm.GetProbe()
require.Empty(t, reqs)
for _, p := range probes {
pm.CompleteProbe(p, autonatv2.Result{AllAddrsRefused: true}, nil)
}
// all requests refused; no more probes for too many refusals
reqs = pm.GetProbe()
require.Empty(t, reqs)
cl.Add(recentProbeInterval)
reqs = pm.GetProbe()
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}})
})
t.Run("successes", func(t *testing.T) {
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
for j := 0; j < 2; j++ {
for i := 0; i < targetConfidence; i++ {
reqs := nextProbe(pm)
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
}
}
// all addrs confirmed
reqs := pm.GetProbe()
require.Empty(t, reqs)
cl.Add(highConfidenceAddrProbeInterval + time.Millisecond)
reqs = nextProbe(pm)
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
reqs = nextProbe(pm)
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}})
})
t.Run("throttling on indeterminate reachability", func(t *testing.T) {
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
reachability := network.ReachabilityPublic
nextReachability := func() network.Reachability {
if reachability == network.ReachabilityPublic {
reachability = network.ReachabilityPrivate
} else {
reachability = network.ReachabilityPublic
}
return reachability
}
// both addresses are indeterminate
for range 2 * maxRecentDialsPerAddr {
reqs := nextProbe(pm)
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil)
}
reqs := pm.GetProbe()
require.Empty(t, reqs)
cl.Add(recentProbeInterval + time.Millisecond)
reqs = pm.GetProbe()
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
for range 2 * maxRecentDialsPerAddr {
reqs := nextProbe(pm)
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil)
}
reqs = pm.GetProbe()
require.Empty(t, reqs)
})
t.Run("reachabilityUpdate", func(t *testing.T) {
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
for range 2 * targetConfidence {
reqs := nextProbe(pm)
if reqs[0].Addr.Equal(pub1) {
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
} else {
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub2, Idx: 0, Reachability: network.ReachabilityPrivate}, nil)
}
}
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
require.Equal(t, reachable, []ma.Multiaddr{pub1})
require.Equal(t, unreachable, []ma.Multiaddr{pub2})
})
t.Run("expiry", func(t *testing.T) {
pm := makeNewProbeManager([]ma.Multiaddr{pub1})
for range 2 * targetConfidence {
reqs := nextProbe(pm)
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
}
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
require.Equal(t, reachable, []ma.Multiaddr{pub1})
require.Empty(t, unreachable)
cl.Add(maxProbeResultTTL + 1*time.Second)
reachable, unreachable = pm.AppendConfirmedAddrs(nil, nil)
require.Empty(t, reachable)
require.Empty(t, unreachable)
})
}
type mockAutoNATClient struct {
F func(context.Context, []autonatv2.Request) (autonatv2.Result, error)
}
func (m mockAutoNATClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
return m.F(ctx, reqs)
}
var _ autonatv2Client = mockAutoNATClient{}
func TestAddrsReachabilityTracker(t *testing.T) {
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1")
pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1")
pri := ma.StringCast("/ip4/192.168.1.1/tcp/1")
newTracker := func(cli mockAutoNATClient, cl clock.Clock) *addrsReachabilityTracker {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if cl == nil {
cl = clock.New()
}
tr := &addrsReachabilityTracker{
ctx: ctx,
cancel: cancel,
client: cli,
newAddrs: make(chan []ma.Multiaddr, 1),
reachabilityUpdateCh: make(chan struct{}, 1),
maxConcurrency: 3,
newAddrsProbeDelay: 0 * time.Second,
probeManager: newProbeManager(cl.Now),
clock: cl,
}
err := tr.Start()
require.NoError(t, err)
t.Cleanup(func() {
err := tr.Close()
assert.NoError(t, err)
})
return tr
}
t.Run("simple", func(t *testing.T) {
// pub1 reachable, pub2 unreachable, pub3 ignored
mockClient := mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
for i, req := range reqs {
if req.Addr.Equal(pub1) {
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
} else if req.Addr.Equal(pub2) {
return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil
}
}
return autonatv2.Result{}, autonatv2.ErrNoPeers
},
}
tr := newTracker(mockClient, nil)
tr.UpdateAddrs([]ma.Multiaddr{pub2, pub1, pri})
select {
case <-tr.reachabilityUpdateCh:
case <-time.After(2 * time.Second):
t.Fatal("expected reachability update")
}
reachable, unreachable := tr.ConfirmedAddrs()
require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1)
require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2)
tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pri})
select {
case <-tr.reachabilityUpdateCh:
case <-time.After(2 * time.Second):
t.Fatal("expected reachability update")
}
reachable, unreachable = tr.ConfirmedAddrs()
require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1)
require.Empty(t, unreachable)
})
t.Run("confirmed addrs ordering", func(t *testing.T) {
mockClient := mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
},
}
tr := newTracker(mockClient, nil)
var addrs []ma.Multiaddr
for i := 0; i < 10; i++ {
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i)))
}
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return -a.Compare(b) }) // sort in reverse order
tr.UpdateAddrs(addrs)
select {
case <-tr.reachabilityUpdateCh:
case <-time.After(2 * time.Second):
t.Fatal("expected reachability update")
}
reachable, unreachable := tr.ConfirmedAddrs()
require.Empty(t, unreachable)
orderedAddrs := slices.Clone(addrs)
slices.Reverse(orderedAddrs)
require.Equal(t, reachable, orderedAddrs, "%s %s", reachable, addrs)
})
t.Run("backoff", func(t *testing.T) {
notify := make(chan struct{}, 1)
drainNotify := func() bool {
found := false
for {
select {
case <-notify:
found = true
default:
return found
}
}
}
var allow atomic.Bool
mockClient := mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
select {
case notify <- struct{}{}:
default:
}
if !allow.Load() {
return autonatv2.Result{}, autonatv2.ErrNoPeers
}
if reqs[0].Addr.Equal(pub1) {
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
}
return autonatv2.Result{AllAddrsRefused: true}, nil
},
}
cl := clock.NewMock()
tr := newTracker(mockClient, cl)
// update addrs and wait for initial checks
tr.UpdateAddrs([]ma.Multiaddr{pub1})
// need to update clock after the background goroutine processes the new addrs
time.Sleep(100 * time.Millisecond)
cl.Add(1)
time.Sleep(100 * time.Millisecond)
require.True(t, drainNotify()) // check that we did receive probes
backoffInterval := backoffStartInterval
for i := 0; i < 4; i++ {
drainNotify()
cl.Add(backoffInterval / 2)
select {
case <-notify:
t.Fatal("unexpected call")
case <-time.After(50 * time.Millisecond):
}
cl.Add(backoffInterval/2 + 1) // +1 to push it slightly over the backoff interval
backoffInterval *= 2
select {
case <-notify:
case <-time.After(1 * time.Second):
t.Fatal("expected probe")
}
reachable, unreachable := tr.ConfirmedAddrs()
require.Empty(t, reachable)
require.Empty(t, unreachable)
}
allow.Store(true)
drainNotify()
cl.Add(backoffInterval + 1)
select {
case <-tr.reachabilityUpdateCh:
case <-time.After(1 * time.Second):
t.Fatal("unexpected reachability update")
}
reachable, unreachable := tr.ConfirmedAddrs()
require.Equal(t, reachable, []ma.Multiaddr{pub1})
require.Empty(t, unreachable)
})
t.Run("event update", func(t *testing.T) {
// allow minConfidence probes to pass
called := make(chan struct{}, minConfidence)
notify := make(chan struct{})
mockClient := mockAutoNATClient{
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
select {
case called <- struct{}{}:
notify <- struct{}{}
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
default:
return autonatv2.Result{AllAddrsRefused: true}, nil
}
},
}
tr := newTracker(mockClient, nil)
tr.UpdateAddrs([]ma.Multiaddr{pub1})
for i := 0; i < minConfidence; i++ {
select {
case <-notify:
case <-time.After(1 * time.Second):
t.Fatal("expected call to autonat client")
}
}
select {
case <-tr.reachabilityUpdateCh:
reachable, unreachable := tr.ConfirmedAddrs()
require.Equal(t, reachable, []ma.Multiaddr{pub1})
require.Empty(t, unreachable)
case <-time.After(1 * time.Second):
t.Fatal("expected reachability update")
}
tr.UpdateAddrs([]ma.Multiaddr{pub1}) // same addrs shouldn't get update
select {
case <-tr.reachabilityUpdateCh:
t.Fatal("didn't expect reachability update")
case <-time.After(100 * time.Millisecond):
}
tr.UpdateAddrs([]ma.Multiaddr{pub2})
select {
case <-tr.reachabilityUpdateCh:
reachable, unreachable := tr.ConfirmedAddrs()
require.Empty(t, reachable)
require.Empty(t, unreachable)
case <-time.After(1 * time.Second):
t.Fatal("expected reachability update")
}
})
t.Run("refresh after reset interval", func(t *testing.T) {
notify := make(chan struct{}, 1)
drainNotify := func() bool {
found := false
for {
select {
case <-notify:
found = true
default:
return found
}
}
}
mockClient := mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
select {
case notify <- struct{}{}:
default:
}
if reqs[0].Addr.Equal(pub1) {
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
}
return autonatv2.Result{AllAddrsRefused: true}, nil
},
}
cl := clock.NewMock()
tr := newTracker(mockClient, cl)
// update addrs and wait for initial checks
tr.UpdateAddrs([]ma.Multiaddr{pub1})
// need to update clock after the background goroutine processes the new addrs
time.Sleep(100 * time.Millisecond)
cl.Add(1)
time.Sleep(100 * time.Millisecond)
require.True(t, drainNotify()) // check that we did receive probes
cl.Add(highConfidenceAddrProbeInterval / 2)
select {
case <-notify:
t.Fatal("unexpected call")
case <-time.After(50 * time.Millisecond):
}
cl.Add(highConfidenceAddrProbeInterval/2 + defaultReachabilityRefreshInterval) // defaultResetInterval for the next probe time
select {
case <-notify:
case <-time.After(1 * time.Second):
t.Fatal("expected probe")
}
})
}
func TestRefreshReachability(t *testing.T) {
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
pub2 := ma.StringCast("/ip4/1.1.1.1/tcp/2")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
newTracker := func(client autonatv2Client, pm *probeManager) *addrsReachabilityTracker {
return &addrsReachabilityTracker{
probeManager: pm,
client: client,
clock: clock.New(),
maxConcurrency: 3,
ctx: ctx,
cancel: cancel,
}
}
t.Run("backoff on ErrNoValidPeers", func(t *testing.T) {
mockClient := mockAutoNATClient{
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
return autonatv2.Result{}, autonatv2.ErrNoPeers
},
}
addrTracker := newProbeManager(time.Now)
addrTracker.UpdateAddrs([]ma.Multiaddr{pub1})
r := newTracker(mockClient, addrTracker)
res := r.refreshReachability()
require.True(t, <-res.BackoffCh)
require.Equal(t, addrTracker.InProgressProbes(), 0)
})
t.Run("returns backoff on errTooManyConsecutiveFailures", func(t *testing.T) {
// Create a client that always returns ErrDialRefused
mockClient := mockAutoNATClient{
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
return autonatv2.Result{}, errors.New("test error")
},
}
pm := newProbeManager(time.Now)
pm.UpdateAddrs([]ma.Multiaddr{pub1})
r := newTracker(mockClient, pm)
result := r.refreshReachability()
require.True(t, <-result.BackoffCh)
require.Equal(t, pm.InProgressProbes(), 0)
})
t.Run("quits on cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
block := make(chan struct{})
mockClient := mockAutoNATClient{
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
block <- struct{}{}
return autonatv2.Result{}, nil
},
}
pm := newProbeManager(time.Now)
pm.UpdateAddrs([]ma.Multiaddr{pub1})
r := &addrsReachabilityTracker{
ctx: ctx,
cancel: cancel,
client: mockClient,
probeManager: pm,
clock: clock.New(),
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
result := r.refreshReachability()
assert.False(t, <-result.BackoffCh)
assert.Equal(t, pm.InProgressProbes(), 0)
}()
cancel()
time.Sleep(50 * time.Millisecond) // wait for the cancellation to be processed
outer:
for i := 0; i < defaultMaxConcurrency; i++ {
select {
case <-block:
default:
break outer
}
}
select {
case <-block:
t.Fatal("expected no more requests")
case <-time.After(50 * time.Millisecond):
}
wg.Wait()
})
t.Run("handles refusals", func(t *testing.T) {
pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1")
mockClient := mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
for i, req := range reqs {
if req.Addr.Equal(pub1) {
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
}
}
return autonatv2.Result{AllAddrsRefused: true}, nil
},
}
pm := newProbeManager(time.Now)
pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1})
r := newTracker(mockClient, pm)
result := r.refreshReachability()
require.False(t, <-result.BackoffCh)
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
require.Equal(t, reachable, []ma.Multiaddr{pub1})
require.Empty(t, unreachable)
require.Equal(t, pm.InProgressProbes(), 0)
})
t.Run("handles completions", func(t *testing.T) {
mockClient := mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
for i, req := range reqs {
if req.Addr.Equal(pub1) {
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
}
if req.Addr.Equal(pub2) {
return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil
}
}
return autonatv2.Result{AllAddrsRefused: true}, nil
},
}
pm := newProbeManager(time.Now)
pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1})
r := newTracker(mockClient, pm)
result := r.refreshReachability()
require.False(t, <-result.BackoffCh)
reachable, unreachable := pm.AppendConfirmedAddrs(nil, nil)
require.Equal(t, reachable, []ma.Multiaddr{pub1})
require.Equal(t, unreachable, []ma.Multiaddr{pub2})
require.Equal(t, pm.InProgressProbes(), 0)
})
}
func TestAddrStatusProbeCount(t *testing.T) {
cases := []struct {
inputs string
wantRequiredProbes int
wantReachability network.Reachability
}{
{
inputs: "",
wantRequiredProbes: 3,
wantReachability: network.ReachabilityUnknown,
},
{
inputs: "S",
wantRequiredProbes: 2,
wantReachability: network.ReachabilityUnknown,
},
{
inputs: "SS",
wantRequiredProbes: 1,
wantReachability: network.ReachabilityPublic,
},
{
inputs: "SSS",
wantRequiredProbes: 0,
wantReachability: network.ReachabilityPublic,
},
{
inputs: "SSSSSSSF",
wantRequiredProbes: 1,
wantReachability: network.ReachabilityPublic,
},
{
inputs: "SFSFSSSS",
wantRequiredProbes: 0,
wantReachability: network.ReachabilityPublic,
},
{
inputs: "SSSSSFSF",
wantRequiredProbes: 2,
wantReachability: network.ReachabilityUnknown,
},
{
inputs: "FF",
wantRequiredProbes: 1,
wantReachability: network.ReachabilityPrivate,
},
}
for _, c := range cases {
t.Run(c.inputs, func(t *testing.T) {
now := time.Time{}.Add(1 * time.Second)
ao := addrStatus{}
for _, r := range c.inputs {
if r == 'S' {
ao.AddOutcome(now, network.ReachabilityPublic, 5)
} else {
ao.AddOutcome(now, network.ReachabilityPrivate, 5)
}
now = now.Add(1 * time.Second)
}
require.Equal(t, ao.RequiredProbeCount(now), c.wantRequiredProbes)
require.Equal(t, ao.Reachability(), c.wantReachability)
if c.wantRequiredProbes == 0 {
now = now.Add(highConfidenceAddrProbeInterval + 10*time.Microsecond)
require.Equal(t, ao.RequiredProbeCount(now), 1)
}
now = now.Add(1 * time.Second)
ao.RemoveBefore(now)
require.Len(t, ao.outcomes, 0)
})
}
}
func BenchmarkAddrTracker(b *testing.B) {
cl := clock.NewMock()
t := newProbeManager(cl.Now)
addrs := make([]ma.Multiaddr, 20)
for i := range addrs {
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000)))
}
t.UpdateAddrs(addrs)
b.ReportAllocs()
b.ResetTimer()
p := t.GetProbe()
for i := 0; i < b.N; i++ {
pp := t.GetProbe()
if len(pp) == 0 {
pp = p
}
t.MarkProbeInProgress(pp)
t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
}
}
func FuzzAddrsReachabilityTracker(f *testing.F) {
type autonatv2Response struct {
Result autonatv2.Result
Err error
}
newMockClient := func(b []byte) mockAutoNATClient {
count := 0
return mockAutoNATClient{
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
if len(b) == 0 {
return autonatv2.Result{}, nil
}
count = (count + 1) % len(b)
if b[count]%3 == 0 {
// some address confirmed
c1 := (count + 1) % len(b)
c2 := (count + 2) % len(b)
rch := network.Reachability(b[c1] % 3)
n := int(b[c2]) % len(reqs)
return autonatv2.Result{
Addr: reqs[n].Addr,
Idx: n,
Reachability: rch,
}, nil
}
outcomes := []autonatv2Response{
{Result: autonatv2.Result{AllAddrsRefused: true}},
{Err: errors.New("test error")},
{Err: autonatv2.ErrPrivateAddrs},
{Err: autonatv2.ErrNoPeers},
{Result: autonatv2.Result{}, Err: nil},
{Result: autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}},
{Result: autonatv2.Result{
Addr: reqs[0].Addr,
Idx: 0,
Reachability: network.ReachabilityPublic,
AllAddrsRefused: true,
}},
{Result: autonatv2.Result{
Addr: reqs[0].Addr,
Idx: len(reqs) - 1, // invalid idx
Reachability: network.ReachabilityPublic,
AllAddrsRefused: false,
}},
}
outcome := outcomes[int(b[count])%len(outcomes)]
return outcome.Result, outcome.Err
},
}
}
// TODO: Move this to go-multiaddrs
getProto := func(protos []byte) ma.Multiaddr {
protoType := 0
if len(protos) > 0 {
protoType = int(protos[0])
}
port1, port2 := 0, 0
if len(protos) > 1 {
port1 = int(protos[1])
}
if len(protos) > 2 {
port2 = int(protos[2])
}
protoTemplates := []string{
"/tcp/%d/",
"/udp/%d/",
"/udp/%d/quic-v1/",
"/udp/%d/quic-v1/tcp/%d",
"/udp/%d/quic-v1/webtransport/",
"/udp/%d/webrtc/",
"/udp/%d/webrtc-direct/",
"/unix/hello/",
}
s := protoTemplates[protoType%len(protoTemplates)]
port1 %= (1 << 16)
if strings.Count(s, "%d") == 1 {
return ma.StringCast(fmt.Sprintf(s, port1))
}
port2 %= (1 << 16)
return ma.StringCast(fmt.Sprintf(s, port1, port2))
}
getIP := func(ips []byte) ma.Multiaddr {
ipType := 0
if len(ips) > 0 {
ipType = int(ips[0])
}
ips = ips[1:]
var x, y int64
split := 128 / 8
if len(ips) < split {
split = len(ips)
}
var b [8]byte
copy(b[:], ips[:split])
x = int64(binary.LittleEndian.Uint64(b[:]))
clear(b[:])
copy(b[:], ips[split:])
y = int64(binary.LittleEndian.Uint64(b[:]))
var ip netip.Addr
switch ipType % 3 {
case 0:
ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)})
return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip))
case 1:
pubIP := net.ParseIP("2005::") // Public IP address
x := int64(binary.LittleEndian.Uint64(pubIP[0:8]))
ip = netip.AddrFrom16([16]byte{
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
})
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
default:
ip := netip.AddrFrom16([16]byte{
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
})
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
}
}
getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr {
switch addrType % 4 {
case 0:
return getIP(ips).Encapsulate(getProto(protos))
case 1:
return getProto(protos)
case 2:
return nil
default:
return getIP(ips).Encapsulate(getProto(protos))
}
}
getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr {
hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "")
hostName = strings.ReplaceAll(hostName, "/", "")
if hostName == "" {
hostName = "localhost"
}
dnsType := 0
if len(hostNameBytes) > 0 {
dnsType = int(hostNameBytes[0])
}
dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"}
da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName))
return da.Encapsulate(getProto(protos))
}
const maxAddrs = 1000
getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr {
if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 {
return nil
}
numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs
addrs := make([]ma.Multiaddr, numAddrs)
ipIdx := 0
protoIdx := 0
for i := range numAddrs {
addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:])
ipIdx = (ipIdx + 1) % len(ips)
protoIdx = (protoIdx + 1) % len(protos)
}
maxDNSAddrs := 10
protoIdx = 0
for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 {
ed := min(i+2, len(hostNames))
addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:]))
protoIdx = (protoIdx + 1) % len(protos)
}
return addrs
}
cl := clock.NewMock()
f.Fuzz(func(t *testing.T, numAddrs int, ips, protos, hostNames, autonatResponses []byte) {
tr := newAddrsReachabilityTracker(newMockClient(autonatResponses), nil, cl)
require.NoError(t, tr.Start())
tr.UpdateAddrs(getAddrs(numAddrs, ips, protos, hostNames))
// fuzz tests need to finish in 10 seconds for some reason
// https://github.com/golang/go/issues/48157
// https://github.com/golang/go/commit/5d24203c394e6b64c42a9f69b990d94cb6c8aad4#diff-4e3b9481b8794eb058998e2bec389d3db7a23c54e67ac0f7259a3a5d2c79fd04R474-R483
const maxIters = 20
for range maxIters {
cl.Add(5 * time.Minute)
time.Sleep(100 * time.Millisecond)
}
require.NoError(t, tr.Close())
})
}

View File

@@ -156,8 +156,8 @@ type HostOpts struct {
// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify // DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify
DisableIdentifyAddressDiscovery bool DisableIdentifyAddressDiscovery bool
EnableAutoNATv2 bool
AutoNATv2Dialer host.Host AutoNATv2 *autonatv2.AutoNAT
} }
// NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network.
@@ -236,7 +236,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
}); ok { }); ok {
tfl = s.TransportForListening tfl = s.TransportForListening
} }
h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan)
if opts.AutoNATv2 != nil {
h.autonatv2 = opts.AutoNATv2
}
var autonatv2Client autonatv2Client // avoid typed nil errors
if h.autonatv2 != nil {
autonatv2Client = h.autonatv2
}
h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan, autonatv2Client)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create address service: %w", err) return nil, fmt.Errorf("failed to create address service: %w", err)
} }
@@ -283,17 +292,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
h.pings = ping.NewPingService(h) h.pings = ping.NewPingService(h)
} }
if opts.EnableAutoNATv2 {
var mt autonatv2.MetricsTracer
if opts.EnableMetrics {
mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer)
}
h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt))
if err != nil {
return nil, fmt.Errorf("failed to create autonatv2: %w", err)
}
}
if !h.disableSignedPeerRecord { if !h.disableSignedPeerRecord {
h.signKey = h.Peerstore().PrivKey(h.ID()) h.signKey = h.Peerstore().PrivKey(h.ID())
cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore())
@@ -320,7 +318,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
func (h *BasicHost) Start() { func (h *BasicHost) Start() {
h.psManager.Start() h.psManager.Start()
if h.autonatv2 != nil { if h.autonatv2 != nil {
err := h.autonatv2.Start() err := h.autonatv2.Start(h)
if err != nil { if err != nil {
log.Errorf("autonat v2 failed to start: %s", err) log.Errorf("autonat v2 failed to start: %s", err)
} }
@@ -754,6 +752,16 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr {
return h.addressManager.DirectAddrs() return h.addressManager.DirectAddrs()
} }
// ReachableAddrs returns all addresses of the host that are reachable from the internet
// as verified by autonatv2.
//
// Experimental: This API may change in the future without deprecation.
//
// Requires AutoNATv2 to be enabled.
func (h *BasicHost) ReachableAddrs() []ma.Multiaddr {
return h.addressManager.ReachableAddrs()
}
func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr {
totalSize := 0 totalSize := 0
for _, a := range addrs { for _, a := range addrs {
@@ -836,7 +844,6 @@ func (h *BasicHost) Close() error {
if h.cmgr != nil { if h.cmgr != nil {
h.cmgr.Close() h.cmgr.Close()
} }
h.addressManager.Close() h.addressManager.Close()
if h.ids != nil { if h.ids != nil {

View File

@@ -47,6 +47,7 @@ func TestHostSimple(t *testing.T) {
h1.Start() h1.Start()
h2, err := NewHost(swarmt.GenSwarm(t), nil) h2, err := NewHost(swarmt.GenSwarm(t), nil)
require.NoError(t, err) require.NoError(t, err)
defer h2.Close() defer h2.Close()
h2.Start() h2.Start()
@@ -211,6 +212,7 @@ func TestAllAddrs(t *testing.T) {
// no listen addrs // no listen addrs
h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil) h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
require.NoError(t, err) require.NoError(t, err)
h.Start()
defer h.Close() defer h.Close()
require.Nil(t, h.AllAddrs()) require.Nil(t, h.AllAddrs())

View File

@@ -4,18 +4,17 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"iter"
"math/rand/v2"
"slices" "slices"
"sync" "sync"
"time" "time"
"math/rand/v2"
logging "github.com/ipfs/go-log/v2" logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
) )
@@ -35,11 +34,15 @@ const (
// maxPeerAddresses is the number of addresses in a dial request the server // maxPeerAddresses is the number of addresses in a dial request the server
// will inspect, rest are ignored. // will inspect, rest are ignored.
maxPeerAddresses = 50 maxPeerAddresses = 50
defaultThrottlePeerDuration = 2 * time.Minute
) )
var ( var (
ErrNoValidPeers = errors.New("no valid peers for autonat v2") // ErrNoPeers is returned when the client knows no autonatv2 servers.
ErrDialRefused = errors.New("dial refused") ErrNoPeers = errors.New("no peers for autonat v2")
// ErrPrivateAddrs is returned when the request has private IP addresses.
ErrPrivateAddrs = errors.New("private addresses cannot be verified with autonatv2")
log = logging.Logger("autonatv2") log = logging.Logger("autonatv2")
) )
@@ -56,10 +59,12 @@ type Request struct {
type Result struct { type Result struct {
// Addr is the dialed address // Addr is the dialed address
Addr ma.Multiaddr Addr ma.Multiaddr
// Reachability of the dialed address // Idx is the index of the address that was dialed
Idx int
// Reachability is the reachability for `Addr`
Reachability network.Reachability Reachability network.Reachability
// Status is the outcome of the dialback // AllAddrsRefused is true when the server refused to dial all the addresses in the request.
Status pb.DialStatus AllAddrsRefused bool
} }
// AutoNAT implements the AutoNAT v2 client and server. // AutoNAT implements the AutoNAT v2 client and server.
@@ -76,8 +81,12 @@ type AutoNAT struct {
srv *server srv *server
cli *client cli *client
mx sync.Mutex mx sync.Mutex
peers *peersMap peers *peersMap
throttlePeer map[peer.ID]time.Time
// throttlePeerDuration is the duration to wait before making another dial request to the
// same server.
throttlePeerDuration time.Duration
// allowPrivateAddrs enables using private and localhost addresses for reachability checks. // allowPrivateAddrs enables using private and localhost addresses for reachability checks.
// This is only useful for testing. // This is only useful for testing.
allowPrivateAddrs bool allowPrivateAddrs bool
@@ -86,7 +95,7 @@ type AutoNAT struct {
// New returns a new AutoNAT instance. // New returns a new AutoNAT instance.
// host and dialerHost should have the same dialing capabilities. In case the host doesn't support // host and dialerHost should have the same dialing capabilities. In case the host doesn't support
// a transport, dial back requests for address for that transport will be ignored. // a transport, dial back requests for address for that transport will be ignored.
func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { func New(dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
s := defaultSettings() s := defaultSettings()
for _, o := range opts { for _, o := range opts {
if err := o(s); err != nil { if err := o(s); err != nil {
@@ -96,18 +105,20 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT,
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
an := &AutoNAT{ an := &AutoNAT{
host: host, ctx: ctx,
ctx: ctx, cancel: cancel,
cancel: cancel, srv: newServer(dialerHost, s),
srv: newServer(host, dialerHost, s), cli: newClient(),
cli: newClient(host), allowPrivateAddrs: s.allowPrivateAddrs,
allowPrivateAddrs: s.allowPrivateAddrs, peers: newPeersMap(),
peers: newPeersMap(), throttlePeer: make(map[peer.ID]time.Time),
throttlePeerDuration: s.throttlePeerDuration,
} }
return an, nil return an, nil
} }
func (an *AutoNAT) background(sub event.Subscription) { func (an *AutoNAT) background(sub event.Subscription) {
ticker := time.NewTicker(10 * time.Minute)
for { for {
select { select {
case <-an.ctx.Done(): case <-an.ctx.Done():
@@ -122,12 +133,24 @@ func (an *AutoNAT) background(sub event.Subscription) {
an.updatePeer(evt.Peer) an.updatePeer(evt.Peer)
case event.EvtPeerIdentificationCompleted: case event.EvtPeerIdentificationCompleted:
an.updatePeer(evt.Peer) an.updatePeer(evt.Peer)
default:
log.Errorf("unexpected event: %T", e)
} }
case <-ticker.C:
now := time.Now()
an.mx.Lock()
for p, t := range an.throttlePeer {
if t.Before(now) {
delete(an.throttlePeer, p)
}
}
an.mx.Unlock()
} }
} }
} }
func (an *AutoNAT) Start() error { func (an *AutoNAT) Start(h host.Host) error {
an.host = h
// Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged // Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged
// event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers.
sub, err := an.host.EventBus().Subscribe([]interface{}{ sub, err := an.host.EventBus().Subscribe([]interface{}{
@@ -138,8 +161,8 @@ func (an *AutoNAT) Start() error {
if err != nil { if err != nil {
return fmt.Errorf("event subscription failed: %w", err) return fmt.Errorf("event subscription failed: %w", err)
} }
an.cli.Start() an.cli.Start(h)
an.srv.Start() an.srv.Start(h)
an.wg.Add(1) an.wg.Add(1)
go an.background(sub) go an.background(sub)
@@ -156,24 +179,48 @@ func (an *AutoNAT) Close() {
// GetReachability makes a single dial request for checking reachability for requested addresses // GetReachability makes a single dial request for checking reachability for requested addresses
func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) { func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) {
var filteredReqs []Request
if !an.allowPrivateAddrs { if !an.allowPrivateAddrs {
filteredReqs = make([]Request, 0, len(reqs))
for _, r := range reqs { for _, r := range reqs {
if !manet.IsPublicAddr(r.Addr) { if manet.IsPublicAddr(r.Addr) {
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr) filteredReqs = append(filteredReqs, r)
} else {
log.Errorf("private address in reachability check: %s", r.Addr)
} }
} }
if len(filteredReqs) == 0 {
return Result{}, ErrPrivateAddrs
}
} else {
filteredReqs = reqs
} }
an.mx.Lock() an.mx.Lock()
p := an.peers.GetRand() now := time.Now()
var p peer.ID
for pr := range an.peers.Shuffled() {
if t := an.throttlePeer[pr]; t.After(now) {
continue
}
p = pr
an.throttlePeer[p] = time.Now().Add(an.throttlePeerDuration)
break
}
an.mx.Unlock() an.mx.Unlock()
if p == "" { if p == "" {
return Result{}, ErrNoValidPeers return Result{}, ErrNoPeers
} }
res, err := an.cli.GetReachability(ctx, p, filteredReqs)
res, err := an.cli.GetReachability(ctx, p, reqs)
if err != nil { if err != nil {
log.Debugf("reachability check with %s failed, err: %s", p, err) log.Debugf("reachability check with %s failed, err: %s", p, err)
return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) return res, fmt.Errorf("reachability check with %s failed: %w", p, err)
}
// restore the correct index in case we'd filtered private addresses
for i, r := range reqs {
if r.Addr.Equal(res.Addr) {
res.Idx = i
break
}
} }
log.Debugf("reachability check with %s successful", p) log.Debugf("reachability check with %s successful", p)
return res, nil return res, nil
@@ -187,7 +234,7 @@ func (an *AutoNAT) updatePeer(p peer.ID) {
// and swarm for the current state // and swarm for the current state
protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol)
connectedness := an.host.Network().Connectedness(p) connectedness := an.host.Network().Connectedness(p)
if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected { if err == nil && connectedness == network.Connected && slices.Contains(protos, DialProtocol) {
an.peers.Put(p) an.peers.Put(p)
} else { } else {
an.peers.Delete(p) an.peers.Delete(p)
@@ -208,28 +255,40 @@ func newPeersMap() *peersMap {
} }
} }
func (p *peersMap) GetRand() peer.ID { // Shuffled iterates over the map in random order
if len(p.peers) == 0 { func (p *peersMap) Shuffled() iter.Seq[peer.ID] {
return "" n := len(p.peers)
start := 0
if n > 0 {
start = rand.IntN(n)
}
return func(yield func(peer.ID) bool) {
for i := range n {
if !yield(p.peers[(i+start)%n]) {
return
}
}
} }
return p.peers[rand.IntN(len(p.peers))]
} }
func (p *peersMap) Put(pid peer.ID) { func (p *peersMap) Put(id peer.ID) {
if _, ok := p.peerIdx[pid]; ok { if _, ok := p.peerIdx[id]; ok {
return return
} }
p.peers = append(p.peers, pid) p.peers = append(p.peers, id)
p.peerIdx[pid] = len(p.peers) - 1 p.peerIdx[id] = len(p.peers) - 1
} }
func (p *peersMap) Delete(pid peer.ID) { func (p *peersMap) Delete(id peer.ID) {
idx, ok := p.peerIdx[pid] idx, ok := p.peerIdx[id]
if !ok { if !ok {
return return
} }
p.peers[idx] = p.peers[len(p.peers)-1] n := len(p.peers)
p.peerIdx[p.peers[idx]] = idx lastPeer := p.peers[n-1]
p.peers = p.peers[:len(p.peers)-1] p.peers[idx] = lastPeer
delete(p.peerIdx, pid) p.peerIdx[lastPeer] = idx
p.peers[n-1] = ""
p.peers = p.peers[:n-1]
delete(p.peerIdx, id)
} }

View File

@@ -2,8 +2,13 @@ package autonatv2
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"math"
"net"
"net/netip"
"strings"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@@ -36,11 +41,12 @@ func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT
swarm.WithUDPBlackHoleSuccessCounter(nil), swarm.WithUDPBlackHoleSuccessCounter(nil),
swarm.WithIPv6BlackHoleSuccessCounter(nil)))) swarm.WithIPv6BlackHoleSuccessCounter(nil))))
} }
an, err := New(h, dialer, opts...) opts = append([]AutoNATOption{withThrottlePeerDuration(0)}, opts...)
an, err := New(dialer, opts...)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
an.Start() require.NoError(t, an.Start(h))
t.Cleanup(an.Close) t.Cleanup(an.Close)
return an return an
} }
@@ -74,7 +80,7 @@ func waitForPeer(t testing.TB, a *AutoNAT) {
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
a.mx.Lock() a.mx.Lock()
defer a.mx.Unlock() defer a.mx.Unlock()
return a.peers.GetRand() != "" return len(a.peers.peers) != 0
}, 5*time.Second, 100*time.Millisecond) }, 5*time.Second, 100*time.Millisecond)
} }
@@ -88,7 +94,7 @@ func TestAutoNATPrivateAddr(t *testing.T) {
an := newAutoNAT(t, nil) an := newAutoNAT(t, nil)
res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}})
require.Equal(t, res, Result{}) require.Equal(t, res, Result{})
require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") require.ErrorIs(t, err, ErrPrivateAddrs)
} }
func TestClientRequest(t *testing.T) { func TestClientRequest(t *testing.T) {
@@ -154,19 +160,6 @@ func TestClientServerError(t *testing.T) {
}, },
errorStr: "invalid msg type", errorStr: "invalid msg type",
}, },
{
handler: func(s network.Stream) {
w := pbio.NewDelimitedWriter(s)
assert.NoError(t, w.WriteMsg(
&pb.Message{Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_E_DIAL_REFUSED,
},
}},
))
},
errorStr: ErrDialRefused.Error(),
},
} }
for i, tc := range tests { for i, tc := range tests {
@@ -298,6 +291,49 @@ func TestClientDataRequest(t *testing.T) {
} }
} }
func TestAutoNATPrivateAndPublicAddrs(t *testing.T) {
an := newAutoNAT(t, nil)
defer an.Close()
defer an.host.Close()
b := bhost.NewBlankHost(swarmt.GenSwarm(t))
defer b.Close()
idAndConnect(t, an.host, b)
waitForPeer(t, an)
dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t))
defer dialerHost.Close()
handler := func(s network.Stream) {
w := pbio.NewDelimitedWriter(s)
r := pbio.NewDelimitedReader(s, maxMsgSize)
var msg pb.Message
assert.NoError(t, r.ReadMsg(&msg))
w.WriteMsg(&pb.Message{
Msg: &pb.Message_DialResponse{
DialResponse: &pb.DialResponse{
Status: pb.DialResponse_OK,
DialStatus: pb.DialStatus_E_DIAL_ERROR,
AddrIdx: 0,
},
},
})
s.Close()
}
b.SetStreamHandler(DialProtocol, handler)
privateAddr := ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")
publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/10/quic-v1")
res, err := an.GetReachability(context.Background(),
[]Request{
{Addr: privateAddr},
{Addr: publicAddr},
})
require.NoError(t, err)
require.Equal(t, res.Addr, publicAddr, "%s\n%s", res.Addr, publicAddr)
require.Equal(t, res.Idx, 1)
require.Equal(t, res.Reachability, network.ReachabilityPrivate)
}
func TestClientDialBacks(t *testing.T) { func TestClientDialBacks(t *testing.T) {
an := newAutoNAT(t, nil, allowPrivateAddrs) an := newAutoNAT(t, nil, allowPrivateAddrs)
defer an.Close() defer an.Close()
@@ -507,7 +543,6 @@ func TestClientDialBacks(t *testing.T) {
} else { } else {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, res.Reachability, network.ReachabilityPublic) require.Equal(t, res.Reachability, network.ReachabilityPublic)
require.Equal(t, res.Status, pb.DialStatus_OK)
} }
}) })
} }
@@ -551,46 +586,6 @@ func TestEventSubscription(t *testing.T) {
}, 5*time.Second, 100*time.Millisecond) }, 5*time.Second, 100*time.Millisecond)
} }
func TestPeersMap(t *testing.T) {
emptyPeerID := peer.ID("")
t.Run("single_item", func(t *testing.T) {
p := newPeersMap()
p.Put("peer1")
p.Delete("peer1")
p.Put("peer1")
require.Equal(t, peer.ID("peer1"), p.GetRand())
p.Delete("peer1")
require.Equal(t, emptyPeerID, p.GetRand())
})
t.Run("multiple_items", func(t *testing.T) {
p := newPeersMap()
require.Equal(t, emptyPeerID, p.GetRand())
allPeers := make(map[peer.ID]bool)
for i := 0; i < 20; i++ {
pid := peer.ID(fmt.Sprintf("peer-%d", i))
allPeers[pid] = true
p.Put(pid)
}
foundPeers := make(map[peer.ID]bool)
for i := 0; i < 1000; i++ {
pid := p.GetRand()
require.NotEqual(t, emptyPeerID, p)
require.True(t, allPeers[pid])
foundPeers[pid] = true
if len(foundPeers) == len(allPeers) {
break
}
}
for pid := range allPeers {
p.Delete(pid)
}
require.Equal(t, emptyPeerID, p.GetRand())
})
}
func TestAreAddrsConsistency(t *testing.T) { func TestAreAddrsConsistency(t *testing.T) {
c := &client{ c := &client{
normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr { normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr {
@@ -645,6 +640,12 @@ func TestAreAddrsConsistency(t *testing.T) {
dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"),
success: false, success: false,
}, },
{
name: "dns6",
localAddr: ma.StringCast("/dns6/lib.p2p/udp/12345/quic-v1"),
dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/"),
success: false,
},
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -658,3 +659,173 @@ func TestAreAddrsConsistency(t *testing.T) {
}) })
} }
} }
func TestPeerMap(t *testing.T) {
pm := newPeersMap()
// Add 1, 2, 3
pm.Put(peer.ID("1"))
pm.Put(peer.ID("2"))
pm.Put(peer.ID("3"))
// Remove 3, 2
pm.Delete(peer.ID("3"))
pm.Delete(peer.ID("2"))
// Add 4
pm.Put(peer.ID("4"))
// Remove 3, 2 again. Should be no op
pm.Delete(peer.ID("3"))
pm.Delete(peer.ID("2"))
contains := []peer.ID{"1", "4"}
elems := make([]peer.ID, 0)
for p := range pm.Shuffled() {
elems = append(elems, p)
}
require.ElementsMatch(t, contains, elems)
}
func FuzzClient(f *testing.F) {
a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2))
c := newAutoNAT(f, nil)
idAndWait(f, c, a)
// TODO: Move this to go-multiaddrs
getProto := func(protos []byte) ma.Multiaddr {
protoType := 0
if len(protos) > 0 {
protoType = int(protos[0])
}
port1, port2 := 0, 0
if len(protos) > 1 {
port1 = int(protos[1])
}
if len(protos) > 2 {
port2 = int(protos[2])
}
protoTemplates := []string{
"/tcp/%d/",
"/udp/%d/",
"/udp/%d/quic-v1/",
"/udp/%d/quic-v1/tcp/%d",
"/udp/%d/quic-v1/webtransport/",
"/udp/%d/webrtc/",
"/udp/%d/webrtc-direct/",
"/unix/hello/",
}
s := protoTemplates[protoType%len(protoTemplates)]
port1 %= (1 << 16)
if strings.Count(s, "%d") == 1 {
return ma.StringCast(fmt.Sprintf(s, port1))
}
port2 %= (1 << 16)
return ma.StringCast(fmt.Sprintf(s, port1, port2))
}
getIP := func(ips []byte) ma.Multiaddr {
ipType := 0
if len(ips) > 0 {
ipType = int(ips[0])
}
ips = ips[1:]
var x, y int64
split := 128 / 8
if len(ips) < split {
split = len(ips)
}
var b [8]byte
copy(b[:], ips[:split])
x = int64(binary.LittleEndian.Uint64(b[:]))
clear(b[:])
copy(b[:], ips[split:])
y = int64(binary.LittleEndian.Uint64(b[:]))
var ip netip.Addr
switch ipType % 3 {
case 0:
ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)})
return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip))
case 1:
pubIP := net.ParseIP("2005::") // Public IP address
x := int64(binary.LittleEndian.Uint64(pubIP[0:8]))
ip = netip.AddrFrom16([16]byte{
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
})
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
default:
ip := netip.AddrFrom16([16]byte{
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
})
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
}
}
getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr {
switch addrType % 4 {
case 0:
return getIP(ips).Encapsulate(getProto(protos))
case 1:
return getProto(protos)
case 2:
return nil
default:
return getIP(ips).Encapsulate(getProto(protos))
}
}
getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr {
hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "")
hostName = strings.ReplaceAll(hostName, "/", "")
if hostName == "" {
hostName = "localhost"
}
dnsType := 0
if len(hostNameBytes) > 0 {
dnsType = int(hostNameBytes[0])
}
dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"}
da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName))
return da.Encapsulate(getProto(protos))
}
const maxAddrs = 100
getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr {
if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 {
return nil
}
numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs
addrs := make([]ma.Multiaddr, numAddrs)
ipIdx := 0
protoIdx := 0
for i := range numAddrs {
addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:])
ipIdx = (ipIdx + 1) % len(ips)
protoIdx = (protoIdx + 1) % len(protos)
}
maxDNSAddrs := 10
protoIdx = 0
for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 {
ed := min(i+2, len(hostNames))
addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:]))
protoIdx = (protoIdx + 1) % len(protos)
}
return addrs
}
// reduce the streamTimeout before running this. TODO: fix this
f.Fuzz(func(_ *testing.T, numAddrs int, ips, protos, hostNames []byte) {
addrs := getAddrs(numAddrs, ips, protos, hostNames)
reqs := make([]Request, len(addrs))
for i, addr := range addrs {
reqs[i] = Request{Addr: addr, SendDialData: true}
}
c.GetReachability(context.Background(), reqs)
})
}

View File

@@ -35,20 +35,20 @@ type normalizeMultiaddrer interface {
NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr
} }
func newClient(h host.Host) *client { func newClient() *client {
return &client{
dialData: make([]byte, 4000),
dialBackQueues: make(map[uint64]chan ma.Multiaddr),
}
}
func (ac *client) Start(h host.Host) {
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a } normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
if hn, ok := h.(normalizeMultiaddrer); ok { if hn, ok := h.(normalizeMultiaddrer); ok {
normalizeMultiaddr = hn.NormalizeMultiaddr normalizeMultiaddr = hn.NormalizeMultiaddr
} }
return &client{ ac.host = h
host: h, ac.normalizeMultiaddr = normalizeMultiaddr
dialData: make([]byte, 4000),
normalizeMultiaddr: normalizeMultiaddr,
dialBackQueues: make(map[uint64]chan ma.Multiaddr),
}
}
func (ac *client) Start() {
ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack)
} }
@@ -109,9 +109,9 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
break break
// provide dial data if appropriate // provide dial data if appropriate
case msg.GetDialDataRequest() != nil: case msg.GetDialDataRequest() != nil:
if err := ac.validateDialDataRequest(reqs, &msg); err != nil { if err := validateDialDataRequest(reqs, &msg); err != nil {
s.Reset() s.Reset()
return Result{}, fmt.Errorf("invalid dial data request: %w", err) return Result{}, fmt.Errorf("invalid dial data request: %s %w", s.Conn().RemoteMultiaddr(), err)
} }
// dial data request is valid and we want to send data // dial data request is valid and we want to send data
if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil { if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil {
@@ -136,7 +136,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
// E_DIAL_REFUSED has implication for deciding future address verificiation priorities // E_DIAL_REFUSED has implication for deciding future address verificiation priorities
// wrap a distinct error for convenient errors.Is usage // wrap a distinct error for convenient errors.Is usage
if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED { if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED {
return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused) return Result{AllAddrsRefused: true}, nil
} }
return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(), return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(),
pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())]) pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())])
@@ -147,7 +147,6 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
if int(resp.AddrIdx) >= len(reqs) { if int(resp.AddrIdx) >= len(reqs) {
return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs)) return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs))
} }
// wait for nonce from the server // wait for nonce from the server
var dialBackAddr ma.Multiaddr var dialBackAddr ma.Multiaddr
if resp.GetDialStatus() == pb.DialStatus_OK { if resp.GetDialStatus() == pb.DialStatus_OK {
@@ -163,7 +162,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
return ac.newResult(resp, reqs, dialBackAddr) return ac.newResult(resp, reqs, dialBackAddr)
} }
func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error { func validateDialDataRequest(reqs []Request, msg *pb.Message) error {
idx := int(msg.GetDialDataRequest().AddrIdx) idx := int(msg.GetDialDataRequest().AddrIdx)
if idx >= len(reqs) { // invalid address index if idx >= len(reqs) { // invalid address index
return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs)) return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs))
@@ -179,9 +178,13 @@ func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error
func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) {
idx := int(resp.AddrIdx) idx := int(resp.AddrIdx)
if idx >= len(reqs) {
// This should have been validated by this point, but checking this is cheap.
return Result{}, fmt.Errorf("addrs index(%d) greater than len(reqs)(%d)", idx, len(reqs))
}
addr := reqs[idx].Addr addr := reqs[idx].Addr
var rch network.Reachability rch := network.ReachabilityUnknown //nolint:ineffassign
switch resp.DialStatus { switch resp.DialStatus {
case pb.DialStatus_OK: case pb.DialStatus_OK:
if !ac.areAddrsConsistent(dialBackAddr, addr) { if !ac.areAddrsConsistent(dialBackAddr, addr) {
@@ -191,17 +194,16 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr
return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr) return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
} }
rch = network.ReachabilityPublic rch = network.ReachabilityPublic
case pb.DialStatus_E_DIAL_BACK_ERROR:
if !ac.areAddrsConsistent(dialBackAddr, addr) {
return Result{}, fmt.Errorf("dial-back stream error: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
}
// We received the dial back but the server claims the dial back errored.
// As long as we received the correct nonce in dial back it is safe to assume
// that we are public.
rch = network.ReachabilityPublic
case pb.DialStatus_E_DIAL_ERROR: case pb.DialStatus_E_DIAL_ERROR:
rch = network.ReachabilityPrivate rch = network.ReachabilityPrivate
case pb.DialStatus_E_DIAL_BACK_ERROR:
if ac.areAddrsConsistent(dialBackAddr, addr) {
// We received the dial back but the server claims the dial back errored.
// As long as we received the correct nonce in dial back it is safe to assume
// that we are public.
rch = network.ReachabilityPublic
} else {
rch = network.ReachabilityUnknown
}
default: default:
// Unexpected response code. Discard the response and fail. // Unexpected response code. Discard the response and fail.
log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus)
@@ -210,8 +212,8 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr
return Result{ return Result{
Addr: addr, Addr: addr,
Idx: idx,
Reachability: rch, Reachability: rch,
Status: resp.DialStatus,
}, nil }, nil
} }
@@ -307,7 +309,7 @@ func (ac *client) handleDialBack(s network.Stream) {
} }
func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool { func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool {
if connLocalAddr == nil || dialedAddr == nil { if len(connLocalAddr) == 0 || len(dialedAddr) == 0 {
return false return false
} }
connLocalAddr = ac.normalizeMultiaddr(connLocalAddr) connLocalAddr = ac.normalizeMultiaddr(connLocalAddr)
@@ -318,32 +320,31 @@ func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) boo
if len(localProtos) != len(externalProtos) { if len(localProtos) != len(externalProtos) {
return false return false
} }
for i := 0; i < len(localProtos); i++ { for i, lp := range localProtos {
ep := externalProtos[i]
if i == 0 { if i == 0 {
switch externalProtos[i].Code { switch ep.Code {
case ma.P_DNS, ma.P_DNSADDR: case ma.P_DNS, ma.P_DNSADDR:
if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 { if lp.Code == ma.P_IP4 || lp.Code == ma.P_IP6 {
continue continue
} }
return false return false
case ma.P_DNS4: case ma.P_DNS4:
if localProtos[i].Code == ma.P_IP4 { if lp.Code == ma.P_IP4 {
continue continue
} }
return false return false
case ma.P_DNS6: case ma.P_DNS6:
if localProtos[i].Code == ma.P_IP6 { if lp.Code == ma.P_IP6 {
continue continue
} }
return false return false
} }
if localProtos[i].Code != externalProtos[i].Code { if lp.Code != ep.Code {
return false
}
} else {
if localProtos[i].Code != externalProtos[i].Code {
return false return false
} }
} else if lp.Code != ep.Code {
return false
} }
} }
return true return true

View File

@@ -13,6 +13,7 @@ type autoNATSettings struct {
now func() time.Time now func() time.Time
amplificatonAttackPreventionDialWait time.Duration amplificatonAttackPreventionDialWait time.Duration
metricsTracer MetricsTracer metricsTracer MetricsTracer
throttlePeerDuration time.Duration
} }
func defaultSettings() *autoNATSettings { func defaultSettings() *autoNATSettings {
@@ -25,6 +26,7 @@ func defaultSettings() *autoNATSettings {
dataRequestPolicy: amplificationAttackPrevention, dataRequestPolicy: amplificationAttackPrevention,
amplificatonAttackPreventionDialWait: 3 * time.Second, amplificatonAttackPreventionDialWait: 3 * time.Second,
now: time.Now, now: time.Now,
throttlePeerDuration: defaultThrottlePeerDuration,
} }
} }
@@ -65,3 +67,10 @@ func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption {
return nil return nil
} }
} }
func withThrottlePeerDuration(d time.Duration) AutoNATOption {
return func(s *autoNATSettings) error {
s.throttlePeerDuration = d
return nil
}
}

View File

@@ -59,10 +59,9 @@ type server struct {
allowPrivateAddrs bool allowPrivateAddrs bool
} }
func newServer(host, dialer host.Host, s *autoNATSettings) *server { func newServer(dialer host.Host, s *autoNATSettings) *server {
return &server{ return &server{
dialerHost: dialer, dialerHost: dialer,
host: host,
dialDataRequestPolicy: s.dataRequestPolicy, dialDataRequestPolicy: s.dataRequestPolicy,
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait, amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
allowPrivateAddrs: s.allowPrivateAddrs, allowPrivateAddrs: s.allowPrivateAddrs,
@@ -79,7 +78,8 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server {
} }
// Enable attaches the stream handler to the host. // Enable attaches the stream handler to the host.
func (as *server) Start() { func (as *server) Start(h host.Host) {
as.host = h
as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) as.host.SetStreamHandler(DialProtocol, as.handleDialRequest)
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
@@ -46,8 +47,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
idAndWait(t, c, an) idAndWait(t, c, an)
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
require.ErrorIs(t, err, ErrDialRefused) require.NoError(t, err)
require.Equal(t, Result{}, res) require.Equal(t, Result{AllAddrsRefused: true}, res)
}) })
t.Run("black holed addr", func(t *testing.T) { t.Run("black holed addr", func(t *testing.T) {
@@ -64,8 +65,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"), Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"),
SendDialData: true, SendDialData: true,
}}) }})
require.ErrorIs(t, err, ErrDialRefused) require.NoError(t, err)
require.Equal(t, Result{}, res) require.Equal(t, Result{AllAddrsRefused: true}, res)
}) })
t.Run("private addrs", func(t *testing.T) { t.Run("private addrs", func(t *testing.T) {
@@ -76,8 +77,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
idAndWait(t, c, an) idAndWait(t, c, an)
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
require.ErrorIs(t, err, ErrDialRefused) require.NoError(t, err)
require.Equal(t, Result{}, res) require.Equal(t, Result{AllAddrsRefused: true}, res)
}) })
t.Run("relay addrs", func(t *testing.T) { t.Run("relay addrs", func(t *testing.T) {
@@ -89,8 +90,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
res, err := c.GetReachability(context.Background(), newTestRequests( res, err := c.GetReachability(context.Background(), newTestRequests(
[]ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true)) []ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true))
require.ErrorIs(t, err, ErrDialRefused) require.NoError(t, err)
require.Equal(t, Result{}, res) require.Equal(t, Result{AllAddrsRefused: true}, res)
}) })
t.Run("no addr", func(t *testing.T) { t.Run("no addr", func(t *testing.T) {
@@ -113,8 +114,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
idAndWait(t, c, an) idAndWait(t, c, an)
res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true)) res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true))
require.ErrorIs(t, err, ErrDialRefused) require.NoError(t, err)
require.Equal(t, Result{}, res) require.Equal(t, Result{AllAddrsRefused: true}, res)
}) })
t.Run("msg too large", func(t *testing.T) { t.Run("msg too large", func(t *testing.T) {
@@ -135,7 +136,6 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
require.ErrorIs(t, err, network.ErrReset) require.ErrorIs(t, err, network.ErrReset)
require.Equal(t, Result{}, res) require.Equal(t, Result{}, res)
}) })
} }
func TestServerDataRequest(t *testing.T) { func TestServerDataRequest(t *testing.T) {
@@ -178,8 +178,8 @@ func TestServerDataRequest(t *testing.T) {
require.Equal(t, Result{ require.Equal(t, Result{
Addr: quicAddr, Addr: quicAddr,
Idx: 0,
Reachability: network.ReachabilityPublic, Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res) }, res)
// Small messages should be rejected for dial data // Small messages should be rejected for dial data
@@ -191,14 +191,11 @@ func TestServerDataRequest(t *testing.T) {
func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
const concurrentRequests = 5 const concurrentRequests = 5
// server will skip all tcp addresses stallChan := make(chan struct{})
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) an := newAutoNAT(t, nil, allowPrivateAddrs, withDataRequestPolicy(
doneChan := make(chan struct{})
an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy(
// stall all allowed requests // stall all allowed requests
func(_, dialAddr ma.Multiaddr) bool { func(_, dialAddr ma.Multiaddr) bool {
<-doneChan <-stallChan
return true return true
}), }),
WithServerRateLimit(10, 10, 10, concurrentRequests), WithServerRateLimit(10, 10, 10, concurrentRequests),
@@ -207,16 +204,18 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
defer an.Close() defer an.Close()
defer an.host.Close() defer an.host.Close()
c := newAutoNAT(t, nil, allowPrivateAddrs) // server will skip all tcp addresses
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
c := newAutoNAT(t, dialer, allowPrivateAddrs)
defer c.Close() defer c.Close()
defer c.host.Close() defer c.host.Close()
idAndWait(t, c, an) idAndWait(t, c, an)
errChan := make(chan error) errChan := make(chan error)
const N = 10 const n = 10
// num concurrentRequests will stall and N will fail // num concurrentRequests will stall and n will fail
for i := 0; i < concurrentRequests+N; i++ { for i := 0; i < concurrentRequests+n; i++ {
go func() { go func() {
_, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}}) _, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}})
errChan <- err errChan <- err
@@ -224,17 +223,20 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
} }
// check N failures // check N failures
for i := 0; i < N; i++ { for i := 0; i < n; i++ {
select { select {
case err := <-errChan: case err := <-errChan:
require.Error(t, err) require.Error(t, err)
if !strings.Contains(err.Error(), "stream reset") && !strings.Contains(err.Error(), "E_REQUEST_REJECTED") {
t.Fatalf("invalid error: %s expected: stream reset or E_REQUEST_REJECTED", err)
}
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
t.Fatalf("expected %d errors: got: %d", N, i) t.Fatalf("expected %d errors: got: %d", n, i)
} }
} }
close(stallChan) // complete stalled requests
// check concurrentRequests failures, as we won't send dial data // check concurrentRequests failures, as we won't send dial data
close(doneChan)
for i := 0; i < concurrentRequests; i++ { for i := 0; i < concurrentRequests; i++ {
select { select {
case err := <-errChan: case err := <-errChan:
@@ -290,8 +292,8 @@ func TestServerDataRequestJitter(t *testing.T) {
require.Equal(t, Result{ require.Equal(t, Result{
Addr: quicAddr, Addr: quicAddr,
Idx: 0,
Reachability: network.ReachabilityPublic, Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res) }, res)
if took > 500*time.Millisecond { if took > 500*time.Millisecond {
return return
@@ -320,8 +322,8 @@ func TestServerDial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, Result{ require.Equal(t, Result{
Addr: unreachableAddr, Addr: unreachableAddr,
Idx: 0,
Reachability: network.ReachabilityPrivate, Reachability: network.ReachabilityPrivate,
Status: pb.DialStatus_E_DIAL_ERROR,
}, res) }, res)
}) })
@@ -330,16 +332,16 @@ func TestServerDial(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, Result{ require.Equal(t, Result{
Addr: hostAddrs[0], Addr: hostAddrs[0],
Idx: 0,
Reachability: network.ReachabilityPublic, Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res) }, res)
for _, addr := range c.host.Addrs() { for _, addr := range c.host.Addrs() {
res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false)) res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, Result{ require.Equal(t, Result{
Addr: addr, Addr: addr,
Idx: 0,
Reachability: network.ReachabilityPublic, Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res) }, res)
} }
}) })
@@ -347,12 +349,8 @@ func TestServerDial(t *testing.T) {
t.Run("dialback error", func(t *testing.T) { t.Run("dialback error", func(t *testing.T) {
c.host.RemoveStreamHandler(DialBackProtocol) c.host.RemoveStreamHandler(DialBackProtocol)
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
require.NoError(t, err) require.ErrorContains(t, err, "dial-back stream error")
require.Equal(t, Result{ require.Equal(t, Result{}, res)
Addr: hostAddrs[0],
Reachability: network.ReachabilityUnknown,
Status: pb.DialStatus_E_DIAL_BACK_ERROR,
}, res)
}) })
} }
@@ -396,7 +394,6 @@ func TestRateLimiter(t *testing.T) {
cl.AdvanceBy(10 * time.Second) cl.AdvanceBy(10 * time.Second)
require.True(t, r.Accept("peer3")) require.True(t, r.Accept("peer3"))
} }
func TestRateLimiterConcurrentRequests(t *testing.T) { func TestRateLimiterConcurrentRequests(t *testing.T) {
@@ -558,22 +555,23 @@ func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, Result{ require.Equal(t, Result{
Addr: quicv4Addr, Addr: quicv4Addr,
Idx: 0,
Reachability: network.ReachabilityPublic, Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res) }, res)
// ipv6 address should require dial data // ipv6 address should require dial data
_, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}}) _, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}})
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "invalid dial data request: low priority addr") require.ErrorContains(t, err, "invalid dial data request")
require.ErrorContains(t, err, "low priority addr")
// ipv6 address should work fine with dial data // ipv6 address should work fine with dial data
res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}}) res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, Result{ require.Equal(t, Result{
Addr: quicv6Addr, Addr: quicv6Addr,
Idx: 0,
Reachability: network.ReachabilityPublic, Reachability: network.ReachabilityPublic,
Status: pb.DialStatus_OK,
}, res) }, res)
} }