mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-10-28 18:42:45 +08:00
Merge branch 'master' into uci-golangci-lint
This commit is contained in:
@@ -12,6 +12,9 @@ linters:
|
||||
- revive
|
||||
- unused
|
||||
- prealloc
|
||||
disable:
|
||||
- errcheck
|
||||
- staticcheck
|
||||
|
||||
disable:
|
||||
- errcheck
|
||||
|
||||
@@ -33,6 +33,7 @@ import (
|
||||
routed "github.com/libp2p/go-libp2p/p2p/host/routed"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/swarm"
|
||||
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
circuitv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client"
|
||||
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
|
||||
@@ -379,8 +380,28 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
|
||||
fxopts = append(fxopts, cfg.QUICReuse...)
|
||||
} else {
|
||||
fxopts = append(fxopts,
|
||||
fx.Provide(func(key quic.StatelessResetKey, tokenGenerator quic.TokenGeneratorKey, lifecycle fx.Lifecycle) (*quicreuse.ConnManager, error) {
|
||||
var opts []quicreuse.Option
|
||||
fx.Provide(func(key quic.StatelessResetKey, tokenGenerator quic.TokenGeneratorKey, rcmgr network.ResourceManager, lifecycle fx.Lifecycle) (*quicreuse.ConnManager, error) {
|
||||
opts := []quicreuse.Option{
|
||||
quicreuse.ConnContext(func(ctx context.Context, clientInfo *quic.ClientInfo) (context.Context, error) {
|
||||
// even if creating the quic maddr fails, let the rcmgr decide what to do with the connection
|
||||
addr, err := quicreuse.ToQuicMultiaddr(clientInfo.RemoteAddr, quic.Version1)
|
||||
if err != nil {
|
||||
addr = nil
|
||||
}
|
||||
scope, err := rcmgr.OpenConnection(network.DirInbound, false, addr)
|
||||
if err != nil {
|
||||
return ctx, err
|
||||
}
|
||||
ctx = network.WithConnManagementScope(ctx, scope)
|
||||
context.AfterFunc(ctx, func() {
|
||||
scope.Done()
|
||||
})
|
||||
return ctx, nil
|
||||
}),
|
||||
quicreuse.VerifySourceAddress(func(addr net.Addr) bool {
|
||||
return rcmgr.VerifySourceAddress(addr)
|
||||
}),
|
||||
}
|
||||
if !cfg.DisableMetrics {
|
||||
opts = append(opts, quicreuse.EnableMetrics(cfg.PrometheusRegisterer))
|
||||
}
|
||||
@@ -413,15 +434,7 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
|
||||
return fxopts, nil
|
||||
}
|
||||
|
||||
func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) {
|
||||
var autonatv2Dialer host.Host
|
||||
if cfg.EnableAutoNATv2 {
|
||||
ah, err := cfg.makeAutoNATV2Host()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
autonatv2Dialer = ah
|
||||
}
|
||||
func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus, an *autonatv2.AutoNAT) (*bhost.BasicHost, error) {
|
||||
h, err := bhost.NewHost(swrm, &bhost.HostOpts{
|
||||
EventBus: eventBus,
|
||||
ConnManager: cfg.ConnManager,
|
||||
@@ -437,8 +450,7 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B
|
||||
EnableMetrics: !cfg.DisableMetrics,
|
||||
PrometheusRegisterer: cfg.PrometheusRegisterer,
|
||||
DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery,
|
||||
EnableAutoNATv2: cfg.EnableAutoNATv2,
|
||||
AutoNATv2Dialer: autonatv2Dialer,
|
||||
AutoNATv2: an,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -517,6 +529,24 @@ func (cfg *Config) NewNode() (host.Host, error) {
|
||||
})
|
||||
return sw, nil
|
||||
}),
|
||||
fx.Provide(func() (*autonatv2.AutoNAT, error) {
|
||||
if !cfg.EnableAutoNATv2 {
|
||||
return nil, nil
|
||||
}
|
||||
ah, err := cfg.makeAutoNATV2Host()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var mt autonatv2.MetricsTracer
|
||||
if !cfg.DisableMetrics {
|
||||
mt = autonatv2.NewMetricsTracer(cfg.PrometheusRegisterer)
|
||||
}
|
||||
autoNATv2, err := autonatv2.New(ah, autonatv2.WithMetricsTracer(mt))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create autonatv2: %w", err)
|
||||
}
|
||||
return autoNATv2, nil
|
||||
}),
|
||||
fx.Provide(cfg.newBasicHost),
|
||||
fx.Provide(func(bh *bhost.BasicHost) identify.IDService {
|
||||
return bh.IDService()
|
||||
|
||||
@@ -2,6 +2,7 @@ package event
|
||||
|
||||
import (
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
// EvtLocalReachabilityChanged is an event struct to be emitted when the local's
|
||||
@@ -11,3 +12,13 @@ import (
|
||||
type EvtLocalReachabilityChanged struct {
|
||||
Reachability network.Reachability
|
||||
}
|
||||
|
||||
// EvtHostReachableAddrsChanged is sent when host's reachable or unreachable addresses change
|
||||
// Reachable, Unreachable, and Unknown only contain Public IP or DNS addresses
|
||||
//
|
||||
// Experimental: This API is unstable. Any changes to this event will be done without a deprecation notice.
|
||||
type EvtHostReachableAddrsChanged struct {
|
||||
Reachable []ma.Multiaddr
|
||||
Unreachable []ma.Multiaddr
|
||||
Unknown []ma.Multiaddr
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
package mocknetwork
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
network "github.com/libp2p/go-libp2p/core/network"
|
||||
@@ -87,6 +88,20 @@ func (mr *MockResourceManagerMockRecorder) OpenStream(p, dir any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenStream", reflect.TypeOf((*MockResourceManager)(nil).OpenStream), p, dir)
|
||||
}
|
||||
|
||||
// VerifySourceAddress mocks base method.
|
||||
func (m *MockResourceManager) VerifySourceAddress(addr net.Addr) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "VerifySourceAddress", addr)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// VerifySourceAddress indicates an expected call of VerifySourceAddress.
|
||||
func (mr *MockResourceManagerMockRecorder) VerifySourceAddress(addr any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifySourceAddress", reflect.TypeOf((*MockResourceManager)(nil).VerifySourceAddress), addr)
|
||||
}
|
||||
|
||||
// ViewPeer mocks base method.
|
||||
func (m *MockResourceManager) ViewPeer(arg0 peer.ID, arg1 func(network.PeerScope) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
@@ -87,6 +91,10 @@ type ResourceManager interface {
|
||||
// the end of the scope's span.
|
||||
OpenConnection(dir Direction, usefd bool, endpoint multiaddr.Multiaddr) (ConnManagementScope, error)
|
||||
|
||||
// VerifySourceAddress tells the transport to verify the source address for an incoming connection
|
||||
// before gating the connection with OpenConnection.
|
||||
VerifySourceAddress(addr net.Addr) bool
|
||||
|
||||
// OpenStream creates a new stream scope, initially unnegotiated.
|
||||
// An unnegotiated stream will be initially unattached to any protocol scope
|
||||
// and constrained by the transient scope.
|
||||
@@ -269,9 +277,30 @@ type ScopeStat struct {
|
||||
Memory int64
|
||||
}
|
||||
|
||||
// connManagementScopeKey is the key to store Scope in contexts
|
||||
type connManagementScopeKey struct{}
|
||||
|
||||
func WithConnManagementScope(ctx context.Context, scope ConnManagementScope) context.Context {
|
||||
return context.WithValue(ctx, connManagementScopeKey{}, scope)
|
||||
}
|
||||
|
||||
func UnwrapConnManagementScope(ctx context.Context) (ConnManagementScope, error) {
|
||||
v := ctx.Value(connManagementScopeKey{})
|
||||
if v == nil {
|
||||
return nil, errors.New("context has no ConnManagementScope")
|
||||
}
|
||||
scope, ok := v.(ConnManagementScope)
|
||||
if !ok {
|
||||
return nil, errors.New("context has no ConnManagementScope")
|
||||
}
|
||||
return scope, nil
|
||||
}
|
||||
|
||||
// NullResourceManager is a stub for tests and initialization of default values
|
||||
type NullResourceManager struct{}
|
||||
|
||||
var _ ResourceManager = (*NullResourceManager)(nil)
|
||||
|
||||
var _ ResourceScope = (*NullScope)(nil)
|
||||
var _ ResourceScopeSpan = (*NullScope)(nil)
|
||||
var _ ServiceScope = (*NullScope)(nil)
|
||||
@@ -306,6 +335,10 @@ func (n *NullResourceManager) OpenConnection(_ Direction, _ bool, _ multiaddr.Mu
|
||||
func (n *NullResourceManager) OpenStream(_ peer.ID, _ Direction) (StreamManagementScope, error) {
|
||||
return &NullScope{}, nil
|
||||
}
|
||||
func (*NullResourceManager) VerifySourceAddress(_ net.Addr) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *NullResourceManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -324,3 +357,4 @@ func (n *NullScope) ProtocolScope() ProtocolScope { return &NullScop
|
||||
func (n *NullScope) SetProtocol(_ protocol.ID) error { return nil }
|
||||
func (n *NullScope) ServiceScope() ServiceScope { return &NullScope{} }
|
||||
func (n *NullScope) SetService(_ string) error { return nil }
|
||||
func (n *NullScope) VerifySourceAddress(_ net.Addr) bool { return false }
|
||||
|
||||
23
go.mod
23
go.mod
@@ -53,18 +53,18 @@ require (
|
||||
github.com/pion/webrtc/v4 v4.0.14
|
||||
github.com/prometheus/client_golang v1.21.0
|
||||
github.com/prometheus/client_model v0.6.1
|
||||
github.com/quic-go/quic-go v0.50.0
|
||||
github.com/quic-go/quic-go v0.52.0
|
||||
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.uber.org/fx v1.23.0
|
||||
go.uber.org/goleak v1.3.0
|
||||
go.uber.org/mock v0.5.0
|
||||
go.uber.org/mock v0.5.2
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/sync v0.11.0
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
golang.org/x/sync v0.14.0
|
||||
golang.org/x/sys v0.33.0
|
||||
golang.org/x/time v0.11.0
|
||||
golang.org/x/tools v0.30.0
|
||||
golang.org/x/tools v0.32.0
|
||||
google.golang.org/protobuf v1.36.5
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/francoispqt/gojay v1.2.13 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 // indirect
|
||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -83,7 +83,7 @@ require (
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/multiformats/go-base36 v0.2.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.22.2 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.23.4 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.4 // indirect
|
||||
github.com/pion/interceptor v0.1.37 // indirect
|
||||
@@ -103,12 +103,13 @@ require (
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/dig v1.18.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
golang.org/x/mod v0.24.0 // indirect
|
||||
golang.org/x/net v0.39.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
lukechampine.com/blake3 v1.4.0 // indirect
|
||||
)
|
||||
|
||||
56
go.sum
56
go.sum
@@ -53,16 +53,16 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 h1:43XjGa6toxLpeksjcxs1jIoIyr+vUfOqY2c6HB4bpoc=
|
||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a h1:rDA3FfmxwXR+BVKKdz55WwMJ1pD2hJQNW31d+l3mPk4=
|
||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY=
|
||||
@@ -180,10 +180,10 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo=
|
||||
github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||
github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU=
|
||||
github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk=
|
||||
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
|
||||
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
|
||||
github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus=
|
||||
github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8=
|
||||
github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU=
|
||||
github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
|
||||
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
|
||||
@@ -232,6 +232,8 @@ github.com/pion/webrtc/v4 v4.0.14/go.mod h1:R3+qTnQTS03UzwDarYecgioNf7DYgTsldxnC
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA=
|
||||
github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
@@ -246,8 +248,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo=
|
||||
github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E=
|
||||
github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA=
|
||||
github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ=
|
||||
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg=
|
||||
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@@ -302,6 +304,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw=
|
||||
go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
|
||||
go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg=
|
||||
@@ -309,8 +313,8 @@ go.uber.org/fx v1.23.0/go.mod h1:o/D9n+2mLP6v1EG+qsdT1O8wKopYAsqZasju97SDFCU=
|
||||
go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
@@ -330,8 +334,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
@@ -344,8 +348,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -367,8 +371,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
@@ -382,8 +386,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -407,8 +411,8 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -425,8 +429,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
@@ -442,8 +446,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU=
|
||||
golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -2,6 +2,7 @@ package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
|
||||
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
|
||||
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
|
||||
"github.com/libp2p/go-netroute"
|
||||
@@ -27,24 +29,37 @@ type observedAddrsManager interface {
|
||||
ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr
|
||||
}
|
||||
|
||||
type hostAddrs struct {
|
||||
addrs []ma.Multiaddr
|
||||
localAddrs []ma.Multiaddr
|
||||
reachableAddrs []ma.Multiaddr
|
||||
unreachableAddrs []ma.Multiaddr
|
||||
unknownAddrs []ma.Multiaddr
|
||||
relayAddrs []ma.Multiaddr
|
||||
}
|
||||
|
||||
type addrsManager struct {
|
||||
eventbus event.Bus
|
||||
bus event.Bus
|
||||
natManager NATManager
|
||||
addrsFactory AddrsFactory
|
||||
listenAddrs func() []ma.Multiaddr
|
||||
transportForListening func(ma.Multiaddr) transport.Transport
|
||||
observedAddrsManager observedAddrsManager
|
||||
interfaceAddrs *interfaceAddrsCache
|
||||
addrsReachabilityTracker *addrsReachabilityTracker
|
||||
|
||||
// addrsUpdatedChan is notified when addrs change. This is provided by the caller.
|
||||
addrsUpdatedChan chan struct{}
|
||||
|
||||
// triggerAddrsUpdateChan is used to trigger an addresses update.
|
||||
triggerAddrsUpdateChan chan struct{}
|
||||
// addrsUpdatedChan is notified when addresses change.
|
||||
addrsUpdatedChan chan struct{}
|
||||
// triggerReachabilityUpdate is notified when reachable addrs are updated.
|
||||
triggerReachabilityUpdate chan struct{}
|
||||
|
||||
hostReachability atomic.Pointer[network.Reachability]
|
||||
|
||||
addrsMx sync.RWMutex // protects fields below
|
||||
localAddrs []ma.Multiaddr
|
||||
relayAddrs []ma.Multiaddr
|
||||
addrsMx sync.RWMutex
|
||||
currentAddrs hostAddrs
|
||||
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
@@ -52,23 +67,25 @@ type addrsManager struct {
|
||||
}
|
||||
|
||||
func newAddrsManager(
|
||||
eventbus event.Bus,
|
||||
bus event.Bus,
|
||||
natmgr NATManager,
|
||||
addrsFactory AddrsFactory,
|
||||
listenAddrs func() []ma.Multiaddr,
|
||||
transportForListening func(ma.Multiaddr) transport.Transport,
|
||||
observedAddrsManager observedAddrsManager,
|
||||
addrsUpdatedChan chan struct{},
|
||||
client autonatv2Client,
|
||||
) (*addrsManager, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
as := &addrsManager{
|
||||
eventbus: eventbus,
|
||||
bus: bus,
|
||||
listenAddrs: listenAddrs,
|
||||
transportForListening: transportForListening,
|
||||
observedAddrsManager: observedAddrsManager,
|
||||
natManager: natmgr,
|
||||
addrsFactory: addrsFactory,
|
||||
triggerAddrsUpdateChan: make(chan struct{}, 1),
|
||||
triggerReachabilityUpdate: make(chan struct{}, 1),
|
||||
addrsUpdatedChan: addrsUpdatedChan,
|
||||
interfaceAddrs: &interfaceAddrsCache{},
|
||||
ctx: ctx,
|
||||
@@ -76,11 +93,23 @@ func newAddrsManager(
|
||||
}
|
||||
unknownReachability := network.ReachabilityUnknown
|
||||
as.hostReachability.Store(&unknownReachability)
|
||||
|
||||
if client != nil {
|
||||
as.addrsReachabilityTracker = newAddrsReachabilityTracker(client, as.triggerReachabilityUpdate, nil)
|
||||
}
|
||||
return as, nil
|
||||
}
|
||||
|
||||
func (a *addrsManager) Start() error {
|
||||
return a.background()
|
||||
// TODO: add Start method to NATMgr
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
err := a.addrsReachabilityTracker.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error starting addrs reachability tracker: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return a.startBackgroundWorker()
|
||||
}
|
||||
|
||||
func (a *addrsManager) Close() {
|
||||
@@ -91,10 +120,18 @@ func (a *addrsManager) Close() {
|
||||
log.Warnf("error closing natmgr: %s", err)
|
||||
}
|
||||
}
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
err := a.addrsReachabilityTracker.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error closing addrs reachability tracker: %s", err)
|
||||
}
|
||||
}
|
||||
a.wg.Wait()
|
||||
}
|
||||
|
||||
func (a *addrsManager) NetNotifee() network.Notifiee {
|
||||
// Updating addrs in sync provides the nice property that
|
||||
// host.Addrs() just after host.Network().Listen(x) will return x
|
||||
return &network.NotifyBundle{
|
||||
ListenF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() },
|
||||
ListenCloseF: func(network.Network, ma.Multiaddr) { a.triggerAddrsUpdate() },
|
||||
@@ -102,37 +139,53 @@ func (a *addrsManager) NetNotifee() network.Notifiee {
|
||||
}
|
||||
|
||||
func (a *addrsManager) triggerAddrsUpdate() {
|
||||
// This is ugly, we update here *and* in the background loop, but this ensures the nice property
|
||||
// that host.Addrs after host.Network().Listen(...) will return the recently added listen address.
|
||||
a.updateLocalAddrs()
|
||||
a.updateAddrs(false, nil)
|
||||
select {
|
||||
case a.triggerAddrsUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (a *addrsManager) background() error {
|
||||
autoRelayAddrsSub, err := a.eventbus.Subscribe(new(event.EvtAutoRelayAddrsUpdated))
|
||||
func (a *addrsManager) startBackgroundWorker() error {
|
||||
autoRelayAddrsSub, err := a.bus.Subscribe(new(event.EvtAutoRelayAddrsUpdated), eventbus.Name("addrs-manager"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error subscribing to auto relay addrs: %s", err)
|
||||
}
|
||||
|
||||
autonatReachabilitySub, err := a.eventbus.Subscribe(new(event.EvtLocalReachabilityChanged))
|
||||
autonatReachabilitySub, err := a.bus.Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("addrs-manager"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error subscribing to autonat reachability: %s", err)
|
||||
err1 := autoRelayAddrsSub.Close()
|
||||
if err1 != nil {
|
||||
err1 = fmt.Errorf("error closign autorelaysub: %w", err1)
|
||||
}
|
||||
err = fmt.Errorf("error subscribing to autonat reachability: %s", err)
|
||||
return errors.Join(err, err1)
|
||||
}
|
||||
|
||||
// ensure that we have the correct address after returning from Start()
|
||||
// update local addrs
|
||||
a.updateLocalAddrs()
|
||||
emitter, err := a.bus.Emitter(new(event.EvtHostReachableAddrsChanged), eventbus.Stateful)
|
||||
if err != nil {
|
||||
err1 := autoRelayAddrsSub.Close()
|
||||
if err1 != nil {
|
||||
err1 = fmt.Errorf("error closing autorelaysub: %w", err1)
|
||||
}
|
||||
err2 := autonatReachabilitySub.Close()
|
||||
if err2 != nil {
|
||||
err2 = fmt.Errorf("error closing autonat reachability: %w", err1)
|
||||
}
|
||||
err = fmt.Errorf("error subscribing to autonat reachability: %s", err)
|
||||
return errors.Join(err, err1, err2)
|
||||
}
|
||||
|
||||
var relayAddrs []ma.Multiaddr
|
||||
// update relay addrs in case we're private
|
||||
select {
|
||||
case e := <-autoRelayAddrsSub.Out():
|
||||
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
|
||||
a.updateRelayAddrs(evt.RelayAddrs)
|
||||
relayAddrs = slices.Clone(evt.RelayAddrs)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case e := <-autonatReachabilitySub.Out():
|
||||
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
|
||||
@@ -140,18 +193,25 @@ func (a *addrsManager) background() error {
|
||||
}
|
||||
default:
|
||||
}
|
||||
// update addresses before starting the worker loop. This ensures that any address updates
|
||||
// before calling addrsManager.Start are correctly reported after Start returns.
|
||||
a.updateAddrs(true, relayAddrs)
|
||||
|
||||
a.wg.Add(1)
|
||||
go func() {
|
||||
go a.background(autoRelayAddrsSub, autonatReachabilitySub, emitter, relayAddrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *addrsManager) background(autoRelayAddrsSub, autonatReachabilitySub event.Subscription,
|
||||
emitter event.Emitter, relayAddrs []ma.Multiaddr,
|
||||
) {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
err := autoRelayAddrsSub.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error closing auto relay addrs sub: %s", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
err := autonatReachabilitySub.Close()
|
||||
err = autonatReachabilitySub.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error closing autonat reachability sub: %s", err)
|
||||
}
|
||||
@@ -159,24 +219,18 @@ func (a *addrsManager) background() error {
|
||||
|
||||
ticker := time.NewTicker(addrChangeTickrInterval)
|
||||
defer ticker.Stop()
|
||||
var prev []ma.Multiaddr
|
||||
var previousAddrs hostAddrs
|
||||
for {
|
||||
a.updateLocalAddrs()
|
||||
curr := a.Addrs()
|
||||
if a.areAddrsDifferent(prev, curr) {
|
||||
log.Debugf("host addresses updated: %s", curr)
|
||||
select {
|
||||
case a.addrsUpdatedChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
prev = curr
|
||||
currAddrs := a.updateAddrs(true, relayAddrs)
|
||||
a.notifyAddrsChanged(emitter, previousAddrs, currAddrs)
|
||||
previousAddrs = currAddrs
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-a.triggerAddrsUpdateChan:
|
||||
case <-a.triggerReachabilityUpdate:
|
||||
case e := <-autoRelayAddrsSub.Out():
|
||||
if evt, ok := e.(event.EvtAutoRelayAddrsUpdated); ok {
|
||||
a.updateRelayAddrs(evt.RelayAddrs)
|
||||
relayAddrs = slices.Clone(evt.RelayAddrs)
|
||||
}
|
||||
case e := <-autonatReachabilitySub.Out():
|
||||
if evt, ok := e.(event.EvtLocalReachabilityChanged); ok {
|
||||
@@ -186,24 +240,106 @@ func (a *addrsManager) background() error {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateAddrs updates the addresses of the host and returns the new updated
|
||||
// addrs
|
||||
func (a *addrsManager) updateAddrs(updateRelayAddrs bool, relayAddrs []ma.Multiaddr) hostAddrs {
|
||||
// Must lock while doing both recompute and update as this method is called from
|
||||
// multiple goroutines.
|
||||
a.addrsMx.Lock()
|
||||
defer a.addrsMx.Unlock()
|
||||
|
||||
localAddrs := a.getLocalAddrs()
|
||||
var currReachableAddrs, currUnreachableAddrs, currUnknownAddrs []ma.Multiaddr
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
currReachableAddrs, currUnreachableAddrs, currUnknownAddrs = a.getConfirmedAddrs(localAddrs)
|
||||
}
|
||||
if !updateRelayAddrs {
|
||||
relayAddrs = a.currentAddrs.relayAddrs
|
||||
} else {
|
||||
// Copy the callers slice
|
||||
relayAddrs = slices.Clone(relayAddrs)
|
||||
}
|
||||
currAddrs := a.getAddrs(slices.Clone(localAddrs), relayAddrs)
|
||||
|
||||
a.currentAddrs = hostAddrs{
|
||||
addrs: append(a.currentAddrs.addrs[:0], currAddrs...),
|
||||
localAddrs: append(a.currentAddrs.localAddrs[:0], localAddrs...),
|
||||
reachableAddrs: append(a.currentAddrs.reachableAddrs[:0], currReachableAddrs...),
|
||||
unreachableAddrs: append(a.currentAddrs.unreachableAddrs[:0], currUnreachableAddrs...),
|
||||
unknownAddrs: append(a.currentAddrs.unknownAddrs[:0], currUnknownAddrs...),
|
||||
relayAddrs: append(a.currentAddrs.relayAddrs[:0], relayAddrs...),
|
||||
}
|
||||
|
||||
return hostAddrs{
|
||||
localAddrs: localAddrs,
|
||||
addrs: currAddrs,
|
||||
reachableAddrs: currReachableAddrs,
|
||||
unreachableAddrs: currUnreachableAddrs,
|
||||
unknownAddrs: currUnknownAddrs,
|
||||
relayAddrs: relayAddrs,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *addrsManager) notifyAddrsChanged(emitter event.Emitter, previous, current hostAddrs) {
|
||||
if areAddrsDifferent(previous.localAddrs, current.localAddrs) {
|
||||
log.Debugf("host local addresses updated: %s", current.localAddrs)
|
||||
if a.addrsReachabilityTracker != nil {
|
||||
a.addrsReachabilityTracker.UpdateAddrs(current.localAddrs)
|
||||
}
|
||||
}
|
||||
if areAddrsDifferent(previous.addrs, current.addrs) {
|
||||
log.Debugf("host addresses updated: %s", current.localAddrs)
|
||||
select {
|
||||
case a.addrsUpdatedChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// We *must* send both reachability changed and addrs changed events from the
|
||||
// same goroutine to ensure correct ordering
|
||||
// Consider the events:
|
||||
// - addr x discovered
|
||||
// - addr x is reachable
|
||||
// - addr x removed
|
||||
// We must send these events in the same order. It'll be confusing for consumers
|
||||
// if the reachable event is received after the addr removed event.
|
||||
if areAddrsDifferent(previous.reachableAddrs, current.reachableAddrs) ||
|
||||
areAddrsDifferent(previous.unreachableAddrs, current.unreachableAddrs) ||
|
||||
areAddrsDifferent(previous.unknownAddrs, current.unknownAddrs) {
|
||||
log.Debugf("host reachable addrs updated: %s", current.localAddrs)
|
||||
if err := emitter.Emit(event.EvtHostReachableAddrsChanged{
|
||||
Reachable: slices.Clone(current.reachableAddrs),
|
||||
Unreachable: slices.Clone(current.unreachableAddrs),
|
||||
Unknown: slices.Clone(current.unknownAddrs),
|
||||
}); err != nil {
|
||||
log.Errorf("error sending host reachable addrs changed event: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Addrs returns the node's dialable addresses both public and private.
|
||||
// If autorelay is enabled and node reachability is private, it returns
|
||||
// the node's relay addresses and private network addresses.
|
||||
func (a *addrsManager) Addrs() []ma.Multiaddr {
|
||||
addrs := a.DirectAddrs()
|
||||
a.addrsMx.RLock()
|
||||
directAddrs := slices.Clone(a.currentAddrs.localAddrs)
|
||||
relayAddrs := slices.Clone(a.currentAddrs.relayAddrs)
|
||||
a.addrsMx.RUnlock()
|
||||
return a.getAddrs(directAddrs, relayAddrs)
|
||||
}
|
||||
|
||||
// getAddrs returns the node's dialable addresses. Mutates localAddrs
|
||||
func (a *addrsManager) getAddrs(localAddrs []ma.Multiaddr, relayAddrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
addrs := localAddrs
|
||||
rch := a.hostReachability.Load()
|
||||
if rch != nil && *rch == network.ReachabilityPrivate {
|
||||
a.addrsMx.RLock()
|
||||
// Delete public addresses if the node's reachability is private, and we have relay addresses
|
||||
if len(a.relayAddrs) > 0 {
|
||||
if len(relayAddrs) > 0 {
|
||||
addrs = slices.DeleteFunc(addrs, manet.IsPublicAddr)
|
||||
addrs = append(addrs, a.relayAddrs...)
|
||||
addrs = append(addrs, relayAddrs...)
|
||||
}
|
||||
a.addrsMx.RUnlock()
|
||||
}
|
||||
// Make a copy. Consumers can modify the slice elements
|
||||
addrs = slices.Clone(a.addrsFactory(addrs))
|
||||
@@ -213,7 +349,8 @@ func (a *addrsManager) Addrs() []ma.Multiaddr {
|
||||
return addrs
|
||||
}
|
||||
|
||||
// HolePunchAddrs returns the node's public direct listen addresses for hole punching.
|
||||
// HolePunchAddrs returns all the host's direct public addresses, reachable or unreachable,
|
||||
// suitable for hole punching.
|
||||
func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
|
||||
addrs := a.DirectAddrs()
|
||||
addrs = slices.Clone(a.addrsFactory(addrs))
|
||||
@@ -230,26 +367,23 @@ func (a *addrsManager) HolePunchAddrs() []ma.Multiaddr {
|
||||
func (a *addrsManager) DirectAddrs() []ma.Multiaddr {
|
||||
a.addrsMx.RLock()
|
||||
defer a.addrsMx.RUnlock()
|
||||
return slices.Clone(a.localAddrs)
|
||||
return slices.Clone(a.currentAddrs.localAddrs)
|
||||
}
|
||||
|
||||
func (a *addrsManager) updateRelayAddrs(addrs []ma.Multiaddr) {
|
||||
a.addrsMx.Lock()
|
||||
defer a.addrsMx.Unlock()
|
||||
a.relayAddrs = append(a.relayAddrs[:0], addrs...)
|
||||
// ConfirmedAddrs returns all addresses of the host that are reachable from the internet
|
||||
func (a *addrsManager) ConfirmedAddrs() (reachable []ma.Multiaddr, unreachable []ma.Multiaddr, unknown []ma.Multiaddr) {
|
||||
a.addrsMx.RLock()
|
||||
defer a.addrsMx.RUnlock()
|
||||
return slices.Clone(a.currentAddrs.reachableAddrs), slices.Clone(a.currentAddrs.unreachableAddrs), slices.Clone(a.currentAddrs.unknownAddrs)
|
||||
}
|
||||
|
||||
func (a *addrsManager) getConfirmedAddrs(localAddrs []ma.Multiaddr) (reachableAddrs, unreachableAddrs, unknownAddrs []ma.Multiaddr) {
|
||||
reachableAddrs, unreachableAddrs, unknownAddrs = a.addrsReachabilityTracker.ConfirmedAddrs()
|
||||
return removeNotInSource(reachableAddrs, localAddrs), removeNotInSource(unreachableAddrs, localAddrs), removeNotInSource(unknownAddrs, localAddrs)
|
||||
}
|
||||
|
||||
var p2pCircuitAddr = ma.StringCast("/p2p-circuit")
|
||||
|
||||
func (a *addrsManager) updateLocalAddrs() {
|
||||
localAddrs := a.getLocalAddrs()
|
||||
slices.SortFunc(localAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
|
||||
a.addrsMx.Lock()
|
||||
a.localAddrs = localAddrs
|
||||
a.addrsMx.Unlock()
|
||||
}
|
||||
|
||||
func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
|
||||
listenAddrs := a.listenAddrs()
|
||||
if len(listenAddrs) == 0 {
|
||||
@@ -260,8 +394,6 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
|
||||
finalAddrs = a.appendPrimaryInterfaceAddrs(finalAddrs, listenAddrs)
|
||||
finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs, a.interfaceAddrs.All())
|
||||
|
||||
finalAddrs = ma.Unique(finalAddrs)
|
||||
|
||||
// Remove "/p2p-circuit" addresses from the list.
|
||||
// The p2p-circuit listener reports its address as just /p2p-circuit. This is
|
||||
// useless for dialing. Users need to manage their circuit addresses themselves,
|
||||
@@ -278,6 +410,8 @@ func (a *addrsManager) getLocalAddrs() []ma.Multiaddr {
|
||||
// Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered
|
||||
// using identify.
|
||||
finalAddrs = a.addCertHashes(finalAddrs)
|
||||
finalAddrs = ma.Unique(finalAddrs)
|
||||
slices.SortFunc(finalAddrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
return finalAddrs
|
||||
}
|
||||
|
||||
@@ -408,7 +542,7 @@ func (a *addrsManager) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr {
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (a *addrsManager) areAddrsDifferent(prev, current []ma.Multiaddr) bool {
|
||||
func areAddrsDifferent(prev, current []ma.Multiaddr) bool {
|
||||
// TODO: make the sorted nature of ma.Unique a guarantee in multiaddrs
|
||||
prev = ma.Unique(prev)
|
||||
current = ma.Unique(current)
|
||||
@@ -547,3 +681,31 @@ func (i *interfaceAddrsCache) updateUnlocked() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// removeNotInSource removes items from addrs that are not present in source.
|
||||
// Modifies the addrs slice in place
|
||||
// addrs and source must be sorted using multiaddr.Compare.
|
||||
func removeNotInSource(addrs, source []ma.Multiaddr) []ma.Multiaddr {
|
||||
j := 0
|
||||
// mark entries not in source as nil
|
||||
for i, a := range addrs {
|
||||
// move right as long as a > source[j]
|
||||
for j < len(source) && a.Compare(source[j]) > 0 {
|
||||
j++
|
||||
}
|
||||
// a is not in source if we've reached the end, or a is lesser
|
||||
if j == len(source) || a.Compare(source[j]) < 0 {
|
||||
addrs[i] = nil
|
||||
}
|
||||
// a is in source, nothing to do
|
||||
}
|
||||
// j is the current element, i is the lowest index nil element
|
||||
i := 0
|
||||
for j := range len(addrs) {
|
||||
if addrs[j] != nil {
|
||||
addrs[i], addrs[j] = addrs[j], addrs[i]
|
||||
i++
|
||||
}
|
||||
}
|
||||
return addrs[:i]
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/event"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -135,7 +139,7 @@ type mockNatManager struct {
|
||||
GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr
|
||||
}
|
||||
|
||||
func (m *mockNatManager) Close() error {
|
||||
func (*mockNatManager) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -146,7 +150,7 @@ func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr {
|
||||
return m.GetMappingFunc(addr)
|
||||
}
|
||||
|
||||
func (m *mockNatManager) HasDiscoveredNAT() bool {
|
||||
func (*mockNatManager) HasDiscoveredNAT() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -170,6 +174,8 @@ type addrsManagerArgs struct {
|
||||
AddrsFactory AddrsFactory
|
||||
ObservedAddrsManager observedAddrsManager
|
||||
ListenAddrs func() []ma.Multiaddr
|
||||
AutoNATClient autonatv2Client
|
||||
Bus event.Bus
|
||||
}
|
||||
|
||||
type addrsManagerTestCase struct {
|
||||
@@ -179,13 +185,16 @@ type addrsManagerTestCase struct {
|
||||
}
|
||||
|
||||
func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTestCase {
|
||||
eb := eventbus.NewBus()
|
||||
eb := args.Bus
|
||||
if eb == nil {
|
||||
eb = eventbus.NewBus()
|
||||
}
|
||||
if args.AddrsFactory == nil {
|
||||
args.AddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs }
|
||||
}
|
||||
addrsUpdatedChan := make(chan struct{}, 1)
|
||||
am, err := newAddrsManager(
|
||||
eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan,
|
||||
eb, args.NATManager, args.AddrsFactory, args.ListenAddrs, nil, args.ObservedAddrsManager, addrsUpdatedChan, args.AutoNATClient,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -196,6 +205,7 @@ func newAddrsManagerTestCase(t *testing.T, args addrsManagerArgs) addrsManagerTe
|
||||
rchEm, err := eb.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(am.Close)
|
||||
return addrsManagerTestCase{
|
||||
addrsManager: am,
|
||||
PushRelay: func(relayAddrs []ma.Multiaddr) {
|
||||
@@ -425,17 +435,113 @@ func TestAddrsManager(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddrsManagerReachabilityEvent(t *testing.T) {
|
||||
publicQUIC, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1234/quic-v1")
|
||||
publicQUIC2, _ := ma.NewMultiaddr("/ip4/1.2.3.4/udp/1235/quic-v1")
|
||||
publicTCP, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234")
|
||||
|
||||
bus := eventbus.NewBus()
|
||||
|
||||
sub, err := bus.Subscribe(new(event.EvtHostReachableAddrsChanged))
|
||||
require.NoError(t, err)
|
||||
defer sub.Close()
|
||||
|
||||
am := newAddrsManagerTestCase(t, addrsManagerArgs{
|
||||
Bus: bus,
|
||||
// currently they aren't being passed to the reachability tracker
|
||||
ListenAddrs: func() []ma.Multiaddr { return []ma.Multiaddr{publicQUIC, publicQUIC2, publicTCP} },
|
||||
AutoNATClient: mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
if reqs[0].Addr.Equal(publicQUIC) {
|
||||
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
} else if reqs[0].Addr.Equal(publicTCP) || reqs[0].Addr.Equal(publicQUIC2) {
|
||||
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPrivate}, nil
|
||||
}
|
||||
return autonatv2.Result{}, errors.New("invalid")
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
initialUnknownAddrs := []ma.Multiaddr{publicQUIC, publicTCP, publicQUIC2}
|
||||
|
||||
// First event: all addresses are initially unknown
|
||||
select {
|
||||
case e := <-sub.Out():
|
||||
evt := e.(event.EvtHostReachableAddrsChanged)
|
||||
require.Empty(t, evt.Reachable)
|
||||
require.Empty(t, evt.Unreachable)
|
||||
require.ElementsMatch(t, initialUnknownAddrs, evt.Unknown)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("expected initial event for reachability change")
|
||||
}
|
||||
|
||||
// Wait for probes to complete and addresses to be classified
|
||||
reachableAddrs := []ma.Multiaddr{publicQUIC}
|
||||
unreachableAddrs := []ma.Multiaddr{publicTCP, publicQUIC2}
|
||||
select {
|
||||
case e := <-sub.Out():
|
||||
evt := e.(event.EvtHostReachableAddrsChanged)
|
||||
require.ElementsMatch(t, reachableAddrs, evt.Reachable)
|
||||
require.ElementsMatch(t, unreachableAddrs, evt.Unreachable)
|
||||
require.Empty(t, evt.Unknown)
|
||||
reachable, unreachable, unknown := am.ConfirmedAddrs()
|
||||
require.ElementsMatch(t, reachable, reachableAddrs)
|
||||
require.ElementsMatch(t, unreachable, unreachableAddrs)
|
||||
require.Empty(t, unknown)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("expected final event for reachability change after probing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveIfNotInSource(t *testing.T) {
|
||||
var addrs []ma.Multiaddr
|
||||
for i := 0; i < 10; i++ {
|
||||
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i)))
|
||||
}
|
||||
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
cases := []struct {
|
||||
addrs []ma.Multiaddr
|
||||
source []ma.Multiaddr
|
||||
expected []ma.Multiaddr
|
||||
}{
|
||||
{},
|
||||
{addrs: slices.Clone(addrs[:5]), source: nil, expected: nil},
|
||||
{addrs: nil, source: addrs, expected: nil},
|
||||
{addrs: []ma.Multiaddr{addrs[0]}, source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}},
|
||||
{addrs: slices.Clone(addrs), source: []ma.Multiaddr{addrs[0]}, expected: []ma.Multiaddr{addrs[0]}},
|
||||
{addrs: slices.Clone(addrs), source: slices.Clone(addrs[5:]), expected: slices.Clone(addrs[5:])},
|
||||
{addrs: slices.Clone(addrs[:5]), source: []ma.Multiaddr{addrs[0], addrs[2], addrs[8]}, expected: []ma.Multiaddr{addrs[0], addrs[2]}},
|
||||
}
|
||||
for i, tc := range cases {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
addrs := removeNotInSource(tc.addrs, tc.source)
|
||||
require.ElementsMatch(t, tc.expected, addrs, "%s\n%s", tc.expected, tc.addrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAreAddrsDifferent(b *testing.B) {
|
||||
var addrs [10]ma.Multiaddr
|
||||
for i := 0; i < len(addrs); i++ {
|
||||
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i))
|
||||
}
|
||||
am := &addrsManager{}
|
||||
b.Run("areAddrsDifferent", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
am.areAddrsDifferent(addrs[:], addrs[:])
|
||||
areAddrsDifferent(addrs[:], addrs[:])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkRemoveIfNotInSource(b *testing.B) {
|
||||
var addrs [10]ma.Multiaddr
|
||||
for i := 0; i < len(addrs); i++ {
|
||||
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.%d/tcp/1", i))
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
removeNotInSource(slices.Clone(addrs[:5]), addrs[:])
|
||||
}
|
||||
}
|
||||
|
||||
671
p2p/host/basic/addrs_reachability_tracker.go
Normal file
671
p2p/host/basic/addrs_reachability_tracker.go
Normal file
@@ -0,0 +1,671 @@
|
||||
package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
type autonatv2Client interface {
|
||||
GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
// maxAddrsPerRequest is the maximum number of addresses to probe in a single request
|
||||
maxAddrsPerRequest = 10
|
||||
// maxTrackedAddrs is the maximum number of addresses to track
|
||||
// 10 addrs per transport for 5 transports
|
||||
maxTrackedAddrs = 50
|
||||
// defaultMaxConcurrency is the default number of concurrent workers for reachability checks
|
||||
defaultMaxConcurrency = 5
|
||||
// newAddrsProbeDelay is the delay before probing new addr's reachability.
|
||||
newAddrsProbeDelay = 1 * time.Second
|
||||
)
|
||||
|
||||
// addrsReachabilityTracker tracks reachability for addresses.
|
||||
// Use UpdateAddrs to provide addresses for tracking reachability.
|
||||
// reachabilityUpdateCh is notified when reachability for any of the tracked address changes.
|
||||
type addrsReachabilityTracker struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
client autonatv2Client
|
||||
// reachabilityUpdateCh is used to notify when reachability may have changed
|
||||
reachabilityUpdateCh chan struct{}
|
||||
maxConcurrency int
|
||||
newAddrsProbeDelay time.Duration
|
||||
probeManager *probeManager
|
||||
newAddrs chan []ma.Multiaddr
|
||||
clock clock.Clock
|
||||
|
||||
mx sync.Mutex
|
||||
reachableAddrs []ma.Multiaddr
|
||||
unreachableAddrs []ma.Multiaddr
|
||||
unknownAddrs []ma.Multiaddr
|
||||
}
|
||||
|
||||
// newAddrsReachabilityTracker returns a new addrsReachabilityTracker.
|
||||
// reachabilityUpdateCh is notified when reachability for any of the tracked address changes.
|
||||
func newAddrsReachabilityTracker(client autonatv2Client, reachabilityUpdateCh chan struct{}, cl clock.Clock) *addrsReachabilityTracker {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if cl == nil {
|
||||
cl = clock.New()
|
||||
}
|
||||
return &addrsReachabilityTracker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: client,
|
||||
reachabilityUpdateCh: reachabilityUpdateCh,
|
||||
probeManager: newProbeManager(cl.Now),
|
||||
newAddrsProbeDelay: newAddrsProbeDelay,
|
||||
maxConcurrency: defaultMaxConcurrency,
|
||||
newAddrs: make(chan []ma.Multiaddr, 1),
|
||||
clock: cl,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) UpdateAddrs(addrs []ma.Multiaddr) {
|
||||
select {
|
||||
case r.newAddrs <- slices.Clone(addrs):
|
||||
case <-r.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) ConfirmedAddrs() (reachableAddrs, unreachableAddrs, unknownAddrs []ma.Multiaddr) {
|
||||
r.mx.Lock()
|
||||
defer r.mx.Unlock()
|
||||
return slices.Clone(r.reachableAddrs), slices.Clone(r.unreachableAddrs), slices.Clone(r.unknownAddrs)
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) Start() error {
|
||||
r.wg.Add(1)
|
||||
go r.background()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) Close() error {
|
||||
r.cancel()
|
||||
r.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// defaultReachabilityRefreshInterval is the default interval to refresh reachability.
|
||||
// In steady state, we check for any required probes every refresh interval.
|
||||
// This doesn't mean we'll probe for any particular address, only that we'll check
|
||||
// if any address needs to be probed.
|
||||
defaultReachabilityRefreshInterval = 5 * time.Minute
|
||||
// maxBackoffInterval is the maximum back off in case we're unable to probe for reachability.
|
||||
// We may be unable to confirm addresses in case there are no valid peers with autonatv2
|
||||
// or the autonatv2 subsystem is consistently erroring.
|
||||
maxBackoffInterval = 5 * time.Minute
|
||||
// backoffStartInterval is the initial back off in case we're unable to probe for reachability.
|
||||
backoffStartInterval = 5 * time.Second
|
||||
)
|
||||
|
||||
func (r *addrsReachabilityTracker) background() {
|
||||
defer r.wg.Done()
|
||||
|
||||
// probeTicker is used to trigger probes at regular intervals
|
||||
probeTicker := r.clock.Ticker(defaultReachabilityRefreshInterval)
|
||||
defer probeTicker.Stop()
|
||||
|
||||
// probeTimer is used to trigger probes at specific times
|
||||
probeTimer := r.clock.Timer(time.Duration(math.MaxInt64))
|
||||
defer probeTimer.Stop()
|
||||
nextProbeTime := time.Time{}
|
||||
|
||||
var task reachabilityTask
|
||||
var backoffInterval time.Duration
|
||||
var currReachable, currUnreachable, currUnknown, prevReachable, prevUnreachable, prevUnknown []ma.Multiaddr
|
||||
for {
|
||||
select {
|
||||
case <-probeTicker.C:
|
||||
// don't start a probe if we have a scheduled probe
|
||||
if task.BackoffCh == nil && nextProbeTime.IsZero() {
|
||||
task = r.refreshReachability()
|
||||
}
|
||||
case <-probeTimer.C:
|
||||
if task.BackoffCh == nil {
|
||||
task = r.refreshReachability()
|
||||
}
|
||||
nextProbeTime = time.Time{}
|
||||
case backoff := <-task.BackoffCh:
|
||||
task = reachabilityTask{}
|
||||
// On completion, start the next probe immediately, or wait for backoff.
|
||||
// In case there are no further probes, the reachability tracker will return an empty task,
|
||||
// which hangs forever. Eventually, we'll refresh again when the ticker fires.
|
||||
if backoff {
|
||||
backoffInterval = newBackoffInterval(backoffInterval)
|
||||
} else {
|
||||
backoffInterval = -1 * time.Second // negative to trigger next probe immediately
|
||||
}
|
||||
nextProbeTime = r.clock.Now().Add(backoffInterval)
|
||||
case addrs := <-r.newAddrs:
|
||||
if task.BackoffCh != nil { // cancel running task.
|
||||
task.Cancel()
|
||||
<-task.BackoffCh // ignore backoff from cancelled task
|
||||
task = reachabilityTask{}
|
||||
}
|
||||
r.updateTrackedAddrs(addrs)
|
||||
newAddrsNextTime := r.clock.Now().Add(r.newAddrsProbeDelay)
|
||||
if nextProbeTime.Before(newAddrsNextTime) {
|
||||
nextProbeTime = newAddrsNextTime
|
||||
}
|
||||
case <-r.ctx.Done():
|
||||
if task.BackoffCh != nil {
|
||||
task.Cancel()
|
||||
<-task.BackoffCh
|
||||
task = reachabilityTask{}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
currReachable, currUnreachable, currUnknown = r.appendConfirmedAddrs(currReachable[:0], currUnreachable[:0], currUnknown[:0])
|
||||
if areAddrsDifferent(prevReachable, currReachable) || areAddrsDifferent(prevUnreachable, currUnreachable) || areAddrsDifferent(prevUnknown, currUnknown) {
|
||||
r.notify()
|
||||
}
|
||||
prevReachable = append(prevReachable[:0], currReachable...)
|
||||
prevUnreachable = append(prevUnreachable[:0], currUnreachable...)
|
||||
prevUnknown = append(prevUnknown[:0], currUnknown...)
|
||||
if !nextProbeTime.IsZero() {
|
||||
probeTimer.Reset(nextProbeTime.Sub(r.clock.Now()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newBackoffInterval(current time.Duration) time.Duration {
|
||||
if current <= 0 {
|
||||
return backoffStartInterval
|
||||
}
|
||||
current *= 2
|
||||
if current > maxBackoffInterval {
|
||||
return maxBackoffInterval
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) appendConfirmedAddrs(reachable, unreachable, unknown []ma.Multiaddr) (reachableAddrs, unreachableAddrs, unknownAddrs []ma.Multiaddr) {
|
||||
reachable, unreachable, unknown = r.probeManager.AppendConfirmedAddrs(reachable, unreachable, unknown)
|
||||
r.mx.Lock()
|
||||
r.reachableAddrs = append(r.reachableAddrs[:0], reachable...)
|
||||
r.unreachableAddrs = append(r.unreachableAddrs[:0], unreachable...)
|
||||
r.unknownAddrs = append(r.unknownAddrs[:0], unknown...)
|
||||
r.mx.Unlock()
|
||||
return reachable, unreachable, unknown
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) notify() {
|
||||
select {
|
||||
case r.reachabilityUpdateCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) updateTrackedAddrs(addrs []ma.Multiaddr) {
|
||||
addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool {
|
||||
return !manet.IsPublicAddr(a)
|
||||
})
|
||||
if len(addrs) > maxTrackedAddrs {
|
||||
log.Errorf("too many addresses (%d) for addrs reachability tracker; dropping %d", len(addrs), len(addrs)-maxTrackedAddrs)
|
||||
addrs = addrs[:maxTrackedAddrs]
|
||||
}
|
||||
r.probeManager.UpdateAddrs(addrs)
|
||||
}
|
||||
|
||||
type probe = []autonatv2.Request
|
||||
|
||||
const probeTimeout = 30 * time.Second
|
||||
|
||||
// reachabilityTask is a task to refresh reachability.
|
||||
// Waiting on the zero value blocks forever.
|
||||
type reachabilityTask struct {
|
||||
Cancel context.CancelFunc
|
||||
// BackoffCh returns whether the caller should backoff before
|
||||
// refreshing reachability
|
||||
BackoffCh chan bool
|
||||
}
|
||||
|
||||
func (r *addrsReachabilityTracker) refreshReachability() reachabilityTask {
|
||||
if len(r.probeManager.GetProbe()) == 0 {
|
||||
return reachabilityTask{}
|
||||
}
|
||||
resCh := make(chan bool, 1)
|
||||
ctx, cancel := context.WithTimeout(r.ctx, 5*time.Minute)
|
||||
r.wg.Add(1)
|
||||
// We run probes provided by addrsTracker. It stops probing when any
|
||||
// of the following happens:
|
||||
// - there are no more probes to run
|
||||
// - context is completed
|
||||
// - there are too many consecutive failures from the client
|
||||
// - the client has no valid peers to probe
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
defer cancel()
|
||||
client := &errCountingClient{autonatv2Client: r.client, MaxConsecutiveErrors: maxConsecutiveErrors}
|
||||
var backoff atomic.Bool
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(r.maxConcurrency)
|
||||
for range r.maxConcurrency {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
reqs := r.probeManager.GetProbe()
|
||||
if len(reqs) == 0 {
|
||||
return
|
||||
}
|
||||
r.probeManager.MarkProbeInProgress(reqs)
|
||||
rctx, cancel := context.WithTimeout(ctx, probeTimeout)
|
||||
res, err := client.GetReachability(rctx, reqs)
|
||||
cancel()
|
||||
r.probeManager.CompleteProbe(reqs, res, err)
|
||||
if isErrorPersistent(err) {
|
||||
backoff.Store(true)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
resCh <- backoff.Load()
|
||||
}()
|
||||
return reachabilityTask{Cancel: cancel, BackoffCh: resCh}
|
||||
}
|
||||
|
||||
var errTooManyConsecutiveFailures = errors.New("too many consecutive failures")
|
||||
|
||||
// errCountingClient counts errors from autonatv2Client and wraps the errors in response with a
|
||||
// errTooManyConsecutiveFailures in case of persistent failures from autonatv2 module.
|
||||
type errCountingClient struct {
|
||||
autonatv2Client
|
||||
MaxConsecutiveErrors int
|
||||
mx sync.Mutex
|
||||
consecutiveErrors int
|
||||
}
|
||||
|
||||
func (c *errCountingClient) GetReachability(ctx context.Context, reqs probe) (autonatv2.Result, error) {
|
||||
res, err := c.autonatv2Client.GetReachability(ctx, reqs)
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
if err != nil && !errors.Is(err, context.Canceled) { // ignore canceled errors, they're not errors from autonatv2
|
||||
c.consecutiveErrors++
|
||||
if c.consecutiveErrors > c.MaxConsecutiveErrors {
|
||||
err = fmt.Errorf("%w:%w", errTooManyConsecutiveFailures, err)
|
||||
}
|
||||
if errors.Is(err, autonatv2.ErrPrivateAddrs) {
|
||||
log.Errorf("private IP addr in autonatv2 request: %s", err)
|
||||
}
|
||||
} else {
|
||||
c.consecutiveErrors = 0
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
const maxConsecutiveErrors = 20
|
||||
|
||||
// isErrorPersistent returns whether the error will repeat on future probes for a while
|
||||
func isErrorPersistent(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, autonatv2.ErrPrivateAddrs) || errors.Is(err, autonatv2.ErrNoPeers) ||
|
||||
errors.Is(err, errTooManyConsecutiveFailures)
|
||||
}
|
||||
|
||||
const (
|
||||
// recentProbeInterval is the interval to probe addresses that have been refused
|
||||
// these are generally addresses with newer transports for which we don't have many peers
|
||||
// capable of dialing the transport
|
||||
recentProbeInterval = 10 * time.Minute
|
||||
// maxConsecutiveRefusals is the maximum number of consecutive refusals for an address after which
|
||||
// we wait for `recentProbeInterval` before probing again
|
||||
maxConsecutiveRefusals = 5
|
||||
// maxRecentDialsPerAddr is the maximum number of dials on an address before we stop probing for the address.
|
||||
// This is used to prevent infinite probing of an address whose status is indeterminate for any reason.
|
||||
maxRecentDialsPerAddr = 10
|
||||
// confidence is the absolute difference between the number of successes and failures for an address
|
||||
// targetConfidence is the confidence threshold for an address after which we wait for `maxProbeInterval`
|
||||
// before probing again.
|
||||
targetConfidence = 3
|
||||
// minConfidence is the confidence threshold for an address to be considered reachable or unreachable
|
||||
// confidence is the absolute difference between the number of successes and failures for an address
|
||||
minConfidence = 2
|
||||
// maxRecentDialsWindow is the maximum number of recent probe results to consider for a single address
|
||||
//
|
||||
// +2 allows for 1 invalid probe result. Consider a string of successes, after which we have a single failure
|
||||
// and then a success(...S S S S F S). The confidence in the targetConfidence window will be equal to
|
||||
// targetConfidence, the last F and S cancel each other, and we won't probe again for maxProbeInterval.
|
||||
maxRecentDialsWindow = targetConfidence + 2
|
||||
// highConfidenceAddrProbeInterval is the maximum interval between probes for an address
|
||||
highConfidenceAddrProbeInterval = 1 * time.Hour
|
||||
// maxProbeResultTTL is the maximum time to keep probe results for an address
|
||||
maxProbeResultTTL = maxRecentDialsWindow * highConfidenceAddrProbeInterval
|
||||
)
|
||||
|
||||
// probeManager tracks reachability for a set of addresses by periodically probing reachability with autonatv2.
|
||||
// A Probe is a list of addresses which can be tested for reachability with autonatv2.
|
||||
// This struct decides the priority order of addresses for testing reachability, and throttles in case there have
|
||||
// been too many probes for an address in the `ProbeInterval`.
|
||||
//
|
||||
// Use the `runProbes` function to execute the probes with an autonatv2 client.
|
||||
type probeManager struct {
|
||||
now func() time.Time
|
||||
|
||||
mx sync.Mutex
|
||||
inProgressProbes map[string]int // addr -> count
|
||||
inProgressProbesTotal int
|
||||
statuses map[string]*addrStatus
|
||||
addrs []ma.Multiaddr
|
||||
}
|
||||
|
||||
// newProbeManager creates a new probe manager.
|
||||
func newProbeManager(now func() time.Time) *probeManager {
|
||||
return &probeManager{
|
||||
statuses: make(map[string]*addrStatus),
|
||||
inProgressProbes: make(map[string]int),
|
||||
now: now,
|
||||
}
|
||||
}
|
||||
|
||||
// AppendConfirmedAddrs appends the current confirmed reachable and unreachable addresses.
|
||||
func (m *probeManager) AppendConfirmedAddrs(reachable, unreachable, unknown []ma.Multiaddr) (reachableAddrs, unreachableAddrs, unknownAddrs []ma.Multiaddr) {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
for _, a := range m.addrs {
|
||||
s := m.statuses[string(a.Bytes())]
|
||||
s.RemoveBefore(m.now().Add(-maxProbeResultTTL)) // cleanup stale results
|
||||
switch s.Reachability() {
|
||||
case network.ReachabilityPublic:
|
||||
reachable = append(reachable, a)
|
||||
case network.ReachabilityPrivate:
|
||||
unreachable = append(unreachable, a)
|
||||
case network.ReachabilityUnknown:
|
||||
unknown = append(unknown, a)
|
||||
}
|
||||
}
|
||||
return reachable, unreachable, unknown
|
||||
}
|
||||
|
||||
// UpdateAddrs updates the tracked addrs
|
||||
func (m *probeManager) UpdateAddrs(addrs []ma.Multiaddr) {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return a.Compare(b) })
|
||||
statuses := make(map[string]*addrStatus, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
k := string(addr.Bytes())
|
||||
if _, ok := m.statuses[k]; !ok {
|
||||
statuses[k] = &addrStatus{Addr: addr}
|
||||
} else {
|
||||
statuses[k] = m.statuses[k]
|
||||
}
|
||||
}
|
||||
m.addrs = addrs
|
||||
m.statuses = statuses
|
||||
}
|
||||
|
||||
// GetProbe returns the next probe. Returns zero value in case there are no more probes.
|
||||
// Probes that are run against an autonatv2 client should be marked in progress with
|
||||
// `MarkProbeInProgress` before running.
|
||||
func (m *probeManager) GetProbe() probe {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
now := m.now()
|
||||
for i, a := range m.addrs {
|
||||
ab := a.Bytes()
|
||||
pc := m.statuses[string(ab)].RequiredProbeCount(now)
|
||||
if m.inProgressProbes[string(ab)] >= pc {
|
||||
continue
|
||||
}
|
||||
reqs := make(probe, 0, maxAddrsPerRequest)
|
||||
reqs = append(reqs, autonatv2.Request{Addr: a, SendDialData: true})
|
||||
// We have the first(primary) address. Append other addresses, ignoring inprogress probes
|
||||
// on secondary addresses. The expectation is that the primary address will
|
||||
// be dialed.
|
||||
for j := 1; j < len(m.addrs); j++ {
|
||||
k := (i + j) % len(m.addrs)
|
||||
ab := m.addrs[k].Bytes()
|
||||
pc := m.statuses[string(ab)].RequiredProbeCount(now)
|
||||
if pc == 0 {
|
||||
continue
|
||||
}
|
||||
reqs = append(reqs, autonatv2.Request{Addr: m.addrs[k], SendDialData: true})
|
||||
if len(reqs) >= maxAddrsPerRequest {
|
||||
break
|
||||
}
|
||||
}
|
||||
return reqs
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkProbeInProgress should be called when a probe is started.
|
||||
// All in progress probes *MUST* be completed with `CompleteProbe`
|
||||
func (m *probeManager) MarkProbeInProgress(reqs probe) {
|
||||
if len(reqs) == 0 {
|
||||
return
|
||||
}
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
m.inProgressProbes[string(reqs[0].Addr.Bytes())]++
|
||||
m.inProgressProbesTotal++
|
||||
}
|
||||
|
||||
// InProgressProbes returns the number of probes that are currently in progress.
|
||||
func (m *probeManager) InProgressProbes() int {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
return m.inProgressProbesTotal
|
||||
}
|
||||
|
||||
// CompleteProbe should be called when a probe completes.
|
||||
func (m *probeManager) CompleteProbe(reqs probe, res autonatv2.Result, err error) {
|
||||
now := m.now()
|
||||
|
||||
if len(reqs) == 0 {
|
||||
// should never happen
|
||||
return
|
||||
}
|
||||
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
// decrement in-progress count for the first address
|
||||
primaryAddrKey := string(reqs[0].Addr.Bytes())
|
||||
m.inProgressProbes[primaryAddrKey]--
|
||||
if m.inProgressProbes[primaryAddrKey] <= 0 {
|
||||
delete(m.inProgressProbes, primaryAddrKey)
|
||||
}
|
||||
m.inProgressProbesTotal--
|
||||
|
||||
// nothing to do if the request errored.
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Consider only primary address as refused. This increases the number of
|
||||
// refused probes, but refused probes are cheap for a server as no dials are made.
|
||||
if res.AllAddrsRefused {
|
||||
if s, ok := m.statuses[primaryAddrKey]; ok {
|
||||
s.AddRefusal(now)
|
||||
}
|
||||
return
|
||||
}
|
||||
dialAddrKey := string(res.Addr.Bytes())
|
||||
if dialAddrKey != primaryAddrKey {
|
||||
if s, ok := m.statuses[primaryAddrKey]; ok {
|
||||
s.AddRefusal(now)
|
||||
}
|
||||
}
|
||||
|
||||
// record the result for the dialed address
|
||||
if s, ok := m.statuses[dialAddrKey]; ok {
|
||||
s.AddOutcome(now, res.Reachability, maxRecentDialsWindow)
|
||||
}
|
||||
}
|
||||
|
||||
type dialOutcome struct {
|
||||
Success bool
|
||||
At time.Time
|
||||
}
|
||||
|
||||
type addrStatus struct {
|
||||
Addr ma.Multiaddr
|
||||
lastRefusalTime time.Time
|
||||
consecutiveRefusals int
|
||||
dialTimes []time.Time
|
||||
outcomes []dialOutcome
|
||||
}
|
||||
|
||||
func (s *addrStatus) Reachability() network.Reachability {
|
||||
rch, _, _ := s.reachabilityAndCounts()
|
||||
return rch
|
||||
}
|
||||
|
||||
func (s *addrStatus) RequiredProbeCount(now time.Time) int {
|
||||
if s.consecutiveRefusals >= maxConsecutiveRefusals {
|
||||
if now.Sub(s.lastRefusalTime) < recentProbeInterval {
|
||||
return 0
|
||||
}
|
||||
// reset every `recentProbeInterval`
|
||||
s.lastRefusalTime = time.Time{}
|
||||
s.consecutiveRefusals = 0
|
||||
}
|
||||
|
||||
// Don't probe if we have probed too many times recently
|
||||
rd := s.recentDialCount(now)
|
||||
if rd >= maxRecentDialsPerAddr {
|
||||
return 0
|
||||
}
|
||||
|
||||
return s.requiredProbeCountForConfirmation(now)
|
||||
}
|
||||
|
||||
func (s *addrStatus) requiredProbeCountForConfirmation(now time.Time) int {
|
||||
reachability, successes, failures := s.reachabilityAndCounts()
|
||||
confidence := successes - failures
|
||||
if confidence < 0 {
|
||||
confidence = -confidence
|
||||
}
|
||||
cnt := targetConfidence - confidence
|
||||
if cnt > 0 {
|
||||
return cnt
|
||||
}
|
||||
// we have enough confirmations; check if we should refresh
|
||||
|
||||
// Should never happen. The confidence logic above should require a few probes.
|
||||
if len(s.outcomes) == 0 {
|
||||
return 0
|
||||
}
|
||||
lastOutcome := s.outcomes[len(s.outcomes)-1]
|
||||
// If the last probe result is old, we need to retest
|
||||
if now.Sub(lastOutcome.At) > highConfidenceAddrProbeInterval {
|
||||
return 1
|
||||
}
|
||||
// if the last probe result was different from reachability, probe again.
|
||||
switch reachability {
|
||||
case network.ReachabilityPublic:
|
||||
if !lastOutcome.Success {
|
||||
return 1
|
||||
}
|
||||
case network.ReachabilityPrivate:
|
||||
if lastOutcome.Success {
|
||||
return 1
|
||||
}
|
||||
default:
|
||||
// this should never happen
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *addrStatus) AddRefusal(now time.Time) {
|
||||
s.lastRefusalTime = now
|
||||
s.consecutiveRefusals++
|
||||
}
|
||||
|
||||
func (s *addrStatus) AddOutcome(at time.Time, rch network.Reachability, windowSize int) {
|
||||
s.lastRefusalTime = time.Time{}
|
||||
s.consecutiveRefusals = 0
|
||||
|
||||
s.dialTimes = append(s.dialTimes, at)
|
||||
for i, t := range s.dialTimes {
|
||||
if at.Sub(t) < recentProbeInterval {
|
||||
s.dialTimes = slices.Delete(s.dialTimes, 0, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.RemoveBefore(at.Add(-maxProbeResultTTL)) // remove old outcomes
|
||||
success := false
|
||||
switch rch {
|
||||
case network.ReachabilityPublic:
|
||||
success = true
|
||||
case network.ReachabilityPrivate:
|
||||
success = false
|
||||
default:
|
||||
return // don't store the outcome if reachability is unknown
|
||||
}
|
||||
s.outcomes = append(s.outcomes, dialOutcome{At: at, Success: success})
|
||||
if len(s.outcomes) > windowSize {
|
||||
s.outcomes = slices.Delete(s.outcomes, 0, len(s.outcomes)-windowSize)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveBefore removes outcomes before t
|
||||
func (s *addrStatus) RemoveBefore(t time.Time) {
|
||||
end := 0
|
||||
for ; end < len(s.outcomes); end++ {
|
||||
if !s.outcomes[end].At.Before(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
s.outcomes = slices.Delete(s.outcomes, 0, end)
|
||||
}
|
||||
|
||||
func (s *addrStatus) recentDialCount(now time.Time) int {
|
||||
cnt := 0
|
||||
for _, t := range slices.Backward(s.dialTimes) {
|
||||
if now.Sub(t) > recentProbeInterval {
|
||||
break
|
||||
}
|
||||
cnt++
|
||||
}
|
||||
return cnt
|
||||
}
|
||||
|
||||
func (s *addrStatus) reachabilityAndCounts() (rch network.Reachability, successes int, failures int) {
|
||||
for _, r := range s.outcomes {
|
||||
if r.Success {
|
||||
successes++
|
||||
} else {
|
||||
failures++
|
||||
}
|
||||
}
|
||||
if successes-failures >= minConfidence {
|
||||
return network.ReachabilityPublic, successes, failures
|
||||
}
|
||||
if failures-successes >= minConfidence {
|
||||
return network.ReachabilityPrivate, successes, failures
|
||||
}
|
||||
return network.ReachabilityUnknown, successes, failures
|
||||
}
|
||||
942
p2p/host/basic/addrs_reachability_tracker_test.go
Normal file
942
p2p/host/basic/addrs_reachability_tracker_test.go
Normal file
@@ -0,0 +1,942 @@
|
||||
package basichost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProbeManager(t *testing.T) {
|
||||
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
|
||||
pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1")
|
||||
pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1")
|
||||
|
||||
cl := clock.NewMock()
|
||||
|
||||
nextProbe := func(pm *probeManager) []autonatv2.Request {
|
||||
reqs := pm.GetProbe()
|
||||
if len(reqs) != 0 {
|
||||
pm.MarkProbeInProgress(reqs)
|
||||
}
|
||||
return reqs
|
||||
}
|
||||
|
||||
makeNewProbeManager := func(addrs []ma.Multiaddr) *probeManager {
|
||||
pm := newProbeManager(cl.Now)
|
||||
pm.UpdateAddrs(addrs)
|
||||
return pm
|
||||
}
|
||||
|
||||
t.Run("addrs updates", func(t *testing.T) {
|
||||
pm := newProbeManager(cl.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub1, pub2})
|
||||
for {
|
||||
reqs := nextProbe(pm)
|
||||
if len(reqs) == 0 {
|
||||
break
|
||||
}
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
reachable, _, _ := pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1, pub2})
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub3})
|
||||
|
||||
reachable, _, _ = pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Empty(t, reachable)
|
||||
require.Len(t, pm.statuses, 1)
|
||||
})
|
||||
|
||||
t.Run("inprogress", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
reqs1 := pm.GetProbe()
|
||||
reqs2 := pm.GetProbe()
|
||||
require.Equal(t, reqs1, reqs2)
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
}
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}})
|
||||
}
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
})
|
||||
|
||||
t.Run("refusals", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
var probes [][]autonatv2.Request
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
probes = append(probes, reqs)
|
||||
}
|
||||
// first one refused second one successful
|
||||
for _, p := range probes {
|
||||
pm.CompleteProbe(p, autonatv2.Result{Addr: pub2, Idx: 1, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
// the second address is validated!
|
||||
probes = nil
|
||||
for range targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}})
|
||||
probes = append(probes, reqs)
|
||||
}
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
for _, p := range probes {
|
||||
pm.CompleteProbe(p, autonatv2.Result{AllAddrsRefused: true}, nil)
|
||||
}
|
||||
// all requests refused; no more probes for too many refusals
|
||||
reqs = pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
|
||||
cl.Add(recentProbeInterval)
|
||||
reqs = pm.GetProbe()
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}})
|
||||
})
|
||||
|
||||
t.Run("successes", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
for j := 0; j < 2; j++ {
|
||||
for i := 0; i < targetConfidence; i++ {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
}
|
||||
// all addrs confirmed
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
|
||||
cl.Add(highConfidenceAddrProbeInterval + time.Millisecond)
|
||||
reqs = nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
reqs = nextProbe(pm)
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub2, SendDialData: true}, {Addr: pub1, SendDialData: true}})
|
||||
})
|
||||
|
||||
t.Run("throttling on indeterminate reachability", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
reachability := network.ReachabilityPublic
|
||||
nextReachability := func() network.Reachability {
|
||||
if reachability == network.ReachabilityPublic {
|
||||
reachability = network.ReachabilityPrivate
|
||||
} else {
|
||||
reachability = network.ReachabilityPublic
|
||||
}
|
||||
return reachability
|
||||
}
|
||||
// both addresses are indeterminate
|
||||
for range 2 * maxRecentDialsPerAddr {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil)
|
||||
}
|
||||
reqs := pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
|
||||
cl.Add(recentProbeInterval + time.Millisecond)
|
||||
reqs = pm.GetProbe()
|
||||
require.Equal(t, reqs, []autonatv2.Request{{Addr: pub1, SendDialData: true}, {Addr: pub2, SendDialData: true}})
|
||||
for range 2 * maxRecentDialsPerAddr {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: nextReachability()}, nil)
|
||||
}
|
||||
reqs = pm.GetProbe()
|
||||
require.Empty(t, reqs)
|
||||
})
|
||||
|
||||
t.Run("reachabilityUpdate", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1, pub2})
|
||||
for range 2 * targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
if reqs[0].Addr.Equal(pub1) {
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
} else {
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub2, Idx: 0, Reachability: network.ReachabilityPrivate}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
reachable, unreachable, _ := pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2})
|
||||
})
|
||||
t.Run("expiry", func(t *testing.T) {
|
||||
pm := makeNewProbeManager([]ma.Multiaddr{pub1})
|
||||
for range 2 * targetConfidence {
|
||||
reqs := nextProbe(pm)
|
||||
pm.CompleteProbe(reqs, autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
|
||||
reachable, unreachable, _ := pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
|
||||
cl.Add(maxProbeResultTTL + 1*time.Second)
|
||||
reachable, unreachable, _ = pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
})
|
||||
}
|
||||
|
||||
type mockAutoNATClient struct {
|
||||
F func(context.Context, []autonatv2.Request) (autonatv2.Result, error)
|
||||
}
|
||||
|
||||
func (m mockAutoNATClient) GetReachability(ctx context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return m.F(ctx, reqs)
|
||||
}
|
||||
|
||||
var _ autonatv2Client = mockAutoNATClient{}
|
||||
|
||||
func TestAddrsReachabilityTracker(t *testing.T) {
|
||||
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
|
||||
pub2 := ma.StringCast("/ip4/1.1.1.2/tcp/1")
|
||||
pub3 := ma.StringCast("/ip4/1.1.1.3/tcp/1")
|
||||
pri := ma.StringCast("/ip4/192.168.1.1/tcp/1")
|
||||
|
||||
assertFirstEvent := func(t *testing.T, tr *addrsReachabilityTracker, addrs []ma.Multiaddr) {
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected first event quickly")
|
||||
}
|
||||
reachable, unreachable, unknown := tr.ConfirmedAddrs()
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
require.ElementsMatch(t, unknown, addrs, "%s %s", unknown, addrs)
|
||||
}
|
||||
|
||||
newTracker := func(cli mockAutoNATClient, cl clock.Clock) *addrsReachabilityTracker {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if cl == nil {
|
||||
cl = clock.New()
|
||||
}
|
||||
tr := &addrsReachabilityTracker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: cli,
|
||||
newAddrs: make(chan []ma.Multiaddr, 1),
|
||||
reachabilityUpdateCh: make(chan struct{}, 1),
|
||||
maxConcurrency: 3,
|
||||
newAddrsProbeDelay: 0 * time.Second,
|
||||
probeManager: newProbeManager(cl.Now),
|
||||
clock: cl,
|
||||
}
|
||||
err := tr.Start()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := tr.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
return tr
|
||||
}
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
// pub1 reachable, pub2 unreachable, pub3 ignored
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
for i, req := range reqs {
|
||||
if req.Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
|
||||
} else if req.Addr.Equal(pub2) {
|
||||
return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil
|
||||
}
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
tr := newTracker(mockClient, nil)
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub2, pub1, pri})
|
||||
assertFirstEvent(t, tr, []ma.Multiaddr{pub1, pub2})
|
||||
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
reachable, unreachable, unknown := tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1)
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2)
|
||||
require.Empty(t, unknown)
|
||||
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub3, pub1, pub2, pri})
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
reachable, unreachable, unknown = tr.ConfirmedAddrs()
|
||||
t.Logf("Second probe - Reachable: %v, Unreachable: %v, Unknown: %v", reachable, unreachable, unknown)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1}, "%s %s", reachable, pub1)
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2}, "%s %s", unreachable, pub2)
|
||||
require.Equal(t, unknown, []ma.Multiaddr{pub3}, "%s %s", unknown, pub3)
|
||||
})
|
||||
|
||||
t.Run("confirmed addrs ordering", func(t *testing.T) {
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
},
|
||||
}
|
||||
tr := newTracker(mockClient, nil)
|
||||
var addrs []ma.Multiaddr
|
||||
for i := 0; i < 10; i++ {
|
||||
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", i)))
|
||||
}
|
||||
slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return -a.Compare(b) }) // sort in reverse order
|
||||
tr.UpdateAddrs(addrs)
|
||||
assertFirstEvent(t, tr, addrs)
|
||||
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
reachable, unreachable, _ := tr.ConfirmedAddrs()
|
||||
require.Empty(t, unreachable)
|
||||
|
||||
orderedAddrs := slices.Clone(addrs)
|
||||
slices.Reverse(orderedAddrs)
|
||||
require.Equal(t, reachable, orderedAddrs, "%s %s", reachable, addrs)
|
||||
})
|
||||
|
||||
t.Run("backoff", func(t *testing.T) {
|
||||
notify := make(chan struct{}, 1)
|
||||
drainNotify := func() bool {
|
||||
found := false
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
found = true
|
||||
default:
|
||||
return found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var allow atomic.Bool
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
select {
|
||||
case notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if !allow.Load() {
|
||||
return autonatv2.Result{}, autonatv2.ErrNoPeers
|
||||
}
|
||||
if reqs[0].Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cl := clock.NewMock()
|
||||
tr := newTracker(mockClient, cl)
|
||||
|
||||
// update addrs and wait for initial checks
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
assertFirstEvent(t, tr, []ma.Multiaddr{pub1})
|
||||
// need to update clock after the background goroutine processes the new addrs
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cl.Add(1)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.True(t, drainNotify()) // check that we did receive probes
|
||||
|
||||
backoffInterval := backoffStartInterval
|
||||
for i := 0; i < 4; i++ {
|
||||
drainNotify()
|
||||
cl.Add(backoffInterval / 2)
|
||||
select {
|
||||
case <-notify:
|
||||
t.Fatal("unexpected call")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
cl.Add(backoffInterval/2 + 1) // +1 to push it slightly over the backoff interval
|
||||
backoffInterval *= 2
|
||||
select {
|
||||
case <-notify:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected probe")
|
||||
}
|
||||
reachable, unreachable, _ := tr.ConfirmedAddrs()
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
}
|
||||
allow.Store(true)
|
||||
drainNotify()
|
||||
cl.Add(backoffInterval + 1)
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("unexpected reachability update")
|
||||
}
|
||||
reachable, unreachable, _ := tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
})
|
||||
|
||||
t.Run("event update", func(t *testing.T) {
|
||||
// allow minConfidence probes to pass
|
||||
called := make(chan struct{}, minConfidence)
|
||||
notify := make(chan struct{})
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
select {
|
||||
case called <- struct{}{}:
|
||||
notify <- struct{}{}
|
||||
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
default:
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
tr := newTracker(mockClient, nil)
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
assertFirstEvent(t, tr, []ma.Multiaddr{pub1})
|
||||
|
||||
for i := 0; i < minConfidence; i++ {
|
||||
select {
|
||||
case <-notify:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected call to autonat client")
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
reachable, unreachable, _ := tr.ConfirmedAddrs()
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1}) // same addrs shouldn't get update
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
t.Fatal("didn't expect reachability update")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub2})
|
||||
select {
|
||||
case <-tr.reachabilityUpdateCh:
|
||||
reachable, unreachable, _ := tr.ConfirmedAddrs()
|
||||
require.Empty(t, reachable)
|
||||
require.Empty(t, unreachable)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected reachability update")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("refresh after reset interval", func(t *testing.T) {
|
||||
notify := make(chan struct{}, 1)
|
||||
drainNotify := func() bool {
|
||||
found := false
|
||||
for {
|
||||
select {
|
||||
case <-notify:
|
||||
found = true
|
||||
default:
|
||||
return found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
select {
|
||||
case notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if reqs[0].Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: 0, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cl := clock.NewMock()
|
||||
tr := newTracker(mockClient, cl)
|
||||
|
||||
// update addrs and wait for initial checks
|
||||
tr.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
assertFirstEvent(t, tr, []ma.Multiaddr{pub1})
|
||||
// need to update clock after the background goroutine processes the new addrs
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cl.Add(1)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.True(t, drainNotify()) // check that we did receive probes
|
||||
cl.Add(highConfidenceAddrProbeInterval / 2)
|
||||
select {
|
||||
case <-notify:
|
||||
t.Fatal("unexpected call")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
cl.Add(highConfidenceAddrProbeInterval/2 + defaultReachabilityRefreshInterval) // defaultResetInterval for the next probe time
|
||||
select {
|
||||
case <-notify:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("expected probe")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshReachability(t *testing.T) {
|
||||
pub1 := ma.StringCast("/ip4/1.1.1.1/tcp/1")
|
||||
pub2 := ma.StringCast("/ip4/1.1.1.1/tcp/2")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
newTracker := func(client autonatv2Client, pm *probeManager) *addrsReachabilityTracker {
|
||||
return &addrsReachabilityTracker{
|
||||
probeManager: pm,
|
||||
client: client,
|
||||
clock: clock.New(),
|
||||
maxConcurrency: 3,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
t.Run("backoff on ErrNoValidPeers", func(t *testing.T) {
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return autonatv2.Result{}, autonatv2.ErrNoPeers
|
||||
},
|
||||
}
|
||||
|
||||
addrTracker := newProbeManager(time.Now)
|
||||
addrTracker.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
r := newTracker(mockClient, addrTracker)
|
||||
res := r.refreshReachability()
|
||||
require.True(t, <-res.BackoffCh)
|
||||
require.Equal(t, addrTracker.InProgressProbes(), 0)
|
||||
})
|
||||
|
||||
t.Run("returns backoff on errTooManyConsecutiveFailures", func(t *testing.T) {
|
||||
// Create a client that always returns ErrDialRefused
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
return autonatv2.Result{}, errors.New("test error")
|
||||
},
|
||||
}
|
||||
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
r := newTracker(mockClient, pm)
|
||||
result := r.refreshReachability()
|
||||
require.True(t, <-result.BackoffCh)
|
||||
require.Equal(t, pm.InProgressProbes(), 0)
|
||||
})
|
||||
|
||||
t.Run("quits on cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
block := make(chan struct{})
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, _ []autonatv2.Request) (autonatv2.Result, error) {
|
||||
block <- struct{}{}
|
||||
return autonatv2.Result{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub1})
|
||||
r := &addrsReachabilityTracker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
client: mockClient,
|
||||
probeManager: pm,
|
||||
clock: clock.New(),
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result := r.refreshReachability()
|
||||
assert.False(t, <-result.BackoffCh)
|
||||
assert.Equal(t, pm.InProgressProbes(), 0)
|
||||
}()
|
||||
|
||||
cancel()
|
||||
time.Sleep(50 * time.Millisecond) // wait for the cancellation to be processed
|
||||
|
||||
outer:
|
||||
for i := 0; i < defaultMaxConcurrency; i++ {
|
||||
select {
|
||||
case <-block:
|
||||
default:
|
||||
break outer
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-block:
|
||||
t.Fatal("expected no more requests")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("handles refusals", func(t *testing.T) {
|
||||
pub1, _ := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/1")
|
||||
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
for i, req := range reqs {
|
||||
if req.Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1})
|
||||
r := newTracker(mockClient, pm)
|
||||
|
||||
result := r.refreshReachability()
|
||||
require.False(t, <-result.BackoffCh)
|
||||
|
||||
reachable, unreachable, _ := pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Empty(t, unreachable)
|
||||
require.Equal(t, pm.InProgressProbes(), 0)
|
||||
})
|
||||
|
||||
t.Run("handles completions", func(t *testing.T) {
|
||||
mockClient := mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
for i, req := range reqs {
|
||||
if req.Addr.Equal(pub1) {
|
||||
return autonatv2.Result{Addr: pub1, Idx: i, Reachability: network.ReachabilityPublic}, nil
|
||||
}
|
||||
if req.Addr.Equal(pub2) {
|
||||
return autonatv2.Result{Addr: pub2, Idx: i, Reachability: network.ReachabilityPrivate}, nil
|
||||
}
|
||||
}
|
||||
return autonatv2.Result{AllAddrsRefused: true}, nil
|
||||
},
|
||||
}
|
||||
pm := newProbeManager(time.Now)
|
||||
pm.UpdateAddrs([]ma.Multiaddr{pub2, pub1})
|
||||
r := newTracker(mockClient, pm)
|
||||
result := r.refreshReachability()
|
||||
require.False(t, <-result.BackoffCh)
|
||||
|
||||
reachable, unreachable, _ := pm.AppendConfirmedAddrs(nil, nil, nil)
|
||||
require.Equal(t, reachable, []ma.Multiaddr{pub1})
|
||||
require.Equal(t, unreachable, []ma.Multiaddr{pub2})
|
||||
require.Equal(t, pm.InProgressProbes(), 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddrStatusProbeCount(t *testing.T) {
|
||||
cases := []struct {
|
||||
inputs string
|
||||
wantRequiredProbes int
|
||||
wantReachability network.Reachability
|
||||
}{
|
||||
{
|
||||
inputs: "",
|
||||
wantRequiredProbes: 3,
|
||||
wantReachability: network.ReachabilityUnknown,
|
||||
},
|
||||
{
|
||||
inputs: "S",
|
||||
wantRequiredProbes: 2,
|
||||
wantReachability: network.ReachabilityUnknown,
|
||||
},
|
||||
{
|
||||
inputs: "SS",
|
||||
wantRequiredProbes: 1,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SSS",
|
||||
wantRequiredProbes: 0,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SSSSSSSF",
|
||||
wantRequiredProbes: 1,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SFSFSSSS",
|
||||
wantRequiredProbes: 0,
|
||||
wantReachability: network.ReachabilityPublic,
|
||||
},
|
||||
{
|
||||
inputs: "SSSSSFSF",
|
||||
wantRequiredProbes: 2,
|
||||
wantReachability: network.ReachabilityUnknown,
|
||||
},
|
||||
{
|
||||
inputs: "FF",
|
||||
wantRequiredProbes: 1,
|
||||
wantReachability: network.ReachabilityPrivate,
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.inputs, func(t *testing.T) {
|
||||
now := time.Time{}.Add(1 * time.Second)
|
||||
ao := addrStatus{}
|
||||
for _, r := range c.inputs {
|
||||
if r == 'S' {
|
||||
ao.AddOutcome(now, network.ReachabilityPublic, 5)
|
||||
} else {
|
||||
ao.AddOutcome(now, network.ReachabilityPrivate, 5)
|
||||
}
|
||||
now = now.Add(1 * time.Second)
|
||||
}
|
||||
require.Equal(t, ao.RequiredProbeCount(now), c.wantRequiredProbes)
|
||||
require.Equal(t, ao.Reachability(), c.wantReachability)
|
||||
if c.wantRequiredProbes == 0 {
|
||||
now = now.Add(highConfidenceAddrProbeInterval + 10*time.Microsecond)
|
||||
require.Equal(t, ao.RequiredProbeCount(now), 1)
|
||||
}
|
||||
|
||||
now = now.Add(1 * time.Second)
|
||||
ao.RemoveBefore(now)
|
||||
require.Len(t, ao.outcomes, 0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAddrTracker(b *testing.B) {
|
||||
cl := clock.NewMock()
|
||||
t := newProbeManager(cl.Now)
|
||||
|
||||
addrs := make([]ma.Multiaddr, 20)
|
||||
for i := range addrs {
|
||||
addrs[i] = ma.StringCast(fmt.Sprintf("/ip4/1.1.1.1/tcp/%d", rand.Intn(1000)))
|
||||
}
|
||||
t.UpdateAddrs(addrs)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
p := t.GetProbe()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pp := t.GetProbe()
|
||||
if len(pp) == 0 {
|
||||
pp = p
|
||||
}
|
||||
t.MarkProbeInProgress(pp)
|
||||
t.CompleteProbe(pp, autonatv2.Result{Addr: pp[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzAddrsReachabilityTracker(f *testing.F) {
|
||||
type autonatv2Response struct {
|
||||
Result autonatv2.Result
|
||||
Err error
|
||||
}
|
||||
|
||||
newMockClient := func(b []byte) mockAutoNATClient {
|
||||
count := 0
|
||||
return mockAutoNATClient{
|
||||
F: func(_ context.Context, reqs []autonatv2.Request) (autonatv2.Result, error) {
|
||||
if len(b) == 0 {
|
||||
return autonatv2.Result{}, nil
|
||||
}
|
||||
count = (count + 1) % len(b)
|
||||
if b[count]%3 == 0 {
|
||||
// some address confirmed
|
||||
c1 := (count + 1) % len(b)
|
||||
c2 := (count + 2) % len(b)
|
||||
rch := network.Reachability(b[c1] % 3)
|
||||
n := int(b[c2]) % len(reqs)
|
||||
return autonatv2.Result{
|
||||
Addr: reqs[n].Addr,
|
||||
Idx: n,
|
||||
Reachability: rch,
|
||||
}, nil
|
||||
}
|
||||
outcomes := []autonatv2Response{
|
||||
{Result: autonatv2.Result{AllAddrsRefused: true}},
|
||||
{Err: errors.New("test error")},
|
||||
{Err: autonatv2.ErrPrivateAddrs},
|
||||
{Err: autonatv2.ErrNoPeers},
|
||||
{Result: autonatv2.Result{}, Err: nil},
|
||||
{Result: autonatv2.Result{Addr: reqs[0].Addr, Idx: 0, Reachability: network.ReachabilityPublic}},
|
||||
{Result: autonatv2.Result{
|
||||
Addr: reqs[0].Addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
AllAddrsRefused: true,
|
||||
}},
|
||||
{Result: autonatv2.Result{
|
||||
Addr: reqs[0].Addr,
|
||||
Idx: len(reqs) - 1, // invalid idx
|
||||
Reachability: network.ReachabilityPublic,
|
||||
AllAddrsRefused: false,
|
||||
}},
|
||||
}
|
||||
outcome := outcomes[int(b[count])%len(outcomes)]
|
||||
return outcome.Result, outcome.Err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Move this to go-multiaddrs
|
||||
getProto := func(protos []byte) ma.Multiaddr {
|
||||
protoType := 0
|
||||
if len(protos) > 0 {
|
||||
protoType = int(protos[0])
|
||||
}
|
||||
|
||||
port1, port2 := 0, 0
|
||||
if len(protos) > 1 {
|
||||
port1 = int(protos[1])
|
||||
}
|
||||
if len(protos) > 2 {
|
||||
port2 = int(protos[2])
|
||||
}
|
||||
protoTemplates := []string{
|
||||
"/tcp/%d/",
|
||||
"/udp/%d/",
|
||||
"/udp/%d/quic-v1/",
|
||||
"/udp/%d/quic-v1/tcp/%d",
|
||||
"/udp/%d/quic-v1/webtransport/",
|
||||
"/udp/%d/webrtc/",
|
||||
"/udp/%d/webrtc-direct/",
|
||||
"/unix/hello/",
|
||||
}
|
||||
s := protoTemplates[protoType%len(protoTemplates)]
|
||||
port1 %= (1 << 16)
|
||||
if strings.Count(s, "%d") == 1 {
|
||||
return ma.StringCast(fmt.Sprintf(s, port1))
|
||||
}
|
||||
port2 %= (1 << 16)
|
||||
return ma.StringCast(fmt.Sprintf(s, port1, port2))
|
||||
}
|
||||
|
||||
getIP := func(ips []byte) ma.Multiaddr {
|
||||
ipType := 0
|
||||
if len(ips) > 0 {
|
||||
ipType = int(ips[0])
|
||||
}
|
||||
ips = ips[1:]
|
||||
var x, y int64
|
||||
split := 128 / 8
|
||||
if len(ips) < split {
|
||||
split = len(ips)
|
||||
}
|
||||
var b [8]byte
|
||||
copy(b[:], ips[:split])
|
||||
x = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
clear(b[:])
|
||||
copy(b[:], ips[split:])
|
||||
y = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
|
||||
var ip netip.Addr
|
||||
switch ipType % 3 {
|
||||
case 0:
|
||||
ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)})
|
||||
return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip))
|
||||
case 1:
|
||||
pubIP := net.ParseIP("2005::") // Public IP address
|
||||
x := int64(binary.LittleEndian.Uint64(pubIP[0:8]))
|
||||
ip = netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
default:
|
||||
ip := netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
}
|
||||
}
|
||||
|
||||
getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr {
|
||||
switch addrType % 4 {
|
||||
case 0:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
case 1:
|
||||
return getProto(protos)
|
||||
case 2:
|
||||
return nil
|
||||
default:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
}
|
||||
}
|
||||
|
||||
getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr {
|
||||
hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "")
|
||||
hostName = strings.ReplaceAll(hostName, "/", "")
|
||||
if hostName == "" {
|
||||
hostName = "localhost"
|
||||
}
|
||||
dnsType := 0
|
||||
if len(hostNameBytes) > 0 {
|
||||
dnsType = int(hostNameBytes[0])
|
||||
}
|
||||
dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"}
|
||||
da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName))
|
||||
return da.Encapsulate(getProto(protos))
|
||||
}
|
||||
|
||||
const maxAddrs = 1000
|
||||
getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr {
|
||||
if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs
|
||||
addrs := make([]ma.Multiaddr, numAddrs)
|
||||
ipIdx := 0
|
||||
protoIdx := 0
|
||||
for i := range numAddrs {
|
||||
addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:])
|
||||
ipIdx = (ipIdx + 1) % len(ips)
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
maxDNSAddrs := 10
|
||||
protoIdx = 0
|
||||
for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 {
|
||||
ed := min(i+2, len(hostNames))
|
||||
addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:]))
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
cl := clock.NewMock()
|
||||
f.Fuzz(func(t *testing.T, numAddrs int, ips, protos, hostNames, autonatResponses []byte) {
|
||||
tr := newAddrsReachabilityTracker(newMockClient(autonatResponses), nil, cl)
|
||||
require.NoError(t, tr.Start())
|
||||
tr.UpdateAddrs(getAddrs(numAddrs, ips, protos, hostNames))
|
||||
|
||||
// fuzz tests need to finish in 10 seconds for some reason
|
||||
// https://github.com/golang/go/issues/48157
|
||||
// https://github.com/golang/go/commit/5d24203c394e6b64c42a9f69b990d94cb6c8aad4#diff-4e3b9481b8794eb058998e2bec389d3db7a23c54e67ac0f7259a3a5d2c79fd04R474-R483
|
||||
const maxIters = 20
|
||||
for range maxIters {
|
||||
cl.Add(5 * time.Minute)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
require.NoError(t, tr.Close())
|
||||
})
|
||||
}
|
||||
@@ -156,8 +156,8 @@ type HostOpts struct {
|
||||
|
||||
// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify
|
||||
DisableIdentifyAddressDiscovery bool
|
||||
EnableAutoNATv2 bool
|
||||
AutoNATv2Dialer host.Host
|
||||
|
||||
AutoNATv2 *autonatv2.AutoNAT
|
||||
}
|
||||
|
||||
// NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network.
|
||||
@@ -236,7 +236,16 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
|
||||
}); ok {
|
||||
tfl = s.TransportForListening
|
||||
}
|
||||
h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan)
|
||||
|
||||
if opts.AutoNATv2 != nil {
|
||||
h.autonatv2 = opts.AutoNATv2
|
||||
}
|
||||
|
||||
var autonatv2Client autonatv2Client // avoid typed nil errors
|
||||
if h.autonatv2 != nil {
|
||||
autonatv2Client = h.autonatv2
|
||||
}
|
||||
h.addressManager, err = newAddrsManager(h.eventbus, natmgr, addrFactory, h.Network().ListenAddresses, tfl, h.ids, h.addrsUpdatedChan, autonatv2Client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create address service: %w", err)
|
||||
}
|
||||
@@ -283,17 +292,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
|
||||
h.pings = ping.NewPingService(h)
|
||||
}
|
||||
|
||||
if opts.EnableAutoNATv2 {
|
||||
var mt autonatv2.MetricsTracer
|
||||
if opts.EnableMetrics {
|
||||
mt = autonatv2.NewMetricsTracer(opts.PrometheusRegisterer)
|
||||
}
|
||||
h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer, autonatv2.WithMetricsTracer(mt))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create autonatv2: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !h.disableSignedPeerRecord {
|
||||
h.signKey = h.Peerstore().PrivKey(h.ID())
|
||||
cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore())
|
||||
@@ -320,7 +318,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
|
||||
func (h *BasicHost) Start() {
|
||||
h.psManager.Start()
|
||||
if h.autonatv2 != nil {
|
||||
err := h.autonatv2.Start()
|
||||
err := h.autonatv2.Start(h)
|
||||
if err != nil {
|
||||
log.Errorf("autonat v2 failed to start: %s", err)
|
||||
}
|
||||
@@ -754,6 +752,16 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr {
|
||||
return h.addressManager.DirectAddrs()
|
||||
}
|
||||
|
||||
// ConfirmedAddrs returns all addresses of the host grouped by their reachability
|
||||
// as verified by autonatv2.
|
||||
//
|
||||
// Experimental: This API may change in the future without deprecation.
|
||||
//
|
||||
// Requires AutoNATv2 to be enabled.
|
||||
func (h *BasicHost) ConfirmedAddrs() (reachable []ma.Multiaddr, unreachable []ma.Multiaddr, unknown []ma.Multiaddr) {
|
||||
return h.addressManager.ConfirmedAddrs()
|
||||
}
|
||||
|
||||
func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr {
|
||||
totalSize := 0
|
||||
for _, a := range addrs {
|
||||
@@ -836,7 +844,6 @@ func (h *BasicHost) Close() error {
|
||||
if h.cmgr != nil {
|
||||
h.cmgr.Close()
|
||||
}
|
||||
|
||||
h.addressManager.Close()
|
||||
|
||||
if h.ids != nil {
|
||||
|
||||
@@ -47,6 +47,7 @@ func TestHostSimple(t *testing.T) {
|
||||
h1.Start()
|
||||
h2, err := NewHost(swarmt.GenSwarm(t), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer h2.Close()
|
||||
h2.Start()
|
||||
|
||||
@@ -211,6 +212,7 @@ func TestAllAddrs(t *testing.T) {
|
||||
// no listen addrs
|
||||
h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil)
|
||||
require.NoError(t, err)
|
||||
h.Start()
|
||||
defer h.Close()
|
||||
require.Nil(t, h.AllAddrs())
|
||||
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/x/rate"
|
||||
)
|
||||
|
||||
type ConnLimitPerSubnet struct {
|
||||
@@ -283,3 +286,58 @@ func (cl *connLimiter) rmConn(ip netip.Addr) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handshakeDuration is a higher end estimate of QUIC handshake time
|
||||
const handshakeDuration = 5 * time.Second
|
||||
|
||||
// sourceAddressRPS is the refill rate for the source address verification rate limiter.
|
||||
// A spoofed address if not verified will take a connLimiter token for handshakeDuration.
|
||||
// Slow refill rate here favours increasing latency(because of address verification) in
|
||||
// exchange for reducing the chances of spoofing successfully causing a DoS.
|
||||
const sourceAddressRPS = float64(1.0*time.Second) / (2 * float64(handshakeDuration))
|
||||
|
||||
// newVerifySourceAddressRateLimiter returns a rate limiter for verifying source addresses.
|
||||
// The returned limiter allows maxAllowedConns / 2 unverified addresses to begin handshake.
|
||||
// This ensures that in the event someone is spoofing IPs, 1/2 the maximum allowed connections
|
||||
// will be able to connect, although they will have increased latency because of address
|
||||
// verification.
|
||||
func newVerifySourceAddressRateLimiter(cl *connLimiter) *rate.Limiter {
|
||||
networkPrefixLimits := make([]rate.PrefixLimit, 0, len(cl.networkPrefixLimitV4)+len(cl.networkPrefixLimitV6))
|
||||
for _, l := range cl.networkPrefixLimitV4 {
|
||||
networkPrefixLimits = append(networkPrefixLimits, rate.PrefixLimit{
|
||||
Prefix: l.Network,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: l.ConnCount / 2},
|
||||
})
|
||||
}
|
||||
for _, l := range cl.networkPrefixLimitV6 {
|
||||
networkPrefixLimits = append(networkPrefixLimits, rate.PrefixLimit{
|
||||
Prefix: l.Network,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: l.ConnCount / 2},
|
||||
})
|
||||
}
|
||||
|
||||
ipv4SubnetLimits := make([]rate.SubnetLimit, 0, len(cl.connLimitPerSubnetV4))
|
||||
for _, l := range cl.connLimitPerSubnetV4 {
|
||||
ipv4SubnetLimits = append(ipv4SubnetLimits, rate.SubnetLimit{
|
||||
PrefixLength: l.PrefixLength,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: l.ConnCount / 2},
|
||||
})
|
||||
}
|
||||
|
||||
ipv6SubnetLimits := make([]rate.SubnetLimit, 0, len(cl.connLimitPerSubnetV6))
|
||||
for _, l := range cl.connLimitPerSubnetV6 {
|
||||
ipv6SubnetLimits = append(ipv6SubnetLimits, rate.SubnetLimit{
|
||||
PrefixLength: l.PrefixLength,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: l.ConnCount / 2},
|
||||
})
|
||||
}
|
||||
|
||||
return &rate.Limiter{
|
||||
NetworkPrefixLimits: networkPrefixLimits,
|
||||
SubnetRateLimiter: rate.SubnetLimiter{
|
||||
IPv4SubnetLimits: ipv4SubnetLimits,
|
||||
IPv6SubnetLimits: ipv6SubnetLimits,
|
||||
GracePeriod: 1 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/x/rate"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -228,3 +230,163 @@ func TestSortedNetworkPrefixLimits(t *testing.T) {
|
||||
}
|
||||
require.EqualValues(t, sorted, npLimits)
|
||||
}
|
||||
|
||||
func TestNewVerifySourceAddressRateLimiter(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
cl *connLimiter
|
||||
expected *rate.Limiter
|
||||
}{
|
||||
{
|
||||
name: "basic configuration",
|
||||
cl: &connLimiter{
|
||||
networkPrefixLimitV4: []NetworkPrefixLimit{
|
||||
{
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
ConnCount: 10,
|
||||
},
|
||||
},
|
||||
networkPrefixLimitV6: []NetworkPrefixLimit{
|
||||
{
|
||||
Network: netip.MustParsePrefix("2001:db8::/32"),
|
||||
ConnCount: 20,
|
||||
},
|
||||
},
|
||||
connLimitPerSubnetV4: []ConnLimitPerSubnet{
|
||||
{
|
||||
PrefixLength: 24,
|
||||
ConnCount: 5,
|
||||
},
|
||||
},
|
||||
connLimitPerSubnetV6: []ConnLimitPerSubnet{
|
||||
{
|
||||
PrefixLength: 56,
|
||||
ConnCount: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &rate.Limiter{
|
||||
NetworkPrefixLimits: []rate.PrefixLimit{
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 5},
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("2001:db8::/32"),
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 10},
|
||||
},
|
||||
},
|
||||
SubnetRateLimiter: rate.SubnetLimiter{
|
||||
IPv4SubnetLimits: []rate.SubnetLimit{
|
||||
{
|
||||
PrefixLength: 24,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 2},
|
||||
},
|
||||
},
|
||||
IPv6SubnetLimits: []rate.SubnetLimit{
|
||||
{
|
||||
PrefixLength: 56,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 4},
|
||||
},
|
||||
},
|
||||
GracePeriod: 1 * time.Minute,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty configuration",
|
||||
cl: &connLimiter{},
|
||||
expected: &rate.Limiter{
|
||||
NetworkPrefixLimits: []rate.PrefixLimit{},
|
||||
SubnetRateLimiter: rate.SubnetLimiter{
|
||||
IPv4SubnetLimits: []rate.SubnetLimit{},
|
||||
IPv6SubnetLimits: []rate.SubnetLimit{},
|
||||
GracePeriod: 1 * time.Minute,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple network prefixes",
|
||||
cl: &connLimiter{
|
||||
networkPrefixLimitV4: []NetworkPrefixLimit{
|
||||
{
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
ConnCount: 10,
|
||||
},
|
||||
{
|
||||
Network: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
ConnCount: 20,
|
||||
},
|
||||
},
|
||||
connLimitPerSubnetV4: []ConnLimitPerSubnet{
|
||||
{
|
||||
PrefixLength: 24,
|
||||
ConnCount: 5,
|
||||
},
|
||||
{
|
||||
PrefixLength: 16,
|
||||
ConnCount: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &rate.Limiter{
|
||||
NetworkPrefixLimits: []rate.PrefixLimit{
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 5},
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 10},
|
||||
},
|
||||
},
|
||||
SubnetRateLimiter: rate.SubnetLimiter{
|
||||
IPv4SubnetLimits: []rate.SubnetLimit{
|
||||
{
|
||||
PrefixLength: 24,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 2},
|
||||
},
|
||||
{
|
||||
PrefixLength: 16,
|
||||
Limit: rate.Limit{RPS: sourceAddressRPS, Burst: 5},
|
||||
},
|
||||
},
|
||||
IPv6SubnetLimits: []rate.SubnetLimit{},
|
||||
GracePeriod: 1 * time.Minute,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := newVerifySourceAddressRateLimiter(tc.cl)
|
||||
|
||||
require.Equal(t, len(tc.expected.NetworkPrefixLimits), len(actual.NetworkPrefixLimits))
|
||||
for i, expected := range tc.expected.NetworkPrefixLimits {
|
||||
actual := actual.NetworkPrefixLimits[i]
|
||||
require.Equal(t, expected.Prefix, actual.Prefix)
|
||||
require.Equal(t, expected.RPS, actual.RPS)
|
||||
require.Equal(t, expected.Burst, actual.Burst)
|
||||
}
|
||||
|
||||
require.Equal(t, len(tc.expected.SubnetRateLimiter.IPv4SubnetLimits), len(actual.SubnetRateLimiter.IPv4SubnetLimits))
|
||||
for i, expected := range tc.expected.SubnetRateLimiter.IPv4SubnetLimits {
|
||||
actual := actual.SubnetRateLimiter.IPv4SubnetLimits[i]
|
||||
require.Equal(t, expected.PrefixLength, actual.PrefixLength)
|
||||
require.Equal(t, expected.RPS, actual.RPS)
|
||||
require.Equal(t, expected.Burst, actual.Burst)
|
||||
}
|
||||
|
||||
require.Equal(t, len(tc.expected.SubnetRateLimiter.IPv6SubnetLimits), len(actual.SubnetRateLimiter.IPv6SubnetLimits))
|
||||
for i, expected := range tc.expected.SubnetRateLimiter.IPv6SubnetLimits {
|
||||
actual := actual.SubnetRateLimiter.IPv6SubnetLimits[i]
|
||||
require.Equal(t, expected.PrefixLength, actual.PrefixLength)
|
||||
require.Equal(t, expected.RPS, actual.RPS)
|
||||
require.Equal(t, expected.Burst, actual.Burst)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.expected.SubnetRateLimiter.GracePeriod, actual.SubnetRateLimiter.GracePeriod)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
59
p2p/host/resource-manager/conn_rate_limiter.go
Normal file
59
p2p/host/resource-manager/conn_rate_limiter.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package rcmgr
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/x/rate"
|
||||
)
|
||||
|
||||
var defaultIPv4SubnetLimits = []rate.SubnetLimit{
|
||||
{
|
||||
PrefixLength: 32,
|
||||
Limit: rate.Limit{RPS: 0.2, Burst: 2 * defaultMaxConcurrentConns},
|
||||
},
|
||||
}
|
||||
|
||||
var defaultIPv6SubnetLimits = []rate.SubnetLimit{
|
||||
{
|
||||
PrefixLength: 56,
|
||||
Limit: rate.Limit{RPS: 0.2, Burst: 2 * defaultMaxConcurrentConns},
|
||||
},
|
||||
{
|
||||
PrefixLength: 48,
|
||||
Limit: rate.Limit{RPS: 0.5, Burst: 10 * defaultMaxConcurrentConns},
|
||||
},
|
||||
}
|
||||
|
||||
// defaultNetworkPrefixLimits ensure that all connections on localhost always succeed
|
||||
var defaultNetworkPrefixLimits = []rate.PrefixLimit{
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("127.0.0.0/8"),
|
||||
Limit: rate.Limit{},
|
||||
},
|
||||
{
|
||||
Prefix: netip.MustParsePrefix("::1/128"),
|
||||
Limit: rate.Limit{},
|
||||
},
|
||||
}
|
||||
|
||||
// WithConnRateLimiters sets a custom rate limiter for new connections.
|
||||
// connRateLimiter is used for OpenConnection calls
|
||||
func WithConnRateLimiters(connRateLimiter *rate.Limiter) Option {
|
||||
return func(rm *resourceManager) error {
|
||||
rm.connRateLimiter = connRateLimiter
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func newConnRateLimiter() *rate.Limiter {
|
||||
return &rate.Limiter{
|
||||
NetworkPrefixLimits: defaultNetworkPrefixLimits,
|
||||
GlobalLimit: rate.Limit{},
|
||||
SubnetRateLimiter: rate.SubnetLimiter{
|
||||
IPv4SubnetLimits: defaultIPv4SubnetLimits,
|
||||
IPv6SubnetLimits: defaultIPv6SubnetLimits,
|
||||
GracePeriod: 1 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,9 @@ package rcmgr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
"github.com/libp2p/go-libp2p/x/rate"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
@@ -23,6 +26,8 @@ type resourceManager struct {
|
||||
limits Limiter
|
||||
|
||||
connLimiter *connLimiter
|
||||
connRateLimiter *rate.Limiter
|
||||
verifySourceAddressRateLimiter *rate.Limiter
|
||||
|
||||
trace *trace
|
||||
metrics *metrics
|
||||
@@ -140,6 +145,7 @@ func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager
|
||||
svc: make(map[string]*serviceScope),
|
||||
proto: make(map[protocol.ID]*protocolScope),
|
||||
peer: make(map[peer.ID]*peerScope),
|
||||
connRateLimiter: newConnRateLimiter(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -169,6 +175,7 @@ func NewResourceManager(limits Limiter, opts ...Option) (network.ResourceManager
|
||||
})
|
||||
}
|
||||
}
|
||||
r.verifySourceAddressRateLimiter = newVerifySourceAddressRateLimiter(r.connLimiter)
|
||||
|
||||
if !r.disableMetrics {
|
||||
var sr TraceReporter
|
||||
@@ -338,7 +345,22 @@ func (r *resourceManager) nextStreamId() int64 {
|
||||
return r.streamId
|
||||
}
|
||||
|
||||
// VerifySourceAddress tells the transport to verify the peer's IP address before
|
||||
// initiating a handshake.
|
||||
func (r *resourceManager) VerifySourceAddress(addr net.Addr) bool {
|
||||
if r.verifySourceAddressRateLimiter == nil {
|
||||
return false
|
||||
}
|
||||
ipPort, err := netip.ParseAddrPort(addr.String())
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return !r.verifySourceAddressRateLimiter.Allow(ipPort.Addr())
|
||||
}
|
||||
|
||||
// OpenConnectionNoIP is deprecated and will be removed in the next release
|
||||
//
|
||||
// Deprecated: Use OpenConnection instead
|
||||
func (r *resourceManager) OpenConnectionNoIP(dir network.Direction, usefd bool, endpoint multiaddr.Multiaddr) (network.ConnManagementScope, error) {
|
||||
return r.openConnection(dir, usefd, endpoint, netip.Addr{})
|
||||
}
|
||||
@@ -358,6 +380,10 @@ func (r *resourceManager) OpenConnection(dir network.Direction, usefd bool, endp
|
||||
}
|
||||
|
||||
func (r *resourceManager) openConnection(dir network.Direction, usefd bool, endpoint multiaddr.Multiaddr, ip netip.Addr) (network.ConnManagementScope, error) {
|
||||
if !r.connRateLimiter.Allow(ip) {
|
||||
return nil, errors.New("rate limit exceeded")
|
||||
}
|
||||
|
||||
if ip.IsValid() {
|
||||
if ok := r.connLimiter.addConn(ip); !ok {
|
||||
return nil, fmt.Errorf("connections per ip limit exceeded for %s", endpoint)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package rcmgr
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
"github.com/libp2p/go-libp2p/core/test"
|
||||
"github.com/libp2p/go-libp2p/x/rate"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
@@ -1111,3 +1113,63 @@ func TestAllowlistAndConnLimiterPlayNice(t *testing.T) {
|
||||
require.Equal(t, 1, rcmgr.(*resourceManager).connLimiter.networkPrefixLimitV4[0].ConnCount)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResourceManagerRateLimiting(t *testing.T) {
|
||||
// Create a resource manager with very low rate limits
|
||||
limits := DefaultLimits.AutoScale()
|
||||
limits.system.Conns = 100 // High enough to not be the limiting factor
|
||||
limits.transient.Conns = 100
|
||||
|
||||
// Create limiters with very low RPS
|
||||
limiter := &rate.Limiter{
|
||||
GlobalLimit: rate.Limit{RPS: 0.00001, Burst: 2},
|
||||
}
|
||||
|
||||
rcmgr, err := NewResourceManager(NewFixedLimiter(limits), WithConnRateLimiters(limiter))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rcmgr.Close()
|
||||
|
||||
addr := multiaddr.StringCast("/ip4/1.2.3.4")
|
||||
|
||||
connScope, err := rcmgr.OpenConnection(network.DirInbound, true, addr)
|
||||
require.NoError(t, err)
|
||||
connScope.Done()
|
||||
|
||||
connScope, err = rcmgr.OpenConnection(network.DirInbound, true, addr)
|
||||
require.NoError(t, err)
|
||||
connScope.Done()
|
||||
|
||||
_, err = rcmgr.OpenConnection(network.DirInbound, true, addr)
|
||||
require.ErrorContains(t, err, "rate limit exceeded")
|
||||
}
|
||||
|
||||
func TestVerifySourceAddressRateLimiter(t *testing.T) {
|
||||
limits := DefaultLimits.AutoScale()
|
||||
limits.allowlistedSystem.Conns = 100
|
||||
limits.allowlistedSystem.ConnsInbound = 100
|
||||
limits.allowlistedSystem.ConnsOutbound = 100
|
||||
|
||||
rcmgr, err := NewResourceManager(NewFixedLimiter(limits), WithLimitPerSubnet([]ConnLimitPerSubnet{
|
||||
{PrefixLength: 32, ConnCount: 2},
|
||||
}, []ConnLimitPerSubnet{}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rcmgr.Close()
|
||||
|
||||
na1 := &net.UDPAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Port: 1234,
|
||||
}
|
||||
require.False(t, rcmgr.VerifySourceAddress(na1))
|
||||
require.True(t, rcmgr.VerifySourceAddress(na1))
|
||||
|
||||
na2 := &net.UDPAddr{
|
||||
IP: net.ParseIP("1.2.3.5"),
|
||||
Port: 1234,
|
||||
}
|
||||
require.False(t, rcmgr.VerifySourceAddress(na2))
|
||||
require.True(t, rcmgr.VerifySourceAddress(na2))
|
||||
}
|
||||
|
||||
@@ -4,18 +4,17 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"math/rand/v2"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-libp2p/core/event"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
@@ -35,11 +34,15 @@ const (
|
||||
// maxPeerAddresses is the number of addresses in a dial request the server
|
||||
// will inspect, rest are ignored.
|
||||
maxPeerAddresses = 50
|
||||
|
||||
defaultThrottlePeerDuration = 2 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoValidPeers = errors.New("no valid peers for autonat v2")
|
||||
ErrDialRefused = errors.New("dial refused")
|
||||
// ErrNoPeers is returned when the client knows no autonatv2 servers.
|
||||
ErrNoPeers = errors.New("no peers for autonat v2")
|
||||
// ErrPrivateAddrs is returned when the request has private IP addresses.
|
||||
ErrPrivateAddrs = errors.New("private addresses cannot be verified with autonatv2")
|
||||
|
||||
log = logging.Logger("autonatv2")
|
||||
)
|
||||
@@ -56,10 +59,12 @@ type Request struct {
|
||||
type Result struct {
|
||||
// Addr is the dialed address
|
||||
Addr ma.Multiaddr
|
||||
// Reachability of the dialed address
|
||||
// Idx is the index of the address that was dialed
|
||||
Idx int
|
||||
// Reachability is the reachability for `Addr`
|
||||
Reachability network.Reachability
|
||||
// Status is the outcome of the dialback
|
||||
Status pb.DialStatus
|
||||
// AllAddrsRefused is true when the server refused to dial all the addresses in the request.
|
||||
AllAddrsRefused bool
|
||||
}
|
||||
|
||||
// AutoNAT implements the AutoNAT v2 client and server.
|
||||
@@ -78,6 +83,10 @@ type AutoNAT struct {
|
||||
|
||||
mx sync.Mutex
|
||||
peers *peersMap
|
||||
throttlePeer map[peer.ID]time.Time
|
||||
// throttlePeerDuration is the duration to wait before making another dial request to the
|
||||
// same server.
|
||||
throttlePeerDuration time.Duration
|
||||
// allowPrivateAddrs enables using private and localhost addresses for reachability checks.
|
||||
// This is only useful for testing.
|
||||
allowPrivateAddrs bool
|
||||
@@ -86,7 +95,7 @@ type AutoNAT struct {
|
||||
// New returns a new AutoNAT instance.
|
||||
// host and dialerHost should have the same dialing capabilities. In case the host doesn't support
|
||||
// a transport, dial back requests for address for that transport will be ignored.
|
||||
func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
|
||||
func New(dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) {
|
||||
s := defaultSettings()
|
||||
for _, o := range opts {
|
||||
if err := o(s); err != nil {
|
||||
@@ -96,18 +105,20 @@ func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT,
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
an := &AutoNAT{
|
||||
host: host,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
srv: newServer(host, dialerHost, s),
|
||||
cli: newClient(host),
|
||||
srv: newServer(dialerHost, s),
|
||||
cli: newClient(),
|
||||
allowPrivateAddrs: s.allowPrivateAddrs,
|
||||
peers: newPeersMap(),
|
||||
throttlePeer: make(map[peer.ID]time.Time),
|
||||
throttlePeerDuration: s.throttlePeerDuration,
|
||||
}
|
||||
return an, nil
|
||||
}
|
||||
|
||||
func (an *AutoNAT) background(sub event.Subscription) {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-an.ctx.Done():
|
||||
@@ -122,12 +133,24 @@ func (an *AutoNAT) background(sub event.Subscription) {
|
||||
an.updatePeer(evt.Peer)
|
||||
case event.EvtPeerIdentificationCompleted:
|
||||
an.updatePeer(evt.Peer)
|
||||
default:
|
||||
log.Errorf("unexpected event: %T", e)
|
||||
}
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
an.mx.Lock()
|
||||
for p, t := range an.throttlePeer {
|
||||
if t.Before(now) {
|
||||
delete(an.throttlePeer, p)
|
||||
}
|
||||
}
|
||||
an.mx.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (an *AutoNAT) Start() error {
|
||||
func (an *AutoNAT) Start(h host.Host) error {
|
||||
an.host = h
|
||||
// Listen on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged
|
||||
// event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers.
|
||||
sub, err := an.host.EventBus().Subscribe([]interface{}{
|
||||
@@ -138,8 +161,8 @@ func (an *AutoNAT) Start() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("event subscription failed: %w", err)
|
||||
}
|
||||
an.cli.Start()
|
||||
an.srv.Start()
|
||||
an.cli.Start(h)
|
||||
an.srv.Start(h)
|
||||
|
||||
an.wg.Add(1)
|
||||
go an.background(sub)
|
||||
@@ -156,24 +179,48 @@ func (an *AutoNAT) Close() {
|
||||
|
||||
// GetReachability makes a single dial request for checking reachability for requested addresses
|
||||
func (an *AutoNAT) GetReachability(ctx context.Context, reqs []Request) (Result, error) {
|
||||
var filteredReqs []Request
|
||||
if !an.allowPrivateAddrs {
|
||||
filteredReqs = make([]Request, 0, len(reqs))
|
||||
for _, r := range reqs {
|
||||
if !manet.IsPublicAddr(r.Addr) {
|
||||
return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr)
|
||||
if manet.IsPublicAddr(r.Addr) {
|
||||
filteredReqs = append(filteredReqs, r)
|
||||
} else {
|
||||
log.Errorf("private address in reachability check: %s", r.Addr)
|
||||
}
|
||||
}
|
||||
if len(filteredReqs) == 0 {
|
||||
return Result{}, ErrPrivateAddrs
|
||||
}
|
||||
} else {
|
||||
filteredReqs = reqs
|
||||
}
|
||||
an.mx.Lock()
|
||||
p := an.peers.GetRand()
|
||||
now := time.Now()
|
||||
var p peer.ID
|
||||
for pr := range an.peers.Shuffled() {
|
||||
if t := an.throttlePeer[pr]; t.After(now) {
|
||||
continue
|
||||
}
|
||||
p = pr
|
||||
an.throttlePeer[p] = time.Now().Add(an.throttlePeerDuration)
|
||||
break
|
||||
}
|
||||
an.mx.Unlock()
|
||||
if p == "" {
|
||||
return Result{}, ErrNoValidPeers
|
||||
return Result{}, ErrNoPeers
|
||||
}
|
||||
|
||||
res, err := an.cli.GetReachability(ctx, p, reqs)
|
||||
res, err := an.cli.GetReachability(ctx, p, filteredReqs)
|
||||
if err != nil {
|
||||
log.Debugf("reachability check with %s failed, err: %s", p, err)
|
||||
return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err)
|
||||
return res, fmt.Errorf("reachability check with %s failed: %w", p, err)
|
||||
}
|
||||
// restore the correct index in case we'd filtered private addresses
|
||||
for i, r := range reqs {
|
||||
if r.Addr.Equal(res.Addr) {
|
||||
res.Idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Debugf("reachability check with %s successful", p)
|
||||
return res, nil
|
||||
@@ -187,7 +234,7 @@ func (an *AutoNAT) updatePeer(p peer.ID) {
|
||||
// and swarm for the current state
|
||||
protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol)
|
||||
connectedness := an.host.Network().Connectedness(p)
|
||||
if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected {
|
||||
if err == nil && connectedness == network.Connected && slices.Contains(protos, DialProtocol) {
|
||||
an.peers.Put(p)
|
||||
} else {
|
||||
an.peers.Delete(p)
|
||||
@@ -208,28 +255,40 @@ func newPeersMap() *peersMap {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *peersMap) GetRand() peer.ID {
|
||||
if len(p.peers) == 0 {
|
||||
return ""
|
||||
// Shuffled iterates over the map in random order
|
||||
func (p *peersMap) Shuffled() iter.Seq[peer.ID] {
|
||||
n := len(p.peers)
|
||||
start := 0
|
||||
if n > 0 {
|
||||
start = rand.IntN(n)
|
||||
}
|
||||
return p.peers[rand.IntN(len(p.peers))]
|
||||
}
|
||||
|
||||
func (p *peersMap) Put(pid peer.ID) {
|
||||
if _, ok := p.peerIdx[pid]; ok {
|
||||
return func(yield func(peer.ID) bool) {
|
||||
for i := range n {
|
||||
if !yield(p.peers[(i+start)%n]) {
|
||||
return
|
||||
}
|
||||
p.peers = append(p.peers, pid)
|
||||
p.peerIdx[pid] = len(p.peers) - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *peersMap) Delete(pid peer.ID) {
|
||||
idx, ok := p.peerIdx[pid]
|
||||
func (p *peersMap) Put(id peer.ID) {
|
||||
if _, ok := p.peerIdx[id]; ok {
|
||||
return
|
||||
}
|
||||
p.peers = append(p.peers, id)
|
||||
p.peerIdx[id] = len(p.peers) - 1
|
||||
}
|
||||
|
||||
func (p *peersMap) Delete(id peer.ID) {
|
||||
idx, ok := p.peerIdx[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.peers[idx] = p.peers[len(p.peers)-1]
|
||||
p.peerIdx[p.peers[idx]] = idx
|
||||
p.peers = p.peers[:len(p.peers)-1]
|
||||
delete(p.peerIdx, pid)
|
||||
n := len(p.peers)
|
||||
lastPeer := p.peers[n-1]
|
||||
p.peers[idx] = lastPeer
|
||||
p.peerIdx[lastPeer] = idx
|
||||
p.peers[n-1] = ""
|
||||
p.peers = p.peers[:n-1]
|
||||
delete(p.peerIdx, id)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,13 @@ package autonatv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -36,11 +41,12 @@ func newAutoNAT(t testing.TB, dialer host.Host, opts ...AutoNATOption) *AutoNAT
|
||||
swarm.WithUDPBlackHoleSuccessCounter(nil),
|
||||
swarm.WithIPv6BlackHoleSuccessCounter(nil))))
|
||||
}
|
||||
an, err := New(h, dialer, opts...)
|
||||
opts = append([]AutoNATOption{withThrottlePeerDuration(0)}, opts...)
|
||||
an, err := New(dialer, opts...)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
an.Start()
|
||||
require.NoError(t, an.Start(h))
|
||||
t.Cleanup(an.Close)
|
||||
return an
|
||||
}
|
||||
@@ -74,7 +80,7 @@ func waitForPeer(t testing.TB, a *AutoNAT) {
|
||||
require.Eventually(t, func() bool {
|
||||
a.mx.Lock()
|
||||
defer a.mx.Unlock()
|
||||
return a.peers.GetRand() != ""
|
||||
return len(a.peers.peers) != 0
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -88,7 +94,7 @@ func TestAutoNATPrivateAddr(t *testing.T) {
|
||||
an := newAutoNAT(t, nil)
|
||||
res, err := an.GetReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}})
|
||||
require.Equal(t, res, Result{})
|
||||
require.Contains(t, err.Error(), "private address cannot be verified by autonatv2")
|
||||
require.ErrorIs(t, err, ErrPrivateAddrs)
|
||||
}
|
||||
|
||||
func TestClientRequest(t *testing.T) {
|
||||
@@ -154,19 +160,6 @@ func TestClientServerError(t *testing.T) {
|
||||
},
|
||||
errorStr: "invalid msg type",
|
||||
},
|
||||
{
|
||||
handler: func(s network.Stream) {
|
||||
w := pbio.NewDelimitedWriter(s)
|
||||
assert.NoError(t, w.WriteMsg(
|
||||
&pb.Message{Msg: &pb.Message_DialResponse{
|
||||
DialResponse: &pb.DialResponse{
|
||||
Status: pb.DialResponse_E_DIAL_REFUSED,
|
||||
},
|
||||
}},
|
||||
))
|
||||
},
|
||||
errorStr: ErrDialRefused.Error(),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
@@ -298,6 +291,49 @@ func TestClientDataRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoNATPrivateAndPublicAddrs(t *testing.T) {
|
||||
an := newAutoNAT(t, nil)
|
||||
defer an.Close()
|
||||
defer an.host.Close()
|
||||
|
||||
b := bhost.NewBlankHost(swarmt.GenSwarm(t))
|
||||
defer b.Close()
|
||||
idAndConnect(t, an.host, b)
|
||||
waitForPeer(t, an)
|
||||
|
||||
dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t))
|
||||
defer dialerHost.Close()
|
||||
handler := func(s network.Stream) {
|
||||
w := pbio.NewDelimitedWriter(s)
|
||||
r := pbio.NewDelimitedReader(s, maxMsgSize)
|
||||
var msg pb.Message
|
||||
assert.NoError(t, r.ReadMsg(&msg))
|
||||
w.WriteMsg(&pb.Message{
|
||||
Msg: &pb.Message_DialResponse{
|
||||
DialResponse: &pb.DialResponse{
|
||||
Status: pb.DialResponse_OK,
|
||||
DialStatus: pb.DialStatus_E_DIAL_ERROR,
|
||||
AddrIdx: 0,
|
||||
},
|
||||
},
|
||||
})
|
||||
s.Close()
|
||||
}
|
||||
|
||||
b.SetStreamHandler(DialProtocol, handler)
|
||||
privateAddr := ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")
|
||||
publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/10/quic-v1")
|
||||
res, err := an.GetReachability(context.Background(),
|
||||
[]Request{
|
||||
{Addr: privateAddr},
|
||||
{Addr: publicAddr},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res.Addr, publicAddr, "%s\n%s", res.Addr, publicAddr)
|
||||
require.Equal(t, res.Idx, 1)
|
||||
require.Equal(t, res.Reachability, network.ReachabilityPrivate)
|
||||
}
|
||||
|
||||
func TestClientDialBacks(t *testing.T) {
|
||||
an := newAutoNAT(t, nil, allowPrivateAddrs)
|
||||
defer an.Close()
|
||||
@@ -507,7 +543,6 @@ func TestClientDialBacks(t *testing.T) {
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, res.Reachability, network.ReachabilityPublic)
|
||||
require.Equal(t, res.Status, pb.DialStatus_OK)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -551,46 +586,6 @@ func TestEventSubscription(t *testing.T) {
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestPeersMap(t *testing.T) {
|
||||
emptyPeerID := peer.ID("")
|
||||
|
||||
t.Run("single_item", func(t *testing.T) {
|
||||
p := newPeersMap()
|
||||
p.Put("peer1")
|
||||
p.Delete("peer1")
|
||||
p.Put("peer1")
|
||||
require.Equal(t, peer.ID("peer1"), p.GetRand())
|
||||
p.Delete("peer1")
|
||||
require.Equal(t, emptyPeerID, p.GetRand())
|
||||
})
|
||||
|
||||
t.Run("multiple_items", func(t *testing.T) {
|
||||
p := newPeersMap()
|
||||
require.Equal(t, emptyPeerID, p.GetRand())
|
||||
|
||||
allPeers := make(map[peer.ID]bool)
|
||||
for i := 0; i < 20; i++ {
|
||||
pid := peer.ID(fmt.Sprintf("peer-%d", i))
|
||||
allPeers[pid] = true
|
||||
p.Put(pid)
|
||||
}
|
||||
foundPeers := make(map[peer.ID]bool)
|
||||
for i := 0; i < 1000; i++ {
|
||||
pid := p.GetRand()
|
||||
require.NotEqual(t, emptyPeerID, p)
|
||||
require.True(t, allPeers[pid])
|
||||
foundPeers[pid] = true
|
||||
if len(foundPeers) == len(allPeers) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for pid := range allPeers {
|
||||
p.Delete(pid)
|
||||
}
|
||||
require.Equal(t, emptyPeerID, p.GetRand())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAreAddrsConsistency(t *testing.T) {
|
||||
c := &client{
|
||||
normalizeMultiaddr: func(a ma.Multiaddr) ma.Multiaddr {
|
||||
@@ -645,6 +640,12 @@ func TestAreAddrsConsistency(t *testing.T) {
|
||||
dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"),
|
||||
success: false,
|
||||
},
|
||||
{
|
||||
name: "dns6",
|
||||
localAddr: ma.StringCast("/dns6/lib.p2p/udp/12345/quic-v1"),
|
||||
dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/"),
|
||||
success: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -658,3 +659,173 @@ func TestAreAddrsConsistency(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerMap(t *testing.T) {
|
||||
pm := newPeersMap()
|
||||
// Add 1, 2, 3
|
||||
pm.Put(peer.ID("1"))
|
||||
pm.Put(peer.ID("2"))
|
||||
pm.Put(peer.ID("3"))
|
||||
|
||||
// Remove 3, 2
|
||||
pm.Delete(peer.ID("3"))
|
||||
pm.Delete(peer.ID("2"))
|
||||
|
||||
// Add 4
|
||||
pm.Put(peer.ID("4"))
|
||||
|
||||
// Remove 3, 2 again. Should be no op
|
||||
pm.Delete(peer.ID("3"))
|
||||
pm.Delete(peer.ID("2"))
|
||||
|
||||
contains := []peer.ID{"1", "4"}
|
||||
elems := make([]peer.ID, 0)
|
||||
for p := range pm.Shuffled() {
|
||||
elems = append(elems, p)
|
||||
}
|
||||
require.ElementsMatch(t, contains, elems)
|
||||
}
|
||||
|
||||
func FuzzClient(f *testing.F) {
|
||||
a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2))
|
||||
c := newAutoNAT(f, nil)
|
||||
idAndWait(f, c, a)
|
||||
|
||||
// TODO: Move this to go-multiaddrs
|
||||
getProto := func(protos []byte) ma.Multiaddr {
|
||||
protoType := 0
|
||||
if len(protos) > 0 {
|
||||
protoType = int(protos[0])
|
||||
}
|
||||
|
||||
port1, port2 := 0, 0
|
||||
if len(protos) > 1 {
|
||||
port1 = int(protos[1])
|
||||
}
|
||||
if len(protos) > 2 {
|
||||
port2 = int(protos[2])
|
||||
}
|
||||
protoTemplates := []string{
|
||||
"/tcp/%d/",
|
||||
"/udp/%d/",
|
||||
"/udp/%d/quic-v1/",
|
||||
"/udp/%d/quic-v1/tcp/%d",
|
||||
"/udp/%d/quic-v1/webtransport/",
|
||||
"/udp/%d/webrtc/",
|
||||
"/udp/%d/webrtc-direct/",
|
||||
"/unix/hello/",
|
||||
}
|
||||
s := protoTemplates[protoType%len(protoTemplates)]
|
||||
port1 %= (1 << 16)
|
||||
if strings.Count(s, "%d") == 1 {
|
||||
return ma.StringCast(fmt.Sprintf(s, port1))
|
||||
}
|
||||
port2 %= (1 << 16)
|
||||
return ma.StringCast(fmt.Sprintf(s, port1, port2))
|
||||
}
|
||||
|
||||
getIP := func(ips []byte) ma.Multiaddr {
|
||||
ipType := 0
|
||||
if len(ips) > 0 {
|
||||
ipType = int(ips[0])
|
||||
}
|
||||
ips = ips[1:]
|
||||
var x, y int64
|
||||
split := 128 / 8
|
||||
if len(ips) < split {
|
||||
split = len(ips)
|
||||
}
|
||||
var b [8]byte
|
||||
copy(b[:], ips[:split])
|
||||
x = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
clear(b[:])
|
||||
copy(b[:], ips[split:])
|
||||
y = int64(binary.LittleEndian.Uint64(b[:]))
|
||||
|
||||
var ip netip.Addr
|
||||
switch ipType % 3 {
|
||||
case 0:
|
||||
ip = netip.AddrFrom4([4]byte{byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24)})
|
||||
return ma.StringCast(fmt.Sprintf("/ip4/%s/", ip))
|
||||
case 1:
|
||||
pubIP := net.ParseIP("2005::") // Public IP address
|
||||
x := int64(binary.LittleEndian.Uint64(pubIP[0:8]))
|
||||
ip = netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
default:
|
||||
ip := netip.AddrFrom16([16]byte{
|
||||
byte(x), byte(x >> 8), byte(x >> 16), byte(x >> 24),
|
||||
byte(x >> 32), byte(x >> 40), byte(x >> 48), byte(x >> 56),
|
||||
byte(y), byte(y >> 8), byte(y >> 16), byte(y >> 24),
|
||||
byte(y >> 32), byte(y >> 40), byte(y >> 48), byte(y >> 56),
|
||||
})
|
||||
return ma.StringCast(fmt.Sprintf("/ip6/%s/", ip))
|
||||
}
|
||||
}
|
||||
|
||||
getAddr := func(addrType int, ips, protos []byte) ma.Multiaddr {
|
||||
switch addrType % 4 {
|
||||
case 0:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
case 1:
|
||||
return getProto(protos)
|
||||
case 2:
|
||||
return nil
|
||||
default:
|
||||
return getIP(ips).Encapsulate(getProto(protos))
|
||||
}
|
||||
}
|
||||
|
||||
getDNSAddr := func(hostNameBytes, protos []byte) ma.Multiaddr {
|
||||
hostName := strings.ReplaceAll(string(hostNameBytes), "\\", "")
|
||||
hostName = strings.ReplaceAll(hostName, "/", "")
|
||||
if hostName == "" {
|
||||
hostName = "localhost"
|
||||
}
|
||||
dnsType := 0
|
||||
if len(hostNameBytes) > 0 {
|
||||
dnsType = int(hostNameBytes[0])
|
||||
}
|
||||
dnsProtos := []string{"dns", "dns4", "dns6", "dnsaddr"}
|
||||
da := ma.StringCast(fmt.Sprintf("/%s/%s/", dnsProtos[dnsType%len(dnsProtos)], hostName))
|
||||
return da.Encapsulate(getProto(protos))
|
||||
}
|
||||
|
||||
const maxAddrs = 100
|
||||
getAddrs := func(numAddrs int, ips, protos, hostNames []byte) []ma.Multiaddr {
|
||||
if len(ips) == 0 || len(protos) == 0 || len(hostNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
numAddrs = ((numAddrs % maxAddrs) + maxAddrs) % maxAddrs
|
||||
addrs := make([]ma.Multiaddr, numAddrs)
|
||||
ipIdx := 0
|
||||
protoIdx := 0
|
||||
for i := range numAddrs {
|
||||
addrs[i] = getAddr(i, ips[ipIdx:], protos[protoIdx:])
|
||||
ipIdx = (ipIdx + 1) % len(ips)
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
maxDNSAddrs := 10
|
||||
protoIdx = 0
|
||||
for i := 0; i < len(hostNames) && i < maxDNSAddrs; i += 2 {
|
||||
ed := min(i+2, len(hostNames))
|
||||
addrs = append(addrs, getDNSAddr(hostNames[i:ed], protos[protoIdx:]))
|
||||
protoIdx = (protoIdx + 1) % len(protos)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
// reduce the streamTimeout before running this. TODO: fix this
|
||||
f.Fuzz(func(_ *testing.T, numAddrs int, ips, protos, hostNames []byte) {
|
||||
addrs := getAddrs(numAddrs, ips, protos, hostNames)
|
||||
reqs := make([]Request, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
reqs[i] = Request{Addr: addr, SendDialData: true}
|
||||
}
|
||||
c.GetReachability(context.Background(), reqs)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -35,20 +35,20 @@ type normalizeMultiaddrer interface {
|
||||
NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr
|
||||
}
|
||||
|
||||
func newClient(h host.Host) *client {
|
||||
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
|
||||
if hn, ok := h.(normalizeMultiaddrer); ok {
|
||||
normalizeMultiaddr = hn.NormalizeMultiaddr
|
||||
}
|
||||
func newClient() *client {
|
||||
return &client{
|
||||
host: h,
|
||||
dialData: make([]byte, 4000),
|
||||
normalizeMultiaddr: normalizeMultiaddr,
|
||||
dialBackQueues: make(map[uint64]chan ma.Multiaddr),
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *client) Start() {
|
||||
func (ac *client) Start(h host.Host) {
|
||||
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
|
||||
if hn, ok := h.(normalizeMultiaddrer); ok {
|
||||
normalizeMultiaddr = hn.NormalizeMultiaddr
|
||||
}
|
||||
ac.host = h
|
||||
ac.normalizeMultiaddr = normalizeMultiaddr
|
||||
ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack)
|
||||
}
|
||||
|
||||
@@ -109,9 +109,9 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
break
|
||||
// provide dial data if appropriate
|
||||
case msg.GetDialDataRequest() != nil:
|
||||
if err := ac.validateDialDataRequest(reqs, &msg); err != nil {
|
||||
if err := validateDialDataRequest(reqs, &msg); err != nil {
|
||||
s.Reset()
|
||||
return Result{}, fmt.Errorf("invalid dial data request: %w", err)
|
||||
return Result{}, fmt.Errorf("invalid dial data request: %s %w", s.Conn().RemoteMultiaddr(), err)
|
||||
}
|
||||
// dial data request is valid and we want to send data
|
||||
if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil {
|
||||
@@ -136,7 +136,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
// E_DIAL_REFUSED has implication for deciding future address verificiation priorities
|
||||
// wrap a distinct error for convenient errors.Is usage
|
||||
if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED {
|
||||
return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused)
|
||||
return Result{AllAddrsRefused: true}, nil
|
||||
}
|
||||
return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(),
|
||||
pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())])
|
||||
@@ -147,7 +147,6 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
if int(resp.AddrIdx) >= len(reqs) {
|
||||
return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs))
|
||||
}
|
||||
|
||||
// wait for nonce from the server
|
||||
var dialBackAddr ma.Multiaddr
|
||||
if resp.GetDialStatus() == pb.DialStatus_OK {
|
||||
@@ -163,7 +162,7 @@ func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request
|
||||
return ac.newResult(resp, reqs, dialBackAddr)
|
||||
}
|
||||
|
||||
func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error {
|
||||
func validateDialDataRequest(reqs []Request, msg *pb.Message) error {
|
||||
idx := int(msg.GetDialDataRequest().AddrIdx)
|
||||
if idx >= len(reqs) { // invalid address index
|
||||
return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs))
|
||||
@@ -179,9 +178,13 @@ func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error
|
||||
|
||||
func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) {
|
||||
idx := int(resp.AddrIdx)
|
||||
if idx >= len(reqs) {
|
||||
// This should have been validated by this point, but checking this is cheap.
|
||||
return Result{}, fmt.Errorf("addrs index(%d) greater than len(reqs)(%d)", idx, len(reqs))
|
||||
}
|
||||
addr := reqs[idx].Addr
|
||||
|
||||
var rch network.Reachability
|
||||
rch := network.ReachabilityUnknown //nolint:ineffassign
|
||||
switch resp.DialStatus {
|
||||
case pb.DialStatus_OK:
|
||||
if !ac.areAddrsConsistent(dialBackAddr, addr) {
|
||||
@@ -191,17 +194,16 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr
|
||||
return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
|
||||
}
|
||||
rch = network.ReachabilityPublic
|
||||
case pb.DialStatus_E_DIAL_ERROR:
|
||||
rch = network.ReachabilityPrivate
|
||||
case pb.DialStatus_E_DIAL_BACK_ERROR:
|
||||
if ac.areAddrsConsistent(dialBackAddr, addr) {
|
||||
if !ac.areAddrsConsistent(dialBackAddr, addr) {
|
||||
return Result{}, fmt.Errorf("dial-back stream error: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
|
||||
}
|
||||
// We received the dial back but the server claims the dial back errored.
|
||||
// As long as we received the correct nonce in dial back it is safe to assume
|
||||
// that we are public.
|
||||
rch = network.ReachabilityPublic
|
||||
} else {
|
||||
rch = network.ReachabilityUnknown
|
||||
}
|
||||
case pb.DialStatus_E_DIAL_ERROR:
|
||||
rch = network.ReachabilityPrivate
|
||||
default:
|
||||
// Unexpected response code. Discard the response and fail.
|
||||
log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus)
|
||||
@@ -210,8 +212,8 @@ func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr
|
||||
|
||||
return Result{
|
||||
Addr: addr,
|
||||
Idx: idx,
|
||||
Reachability: rch,
|
||||
Status: resp.DialStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -307,7 +309,7 @@ func (ac *client) handleDialBack(s network.Stream) {
|
||||
}
|
||||
|
||||
func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool {
|
||||
if connLocalAddr == nil || dialedAddr == nil {
|
||||
if len(connLocalAddr) == 0 || len(dialedAddr) == 0 {
|
||||
return false
|
||||
}
|
||||
connLocalAddr = ac.normalizeMultiaddr(connLocalAddr)
|
||||
@@ -318,33 +320,32 @@ func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) boo
|
||||
if len(localProtos) != len(externalProtos) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(localProtos); i++ {
|
||||
for i, lp := range localProtos {
|
||||
ep := externalProtos[i]
|
||||
if i == 0 {
|
||||
switch externalProtos[i].Code {
|
||||
switch ep.Code {
|
||||
case ma.P_DNS, ma.P_DNSADDR:
|
||||
if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 {
|
||||
if lp.Code == ma.P_IP4 || lp.Code == ma.P_IP6 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
case ma.P_DNS4:
|
||||
if localProtos[i].Code == ma.P_IP4 {
|
||||
if lp.Code == ma.P_IP4 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
case ma.P_DNS6:
|
||||
if localProtos[i].Code == ma.P_IP6 {
|
||||
if lp.Code == ma.P_IP6 {
|
||||
continue
|
||||
}
|
||||
return false
|
||||
}
|
||||
if localProtos[i].Code != externalProtos[i].Code {
|
||||
if lp.Code != ep.Code {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if localProtos[i].Code != externalProtos[i].Code {
|
||||
} else if lp.Code != ep.Code {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type autoNATSettings struct {
|
||||
now func() time.Time
|
||||
amplificatonAttackPreventionDialWait time.Duration
|
||||
metricsTracer MetricsTracer
|
||||
throttlePeerDuration time.Duration
|
||||
}
|
||||
|
||||
func defaultSettings() *autoNATSettings {
|
||||
@@ -25,6 +26,7 @@ func defaultSettings() *autoNATSettings {
|
||||
dataRequestPolicy: amplificationAttackPrevention,
|
||||
amplificatonAttackPreventionDialWait: 3 * time.Second,
|
||||
now: time.Now,
|
||||
throttlePeerDuration: defaultThrottlePeerDuration,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,3 +67,10 @@ func withAmplificationAttackPreventionDialWait(d time.Duration) AutoNATOption {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func withThrottlePeerDuration(d time.Duration) AutoNATOption {
|
||||
return func(s *autoNATSettings) error {
|
||||
s.throttlePeerDuration = d
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,10 +59,9 @@ type server struct {
|
||||
allowPrivateAddrs bool
|
||||
}
|
||||
|
||||
func newServer(host, dialer host.Host, s *autoNATSettings) *server {
|
||||
func newServer(dialer host.Host, s *autoNATSettings) *server {
|
||||
return &server{
|
||||
dialerHost: dialer,
|
||||
host: host,
|
||||
dialDataRequestPolicy: s.dataRequestPolicy,
|
||||
amplificatonAttackPreventionDialWait: s.amplificatonAttackPreventionDialWait,
|
||||
allowPrivateAddrs: s.allowPrivateAddrs,
|
||||
@@ -79,7 +78,8 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server {
|
||||
}
|
||||
|
||||
// Enable attaches the stream handler to the host.
|
||||
func (as *server) Start() {
|
||||
func (as *server) Start(h host.Host) {
|
||||
as.host = h
|
||||
as.host.SetStreamHandler(DialProtocol, as.handleDialRequest)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -46,8 +47,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
idAndWait(t, c, an)
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("black holed addr", func(t *testing.T) {
|
||||
@@ -64,8 +65,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"),
|
||||
SendDialData: true,
|
||||
}})
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("private addrs", func(t *testing.T) {
|
||||
@@ -76,8 +77,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
idAndWait(t, c, an)
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("relay addrs", func(t *testing.T) {
|
||||
@@ -89,8 +90,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(
|
||||
[]ma.Multiaddr{ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p/%s/p2p-circuit/p2p/%s", c.host.ID(), c.srv.dialerHost.ID()))}, true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("no addr", func(t *testing.T) {
|
||||
@@ -113,8 +114,8 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
idAndWait(t, c, an)
|
||||
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(addrs, true))
|
||||
require.ErrorIs(t, err, ErrDialRefused)
|
||||
require.Equal(t, Result{}, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{AllAddrsRefused: true}, res)
|
||||
})
|
||||
|
||||
t.Run("msg too large", func(t *testing.T) {
|
||||
@@ -135,7 +136,6 @@ func TestServerInvalidAddrsRejected(t *testing.T) {
|
||||
require.ErrorIs(t, err, network.ErrReset)
|
||||
require.Equal(t, Result{}, res)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestServerDataRequest(t *testing.T) {
|
||||
@@ -178,8 +178,8 @@ func TestServerDataRequest(t *testing.T) {
|
||||
|
||||
require.Equal(t, Result{
|
||||
Addr: quicAddr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
|
||||
// Small messages should be rejected for dial data
|
||||
@@ -191,14 +191,11 @@ func TestServerDataRequest(t *testing.T) {
|
||||
func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
|
||||
const concurrentRequests = 5
|
||||
|
||||
// server will skip all tcp addresses
|
||||
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy(
|
||||
stallChan := make(chan struct{})
|
||||
an := newAutoNAT(t, nil, allowPrivateAddrs, withDataRequestPolicy(
|
||||
// stall all allowed requests
|
||||
func(_, _ ma.Multiaddr) bool {
|
||||
<-doneChan
|
||||
<-stallChan
|
||||
return true
|
||||
}),
|
||||
WithServerRateLimit(10, 10, 10, concurrentRequests),
|
||||
@@ -207,16 +204,18 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
|
||||
defer an.Close()
|
||||
defer an.host.Close()
|
||||
|
||||
c := newAutoNAT(t, nil, allowPrivateAddrs)
|
||||
// server will skip all tcp addresses
|
||||
dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
|
||||
c := newAutoNAT(t, dialer, allowPrivateAddrs)
|
||||
defer c.Close()
|
||||
defer c.host.Close()
|
||||
|
||||
idAndWait(t, c, an)
|
||||
|
||||
errChan := make(chan error)
|
||||
const N = 10
|
||||
// num concurrentRequests will stall and N will fail
|
||||
for i := 0; i < concurrentRequests+N; i++ {
|
||||
const n = 10
|
||||
// num concurrentRequests will stall and n will fail
|
||||
for i := 0; i < concurrentRequests+n; i++ {
|
||||
go func() {
|
||||
_, err := c.GetReachability(context.Background(), []Request{{Addr: c.host.Addrs()[0], SendDialData: false}})
|
||||
errChan <- err
|
||||
@@ -224,17 +223,20 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// check N failures
|
||||
for i := 0; i < N; i++ {
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
require.Error(t, err)
|
||||
if !strings.Contains(err.Error(), "stream reset") && !strings.Contains(err.Error(), "E_REQUEST_REJECTED") {
|
||||
t.Fatalf("invalid error: %s expected: stream reset or E_REQUEST_REJECTED", err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("expected %d errors: got: %d", N, i)
|
||||
t.Fatalf("expected %d errors: got: %d", n, i)
|
||||
}
|
||||
}
|
||||
|
||||
close(stallChan) // complete stalled requests
|
||||
// check concurrentRequests failures, as we won't send dial data
|
||||
close(doneChan)
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -290,8 +292,8 @@ func TestServerDataRequestJitter(t *testing.T) {
|
||||
|
||||
require.Equal(t, Result{
|
||||
Addr: quicAddr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
if took > 500*time.Millisecond {
|
||||
return
|
||||
@@ -320,8 +322,8 @@ func TestServerDial(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: unreachableAddr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPrivate,
|
||||
Status: pb.DialStatus_E_DIAL_ERROR,
|
||||
}, res)
|
||||
})
|
||||
|
||||
@@ -330,16 +332,16 @@ func TestServerDial(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: hostAddrs[0],
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
for _, addr := range c.host.Addrs() {
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests([]ma.Multiaddr{addr}, false))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
}
|
||||
})
|
||||
@@ -347,12 +349,8 @@ func TestServerDial(t *testing.T) {
|
||||
t.Run("dialback error", func(t *testing.T) {
|
||||
c.host.RemoveStreamHandler(DialBackProtocol)
|
||||
res, err := c.GetReachability(context.Background(), newTestRequests(c.host.Addrs(), false))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: hostAddrs[0],
|
||||
Reachability: network.ReachabilityUnknown,
|
||||
Status: pb.DialStatus_E_DIAL_BACK_ERROR,
|
||||
}, res)
|
||||
require.ErrorContains(t, err, "dial-back stream error")
|
||||
require.Equal(t, Result{}, res)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -396,7 +394,6 @@ func TestRateLimiter(t *testing.T) {
|
||||
|
||||
cl.AdvanceBy(10 * time.Second)
|
||||
require.True(t, r.Accept("peer3"))
|
||||
|
||||
}
|
||||
|
||||
func TestRateLimiterConcurrentRequests(t *testing.T) {
|
||||
@@ -558,22 +555,23 @@ func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: quicv4Addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
|
||||
// ipv6 address should require dial data
|
||||
_, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}})
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "invalid dial data request: low priority addr")
|
||||
require.ErrorContains(t, err, "invalid dial data request")
|
||||
require.ErrorContains(t, err, "low priority addr")
|
||||
|
||||
// ipv6 address should work fine with dial data
|
||||
res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, Result{
|
||||
Addr: quicv6Addr,
|
||||
Idx: 0,
|
||||
Reachability: network.ReachabilityPublic,
|
||||
Status: pb.DialStatus_OK,
|
||||
}, res)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/protocol"
|
||||
"github.com/libp2p/go-libp2p/core/record"
|
||||
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
|
||||
"github.com/libp2p/go-libp2p/p2p/internal/rate"
|
||||
useragent "github.com/libp2p/go-libp2p/p2p/protocol/identify/internal/user-agent"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
|
||||
"github.com/libp2p/go-libp2p/x/rate"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-msgio/pbio"
|
||||
|
||||
@@ -3,6 +3,7 @@ package transport_integration
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -56,11 +57,6 @@ func TestResourceManagerIsUsed(t *testing.T) {
|
||||
expectedAddr = gomock.Any()
|
||||
}
|
||||
|
||||
expectFd := true
|
||||
if strings.Contains(tc.Name, "QUIC") || strings.Contains(tc.Name, "WebTransport") || strings.Contains(tc.Name, "WebRTC") {
|
||||
expectFd = false
|
||||
}
|
||||
|
||||
peerScope := mocknetwork.NewMockPeerScope(ctrl)
|
||||
peerScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes().Do(func(amount int, _ uint8) {
|
||||
reservedMemory.Add(int32(amount))
|
||||
@@ -93,10 +89,16 @@ func TestResourceManagerIsUsed(t *testing.T) {
|
||||
connScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any())
|
||||
}
|
||||
connScope.EXPECT().Done().MinTimes(1)
|
||||
// udp transports won't have FD
|
||||
udpTransportRegex := regexp.MustCompile(`QUIC|WebTransport|WebRTC`)
|
||||
expectFd := !udpTransportRegex.MatchString(tc.Name)
|
||||
|
||||
if !testDialer && (strings.Contains(tc.Name, "QUIC") || strings.Contains(tc.Name, "WebTransport")) {
|
||||
rcmgr.EXPECT().VerifySourceAddress(gomock.Any()).Return(false)
|
||||
}
|
||||
rcmgr.EXPECT().OpenConnection(expectedDir, expectFd, expectedAddr).Return(connScope, nil)
|
||||
|
||||
var allStreamsDone sync.WaitGroup
|
||||
|
||||
rcmgr.EXPECT().OpenConnection(expectedDir, expectFd, expectedAddr).Return(connScope, nil)
|
||||
rcmgr.EXPECT().OpenStream(expectedPeer, gomock.Any()).AnyTimes().DoAndReturn(func(_ peer.ID, _ network.Direction) (network.StreamManagementScope, error) {
|
||||
allStreamsDone.Add(1)
|
||||
streamScope := mocknetwork.NewMockStreamManagementScope(ctrl)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -38,6 +39,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/p2p/net/swarm"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
|
||||
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/websocket"
|
||||
@@ -275,6 +277,29 @@ var transportsToTest = []TransportTestCase{
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "QUIC-CustomReuse",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
libp2pOpts := transformOpts(opts)
|
||||
if opts.NoListen {
|
||||
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.QUICReuse(quicreuse.NewConnManager))
|
||||
} else {
|
||||
qr := libp2p.QUICReuse(quicreuse.NewConnManager)
|
||||
if !opts.NoRcmgr && opts.ResourceManager != nil {
|
||||
qr = libp2p.QUICReuse(
|
||||
quicreuse.NewConnManager,
|
||||
quicreuse.VerifySourceAddress(opts.ResourceManager.VerifySourceAddress))
|
||||
}
|
||||
libp2pOpts = append(libp2pOpts,
|
||||
qr,
|
||||
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1"),
|
||||
)
|
||||
}
|
||||
h, err := libp2p.New(libp2pOpts...)
|
||||
require.NoError(t, err)
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "WebTransport",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
@@ -289,6 +314,30 @@ var transportsToTest = []TransportTestCase{
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "WebTransport-CustomReuse",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
libp2pOpts := transformOpts(opts)
|
||||
if opts.NoListen {
|
||||
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs, libp2p.QUICReuse(quicreuse.NewConnManager))
|
||||
} else {
|
||||
qr := libp2p.QUICReuse(quicreuse.NewConnManager)
|
||||
if !opts.NoRcmgr && opts.ResourceManager != nil {
|
||||
qr = libp2p.QUICReuse(
|
||||
quicreuse.NewConnManager,
|
||||
quicreuse.VerifySourceAddress(opts.ResourceManager.VerifySourceAddress),
|
||||
)
|
||||
}
|
||||
libp2pOpts = append(libp2pOpts,
|
||||
qr,
|
||||
libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"),
|
||||
)
|
||||
}
|
||||
h, err := libp2p.New(libp2pOpts...)
|
||||
require.NoError(t, err)
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "WebRTC",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
@@ -844,17 +893,23 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) {
|
||||
// TestCloseConnWhenBlocked tests that the server closes the connection when the rcmgr blocks it.
|
||||
func TestCloseConnWhenBlocked(t *testing.T) {
|
||||
for _, tc := range transportsToTest {
|
||||
// WebRTC doesn't have a connection when rcmgr blocks it, so there's nothing to close.
|
||||
if tc.Name == "WebRTC" {
|
||||
continue // WebRTC doesn't have a connection when we block so there's nothing to close
|
||||
continue
|
||||
}
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockRcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
mockRcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).DoAndReturn(func(network.Direction, bool, ma.Multiaddr) (network.ConnManagementScope, error) {
|
||||
// Block the connection
|
||||
return nil, fmt.Errorf("connections blocked")
|
||||
})
|
||||
if matched, _ := regexp.MatchString(`^(QUIC|WebTransport)`, tc.Name); matched {
|
||||
mockRcmgr.EXPECT().VerifySourceAddress(gomock.Any()).AnyTimes().Return(false)
|
||||
// If the initial TLS ClientHello is split into two quic-go might call the transport multiple times to open a
|
||||
// connection. This will only be called multiple times if the connection is rejected. If were were to accept
|
||||
// the connection, this would have been called only once.
|
||||
mockRcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).Return(nil, errors.New("connection blocked")).AnyTimes()
|
||||
} else {
|
||||
mockRcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).Return(nil, errors.New("connection blocked"))
|
||||
}
|
||||
mockRcmgr.EXPECT().Close().AnyTimes()
|
||||
|
||||
server := tc.HostGenerator(t, TransportTestCaseOpts{ResourceManager: mockRcmgr})
|
||||
@@ -958,6 +1013,10 @@ func TestErrorCodes(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range transportsToTest {
|
||||
if strings.HasPrefix(tc.Name, "WebTransport") {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
continue
|
||||
}
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
server := tc.HostGenerator(t, TransportTestCaseOpts{})
|
||||
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
|
||||
@@ -993,10 +1052,6 @@ func TestErrorCodes(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("StreamResetWithError", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s, err := client.NewStream(ctx, server.ID(), "/test")
|
||||
@@ -1019,10 +1074,6 @@ func TestErrorCodes(t *testing.T) {
|
||||
})
|
||||
})
|
||||
t.Run("StreamResetWithErrorByRemote", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s, err := client.NewStream(ctx, server.ID(), "/test")
|
||||
@@ -1046,7 +1097,7 @@ func TestErrorCodes(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("StreamResetByConnCloseWithError", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
|
||||
if tc.Name == "WebRTC" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
@@ -1074,7 +1125,7 @@ func TestErrorCodes(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("NewStreamErrorByConnCloseWithError", func(t *testing.T) {
|
||||
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
|
||||
if tc.Name == "WebRTC" {
|
||||
t.Skipf("skipping: %s, not implemented", tc.Name)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -580,7 +580,7 @@ func testStatelessReset(t *testing.T, tc *connTestCase) {
|
||||
ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1")
|
||||
|
||||
var drop uint32
|
||||
dropCallback := func(quicproxy.Direction, []byte) bool { return atomic.LoadUint32(&drop) > 0 }
|
||||
dropCallback := func(quicproxy.Direction, net.Addr, net.Addr, []byte) bool { return atomic.LoadUint32(&drop) > 0 }
|
||||
proxyConn, cleanup := newUDPConnLocalhost(t, 0)
|
||||
proxy := quicproxy.Proxy{
|
||||
Conn: proxyConn,
|
||||
|
||||
@@ -88,12 +88,21 @@ func (l *listener) wrapConn(qconn quic.Connection) (*conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
|
||||
connScope, err := network.UnwrapConnManagementScope(qconn.Context())
|
||||
if err != nil {
|
||||
connScope = nil
|
||||
// Don't error here.
|
||||
// Setup scope if we don't have scope from quicreuse.
|
||||
// This is better than failing so that users that don't use quicreuse.ConnContext option with the resource
|
||||
// manager work correctly.
|
||||
}
|
||||
if connScope == nil {
|
||||
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked incoming connection", "addr", qconn.RemoteAddr(), "error", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
c, err := l.wrapConnWithScope(qconn, connScope, remoteMultiaddr)
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/quic-go/quic-go"
|
||||
quiclogging "github.com/quic-go/quic-go/logging"
|
||||
quicmetrics "github.com/quic-go/quic-go/metrics"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type QUICListener interface {
|
||||
@@ -38,7 +39,7 @@ type QUICTransport interface {
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// ConnManager implements using the same listen address for both QUIC & WebTransport, reusing
|
||||
// ConnManager enables QUIC and WebTransport transports to listen on the same port, reusing
|
||||
// listen addresses for dialing, and provides a PacketConn for sharing the listen address
|
||||
// with other protocols like WebRTC.
|
||||
// Reusing the listen address for dialing helps with address discovery and hole punching. For details
|
||||
@@ -64,6 +65,9 @@ type ConnManager struct {
|
||||
|
||||
srk quic.StatelessResetKey
|
||||
tokenKey quic.TokenGeneratorKey
|
||||
connContext connContextFunc
|
||||
|
||||
verifySourceAddress func(addr net.Addr) bool
|
||||
}
|
||||
|
||||
type quicListenerEntry struct {
|
||||
@@ -80,6 +84,11 @@ func defaultSourceIPSelectorFn() (SourceIPSelector, error) {
|
||||
return &netrouteSourceIPSelector{routes: r}, err
|
||||
}
|
||||
|
||||
const (
|
||||
unverifiedAddressNewConnectionRPS = 1000
|
||||
unverifiedAddressNewConnectionBurst = 1000
|
||||
)
|
||||
|
||||
// NewConnManager returns a new ConnManager
|
||||
func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.TokenGeneratorKey, opts ...Option) (*ConnManager, error) {
|
||||
cm := &ConnManager{
|
||||
@@ -103,9 +112,24 @@ func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.Toke
|
||||
|
||||
cm.clientConfig = quicConf
|
||||
cm.serverConfig = serverConfig
|
||||
|
||||
// Verify source addresses when under high load.
|
||||
// This is ensures that the number of spoofed/unverified addresses that are passed to downstream rate limiters
|
||||
// are limited, which enables IP address based rate limiting.
|
||||
sourceAddrRateLimiter := rate.NewLimiter(unverifiedAddressNewConnectionRPS, unverifiedAddressNewConnectionBurst)
|
||||
vsa := cm.verifySourceAddress
|
||||
cm.verifySourceAddress = func(addr net.Addr) bool {
|
||||
if sourceAddrRateLimiter.Allow() {
|
||||
if vsa != nil {
|
||||
return vsa(addr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
if cm.enableReuseport {
|
||||
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn)
|
||||
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn)
|
||||
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn, cm.connContext, cm.verifySourceAddress)
|
||||
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn, cm.connContext, cm.verifySourceAddress)
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
@@ -290,16 +314,7 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &singleOwnerTransport{
|
||||
packetConn: conn,
|
||||
Transport: &wrappedQUICTransport{
|
||||
&quic.Transport{
|
||||
Conn: conn,
|
||||
StatelessResetKey: &c.srk,
|
||||
TokenGeneratorKey: &c.tokenKey,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
return c.newSingleOwnerTransport(conn), nil
|
||||
}
|
||||
|
||||
type associationKey struct{}
|
||||
@@ -378,11 +393,24 @@ func (c *ConnManager) TransportWithAssociationForDial(association any, network s
|
||||
laddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
|
||||
}
|
||||
conn, err := c.listenUDP(network, laddr)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &singleOwnerTransport{Transport: &wrappedQUICTransport{&quic.Transport{Conn: conn, StatelessResetKey: &c.srk}}, packetConn: conn}, nil
|
||||
return c.newSingleOwnerTransport(conn), nil
|
||||
}
|
||||
|
||||
func (c *ConnManager) newSingleOwnerTransport(conn net.PacketConn) *singleOwnerTransport {
|
||||
return &singleOwnerTransport{
|
||||
Transport: &wrappedQUICTransport{
|
||||
Transport: newQUICTransport(
|
||||
conn,
|
||||
&c.tokenKey,
|
||||
&c.srk,
|
||||
c.connContext,
|
||||
c.verifySourceAddress,
|
||||
),
|
||||
},
|
||||
packetConn: conn}
|
||||
}
|
||||
|
||||
// Protocols returns the supported QUIC protocols. The only supported protocol at the moment is /quic-v1.
|
||||
@@ -414,3 +442,19 @@ var _ QUICTransport = (*wrappedQUICTransport)(nil)
|
||||
func (t *wrappedQUICTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
|
||||
return t.Transport.Listen(tlsConf, conf)
|
||||
}
|
||||
|
||||
func newQUICTransport(
|
||||
conn net.PacketConn,
|
||||
tokenGeneratorKey *quic.TokenGeneratorKey,
|
||||
statelessResetKey *quic.StatelessResetKey,
|
||||
connContext connContextFunc,
|
||||
verifySourceAddress func(addr net.Addr) bool,
|
||||
) *quic.Transport {
|
||||
return &quic.Transport{
|
||||
Conn: conn,
|
||||
TokenGeneratorKey: tokenGeneratorKey,
|
||||
StatelessResetKey: statelessResetKey,
|
||||
ConnContext: connContext,
|
||||
VerifySourceAddress: verifySourceAddress,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
@@ -63,7 +64,7 @@ func testListenOnSameProto(t *testing.T, enableReuseport bool) {
|
||||
|
||||
ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln1.Close()
|
||||
defer func() { _ = ln1.Close() }()
|
||||
|
||||
addr := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", ln1.Addr().(*net.UDPAddr).Port))
|
||||
_, err = cm.ListenQUIC(addr, &tls.Config{NextProtos: []string{alpn}}, nil)
|
||||
@@ -72,7 +73,7 @@ func testListenOnSameProto(t *testing.T, enableReuseport bool) {
|
||||
// listening on a different address works
|
||||
ln2, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln2.Close()
|
||||
defer func() { _ = ln2.Close() }()
|
||||
}
|
||||
|
||||
// The conn passed to quic-go should be a conn that quic-go can be
|
||||
@@ -206,7 +207,9 @@ func connectWithProtocol(t *testing.T, addr net.Addr, alpn string) (peer.ID, err
|
||||
cconn, err := net.ListenUDP("udp4", nil)
|
||||
tlsConf.NextProtos = []string{alpn}
|
||||
require.NoError(t, err)
|
||||
c, err := quic.Dial(context.Background(), cconn, addr, tlsConf, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
c, err := quic.Dial(ctx, cconn, addr, tlsConf, nil)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -387,3 +390,102 @@ func TestAssociate(t *testing.T) {
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnContext(t *testing.T) {
|
||||
for _, reuse := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("reuseport:%t_error", reuse), func(t *testing.T) {
|
||||
opts := []Option{
|
||||
ConnContext(func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
|
||||
return ctx, errors.New("test error")
|
||||
})}
|
||||
if !reuse {
|
||||
opts = append(opts, DisableReuseport())
|
||||
}
|
||||
cm, err := NewConnManager(
|
||||
quic.StatelessResetKey{},
|
||||
quic.TokenGeneratorKey{},
|
||||
opts...,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = cm.Close() }()
|
||||
|
||||
proto1 := "proto1"
|
||||
_, proto1TLS := getTLSConfForProto(t, proto1)
|
||||
ln1, err := cm.ListenQUIC(
|
||||
ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"),
|
||||
proto1TLS,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln1.Close()
|
||||
proto2 := "proto2"
|
||||
_, proto2TLS := getTLSConfForProto(t, proto2)
|
||||
ln2, err := cm.ListenQUIC(
|
||||
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", ln1.Addr().(*net.UDPAddr).Port)),
|
||||
proto2TLS,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln2.Close()
|
||||
|
||||
_, err = connectWithProtocol(t, ln1.Addr(), proto1)
|
||||
require.ErrorContains(t, err, "CONNECTION_REFUSED")
|
||||
|
||||
_, err = connectWithProtocol(t, ln1.Addr(), proto2)
|
||||
require.ErrorContains(t, err, "CONNECTION_REFUSED")
|
||||
})
|
||||
t.Run(fmt.Sprintf("reuseport:%t_success", reuse), func(t *testing.T) {
|
||||
type ctxKey struct{}
|
||||
opts := []Option{
|
||||
ConnContext(func(ctx context.Context, _ *quic.ClientInfo) (context.Context, error) {
|
||||
return context.WithValue(ctx, ctxKey{}, "success"), nil
|
||||
})}
|
||||
if !reuse {
|
||||
opts = append(opts, DisableReuseport())
|
||||
}
|
||||
cm, err := NewConnManager(
|
||||
quic.StatelessResetKey{},
|
||||
quic.TokenGeneratorKey{},
|
||||
opts...,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = cm.Close() }()
|
||||
|
||||
proto1 := "proto1"
|
||||
_, proto1TLS := getTLSConfForProto(t, proto1)
|
||||
ln1, err := cm.ListenQUIC(
|
||||
ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"),
|
||||
proto1TLS,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer ln1.Close()
|
||||
|
||||
clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
clientIdentity, err := libp2ptls.NewIdentity(clientKey)
|
||||
require.NoError(t, err)
|
||||
tlsConf, peerChan := clientIdentity.ConfigForPeer("")
|
||||
cconn, err := net.ListenUDP("udp4", nil)
|
||||
tlsConf.NextProtos = []string{proto1}
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
conn, err := quic.Dial(ctx, cconn, ln1.Addr(), tlsConf, nil)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
defer conn.CloseWithError(0, "")
|
||||
|
||||
require.Equal(t, proto1, conn.ConnectionState().TLS.NegotiatedProtocol)
|
||||
_, err = peer.IDFromPublicKey(<-peerChan)
|
||||
require.NoError(t, err)
|
||||
|
||||
acceptCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
c, err := ln1.Accept(acceptCtx)
|
||||
cancel()
|
||||
require.NoError(t, err)
|
||||
defer c.CloseWithError(0, "")
|
||||
|
||||
require.Equal(t, "success", c.Context().Value(ctxKey{}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package quicreuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type Option func(*ConnManager) error
|
||||
@@ -31,6 +34,27 @@ func DisableReuseport() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// ConnContext sets the context for all connections accepted by listeners. This doesn't affect the
|
||||
// context for dialed connections. To reject a connection, return a non nil error.
|
||||
func ConnContext(f func(ctx context.Context, clientInfo *quic.ClientInfo) (context.Context, error)) Option {
|
||||
return func(m *ConnManager) error {
|
||||
if m.connContext != nil {
|
||||
return errors.New("cannot set ConnContext more than once")
|
||||
}
|
||||
m.connContext = f
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// VerifySourceAddress returns whether to verify the source address for incoming connection requests.
|
||||
// For more details see: `quic.Transport.VerifySourceAddress`
|
||||
func VerifySourceAddress(f func(addr net.Addr) bool) Option {
|
||||
return func(m *ConnManager) error {
|
||||
m.verifySourceAddress = f
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// EnableMetrics enables Prometheus metrics collection. If reg is nil,
|
||||
// prometheus.DefaultRegisterer will be used as the registerer.
|
||||
func EnableMetrics(reg prometheus.Registerer) Option {
|
||||
|
||||
@@ -91,6 +91,8 @@ type refcountedTransport struct {
|
||||
assocations map[any]struct{}
|
||||
}
|
||||
|
||||
type connContextFunc = func(context.Context, *quic.ClientInfo) (context.Context, error)
|
||||
|
||||
// associate an arbitrary value with this transport.
|
||||
// This lets us "tag" the refcountedTransport when listening so we can use it
|
||||
// later for dialing. Necessary for holepunching and learning about our own
|
||||
@@ -183,9 +185,12 @@ type reuse struct {
|
||||
|
||||
statelessResetKey *quic.StatelessResetKey
|
||||
tokenGeneratorKey *quic.TokenGeneratorKey
|
||||
connContext connContextFunc
|
||||
verifySourceAddress func(addr net.Addr) bool
|
||||
}
|
||||
|
||||
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey, listenUDP listenUDP, sourceIPSelectorFn func() (SourceIPSelector, error)) *reuse {
|
||||
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey, listenUDP listenUDP, sourceIPSelectorFn func() (SourceIPSelector, error),
|
||||
connContext connContextFunc, verifySourceAddress func(addr net.Addr) bool) *reuse {
|
||||
r := &reuse{
|
||||
unicast: make(map[string]map[int]*refcountedTransport),
|
||||
globalListeners: make(map[int]*refcountedTransport),
|
||||
@@ -196,6 +201,8 @@ func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey, lis
|
||||
sourceIPSelectorFn: sourceIPSelectorFn,
|
||||
statelessResetKey: srk,
|
||||
tokenGeneratorKey: tokenKey,
|
||||
connContext: connContext,
|
||||
verifySourceAddress: verifySourceAddress,
|
||||
}
|
||||
go r.gc()
|
||||
return r
|
||||
@@ -341,16 +348,7 @@ func (r *reuse) transportForDialLocked(association any, network string, source *
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr := &refcountedTransport{
|
||||
QUICTransport: &wrappedQUICTransport{
|
||||
Transport: &quic.Transport{
|
||||
Conn: conn,
|
||||
StatelessResetKey: r.statelessResetKey,
|
||||
TokenGeneratorKey: r.tokenGeneratorKey,
|
||||
},
|
||||
},
|
||||
packetConn: conn,
|
||||
}
|
||||
tr := r.newTransport(conn)
|
||||
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr
|
||||
return tr, nil
|
||||
}
|
||||
@@ -434,18 +432,10 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
tr := &refcountedTransport{
|
||||
QUICTransport: &wrappedQUICTransport{
|
||||
Transport: &quic.Transport{
|
||||
Conn: conn,
|
||||
StatelessResetKey: r.statelessResetKey,
|
||||
},
|
||||
},
|
||||
packetConn: conn,
|
||||
}
|
||||
tr := r.newTransport(conn)
|
||||
tr.IncreaseCount()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
// Deal with listen on a global address
|
||||
if localAddr.IP.IsUnspecified() {
|
||||
// The kernel already checked that the laddr is not already listen
|
||||
@@ -468,6 +458,21 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
func (r *reuse) newTransport(conn net.PacketConn) *refcountedTransport {
|
||||
return &refcountedTransport{
|
||||
QUICTransport: &wrappedQUICTransport{
|
||||
Transport: newQUICTransport(
|
||||
conn,
|
||||
r.tokenGeneratorKey,
|
||||
r.statelessResetKey,
|
||||
r.connContext,
|
||||
r.verifySourceAddress,
|
||||
),
|
||||
},
|
||||
packetConn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *reuse) Close() error {
|
||||
close(r.closeChan)
|
||||
<-r.gcStopChan
|
||||
|
||||
@@ -61,7 +61,7 @@ func cleanup(t *testing.T, reuse *reuse) {
|
||||
}
|
||||
|
||||
func TestReuseListenOnAllIPv4(t *testing.T) {
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
require.Eventually(t, isGarbageCollectorRunning, 500*time.Millisecond, 50*time.Millisecond, "expected garbage collector to be running")
|
||||
cleanup(t, reuse)
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestReuseListenOnAllIPv4(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseListenOnAllIPv6(t *testing.T) {
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
require.Eventually(t, isGarbageCollectorRunning, 500*time.Millisecond, 50*time.Millisecond, "expected garbage collector to be running")
|
||||
cleanup(t, reuse)
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestReuseListenOnAllIPv6(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
@@ -100,7 +100,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenDialing(t *testing.T) {
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
|
||||
@@ -117,7 +117,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenListening(t *testing.T) {
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
cleanup(t, reuse)
|
||||
|
||||
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
@@ -132,7 +132,7 @@ func TestReuseConnectionWhenListening(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
cleanup(t, reuse)
|
||||
|
||||
// dial any address
|
||||
@@ -166,7 +166,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) {
|
||||
if platformHasRoutingTables() {
|
||||
t.Skip("this test only works on platforms that support routing tables")
|
||||
}
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
cleanup(t, reuse)
|
||||
|
||||
router, err := netroute.New()
|
||||
@@ -203,7 +203,7 @@ func TestReuseGarbageCollect(t *testing.T) {
|
||||
maxUnusedDuration = 10 * maxUnusedDuration
|
||||
}
|
||||
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn, nil, nil)
|
||||
cleanup(t, reuse)
|
||||
|
||||
numGlobals := func() int {
|
||||
|
||||
@@ -799,7 +799,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) {
|
||||
proxy := quicproxy.Proxy{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||
DropPacket: func(quicproxy.Direction, []byte) bool { return drop.Load() },
|
||||
DropPacket: func(_ quicproxy.Direction, _, _ net.Addr, _ []byte) bool { return drop.Load() },
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
@@ -72,7 +72,7 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
|
||||
// garbage collection to properly work in this package.
|
||||
func (c *conn) Close() error {
|
||||
defer c.scope.Done()
|
||||
c.transport.removeConn(c.session)
|
||||
c.transport.removeConn(c.qconn)
|
||||
err := c.session.CloseWithError(0, "")
|
||||
_ = c.qconn.CloseWithError(1, "")
|
||||
return err
|
||||
|
||||
@@ -149,13 +149,22 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
connScope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
|
||||
connScope, err := network.UnwrapConnManagementScope(r.Context())
|
||||
if err != nil {
|
||||
connScope = nil
|
||||
// Don't error here.
|
||||
// Setup scope if we don't have scope from quicreuse.
|
||||
// This is better than failing so that users that don't use quicreuse.ConnContext option with the resource
|
||||
// manager still work correctly.
|
||||
}
|
||||
if connScope == nil {
|
||||
connScope, err = l.transport.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked incoming connection", "addr", r.RemoteAddr, "error", err)
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
}
|
||||
err = l.httpHandlerWithConnScope(w, r, connScope)
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
@@ -212,7 +221,7 @@ func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
conn := newConn(l.transport, sess, sconn, connScope, qconn)
|
||||
l.transport.addConn(sess, conn)
|
||||
l.transport.addConn(qconn, conn)
|
||||
select {
|
||||
case l.queue <- conn:
|
||||
default:
|
||||
|
||||
@@ -86,7 +86,7 @@ type transport struct {
|
||||
noise *noise.Transport
|
||||
|
||||
connMx sync.Mutex
|
||||
conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key
|
||||
conns map[quic.Connection]*conn // quic connection -> *conn
|
||||
handshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater
|
||||
gater: gater,
|
||||
clock: clock.New(),
|
||||
connManager: connManager,
|
||||
conns: map[quic.ConnectionTracingID]*conn{},
|
||||
conns: map[quic.Connection]*conn{},
|
||||
handshakeTimeout: handshakeTimeout,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
@@ -184,7 +184,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
|
||||
return nil, fmt.Errorf("secured connection gated")
|
||||
}
|
||||
conn := newConn(t, sess, sconn, scope, qconn)
|
||||
t.addConn(sess, conn)
|
||||
t.addConn(qconn, conn)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@@ -361,22 +361,22 @@ func (t *transport) allowWindowIncrease(conn quic.Connection, size uint64) bool
|
||||
t.connMx.Lock()
|
||||
defer t.connMx.Unlock()
|
||||
|
||||
c, ok := t.conns[conn.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)]
|
||||
c, ok := t.conns[conn]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return c.allowWindowIncrease(size)
|
||||
}
|
||||
|
||||
func (t *transport) addConn(sess *webtransport.Session, c *conn) {
|
||||
func (t *transport) addConn(conn quic.Connection, c *conn) {
|
||||
t.connMx.Lock()
|
||||
t.conns[sess.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)] = c
|
||||
t.conns[conn] = c
|
||||
t.connMx.Unlock()
|
||||
}
|
||||
|
||||
func (t *transport) removeConn(sess *webtransport.Session) {
|
||||
func (t *transport) removeConn(conn quic.Connection) {
|
||||
t.connMx.Lock()
|
||||
delete(t.conns, sess.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID))
|
||||
delete(t.conns, conn)
|
||||
t.connMx.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -584,7 +584,7 @@ func TestFlowControlWindowIncrease(t *testing.T) {
|
||||
proxy := quicproxy.Proxy{
|
||||
Conn: newUDPConnLocalhost(t),
|
||||
ServerAddr: ln.Addr().(*net.UDPAddr),
|
||||
DelayPacket: func(quicproxy.Direction, []byte) time.Duration { return rtt / 2 },
|
||||
DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
|
||||
}
|
||||
require.NoError(t, proxy.Start())
|
||||
defer proxy.Close()
|
||||
|
||||
@@ -19,7 +19,7 @@ require (
|
||||
github.com/francoispqt/gojay v1.2.13 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/gopacket v1.1.19 // indirect
|
||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 // indirect
|
||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/huin/goupnp v1.3.0 // indirect
|
||||
@@ -54,7 +54,7 @@ require (
|
||||
github.com/multiformats/go-multistream v0.6.0 // indirect
|
||||
github.com/multiformats/go-varint v0.0.7 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.22.2 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.23.4 // indirect
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
|
||||
github.com/pion/datachannel v1.5.10 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||
@@ -80,24 +80,25 @@ require (
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/quic-go/quic-go v0.50.0 // indirect
|
||||
github.com/quic-go/quic-go v0.52.0 // indirect
|
||||
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.uber.org/dig v1.18.0 // indirect
|
||||
go.uber.org/fx v1.23.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
go.uber.org/mock v0.5.2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
golang.org/x/crypto v0.35.0 // indirect
|
||||
golang.org/x/crypto v0.37.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/mod v0.23.0 // indirect
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
golang.org/x/mod v0.24.0 // indirect
|
||||
golang.org/x/net v0.39.0 // indirect
|
||||
golang.org/x/sync v0.14.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
golang.org/x/time v0.11.0 // indirect
|
||||
golang.org/x/tools v0.30.0 // indirect
|
||||
golang.org/x/tools v0.32.0 // indirect
|
||||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
lukechampine.com/blake3 v1.4.0 // indirect
|
||||
)
|
||||
|
||||
@@ -59,16 +59,16 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
|
||||
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
|
||||
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
|
||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941 h1:43XjGa6toxLpeksjcxs1jIoIyr+vUfOqY2c6HB4bpoc=
|
||||
github.com/google/pprof v0.0.0-20250208200701-d0013a598941/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144=
|
||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a h1:rDA3FfmxwXR+BVKKdz55WwMJ1pD2hJQNW31d+l3mPk4=
|
||||
github.com/google/pprof v0.0.0-20250501235452-c0086092b71a/go.mod h1:5hDyRhoBCxViHszMt12TnOpEI4VVi+U8Gm9iphldiMA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY=
|
||||
@@ -179,10 +179,10 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU=
|
||||
github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk=
|
||||
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
|
||||
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
|
||||
github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus=
|
||||
github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8=
|
||||
github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU=
|
||||
github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
|
||||
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0=
|
||||
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y=
|
||||
@@ -231,6 +231,8 @@ github.com/pion/webrtc/v4 v4.0.14/go.mod h1:R3+qTnQTS03UzwDarYecgioNf7DYgTsldxnC
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA=
|
||||
github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
@@ -245,8 +247,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo=
|
||||
github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E=
|
||||
github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA=
|
||||
github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ=
|
||||
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg=
|
||||
github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@@ -301,6 +303,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.uber.org/dig v1.18.0 h1:imUL1UiY0Mg4bqbFfsRQO5G4CGRBec/ZujWTvSVp3pw=
|
||||
go.uber.org/dig v1.18.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE=
|
||||
go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg=
|
||||
@@ -308,8 +312,8 @@ go.uber.org/fx v1.23.0/go.mod h1:o/D9n+2mLP6v1EG+qsdT1O8wKopYAsqZasju97SDFCU=
|
||||
go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
@@ -329,8 +333,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
|
||||
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
|
||||
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
@@ -343,8 +347,8 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM=
|
||||
golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
|
||||
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -365,8 +369,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
|
||||
golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
@@ -380,8 +384,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
|
||||
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -402,8 +406,8 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -419,8 +423,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
@@ -436,8 +440,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||
golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU=
|
||||
golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
@@ -62,6 +61,8 @@ func (r *Limiter) init() {
|
||||
} else {
|
||||
r.globalBucket = rate.NewLimiter(rate.Limit(r.GlobalLimit.RPS), r.GlobalLimit.Burst)
|
||||
}
|
||||
// clone the slice in case it's shared with other limiters
|
||||
r.NetworkPrefixLimits = slices.Clone(r.NetworkPrefixLimits)
|
||||
// sort such that the widest prefix (smallest bit count) is last.
|
||||
slices.SortFunc(r.NetworkPrefixLimits, func(a, b PrefixLimit) int { return b.Prefix.Bits() - a.Prefix.Bits() })
|
||||
r.networkPrefixBuckets = make([]*rate.Limiter, 0, len(r.NetworkPrefixLimits))
|
||||
@@ -79,7 +80,16 @@ func (r *Limiter) init() {
|
||||
func (r *Limiter) Limit(f func(s network.Stream)) func(s network.Stream) {
|
||||
r.init()
|
||||
return func(s network.Stream) {
|
||||
if !r.allow(s.Conn().RemoteMultiaddr()) {
|
||||
addr := s.Conn().RemoteMultiaddr()
|
||||
ip, err := manet.ToIP(addr)
|
||||
if err != nil {
|
||||
ip = nil
|
||||
}
|
||||
ipAddr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
ipAddr = netip.Addr{}
|
||||
}
|
||||
if !r.Allow(ipAddr) {
|
||||
_ = s.ResetWithError(network.StreamRateLimited)
|
||||
return
|
||||
}
|
||||
@@ -87,7 +97,8 @@ func (r *Limiter) Limit(f func(s network.Stream)) func(s network.Stream) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Limiter) allow(addr ma.Multiaddr) bool {
|
||||
// Allow returns true if requests for `ipAddr` are within specified rate limits
|
||||
func (r *Limiter) Allow(ipAddr netip.Addr) bool {
|
||||
r.init()
|
||||
// Check buckets from the most specific to the least.
|
||||
//
|
||||
@@ -97,14 +108,6 @@ func (r *Limiter) allow(addr ma.Multiaddr) bool {
|
||||
// bucket before the specific bucket, and the specific bucket rejected the
|
||||
// request, there's no way to return the token to the global bucket. So all
|
||||
// rejected requests from the specific bucket would take up tokens from the global bucket.
|
||||
ip, err := manet.ToIP(addr)
|
||||
if err != nil {
|
||||
return r.globalBucket.Allow()
|
||||
}
|
||||
ipAddr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return r.globalBucket.Allow()
|
||||
}
|
||||
|
||||
// prefixs have been sorted from most to least specific so rejected requests for more
|
||||
// specific prefixes don't take up tokens from the less specific prefixes.
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
@@ -25,19 +24,19 @@ func getSleepDurationAndRequestCount(rps float64) (time.Duration, int) {
|
||||
return sleepDuration, requestCount
|
||||
}
|
||||
|
||||
func assertLimiter(t *testing.T, rl *Limiter, addr ma.Multiaddr, allowed, errorMargin int) {
|
||||
func assertLimiter(t *testing.T, rl *Limiter, ipAddr netip.Addr, allowed, errorMargin int) {
|
||||
t.Helper()
|
||||
for i := 0; i < allowed; i++ {
|
||||
require.True(t, rl.allow(addr))
|
||||
require.True(t, rl.Allow(ipAddr))
|
||||
}
|
||||
for i := 0; i < errorMargin; i++ {
|
||||
rl.allow(addr)
|
||||
rl.Allow(ipAddr)
|
||||
}
|
||||
require.False(t, rl.allow(addr))
|
||||
require.False(t, rl.Allow(ipAddr))
|
||||
}
|
||||
|
||||
func TestLimiterGlobal(t *testing.T) {
|
||||
addr := ma.StringCast("/ip4/127.0.0.1/udp/123/quic-v1")
|
||||
addr := netip.MustParseAddr("127.0.0.1")
|
||||
limits := []Limit{
|
||||
{RPS: 0.0, Burst: 1},
|
||||
{RPS: 0.8, Burst: 1},
|
||||
@@ -53,7 +52,7 @@ func TestLimiterGlobal(t *testing.T) {
|
||||
if limit.RPS == 0 {
|
||||
// 0 implies no rate limiting, any large number would do
|
||||
for i := 0; i < 1000; i++ {
|
||||
require.True(t, rl.allow(addr))
|
||||
require.True(t, rl.Allow(addr))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -66,8 +65,8 @@ func TestLimiterGlobal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLimiterNetworkPrefix(t *testing.T) {
|
||||
local := ma.StringCast("/ip4/127.0.0.1/udp/123/quic-v1")
|
||||
public := ma.StringCast("/ip4/1.1.1.1/udp/123/quic-v1")
|
||||
local := netip.MustParseAddr("127.0.0.1")
|
||||
public := netip.MustParseAddr("1.1.1.1")
|
||||
rl := &Limiter{
|
||||
NetworkPrefixLimits: []PrefixLimit{
|
||||
{Prefix: netip.MustParsePrefix("127.0.0.0/24"), Limit: Limit{}},
|
||||
@@ -76,22 +75,22 @@ func TestLimiterNetworkPrefix(t *testing.T) {
|
||||
}
|
||||
// element within prefix is allowed even over the limit
|
||||
for range rl.GlobalLimit.Burst + 100 {
|
||||
require.True(t, rl.allow(local))
|
||||
require.True(t, rl.Allow(local))
|
||||
}
|
||||
// rate limit public ips
|
||||
assertLimiter(t, rl, public, rl.GlobalLimit.Burst, int(rl.GlobalLimit.RPS*rateLimitErrorTolerance))
|
||||
|
||||
// public ip rejected
|
||||
require.False(t, rl.allow(public))
|
||||
require.False(t, rl.Allow(public))
|
||||
// local ip accepted
|
||||
for range 100 {
|
||||
require.True(t, rl.allow(local))
|
||||
require.True(t, rl.Allow(local))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimiterNetworkPrefixWidth(t *testing.T) {
|
||||
a1 := ma.StringCast("/ip4/1.1.1.1/udp/123/quic-v1")
|
||||
a2 := ma.StringCast("/ip4/1.1.0.1/udp/123/quic-v1")
|
||||
a1 := netip.MustParseAddr("1.1.1.1")
|
||||
a2 := netip.MustParseAddr("1.1.0.1")
|
||||
|
||||
wideLimit := 20
|
||||
narrowLimit := 10
|
||||
@@ -102,13 +101,13 @@ func TestLimiterNetworkPrefixWidth(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for range 2 * wideLimit {
|
||||
rl.allow(a1)
|
||||
rl.Allow(a1)
|
||||
}
|
||||
// a1 rejected
|
||||
require.False(t, rl.allow(a1))
|
||||
require.False(t, rl.Allow(a1))
|
||||
// a2 accepted
|
||||
for range wideLimit - narrowLimit {
|
||||
require.True(t, rl.allow(a2))
|
||||
require.True(t, rl.Allow(a2))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user