transport: add GatedMaListener type (#3186)

This introduces a new GatedMaListener type which gates conns
accepted from a manet.Listener with a gater and creates the rcmgr
scope for it. Explicitly passing the scope allows for many guardrails
that the previous interface assertion didn't.

This breaks the previous responsibility of the upgradeListener method
into two, one gating the connection initially, and the other upgrading
the connection with a security and muxer selection.

This split makes it easy to gate the connection with the resource
manager as early as possible. This is especially true for websocket
because we want to gate the connection just after the TCP connection is
established, and not after the tls handshake + websocket upgrade is
completed.
This commit is contained in:
sukun
2025-03-25 22:09:57 +05:30
committed by GitHub
parent 8430ad3e2f
commit 6249e685e9
19 changed files with 602 additions and 317 deletions

View File

@@ -292,11 +292,11 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr {
fx.Provide(func(upgrader transport.Upgrader) *tcpreuse.ConnMgr {
if !cfg.ShareTCPListener {
return nil
}
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr)
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, upgrader)
}),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {

View File

@@ -129,11 +129,41 @@ type TransportNetwork interface {
AddTransport(t Transport) error
}
// GatedMaListener is listener that listens for raw(unsecured and non-multiplexed) incoming connections,
// gates them with a `connmgr.ConnGater`and creates a resource management scope for them.
// It can be upgraded to a full libp2p transport listener by the Upgrader.
//
// Compared to manet.Listener, this listener creates the resource management scope for the accepted connection.
type GatedMaListener interface {
// Accept waits for and returns the next connection to the listener.
Accept() (manet.Conn, network.ConnManagementScope, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error
// Multiaddr returns the listener's (local) Multiaddr.
Multiaddr() ma.Multiaddr
// Addr returns the net.Listener's network address.
Addr() net.Addr
}
// Upgrader is a multistream upgrader that can upgrade an underlying connection
// to a full transport connection (secure and multiplexed).
type Upgrader interface {
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
//
// Deprecated: Use UpgradeGatedMaListener(upgrader.GateMaListener(manet.Listener)) instead.
UpgradeListener(Transport, manet.Listener) Listener
// GateMaListener creates a GatedMaListener from a manet.Listener. It gates the accepted connection
// and creates a resource scope for it.
GateMaListener(manet.Listener) GatedMaListener
// UpgradeGatedMaListener upgrades the passed GatedMaListener into a full libp2p-transport listener.
UpgradeGatedMaListener(Transport, GatedMaListener) Listener
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.
Upgrade(ctx context.Context, t Transport, maconn manet.Conn, dir network.Direction, p peer.ID, scope network.ConnManagementScope) (CapableConn, error)
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"sync"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
@@ -17,7 +18,7 @@ import (
var log = logging.Logger("upgrader")
type listener struct {
manet.Listener
transport.GatedMaListener
transport transport.Transport
upgrader *upgrader
@@ -35,10 +36,12 @@ type listener struct {
cancel func()
}
var _ transport.Listener = (*listener)(nil)
// Close closes the listener.
func (l *listener) Close() error {
// Do this first to try to get any relevant errors.
err := l.Listener.Close()
err := l.GatedMaListener.Close()
l.cancel()
// Drain and wait.
@@ -61,7 +64,7 @@ func (l *listener) handleIncoming() {
var wg sync.WaitGroup
defer func() {
// make sure we're closed
l.Listener.Close()
l.GatedMaListener.Close()
if l.err == nil {
l.err = fmt.Errorf("listener closed")
}
@@ -72,7 +75,7 @@ func (l *listener) handleIncoming() {
var catcher tec.TempErrCatcher
for l.ctx.Err() == nil {
maconn, err := l.Listener.Accept()
maconn, connScope, err := l.GatedMaListener.Accept()
if err != nil {
// Note: function may pause the accept loop.
if catcher.IsTemporary(err) {
@@ -84,35 +87,12 @@ func (l *listener) handleIncoming() {
}
catcher.Reset()
// Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation.
var connScope network.ConnManagementScope
if sc, ok := maconn.(interface {
Scope() network.ConnManagementScope
}); ok {
connScope = sc.Scope()
}
if connScope == nil {
// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
log.Errorf("BUG: got nil connScope for incoming connection from %s", maconn.RemoteMultiaddr())
maconn.Close()
continue
}
var err error
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
}
// The go routine below calls Release when the context is
// canceled so there's no need to wait on it here.
l.threshold.Wait()
@@ -154,14 +134,10 @@ func (l *listener) handleIncoming() {
select {
case l.incoming <- conn:
case <-ctx.Done():
// Listener not closed but the accept timeout expired.
if l.ctx.Err() == nil {
// Listener *not* closed but the accept timeout expired.
log.Warn("listener dropped connection due to slow accept")
log.Warnf("listener dropped connection due to slow accept. remote addr: %s peer: %s", maconn.RemoteMultiaddr(), conn.RemotePeer())
}
// Wait on the context with a timeout. This way,
// if we stop accepting connections for some reason,
// we'll eventually close all the open ones
// instead of hanging onto them.
conn.CloseWithError(network.ConnRateLimited)
}
}()
@@ -189,4 +165,38 @@ func (l *listener) String() string {
return fmt.Sprintf("<stream.Listener %s>", l.Multiaddr())
}
var _ transport.Listener = (*listener)(nil)
type gatedMaListener struct {
manet.Listener
rcmgr network.ResourceManager
connGater connmgr.ConnectionGater
}
var _ transport.GatedMaListener = &gatedMaListener{}
func (l *gatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
for {
conn, err := l.Listener.Accept()
if err != nil {
return nil, nil, err
}
// gate the connection if applicable
if l.connGater != nil && !l.connGater.InterceptAccept(conn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
conn.LocalMultiaddr(), conn.RemoteMultiaddr())
if err := conn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, conn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := conn.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
return conn, connScope, nil
}
}

View File

@@ -30,7 +30,7 @@ func createListener(t *testing.T, u transport.Upgrader) transport.Listener {
require.NoError(t, err)
ln, err := manet.Listen(addr)
require.NoError(t, err)
return u.UpgradeListener(nil, ln)
return u.UpgradeGatedMaListener(nil, u.GateMaListener(ln))
}
func TestAcceptSingleConn(t *testing.T) {

View File

@@ -105,9 +105,22 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener {
return u.UpgradeGatedMaListener(t, u.GateMaListener(list))
}
func (u *upgrader) GateMaListener(l manet.Listener) transport.GatedMaListener {
return &gatedMaListener{
Listener: l,
rcmgr: u.rcmgr,
connGater: u.connGater,
}
}
// UpgradeGatedMaListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
func (u *upgrader) UpgradeGatedMaListener(t transport.Transport, l transport.GatedMaListener) transport.Listener {
ctx, cancel := context.WithCancel(context.Background())
l := &listener{
Listener: list,
list := &listener{
GatedMaListener: l,
upgrader: u,
transport: t,
rcmgr: u.rcmgr,
@@ -116,8 +129,8 @@ func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) t
cancel: cancel,
ctx: ctx,
}
go l.handleIncoming()
return l
go list.handleIncoming()
return list
}
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.

View File

@@ -101,7 +101,7 @@ func (c *Client) Listen(addr ma.Multiaddr) (transport.Listener, error) {
return nil, err
}
return c.upgrader.UpgradeListener(c, c.Listener()), nil
return c.upgrader.UpgradeGatedMaListener(c, c.upgrader.GateMaListener(c.Listener())), nil
}
func (c *Client) Protocols() []int {

View File

@@ -203,7 +203,7 @@ func TestInterceptAccept(t *testing.T) {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
}).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") {
} else if strings.Contains(tc.Name, "WebSocket") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
})

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/marten-seemann/tcp"
"github.com/mikioh/tcpinfo"
manet "github.com/multiformats/go-multiaddr/net"
@@ -253,16 +254,6 @@ func (c *tracingConn) Close() error {
return c.closeErr
}
func (c *tracingConn) Scope() network.ConnManagementScope {
if cs, ok := c.Conn.(interface {
Scope() network.ConnManagementScope
}); ok {
return cs.Scope()
}
// upgrader is expected to handle this
return nil
}
func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
var o tcpinfo.Info
var b [256]byte
@@ -275,19 +266,31 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
}
type tracingListener struct {
manet.Listener
transport.GatedMaListener
collector *aggregatingCollector
}
// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector.
func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener {
return &tracingListener{Listener: l, collector: collector}
func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) *tracingListener {
return &tracingListener{GatedMaListener: l, collector: collector}
}
func (l *tracingListener) Accept() (manet.Conn, error) {
conn, err := l.Listener.Accept()
func (l *tracingListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
conn, scope, err := l.GatedMaListener.Accept()
if err != nil {
return nil, err
if scope != nil {
scope.Done()
log.Errorf("BUG: got non-nil scope but also an error: %s", err)
}
return newTracingConn(conn, l.collector, false)
return nil, nil, err
}
tc, err := newTracingConn(conn, l.collector, false)
if err != nil {
log.Errorf("failed to create tracingConn from %T: %s", conn, err)
conn.Close()
scope.Done()
return nil, nil, err
}
return tc, scope, nil
}

View File

@@ -4,11 +4,16 @@
package tcp
import manet "github.com/multiformats/go-multiaddr/net"
import (
"github.com/libp2p/go-libp2p/core/transport"
manet "github.com/multiformats/go-multiaddr/net"
)
type aggregatingCollector struct{}
func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) {
return c, nil
}
func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l }
func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) transport.GatedMaListener {
return l
}

View File

@@ -15,8 +15,10 @@ func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) {
peerA, ia := makeInsecureMuxer(t)
_, ib := makeInsecureMuxer(t)
sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil)
sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil)
upg, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
sharedTCPSocketA := tcpreuse.NewConnMgr(false, upg)
sharedTCPSocketB := tcpreuse.NewConnMgr(false, upg)
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)

View File

@@ -41,7 +41,7 @@ var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable
func tryKeepAlive(conn net.Conn, keepAlive bool) {
keepAliveConn, ok := conn.(canKeepAlive)
if !ok {
log.Errorf("Can't set TCP keepalives.")
log.Errorf("can't set TCP keepalives. net.Conn of type %T doesn't support SetKeepAlive", conn)
return
}
if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil {
@@ -76,23 +76,23 @@ func tryLinger(conn net.Conn, sec int) {
}
}
type tcpListener struct {
manet.Listener
type tcpGatedMaListener struct {
transport.GatedMaListener
sec int
}
func (ll *tcpListener) Accept() (manet.Conn, error) {
c, err := ll.Listener.Accept()
func (ll *tcpGatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
c, scope, err := ll.GatedMaListener.Accept()
if err != nil {
return nil, err
if scope != nil {
log.Errorf("BUG: got non-nil scope but also an error: %s", err)
scope.Done()
}
return nil, nil, err
}
tryLinger(c, ll.sec)
tryKeepAlive(c, true)
// We're not calling OpenConnection in the resource manager here,
// since the manet.Conn doesn't allow us to save the scope.
// It's the caller's (usually the p2p/net/upgrader) responsibility
// to call the resource manager.
return c, nil
return c, scope, nil
}
type Option func(*TcpTransport) error
@@ -316,22 +316,26 @@ func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, err
// Listen listens on the given multiaddr.
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
var list manet.Listener
var list transport.GatedMaListener
var err error
if t.sharedTcp == nil {
list, err = t.unsharedMAListen(laddr)
} else {
if t.sharedTcp != nil {
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
}
if err != nil {
return nil, err
}
} else {
mal, err := t.unsharedMAListen(laddr)
if err != nil {
return nil, err
}
list = t.upgrader.GateMaListener(mal)
}
if t.enableMetrics {
list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector)
// TODO: Fix this: The tcpListener wrapping should happen on both enableMetrics and disabledMetrics path
list = newTracingListener(&tcpGatedMaListener{list, 0}, t.metricsCollector)
}
return t.upgrader.UpgradeListener(t, list), nil
return t.upgrader.UpgradeGatedMaListener(t, list), nil
}
// Protocols returns the list of terminal protocols this transport can dial.

View File

@@ -10,19 +10,15 @@ import (
type connWithScope struct {
sampledconn.ManetTCPConnInterface
scope network.ConnManagementScope
}
func (c connWithScope) Scope() network.ConnManagementScope {
return c.scope
ConnScope network.ConnManagementScope
}
func (c *connWithScope) Close() error {
c.scope.Done()
defer c.ConnScope.Done()
return c.ManetTCPConnInterface.Close()
}
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) {
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (*connWithScope, error) {
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
return &connWithScope{tcpconn, scope}, nil
}

View File

@@ -9,7 +9,6 @@ import (
"time"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
@@ -28,32 +27,36 @@ var log = logging.Logger("tcp-demultiplex")
type ConnMgr struct {
enableReuseport bool
reuse reuseport.Transport
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
upgrader transport.Upgrader
mx sync.Mutex
listeners map[string]*multiplexedListener
}
func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
func NewConnMgr(enableReuseport bool, upgrader transport.Upgrader) *ConnMgr {
return &ConnMgr{
enableReuseport: enableReuseport,
reuse: reuseport.Transport{},
connGater: gater,
rcmgr: rcmgr,
upgrader: upgrader,
listeners: make(map[string]*multiplexedListener),
}
}
func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) {
func (t *ConnMgr) gatedMaListen(listenAddr ma.Multiaddr) (transport.GatedMaListener, error) {
var mal manet.Listener
var err error
if t.useReuseport() {
return t.reuse.Listen(listenAddr)
} else {
return manet.Listen(listenAddr)
mal, err = t.reuse.Listen(listenAddr)
if err != nil {
return nil, err
}
} else {
mal, err = manet.Listen(listenAddr)
if err != nil {
return nil, err
}
}
return t.upgrader.GateMaListener(mal), nil
}
func (t *ConnMgr) useReuseport() bool {
@@ -80,7 +83,7 @@ func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) {
// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections
// accepted from returned listeners need to be upgraded with a `transport.Upgrader`.
// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port.
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) {
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (transport.GatedMaListener, error) {
if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType)
}
@@ -100,7 +103,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
return dl, nil
}
l, err := t.maListen(laddr)
gmal, err := t.gatedMaListen(laddr)
if err != nil {
return nil, err
}
@@ -111,19 +114,17 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
t.mx.Lock()
defer t.mx.Unlock()
delete(t.listeners, laddr.String())
delete(t.listeners, l.Multiaddr().String())
return l.Close()
delete(t.listeners, gmal.Multiaddr().String())
return gmal.Close()
}
ml = &multiplexedListener{
Listener: l,
GatedMaListener: gmal,
listeners: make(map[DemultiplexedConnType]*demultiplexedListener),
ctx: ctx,
closeFn: cancelFunc,
connGater: t.connGater,
rcmgr: t.rcmgr,
}
t.listeners[laddr.String()] = ml
t.listeners[l.Multiaddr().String()] = ml
t.listeners[gmal.Multiaddr().String()] = ml
dl, err := ml.DemultiplexedListen(connType)
if err != nil {
@@ -137,15 +138,13 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
return dl, nil
}
var _ manet.Listener = &demultiplexedListener{}
var _ transport.GatedMaListener = &demultiplexedListener{}
type multiplexedListener struct {
manet.Listener
transport.GatedMaListener
listeners map[DemultiplexedConnType]*demultiplexedListener
mx sync.RWMutex
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
ctx context.Context
closeFn func() error
wg sync.WaitGroup
@@ -153,7 +152,7 @@ type multiplexedListener struct {
var ErrListenerExists = errors.New("listener already exists for this conn type on this address")
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) {
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (transport.GatedMaListener, error) {
if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType)
}
@@ -166,8 +165,8 @@ func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType
ctx, cancel := context.WithCancel(m.ctx)
l := &demultiplexedListener{
buffer: make(chan manet.Conn),
inner: m.Listener,
buffer: make(chan *connWithScope),
inner: m.GatedMaListener,
ctx: ctx,
cancelFunc: cancel,
closeFn: func() error { m.removeDemultiplexedListener(connType); return nil },
@@ -183,53 +182,35 @@ func (m *multiplexedListener) run() error {
defer m.wg.Done()
acceptQueue := make(chan struct{}, acceptQueueSize)
for {
c, err := m.Listener.Accept()
c, connScope, err := m.GatedMaListener.Accept()
if err != nil {
return err
}
// Gate and resource limit the connection here.
// If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer
// which clogs up our entire connection queue.
// This duplicates the responsibility of gating and resource limiting between here and the upgrader. The
// alternative without duplication requires moving the process of upgrading the connection here, which forces
// us to establish the websocket connection here. That is more duplication, or a significant breaking change.
//
// Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport
// integration tests.
if m.connGater != nil && !m.connGater.InterceptAccept(c) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
c.LocalMultiaddr(), c.RemoteMultiaddr())
if err := c.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := c.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
select {
case acceptQueue <- struct{}{}:
// NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader.
case <-m.ctx.Done():
case <-ctx.Done():
cancelCtx()
connScope.Done()
c.Close()
log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr())
continue
case <-m.ctx.Done():
cancelCtx()
connScope.Done()
c.Close()
log.Debugf("listener closed; dropping connection from: %s", c.RemoteMultiaddr())
continue
}
m.wg.Add(1)
go func() {
defer func() { <-acceptQueue }()
defer m.wg.Done()
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
defer cancelCtx()
t, c, err := identifyConnType(c)
if err != nil {
// conn closed by identifyConnType
connScope.Done()
log.Debugf("error demultiplexing connection: %s", err.Error())
return
@@ -279,7 +260,7 @@ func (m *multiplexedListener) Close() error {
}
func (m *multiplexedListener) closeListener() error {
lerr := m.Listener.Close()
lerr := m.GatedMaListener.Close()
cerr := m.closeFn()
return errors.Join(lerr, cerr)
}
@@ -298,19 +279,19 @@ func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnTyp
}
type demultiplexedListener struct {
buffer chan manet.Conn
inner manet.Listener
buffer chan *connWithScope
inner transport.GatedMaListener
ctx context.Context
cancelFunc context.CancelFunc
closeFn func() error
}
func (m *demultiplexedListener) Accept() (manet.Conn, error) {
func (m *demultiplexedListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
select {
case c := <-m.buffer:
return c, nil
return c.ManetTCPConnInterface, c.ConnScope, nil
case <-m.ctx.Done():
return nil, transport.ErrListenerClosed
return nil, nil, transport.ErrListenerClosed
}
}

View File

@@ -17,6 +17,9 @@ import (
"time"
"github.com/gorilla/websocket"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multistream"
@@ -53,6 +56,17 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
return tlsConfig
}
type maListener struct {
transport.GatedMaListener
}
var _ manet.Listener = &maListener{}
func (ml *maListener) Accept() (manet.Conn, error) {
c, _, err := ml.GatedMaListener.Accept()
return c, err
}
type wsHandler struct{ conns chan *websocket.Conn }
func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -61,12 +75,19 @@ func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wh.conns <- c
}
func upgrader(t *testing.T) transport.Upgrader {
t.Helper()
upd, err := tptu.New(nil, nil, nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
return upd
}
func TestListenerSingle(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 64
for _, enableReuseport := range []bool{true, false} {
t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil)
cm := NewConnMgr(enableReuseport, upgrader(t))
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
go func() {
@@ -96,7 +117,7 @@ func TestListenerSingle(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < N; i++ {
c, err := l.Accept()
c, _, err := l.Accept()
require.NoError(t, err)
wg.Add(1)
go func() {
@@ -117,12 +138,12 @@ func TestListenerSingle(t *testing.T) {
})
t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil)
cm := NewConnMgr(enableReuseport, upgrader(t))
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
http.Serve(manet.NetListener(l), wh)
http.Serve(manet.NetListener(&maListener{GatedMaListener: l}), wh)
}()
go func() {
d := websocket.Dialer{}
@@ -169,14 +190,14 @@ func TestListenerSingle(t *testing.T) {
})
t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil)
cm := NewConnMgr(enableReuseport, upgrader(t))
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err)
defer l.Close()
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(l), "", "")
s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: l}), "", "")
}()
go func() {
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
@@ -228,7 +249,7 @@ func TestListenerMultiplexed(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 20
for _, enableReuseport := range []bool{true, false} {
cm := NewConnMgr(enableReuseport, nil, nil)
cm := NewConnMgr(enableReuseport, upgrader(t))
msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
defer msl.Close()
@@ -239,7 +260,7 @@ func TestListenerMultiplexed(t *testing.T) {
require.Equal(t, wsl.Multiaddr(), msl.Multiaddr())
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
http.Serve(manet.NetListener(wsl), wh)
http.Serve(manet.NetListener(&maListener{GatedMaListener: wsl}), wh)
}()
wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
@@ -249,7 +270,7 @@ func TestListenerMultiplexed(t *testing.T) {
whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(wssl), "", "")
s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: wssl}), "", "")
}()
// multistream connections
@@ -331,7 +352,7 @@ func TestListenerMultiplexed(t *testing.T) {
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
c, err := msl.Accept()
c, _, err := msl.Accept()
if !assert.NoError(t, err) {
return
}
@@ -404,7 +425,7 @@ func TestListenerMultiplexed(t *testing.T) {
func TestListenerClose(t *testing.T) {
testClose := func(listenAddr ma.Multiaddr) {
// listen on port 0
cm := NewConnMgr(false, nil, nil)
cm := NewConnMgr(false, upgrader(t))
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
@@ -459,7 +480,7 @@ func setDeferReset[T any](t testing.TB, ptr *T, val T) {
func TestHitTimeout(t *testing.T) {
setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond)
// listen on port 0
cm := NewConnMgr(false, nil, nil)
cm := NewConnMgr(false, upgrader(t))
listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0")
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)

View File

@@ -4,8 +4,6 @@ import (
"net/url"
"testing"
"github.com/stretchr/testify/require"
ma "github.com/multiformats/go-multiaddr"
)
@@ -67,15 +65,3 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network())
}
}
func TestListeningOnDNSAddr(t *testing.T) {
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
require.NoError(t, err)
addr := ln.Multiaddr()
first, rest := ma.SplitFirst(addr)
require.Equal(t, ma.P_DNS, first.Protocol().Code)
require.Equal(t, "localhost", first.Value())
next, _ := ma.SplitFirst(rest)
require.Equal(t, ma.P_TCP, next.Protocol().Code)
require.NotEqual(t, 0, next.Value())
}

View File

@@ -1,7 +1,6 @@
package websocket
import (
"crypto/tls"
"errors"
"io"
"net"
@@ -9,7 +8,6 @@ import (
"time"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
@@ -23,6 +21,7 @@ var GracefulCloseTimeout = 100 * time.Millisecond
// Conn implements net.Conn interface for gorilla/websocket.
type Conn struct {
*ws.Conn
Scope network.ConnManagementScope
secure bool
DefaultMessageType int
reader io.Reader
@@ -36,10 +35,8 @@ type Conn struct {
var _ net.Conn = (*Conn)(nil)
var _ manet.Conn = (*Conn)(nil)
// NewConn creates a Conn given a regular gorilla/websocket Conn.
//
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
func NewConn(raw *ws.Conn, secure bool) *Conn {
// newConn creates a Conn given a regular gorilla/websocket Conn.
func newConn(raw *ws.Conn, secure bool, scope network.ConnManagementScope) *Conn {
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
laddr, err := manet.FromNetAddr(lna)
if err != nil {
@@ -56,6 +53,7 @@ func NewConn(raw *ws.Conn, secure bool) *Conn {
c := &Conn{
Conn: raw,
Scope: scope,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
laddr: laddr,
@@ -136,23 +134,6 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return len(b), nil
}
func (c *Conn) Scope() network.ConnManagementScope {
nc := c.NetConn()
if sc, ok := nc.(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
if nc, ok := nc.(*tls.Conn); ok {
if sc, ok := nc.NetConn().(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
}
return nil
}
// Close closes the connection.
// subsequent and concurrent calls will return the same error value.
// This method is thread-safe.
@@ -201,13 +182,3 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t)
}
type capableConn struct {
transport.CapableConn
}
func (c *capableConn) ConnState() network.ConnectionState {
cs := c.CapableConn.ConnState()
cs.Transport = "websocket"
return cs
}

View File

@@ -1,17 +1,22 @@
package websocket
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
"go.uber.org/zap"
ws "github.com/gorilla/websocket"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
@@ -23,8 +28,9 @@ var log = logging.Logger("websocket-transport")
var stdLog = zap.NewStdLog(log.Desugar())
type listener struct {
nl net.Listener
netListener *httpNetListener
server http.Server
wsUpgrader ws.Upgrader
// The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS,
// so we can't rely on checking if server.TLSConfig is set.
isWss bool
@@ -36,8 +42,11 @@ type listener struct {
closeOnce sync.Once
closeErr error
closed chan struct{}
wsurl *url.URL
}
var _ transport.GatedMaListener = &listener{}
func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
if !pwma.isWSS {
return pwma.restMultiaddr.AppendComponent(wsComponent)
@@ -52,7 +61,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
// newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) {
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, upgrader transport.Upgrader, handshakeTimeout time.Duration) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a)
if err != nil {
return nil, err
@@ -62,17 +71,13 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
}
var nl net.Listener
var gmal transport.GatedMaListener
if sharedTcp == nil {
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
if err != nil {
return nil, err
}
nl, err = net.Listen(lnet, lnaddr)
mal, err := manet.Listen(parsed.restMultiaddr)
if err != nil {
return nil, err
}
gmal = upgrader.GateMaListener(mal)
} else {
var connType tcpreuse.DemultiplexedConnType
if parsed.isWSS {
@@ -80,89 +85,146 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
} else {
connType = tcpreuse.DemultiplexedConnType_HTTP
}
mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
gmal, err = sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
if err != nil {
return nil, err
}
nl = manet.NetListener(mal)
}
laddr, err := manet.FromNetAddr(nl.Addr())
if err != nil {
return nil, err
}
// laddr has the correct port in case we listened on port 0
laddr := gmal.Multiaddr()
first, _ := ma.SplitFirst(a)
// Don't resolve dns addresses.
// We want to be able to announce domain names, so the peer can validate the TLS certificate.
first, _ := ma.SplitFirst(a)
if c := first.Protocol().Code; c == ma.P_DNS || c == ma.P_DNS4 || c == ma.P_DNS6 || c == ma.P_DNSADDR {
_, last := ma.SplitFirst(laddr)
laddr = first.Encapsulate(last)
}
parsed.restMultiaddr = laddr
listenAddr := parsed.toMultiaddr()
wsurl, err := parseMultiaddr(listenAddr)
if err != nil {
gmal.Close()
return nil, fmt.Errorf("failed to parse multiaddr to URL: %v: %w", listenAddr, err)
}
ln := &listener{
nl: nl,
netListener: &httpNetListener{
GatedMaListener: gmal,
handshakeTimeout: handshakeTimeout,
},
laddr: parsed.toMultiaddr(),
incoming: make(chan *Conn),
closed: make(chan struct{}),
isWss: parsed.isWSS,
wsurl: wsurl,
wsUpgrader: ws.Upgrader{
// Allow requests from *all* origins.
CheckOrigin: func(r *http.Request) bool {
return true
},
HandshakeTimeout: handshakeTimeout,
},
}
ln.server = http.Server{Handler: ln, ErrorLog: stdLog}
if parsed.isWSS {
ln.isWss = true
ln.server.TLSConfig = tlsConf
}
ln.server = http.Server{Handler: ln, ErrorLog: stdLog, ConnContext: ln.ConnContext, TLSConfig: tlsConf}
return ln, nil
}
func (l *listener) serve() {
defer close(l.closed)
if !l.isWss {
l.server.Serve(l.nl)
l.server.Serve(l.netListener)
} else {
l.server.ServeTLS(l.nl, "", "")
l.server.ServeTLS(l.netListener, "", "")
}
}
type connKey struct{}
func (l *listener) ConnContext(ctx context.Context, c net.Conn) context.Context {
// prefer `*tls.Conn` over `(interface{NetConn() net.Conn})` in case `manet.Conn` is extended
// to support a `NetConn() net.Conn` method.
if tc, ok := c.(*tls.Conn); ok {
c = tc.NetConn()
}
if nc, ok := c.(*negotiatingConn); ok {
return context.WithValue(ctx, connKey{}, nc)
}
log.Errorf("BUG: expected net.Conn of type *websocket.negotiatingConn: got %T", c)
// might as well close the connection as there's no way to proceed now.
c.Close()
return ctx
}
func (l *listener) extractConnFromContext(ctx context.Context) (*negotiatingConn, error) {
c := ctx.Value(connKey{})
if c == nil {
return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got nil")
}
nc, ok := c.(*negotiatingConn)
if !ok {
return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got %T", c)
}
return nc, nil
}
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
c, err := l.wsUpgrader.Upgrade(w, r, nil)
if err != nil {
// The upgrader writes a response for us.
return
}
nc := NewConn(c, l.isWss)
if nc == nil {
nc, err := l.extractConnFromContext(r.Context())
if err != nil {
c.Close()
w.WriteHeader(500)
log.Errorf("BUG: failed to extract conn from context: RemoteAddr: %s: err: %s", r.RemoteAddr, err)
return
}
cs, err := nc.Unwrap()
if err != nil {
c.Close()
w.WriteHeader(500)
log.Debugf("connection timed out from: %s", r.RemoteAddr)
return
}
conn := newConn(c, l.isWss, cs.Scope)
if conn == nil {
c.Close()
w.WriteHeader(500)
return
}
select {
case l.incoming <- nc:
case l.incoming <- conn:
case <-l.closed:
nc.Close()
conn.Close()
}
// The connection has been hijacked, it's safe to return.
}
func (l *listener) Accept() (manet.Conn, error) {
func (l *listener) Accept() (manet.Conn, network.ConnManagementScope, error) {
select {
case c, ok := <-l.incoming:
if !ok {
return nil, transport.ErrListenerClosed
return nil, nil, transport.ErrListenerClosed
}
return c, nil
return c, c.Scope, nil
case <-l.closed:
return nil, transport.ErrListenerClosed
return nil, nil, transport.ErrListenerClosed
}
}
func (l *listener) Addr() net.Addr {
return l.nl.Addr()
return &Addr{URL: l.wsurl}
}
func (l *listener) Close() error {
l.closeOnce.Do(func() {
err1 := l.nl.Close()
err1 := l.netListener.Close()
err2 := l.server.Close()
<-l.closed
l.closeErr = errors.Join(err1, err2)
@@ -174,14 +236,74 @@ func (l *listener) Multiaddr() ma.Multiaddr {
return l.laddr
}
type transportListener struct {
transport.Listener
// httpNetListener is a net.Listener that adapts a transport.GatedMaListener to a net.Listener.
// It wraps the manet.Conn, and the Scope from the underlying gated listener in a connWithScope.
type httpNetListener struct {
transport.GatedMaListener
handshakeTimeout time.Duration
}
func (l *transportListener) Accept() (transport.CapableConn, error) {
conn, err := l.Listener.Accept()
var _ net.Listener = &httpNetListener{}
func (l *httpNetListener) Accept() (net.Conn, error) {
conn, scope, err := l.GatedMaListener.Accept()
if err != nil {
if scope != nil {
log.Errorf("BUG: scope non-nil when err is non nil: %v", err)
scope.Done()
}
return nil, err
}
return &capableConn{CapableConn: conn}, nil
connWithScope := connWithScope{
Conn: conn,
Scope: scope,
}
ctx, cancel := context.WithTimeout(context.Background(), l.handshakeTimeout)
return &negotiatingConn{
connWithScope: connWithScope,
ctx: ctx,
cancelCtx: cancel,
stopClose: context.AfterFunc(ctx, func() {
connWithScope.Close()
log.Debugf("handshake timeout for conn from: %s", conn.RemoteAddr())
}),
}, nil
}
type connWithScope struct {
net.Conn
Scope network.ConnManagementScope
}
func (c connWithScope) Close() error {
c.Scope.Done()
return c.Conn.Close()
}
type negotiatingConn struct {
connWithScope
ctx context.Context
cancelCtx context.CancelFunc
stopClose func() bool
}
// Close closes the negotiating conn and the underlying connWithScope
// This will be called in case the tls handshake or websocket upgrade fails.
func (c *negotiatingConn) Close() error {
defer c.cancelCtx()
if c.stopClose != nil {
c.stopClose()
}
return c.connWithScope.Close()
}
func (c *negotiatingConn) Unwrap() (connWithScope, error) {
defer c.cancelCtx()
if c.stopClose != nil {
if !c.stopClose() {
return connWithScope{}, errors.New("timed out")
}
c.stopClose = nil
}
return c.connWithScope, nil
}

View File

@@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
"github.com/libp2p/go-libp2p/core/network"
@@ -51,14 +50,6 @@ func init() {
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss")
}
// Default gorilla upgrader
var upgrader = ws.Upgrader{
// Allow requests from *all* origins.
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Option func(*WebsocketTransport) error
// WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only
@@ -81,15 +72,24 @@ func WithTLSConfig(conf *tls.Config) Option {
}
}
var defaultHandshakeTimeout = 15 * time.Second
// WithHandshakeTimeout sets a timeout for the websocket upgrade.
func WithHandshakeTimeout(timeout time.Duration) Option {
return func(t *WebsocketTransport) error {
t.handshakeTimeout = timeout
return nil
}
}
// WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct {
upgrader transport.Upgrader
rcmgr network.ResourceManager
tlsClientConf *tls.Config
tlsConf *tls.Config
sharedTcp *tcpreuse.ConnMgr
handshakeTimeout time.Duration
}
var _ transport.Transport = (*WebsocketTransport)(nil)
@@ -103,6 +103,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreus
rcmgr: rcmgr,
tlsClientConf: &tls.Config{},
sharedTcp: sharedTCP,
handshakeTimeout: defaultHandshakeTimeout,
}
for _, opt := range opts {
if err := opt(t); err != nil {
@@ -176,7 +177,7 @@ func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p pee
}
func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) {
macon, err := t.maDial(ctx, raddr)
macon, err := t.maDial(ctx, raddr, connScope)
if err != nil {
return nil, err
}
@@ -187,14 +188,14 @@ func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiad
return &capableConn{CapableConn: conn}, nil
}
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr, scope network.ConnManagementScope) (manet.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}
isWss := wsurl.Scheme == "wss"
dialer := ws.Dialer{
HandshakeTimeout: 30 * time.Second,
HandshakeTimeout: t.handshakeTimeout,
// Inherit the default proxy behavior
Proxy: ws.DefaultDialer.Proxy,
}
@@ -236,7 +237,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
return nil, err
}
mnc, err := manet.WrapNetConn(NewConn(wscon, isWss))
mnc, err := manet.WrapNetConn(newConn(wscon, isWss, scope))
if err != nil {
wscon.Close()
return nil, err
@@ -244,12 +245,12 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
return mnc, nil
}
func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
func (t *WebsocketTransport) gatedMaListen(a ma.Multiaddr) (transport.GatedMaListener, error) {
var tlsConf *tls.Config
if t.tlsConf != nil {
tlsConf = t.tlsConf.Clone()
}
l, err := newListener(a, tlsConf, t.sharedTcp)
l, err := newListener(a, tlsConf, t.sharedTcp, t.upgrader, t.handshakeTimeout)
if err != nil {
return nil, err
}
@@ -258,9 +259,32 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
}
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) {
malist, err := t.maListen(a)
gmal, err := t.gatedMaListen(a)
if err != nil {
return nil, err
}
return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil
return &transportListener{Listener: t.upgrader.UpgradeGatedMaListener(t, gmal)}, nil
}
// transportListener wraps a transport.Listener to provide connections with a `ConnState() network.ConnectionState` method.
type transportListener struct {
transport.Listener
}
type capableConn struct {
transport.CapableConn
}
func (c *capableConn) ConnState() network.ConnectionState {
cs := c.CapableConn.ConnState()
cs.Transport = "websocket"
return cs
}
func (l *transportListener) Accept() (transport.CapableConn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &capableConn{CapableConn: conn}, nil
}

View File

@@ -35,6 +35,7 @@ import (
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -296,18 +297,39 @@ func TestDialWssNoClientCert(t *testing.T) {
}
func TestWebsocketTransport(t *testing.T) {
t.Run("/ws", func(t *testing.T) {
peerA, ua := newUpgrader(t)
ta, err := New(ua, nil, nil)
if err != nil {
t.Fatal(err)
}
_, ub := newUpgrader(t)
peerB, ub := newUpgrader(t)
tb, err := New(ub, nil, nil)
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
t.Run("/wss", func(t *testing.T) {
peerA, ua := newUpgrader(t)
tca := generateTLSConfig(t)
ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
peerB, ub := newUpgrader(t)
tcb := generateTLSConfig(t)
tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
}
func isWSS(addr ma.Multiaddr) bool {
@@ -441,7 +463,7 @@ func TestConcurrentClose(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
@@ -451,7 +473,7 @@ func TestConcurrentClose(t *testing.T) {
go func() {
for i := 0; i < 100; i++ {
c, err := tpt.maDial(context.Background(), l.Multiaddr())
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil {
t.Error(err)
return
@@ -467,7 +489,7 @@ func TestConcurrentClose(t *testing.T) {
}()
for i := 0; i < 100; i++ {
c, err := l.Accept()
c, _, err := l.Accept()
if err != nil {
t.Fatal(err)
}
@@ -481,7 +503,7 @@ func TestWriteZero(t *testing.T) {
if err != nil {
t.Fatal(err)
}
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
@@ -490,7 +512,7 @@ func TestWriteZero(t *testing.T) {
msg := []byte(nil)
go func() {
c, err := tpt.maDial(context.Background(), l.Multiaddr())
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil {
t.Error(err)
return
@@ -509,7 +531,7 @@ func TestWriteZero(t *testing.T) {
}
}()
c, err := l.Accept()
c, _, err := l.Accept()
if err != nil {
t.Fatal(err)
}
@@ -623,3 +645,98 @@ func TestSocksProxy(t *testing.T) {
})
}
}
func TestListenerAddr(t *testing.T) {
_, upgrader := newUpgrader(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer l1.Close()
require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String())
l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer l2.Close()
require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String())
}
func TestHandshakeTimeout(t *testing.T) {
handshakeTimeout := 200 * time.Millisecond
_, upgrader := newUpgrader(t)
tlsconf := generateTLSConfig(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf))
require.NoError(t, err)
fastWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
NetDial: func(network, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
return tcpConn, nil
},
}
slowWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
NetDial: func(network, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
// wait to simulate a slow handshake
time.Sleep(2 * handshakeTimeout)
return tcpConn, nil
},
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
t.Run("ws", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer wsListener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if !assert.NoError(t, err) {
return
}
conn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
conn.Close()
resp.Body.Close()
t.Fatal("should error as the handshake will time out")
}
})
t.Run("wss", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer wsListener.Close()
// Test that the normal dial works fine
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
require.NoError(t, err)
wsConn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
wsConn.Close()
resp.Body.Close()
t.Fatal("websocket handshake should have timed out")
}
})
}