Files
go-libp2p/p2p/host/basic/addrs_manager.go
2025-08-25 12:24:40 -07:00

677 lines
22 KiB
Go

package basichost
import (
"context"
"errors"
"fmt"
"io"
"net"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-netroute"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/prometheus/client_golang/prometheus"
)
const maxObservedAddrsPerListenAddr = 3
// addrChangeTickrInterval is the interval to recompute host addrs.
var addrChangeTickrInterval = 5 * time.Second
// ObservedAddrsManager maps our local listen addrs to externally observed addrs.
type ObservedAddrsManager interface {
Addrs(minObservers int) []ma.Multiaddr
AddrsFor(local ma.Multiaddr) []ma.Multiaddr
}
type hostAddrs struct {
addrs []ma.Multiaddr
localAddrs []ma.Multiaddr
reachableAddrs []ma.Multiaddr
unreachableAddrs []ma.Multiaddr
unknownAddrs []ma.Multiaddr
relayAddrs []ma.Multiaddr
}
type addrsManager struct {
bus event.Bus
natManager NATManager
addrsFactory AddrsFactory
listenAddrs func() []ma.Multiaddr
addCertHashes func([]ma.Multiaddr) []ma.Multiaddr
observedAddrsManager ObservedAddrsManager
interfaceAddrs *interfaceAddrsCache
addrsReachabilityTracker *addrsReachabilityTracker
// addrsUpdatedChan is notified when addrs change. This is provided by the caller.
addrsUpdatedChan chan struct{}
// triggerAddrsUpdateChan is used to trigger an addresses update.
triggerAddrsUpdateChan chan chan struct{}
// started is used to check whether the addrsManager has started.
started atomic.Bool
// triggerReachabilityUpdate is notified when reachable addrs are updated.
triggerReachabilityUpdate chan struct{}
hostReachability atomic.Pointer[network.Reachability]
addrsMx sync.RWMutex
currentAddrs hostAddrs
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func newAddrsManager(
bus event.Bus,
natmgr NATManager,
addrsFactory AddrsFactory,
listenAddrs func() []ma.Multiaddr,
addCertHashes func([]ma.Multiaddr) []ma.Multiaddr,
observedAddrsManager ObservedAddrsManager,
addrsUpdatedChan chan struct{},
client autonatv2Client,
enableMetrics bool,
registerer prometheus.Registerer,
) (*addrsManager, error) {
ctx, cancel := context.WithCancel(context.Background())
as := &addrsManager{
bus: bus,
listenAddrs: listenAddrs,
addCertHashes: addCertHashes,
observedAddrsManager: observedAddrsManager,
natManager: natmgr,
addrsFactory: addrsFactory,
triggerAddrsUpdateChan: make(chan chan struct{}, 1),
triggerReachabilityUpdate: make(chan struct{}, 1),
addrsUpdatedChan: addrsUpdatedChan,
interfaceAddrs: &interfaceAddrsCache{},
ctx: ctx,
ctxCancel: cancel,
}
unknownReachability := network.ReachabilityUnknown
as.hostReachability.Store(&unknownReachability)
if client != nil {
var metricsTracker MetricsTracker
if enableMetrics {
metricsTracker = newMetricsTracker(withRegisterer(registerer))
}
as.addrsReachabilityTracker = newAddrsReachabilityTracker(client, as.triggerReachabilityUpdate, nil, metricsTracker)
}
return as, nil
}
func (a *addrsManager) Start() error {
if a.addrsReachabilityTracker != nil {
err := a.addrsReachabilityTracker.Start()
if err != nil {
return fmt.Errorf("error starting addrs reachability tracker: %s", err)
}
}
return a.startBackgroundWorker()
}
func (a *addrsManager) Close() {
a.ctxCancel()
if a.natManager != nil {
err := a.natManager.Close()
if err != nil {
log.Warn("error closing natmgr", "err", err)
}
}
if a.addrsReachabilityTracker != nil {
err := a.addrsReachabilityTracker.Close()
if err != nil {
log.Warn("error closing addrs reachability tracker", "err", err)
}
}
a.wg.Wait()
}
func (a *addrsManager) NetNotifee() network.Notifiee {
return &network.NotifyBundle{
ListenF: func(network.Network, ma.Multiaddr) { a.updateAddrsSync() },
ListenCloseF: func(network.Network, ma.Multiaddr) { a.updateAddrsSync() },
}
}
func (a *addrsManager) updateAddrsSync() {
// This prevents a deadlock where addrs updates before starting the manager are ignored
if !a.started.Load() {
return
}
ch := make(chan struct{})
select {
case a.triggerAddrsUpdateChan <- ch:
select {
case <-ch:
case <-a.ctx.Done():
}
case <-a.ctx.Done():
}
}
func (a *addrsManager) startBackgroundWorker() (retErr error) {
autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager autorelay sub"))
if err != nil {
return fmt.Errorf("error subscribing to auto relay addrs: %s", err)
}
mc := multiCloser{autoRelayAddrsSub}
autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager autonatv1 sub"))
if err != nil {
return errors.Join(
fmt.Errorf("error subscribing to autonat reachability: %s", err),
mc.Close(),
)
}
mc = append(mc, autonatReachabilitySub)
emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful)
if err != nil {
return errors.Join(
fmt.Errorf("error creating reachability subscriber: %s", err),
mc.Close(),
)
}
var relayAddrs []ma.Multiaddr
// update relay addrs in case we're private
select {
case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
relayAddrs = slices.Clone(evt.RelayAddrs)
}
default:
}
select {
case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
a.hostReachability.Store(&evt.Reachability)
}
default:
}
// this ensures that listens concurrent with Start are reflected correctly after Start exits.
a.started.Store(true)
// update addresses before starting the worker loop. This ensures that any address updates
// before calling addrsManager.Start are correctly reported after Start returns.
a.updateAddrs(relayAddrs)
a.wg.Add(1)
go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, relayAddrs)
return nil
}
func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription,
emitter event.Emitter, relayAddrs []ma.Multiaddr,
) {
defer a.wg.Done()
defer func() {
err := autoRelayAddrsSub.Close()
if err != nil {
log.Warn("error closing auto relay addrs sub", "err", err)
}
err = autonatReachabilitySub.Close()
if err != nil {
log.Warn("error closing autonat reachability sub", "err", err)
}
err = emitter.Close()
if err != nil {
log.Warn("error closing host reachability emitter", "err", err)
}
}()
ticker := time.NewTicker(addrChangeTickrInterval)
defer ticker.Stop()
var previousAddrs hostAddrs
var notifCh chan struct{}
for {
currAddrs := a.updateAddrs(relayAddrs)
if notifCh != nil {
close(notifCh)
notifCh = nil
}
a.notifyAddrsChanged(emitter, previousAddrs, currAddrs)
previousAddrs = currAddrs
select {
case <-ticker.C:
case notifCh = <-a.triggerAddrsUpdateChan:
case <-a.triggerReachabilityUpdate:
case e := <-autoRelayAddrsSub.Out():
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
relayAddrs = slices.Clone(evt.RelayAddrs)
}
case e := <-autonatReachabilitySub.Out():
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
a.hostReachability.Store(&evt.Reachability)
}
case <-a.ctx.Done():
return
}
}
}
// updateAddrs updates the addresses of the host and returns the new updated
// addrs. This must only be called from the background goroutine or from the Start method otherwise
// we may end up with stale addrs.
func (a *addrsManager) updateAddrs(relayAddrs []ma.Multiaddr) hostAddrs {
localAddrs := a.getLocalAddrs()
var currReachableAddrs, currUnreachableAddrs, currUnknownAddrs []ma.Multiaddr
if a.addrsReachabilityTracker != nil {
currReachableAddrs, currUnreachableAddrs, currUnknownAddrs = a.getConfirmedAddrs(localAddrs)
}
relayAddrs = slices.Clone(relayAddrs)
currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs)
a.addrsMx.Lock()
a.currentAddrs = hostAddrs{
addrs: append(a.currentAddrs.addrs[:0], currAddrs...),
localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...),
reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...),
unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...),
unknownAddrs: append(a.currentAddrs.unknownAddrs[:0], currUnknownAddrs...),
relayAddrs: append(a.currentAddrs.relayAddrs[:0], relayAddrs...),
}
a.addrsMx.Unlock()
return hostAddrs{
localAddrs: localAddrs,
addrs: currAddrs,
reachableAddrs: currReachableAddrs,
unreachableAddrs: currUnreachableAddrs,
unknownAddrs: currUnknownAddrs,
relayAddrs: relayAddrs,
}
}
func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) {
if areAddrsDifferent(previous.localAddrs, current.localAddrs) {
log.Debug("host local addresses updated", "addrs", current.localAddrs)
if a.addrsReachabilityTracker != nil {
a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs)
}
}
if areAddrsDifferent(previous.addrs, current.addrs) {
log.Debug("host addresses updated", "addrs", current.localAddrs)
select {
case a.addrsUpdatedChan <- struct{}{}:
default:
}
}
// We *must* send both reachability changed and addrs changed events from the
// same goroutine to ensure correct ordering
// Consider the events:
// - addr x discovered
// - addr x is reachable
// - addr x removed
// We must send these events in the same order. It'll be confusing for consumers
// if the reachable event is received after the addr removed event.
if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) ||
areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) ||
areAddrsDifferent(previous.unknownAddrs, current.unknownAddrs) {
log.Debug("host reachable addrs updated",
"reachable", current.reachableAddrs,
"unreachable", current.unreachableAddrs,
"unknown", current.unknownAddrs)
if err := emitter.Emit(event.EvtHostReachableAddrsChanged{
Reachable: slices.Clone(current.reachableAddrs),
Unreachable: slices.Clone(current.unreachableAddrs),
Unknown: slices.Clone(current.unknownAddrs),
}); err != nil {
log.Error("error sending host reachable addrs changed event", "err", err)
}
}
}
// Addrs returns the node's dialable addresses both public and private.
// If autorelay is enabled and node reachability is private, it returns
// the node's relay addresses and private network addresses.
func (a *addrsManager) Addrs() []ma.Multiaddr {
a.addrsMx.RLock()
directAddrs := slices.Clone(a.currentAddrs.localAddrs)
relayAddrs := slices.Clone(a.currentAddrs.relayAddrs)
a.addrsMx.RUnlock()
return a.getAddrs(directAddrs, relayAddrs)
}
// getAddrs returns the node's dialable addresses. Mutates localAddrs
func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr {
addrs := localAddrs
rch := a.hostReachability.Load()
if rch != nil && *rch == network.ReachabilityPrivate {
// Delete public addresses if the node's reachability is private, and we have relay addresses
if len(relayAddrs) > 0 {
addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr)
addrs = append(addrs, relayAddrs...)
}
}
// Make a copy. Consumers can modify the slice elements
addrs = slices.Clone(a.addrsFactory(addrs))
// Add certhashes for the addresses provided by the user via address factory.
addrs = a.addCertHashes(ma.Unique(addrs))
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
return addrs
}
// HolePunchAddrs returns all the host's direct public addresses, reachable or unreachable,
// suitable for hole punching.
func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
addrs := a.DirectAddrs()
addrs = slices.Clone(a.addrsFactory(addrs))
// AllAddrs may ignore observed addresses in favour of NAT mappings.
// Use both for hole punching.
if a.observedAddrsManager != nil {
// For holepunching, include all the best addresses we know even ones with only 1 observer.
addrs = append(addrs, a.observedAddrsManager.Addrs(1)...)
}
addrs = ma.Unique(addrs)
return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
}
// DirectAddrs returns all the addresses the host is listening on except circuit addresses.
func (a *addrsManager) DirectAddrs() []ma.Multiaddr {
a.addrsMx.RLock()
defer a.addrsMx.RUnlock()
return slices.Clone(a.currentAddrs.localAddrs)
}
// ConfirmedAddrs returns all addresses of the host that are reachable from the internet
func (a *addrsManager) ConfirmedAddrs() (reachable []ma.Multiaddr, unreachable []ma.Multiaddr, unknown []ma.Multiaddr) {
a.addrsMx.RLock()
defer a.addrsMx.RUnlock()
return slices.Clone(a.currentAddrs.reachableAddrs), slices.Clone(a.currentAddrs.unreachableAddrs), slices.Clone(a.currentAddrs.unknownAddrs)
}
func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs, unknownAddrs []ma.Multiaddr) {
reachableAddrs, unreachableAddrs, unknownAddrs = a.addrsReachabilityTracker.ConfirmedAddrs()
return removeNotInSource(reachableAddrs, localAddrs), removeNotInSource(unreachableAddrs, localAddrs), removeNotInSource(unknownAddrs, localAddrs)
}
var p2pCircuitAddr = ma.StringCast("/p2p-circuit")
func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
listenAddrs := a.listenAddrs()
if len(listenAddrs) == 0 {
return nil
}
finalAddrs := make([]ma.Multiaddr, 0, 8)
finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs)
if a.natManager != nil {
finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs)
}
if a.observedAddrsManager != nil {
finalAddrs = a.appendObservedAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All())
}
// Remove "/p2p-circuit" addresses from the list.
// The p2p-circuit listener reports its address as just /p2p-circuit. This is
// useless for dialing. Users need to manage their circuit addresses themselves,
// or use AutoRelay.
finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool {
return a.Equal(p2pCircuitAddr)
})
// Remove any unspecified address from the list
finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool {
return manet.IsIPUnspecified(a)
})
// Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered
// using identify.
finalAddrs = a.addCertHashes(finalAddrs)
finalAddrs = ma.Unique(finalAddrs)
slices.SortFunc(finalAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
return finalAddrs
}
// appendPrimaryInterfaceAddrs appends the primary interface addresses to `dst`.
func (a *addrsManager) appendPrimaryInterfaceAddrs(dst []ma.Multiaddr, listenAddrs []ma.Multiaddr) []ma.Multiaddr {
// resolving any unspecified listen addressees to use only the primary
// interface to avoid advertising too many addresses.
if resolved, err := manet.ResolveUnspecifiedAddresses(listenAddrs, a.interfaceAddrs.Filtered()); err != nil {
log.Warn("failed to resolve listen addrs", "err", err)
} else {
dst = append(dst, resolved...)
}
return dst
}
// appendNATAddrs appends the NAT-ed addrs for the listenAddrs. For unspecified listen addrs it appends the
// public address for all the interfaces.
// Inferring WebTransport from QUIC depends on the observed address manager.
func (a *addrsManager) appendNATAddrs(dst []ma.Multiaddr, listenAddrs []ma.Multiaddr) []ma.Multiaddr {
for _, listenAddr := range listenAddrs {
natAddr := a.natManager.GetMapping(listenAddr)
if natAddr != nil {
dst = append(dst, natAddr)
}
}
return dst
}
func (a *addrsManager) appendObservedAddrs(dst []ma.Multiaddr, listenAddrs, ifaceAddrs []ma.Multiaddr) []ma.Multiaddr {
// Add it for all the listenAddr first.
// listenAddr maybe unspecified. That's okay as connections on UDP transports
// will have the unspecified address as the local address.
for _, la := range listenAddrs {
obsAddrs := a.observedAddrsManager.AddrsFor(la)
if len(obsAddrs) > maxObservedAddrsPerListenAddr {
obsAddrs = obsAddrs[:maxObservedAddrsPerListenAddr]
}
dst = append(dst, obsAddrs...)
}
// if it can be resolved into more addresses, add them too
resolved, err := manet.ResolveUnspecifiedAddresses(listenAddrs, ifaceAddrs)
if err != nil {
log.Warn("failed to resolve listen addr", "listen_addr", listenAddrs, "iface_addrs", ifaceAddrs, "err", err)
return dst
}
for _, addr := range resolved {
obsAddrs := a.observedAddrsManager.AddrsFor(addr)
if len(obsAddrs) > maxObservedAddrsPerListenAddr {
obsAddrs = obsAddrs[:maxObservedAddrsPerListenAddr]
}
dst = append(dst, obsAddrs...)
}
return dst
}
func areAddrsDifferent(prev, current []ma.Multiaddr) bool {
// TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs
prev = ma.Unique(prev)
current = ma.Unique(current)
if len(prev) != len(current) {
return true
}
slices.SortFunc(prev, func(a, b ma.Multiaddr) int { return a.Compare(b) })
slices.SortFunc(current, func(a, b ma.Multiaddr) int { return a.Compare(b) })
for i := range prev {
if !prev[i].Equal(current[i]) {
return true
}
}
return false
}
const interfaceAddrsCacheTTL = time.Minute
type interfaceAddrsCache struct {
mx sync.RWMutex
filtered []ma.Multiaddr
all []ma.Multiaddr
updateLocalIPv4Backoff backoff.ExpBackoff
updateLocalIPv6Backoff backoff.ExpBackoff
lastUpdated time.Time
}
func (i *interfaceAddrsCache) Filtered() []ma.Multiaddr {
i.mx.RLock()
if time.Now().After(i.lastUpdated.Add(interfaceAddrsCacheTTL)) {
i.mx.RUnlock()
return i.update(true)
}
defer i.mx.RUnlock()
return i.filtered
}
func (i *interfaceAddrsCache) All() []ma.Multiaddr {
i.mx.RLock()
if time.Now().After(i.lastUpdated.Add(interfaceAddrsCacheTTL)) {
i.mx.RUnlock()
return i.update(false)
}
defer i.mx.RUnlock()
return i.all
}
func (i *interfaceAddrsCache) update(filtered bool) []ma.Multiaddr {
i.mx.Lock()
defer i.mx.Unlock()
if !time.Now().After(i.lastUpdated.Add(interfaceAddrsCacheTTL)) {
if filtered {
return i.filtered
}
return i.all
}
i.updateUnlocked()
i.lastUpdated = time.Now()
if filtered {
return i.filtered
}
return i.all
}
func (i *interfaceAddrsCache) updateUnlocked() {
i.filtered = nil
i.all = nil
// Try to use the default ipv4/6 addresses.
// TODO: Remove this. We should advertise all interface addresses.
if r, err := netroute.New(); err != nil {
log.Debug("failed to build Router for kernel's routing table", "err", err)
} else {
var localIPv4 net.IP
var ran bool
err, ran = i.updateLocalIPv4Backoff.Run(func() error {
_, _, localIPv4, err = r.Route(net.IPv4zero)
return err
})
if ran && err != nil {
log.Debug("failed to fetch local IPv4 address", "err", err)
} else if ran && localIPv4.IsGlobalUnicast() {
maddr, err := manet.FromIP(localIPv4)
if err == nil {
i.filtered = append(i.filtered, maddr)
}
}
var localIPv6 net.IP
err, ran = i.updateLocalIPv6Backoff.Run(func() error {
_, _, localIPv6, err = r.Route(net.IPv6unspecified)
return err
})
if ran && err != nil {
log.Debug("failed to fetch local IPv6 address", "err", err)
} else if ran && localIPv6.IsGlobalUnicast() {
maddr, err := manet.FromIP(localIPv6)
if err == nil {
i.filtered = append(i.filtered, maddr)
}
}
}
// Resolve the interface addresses
ifaceAddrs, err := manet.InterfaceMultiaddrs()
if err != nil {
// This usually shouldn't happen, but we could be in some kind
// of funky restricted environment.
log.Error("failed to resolve local interface addresses", "err", err)
// Add the loopback addresses to the filtered addrs and use them as the non-filtered addrs.
// Then bail. There's nothing else we can do here.
i.filtered = append(i.filtered, manet.IP4Loopback, manet.IP6Loopback)
i.all = i.filtered
return
}
// remove link local ipv6 addresses
i.all = slices.DeleteFunc(ifaceAddrs, manet.IsIP6LinkLocal)
// If netroute failed to get us any interface addresses, use all of
// them.
if len(i.filtered) == 0 {
// Add all addresses.
i.filtered = i.all
} else {
// Only add loopback addresses. Filter these because we might
// not _have_ an IPv6 loopback address.
for _, addr := range i.all {
if manet.IsIPLoopback(addr) {
i.filtered = append(i.filtered, addr)
}
}
}
}
// removeNotInSource removes items from addrs that are not present in source.
// Modifies the addrs slice in place
// addrs and source must be sorted using multiaddr.Compare.
func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr {
j := 0
// mark entries not in source as nil
for i, a := range addrs {
// move right as long as a > source[j]
for j < len(source) && a.Compare(source[j]) > 0 {
j++
}
// a is not in source if we've reached the end, or a is lesser
if j == len(source) || a.Compare(source[j]) < 0 {
addrs[i] = nil
}
// a is in source, nothing to do
}
// j is the current element, i is the lowest index nil element
i := 0
for j := range len(addrs) {
if addrs[j] != nil {
addrs[i], addrs[j] = addrs[j], addrs[i]
i++
}
}
return addrs[:i]
}
type multiCloser []io.Closer
func (mc *multiCloser) Close() error {
var errs []error
for _, closer := range *mc {
if err := closer.Close(); err != nil {
var closerName string
if named, ok := closer.(interface{ Name() string }); ok {
closerName = named.Name()
} else {
closerName = fmt.Sprintf("%T", closer)
}
errs = append(errs, fmt.Errorf("error closing %s: %w", closerName, err))
}
}
return errors.Join(errs...)
}