identify: refactor observed address manager to do address mapping at thin waist(IP+TCP/UDP) layer (#2793)

* refactor observed address manager to do mapping at thin waist layer

---------

Co-authored-by: Marco Munizaga <git@marcopolo.io>

* restrict output message size, add top level option to disable address
discovery

* Comment nit

* Increase maxPeerRecordSize

---------

Co-authored-by: Marco Munizaga <git@marcopolo.io>
This commit is contained in:
sukun
2024-05-16 08:47:25 +05:30
committed by GitHub
parent b76d639690
commit 6861cecb3c
13 changed files with 1660 additions and 1112 deletions

View File

@@ -129,6 +129,8 @@ type Config struct {
DialRanker network.DialRanker
SwarmOpts []swarm.Option
DisableIdentifyAddressDiscovery bool
}
func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) {
@@ -290,19 +292,20 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) {
h, err := bhost.NewHost(swrm, &bhost.HostOpts{
EventBus: eventBus,
ConnManager: cfg.ConnManager,
AddrsFactory: cfg.AddrsFactory,
NATManager: cfg.NATManager,
EnablePing: !cfg.DisablePing,
UserAgent: cfg.UserAgent,
ProtocolVersion: cfg.ProtocolVersion,
EnableHolePunching: cfg.EnableHolePunching,
HolePunchingOptions: cfg.HolePunchingOptions,
EnableRelayService: cfg.EnableRelayService,
RelayServiceOpts: cfg.RelayServiceOpts,
EnableMetrics: !cfg.DisableMetrics,
PrometheusRegisterer: cfg.PrometheusRegisterer,
EventBus: eventBus,
ConnManager: cfg.ConnManager,
AddrsFactory: cfg.AddrsFactory,
NATManager: cfg.NATManager,
EnablePing: !cfg.DisablePing,
UserAgent: cfg.UserAgent,
ProtocolVersion: cfg.ProtocolVersion,
EnableHolePunching: cfg.EnableHolePunching,
HolePunchingOptions: cfg.HolePunchingOptions,
EnableRelayService: cfg.EnableRelayService,
RelayServiceOpts: cfg.RelayServiceOpts,
EnableMetrics: !cfg.DisableMetrics,
PrometheusRegisterer: cfg.PrometheusRegisterer,
DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery,
})
if err != nil {
return nil, err

View File

@@ -28,9 +28,10 @@ var (
// RecentlyConnectedAddrTTL is used when we recently connected to a peer.
// It means that we are reasonably certain of the peer's address.
RecentlyConnectedAddrTTL = time.Minute * 30
RecentlyConnectedAddrTTL = time.Minute * 15
// OwnObservedAddrTTL is used for our own external addresses observed by peers.
// Deprecated: observed addresses are maintained till we disconnect from the peer which provided it
OwnObservedAddrTTL = time.Minute * 30
)

View File

@@ -376,6 +376,12 @@ func TestAutoNATService(t *testing.T) {
h.Close()
}
func TestDisableIdentifyAddressDiscovery(t *testing.T) {
h, err := New(DisableIdentifyAddressDiscovery())
require.NoError(t, err)
h.Close()
}
func TestMain(m *testing.M) {
goleak.VerifyTestMain(
m,

View File

@@ -598,3 +598,14 @@ func SwarmOpts(opts ...swarm.Option) Option {
return nil
}
}
// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses
// in identify. If you know your public addresses upfront, the recommended way is to use
// AddressFactory to provide the external adddress to the host and use this option to disable
// discovery from identify.
func DisableIdentifyAddressDiscovery() Option {
return func(cfg *Config) error {
cfg.DisableIdentifyAddressDiscovery = true
return nil
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"slices"
"sync"
"time"
@@ -53,6 +54,8 @@ var (
DefaultAddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs }
)
const maxPeerRecordSize = 8 * 1024 // 8k to be compatible with identify's limit
// AddrsFactory functions can be passed to New in order to override
// addresses returned by Addrs.
type AddrsFactory func([]ma.Multiaddr) []ma.Multiaddr
@@ -161,6 +164,9 @@ type HostOpts struct {
EnableMetrics bool
// PrometheusRegisterer is the PrometheusRegisterer used for metrics
PrometheusRegisterer prometheus.Registerer
// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify
DisableIdentifyAddressDiscovery bool
}
// NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network.
@@ -244,6 +250,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
identify.WithMetricsTracer(
identify.NewMetricsTracer(identify.WithRegisterer(opts.PrometheusRegisterer))))
}
if opts.DisableIdentifyAddressDiscovery {
idOpts = append(idOpts, identify.DisableObservedAddrManager())
}
h.ids, err = identify.NewIDService(h, idOpts...)
if err != nil {
@@ -482,15 +491,18 @@ func makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddresses
return &evt
}
func (h *BasicHost) makeSignedPeerRecord(evt *event.EvtLocalAddressesUpdated) (*record.Envelope, error) {
current := make([]ma.Multiaddr, 0, len(evt.Current))
for _, a := range evt.Current {
current = append(current, a.Address)
func (h *BasicHost) makeSignedPeerRecord(addrs []ma.Multiaddr) (*record.Envelope, error) {
// Limit the length of currentAddrs to ensure that our signed peer records aren't rejected
peerRecordSize := 64 // HostID
k, err := h.signKey.Raw()
if err != nil {
peerRecordSize += 2 * len(k) // 1 for signature, 1 for public key
}
// we want the final address list to be small for keeping the signed peer record in size
addrs = trimHostAddrList(addrs, maxPeerRecordSize-peerRecordSize-256) // 256 B of buffer
rec := peer.PeerRecordFromAddrInfo(peer.AddrInfo{
ID: h.ID(),
Addrs: current,
Addrs: addrs,
})
return record.Seal(rec, h.signKey)
}
@@ -513,7 +525,7 @@ func (h *BasicHost) background() {
if !h.disableSignedPeerRecord {
// add signed peer record to the event
sr, err := h.makeSignedPeerRecord(changeEvt)
sr, err := h.makeSignedPeerRecord(currentAddrs)
if err != nil {
log.Errorf("error creating a signed peer record from the set of current addresses, err=%s", err)
return
@@ -805,6 +817,7 @@ func (h *BasicHost) Addrs() []ma.Multiaddr {
addrs[i] = addrWithCerthash
}
}
return addrs
}
@@ -997,6 +1010,58 @@ func inferWebtransportAddrsFromQuic(in []ma.Multiaddr) []ma.Multiaddr {
return out
}
func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr {
totalSize := 0
for _, a := range addrs {
totalSize += len(a.Bytes())
}
if totalSize <= maxSize {
return addrs
}
score := func(addr ma.Multiaddr) int {
var res int
if manet.IsPublicAddr(addr) {
res |= 1 << 12
} else if !manet.IsIPLoopback(addr) {
res |= 1 << 11
}
var protocolWeight int
ma.ForEach(addr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_QUIC_V1:
protocolWeight = 5
case ma.P_TCP:
protocolWeight = 4
case ma.P_WSS:
protocolWeight = 3
case ma.P_WEBTRANSPORT:
protocolWeight = 2
case ma.P_WEBRTC_DIRECT:
protocolWeight = 1
case ma.P_P2P:
return false
}
return true
})
res |= 1 << protocolWeight
return res
}
slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int {
return score(b) - score(a) // b-a for reverse order
})
totalSize = 0
for i, a := range addrs {
totalSize += len(a.Bytes())
if totalSize > maxSize {
addrs = addrs[:i]
break
}
}
return addrs
}
// SetAutoNat sets the autonat service for the host.
func (h *BasicHost) SetAutoNat(a autonat.AutoNAT) {
h.addrMu.Lock()

View File

@@ -896,3 +896,55 @@ func TestInferWebtransportAddrsFromQuic(t *testing.T) {
}
}
func TestTrimHostAddrList(t *testing.T) {
type testCase struct {
name string
in []ma.Multiaddr
threshold int
out []ma.Multiaddr
}
tcpPublic := ma.StringCast("/ip4/1.1.1.1/tcp/1")
quicPublic := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1")
tcpPrivate := ma.StringCast("/ip4/192.168.1.1/tcp/1")
quicPrivate := ma.StringCast("/ip4/192.168.1.1/udp/1/quic-v1")
tcpLocal := ma.StringCast("/ip4/127.0.0.1/tcp/1")
quicLocal := ma.StringCast("/ip4/127.0.0.1/udp/1/quic-v1")
testCases := []testCase{
{
name: "Public preferred over private",
in: []ma.Multiaddr{tcpPublic, quicPrivate},
threshold: len(tcpLocal.Bytes()),
out: []ma.Multiaddr{tcpPublic},
},
{
name: "Public and private preffered over local",
in: []ma.Multiaddr{tcpPublic, tcpPrivate, quicLocal},
threshold: len(tcpPublic.Bytes()) + len(tcpPrivate.Bytes()),
out: []ma.Multiaddr{tcpPublic, tcpPrivate},
},
{
name: "quic preferred over tcp",
in: []ma.Multiaddr{tcpPublic, quicPublic},
threshold: len(quicPublic.Bytes()),
out: []ma.Multiaddr{quicPublic},
},
{
name: "no filtering on large threshold",
in: []ma.Multiaddr{tcpPublic, quicPublic, quicLocal, tcpLocal, tcpPrivate},
threshold: 10000,
out: []ma.Multiaddr{tcpPublic, quicPublic, quicLocal, tcpLocal, tcpPrivate},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := trimHostAddrList(tc.in, tc.threshold)
require.ElementsMatch(t, got, tc.out)
})
}
}

View File

@@ -34,24 +34,27 @@ import (
var log = logging.Logger("net/identify")
var Timeout = 30 * time.Second // timeout on all incoming Identify interactions
const (
// ID is the protocol.ID of version 1.0.0 of the identify service.
ID = "/ipfs/id/1.0.0"
// IDPush is the protocol.ID of the Identify push protocol.
// It sends full identify messages containing the current state of the peer.
IDPush = "/ipfs/id/push/1.0.0"
)
const ServiceName = "libp2p.identify"
ServiceName = "libp2p.identify"
const maxPushConcurrency = 32
var Timeout = 60 * time.Second // timeout on all incoming Identify interactions
const (
legacyIDSize = 2 * 1024 // 2k Bytes
signedIDSize = 8 * 1024 // 8K
maxMessages = 10
legacyIDSize = 2 * 1024
signedIDSize = 8 * 1024
maxOwnIdentifyMsgSize = 4 * 1024 // smaller than what we accept. This is 4k to be compatible with rust-libp2p
maxMessages = 10
maxPushConcurrency = 32
// number of addresses to keep for peers we have disconnected from for peerstore.RecentlyConnectedTTL time
// This number can be small as we already filter peer addresses based on whether the peer is connected to us over
// localhost, private IP or public IP address
recentlyConnectedPeerMaxAddrs = 20
connectedPeerMaxAddrs = 500
)
var defaultUserAgent = "github.com/libp2p/go-libp2p"
@@ -159,7 +162,8 @@ type idService struct {
addrMu sync.Mutex
// our own observed addresses.
observedAddrs *ObservedAddrManager
observedAddrMgr *ObservedAddrManager
disableObservedAddrManager bool
emitters struct {
evtPeerProtocolsUpdated event.Emitter
@@ -171,6 +175,12 @@ type idService struct {
sync.Mutex
snapshot identifySnapshot
}
natEmitter *natEmitter
}
type normalizer interface {
NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr
}
// NewIDService constructs a new *idService and activates it by
@@ -199,11 +209,27 @@ func NewIDService(h host.Host, opts ...Option) (*idService, error) {
metricsTracer: cfg.metricsTracer,
}
observedAddrs, err := NewObservedAddrManager(h)
if err != nil {
return nil, fmt.Errorf("failed to create observed address manager: %s", err)
var normalize func(ma.Multiaddr) ma.Multiaddr
if hn, ok := h.(normalizer); ok {
normalize = hn.NormalizeMultiaddr
}
var err error
if cfg.disableObservedAddrManager {
s.disableObservedAddrManager = true
} else {
observedAddrs, err := NewObservedAddrManager(h.Network().ListenAddresses,
h.Addrs, h.Network().InterfaceListenAddresses, normalize)
if err != nil {
return nil, fmt.Errorf("failed to create observed address manager: %s", err)
}
natEmitter, err := newNATEmitter(h, observedAddrs, time.Minute)
if err != nil {
return nil, fmt.Errorf("failed to create nat emitter: %s", err)
}
s.natEmitter = natEmitter
s.observedAddrMgr = observedAddrs
}
s.observedAddrs = observedAddrs
s.emitters.evtPeerProtocolsUpdated, err = h.EventBus().Emitter(&event.EvtPeerProtocolsUpdated{})
if err != nil {
@@ -341,17 +367,26 @@ func (ids *idService) sendPushes(ctx context.Context) {
// Close shuts down the idService
func (ids *idService) Close() error {
ids.ctxCancel()
ids.observedAddrs.Close()
if !ids.disableObservedAddrManager {
ids.observedAddrMgr.Close()
ids.natEmitter.Close()
}
ids.refCount.Wait()
return nil
}
func (ids *idService) OwnObservedAddrs() []ma.Multiaddr {
return ids.observedAddrs.Addrs()
if ids.disableObservedAddrManager {
return nil
}
return ids.observedAddrMgr.Addrs()
}
func (ids *idService) ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr {
return ids.observedAddrs.AddrsFor(local)
if ids.disableObservedAddrManager {
return nil
}
return ids.observedAddrMgr.AddrsFor(local)
}
// IdentifyConn runs the Identify protocol on a connection.
@@ -553,10 +588,18 @@ func readAllIDMessages(r pbio.Reader, finalMsg proto.Message) error {
}
func (ids *idService) updateSnapshot() (updated bool) {
addrs := ids.Host.Addrs()
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return bytes.Compare(a.Bytes(), b.Bytes()) })
protos := ids.Host.Mux().Protocols()
slices.Sort(protos)
addrs := ids.Host.Addrs()
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return bytes.Compare(a.Bytes(), b.Bytes()) })
usedSpace := len(ids.ProtocolVersion) + len(ids.UserAgent)
for i := 0; i < len(protos); i++ {
usedSpace += len(protos[i])
}
addrs = trimHostAddrList(addrs, maxOwnIdentifyMsgSize-usedSpace-256) // 256 bytes of buffer
snapshot := identifySnapshot{
addrs: addrs,
protocols: protos,
@@ -715,9 +758,9 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo
obsAddr = nil
}
if obsAddr != nil {
if obsAddr != nil && !ids.disableObservedAddrManager {
// TODO refactor this to use the emitted events instead of having this func call explicitly.
ids.observedAddrs.Record(c, obsAddr)
ids.observedAddrMgr.Record(c, obsAddr)
}
// mes.ListenAddrs
@@ -777,7 +820,12 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo
} else {
addrs = lmaddrs
}
ids.Host.Peerstore().AddAddrs(p, filterAddrs(addrs, c.RemoteMultiaddr()), ttl)
addrs = filterAddrs(addrs, c.RemoteMultiaddr())
if len(addrs) > connectedPeerMaxAddrs {
addrs = addrs[:connectedPeerMaxAddrs]
}
ids.Host.Peerstore().AddAddrs(p, addrs, ttl)
// Finally, expire all temporary addrs.
ids.Host.Peerstore().UpdateAddrs(p, peerstore.TempAddrTTL, 0)
@@ -981,15 +1029,36 @@ func (nn *netNotifiee) Disconnected(_ network.Network, c network.Conn) {
delete(ids.conns, c)
ids.connsMu.Unlock()
switch ids.Host.Network().Connectedness(c.RemotePeer()) {
case network.Connected, network.Limited:
return
if !ids.disableObservedAddrManager {
ids.observedAddrMgr.removeConn(c)
}
// Last disconnect.
// Undo the setting of addresses to peer.ConnectedAddrTTL we did
ids.addrMu.Lock()
defer ids.addrMu.Unlock()
ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.RecentlyConnectedAddrTTL)
// This check MUST happen after acquiring the Lock as identify on a different connection
// might be trying to add addresses.
switch ids.Host.Network().Connectedness(c.RemotePeer()) {
case network.Connected, network.Limited:
return
}
// peerstore returns the elements in a random order as it uses a map to store the addresses
addrs := ids.Host.Peerstore().Addrs(c.RemotePeer())
n := len(addrs)
if n > recentlyConnectedPeerMaxAddrs {
// We want to always save the address we are connected to
for i, a := range addrs {
if a.Equal(c.RemoteMultiaddr()) {
addrs[i], addrs[0] = addrs[0], addrs[i]
}
}
n = recentlyConnectedPeerMaxAddrs
}
ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.TempAddrTTL)
ids.Host.Peerstore().AddAddrs(c.RemotePeer(), addrs[:n], peerstore.RecentlyConnectedAddrTTL)
ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.TempAddrTTL, 0)
}
func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {}
@@ -1008,3 +1077,55 @@ func filterAddrs(addrs []ma.Multiaddr, remote ma.Multiaddr) []ma.Multiaddr {
}
return ma.FilterAddrs(addrs, manet.IsPublicAddr)
}
func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr {
totalSize := 0
for _, a := range addrs {
totalSize += len(a.Bytes())
}
if totalSize <= maxSize {
return addrs
}
score := func(addr ma.Multiaddr) int {
var res int
if manet.IsPublicAddr(addr) {
res |= 1 << 12
} else if !manet.IsIPLoopback(addr) {
res |= 1 << 11
}
var protocolWeight int
ma.ForEach(addr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_QUIC_V1:
protocolWeight = 5
case ma.P_TCP:
protocolWeight = 4
case ma.P_WSS:
protocolWeight = 3
case ma.P_WEBTRANSPORT:
protocolWeight = 2
case ma.P_WEBRTC_DIRECT:
protocolWeight = 1
case ma.P_P2P:
return false
}
return true
})
res |= 1 << protocolWeight
return res
}
slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int {
return score(b) - score(a) // b-a for reverse order
})
totalSize = 0
for i, a := range addrs {
totalSize += len(a.Bytes())
if totalSize > maxSize {
addrs = addrs[:i]
break
}
}
return addrs
}

