mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-26 20:21:26 +08:00
transport: add GatedMaListener type (#3186)
This introduces a new GatedMaListener type which gates conns accepted from a manet.Listener with a gater and creates the rcmgr scope for it. Explicitly passing the scope allows for many guardrails that the previous interface assertion didn't. This breaks the previous responsibility of the upgradeListener method into two, one gating the connection initially, and the other upgrading the connection with a security and muxer selection. This split makes it easy to gate the connection with the resource manager as early as possible. This is especially true for websocket because we want to gate the connection just after the TCP connection is established, and not after the tls handshake + websocket upgrade is completed.
This commit is contained in:
@@ -292,11 +292,11 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
|
||||
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
|
||||
fx.Provide(func() pnet.PSK { return cfg.PSK }),
|
||||
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
|
||||
fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr {
|
||||
fx.Provide(func(upgrader transport.Upgrader) *tcpreuse.ConnMgr {
|
||||
if !cfg.ShareTCPListener {
|
||||
return nil
|
||||
}
|
||||
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr)
|
||||
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, upgrader)
|
||||
}),
|
||||
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
|
||||
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
|
||||
|
@@ -129,11 +129,41 @@ type TransportNetwork interface {
|
||||
AddTransport(t Transport) error
|
||||
}
|
||||
|
||||
// GatedMaListener is listener that listens for raw(unsecured and non-multiplexed) incoming connections,
|
||||
// gates them with a `connmgr.ConnGater`and creates a resource management scope for them.
|
||||
// It can be upgraded to a full libp2p transport listener by the Upgrader.
|
||||
//
|
||||
// Compared to manet.Listener, this listener creates the resource management scope for the accepted connection.
|
||||
type GatedMaListener interface {
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
Accept() (manet.Conn, network.ConnManagementScope, error)
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
Close() error
|
||||
|
||||
// Multiaddr returns the listener's (local) Multiaddr.
|
||||
Multiaddr() ma.Multiaddr
|
||||
|
||||
// Addr returns the net.Listener's network address.
|
||||
Addr() net.Addr
|
||||
}
|
||||
|
||||
// Upgrader is a multistream upgrader that can upgrade an underlying connection
|
||||
// to a full transport connection (secure and multiplexed).
|
||||
type Upgrader interface {
|
||||
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
|
||||
//
|
||||
// Deprecated: Use UpgradeGatedMaListener(upgrader.GateMaListener(manet.Listener)) instead.
|
||||
UpgradeListener(Transport, manet.Listener) Listener
|
||||
|
||||
// GateMaListener creates a GatedMaListener from a manet.Listener. It gates the accepted connection
|
||||
// and creates a resource scope for it.
|
||||
GateMaListener(manet.Listener) GatedMaListener
|
||||
|
||||
// UpgradeGatedMaListener upgrades the passed GatedMaListener into a full libp2p-transport listener.
|
||||
UpgradeGatedMaListener(Transport, GatedMaListener) Listener
|
||||
|
||||
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.
|
||||
Upgrade(ctx context.Context, t Transport, maconn manet.Conn, dir network.Direction, p peer.ID, scope network.ConnManagementScope) (CapableConn, error)
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/connmgr"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
var log = logging.Logger("upgrader")
|
||||
|
||||
type listener struct {
|
||||
manet.Listener
|
||||
transport.GatedMaListener
|
||||
|
||||
transport transport.Transport
|
||||
upgrader *upgrader
|
||||
@@ -35,10 +36,12 @@ type listener struct {
|
||||
cancel func()
|
||||
}
|
||||
|
||||
var _ transport.Listener = (*listener)(nil)
|
||||
|
||||
// Close closes the listener.
|
||||
func (l *listener) Close() error {
|
||||
// Do this first to try to get any relevant errors.
|
||||
err := l.Listener.Close()
|
||||
err := l.GatedMaListener.Close()
|
||||
|
||||
l.cancel()
|
||||
// Drain and wait.
|
||||
@@ -61,7 +64,7 @@ func (l *listener) handleIncoming() {
|
||||
var wg sync.WaitGroup
|
||||
defer func() {
|
||||
// make sure we're closed
|
||||
l.Listener.Close()
|
||||
l.GatedMaListener.Close()
|
||||
if l.err == nil {
|
||||
l.err = fmt.Errorf("listener closed")
|
||||
}
|
||||
@@ -72,7 +75,7 @@ func (l *listener) handleIncoming() {
|
||||
|
||||
var catcher tec.TempErrCatcher
|
||||
for l.ctx.Err() == nil {
|
||||
maconn, err := l.Listener.Accept()
|
||||
maconn, connScope, err := l.GatedMaListener.Accept()
|
||||
if err != nil {
|
||||
// Note: function may pause the accept loop.
|
||||
if catcher.IsTemporary(err) {
|
||||
@@ -84,35 +87,12 @@ func (l *listener) handleIncoming() {
|
||||
}
|
||||
catcher.Reset()
|
||||
|
||||
// Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation.
|
||||
var connScope network.ConnManagementScope
|
||||
if sc, ok := maconn.(interface {
|
||||
Scope() network.ConnManagementScope
|
||||
}); ok {
|
||||
connScope = sc.Scope()
|
||||
}
|
||||
if connScope == nil {
|
||||
// gate the connection if applicable
|
||||
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
|
||||
log.Debugf("gater blocked incoming connection on local addr %s from %s",
|
||||
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
|
||||
if err := maconn.Close(); err != nil {
|
||||
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
|
||||
}
|
||||
log.Errorf("BUG: got nil connScope for incoming connection from %s", maconn.RemoteMultiaddr())
|
||||
maconn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked accept of new connection", "error", err)
|
||||
if err := maconn.Close(); err != nil {
|
||||
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// The go routine below calls Release when the context is
|
||||
// canceled so there's no need to wait on it here.
|
||||
l.threshold.Wait()
|
||||
@@ -154,14 +134,10 @@ func (l *listener) handleIncoming() {
|
||||
select {
|
||||
case l.incoming <- conn:
|
||||
case <-ctx.Done():
|
||||
// Listener not closed but the accept timeout expired.
|
||||
if l.ctx.Err() == nil {
|
||||
// Listener *not* closed but the accept timeout expired.
|
||||
log.Warn("listener dropped connection due to slow accept")
|
||||
log.Warnf("listener dropped connection due to slow accept. remote addr: %s peer: %s", maconn.RemoteMultiaddr(), conn.RemotePeer())
|
||||
}
|
||||
// Wait on the context with a timeout. This way,
|
||||
// if we stop accepting connections for some reason,
|
||||
// we'll eventually close all the open ones
|
||||
// instead of hanging onto them.
|
||||
conn.CloseWithError(network.ConnRateLimited)
|
||||
}
|
||||
}()
|
||||
@@ -189,4 +165,38 @@ func (l *listener) String() string {
|
||||
return fmt.Sprintf("<stream.Listener %s>", l.Multiaddr())
|
||||
}
|
||||
|
||||
var _ transport.Listener = (*listener)(nil)
|
||||
type gatedMaListener struct {
|
||||
manet.Listener
|
||||
rcmgr network.ResourceManager
|
||||
connGater connmgr.ConnectionGater
|
||||
}
|
||||
|
||||
var _ transport.GatedMaListener = &gatedMaListener{}
|
||||
|
||||
func (l *gatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
// gate the connection if applicable
|
||||
if l.connGater != nil && !l.connGater.InterceptAccept(conn) {
|
||||
log.Debugf("gater blocked incoming connection on local addr %s from %s",
|
||||
conn.LocalMultiaddr(), conn.RemoteMultiaddr())
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, conn.RemoteMultiaddr())
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked accept of new connection", "error", err)
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return conn, connScope, nil
|
||||
}
|
||||
}
|
||||
|
@@ -30,7 +30,7 @@ func createListener(t *testing.T, u transport.Upgrader) transport.Listener {
|
||||
require.NoError(t, err)
|
||||
ln, err := manet.Listen(addr)
|
||||
require.NoError(t, err)
|
||||
return u.UpgradeListener(nil, ln)
|
||||
return u.UpgradeGatedMaListener(nil, u.GateMaListener(ln))
|
||||
}
|
||||
|
||||
func TestAcceptSingleConn(t *testing.T) {
|
||||
|
@@ -105,9 +105,22 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc
|
||||
|
||||
// UpgradeListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
|
||||
func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) transport.Listener {
|
||||
return u.UpgradeGatedMaListener(t, u.GateMaListener(list))
|
||||
}
|
||||
|
||||
func (u *upgrader) GateMaListener(l manet.Listener) transport.GatedMaListener {
|
||||
return &gatedMaListener{
|
||||
Listener: l,
|
||||
rcmgr: u.rcmgr,
|
||||
connGater: u.connGater,
|
||||
}
|
||||
}
|
||||
|
||||
// UpgradeGatedMaListener upgrades the passed multiaddr-net listener into a full libp2p-transport listener.
|
||||
func (u *upgrader) UpgradeGatedMaListener(t transport.Transport, l transport.GatedMaListener) transport.Listener {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l := &listener{
|
||||
Listener: list,
|
||||
list := &listener{
|
||||
GatedMaListener: l,
|
||||
upgrader: u,
|
||||
transport: t,
|
||||
rcmgr: u.rcmgr,
|
||||
@@ -116,8 +129,8 @@ func (u *upgrader) UpgradeListener(t transport.Transport, list manet.Listener) t
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
go l.handleIncoming()
|
||||
return l
|
||||
go list.handleIncoming()
|
||||
return list
|
||||
}
|
||||
|
||||
// Upgrade upgrades the multiaddr/net connection into a full libp2p-transport connection.
|
||||
|
@@ -101,7 +101,7 @@ func (c *Client) Listen(addr ma.Multiaddr) (transport.Listener, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.upgrader.UpgradeListener(c, c.Listener()), nil
|
||||
return c.upgrader.UpgradeGatedMaListener(c, c.upgrader.GateMaListener(c.Listener())), nil
|
||||
}
|
||||
|
||||
func (c *Client) Protocols() []int {
|
||||
|
@@ -203,7 +203,7 @@ func TestInterceptAccept(t *testing.T) {
|
||||
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
|
||||
require.Equal(t, normalize(h2.Addrs()[0]), normalize(addrs.LocalMultiaddr()))
|
||||
}).AnyTimes()
|
||||
} else if strings.Contains(tc.Name, "WebSocket-Shared") || strings.Contains(tc.Name, "WebSocket-Secured-Shared") {
|
||||
} else if strings.Contains(tc.Name, "WebSocket") {
|
||||
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
|
||||
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
|
||||
})
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/marten-seemann/tcp"
|
||||
"github.com/mikioh/tcpinfo"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
@@ -253,16 +254,6 @@ func (c *tracingConn) Close() error {
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *tracingConn) Scope() network.ConnManagementScope {
|
||||
if cs, ok := c.Conn.(interface {
|
||||
Scope() network.ConnManagementScope
|
||||
}); ok {
|
||||
return cs.Scope()
|
||||
}
|
||||
// upgrader is expected to handle this
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
|
||||
var o tcpinfo.Info
|
||||
var b [256]byte
|
||||
@@ -275,19 +266,31 @@ func (c *tracingConn) getTCPInfo() (*tcpinfo.Info, error) {
|
||||
}
|
||||
|
||||
type tracingListener struct {
|
||||
manet.Listener
|
||||
transport.GatedMaListener
|
||||
collector *aggregatingCollector
|
||||
}
|
||||
|
||||
// newTracingListener wraps a manet.Listener with a tracingListener. A nil collector will use the default collector.
|
||||
func newTracingListener(l manet.Listener, collector *aggregatingCollector) *tracingListener {
|
||||
return &tracingListener{Listener: l, collector: collector}
|
||||
func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) *tracingListener {
|
||||
return &tracingListener{GatedMaListener: l, collector: collector}
|
||||
}
|
||||
|
||||
func (l *tracingListener) Accept() (manet.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
func (l *tracingListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
|
||||
conn, scope, err := l.GatedMaListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if scope != nil {
|
||||
scope.Done()
|
||||
log.Errorf("BUG: got non-nil scope but also an error: %s", err)
|
||||
}
|
||||
return newTracingConn(conn, l.collector, false)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
tc, err := newTracingConn(conn, l.collector, false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create tracingConn from %T: %s", conn, err)
|
||||
conn.Close()
|
||||
scope.Done()
|
||||
return nil, nil, err
|
||||
}
|
||||
return tc, scope, nil
|
||||
}
|
||||
|
@@ -4,11 +4,16 @@
|
||||
|
||||
package tcp
|
||||
|
||||
import manet "github.com/multiformats/go-multiaddr/net"
|
||||
import (
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
type aggregatingCollector struct{}
|
||||
|
||||
func newTracingConn(c manet.Conn, collector *aggregatingCollector, isClient bool) (manet.Conn, error) {
|
||||
return c, nil
|
||||
}
|
||||
func newTracingListener(l manet.Listener, collector *aggregatingCollector) manet.Listener { return l }
|
||||
func newTracingListener(l transport.GatedMaListener, collector *aggregatingCollector) transport.GatedMaListener {
|
||||
return l
|
||||
}
|
||||
|
@@ -15,8 +15,10 @@ func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) {
|
||||
peerA, ia := makeInsecureMuxer(t)
|
||||
_, ib := makeInsecureMuxer(t)
|
||||
|
||||
sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil)
|
||||
sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil)
|
||||
upg, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
sharedTCPSocketA := tcpreuse.NewConnMgr(false, upg)
|
||||
sharedTCPSocketB := tcpreuse.NewConnMgr(false, upg)
|
||||
|
||||
ua, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
@@ -41,7 +41,7 @@ var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable
|
||||
func tryKeepAlive(conn net.Conn, keepAlive bool) {
|
||||
keepAliveConn, ok := conn.(canKeepAlive)
|
||||
if !ok {
|
||||
log.Errorf("Can't set TCP keepalives.")
|
||||
log.Errorf("can't set TCP keepalives. net.Conn of type %T doesn't support SetKeepAlive", conn)
|
||||
return
|
||||
}
|
||||
if err := keepAliveConn.SetKeepAlive(keepAlive); err != nil {
|
||||
@@ -76,23 +76,23 @@ func tryLinger(conn net.Conn, sec int) {
|
||||
}
|
||||
}
|
||||
|
||||
type tcpListener struct {
|
||||
manet.Listener
|
||||
type tcpGatedMaListener struct {
|
||||
transport.GatedMaListener
|
||||
sec int
|
||||
}
|
||||
|
||||
func (ll *tcpListener) Accept() (manet.Conn, error) {
|
||||
c, err := ll.Listener.Accept()
|
||||
func (ll *tcpGatedMaListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
|
||||
c, scope, err := ll.GatedMaListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if scope != nil {
|
||||
log.Errorf("BUG: got non-nil scope but also an error: %s", err)
|
||||
scope.Done()
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
tryLinger(c, ll.sec)
|
||||
tryKeepAlive(c, true)
|
||||
// We're not calling OpenConnection in the resource manager here,
|
||||
// since the manet.Conn doesn't allow us to save the scope.
|
||||
// It's the caller's (usually the p2p/net/upgrader) responsibility
|
||||
// to call the resource manager.
|
||||
return c, nil
|
||||
return c, scope, nil
|
||||
}
|
||||
|
||||
type Option func(*TcpTransport) error
|
||||
@@ -316,22 +316,26 @@ func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, err
|
||||
|
||||
// Listen listens on the given multiaddr.
|
||||
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
|
||||
var list manet.Listener
|
||||
var list transport.GatedMaListener
|
||||
var err error
|
||||
|
||||
if t.sharedTcp == nil {
|
||||
list, err = t.unsharedMAListen(laddr)
|
||||
} else {
|
||||
if t.sharedTcp != nil {
|
||||
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
mal, err := t.unsharedMAListen(laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = t.upgrader.GateMaListener(mal)
|
||||
}
|
||||
|
||||
if t.enableMetrics {
|
||||
list = newTracingListener(&tcpListener{list, 0}, t.metricsCollector)
|
||||
// TODO: Fix this: The tcpListener wrapping should happen on both enableMetrics and disabledMetrics path
|
||||
list = newTracingListener(&tcpGatedMaListener{list, 0}, t.metricsCollector)
|
||||
}
|
||||
return t.upgrader.UpgradeListener(t, list), nil
|
||||
return t.upgrader.UpgradeGatedMaListener(t, list), nil
|
||||
}
|
||||
|
||||
// Protocols returns the list of terminal protocols this transport can dial.
|
||||
|
@@ -10,19 +10,15 @@ import (
|
||||
|
||||
type connWithScope struct {
|
||||
sampledconn.ManetTCPConnInterface
|
||||
scope network.ConnManagementScope
|
||||
}
|
||||
|
||||
func (c connWithScope) Scope() network.ConnManagementScope {
|
||||
return c.scope
|
||||
ConnScope network.ConnManagementScope
|
||||
}
|
||||
|
||||
func (c *connWithScope) Close() error {
|
||||
c.scope.Done()
|
||||
defer c.ConnScope.Done()
|
||||
return c.ManetTCPConnInterface.Close()
|
||||
}
|
||||
|
||||
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) {
|
||||
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (*connWithScope, error) {
|
||||
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
|
||||
return &connWithScope{tcpconn, scope}, nil
|
||||
}
|
||||
|
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-libp2p/core/connmgr"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
|
||||
@@ -28,32 +27,36 @@ var log = logging.Logger("tcp-demultiplex")
|
||||
type ConnMgr struct {
|
||||
enableReuseport bool
|
||||
reuse reuseport.Transport
|
||||
connGater connmgr.ConnectionGater
|
||||
rcmgr network.ResourceManager
|
||||
upgrader transport.Upgrader
|
||||
|
||||
mx sync.Mutex
|
||||
listeners map[string]*multiplexedListener
|
||||
}
|
||||
|
||||
func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr {
|
||||
if rcmgr == nil {
|
||||
rcmgr = &network.NullResourceManager{}
|
||||
}
|
||||
func NewConnMgr(enableReuseport bool, upgrader transport.Upgrader) *ConnMgr {
|
||||
return &ConnMgr{
|
||||
enableReuseport: enableReuseport,
|
||||
reuse: reuseport.Transport{},
|
||||
connGater: gater,
|
||||
rcmgr: rcmgr,
|
||||
upgrader: upgrader,
|
||||
listeners: make(map[string]*multiplexedListener),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) {
|
||||
func (t *ConnMgr) gatedMaListen(listenAddr ma.Multiaddr) (transport.GatedMaListener, error) {
|
||||
var mal manet.Listener
|
||||
var err error
|
||||
if t.useReuseport() {
|
||||
return t.reuse.Listen(listenAddr)
|
||||
} else {
|
||||
return manet.Listen(listenAddr)
|
||||
mal, err = t.reuse.Listen(listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
mal, err = manet.Listen(listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return t.upgrader.GateMaListener(mal), nil
|
||||
}
|
||||
|
||||
func (t *ConnMgr) useReuseport() bool {
|
||||
@@ -80,7 +83,7 @@ func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) {
|
||||
// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections
|
||||
// accepted from returned listeners need to be upgraded with a `transport.Upgrader`.
|
||||
// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port.
|
||||
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) {
|
||||
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (transport.GatedMaListener, error) {
|
||||
if !connType.IsKnown() {
|
||||
return nil, fmt.Errorf("unknown connection type: %s", connType)
|
||||
}
|
||||
@@ -100,7 +103,7 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
l, err := t.maListen(laddr)
|
||||
gmal, err := t.gatedMaListen(laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -111,19 +114,17 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
|
||||
t.mx.Lock()
|
||||
defer t.mx.Unlock()
|
||||
delete(t.listeners, laddr.String())
|
||||
delete(t.listeners, l.Multiaddr().String())
|
||||
return l.Close()
|
||||
delete(t.listeners, gmal.Multiaddr().String())
|
||||
return gmal.Close()
|
||||
}
|
||||
ml = &multiplexedListener{
|
||||
Listener: l,
|
||||
GatedMaListener: gmal,
|
||||
listeners: make(map[DemultiplexedConnType]*demultiplexedListener),
|
||||
ctx: ctx,
|
||||
closeFn: cancelFunc,
|
||||
connGater: t.connGater,
|
||||
rcmgr: t.rcmgr,
|
||||
}
|
||||
t.listeners[laddr.String()] = ml
|
||||
t.listeners[l.Multiaddr().String()] = ml
|
||||
t.listeners[gmal.Multiaddr().String()] = ml
|
||||
|
||||
dl, err := ml.DemultiplexedListen(connType)
|
||||
if err != nil {
|
||||
@@ -137,15 +138,13 @@ func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType Demultiplexed
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
var _ manet.Listener = &demultiplexedListener{}
|
||||
var _ transport.GatedMaListener = &demultiplexedListener{}
|
||||
|
||||
type multiplexedListener struct {
|
||||
manet.Listener
|
||||
transport.GatedMaListener
|
||||
listeners map[DemultiplexedConnType]*demultiplexedListener
|
||||
mx sync.RWMutex
|
||||
|
||||
connGater connmgr.ConnectionGater
|
||||
rcmgr network.ResourceManager
|
||||
ctx context.Context
|
||||
closeFn func() error
|
||||
wg sync.WaitGroup
|
||||
@@ -153,7 +152,7 @@ type multiplexedListener struct {
|
||||
|
||||
var ErrListenerExists = errors.New("listener already exists for this conn type on this address")
|
||||
|
||||
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) {
|
||||
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (transport.GatedMaListener, error) {
|
||||
if !connType.IsKnown() {
|
||||
return nil, fmt.Errorf("unknown connection type: %s", connType)
|
||||
}
|
||||
@@ -166,8 +165,8 @@ func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType
|
||||
|
||||
ctx, cancel := context.WithCancel(m.ctx)
|
||||
l := &demultiplexedListener{
|
||||
buffer: make(chan manet.Conn),
|
||||
inner: m.Listener,
|
||||
buffer: make(chan *connWithScope),
|
||||
inner: m.GatedMaListener,
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
closeFn: func() error { m.removeDemultiplexedListener(connType); return nil },
|
||||
@@ -183,53 +182,35 @@ func (m *multiplexedListener) run() error {
|
||||
defer m.wg.Done()
|
||||
acceptQueue := make(chan struct{}, acceptQueueSize)
|
||||
for {
|
||||
c, err := m.Listener.Accept()
|
||||
c, connScope, err := m.GatedMaListener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Gate and resource limit the connection here.
|
||||
// If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer
|
||||
// which clogs up our entire connection queue.
|
||||
// This duplicates the responsibility of gating and resource limiting between here and the upgrader. The
|
||||
// alternative without duplication requires moving the process of upgrading the connection here, which forces
|
||||
// us to establish the websocket connection here. That is more duplication, or a significant breaking change.
|
||||
//
|
||||
// Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport
|
||||
// integration tests.
|
||||
if m.connGater != nil && !m.connGater.InterceptAccept(c) {
|
||||
log.Debugf("gater blocked incoming connection on local addr %s from %s",
|
||||
c.LocalMultiaddr(), c.RemoteMultiaddr())
|
||||
if err := c.Close(); err != nil {
|
||||
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked accept of new connection", "error", err)
|
||||
if err := c.Close(); err != nil {
|
||||
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
|
||||
select {
|
||||
case acceptQueue <- struct{}{}:
|
||||
// NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader.
|
||||
case <-m.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
cancelCtx()
|
||||
connScope.Done()
|
||||
c.Close()
|
||||
log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr())
|
||||
continue
|
||||
case <-m.ctx.Done():
|
||||
cancelCtx()
|
||||
connScope.Done()
|
||||
c.Close()
|
||||
log.Debugf("listener closed; dropping connection from: %s", c.RemoteMultiaddr())
|
||||
continue
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer func() { <-acceptQueue }()
|
||||
defer m.wg.Done()
|
||||
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
|
||||
defer cancelCtx()
|
||||
t, c, err := identifyConnType(c)
|
||||
if err != nil {
|
||||
// conn closed by identifyConnType
|
||||
connScope.Done()
|
||||
log.Debugf("error demultiplexing connection: %s", err.Error())
|
||||
return
|
||||
@@ -279,7 +260,7 @@ func (m *multiplexedListener) Close() error {
|
||||
}
|
||||
|
||||
func (m *multiplexedListener) closeListener() error {
|
||||
lerr := m.Listener.Close()
|
||||
lerr := m.GatedMaListener.Close()
|
||||
cerr := m.closeFn()
|
||||
return errors.Join(lerr, cerr)
|
||||
}
|
||||
@@ -298,19 +279,19 @@ func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnTyp
|
||||
}
|
||||
|
||||
type demultiplexedListener struct {
|
||||
buffer chan manet.Conn
|
||||
inner manet.Listener
|
||||
buffer chan *connWithScope
|
||||
inner transport.GatedMaListener
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (m *demultiplexedListener) Accept() (manet.Conn, error) {
|
||||
func (m *demultiplexedListener) Accept() (manet.Conn, network.ConnManagementScope, error) {
|
||||
select {
|
||||
case c := <-m.buffer:
|
||||
return c, nil
|
||||
return c.ManetTCPConnInterface, c.ConnScope, nil
|
||||
case <-m.ctx.Done():
|
||||
return nil, transport.ErrListenerClosed
|
||||
return nil, nil, transport.ErrListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -17,6 +17,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/multiformats/go-multistream"
|
||||
@@ -53,6 +56,17 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
type maListener struct {
|
||||
transport.GatedMaListener
|
||||
}
|
||||
|
||||
var _ manet.Listener = &maListener{}
|
||||
|
||||
func (ml *maListener) Accept() (manet.Conn, error) {
|
||||
c, _, err := ml.GatedMaListener.Accept()
|
||||
return c, err
|
||||
}
|
||||
|
||||
type wsHandler struct{ conns chan *websocket.Conn }
|
||||
|
||||
func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -61,12 +75,19 @@ func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
wh.conns <- c
|
||||
}
|
||||
|
||||
func upgrader(t *testing.T) transport.Upgrader {
|
||||
t.Helper()
|
||||
upd, err := tptu.New(nil, nil, nil, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
return upd
|
||||
}
|
||||
|
||||
func TestListenerSingle(t *testing.T) {
|
||||
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
|
||||
const N = 64
|
||||
for _, enableReuseport := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
cm := NewConnMgr(enableReuseport, upgrader(t))
|
||||
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
@@ -96,7 +117,7 @@ func TestListenerSingle(t *testing.T) {
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < N; i++ {
|
||||
c, err := l.Accept()
|
||||
c, _, err := l.Accept()
|
||||
require.NoError(t, err)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -117,12 +138,12 @@ func TestListenerSingle(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
cm := NewConnMgr(enableReuseport, upgrader(t))
|
||||
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
http.Serve(manet.NetListener(l), wh)
|
||||
http.Serve(manet.NetListener(&maListener{GatedMaListener: l}), wh)
|
||||
}()
|
||||
go func() {
|
||||
d := websocket.Dialer{}
|
||||
@@ -169,14 +190,14 @@ func TestListenerSingle(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
cm := NewConnMgr(enableReuseport, upgrader(t))
|
||||
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)}
|
||||
s.ServeTLS(manet.NetListener(l), "", "")
|
||||
s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: l}), "", "")
|
||||
}()
|
||||
go func() {
|
||||
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
@@ -228,7 +249,7 @@ func TestListenerMultiplexed(t *testing.T) {
|
||||
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
|
||||
const N = 20
|
||||
for _, enableReuseport := range []bool{true, false} {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
cm := NewConnMgr(enableReuseport, upgrader(t))
|
||||
msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
defer msl.Close()
|
||||
@@ -239,7 +260,7 @@ func TestListenerMultiplexed(t *testing.T) {
|
||||
require.Equal(t, wsl.Multiaddr(), msl.Multiaddr())
|
||||
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
http.Serve(manet.NetListener(wsl), wh)
|
||||
http.Serve(manet.NetListener(&maListener{GatedMaListener: wsl}), wh)
|
||||
}()
|
||||
|
||||
wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
|
||||
@@ -249,7 +270,7 @@ func TestListenerMultiplexed(t *testing.T) {
|
||||
whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)}
|
||||
s.ServeTLS(manet.NetListener(wssl), "", "")
|
||||
s.ServeTLS(manet.NetListener(&maListener{GatedMaListener: wssl}), "", "")
|
||||
}()
|
||||
|
||||
// multistream connections
|
||||
@@ -331,7 +352,7 @@ func TestListenerMultiplexed(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < N; i++ {
|
||||
c, err := msl.Accept()
|
||||
c, _, err := msl.Accept()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
@@ -404,7 +425,7 @@ func TestListenerMultiplexed(t *testing.T) {
|
||||
func TestListenerClose(t *testing.T) {
|
||||
testClose := func(listenAddr ma.Multiaddr) {
|
||||
// listen on port 0
|
||||
cm := NewConnMgr(false, nil, nil)
|
||||
cm := NewConnMgr(false, upgrader(t))
|
||||
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
@@ -459,7 +480,7 @@ func setDeferReset[T any](t testing.TB, ptr *T, val T) {
|
||||
func TestHitTimeout(t *testing.T) {
|
||||
setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond)
|
||||
// listen on port 0
|
||||
cm := NewConnMgr(false, nil, nil)
|
||||
cm := NewConnMgr(false, upgrader(t))
|
||||
|
||||
listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0")
|
||||
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
|
@@ -4,8 +4,6 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
@@ -67,15 +65,3 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
|
||||
t.Fatalf("expected network: \"websocket\", got \"%s\"", wsaddr.Network())
|
||||
}
|
||||
}
|
||||
|
||||
func TestListeningOnDNSAddr(t *testing.T) {
|
||||
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
|
||||
require.NoError(t, err)
|
||||
addr := ln.Multiaddr()
|
||||
first, rest := ma.SplitFirst(addr)
|
||||
require.Equal(t, ma.P_DNS, first.Protocol().Code)
|
||||
require.Equal(t, "localhost", first.Value())
|
||||
next, _ := ma.SplitFirst(rest)
|
||||
require.Equal(t, ma.P_TCP, next.Protocol().Code)
|
||||
require.NotEqual(t, 0, next.Value())
|
||||
}
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
|
||||
@@ -23,6 +21,7 @@ var GracefulCloseTimeout = 100 * time.Millisecond
|
||||
// Conn implements net.Conn interface for gorilla/websocket.
|
||||
type Conn struct {
|
||||
*ws.Conn
|
||||
Scope network.ConnManagementScope
|
||||
secure bool
|
||||
DefaultMessageType int
|
||||
reader io.Reader
|
||||
@@ -36,10 +35,8 @@ type Conn struct {
|
||||
var _ net.Conn = (*Conn)(nil)
|
||||
var _ manet.Conn = (*Conn)(nil)
|
||||
|
||||
// NewConn creates a Conn given a regular gorilla/websocket Conn.
|
||||
//
|
||||
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
|
||||
func NewConn(raw *ws.Conn, secure bool) *Conn {
|
||||
// newConn creates a Conn given a regular gorilla/websocket Conn.
|
||||
func newConn(raw *ws.Conn, secure bool, scope network.ConnManagementScope) *Conn {
|
||||
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
|
||||
laddr, err := manet.FromNetAddr(lna)
|
||||
if err != nil {
|
||||
@@ -56,6 +53,7 @@ func NewConn(raw *ws.Conn, secure bool) *Conn {
|
||||
|
||||
c := &Conn{
|
||||
Conn: raw,
|
||||
Scope: scope,
|
||||
secure: secure,
|
||||
DefaultMessageType: ws.BinaryMessage,
|
||||
laddr: laddr,
|
||||
@@ -136,23 +134,6 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) Scope() network.ConnManagementScope {
|
||||
nc := c.NetConn()
|
||||
if sc, ok := nc.(interface {
|
||||
Scope() network.ConnManagementScope
|
||||
}); ok {
|
||||
return sc.Scope()
|
||||
}
|
||||
if nc, ok := nc.(*tls.Conn); ok {
|
||||
if sc, ok := nc.NetConn().(interface {
|
||||
Scope() network.ConnManagementScope
|
||||
}); ok {
|
||||
return sc.Scope()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
// subsequent and concurrent calls will return the same error value.
|
||||
// This method is thread-safe.
|
||||
@@ -201,13 +182,3 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type capableConn struct {
|
||||
transport.CapableConn
|
||||
}
|
||||
|
||||
func (c *capableConn) ConnState() network.ConnectionState {
|
||||
cs := c.CapableConn.ConnState()
|
||||
cs.Transport = "websocket"
|
||||
return cs
|
||||
}
|
||||
|
@@ -1,17 +1,22 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
ws "github.com/gorilla/websocket"
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
|
||||
@@ -23,8 +28,9 @@ var log = logging.Logger("websocket-transport")
|
||||
var stdLog = zap.NewStdLog(log.Desugar())
|
||||
|
||||
type listener struct {
|
||||
nl net.Listener
|
||||
netListener *httpNetListener
|
||||
server http.Server
|
||||
wsUpgrader ws.Upgrader
|
||||
// The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS,
|
||||
// so we can't rely on checking if server.TLSConfig is set.
|
||||
isWss bool
|
||||
@@ -36,8 +42,11 @@ type listener struct {
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
closed chan struct{}
|
||||
wsurl *url.URL
|
||||
}
|
||||
|
||||
var _ transport.GatedMaListener = &listener{}
|
||||
|
||||
func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
|
||||
if !pwma.isWSS {
|
||||
return pwma.restMultiaddr.AppendComponent(wsComponent)
|
||||
@@ -52,7 +61,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
|
||||
|
||||
// newListener creates a new listener from a raw net.Listener.
|
||||
// tlsConf may be nil (for unencrypted websockets).
|
||||
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) {
|
||||
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, upgrader transport.Upgrader, handshakeTimeout time.Duration) (*listener, error) {
|
||||
parsed, err := parseWebsocketMultiaddr(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -62,17 +71,13 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
|
||||
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
|
||||
}
|
||||
|
||||
var nl net.Listener
|
||||
|
||||
var gmal transport.GatedMaListener
|
||||
if sharedTcp == nil {
|
||||
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nl, err = net.Listen(lnet, lnaddr)
|
||||
mal, err := manet.Listen(parsed.restMultiaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gmal = upgrader.GateMaListener(mal)
|
||||
} else {
|
||||
var connType tcpreuse.DemultiplexedConnType
|
||||
if parsed.isWSS {
|
||||
@@ -80,89 +85,146 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
|
||||
} else {
|
||||
connType = tcpreuse.DemultiplexedConnType_HTTP
|
||||
}
|
||||
mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
|
||||
gmal, err = sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nl = manet.NetListener(mal)
|
||||
}
|
||||
|
||||
laddr, err := manet.FromNetAddr(nl.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// laddr has the correct port in case we listened on port 0
|
||||
laddr := gmal.Multiaddr()
|
||||
|
||||
first, _ := ma.SplitFirst(a)
|
||||
// Don't resolve dns addresses.
|
||||
// We want to be able to announce domain names, so the peer can validate the TLS certificate.
|
||||
first, _ := ma.SplitFirst(a)
|
||||
if c := first.Protocol().Code; c == ma.P_DNS || c == ma.P_DNS4 || c == ma.P_DNS6 || c == ma.P_DNSADDR {
|
||||
_, last := ma.SplitFirst(laddr)
|
||||
laddr = first.Encapsulate(last)
|
||||
}
|
||||
parsed.restMultiaddr = laddr
|
||||
|
||||
listenAddr := parsed.toMultiaddr()
|
||||
wsurl, err := parseMultiaddr(listenAddr)
|
||||
if err != nil {
|
||||
gmal.Close()
|
||||
return nil, fmt.Errorf("failed to parse multiaddr to URL: %v: %w", listenAddr, err)
|
||||
}
|
||||
ln := &listener{
|
||||
nl: nl,
|
||||
netListener: &httpNetListener{
|
||||
GatedMaListener: gmal,
|
||||
handshakeTimeout: handshakeTimeout,
|
||||
},
|
||||
laddr: parsed.toMultiaddr(),
|
||||
incoming: make(chan *Conn),
|
||||
closed: make(chan struct{}),
|
||||
isWss: parsed.isWSS,
|
||||
wsurl: wsurl,
|
||||
wsUpgrader: ws.Upgrader{
|
||||
// Allow requests from *all* origins.
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
HandshakeTimeout: handshakeTimeout,
|
||||
},
|
||||
}
|
||||
ln.server = http.Server{Handler: ln, ErrorLog: stdLog}
|
||||
if parsed.isWSS {
|
||||
ln.isWss = true
|
||||
ln.server.TLSConfig = tlsConf
|
||||
}
|
||||
ln.server = http.Server{Handler: ln, ErrorLog: stdLog, ConnContext: ln.ConnContext, TLSConfig: tlsConf}
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
func (l *listener) serve() {
|
||||
defer close(l.closed)
|
||||
if !l.isWss {
|
||||
l.server.Serve(l.nl)
|
||||
l.server.Serve(l.netListener)
|
||||
} else {
|
||||
l.server.ServeTLS(l.nl, "", "")
|
||||
l.server.ServeTLS(l.netListener, "", "")
|
||||
}
|
||||
}
|
||||
|
||||
type connKey struct{}
|
||||
|
||||
func (l *listener) ConnContext(ctx context.Context, c net.Conn) context.Context {
|
||||
// prefer `*tls.Conn` over `(interface{NetConn() net.Conn})` in case `manet.Conn` is extended
|
||||
// to support a `NetConn() net.Conn` method.
|
||||
if tc, ok := c.(*tls.Conn); ok {
|
||||
c = tc.NetConn()
|
||||
}
|
||||
if nc, ok := c.(*negotiatingConn); ok {
|
||||
return context.WithValue(ctx, connKey{}, nc)
|
||||
}
|
||||
log.Errorf("BUG: expected net.Conn of type *websocket.negotiatingConn: got %T", c)
|
||||
// might as well close the connection as there's no way to proceed now.
|
||||
c.Close()
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (l *listener) extractConnFromContext(ctx context.Context) (*negotiatingConn, error) {
|
||||
c := ctx.Value(connKey{})
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got nil")
|
||||
}
|
||||
nc, ok := c.(*negotiatingConn)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected *websocket.negotiatingConn in context: got %T", c)
|
||||
}
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
c, err := l.wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
// The upgrader writes a response for us.
|
||||
return
|
||||
}
|
||||
nc := NewConn(c, l.isWss)
|
||||
if nc == nil {
|
||||
nc, err := l.extractConnFromContext(r.Context())
|
||||
if err != nil {
|
||||
c.Close()
|
||||
w.WriteHeader(500)
|
||||
log.Errorf("BUG: failed to extract conn from context: RemoteAddr: %s: err: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
cs, err := nc.Unwrap()
|
||||
if err != nil {
|
||||
c.Close()
|
||||
w.WriteHeader(500)
|
||||
log.Debugf("connection timed out from: %s", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
conn := newConn(c, l.isWss, cs.Scope)
|
||||
if conn == nil {
|
||||
c.Close()
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.incoming <- nc:
|
||||
case l.incoming <- conn:
|
||||
case <-l.closed:
|
||||
nc.Close()
|
||||
conn.Close()
|
||||
}
|
||||
// The connection has been hijacked, it's safe to return.
|
||||
}
|
||||
|
||||
func (l *listener) Accept() (manet.Conn, error) {
|
||||
func (l *listener) Accept() (manet.Conn, network.ConnManagementScope, error) {
|
||||
select {
|
||||
case c, ok := <-l.incoming:
|
||||
if !ok {
|
||||
return nil, transport.ErrListenerClosed
|
||||
return nil, nil, transport.ErrListenerClosed
|
||||
}
|
||||
return c, nil
|
||||
return c, c.Scope, nil
|
||||
case <-l.closed:
|
||||
return nil, transport.ErrListenerClosed
|
||||
return nil, nil, transport.ErrListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (l *listener) Addr() net.Addr {
|
||||
return l.nl.Addr()
|
||||
return &Addr{URL: l.wsurl}
|
||||
}
|
||||
|
||||
func (l *listener) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
err1 := l.nl.Close()
|
||||
err1 := l.netListener.Close()
|
||||
err2 := l.server.Close()
|
||||
<-l.closed
|
||||
l.closeErr = errors.Join(err1, err2)
|
||||
@@ -174,14 +236,74 @@ func (l *listener) Multiaddr() ma.Multiaddr {
|
||||
return l.laddr
|
||||
}
|
||||
|
||||
type transportListener struct {
|
||||
transport.Listener
|
||||
// httpNetListener is a net.Listener that adapts a transport.GatedMaListener to a net.Listener.
|
||||
// It wraps the manet.Conn, and the Scope from the underlying gated listener in a connWithScope.
|
||||
type httpNetListener struct {
|
||||
transport.GatedMaListener
|
||||
handshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
func (l *transportListener) Accept() (transport.CapableConn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
var _ net.Listener = &httpNetListener{}
|
||||
|
||||
func (l *httpNetListener) Accept() (net.Conn, error) {
|
||||
conn, scope, err := l.GatedMaListener.Accept()
|
||||
if err != nil {
|
||||
if scope != nil {
|
||||
log.Errorf("BUG: scope non-nil when err is non nil: %v", err)
|
||||
scope.Done()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &capableConn{CapableConn: conn}, nil
|
||||
connWithScope := connWithScope{
|
||||
Conn: conn,
|
||||
Scope: scope,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), l.handshakeTimeout)
|
||||
return &negotiatingConn{
|
||||
connWithScope: connWithScope,
|
||||
ctx: ctx,
|
||||
cancelCtx: cancel,
|
||||
stopClose: context.AfterFunc(ctx, func() {
|
||||
connWithScope.Close()
|
||||
log.Debugf("handshake timeout for conn from: %s", conn.RemoteAddr())
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type connWithScope struct {
|
||||
net.Conn
|
||||
Scope network.ConnManagementScope
|
||||
}
|
||||
|
||||
func (c connWithScope) Close() error {
|
||||
c.Scope.Done()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type negotiatingConn struct {
|
||||
connWithScope
|
||||
ctx context.Context
|
||||
cancelCtx context.CancelFunc
|
||||
stopClose func() bool
|
||||
}
|
||||
|
||||
// Close closes the negotiating conn and the underlying connWithScope
|
||||
// This will be called in case the tls handshake or websocket upgrade fails.
|
||||
func (c *negotiatingConn) Close() error {
|
||||
defer c.cancelCtx()
|
||||
if c.stopClose != nil {
|
||||
c.stopClose()
|
||||
}
|
||||
return c.connWithScope.Close()
|
||||
}
|
||||
|
||||
func (c *negotiatingConn) Unwrap() (connWithScope, error) {
|
||||
defer c.cancelCtx()
|
||||
if c.stopClose != nil {
|
||||
if !c.stopClose() {
|
||||
return connWithScope{}, errors.New("timed out")
|
||||
}
|
||||
c.stopClose = nil
|
||||
}
|
||||
return c.connWithScope, nil
|
||||
}
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
@@ -51,14 +50,6 @@ func init() {
|
||||
manet.RegisterToNetAddr(ConvertWebsocketMultiaddrToNetAddr, "wss")
|
||||
}
|
||||
|
||||
// Default gorilla upgrader
|
||||
var upgrader = ws.Upgrader{
|
||||
// Allow requests from *all* origins.
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type Option func(*WebsocketTransport) error
|
||||
|
||||
// WithTLSClientConfig sets a TLS client configuration on the WebSocket Dialer. Only
|
||||
@@ -81,15 +72,24 @@ func WithTLSConfig(conf *tls.Config) Option {
|
||||
}
|
||||
}
|
||||
|
||||
var defaultHandshakeTimeout = 15 * time.Second
|
||||
|
||||
// WithHandshakeTimeout sets a timeout for the websocket upgrade.
|
||||
func WithHandshakeTimeout(timeout time.Duration) Option {
|
||||
return func(t *WebsocketTransport) error {
|
||||
t.handshakeTimeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WebsocketTransport is the actual go-libp2p transport
|
||||
type WebsocketTransport struct {
|
||||
upgrader transport.Upgrader
|
||||
rcmgr network.ResourceManager
|
||||
|
||||
tlsClientConf *tls.Config
|
||||
tlsConf *tls.Config
|
||||
|
||||
sharedTcp *tcpreuse.ConnMgr
|
||||
handshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
var _ transport.Transport = (*WebsocketTransport)(nil)
|
||||
@@ -103,6 +103,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreus
|
||||
rcmgr: rcmgr,
|
||||
tlsClientConf: &tls.Config{},
|
||||
sharedTcp: sharedTCP,
|
||||
handshakeTimeout: defaultHandshakeTimeout,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt(t); err != nil {
|
||||
@@ -176,7 +177,7 @@ func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p pee
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, connScope network.ConnManagementScope) (transport.CapableConn, error) {
|
||||
macon, err := t.maDial(ctx, raddr)
|
||||
macon, err := t.maDial(ctx, raddr, connScope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -187,14 +188,14 @@ func (t *WebsocketTransport) dialWithScope(ctx context.Context, raddr ma.Multiad
|
||||
return &capableConn{CapableConn: conn}, nil
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
|
||||
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr, scope network.ConnManagementScope) (manet.Conn, error) {
|
||||
wsurl, err := parseMultiaddr(raddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
isWss := wsurl.Scheme == "wss"
|
||||
dialer := ws.Dialer{
|
||||
HandshakeTimeout: 30 * time.Second,
|
||||
HandshakeTimeout: t.handshakeTimeout,
|
||||
// Inherit the default proxy behavior
|
||||
Proxy: ws.DefaultDialer.Proxy,
|
||||
}
|
||||
@@ -236,7 +237,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mnc, err := manet.WrapNetConn(NewConn(wscon, isWss))
|
||||
mnc, err := manet.WrapNetConn(newConn(wscon, isWss, scope))
|
||||
if err != nil {
|
||||
wscon.Close()
|
||||
return nil, err
|
||||
@@ -244,12 +245,12 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
|
||||
return mnc, nil
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
|
||||
func (t *WebsocketTransport) gatedMaListen(a ma.Multiaddr) (transport.GatedMaListener, error) {
|
||||
var tlsConf *tls.Config
|
||||
if t.tlsConf != nil {
|
||||
tlsConf = t.tlsConf.Clone()
|
||||
}
|
||||
l, err := newListener(a, tlsConf, t.sharedTcp)
|
||||
l, err := newListener(a, tlsConf, t.sharedTcp, t.upgrader, t.handshakeTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -258,9 +259,32 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
|
||||
}
|
||||
|
||||
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) {
|
||||
malist, err := t.maListen(a)
|
||||
gmal, err := t.gatedMaListen(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil
|
||||
return &transportListener{Listener: t.upgrader.UpgradeGatedMaListener(t, gmal)}, nil
|
||||
}
|
||||
|
||||
// transportListener wraps a transport.Listener to provide connections with a `ConnState() network.ConnectionState` method.
|
||||
type transportListener struct {
|
||||
transport.Listener
|
||||
}
|
||||
|
||||
type capableConn struct {
|
||||
transport.CapableConn
|
||||
}
|
||||
|
||||
func (c *capableConn) ConnState() network.ConnectionState {
|
||||
cs := c.CapableConn.ConnState()
|
||||
cs.Transport = "websocket"
|
||||
return cs
|
||||
}
|
||||
|
||||
func (l *transportListener) Accept() (transport.CapableConn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &capableConn{CapableConn: conn}, nil
|
||||
}
|
||||
|
@@ -35,6 +35,7 @@ import (
|
||||
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -296,18 +297,39 @@ func TestDialWssNoClientCert(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebsocketTransport(t *testing.T) {
|
||||
t.Run("/ws", func(t *testing.T) {
|
||||
peerA, ua := newUpgrader(t)
|
||||
ta, err := New(ua, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ub := newUpgrader(t)
|
||||
peerB, ub := newUpgrader(t)
|
||||
tb, err := New(ub, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA)
|
||||
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
|
||||
|
||||
})
|
||||
t.Run("/wss", func(t *testing.T) {
|
||||
peerA, ua := newUpgrader(t)
|
||||
tca := generateTLSConfig(t)
|
||||
ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
peerB, ub := newUpgrader(t)
|
||||
tcb := generateTLSConfig(t)
|
||||
tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA)
|
||||
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
|
||||
})
|
||||
}
|
||||
|
||||
func isWSS(addr ma.Multiaddr) bool {
|
||||
@@ -441,7 +463,7 @@ func TestConcurrentClose(t *testing.T) {
|
||||
_, u := newUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -451,7 +473,7 @@ func TestConcurrentClose(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
c, err := tpt.maDial(context.Background(), l.Multiaddr())
|
||||
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -467,7 +489,7 @@ func TestConcurrentClose(t *testing.T) {
|
||||
}()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
c, err := l.Accept()
|
||||
c, _, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -481,7 +503,7 @@ func TestWriteZero(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -490,7 +512,7 @@ func TestWriteZero(t *testing.T) {
|
||||
msg := []byte(nil)
|
||||
|
||||
go func() {
|
||||
c, err := tpt.maDial(context.Background(), l.Multiaddr())
|
||||
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -509,7 +531,7 @@ func TestWriteZero(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
c, err := l.Accept()
|
||||
c, _, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -623,3 +645,98 @@ func TestSocksProxy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerAddr(t *testing.T) {
|
||||
_, upgrader := newUpgrader(t)
|
||||
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
|
||||
require.NoError(t, err)
|
||||
l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
require.NoError(t, err)
|
||||
defer l1.Close()
|
||||
require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String())
|
||||
l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
|
||||
require.NoError(t, err)
|
||||
defer l2.Close()
|
||||
require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String())
|
||||
}
|
||||
func TestHandshakeTimeout(t *testing.T) {
|
||||
handshakeTimeout := 200 * time.Millisecond
|
||||
_, upgrader := newUpgrader(t)
|
||||
tlsconf := generateTLSConfig(t)
|
||||
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf))
|
||||
require.NoError(t, err)
|
||||
|
||||
fastWSDialer := gws.Dialer{
|
||||
HandshakeTimeout: 10 * handshakeTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
tcpConn, err := net.Dial("tcp", addr)
|
||||
if !assert.NoError(t, err) {
|
||||
return nil, err
|
||||
}
|
||||
return tcpConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
slowWSDialer := gws.Dialer{
|
||||
HandshakeTimeout: 10 * handshakeTimeout,
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
tcpConn, err := net.Dial("tcp", addr)
|
||||
if !assert.NoError(t, err) {
|
||||
return nil, err
|
||||
}
|
||||
// wait to simulate a slow handshake
|
||||
time.Sleep(2 * handshakeTimeout)
|
||||
return tcpConn, nil
|
||||
},
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
t.Run("ws", func(t *testing.T) {
|
||||
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
|
||||
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
require.NoError(t, err)
|
||||
defer wsListener.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
||||
defer cancel()
|
||||
conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
resp.Body.Close()
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
||||
defer cancel()
|
||||
conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
resp.Body.Close()
|
||||
t.Fatal("should error as the handshake will time out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wss", func(t *testing.T) {
|
||||
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
|
||||
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
|
||||
require.NoError(t, err)
|
||||
defer wsListener.Close()
|
||||
|
||||
// Test that the normal dial works fine
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
||||
defer cancel()
|
||||
wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
||||
require.NoError(t, err)
|
||||
wsConn.Close()
|
||||
resp.Body.Close()
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
||||
defer cancel()
|
||||
wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
||||
if err == nil {
|
||||
wsConn.Close()
|
||||
resp.Body.Close()
|
||||
t.Fatal("websocket handshake should have timed out")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user