transport: add GatedMaListener type (#3186)

This introduces a new GatedMaListener type which gates conns
accepted from a manet.Listener with a gater and creates the rcmgr
scope for it. Explicitly passing the scope allows for many guardrails
that the previous interface assertion didn't.

This breaks the previous responsibility of the upgradeListener method
into two, one gating the connection initially, and the other upgrading
the connection with a security and muxer selection.

This split makes it easy to gate the connection with the resource
manager as early as possible. This is especially true for websocket
because we want to gate the connection just after the TCP connection is
established, and not after the tls handshake + websocket upgrade is
completed.
This commit is contained in:
sukun
2025-03-25 22:09:57 +05:30
committed by GitHub
parent 8430ad3e2f
commit 6249e685e9
19 changed files with 602 additions and 317 deletions

View File

@@ -292,11 +292,11 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }), fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }), fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr { fx.Provide(func(upgrader transport.Upgrader) *tcpreuse.ConnMgr {
if !cfg.ShareTCPListener { if !cfg.ShareTCPListener {
return nil return nil
} }
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr) return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, upgrader)
}), }),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn { fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool { hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {

View File

@@ -129,11 +129,41 @@ type TransportNetwork interface {
AddTransport(t Transport) error AddTransport(t Transport) error
} }
// GatedMaListener is listener that listens for raw(unsecured and non-multiplexed) incoming connections,
// gates them with a `connmgr.ConnGater`and creates a resource management scope for them.
// It can be upgraded to a full libp2p transport listener by the Upgrader.
//
// Compared to manet.Listener, this listener creates the resource management scope for the accepted connection.
type GatedMaListener interface {
// Accept waits for and returns the next connection to the listener.
Accept() (manet.Conn, network.ConnManagementScope, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error
// Multiaddr returns the listener's (local) Multiaddr.
Multiaddr() ma.Multiaddr
// Addr returns the net.Listener's network address.
Addr() net.Addr
}
// Upgrader is a multistream upgrader that can upgrade an underlying connection // Upgrader is a multistream upgrader that can upgrade an underlying connection
// to a full transport connection (secure and multiplexed). // to a full transport connection (secure and multiplexed).
type Upgrader interface { type Upgrader interface {
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener. // UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
//
// Deprecated: Use UpgradeGatedMaListener(upgrader.GateMaListener(manet.Listener)) instead.
UpgradeListener(Transport, manet.Listener) Listener UpgradeListener(Transport, manet.Listener) Listener
// GateMaListener creates a GatedMaListener from a manet.Listener. It gates the accepted connection
// and creates a resource scope for it.
GateMaListener(manet.Listener) GatedMaListener
// UpgradeGatedMaListener upgrades the passed GatedMaListener into a full libp2p-transport listener.
UpgradeGatedMaListener(Transport, GatedMaListener) Listener
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection. // Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.
Upgrade(ctx context.Context, t Transport, maconn manet.Conn, dir network.Direction, p peer.ID, scope network.ConnManagementScope) (CapableConn, error) Upgrade(ctx context.Context, t Transport, maconn manet.Conn, dir network.Direction, p peer.ID, scope network.ConnManagementScope) (CapableConn, error)
} }

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/core/transport"
@@ -17,7 +18,7 @@ import (
var log = logging.Logger("upgrader") var log = logging.Logger("upgrader")
type listener struct { type listener struct {
manet.Listener transport.GatedMaListener
transport transport.Transport transport transport.Transport
upgrader *upgrader upgrader *upgrader
@@ -35,10 +36,12 @@ type listener struct {
cancel func() cancel func()
} }
var _ transport.Listener = (*listener)(nil)
// Close closes the listener. // Close closes the listener.
func (l *listener) Close() error { func (l *listener) Close() error {
// Do this first to try to get any relevant errors. // Do this first to try to get any relevant errors.
err := l.Listener.Close() err := l.GatedMaListener.Close()
l.cancel() l.cancel()
// Drain and wait. // Drain and wait.
@@ -61,7 +64,7 @@ func (l *listener) handleIncoming() {
var wg sync.WaitGroup var wg sync.WaitGroup
defer func() { defer func() {
// make sure we're closed // make sure we're closed
l.Listener.Close() l.GatedMaListener.Close()
if l.err == nil { if l.err == nil {
l.err = fmt.Errorf("listener closed") l.err = fmt.Errorf("listener closed")
} }
@@ -72,7 +75,7 @@ func (l *listener) handleIncoming() {
var catcher tec.TempErrCatcher var catcher tec.TempErrCatcher
for l.ctx.Err() == nil { for l.ctx.Err() == nil {
maconn, err := l.Listener.Accept() maconn, connScope, err := l.GatedMaListener.Accept()
if err != nil { if err != nil {
// Note: function may pause the accept loop. // Note: function may pause the accept loop.
if catcher.IsTemporary(err) { if catcher.IsTemporary(err) {
@@ -84,33 +87,10 @@ func (l *listener) handleIncoming() {
} }
catcher.Reset() catcher.Reset()
// Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation.
var connScope network.ConnManagementScope
if sc, ok := maconn.(interface {
Scope() network.ConnManagementScope
}); ok {
connScope = sc.Scope()
}
if connScope == nil { if connScope == nil {
// gate the connection if applicable log.Errorf("BUG: got nil connScope for incoming connection from %s", maconn.RemoteMultiaddr())
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) { maconn.Close()
log.Debugf("gater blocked incoming connection on local addr %s from %s", continue
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
var err error
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
} }
// The go routine below calls Release when the context is // The go routine below calls Release when the context is
@@ -154,14 +134,10 @@ func (l *listener) handleIncoming() {
select { select {
case l.incoming <- conn: case l.incoming <- conn:
case <-ctx.Done(): case <-ctx.Done():
// Listener not closed but the accept timeout expired.
if l.ctx.Err() == nil { if l.ctx.Err() == nil {
// Listener *not* closed but the accept timeout expired. log.Warnf("listener dropped connection due to slow accept. remote addr: %s peer: %s", maconn.RemoteMultiaddr(), conn.RemotePeer())
log.Warn("listener dropped connection due to slow accept")
} }
// Wait on the context with a timeout. This way,
// if we stop accepting connections for some reason,
// we'll eventually close all the open ones
// instead of hanging onto them.
conn.CloseWithError(network.ConnRateLimited) conn.CloseWithError(network.ConnRateLimited)
} }
}() }()
@@ -189,4 +165,38 @@ func (l *listener) String() string {
return fmt.Sprintf("<stream.Listener %s>", l.Multiaddr()) return fmt.Sprintf("<stream.Listener %s>", l.Multiaddr())
} }
var _ transport.Listener = (*listener)(nil) type gatedMaListener struct {
manet.Listener
rcmgr network.ResourceManager
connGater connmgr.ConnectionGater
}
var _ transport.GatedMaListener = &gatedMaListener{}
func (l *gatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
for {
conn, err := l.Listener.Accept()
if err != nil {
return nil, nil, err
}
// gate the connection if applicable
if l.connGater != nil && !l.connGater.InterceptAccept(conn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
conn.LocalMultiaddr(), conn.RemoteMultiaddr())
if err := conn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, conn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := conn.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
return conn, connScope, nil
}
}

View File

@@ -30,7 +30,7 @@ func createListener(t *testing.T, u transport.Upgrader) transport.Listener {
require.NoError(t, err) require.NoError(t, err)
ln, err := manet.Listen(addr) ln, err := manet.Listen(addr)
require.NoError(t, err) require.NoError(t, err)
return u.UpgradeListener(nil, ln) return u.UpgradeGatedMaListener(nil, u.GateMaListener(ln))
} }
func TestAcceptSingleConn(t *testing.T) { func TestAcceptSingleConn(t *testing.T) {

View File

@@ -105,19 +105,32 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener. // UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener { func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener {
ctx, cancel := context.WithCancel(context.Background()) return u.UpgradeGatedMaListener(t, u.GateMaListener(list))
l := &listener{ }
Listener: list,
upgrader: u, func (u *upgrader) GateMaListener(l manet.Listener) transport.GatedMaListener {
transport: t, return &gatedMaListener{
Listener: l,
rcmgr: u.rcmgr, rcmgr: u.rcmgr,
threshold: newThreshold(AcceptQueueLength), connGater: u.connGater,
incoming: make(chan transport.CapableConn),
cancel: cancel,
ctx: ctx,
} }
go l.handleIncoming() }
return l
// UpgradeGatedMaListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
func (u *upgrader) UpgradeGatedMaListener(t transport.Transport, l transport.GatedMaListener) transport.Listener {
ctx, cancel := context.WithCancel(context.Background())
list := &listener{
GatedMaListener: l,
upgrader: u,
transport: t,
rcmgr: u.rcmgr,
threshold: newThreshold(AcceptQueueLength),
incoming: make(chan transport.CapableConn),
cancel: cancel,
ctx: ctx,
}
go list.handleIncoming()
return list
} }
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection. // Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.

View File

@@ -101,7 +101,7 @@ func (c *Client) Listen(addr ma.Multiaddr) (transport.Listener, error) {
return nil, err return nil, err
} }
return c.upgrader.UpgradeListener(c, c.Listener()), nil return c.upgrader.UpgradeGatedMaListener(c, c.upgrader.GateMaListener(c.Listener())), nil
} }
func (c *Client) Protocols() []int { func (c *Client) Protocols() []int {

View File

@@ -203,7 +203,7 @@ func TestInterceptAccept(t *testing.T) {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr())) require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}).AnyTimes() }).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") { } else if strings.Contains(tc.Name, "WebSocket") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr())) require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
}) })

View File

@@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/marten-seemann/tcp" "github.com/marten-seemann/tcp"
"github.com/mikioh/tcpinfo" "github.com/mikioh/tcpinfo"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
@@ -253,16 +254,6 @@ func (c *tracingConn) Close() error {
return c.closeErr return c.closeErr
} }
func (c *tracingConn) Scope() network.ConnManagementScope {
if cs, ok := c.Conn.(interface {
Scope() network.ConnManagementScope
}); ok {
return cs.Scope()
}
// upgrader is expected to handle this
return nil
}
func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) { func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
var o tcpinfo.Info var o tcpinfo.Info
var b [256]byte var b [256]byte
@@ -275,19 +266,31 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
} }
type tracingListener struct { type tracingListener struct {
manet.Listener transport.GatedMaListener
collector *aggregatingCollector collector *aggregatingCollector
} }
// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector. // newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector.
func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener { func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) *tracingListener {
return &tracingListener{Listener: l, collector: collector} return &tracingListener{GatedMaListener: l, collector: collector}
} }
func (l *tracingListener) Accept() (manet.Conn, error) { func (l *tracingListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
conn, err := l.Listener.Accept() conn, scope, err := l.GatedMaListener.Accept()
if err != nil { if err != nil {
return nil, err if scope != nil {
scope.Done()
log.Errorf("BUG: got non-nil scope but also an error: %s", err)
}
return nil, nil, err
} }
return newTracingConn(conn, l.collector, false)
tc, err := newTracingConn(conn, l.collector, false)
if err != nil {
log.Errorf("failed to create tracingConn from %T: %s", conn, err)
conn.Close()
scope.Done()
return nil, nil, err
}
return tc, scope, nil
} }