View File

@@ -107,104 +107,121 @@ func emitAddrChangeEvt(t *testing.T, h host.Host) {
// this is because it used to be concurrent. Now, Dial wait till the
// id service is done.
func TestIDService(t *testing.T) {
if race.WithRace() {
t.Skip("This test modifies peerstore.RecentlyConnectedAddrTTL, which is racy.")
}
// This test is highly timing dependent, waiting on timeouts/expiration.
oldTTL := peerstore.RecentlyConnectedAddrTTL
peerstore.RecentlyConnectedAddrTTL = 500 * time.Millisecond
t.Cleanup(func() { peerstore.RecentlyConnectedAddrTTL = oldTTL })
for _, withObsAddrManager := range []bool{false, true} {
t.Run(fmt.Sprintf("withObsAddrManager=%t", withObsAddrManager), func(t *testing.T) {
if race.WithRace() {
t.Skip("This test modifies peerstore.RecentlyConnectedAddrTTL, which is racy.")
}
// This test is highly timing dependent, waiting on timeouts/expiration.
oldTTL := peerstore.RecentlyConnectedAddrTTL
oldTempTTL := peerstore.TempAddrTTL
peerstore.RecentlyConnectedAddrTTL = 500 * time.Millisecond
peerstore.TempAddrTTL = 50 * time.Millisecond
t.Cleanup(func() {
peerstore.RecentlyConnectedAddrTTL = oldTTL
peerstore.TempAddrTTL = oldTempTTL
})
clk := mockClock.NewMock()
swarm1 := swarmt.GenSwarm(t, swarmt.WithClock(clk))
swarm2 := swarmt.GenSwarm(t, swarmt.WithClock(clk))
h1 := blhost.NewBlankHost(swarm1)
h2 := blhost.NewBlankHost(swarm2)
clk := mockClock.NewMock()
swarm1 := swarmt.GenSwarm(t, swarmt.WithClock(clk))
swarm2 := swarmt.GenSwarm(t, swarmt.WithClock(clk))
h1 := blhost.NewBlankHost(swarm1)
h2 := blhost.NewBlankHost(swarm2)
h1p := h1.ID()
h2p := h2.ID()
h1p := h1.ID()
h2p := h2.ID()
ids1, err := identify.NewIDService(h1)
require.NoError(t, err)
defer ids1.Close()
ids1.Start()
opts := []identify.Option{}
if !withObsAddrManager {
opts = append(opts, identify.DisableObservedAddrManager())
}
ids1, err := identify.NewIDService(h1, opts...)
require.NoError(t, err)
defer ids1.Close()
ids1.Start()
ids2, err := identify.NewIDService(h2)
require.NoError(t, err)
defer ids2.Close()
ids2.Start()
opts = []identify.Option{}
if !withObsAddrManager {
opts = append(opts, identify.DisableObservedAddrManager())
}
ids2, err := identify.NewIDService(h2, opts...)
require.NoError(t, err)
defer ids2.Close()
ids2.Start()
sub, err := ids1.Host.EventBus().Subscribe(new(event.EvtPeerIdentificationCompleted))
if err != nil {
t.Fatal(err)
}
sub, err := ids1.Host.EventBus().Subscribe(new(event.EvtPeerIdentificationCompleted))
if err != nil {
t.Fatal(err)
}
testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing
testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing
testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing
testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing
// the forgetMe addr represents an address for h1 that h2 has learned out of band
// (not via identify protocol). During the identify exchange, it will be
// forgotten and replaced by the addrs h1 sends.
forgetMe, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234")
// the forgetMe addr represents an address for h1 that h2 has learned out of band
// (not via identify protocol). During the identify exchange, it will be
// forgotten and replaced by the addrs h1 sends.
forgetMe, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234")
h2.Peerstore().AddAddr(h1p, forgetMe, peerstore.RecentlyConnectedAddrTTL)
h2pi := h2.Peerstore().PeerInfo(h2p)
require.NoError(t, h1.Connect(context.Background(), h2pi))
h2.Peerstore().AddAddr(h1p, forgetMe, peerstore.RecentlyConnectedAddrTTL)
h2pi := h2.Peerstore().PeerInfo(h2p)
require.NoError(t, h1.Connect(context.Background(), h2pi))
h1t2c := h1.Network().ConnsToPeer(h2p)
require.NotEmpty(t, h1t2c, "should have a conn here")
h1t2c := h1.Network().ConnsToPeer(h2p)
require.NotEmpty(t, h1t2c, "should have a conn here")
ids1.IdentifyConn(h1t2c[0])
ids1.IdentifyConn(h1t2c[0])
// the idService should be opened automatically, by the network.
// what we should see now is that both peers know about each others listen addresses.
t.Log("test peer1 has peer2 addrs correctly")
testKnowsAddrs(t, h1, h2p, h2.Addrs()) // has them
testHasAgentVersion(t, h1, h2p)
testHasPublicKey(t, h1, h2p, h2.Peerstore().PubKey(h2p)) // h1 should have h2's public key
// the idService should be opened automatically, by the network.
// what we should see now is that both peers know about each others listen addresses.
t.Log("test peer1 has peer2 addrs correctly")
testKnowsAddrs(t, h1, h2p, h2.Addrs()) // has them
testHasAgentVersion(t, h1, h2p)
testHasPublicKey(t, h1, h2p, h2.Peerstore().PubKey(h2p)) // h1 should have h2's public key
// now, this wait we do have to do. it's the wait for the Listening side
// to be done identifying the connection.
c := h2.Network().ConnsToPeer(h1.ID())
require.NotEmpty(t, c, "should have connection by now at least.")
ids2.IdentifyConn(c[0])
// now, this wait we do have to do. it's the wait for the Listening side
// to be done identifying the connection.
c := h2.Network().ConnsToPeer(h1.ID())
require.NotEmpty(t, c, "should have connection by now at least.")
ids2.IdentifyConn(c[0])
// and the protocol versions.
t.Log("test peer2 has peer1 addrs correctly")
testKnowsAddrs(t, h2, h1p, h1.Addrs()) // has them
testHasAgentVersion(t, h2, h1p)
testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key
// and the protocol versions.
t.Log("test peer2 has peer1 addrs correctly")
testKnowsAddrs(t, h2, h1p, h1.Addrs()) // has them
testHasAgentVersion(t, h2, h1p)
testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key
// Need both sides to actually notice that the connection has been closed.
sentDisconnect1 := waitForDisconnectNotification(swarm1)
sentDisconnect2 := waitForDisconnectNotification(swarm2)
h1.Network().ClosePeer(h2p)
h2.Network().ClosePeer(h1p)
if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 {
t.Fatal("should have no connections")
}
// Need both sides to actually notice that the connection has been closed.
sentDisconnect1 := waitForDisconnectNotification(swarm1)
sentDisconnect2 := waitForDisconnectNotification(swarm2)
h1.Network().ClosePeer(h2p)
h2.Network().ClosePeer(h1p)
if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 {
t.Fatal("should have no connections")
}
t.Log("testing addrs just after disconnect")
// addresses don't immediately expire on disconnect, so we should still have them
testKnowsAddrs(t, h2, h1p, h1.Addrs())
testKnowsAddrs(t, h1, h2p, h2.Addrs())
t.Log("testing addrs just after disconnect")
// addresses don't immediately expire on disconnect, so we should still have them
testKnowsAddrs(t, h2, h1p, h1.Addrs())
testKnowsAddrs(t, h1, h2p, h2.Addrs())
<-sentDisconnect1
<-sentDisconnect2
<-sentDisconnect1
<-sentDisconnect2
// the addrs had their TTLs reduced on disconnect, and
// will be forgotten soon after
t.Log("testing addrs after TTL expiration")
clk.Add(time.Second)
testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{})
testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{})
// the addrs had their TTLs reduced on disconnect, and
// will be forgotten soon after
t.Log("testing addrs after TTL expiration")
clk.Add(time.Second)
testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{})
testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{})
// test that we received the "identify completed" event.
select {
case evtAny := <-sub.Out():
assertCorrectEvtPeerIdentificationCompleted(t, evtAny, h2)
case <-time.After(3 * time.Second):
t.Fatalf("expected EvtPeerIdentificationCompleted event within 10 seconds; none received")
// test that we received the "identify completed" event.
select {
case evtAny := <-sub.Out():
assertCorrectEvtPeerIdentificationCompleted(t, evtAny, h2)
case <-time.After(3 * time.Second):
t.Fatalf("expected EvtPeerIdentificationCompleted event within 10 seconds; none received")
}
})
}
}
@@ -603,8 +620,13 @@ func TestLargeIdentifyMessage(t *testing.T) {
t.Skip("setting peerstore.RecentlyConnectedAddrTTL is racy")
}
oldTTL := peerstore.RecentlyConnectedAddrTTL
oldTempTTL := peerstore.TempAddrTTL
peerstore.RecentlyConnectedAddrTTL = 500 * time.Millisecond
t.Cleanup(func() { peerstore.RecentlyConnectedAddrTTL = oldTTL })
peerstore.TempAddrTTL = 50 * time.Millisecond
t.Cleanup(func() {
peerstore.RecentlyConnectedAddrTTL = oldTTL
peerstore.TempAddrTTL = oldTempTTL
})
clk := mockClock.NewMock()
swarm1 := swarmt.GenSwarm(t, swarmt.WithClock(clk))

View File

@@ -0,0 +1,119 @@
package identify
import (
"context"
"fmt"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
)
type natEmitter struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
reachabilitySub event.Subscription
reachability network.Reachability
eventInterval time.Duration
currentUDPNATDeviceType network.NATDeviceType
currentTCPNATDeviceType network.NATDeviceType
emitNATDeviceTypeChanged event.Emitter
observedAddrMgr *ObservedAddrManager
}
func newNATEmitter(h host.Host, o *ObservedAddrManager, eventInterval time.Duration) (*natEmitter, error) {
ctx, cancel := context.WithCancel(context.Background())
n := &natEmitter{
observedAddrMgr: o,
ctx: ctx,
cancel: cancel,
eventInterval: eventInterval,
}
reachabilitySub, err := h.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("identify (nat emitter)"))
if err != nil {
return nil, fmt.Errorf("failed to subscribe to reachability event: %s", err)
}
n.reachabilitySub = reachabilitySub
emitter, err := h.EventBus().Emitter(new(event.EvtNATDeviceTypeChanged), eventbus.Stateful)
if err != nil {
return nil, fmt.Errorf("failed to create emitter for NATDeviceType: %s", err)
}
n.emitNATDeviceTypeChanged = emitter
n.wg.Add(1)
go n.worker()
return n, nil
}
func (n *natEmitter) worker() {
defer n.wg.Done()
subCh := n.reachabilitySub.Out()
ticker := time.NewTicker(n.eventInterval)
pendingUpdate := false
enoughTimeSinceLastUpdate := true
for {
select {
case evt, ok := <-subCh:
if !ok {
subCh = nil
continue
}
ev, ok := evt.(event.EvtLocalReachabilityChanged)
if !ok {
log.Error("invalid event: %v", evt)
continue
}
n.reachability = ev.Reachability
case <-ticker.C:
enoughTimeSinceLastUpdate = true
if pendingUpdate {
n.maybeNotify()
pendingUpdate = false
enoughTimeSinceLastUpdate = false
}
case <-n.observedAddrMgr.addrRecordedNotif:
pendingUpdate = true
if enoughTimeSinceLastUpdate {
n.maybeNotify()
pendingUpdate = false
enoughTimeSinceLastUpdate = false
}
case <-n.ctx.Done():
return
}
}
}
func (n *natEmitter) maybeNotify() {
if n.reachability == network.ReachabilityPrivate {
tcpNATType, udpNATType := n.observedAddrMgr.getNATType()
if tcpNATType != n.currentTCPNATDeviceType {
n.currentTCPNATDeviceType = tcpNATType
n.emitNATDeviceTypeChanged.Emit(event.EvtNATDeviceTypeChanged{
TransportProtocol: network.NATTransportTCP,
NatDeviceType: n.currentTCPNATDeviceType,
})
}
if udpNATType != n.currentUDPNATDeviceType {
n.currentUDPNATDeviceType = udpNATType
n.emitNATDeviceTypeChanged.Emit(event.EvtNATDeviceTypeChanged{
TransportProtocol: network.NATTransportUDP,
NatDeviceType: n.currentUDPNATDeviceType,
})
}
}
}
func (n *natEmitter) Close() {
n.cancel()
n.wg.Wait()
n.reachabilitySub.Close()
n.emitNATDeviceTypeChanged.Close()
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,77 +5,16 @@ package identify
import (
"fmt"
"sync/atomic"
"testing"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)
func TestObservedAddrGroupKey(t *testing.T) {
oa1 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/tcp/2345")}
oa2 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/tcp/1231")}
oa3 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.5/tcp/1231")}
oa4 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1231")}
oa5 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1531")}
oa6 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1531/quic-v1")}
oa7 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1111/quic-v1")}
oa8 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.5/udp/1111/quic-v1")}
// different ports, same IP => same key
require.Equal(t, oa1.groupKey(), oa2.groupKey())
// different IPs => different key
require.NotEqual(t, oa2.groupKey(), oa3.groupKey())
// same port, different protos => different keys
require.NotEqual(t, oa3.groupKey(), oa4.groupKey())
// same port, same address, different protos => different keys
require.NotEqual(t, oa2.groupKey(), oa4.groupKey())
// udp works as well
require.Equal(t, oa4.groupKey(), oa5.groupKey())
// udp and quic are different
require.NotEqual(t, oa5.groupKey(), oa6.groupKey())
// quic works as well
require.Equal(t, oa6.groupKey(), oa7.groupKey())
require.NotEqual(t, oa7.groupKey(), oa8.groupKey())
}
type mockHost struct {
addrs []ma.Multiaddr
listenAddrs []ma.Multiaddr
ifaceListenAddrs []ma.Multiaddr
}
// InterfaceListenAddresses implements listenAddrsProvider
func (h *mockHost) InterfaceListenAddresses() ([]ma.Multiaddr, error) {
return h.ifaceListenAddrs, nil
}
// ListenAddresses implements listenAddrsProvider
func (h *mockHost) ListenAddresses() []ma.Multiaddr {
return h.listenAddrs
}
// Addrs implements addrsProvider
func (h *mockHost) Addrs() []ma.Multiaddr {
return h.addrs
}
// NormalizeMultiaddr implements normalizeMultiaddrer
func (h *mockHost) NormalizeMultiaddr(m ma.Multiaddr) ma.Multiaddr {
original := m
for {
rest, tail := ma.SplitLast(m)
if rest == nil {
return original
}
if tail.Protocol().Code == ma.P_WEBTRANSPORT {
return m
}
m = rest
}
}
type mockConn struct {
local, remote ma.Multiaddr
isClosed atomic.Bool
}
// LocalMultiaddr implements connMultiaddrProvider
@@ -88,21 +27,30 @@ func (c *mockConn) RemoteMultiaddr() ma.Multiaddr {
return c.remote
}
func (c *mockConn) Close() {
c.isClosed.Store(true)
}
func (c *mockConn) IsClosed() bool {
return c.isClosed.Load()
}
func TestShouldRecordObservationWithWebTransport(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/udp/0/quic-v1/webtransport/certhash/uEgNmb28")
ifaceAddr := ma.StringCast("/ip4/10.0.0.2/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")
h := &mockHost{
listenAddrs: []ma.Multiaddr{listenAddr},
ifaceListenAddrs: []ma.Multiaddr{ifaceAddr},
addrs: []ma.Multiaddr{listenAddr},
}
listenAddrs := func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr} }
ifaceListenAddrs := func() ([]ma.Multiaddr, error) { return []ma.Multiaddr{ifaceAddr}, nil }
addrs := func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr} }
c := &mockConn{
local: listenAddr,
remote: ma.StringCast("/ip4/1.2.3.6/udp/1236/quic-v1/webtransport"),
}
observedAddr := ma.StringCast("/ip4/1.2.3.4/udp/1231/quic-v1/webtransport")
require.True(t, shouldRecordObservation(h, h, c, observedAddr))
o, err := NewObservedAddrManager(listenAddrs, addrs, ifaceListenAddrs, normalize)
require.NoError(t, err)
shouldRecord, _, _ := o.shouldRecordObservation(c, observedAddr)
require.True(t, shouldRecord)
}
func TestShouldRecordObservationWithNAT64Addr(t *testing.T) {
@@ -111,11 +59,11 @@ func TestShouldRecordObservationWithNAT64Addr(t *testing.T) {
listenAddr2 := ma.StringCast("/ip6/::/tcp/1234")
ifaceAddr2 := ma.StringCast("/ip6/1::1/tcp/4321")
h := &mockHost{
listenAddrs: []ma.Multiaddr{listenAddr1, listenAddr2},
ifaceListenAddrs: []ma.Multiaddr{ifaceAddr1, ifaceAddr2},
addrs: []ma.Multiaddr{listenAddr1, listenAddr2},
}
var (
listenAddrs = func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr1, listenAddr2} }
ifaceListenAddrs = func() ([]ma.Multiaddr, error) { return []ma.Multiaddr{ifaceAddr1, ifaceAddr2}, nil }
addrs = func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr1, listenAddr2} }
)
c := &mockConn{
local: listenAddr1,
remote: ma.StringCast("/ip4/1.2.3.6/tcp/4321"),
@@ -142,12 +90,70 @@ func TestShouldRecordObservationWithNAT64Addr(t *testing.T) {
failureReason: "NAT64 IPv6 address shouldn't be observed",
},
}
o, err := NewObservedAddrManager(listenAddrs, addrs, ifaceListenAddrs, normalize)
require.NoError(t, err)
for i, tc := range cases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
if shouldRecordObservation(h, h, c, tc.addr) != tc.want {
if shouldRecord, _, _ := o.shouldRecordObservation(c, tc.addr); shouldRecord != tc.want {
t.Fatalf("%s %s", tc.addr, tc.failureReason)
}
})
}
}
func TestThinWaistForm(t *testing.T) {
tc := []struct {
input string
tw string
rest string
err bool
}{{
input: "/ip4/1.2.3.4/tcp/1",
tw: "/ip4/1.2.3.4/tcp/1",
rest: "",
}, {
input: "/ip4/1.2.3.4/tcp/1/ws",
tw: "/ip4/1.2.3.4/tcp/1",
rest: "/ws",
}, {
input: "/ip4/127.0.0.1/udp/1/quic-v1",
tw: "/ip4/127.0.0.1/udp/1",
rest: "/quic-v1",
}, {
input: "/ip4/1.2.3.4/udp/1/quic-v1/webtransport",
tw: "/ip4/1.2.3.4/udp/1",
rest: "/quic-v1/webtransport",
}, {
input: "/ip4/1.2.3.4/",
err: true,
}, {
input: "/tcp/1",
err: true,
}, {
input: "/ip6/::1/tcp/1",
tw: "/ip6/::1/tcp/1",
rest: "",
}}
for i, tt := range tc {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
inputAddr := ma.StringCast(tt.input)
tw, err := thinWaistForm(inputAddr)
if tt.err {
require.Equal(t, tw, thinWaist{})
require.Error(t, err)
return
}
wantTW := ma.StringCast(tt.tw)
var restTW ma.Multiaddr
if tt.rest != "" {
restTW = ma.StringCast(tt.rest)
}
require.Equal(t, tw.Addr, inputAddr, "%s %s", tw.Addr, inputAddr)
require.Equal(t, wantTW, tw.TW, "%s %s", tw.TW, wantTW)
require.Equal(t, restTW, tw.Rest, "%s %s", restTW, tw.Rest)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,11 @@
package identify
type config struct {
protocolVersion string
userAgent string
disableSignedPeerRecord bool
metricsTracer MetricsTracer
protocolVersion string
userAgent string
disableSignedPeerRecord bool
metricsTracer MetricsTracer
disableObservedAddrManager bool
}
// Option is an option function for identify.
@@ -38,3 +39,11 @@ func WithMetricsTracer(tr MetricsTracer) Option {
cfg.metricsTracer = tr
}
}
// DisableObservedAddrManager disables the observed address manager. It also
// effectively disables the nat emitter and EvtNATDeviceTypeChanged
func DisableObservedAddrManager() Option {
return func(cfg *config) {
cfg.disableObservedAddrManager = true
}
}