View File

@@ -4,11 +4,16 @@
package tcp package tcp
import manet "github.com/multiformats/go-multiaddr/net" import (
"github.com/libp2p/go-libp2p/core/transport"
manet "github.com/multiformats/go-multiaddr/net"
)
type aggregatingCollector struct{} type aggregatingCollector struct{}
func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) { func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) {
return c, nil return c, nil
} }
func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l } func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) transport.GatedMaListener {
return l
}

View File

@@ -15,8 +15,10 @@ func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) {
peerA, ia := makeInsecureMuxer(t) peerA, ia := makeInsecureMuxer(t)
_, ib := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t)
sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil) upg, err := tptu.New(ia, muxers, nil, nil, nil)
sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil) require.NoError(t, err)
sharedTCPSocketA := tcpreuse.NewConnMgr(false, upg)
sharedTCPSocketB := tcpreuse.NewConnMgr(false, upg)
ua, err := tptu.New(ia, muxers, nil, nil, nil) ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -41,7 +41,7 @@ var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable
func tryKeepAlive(conn net.Conn, keepAlive bool) { func tryKeepAlive(conn net.Conn, keepAlive bool) {
keepAliveConn, ok := conn.(canKeepAlive) keepAliveConn, ok := conn.(canKeepAlive)
if !ok { if !ok {
log.Errorf("Can't set TCP keepalives.") log.Errorf("can't set TCP keepalives. net.Conn of type %T doesn't support SetKeepAlive", conn)
return return
} }
if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil { if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil {
@@ -76,23 +76,23 @@ func tryLinger(conn net.Conn, sec int) {
} }
} }
type tcpListener struct { type tcpGatedMaListener struct {
manet.Listener transport.GatedMaListener
sec int sec int
} }
func (ll *tcpListener) Accept() (manet.Conn, error) { func (ll *tcpGatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
c, err := ll.Listener.Accept() c, scope, err := ll.GatedMaListener.Accept()
if err != nil { if err != nil {
return nil, err if scope != nil {
log.Errorf("BUG: got non-nil scope but also an error: %s", err)
scope.Done()
}
return nil, nil, err
} }
tryLinger(c, ll.sec) tryLinger(c, ll.sec)
tryKeepAlive(c, true) tryKeepAlive(c, true)
// We're not calling OpenConnection in the resource manager here, return c, scope, nil
// since the manet.Conn doesn't allow us to save the scope.
// It's the caller's (usually the p2p/net/upgrader) responsibility
// to call the resource manager.
return c, nil
} }
type Option func(*TcpTransport) error type Option func(*TcpTransport) error
@@ -316,22 +316,26 @@ func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, err
// Listen listens on the given multiaddr. // Listen listens on the given multiaddr.
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) { func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
var list manet.Listener var list transport.GatedMaListener
var err error var err error
if t.sharedTcp != nil {
if t.sharedTcp == nil {
list, err = t.unsharedMAListen(laddr)
} else {
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect) list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
} if err != nil {
if err != nil { return nil, err
return nil, err }
} else {
mal, err := t.unsharedMAListen(laddr)
if err != nil {
return nil, err
}
list = t.upgrader.GateMaListener(mal)
} }
if t.enableMetrics { if t.enableMetrics {
list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector) // TODO: Fix this: The tcpListener wrapping should happen on both enableMetrics and disabledMetrics path
list = newTracingListener(&tcpGatedMaListener{list, 0}, t.metricsCollector)
} }
return t.upgrader.UpgradeListener(t, list), nil return t.upgrader.UpgradeGatedMaListener(t, list), nil
} }
// Protocols returns the list of terminal protocols this transport can dial. // Protocols returns the list of terminal protocols this transport can dial.

View File

@@ -10,19 +10,15 @@ import (
type connWithScope struct { type connWithScope struct {
sampledconn.ManetTCPConnInterface sampledconn.ManetTCPConnInterface
scope network.ConnManagementScope ConnScope network.ConnManagementScope
}
func (c connWithScope) Scope() network.ConnManagementScope {
return c.scope
} }
func (c *connWithScope) Close() error { func (c *connWithScope) Close() error {
c.scope.Done() defer c.ConnScope.Done()
return c.ManetTCPConnInterface.Close() return c.ManetTCPConnInterface.Close()
} }
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) { func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (*connWithScope, error) {
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok { if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
return &connWithScope{tcpconn, scope}, nil return &connWithScope{tcpconn, scope}, nil
} }

View File

@@ -9,7 +9,6 @@ import (
"time" "time"
logging "github.com/ipfs/go-log/v2" logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport" "github.com/libp2p/go-libp2p/p2p/net/reuseport"
@@ -28,32 +27,36 @@ var log = logging.Logger("tcp-demultiplex")
type ConnMgr struct { type ConnMgr struct {
enableReuseport bool enableReuseport bool
reuse reuseport.Transport reuse reuseport.Transport
connGater connmgr.ConnectionGater upgrader transport.Upgrader
rcmgr network.ResourceManager
mx sync.Mutex mx sync.Mutex
listeners map[string]*multiplexedListener listeners map[string]*multiplexedListener
} }
func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr { func NewConnMgr(enableReuseport bool, upgrader transport.Upgrader) *ConnMgr {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
return &ConnMgr{ return &ConnMgr{
enableReuseport: enableReuseport, enableReuseport: enableReuseport,
reuse: reuseport.Transport{}, reuse: reuseport.Transport{},
connGater: gater, upgrader: upgrader,
rcmgr: rcmgr,
listeners: make(map[string]*multiplexedListener), listeners: make(map[string]*multiplexedListener),
} }
} }
func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) { func (t *ConnMgr) gatedMaListen(listenAddr ma.Multiaddr) (transport.GatedMaListener, error) {
var mal manet.Listener
var err error
if t.useReuseport() { if t.useReuseport() {
return t.reuse.Listen(listenAddr) mal, err = t.reuse.Listen(listenAddr)
if err != nil {
return nil, err
}
} else { } else {
return manet.Listen(listenAddr) mal, err = manet.Listen(listenAddr)
if err != nil {
return nil, err
}
} }
return t.upgrader.GateMaListener(mal), nil
} }
func (t *ConnMgr) useReuseport() bool { func (t *ConnMgr) useReuseport() bool {
@@ -80,7 +83,7 @@ func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) {
// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections // DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections
// accepted from returned listeners need to be upgraded with a `transport.Upgrader`. // accepted from returned listeners need to be upgraded with a `transport.Upgrader`.
// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port. // NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port.
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) { func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (transport.GatedMaListener, error) {
if !connType.IsKnown() { if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType) return nil, fmt.Errorf("unknown connection type: %s", connType)
} }
@@ -100,7 +103,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
return dl, nil return dl, nil
} }
l, err := t.maListen(laddr) gmal, err := t.gatedMaListen(laddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -111,19 +114,17 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
t.mx.Lock() t.mx.Lock()
defer t.mx.Unlock() defer t.mx.Unlock()
delete(t.listeners, laddr.String()) delete(t.listeners, laddr.String())
delete(t.listeners, l.Multiaddr().String()) delete(t.listeners, gmal.Multiaddr().String())
return l.Close() return gmal.Close()
} }
ml = &multiplexedListener{ ml = &multiplexedListener{
Listener: l, GatedMaListener: gmal,
listeners: make(map[DemultiplexedConnType]*demultiplexedListener), listeners: make(map[DemultiplexedConnType]*demultiplexedListener),
ctx: ctx, ctx: ctx,
closeFn: cancelFunc, closeFn: cancelFunc,
connGater: t.connGater,
rcmgr: t.rcmgr,
} }
t.listeners[laddr.String()] = ml t.listeners[laddr.String()] = ml
t.listeners[l.Multiaddr().String()] = ml t.listeners[gmal.Multiaddr().String()] = ml
dl, err := ml.DemultiplexedListen(connType) dl, err := ml.DemultiplexedListen(connType)
if err != nil { if err != nil {
@@ -137,23 +138,21 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
return dl, nil return dl, nil
} }
var _ manet.Listener = &demultiplexedListener{} var _ transport.GatedMaListener = &demultiplexedListener{}
type multiplexedListener struct { type multiplexedListener struct {
manet.Listener transport.GatedMaListener
listeners map[DemultiplexedConnType]*demultiplexedListener listeners map[DemultiplexedConnType]*demultiplexedListener
mx sync.RWMutex mx sync.RWMutex
connGater connmgr.ConnectionGater ctx context.Context
rcmgr network.ResourceManager closeFn func() error
ctx context.Context wg sync.WaitGroup
closeFn func() error
wg sync.WaitGroup
} }
var ErrListenerExists = errors.New("listener already exists for this conn type on this address") var ErrListenerExists = errors.New("listener already exists for this conn type on this address")
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) { func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (transport.GatedMaListener, error) {
if !connType.IsKnown() { if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType) return nil, fmt.Errorf("unknown connection type: %s", connType)
} }
@@ -166,8 +165,8 @@ func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType
ctx, cancel := context.WithCancel(m.ctx) ctx, cancel := context.WithCancel(m.ctx)
l := &demultiplexedListener{ l := &demultiplexedListener{
buffer: make(chan manet.Conn), buffer: make(chan *connWithScope),
inner: m.Listener, inner: m.GatedMaListener,
ctx: ctx, ctx: ctx,
cancelFunc: cancel, cancelFunc: cancel,
closeFn: func() error { m.removeDemultiplexedListener(connType); return nil }, closeFn: func() error { m.removeDemultiplexedListener(connType); return nil },
@@ -183,53 +182,35 @@ func (m *multiplexedListener) run() error {
defer m.wg.Done() defer m.wg.Done()
acceptQueue := make(chan struct{}, acceptQueueSize) acceptQueue := make(chan struct{}, acceptQueueSize)
for { for {
c, err := m.Listener.Accept() c, connScope, err := m.GatedMaListener.Accept()
if err != nil { if err != nil {
return err return err
} }
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
// Gate and resource limit the connection here.
// If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer
// which clogs up our entire connection queue.
// This duplicates the responsibility of gating and resource limiting between here and the upgrader. The
// alternative without duplication requires moving the process of upgrading the connection here, which forces
// us to establish the websocket connection here. That is more duplication, or a significant breaking change.
//
// Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport
// integration tests.
if m.connGater != nil && !m.connGater.InterceptAccept(c) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
c.LocalMultiaddr(), c.RemoteMultiaddr())
if err := c.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := c.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
select { select {
case acceptQueue <- struct{}{}: case acceptQueue <- struct{}{}:
// NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader. case <-ctx.Done():
case <-m.ctx.Done(): cancelCtx()
connScope.Done()
c.Close() c.Close()
log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr()) log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr())
continue
case <-m.ctx.Done():
cancelCtx()
connScope.Done()
c.Close()
log.Debugf("listener closed; dropping connection from: %s", c.RemoteMultiaddr())
continue
} }
m.wg.Add(1) m.wg.Add(1)
go func() { go func() {
defer func() { <-acceptQueue }() defer func() { <-acceptQueue }()
defer m.wg.Done() defer m.wg.Done()
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
defer cancelCtx() defer cancelCtx()
t, c, err := identifyConnType(c) t, c, err := identifyConnType(c)
if err != nil { if err != nil {
// conn closed by identifyConnType
connScope.Done() connScope.Done()
log.Debugf("error demultiplexing connection: %s", err.Error()) log.Debugf("error demultiplexing connection: %s", err.Error())
return return
@@ -279,7 +260,7 @@ func (m *multiplexedListener) Close() error {
} }
func (m *multiplexedListener) closeListener() error { func (m *multiplexedListener) closeListener() error {
lerr := m.Listener.Close() lerr := m.GatedMaListener.Close()
cerr := m.closeFn() cerr := m.closeFn()
return errors.Join(lerr, cerr) return errors.Join(lerr, cerr)
} }
@@ -298,19 +279,19 @@ func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnTyp
} }
type demultiplexedListener struct { type demultiplexedListener struct {
buffer chan manet.Conn buffer chan *connWithScope
inner manet.Listener inner transport.GatedMaListener
ctx context.Context ctx context.Context
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
closeFn func() error closeFn func() error
} }
func (m *demultiplexedListener) Accept() (manet.Conn, error) { func (m *demultiplexedListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
select { select {
case c := <-m.buffer: case c := <-m.buffer:
return c, nil return c.ManetTCPConnInterface, c.ConnScope, nil
case <-m.ctx.Done(): case <-m.ctx.Done():
return nil, transport.ErrListenerClosed return nil, nil, transport.ErrListenerClosed
} }
} }

View File

@@ -17,6 +17,9 @@ import (
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multistream" "github.com/multiformats/go-multistream"
@@ -53,6 +56,17 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
return tlsConfig return tlsConfig
} }
type maListener struct {
transport.GatedMaListener
}
var _ manet.Listener = &maListener{}
func (ml *maListener) Accept() (manet.Conn, error) {
c, _, err := ml.GatedMaListener.Accept()
return c, err
}
type wsHandler struct{ conns chan *websocket.Conn } type wsHandler struct{ conns chan *websocket.Conn }
func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -61,12 +75,19 @@ func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wh.conns <- c wh.conns <- c
} }
func upgrader(t *testing.T) transport.Upgrader {
t.Helper()
upd, err := tptu.New(nil, nil, nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
return upd
}
func TestListenerSingle(t *testing.T) { func TestListenerSingle(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 64 const N = 64
for _, enableReuseport := range []bool{true, false} { for _, enableReuseport := range []bool{true, false} {
t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) { t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil) cm := NewConnMgr(enableReuseport, upgrader(t))
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err) require.NoError(t, err)
go func() { go func() {
@@ -96,7 +117,7 @@ func TestListenerSingle(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
c, err := l.Accept() c, _, err := l.Accept()
require.NoError(t, err) require.NoError(t, err)
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -117,12 +138,12 @@ func TestListenerSingle(t *testing.T) {
}) })
t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) { t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil) cm := NewConnMgr(enableReuseport, upgrader(t))
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err) require.NoError(t, err)
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() { go func() {
http.Serve(manet.NetListener(l), wh) http.Serve(manet.NetListener(&maListener{GatedMaListener: l}), wh)
}() }()
go func() { go func() {
d := websocket.Dialer{} d := websocket.Dialer{}
@@ -169,14 +190,14 @@ func TestListenerSingle(t *testing.T) {
}) })
t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) { t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil) cm := NewConnMgr(enableReuseport, upgrader(t))
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err) require.NoError(t, err)
defer l.Close() defer l.Close()
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() { go func() {
s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)} s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(l), "", "") s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: l}), "", "")
}() }()
go func() { go func() {
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
@@ -228,7 +249,7 @@ func TestListenerMultiplexed(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 20 const N = 20
for _, enableReuseport := range []bool{true, false} { for _, enableReuseport := range []bool{true, false} {
cm := NewConnMgr(enableReuseport, nil, nil) cm := NewConnMgr(enableReuseport, upgrader(t))
msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err) require.NoError(t, err)
defer msl.Close() defer msl.Close()
@@ -239,7 +260,7 @@ func TestListenerMultiplexed(t *testing.T) {
require.Equal(t, wsl.Multiaddr(), msl.Multiaddr()) require.Equal(t, wsl.Multiaddr(), msl.Multiaddr())
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() { go func() {
http.Serve(manet.NetListener(wsl), wh) http.Serve(manet.NetListener(&maListener{GatedMaListener: wsl}), wh)
}() }()
wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
@@ -249,7 +270,7 @@ func TestListenerMultiplexed(t *testing.T) {
whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() { go func() {
s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)} s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(wssl), "", "") s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: wssl}), "", "")
}() }()
// multistream connections // multistream connections
@@ -331,7 +352,7 @@ func TestListenerMultiplexed(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
c, err := msl.Accept() c, _, err := msl.Accept()
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@@ -404,7 +425,7 @@ func TestListenerMultiplexed(t *testing.T) {
func TestListenerClose(t *testing.T) { func TestListenerClose(t *testing.T) {
testClose := func(listenAddr ma.Multiaddr) { testClose := func(listenAddr ma.Multiaddr) {
// listen on port 0 // listen on port 0
cm := NewConnMgr(false, nil, nil) cm := NewConnMgr(false, upgrader(t))
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err) require.NoError(t, err)
wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
@@ -459,7 +480,7 @@ func setDeferReset[T any](t testing.TB, ptr *T, val T) {
func TestHitTimeout(t *testing.T) { func TestHitTimeout(t *testing.T) {
setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond) setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond)
// listen on port 0 // listen on port 0
cm := NewConnMgr(false, nil, nil) cm := NewConnMgr(false, upgrader(t))
listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0") listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0")
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)

View File

@@ -4,8 +4,6 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/stretchr/testify/require"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
) )
@@ -67,15 +65,3 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network()) t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network())
} }
} }
func TestListeningOnDNSAddr(t *testing.T) {
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
require.NoError(t, err)
addr := ln.Multiaddr()
first, rest := ma.SplitFirst(addr)
require.Equal(t, ma.P_DNS, first.Protocol().Code)
require.Equal(t, "localhost", first.Value())
next, _ := ma.SplitFirst(rest)
require.Equal(t, ma.P_TCP, next.Protocol().Code)
require.NotEqual(t, 0, next.Value())
}

View File

@@ -1,7 +1,6 @@
package websocket package websocket
import ( import (
"crypto/tls"
"errors" "errors"
"io" "io"
"net" "net"
@@ -9,7 +8,6 @@ import (
"time" "time"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net" manet "github.com/multiformats/go-multiaddr/net"
@@ -23,6 +21,7 @@ var GracefulCloseTimeout = 100 * time.Millisecond
// Conn implements net.Conn interface for gorilla/websocket. // Conn implements net.Conn interface for gorilla/websocket.
type Conn struct { type Conn struct {
*ws.Conn *ws.Conn
Scope network.ConnManagementScope
secure bool secure bool
DefaultMessageType int DefaultMessageType int
reader io.Reader reader io.Reader
@@ -36,10 +35,8 @@ type Conn struct {
var _ net.Conn = (*Conn)(nil) var _ net.Conn = (*Conn)(nil)
var _ manet.Conn = (*Conn)(nil) var _ manet.Conn = (*Conn)(nil)
// NewConn creates a Conn given a regular gorilla/websocket Conn. // newConn creates a Conn given a regular gorilla/websocket Conn.
// func newConn(raw *ws.Conn, secure bool, scope network.ConnManagementScope) *Conn {
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
func NewConn(raw *ws.Conn, secure bool) *Conn {
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
laddr, err := manet.FromNetAddr(lna) laddr, err := manet.FromNetAddr(lna)
if err != nil { if err != nil {
@@ -56,6 +53,7 @@ func NewConn(raw *ws.Conn, secure bool) *Conn {
c := &Conn{ c := &Conn{
Conn: raw, Conn: raw,
Scope: scope,
secure: secure, secure: secure,
DefaultMessageType: ws.BinaryMessage, DefaultMessageType: ws.BinaryMessage,
laddr: laddr, laddr: laddr,
@@ -136,23 +134,6 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return len(b), nil return len(b), nil
} }
func (c *Conn) Scope() network.ConnManagementScope {
nc := c.NetConn()
if sc, ok := nc.(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
if nc, ok := nc.(*tls.Conn); ok {
if sc, ok := nc.NetConn().(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
}
return nil
}
// Close closes the connection. // Close closes the connection.
// subsequent and concurrent calls will return the same error value. // subsequent and concurrent calls will return the same error value.
// This method is thread-safe. // This method is thread-safe.
@@ -201,13 +182,3 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t) return c.Conn.SetWriteDeadline(t)
} }
type capableConn struct {
transport.CapableConn
}
func (c *capableConn) ConnState() network.ConnectionState {
cs := c.CapableConn.ConnState()
cs.Transport = "websocket"
return cs
}

View File

@@ -1,17 +1,22 @@
package websocket package websocket
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url"
"sync" "sync"
"time"
"go.uber.org/zap" "go.uber.org/zap"
ws "github.com/gorilla/websocket"
logging "github.com/ipfs/go-log/v2" logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
@@ -23,8 +28,9 @@ var log = logging.Logger("websocket-transport")
var stdLog = zap.NewStdLog(log.Desugar()) var stdLog = zap.NewStdLog(log.Desugar())
type listener struct { type listener struct {
nl net.Listener netListener *httpNetListener
server http.Server server http.Server
wsUpgrader ws.Upgrader
// The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS, // The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS,
// so we can't rely on checking if server.TLSConfig is set. // so we can't rely on checking if server.TLSConfig is set.
isWss bool isWss bool
@@ -36,8 +42,11 @@ type listener struct {
closeOnce sync.Once closeOnce sync.Once
closeErr error closeErr error
closed chan struct{} closed chan struct{}
wsurl *url.URL
} }
var _ transport.GatedMaListener = &listener{}
func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
if !pwma.isWSS { if !pwma.isWSS {
return pwma.restMultiaddr.AppendComponent(wsComponent) return pwma.restMultiaddr.AppendComponent(wsComponent)
@@ -52,7 +61,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
// newListener creates a new listener from a raw net.Listener. // newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets). // tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, upgrader transport.Upgrader, handshakeTimeout time.Duration) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a) parsed, err := parseWebsocketMultiaddr(a)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -62,17 +71,13 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
} }
var nl net.Listener var gmal transport.GatedMaListener
if sharedTcp == nil { if sharedTcp == nil {
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) mal, err := manet.Listen(parsed.restMultiaddr)
if err != nil {
return nil, err
}
nl, err = net.Listen(lnet, lnaddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
gmal = upgrader.GateMaListener(mal)
} else { } else {
var connType tcpreuse.DemultiplexedConnType var connType tcpreuse.DemultiplexedConnType
if parsed.isWSS { if parsed.isWSS {
@@ -80,89 +85,146 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
} else { } else {
connType = tcpreuse.DemultiplexedConnType_HTTP connType = tcpreuse.DemultiplexedConnType_HTTP
} }
mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType) gmal, err = sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
nl = manet.NetListener(mal)
} }
laddr, err := manet.FromNetAddr(nl.Addr()) // laddr has the correct port in case we listened on port 0
if err != nil { laddr := gmal.Multiaddr()
return nil, err
}
first, _ := ma.SplitFirst(a)
// Don't resolve dns addresses. // Don't resolve dns addresses.
// We want to be able to announce domain names, so the peer can validate the TLS certificate. // We want to be able to announce domain names, so the peer can validate the TLS certificate.
first, _ := ma.SplitFirst(a)
if c := first.Protocol().Code; c == ma.P_DNS || c == ma.P_DNS4 || c == ma.P_DNS6 || c == ma.P_DNSADDR { if c := first.Protocol().Code; c == ma.P_DNS || c == ma.P_DNS4 || c == ma.P_DNS6 || c == ma.P_DNSADDR {
_, last := ma.SplitFirst(laddr) _, last := ma.SplitFirst(laddr)
laddr = first.Encapsulate(last) laddr = first.Encapsulate(last)
} }
parsed.restMultiaddr = laddr parsed.restMultiaddr = laddr
listenAddr := parsed.toMultiaddr()
wsurl, err := parseMultiaddr(listenAddr)
if err != nil {
gmal.Close()
return nil, fmt.Errorf("failed to parse multiaddr to URL: %v: %w", listenAddr, err)
}
ln := &listener{ ln := &listener{
nl: nl, netListener: &httpNetListener{
GatedMaListener: gmal,
handshakeTimeout: handshakeTimeout,
},
laddr: parsed.toMultiaddr(), laddr: parsed.toMultiaddr(),
incoming: make(chan *Conn), incoming: make(chan *Conn),
closed: make(chan struct{}), closed: make(chan struct{}),
isWss: parsed.isWSS,
wsurl: wsurl,
wsUpgrader: ws.Upgrader{
// Allow requests from *all* origins.
CheckOrigin: func(r *http.Request) bool {
return true
},
HandshakeTimeout: handshakeTimeout,
},
} }
ln.server = http.Server{Handler: ln, ErrorLog: stdLog} ln.server = http.Server{Handler: ln, ErrorLog: stdLog, ConnContext: ln.ConnContext, TLSConfig: tlsConf}
if parsed.isWSS {
ln.isWss = true
ln.server.TLSConfig = tlsConf
}
return ln, nil return ln, nil
} }
func (l *listener) serve() { func (l *listener) serve() {
defer close(l.closed) defer close(l.closed)
if !l.isWss { if !l.isWss {
l.server.Serve(l.nl) l.server.Serve(l.netListener)
} else { } else {
l.server.ServeTLS(l.nl, "", "") l.server.ServeTLS(l.netListener, "", "")
} }
} }
type connKey struct{}
func (l *listener) ConnContext(ctx context.Context, c net.Conn) context.Context {
// prefer `*tls.Conn` over `(interface{NetConn() net.Conn})` in case `manet.Conn` is extended
// to support a `NetConn() net.Conn` method.
if tc, ok := c.(*tls.Conn); ok {
c = tc.NetConn()
}
if nc, ok := c.(*negotiatingConn); ok {
return context.WithValue(ctx, connKey{}, nc)
}
log.Errorf("BUG: expected net.Conn of type *websocket.negotiatingConn: got %T", c)
// might as well close the connection as there's no way to proceed now.
c.Close()
return ctx
}
func (l *listener) extractConnFromContext(ctx context.Context) (*negotiatingConn, error) {
c := ctx.Value(connKey{})
if c == nil {
return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got nil")
}
nc, ok := c.(*negotiatingConn)
if !ok {
return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got %T", c)
}
return nc, nil
}
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil) c, err := l.wsUpgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
// The upgrader writes a response for us. // The upgrader writes a response for us.
return return
} }
nc := NewConn(c, l.isWss) nc, err := l.extractConnFromContext(r.Context())
if nc == nil { if err != nil {
c.Close()
w.WriteHeader(500)
log.Errorf("BUG: failed to extract conn from context: RemoteAddr: %s: err: %s", r.RemoteAddr, err)
return
}
cs, err := nc.Unwrap()
if err != nil {
c.Close()
w.WriteHeader(500)
log.Debugf("connection timed out from: %s", r.RemoteAddr)
return
}
conn := newConn(c, l.isWss, cs.Scope)
if conn == nil {
c.Close() c.Close()
w.WriteHeader(500) w.WriteHeader(500)
return return
} }
select { select {
case l.incoming <- nc: case l.incoming <- conn:
case <-l.closed: case <-l.closed:
nc.Close() conn.Close()
} }
// The connection has been hijacked, it's safe to return. // The connection has been hijacked, it's safe to return.
} }
func (l *listener) Accept() (manet.Conn, error) { func (l *listener) Accept() (manet.Conn, network.ConnManagementScope, error) {
select { select {
case c, ok := <-l.incoming: case c, ok := <-l.incoming:
if !ok { if !ok {
return nil, transport.ErrListenerClosed return nil, nil, transport.ErrListenerClosed
} }
return c, nil return c, c.Scope, nil
case <-l.closed: case <-l.closed:
return nil, transport.ErrListenerClosed return nil, nil, transport.ErrListenerClosed
} }
} }
func (l *listener) Addr() net.Addr { func (l *listener) Addr() net.Addr {
return l.nl.Addr() return &Addr{URL: l.wsurl}
} }
func (l *listener) Close() error { func (l *listener) Close() error {
l.closeOnce.Do(func() { l.closeOnce.Do(func() {
err1 := l.nl.Close() err1 := l.netListener.Close()
err2 := l.server.Close() err2 := l.server.Close()
<-l.closed <-l.closed
l.closeErr = errors.Join(err1, err2) l.closeErr = errors.Join(err1, err2)
@@ -174,14 +236,74 @@ func (l *listener) Multiaddr() ma.Multiaddr {
return l.laddr return l.laddr
} }
type transportListener struct { // httpNetListener is a net.Listener that adapts a transport.GatedMaListener to a net.Listener.
transport.Listener // It wraps the manet.Conn, and the Scope from the underlying gated listener in a connWithScope.
type httpNetListener struct {
transport.GatedMaListener
handshakeTimeout time.Duration
} }
func (l *transportListener) Accept() (transport.CapableConn, error) { var _ net.Listener = &httpNetListener{}
conn, err := l.Listener.Accept()
func (l *httpNetListener) Accept() (net.Conn, error) {
conn, scope, err := l.GatedMaListener.Accept()
if err != nil { if err != nil {
if scope != nil {
log.Errorf("BUG: scope non-nil when err is non nil: %v", err)
scope.Done()
}
return nil, err return nil, err
} }
return &capableConn{CapableConn: conn}, nil connWithScope := connWithScope{
Conn: conn,
Scope: scope,
}
ctx, cancel := context.WithTimeout(context.Background(), l.handshakeTimeout)
return &negotiatingConn{
connWithScope: connWithScope,
ctx: ctx,
cancelCtx: cancel,
stopClose: context.AfterFunc(ctx, func() {
connWithScope.Close()
log.Debugf("handshake timeout for conn from: %s", conn.RemoteAddr())
}),
}, nil
}
type connWithScope struct {
net.Conn
Scope network.ConnManagementScope
}
func (c connWithScope) Close() error {
c.Scope.Done()
return c.Conn.Close()
}
type negotiatingConn struct {
connWithScope
ctx context.Context
cancelCtx context.CancelFunc
stopClose func() bool
}
// Close closes the negotiating conn and the underlying connWithScope
// This will be called in case the tls handshake or websocket upgrade fails.
func (c *negotiatingConn) Close() error {
defer c.cancelCtx()
if c.stopClose != nil {
c.stopClose()
}
return c.connWithScope.Close()
}
func (c *negotiatingConn) Unwrap() (connWithScope, error) {
defer c.cancelCtx()
if c.stopClose != nil {
if !c.stopClose() {
return connWithScope{}, errors.New("timed out")
}
c.stopClose = nil
}
return c.connWithScope, nil
} }

View File

@@ -5,7 +5,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"net" "net"
"net/http"
"time" "time"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
@@ -51,14 +50,6 @@ func init() {
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss") manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss")
} }
// Default gorilla upgrader
var upgrader = ws.Upgrader{
// Allow requests from *all* origins.
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Option func(*WebsocketTransport) error type Option func(*WebsocketTransport) error
// WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only // WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only
@@ -81,15 +72,24 @@ func WithTLSConfig(conf *tls.Config) Option {
} }
} }
var defaultHandshakeTimeout = 15 * time.Second
// WithHandshakeTimeout sets a timeout for the websocket upgrade.
func WithHandshakeTimeout(timeout time.Duration) Option {
return func(t *WebsocketTransport) error {
t.handshakeTimeout = timeout
return nil
}
}
// WebsocketTransport is the actual go-libp2p transport // WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct { type WebsocketTransport struct {
upgrader transport.Upgrader upgrader transport.Upgrader
rcmgr network.ResourceManager rcmgr network.ResourceManager
tlsClientConf *tls.Config
tlsClientConf *tls.Config tlsConf *tls.Config
tlsConf *tls.Config sharedTcp *tcpreuse.ConnMgr
handshakeTimeout time.Duration
sharedTcp *tcpreuse.ConnMgr
} }
var _ transport.Transport = (*WebsocketTransport)(nil) var _ transport.Transport = (*WebsocketTransport)(nil)
@@ -99,10 +99,11 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreus
rcmgr = &network.NullResourceManager{} rcmgr = &network.NullResourceManager{}
} }
t := &WebsocketTransport{ t := &WebsocketTransport{
upgrader: u, upgrader: u,
rcmgr: rcmgr, rcmgr: rcmgr,
tlsClientConf: &tls.Config{}, tlsClientConf: &tls.Config{},
sharedTcp: sharedTCP, sharedTcp: sharedTCP,
handshakeTimeout: defaultHandshakeTimeout,
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(t); err != nil { if err := opt(t); err != nil {
@@ -176,7 +177,7 @@ func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p pee
} }
func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) { func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) {
macon, err := t.maDial(ctx, raddr) macon, err := t.maDial(ctx, raddr, connScope)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -187,14 +188,14 @@ func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiad
return &capableConn{CapableConn: conn}, nil return &capableConn{CapableConn: conn}, nil
} }
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) { func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr, scope network.ConnManagementScope) (manet.Conn, error) {
wsurl, err := parseMultiaddr(raddr) wsurl, err := parseMultiaddr(raddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
isWss := wsurl.Scheme == "wss" isWss := wsurl.Scheme == "wss"
dialer := ws.Dialer{ dialer := ws.Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: t.handshakeTimeout,
// Inherit the default proxy behavior // Inherit the default proxy behavior
Proxy: ws.DefaultDialer.Proxy, Proxy: ws.DefaultDialer.Proxy,
} }
@@ -236,7 +237,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
return nil, err return nil, err
} }
mnc, err := manet.WrapNetConn(NewConn(wscon, isWss)) mnc, err := manet.WrapNetConn(newConn(wscon, isWss, scope))
if err != nil { if err != nil {
wscon.Close() wscon.Close()
return nil, err return nil, err
@@ -244,12 +245,12 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
return mnc, nil return mnc, nil
} }
func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { func (t *WebsocketTransport) gatedMaListen(a ma.Multiaddr) (transport.GatedMaListener, error) {
var tlsConf *tls.Config var tlsConf *tls.Config
if t.tlsConf != nil { if t.tlsConf != nil {
tlsConf = t.tlsConf.Clone() tlsConf = t.tlsConf.Clone()
} }
l, err := newListener(a, tlsConf, t.sharedTcp) l, err := newListener(a, tlsConf, t.sharedTcp, t.upgrader, t.handshakeTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -258,9 +259,32 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
} }
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) { func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) {
malist, err := t.maListen(a) gmal, err := t.gatedMaListen(a)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil return &transportListener{Listener: t.upgrader.UpgradeGatedMaListener(t, gmal)}, nil
}
// transportListener wraps a transport.Listener to provide connections with a `ConnState() network.ConnectionState` method.
type transportListener struct {
transport.Listener
}
type capableConn struct {
transport.CapableConn
}
func (c *capableConn) ConnState() network.ConnectionState {
cs := c.CapableConn.ConnState()
cs.Transport = "websocket"
return cs
}
func (l *transportListener) Accept() (transport.CapableConn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &capableConn{CapableConn: conn}, nil
} }

View File

@@ -35,6 +35,7 @@ import (
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -296,18 +297,39 @@ func TestDialWssNoClientCert(t *testing.T) {
} }
func TestWebsocketTransport(t *testing.T) { func TestWebsocketTransport(t *testing.T) {
peerA, ua := newUpgrader(t) t.Run("/ws", func(t *testing.T) {
ta, err := New(ua, nil, nil) peerA, ua := newUpgrader(t)
if err != nil { ta, err := New(ua, nil, nil)
t.Fatal(err) if err != nil {
} t.Fatal(err)
_, ub := newUpgrader(t) }
tb, err := New(ub, nil, nil) peerB, ub := newUpgrader(t)
if err != nil { tb, err := New(ub, nil, nil)
t.Fatal(err) if err != nil {
} t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA) ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
t.Run("/wss", func(t *testing.T) {
peerA, ua := newUpgrader(t)
tca := generateTLSConfig(t)
ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
peerB, ub := newUpgrader(t)
tcb := generateTLSConfig(t)
tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
} }
func isWSS(addr ma.Multiaddr) bool { func isWSS(addr ma.Multiaddr) bool {
@@ -441,7 +463,7 @@ func TestConcurrentClose(t *testing.T) {
_, u := newUpgrader(t) _, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil) tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err) require.NoError(t, err)
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -451,7 +473,7 @@ func TestConcurrentClose(t *testing.T) {
go func() { go func() {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
c, err := tpt.maDial(context.Background(), l.Multiaddr()) c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@@ -467,7 +489,7 @@ func TestConcurrentClose(t *testing.T) {
}() }()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
c, err := l.Accept() c, _, err := l.Accept()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -481,7 +503,7 @@ func TestWriteZero(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -490,7 +512,7 @@ func TestWriteZero(t *testing.T) {
msg := []byte(nil) msg := []byte(nil)
go func() { go func() {
c, err := tpt.maDial(context.Background(), l.Multiaddr()) c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@@ -509,7 +531,7 @@ func TestWriteZero(t *testing.T) {
} }
}() }()
c, err := l.Accept() c, _, err := l.Accept()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -623,3 +645,98 @@ func TestSocksProxy(t *testing.T) {
}) })
} }
} }
func TestListenerAddr(t *testing.T) {
_, upgrader := newUpgrader(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer l1.Close()
require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String())
l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer l2.Close()
require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String())
}
func TestHandshakeTimeout(t *testing.T) {
handshakeTimeout := 200 * time.Millisecond
_, upgrader := newUpgrader(t)
tlsconf := generateTLSConfig(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf))
require.NoError(t, err)
fastWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
NetDial: func(network, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
return tcpConn, nil
},
}
slowWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
NetDial: func(network, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
// wait to simulate a slow handshake
time.Sleep(2 * handshakeTimeout)
return tcpConn, nil
},
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
t.Run("ws", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer wsListener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if !assert.NoError(t, err) {
return
}
conn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
conn.Close()
resp.Body.Close()
t.Fatal("should error as the handshake will time out")
}
})
t.Run("wss", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer wsListener.Close()
// Test that the normal dial works fine
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
require.NoError(t, err)
wsConn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
wsConn.Close()
resp.Body.Close()
t.Fatal("websocket handshake should have timed out")
}
})
}