feat(tcpreuse): add options for sharing TCP listeners amongst TCP, WS and WSS transports (#2984)

Allows the same socket to be shared amongst TCP,WS,WSS transports.

---------

Co-authored-by: sukun <sukunrt@gmail.com>
Co-authored-by: Marco Munizaga <git@marcopolo.io>
This commit is contained in:
Adin Schmahmann
2024-11-04 12:41:32 -05:00
committed by GitHub
parent 362e5836f1
commit 5a47a90938
32 changed files with 1598 additions and 107 deletions

View File

@@ -38,6 +38,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/prometheus/client_golang/prometheus"
@@ -145,6 +146,8 @@ type Config struct {
CustomIPv6BlackHoleSuccessCounter bool
UserFxOptions []fx.Option
ShareTCPListener bool
}
func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) {
@@ -289,6 +292,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr {
if !cfg.ShareTCPListener {
return nil
}
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr)
}),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
quicAddrPorts := map[string]struct{}{}

View File

@@ -59,7 +59,7 @@ func TestTransportConstructor(t *testing.T) {
_ connmgr.ConnectionGater,
upgrader transport.Upgrader,
) transport.Transport {
tpt, err := tcp.NewTCPTransport(upgrader, nil)
tpt, err := tcp.NewTCPTransport(upgrader, nil, nil)
require.NoError(t, err)
return tpt
}
@@ -751,3 +751,27 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
}},
}
}
func TestSharedTCPAddr(t *testing.T) {
h, err := New(
ShareTCPListener(),
Transport(tcp.NewTCPTransport),
Transport(websocket.New),
ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"),
ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"),
)
require.NoError(t, err)
sawTCP := false
sawWS := false
for _, addr := range h.Addrs() {
if strings.HasSuffix(addr.String(), "/tcp/8888") {
sawTCP = true
}
if strings.HasSuffix(addr.String(), "/tcp/8888/ws") {
sawWS = true
}
}
require.True(t, sawTCP)
require.True(t, sawWS)
h.Close()
}

View File

@@ -643,3 +643,15 @@ func WithFxOption(opts ...fx.Option) Option {
return nil
}
}
// ShareTCPListener shares the same listen address between TCP and Websocket
// transports. This lets both transports use the same TCP port.
//
// Currently this behavior is Opt-in. In a future release this will be the
// default, and this option will be removed.
func ShareTCPListener() Option {
return func(cfg *Config) error {
cfg.ShareTCPListener = true
return nil
}
}

View File

@@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm {
upgrader := makeUpgrader(t, s)
var tcpOpts []tcp.Option
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...)
require.NoError(t, err)
if err := s.AddTransport(tcpTransport); err != nil {
t.Fatal(err)

View File

@@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) {
s, err := swarm.NewSwarm("local", nil, eventbus.NewBus())
require.NoError(t, err)
tcpTr, err := tcp.NewTCPTransport(nil, nil)
tcpTr, err := tcp.NewTCPTransport(nil, nil, nil)
require.NoError(t, err)
require.NoError(t, s.AddTransport(tcpTr))
reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})

View File

@@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) {
ps.AddPrivKey(id, priv)
t.Cleanup(func() { ps.Close() })
tpt, err := websocket.New(nil, &network.NullResourceManager{})
tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver}))
require.NoError(t, err)
@@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) {
require.NoError(t, err)
defer s.Close()
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
err = s.AddTransport(tpt)
require.NoError(t, err)
@@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
})
// Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out.
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
err = s.AddTransport(tpt)
require.NoError(t, err)
@@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
err = s.AddTransport(wtTpt)
require.NoError(t, err)
wsTpt, err := websocket.New(nil, &network.NullResourceManager{})
wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil)
require.NoError(t, err)
err = s.AddTransport(wsTpt)
require.NoError(t, err)

View File

@@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm {
if cfg.disableReuseport {
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
}
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...)
require.NoError(t, err)
if err := s.AddTransport(tcpTransport); err != nil {
t.Fatal(err)

View File

@@ -84,23 +84,33 @@ func (l *listener) handleIncoming() {
}
catcher.Reset()
// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
// Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation.
var connScope network.ConnManagementScope
if sc, ok := maconn.(interface {
Scope() network.ConnManagementScope
}); ok {
connScope = sc.Scope()
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to incoming connection rejected by resource manager: %s", err)
if connScope == nil {
// gate the connection if applicable
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
if err := maconn.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
var err error
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := maconn.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
continue
}
// The go routine below calls Release when the context is

View File

@@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u
upgrader := swarmt.GenUpgrader(t, netw, nil)
upgraders = append(upgraders, upgrader)
tpt, err := tcp.NewTCPTransport(upgrader, nil)
tpt, err := tcp.NewTCPTransport(upgrader, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,6 +2,8 @@ package transport_integration
import (
"context"
"encoding/binary"
"net/netip"
"strings"
"testing"
"time"
@@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
return addr
}
func addrPort(addr ma.Multiaddr) netip.AddrPort {
a := netip.Addr{}
p := uint16(0)
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 {
a, _ = netip.AddrFromSlice(c.RawValue())
return false
}
if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP {
p = binary.BigEndian.Uint16(c.RawValue())
return true
}
return false
})
return netip.AddrPortFrom(a, p)
}
func TestInterceptPeerDial(t *testing.T) {
if race.WithRace() {
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
@@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
}).AnyTimes()
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
})
} else {
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
// remove the certhash component from WebTransport addresses
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
})
}

View File

@@ -99,6 +99,38 @@ var transportsToTest = []TransportTestCase{
return h
},
},
{
Name: "TCP-Shared / TLS / Yamux",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket-Shared",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
libp2pOpts := transformOpts(opts)
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
if opts.NoListen {
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
} else {
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws"))
}
h, err := libp2p.New(libp2pOpts...)
require.NoError(t, err)
return h
},
},
{
Name: "WebSocket",
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {

View File

@@ -0,0 +1,54 @@
// go:build: unix
package tcp
import (
"testing"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
"github.com/stretchr/testify/require"
)
func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) {
peerA, ia := makeInsecureMuxer(t)
_, ib := makeInsecureMuxer(t)
sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil)
sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil)
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil, sharedTCPSocketA, WithMetrics())
require.NoError(t, err)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil, sharedTCPSocketB, WithMetrics())
require.NoError(t, err)
zero := "/ip4/127.0.0.1/tcp/0"
// Not running any test that needs more than 1 conn because the testsuite
// opens multiple conns via multiple listeners, which is not expected to work
// with the shared TCP socket.
subtestsToRun := []ttransport.TransportSubTestFn{
ttransport.SubtestProtocols,
ttransport.SubtestBasic,
ttransport.SubtestCancel,
ttransport.SubtestPingPong,
// Stolen from the stream muxer test suite.
ttransport.SubtestStress1Conn1Stream1Msg,
ttransport.SubtestStress1Conn1Stream100Msg,
ttransport.SubtestStress1Conn100Stream100Msg,
ttransport.SubtestStress1Conn1000Stream10Msg,
ttransport.SubtestStress1Conn100Stream100Msg10MB,
ttransport.SubtestStreamOpenStress,
ttransport.SubtestStreamReset,
}
ttransport.SubtestTransportWithFs(t, ta, tb, zero, peerA, subtestsToRun)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
@@ -33,6 +34,9 @@ type canKeepAlive interface {
var _ canKeepAlive = &net.TCPConn{}
// Deprecated: Use tcpreuse.ReuseportIsAvailable
var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable
func tryKeepAlive(conn net.Conn, keepAlive bool) {
keepAliveConn, ok := conn.(canKeepAlive)
if !ok {
@@ -122,6 +126,9 @@ type TcpTransport struct {
disableReuseport bool // Explicitly disable reuseport.
enableMetrics bool
// share and demultiplex TCP listeners across multiple transports
sharedTcp *tcpreuse.ConnMgr
// TCP connect timeout
connectTimeout time.Duration
@@ -134,8 +141,8 @@ var _ transport.Transport = &TcpTransport{}
var _ transport.DialUpdater = &TcpTransport{}
// NewTCPTransport creates a tcp transport object that tracks dialers and listeners
// created. It represents an entire TCP stack (though it might not necessarily be).
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) {
// created.
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
@@ -143,6 +150,7 @@ func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager,
upgrader: upgrader,
connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option
rcmgr: rcmgr,
sharedTcp: sharedTCP,
}
for _, o := range opts {
if err := o(tr); err != nil {
@@ -168,6 +176,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co
defer cancel()
}
if t.sharedTcp != nil {
return t.sharedTcp.DialContext(ctx, raddr)
}
if t.UseReuseport() {
return t.reuse.DialContext(ctx, raddr)
}
@@ -233,10 +245,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p
// UseReuseport returns true if reuseport is enabled and available.
func (t *TcpTransport) UseReuseport() bool {
return !t.disableReuseport && ReuseportIsAvailable()
return !t.disableReuseport && tcpreuse.ReuseportIsAvailable()
}
func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) {
if t.UseReuseport() {
return t.reuse.Listen(laddr)
}
@@ -245,10 +257,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
// Listen listens on the given multiaddr.
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
list, err := t.maListen(laddr)
var list manet.Listener
var err error
if t.sharedTcp == nil {
list, err = t.unsharedMAListen(laddr)
} else {
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
}
if err != nil {
return nil, err
}
if t.enableMetrics {
list = newTracingListener(&tcpListener{list, 0})
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
ma "github.com/multiformats/go-multiaddr"
@@ -31,19 +32,19 @@ func TestTcpTransport(t *testing.T) {
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil)
tb, err := NewTCPTransport(ub, nil, nil)
require.NoError(t, err)
zero := "/ip4/127.0.0.1/tcp/0"
ttransport.SubtestTransport(t, ta, tb, zero, peerA)
envReuseportVal = false
tcpreuse.EnvReuseportVal = false
}
envReuseportVal = true
tcpreuse.EnvReuseportVal = true
}
func TestTcpTransportWithMetrics(t *testing.T) {
@@ -52,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) {
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil, WithMetrics())
ta, err := NewTCPTransport(ua, nil, nil, WithMetrics())
require.NoError(t, err)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil, WithMetrics())
tb, err := NewTCPTransport(ub, nil, nil, WithMetrics())
require.NoError(t, err)
zero := "/ip4/127.0.0.1/tcp/0"
@@ -72,7 +73,7 @@ func TestResourceManager(t *testing.T) {
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
@@ -81,7 +82,7 @@ func TestResourceManager(t *testing.T) {
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tb, err := NewTCPTransport(ub, rcmgr)
tb, err := NewTCPTransport(ub, rcmgr, nil)
require.NoError(t, err)
t.Run("success", func(t *testing.T) {
@@ -119,16 +120,16 @@ func TestTcpTransportCantDialDNS(t *testing.T) {
require.NoError(t, err)
var u transport.Upgrader
tpt, err := NewTCPTransport(u, nil)
tpt, err := NewTCPTransport(u, nil, nil)
require.NoError(t, err)
if tpt.CanDial(dnsa) {
t.Fatal("shouldn't be able to dial dns")
}
envReuseportVal = false
tcpreuse.EnvReuseportVal = false
}
envReuseportVal = true
tcpreuse.EnvReuseportVal = true
}
func TestTcpTransportCantListenUtp(t *testing.T) {
@@ -137,15 +138,15 @@ func TestTcpTransportCantListenUtp(t *testing.T) {
require.NoError(t, err)
var u transport.Upgrader
tpt, err := NewTCPTransport(u, nil)
tpt, err := NewTCPTransport(u, nil, nil)
require.NoError(t, err)
_, err = tpt.Listen(utpa)
require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport")
envReuseportVal = false
tcpreuse.EnvReuseportVal = false
}
envReuseportVal = true
tcpreuse.EnvReuseportVal = true
}
func TestDialWithUpdates(t *testing.T) {
@@ -154,7 +155,7 @@ func TestDialWithUpdates(t *testing.T) {
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
@@ -162,7 +163,7 @@ func TestDialWithUpdates(t *testing.T) {
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
tb, err := NewTCPTransport(ub, nil)
tb, err := NewTCPTransport(ub, nil, nil)
require.NoError(t, err)
updCh := make(chan transport.DialUpdate, 1)

View File

@@ -0,0 +1,26 @@
package tcpreuse
import (
"fmt"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn"
manet "github.com/multiformats/go-multiaddr/net"
)
type connWithScope struct {
sampledconn.ManetTCPConnInterface
scope network.ConnManagementScope
}
func (c connWithScope) Scope() network.ConnManagementScope {
return c.scope
}
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) {
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
return &connWithScope{tcpconn, scope}, nil
}
return nil, fmt.Errorf("manet.Conn is not a TCP Conn")
}

View File

@@ -0,0 +1,97 @@
package tcpreuse
import (
"errors"
"fmt"
"time"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn"
manet "github.com/multiformats/go-multiaddr/net"
)
// This is readiung the first 3 bytes of the packet. It should be instant.
const identifyConnTimeout = 1 * time.Second
type DemultiplexedConnType int
const (
DemultiplexedConnType_Unknown DemultiplexedConnType = iota
DemultiplexedConnType_MultistreamSelect
DemultiplexedConnType_HTTP
DemultiplexedConnType_TLS
)
func (t DemultiplexedConnType) String() string {
switch t {
case DemultiplexedConnType_MultistreamSelect:
return "MultistreamSelect"
case DemultiplexedConnType_HTTP:
return "HTTP"
case DemultiplexedConnType_TLS:
return "TLS"
default:
return fmt.Sprintf("Unknown(%d)", int(t))
}
}
func (t DemultiplexedConnType) IsKnown() bool {
return t >= 1 || t <= 3
}
// identifyConnType attempts to identify the connection type by peeking at the
// first few bytes.
// It Callers must not use the passed in Conn after this
// function returns. if an error is returned, the connection will be closed.
func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) {
if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}
s, c, err := sampledconn.PeekBytes(c)
if err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}
if err := c.SetReadDeadline(time.Time{}); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}
if IsMultistreamSelect(s) {
return DemultiplexedConnType_MultistreamSelect, c, nil
}
if IsTLS(s) {
return DemultiplexedConnType_TLS, c, nil
}
if IsHTTP(s) {
return DemultiplexedConnType_HTTP, c, nil
}
return DemultiplexedConnType_Unknown, c, nil
}
// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
type Prefix = [3]byte
func IsMultistreamSelect(s Prefix) bool {
return string(s[:]) == "\x13/m"
}
func IsHTTP(s Prefix) bool {
switch string(s[:]) {
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
return true
default:
return false
}
}
func IsTLS(s Prefix) bool {
switch string(s[:]) {
case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03":
return true
default:
return false
}
}

View File

@@ -0,0 +1,50 @@
package tcpreuse
import "testing"
func FuzzClash(f *testing.F) {
// make untyped literals type correctly
add := func(a, b, c byte) { f.Add(a, b, c) }
// multistream-select
add('\x13', '/', 'm')
// http
add('G', 'E', 'T')
add('H', 'E', 'A')
add('P', 'O', 'S')
add('P', 'U', 'T')
add('D', 'E', 'L')
add('C', 'O', 'N')
add('O', 'P', 'T')
add('T', 'R', 'A')
add('P', 'A', 'T')
// tls
add('\x16', '\x03', '\x01')
add('\x16', '\x03', '\x02')
add('\x16', '\x03', '\x03')
add('\x16', '\x03', '\x04')
f.Fuzz(func(t *testing.T, a, b, c byte) {
s := Prefix{a, b, c}
var total uint
ms := IsMultistreamSelect(s)
if ms {
total++
}
http := IsHTTP(s)
if http {
total++
}
tls := IsTLS(s)
if tls {
total++
}
if total > 1 {
t.Errorf("clash on: %q; ms: %v; http: %v; tls: %v", s, ms, http, tls)
}
})
}

View File

@@ -0,0 +1,16 @@
package tcpreuse
import (
"context"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
// DialContext is like Dial but takes a context.
func (t *ConnMgr) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
if t.useReuseport() {
return t.reuse.DialContext(ctx, raddr)
}
var d manet.Dialer
return d.DialContext(ctx, raddr)
}

View File

@@ -0,0 +1,89 @@
package sampledconn
import (
"errors"
"io"
"net"
"syscall"
"time"
manet "github.com/multiformats/go-multiaddr/net"
)
const peekSize = 3
type PeekedBytes = [peekSize]byte
var errNotSupported = errors.New("not supported on this platform")
var ErrNotTCPConn = errors.New("passed conn is not a TCPConn")
func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) {
if c, ok := conn.(syscall.Conn); ok {
b, err := OSPeekConn(c)
if err == nil {
return b, conn, nil
}
if err != errNotSupported {
return PeekedBytes{}, nil, err
}
// Fallback to wrapping the coonn
}
if c, ok := conn.(ManetTCPConnInterface); ok {
return newFallbackSampledConn(c)
}
return PeekedBytes{}, nil, ErrNotTCPConn
}
type fallbackPeekingConn struct {
ManetTCPConnInterface
peekedBytes PeekedBytes
bytesPeeked uint8
}
// tcpConnInterface is the interface for TCPConn's functions
// NOTE: `SyscallConn() (syscall.RawConn, error)` is here to make using this as
// a TCP Conn easier, but it's a potential footgun as you could skipped the
// peeked bytes if using the fallback
type tcpConnInterface interface {
net.Conn
syscall.Conn
CloseRead() error
CloseWrite() error
SetLinger(sec int) error
SetKeepAlive(keepalive bool) error
SetKeepAlivePeriod(d time.Duration) error
SetNoDelay(noDelay bool) error
MultipathTCP() (bool, error)
io.ReaderFrom
io.WriterTo
}
type ManetTCPConnInterface interface {
manet.Conn
tcpConnInterface
}
func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) {
s := &fallbackPeekingConn{ManetTCPConnInterface: conn}
_, err := io.ReadFull(conn, s.peekedBytes[:])
if err != nil {
return s.peekedBytes, nil, err
}
return s.peekedBytes, s, nil
}
func (sc *fallbackPeekingConn) Read(b []byte) (int, error) {
if int(sc.bytesPeeked) != len(sc.peekedBytes) {
red := copy(b, sc.peekedBytes[sc.bytesPeeked:])
sc.bytesPeeked += uint8(red)
return red, nil
}
return sc.ManetTCPConnInterface.Read(b)
}

View File

@@ -0,0 +1,11 @@
//go:build !unix && !windows
package sampledconn
import (
"syscall"
)
func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
return PeekedBytes{}, errNotSupported
}

View File

@@ -0,0 +1,78 @@
package sampledconn
import (
"io"
"syscall"
"testing"
"time"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/assert"
)
func TestSampledConn(t *testing.T) {
testCases := []string{
"platform",
"fallback",
}
// Start a TCP server
listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
assert.NoError(t, err)
defer listener.Close()
serverAddr := listener.Multiaddr()
// Server goroutine
go func() {
for i := 0; i < len(testCases); i++ {
conn, err := listener.Accept()
assert.NoError(t, err)
defer conn.Close()
// Write some data to the connection
_, err = conn.Write([]byte("hello"))
assert.NoError(t, err)
}
}()
// Give the server a moment to start
time.Sleep(100 * time.Millisecond)
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
// Create a TCP client
clientConn, err := manet.Dial(serverAddr)
assert.NoError(t, err)
defer clientConn.Close()
if tc == "platform" {
// Wrap the client connection in SampledConn
peeked, clientConn, err := PeekBytes(clientConn.(interface {
manet.Conn
syscall.Conn
}))
assert.NoError(t, err)
assert.Equal(t, "hel", string(peeked[:]))
buf := make([]byte, 5)
_, err = clientConn.Read(buf)
assert.NoError(t, err)
assert.Equal(t, "hello", string(buf))
} else {
// Wrap the client connection in SampledConn
sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface))
assert.NoError(t, err)
assert.Equal(t, "hel", string(sample[:]))
buf := make([]byte, 5)
_, err = io.ReadFull(sampledConn, buf)
assert.NoError(t, err)
assert.Equal(t, "hello", string(buf))
}
})
}
}

View File

@@ -0,0 +1,42 @@
//go:build unix
package sampledconn
import (
"errors"
"syscall"
)
func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
s := PeekedBytes{}
rawConn, err := conn.SyscallConn()
if err != nil {
return s, err
}
readBytes := 0
var readErr error
err = rawConn.Read(func(fd uintptr) bool {
for readBytes < peekSize {
var n int
n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK)
if errors.Is(readErr, syscall.EAGAIN) {
return false
}
if readErr != nil {
return true
}
readBytes += n
}
return true
})
if readErr != nil {
return s, readErr
}
if err != nil {
return s, err
}
return s, nil
}

View File

@@ -0,0 +1,49 @@
//go:build windows
package sampledconn
import (
"errors"
"golang.org/x/sys/windows"
"syscall"
)
func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
s := PeekedBytes{}
rawConn, err := conn.SyscallConn()
if err != nil {
return s, err
}
readBytes := 0
var readErr error
err = rawConn.Read(func(fd uintptr) bool {
for readBytes < peekSize {
var n uint32
flags := uint32(windows.MSG_PEEK)
wsabuf := windows.WSABuf{
Len: uint32(len(s) - readBytes),
Buf: &s[readBytes],
}
readErr = windows.WSARecv(windows.Handle(fd), &wsabuf, 1, &n, &flags, nil, nil)
if errors.Is(readErr, windows.WSAEWOULDBLOCK) {
return false
}
if readErr != nil {
return true
}
readBytes += int(n)
}
return true
})
if readErr != nil {
return s, readErr
}
if err != nil {
return s, err
}
return s, nil
}

View File

@@ -0,0 +1,329 @@
package tcpreuse
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
const acceptQueueSize = 64 // It is fine to read 3 bytes from 64 connections in parallel.
// How long we wait for a connection to be accepted before dropping it.
const acceptTimeout = 30 * time.Second
var log = logging.Logger("tcp-demultiplex")
// ConnMgr enables you to share the same listen address between TCP and WebSocket transports.
type ConnMgr struct {
enableReuseport bool
reuse reuseport.Transport
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
mx sync.Mutex
listeners map[string]*multiplexedListener
}
func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
return &ConnMgr{
enableReuseport: enableReuseport,
reuse: reuseport.Transport{},
connGater: gater,
rcmgr: rcmgr,
listeners: make(map[string]*multiplexedListener),
}
}
func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) {
if t.useReuseport() {
return t.reuse.Listen(listenAddr)
} else {
return manet.Listen(listenAddr)
}
}
func (t *ConnMgr) useReuseport() bool {
return t.enableReuseport && ReuseportIsAvailable()
}
func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) {
haveTCP := false
addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool {
if haveTCP {
return true
}
if c.Protocol().Code == ma.P_TCP {
haveTCP = true
}
return false
})
if !haveTCP {
return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr)
}
return addr, nil
}
// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections
// accepted from returned listeners need to be upgraded with a `transport.Upgrader`.
// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port.
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) {
if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType)
}
laddr, err := getTCPAddr(laddr)
if err != nil {
return nil, err
}
t.mx.Lock()
defer t.mx.Unlock()
ml, ok := t.listeners[laddr.String()]
if ok {
dl, err := ml.DemultiplexedListen(connType)
if err != nil {
return nil, err
}
return dl, nil
}
l, err := t.maListen(laddr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
cancelFunc := func() error {
cancel()
t.mx.Lock()
defer t.mx.Unlock()
delete(t.listeners, laddr.String())
delete(t.listeners, l.Multiaddr().String())
return l.Close()
}
ml = &multiplexedListener{
Listener: l,
listeners: make(map[DemultiplexedConnType]*demultiplexedListener),
ctx: ctx,
closeFn: cancelFunc,
connGater: t.connGater,
rcmgr: t.rcmgr,
}
t.listeners[laddr.String()] = ml
t.listeners[l.Multiaddr().String()] = ml
dl, err := ml.DemultiplexedListen(connType)
if err != nil {
cerr := ml.Close()
return nil, errors.Join(err, cerr)
}
ml.wg.Add(1)
go ml.run()
return dl, nil
}
var _ manet.Listener = &demultiplexedListener{}
type multiplexedListener struct {
manet.Listener
listeners map[DemultiplexedConnType]*demultiplexedListener
mx sync.RWMutex
connGater connmgr.ConnectionGater
rcmgr network.ResourceManager
ctx context.Context
closeFn func() error
wg sync.WaitGroup
}
var ErrListenerExists = errors.New("listener already exists for this conn type on this address")
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) {
if !connType.IsKnown() {
return nil, fmt.Errorf("unknown connection type: %s", connType)
}
m.mx.Lock()
defer m.mx.Unlock()
if _, ok := m.listeners[connType]; ok {
return nil, ErrListenerExists
}
ctx, cancel := context.WithCancel(m.ctx)
l := &demultiplexedListener{
buffer: make(chan manet.Conn),
inner: m.Listener,
ctx: ctx,
cancelFunc: cancel,
closeFn: func() error { m.removeDemultiplexedListener(connType); return nil },
}
m.listeners[connType] = l
return l, nil
}
func (m *multiplexedListener) run() error {
defer m.Close()
defer m.wg.Done()
acceptQueue := make(chan struct{}, acceptQueueSize)
for {
c, err := m.Listener.Accept()
if err != nil {
return err
}
// Gate and resource limit the connection here.
// If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer
// which clogs up our entire connection queue.
// This duplicates the responsibility of gating and resource limiting between here and the upgrader. The
// alternative without duplication requires moving the process of upgrading the connection here, which forces
// us to establish the websocket connection here. That is more duplication, or a significant breaking change.
//
// Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport
// integration tests.
if m.connGater != nil && !m.connGater.InterceptAccept(c) {
log.Debugf("gater blocked incoming connection on local addr %s from %s",
c.LocalMultiaddr(), c.RemoteMultiaddr())
if err := c.Close(); err != nil {
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
}
continue
}
connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
if err != nil {
log.Debugw("resource manager blocked accept of new connection", "error", err)
if err := c.Close(); err != nil {
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
}
continue
}
select {
case acceptQueue <- struct{}{}:
// NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader.
case <-m.ctx.Done():
c.Close()
log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr())
}
m.wg.Add(1)
go func() {
defer func() { <-acceptQueue }()
defer m.wg.Done()
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
defer cancelCtx()
t, c, err := identifyConnType(c)
if err != nil {
connScope.Done()
closeErr := c.Close()
err = errors.Join(err, closeErr)
log.Debugf("error demultiplexing connection: %s", err.Error())
return
}
connWithScope, err := manetConnWithScope(c, connScope)
if err != nil {
connScope.Done()
closeErr := c.Close()
err = errors.Join(err, closeErr)
log.Debugf("error wrapping connection with scope: %s", err.Error())
return
}
m.mx.RLock()
demux, ok := m.listeners[t]
m.mx.RUnlock()
if !ok {
closeErr := connWithScope.Close()
if closeErr != nil {
log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error())
} else {
log.Debugf("no registered listener for demultiplex connection %s", t)
}
return
}
select {
case demux.buffer <- connWithScope:
case <-ctx.Done():
connWithScope.Close()
}
}()
}
}
func (m *multiplexedListener) Close() error {
m.mx.Lock()
for _, l := range m.listeners {
l.cancelFunc()
}
err := m.closeListener()
m.mx.Unlock()
m.wg.Wait()
return err
}
func (m *multiplexedListener) closeListener() error {
lerr := m.Listener.Close()
cerr := m.closeFn()
return errors.Join(lerr, cerr)
}
func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnType) {
m.mx.Lock()
defer m.mx.Unlock()
delete(m.listeners, c)
if len(m.listeners) == 0 {
m.closeListener()
m.mx.Unlock()
m.wg.Wait()
m.mx.Lock()
}
}
type demultiplexedListener struct {
buffer chan manet.Conn
inner manet.Listener
ctx context.Context
cancelFunc context.CancelFunc
closeFn func() error
}
func (m *demultiplexedListener) Accept() (manet.Conn, error) {
select {
case c := <-m.buffer:
return c, nil
case <-m.ctx.Done():
return nil, transport.ErrListenerClosed
}
}
func (m *demultiplexedListener) Close() error {
m.cancelFunc()
return m.closeFn()
}
func (m *demultiplexedListener) Multiaddr() ma.Multiaddr {
return m.inner.Multiaddr()
}
func (m *demultiplexedListener) Addr() net.Addr {
return m.inner.Addr()
}

View File

@@ -0,0 +1,449 @@
package tcpreuse
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"net/http"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multistream"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func selfSignedTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
certTemplate := x509.Certificate{
SerialNumber: &big.Int{},
Subject: pkix.Name{
Organization: []string{"Test"},
},
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
require.NoError(t, err)
cert := tls.Certificate{
Certificate: [][]byte{derBytes},
PrivateKey: priv,
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
return tlsConfig
}
type wsHandler struct{ conns chan *websocket.Conn }
func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
u := websocket.Upgrader{}
c, _ := u.Upgrade(w, r, http.Header{})
wh.conns <- c
}
func TestListenerSingle(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 64
for _, enableReuseport := range []bool{true, false} {
t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
go func() {
d := net.Dialer{}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String())
if err != nil {
t.Error("failed to dial", err, i)
return
}
lconn := multistream.NewMSSelect(conn, "a")
buf := make([]byte, 10)
_, err = lconn.Write([]byte("hello-multistream"))
if err != nil {
t.Error(err)
}
_, err = lconn.Read(buf)
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
var wg sync.WaitGroup
for i := 0; i < N; i++ {
c, err := l.Accept()
require.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
cc := multistream.NewMSSelect(c, "a")
defer cc.Close()
buf := make([]byte, 30)
n, err := cc.Read(buf)
if !assert.NoError(t, err) {
return
}
if !assert.Equal(t, "hello-multistream", string(buf[:n])) {
return
}
}()
}
wg.Wait()
})
t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
http.Serve(manet.NetListener(l), wh)
}()
go func() {
d := websocket.Dialer{}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{})
if err != nil {
t.Error("failed to dial", err, i)
return
}
err = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
if err != nil {
t.Error(err)
}
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
var wg sync.WaitGroup
for i := 0; i < N; i++ {
c := <-wh.conns
wg.Add(1)
go func() {
defer wg.Done()
defer c.Close()
msgType, buf, err := c.ReadMessage()
if !assert.NoError(t, err) {
return
}
if !assert.Equal(t, msgType, websocket.TextMessage) {
return
}
if !assert.Equal(t, "hello", string(buf)) {
return
}
}()
}
wg.Wait()
})
t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) {
cm := NewConnMgr(enableReuseport, nil, nil)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err)
defer l.Close()
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(l), "", "")
}()
go func() {
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{})
if err != nil {
t.Error("failed to dial", err, i)
return
}
err = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
if err != nil {
t.Error(err)
}
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
var wg sync.WaitGroup
for i := 0; i < N; i++ {
c := <-wh.conns
wg.Add(1)
go func() {
defer wg.Done()
defer c.Close()
msgType, buf, err := c.ReadMessage()
if !assert.NoError(t, err) {
return
}
if !assert.Equal(t, msgType, websocket.TextMessage) {
return
}
if !assert.Equal(t, "hello", string(buf)) {
return
}
}()
}
wg.Wait()
})
}
}
func TestListenerMultiplexed(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 20
for _, enableReuseport := range []bool{true, false} {
cm := NewConnMgr(enableReuseport, nil, nil)
msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
defer msl.Close()
wsl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
defer wsl.Close()
require.Equal(t, wsl.Multiaddr(), msl.Multiaddr())
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
http.Serve(manet.NetListener(wsl), wh)
}()
wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err)
defer wssl.Close()
require.Equal(t, wssl.Multiaddr(), wsl.Multiaddr())
whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(wssl), "", "")
}()
// multistream connections
go func() {
d := net.Dialer{}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, err := d.DialContext(ctx, msl.Addr().Network(), msl.Addr().String())
if err != nil {
t.Error("failed to dial", err, i)
return
}
lconn := multistream.NewMSSelect(conn, "a")
buf := make([]byte, 10)
_, err = lconn.Write([]byte("multistream"))
if err != nil {
t.Error(err)
}
_, err = lconn.Read(buf)
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
// ws connections
go func() {
d := websocket.Dialer{}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", msl.Addr().String()), http.Header{})
if err != nil {
t.Error("failed to dial", err, i)
return
}
err = conn.WriteMessage(websocket.TextMessage, []byte("websocket"))
if err != nil {
t.Error(err)
}
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
// wss connections
go func() {
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", msl.Addr().String()), http.Header{})
if err != nil {
t.Error("failed to dial", err, i)
return
}
err = conn.WriteMessage(websocket.TextMessage, []byte("websocket-tls"))
if err != nil {
t.Error(err)
}
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
c, err := msl.Accept()
if !assert.NoError(t, err) {
return
}
wg.Add(1)
go func() {
defer wg.Done()
cc := multistream.NewMSSelect(c, "a")
defer cc.Close()
buf := make([]byte, 20)
n, err := cc.Read(buf)
if !assert.NoError(t, err) {
return
}
if !assert.Equal(t, "multistream", string(buf[:n])) {
return
}
}()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
c := <-wh.conns
wg.Add(1)
go func() {
defer wg.Done()
defer c.Close()
msgType, buf, err := c.ReadMessage()
if !assert.NoError(t, err) {
return
}
if !assert.Equal(t, msgType, websocket.TextMessage) {
return
}
if !assert.Equal(t, "websocket", string(buf)) {
return
}
}()
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
c := <-whs.conns
wg.Add(1)
go func() {
defer wg.Done()
defer c.Close()
msgType, buf, err := c.ReadMessage()
if !assert.NoError(t, err) {
return
}
if !assert.Equal(t, msgType, websocket.TextMessage) {
return
}
if !assert.Equal(t, "websocket-tls", string(buf)) {
return
}
}()
}
}()
wg.Wait()
}
}
func TestListenerClose(t *testing.T) {
testClose := func(listenAddr ma.Multiaddr) {
// listen on port 0
cm := NewConnMgr(false, nil, nil)
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
wl.Close()
wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
ml.Close()
mll, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
mll.Close()
wl.Close()
ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
// Now listen on the specific port previously used
listenAddr = ml.Multiaddr()
wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
wl.Close()
wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
ml.Close()
wl.Close()
}
listenAddrs := []ma.Multiaddr{ma.StringCast("/ip4/0.0.0.0/tcp/0"), ma.StringCast("/ip6/::/tcp/0")}
for _, listenAddr := range listenAddrs {
testClose(listenAddr)
}
}

View File

@@ -1,4 +1,4 @@
package tcp
package tcpreuse
import (
"os"
@@ -11,13 +11,13 @@ import (
// It default to true.
const envReuseport = "LIBP2P_TCP_REUSEPORT"
// envReuseportVal stores the value of envReuseport. defaults to true.
var envReuseportVal = true
// EnvReuseportVal stores the value of envReuseport. defaults to true.
var EnvReuseportVal = true
func init() {
v := strings.ToLower(os.Getenv(envReuseport))
if v == "false" || v == "f" || v == "0" {
envReuseportVal = false
EnvReuseportVal = false
log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v)
}
}
@@ -31,5 +31,5 @@ func init() {
// If this becomes a sought after feature, we could add this to the config.
// In the end, reuseport is a stop-gap.
func ReuseportIsAvailable() bool {
return envReuseportVal && reuseport.Available()
return EnvReuseportVal && reuseport.Available()
}

View File

@@ -11,7 +11,9 @@ import (
ma "github.com/multiformats/go-multiaddr"
)
var Subtests = []func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID){
type TransportSubTestFn func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID)
var Subtests = []TransportSubTestFn{
SubtestProtocols,
SubtestBasic,
SubtestCancel,
@@ -33,12 +35,17 @@ func getFunctionName(i interface{}) string {
}
func SubtestTransport(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID) {
t.Helper()
SubtestTransportWithFs(t, ta, tb, addr, peerA, Subtests)
}
func SubtestTransportWithFs(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID, tests []TransportSubTestFn) {
maddr, err := ma.NewMultiaddr(addr)
if err != nil {
t.Fatal(err)
}
for _, f := range Subtests {
for _, f := range tests {
t.Run(getFunctionName(f), func(t *testing.T) {
f(t, ta, tb, maddr, peerA)
})

View File

@@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
}
func TestListeningOnDNSAddr(t *testing.T) {
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil)
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
require.NoError(t, err)
addr := ln.Multiaddr()
first, rest := ma.SplitFirst(addr)

View File

@@ -1,6 +1,7 @@
package websocket
import (
"errors"
"io"
"net"
"sync"
@@ -8,6 +9,8 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
ws "github.com/gorilla/websocket"
)
@@ -22,20 +25,53 @@ type Conn struct {
secure bool
DefaultMessageType int
reader io.Reader
closeOnce sync.Once
closeOnceVal func() error
laddr ma.Multiaddr
raddr ma.Multiaddr
readLock, writeLock sync.Mutex
}
var _ net.Conn = (*Conn)(nil)
var _ manet.Conn = (*Conn)(nil)
// NewConn creates a Conn given a regular gorilla/websocket Conn.
//
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
func NewConn(raw *ws.Conn, secure bool) *Conn {
return &Conn{
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
laddr, err := manet.FromNetAddr(lna)
if err != nil {
log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr())
return nil
}
rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
raddr, err := manet.FromNetAddr(rna)
if err != nil {
log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr())
return nil
}
c := &Conn{
Conn: raw,
secure: secure,
DefaultMessageType: ws.BinaryMessage,
laddr: laddr,
raddr: raddr,
}
c.closeOnceVal = sync.OnceValue(c.closeOnceFn)
return c
}
// LocalMultiaddr implements manet.Conn.
func (c *Conn) LocalMultiaddr() ma.Multiaddr {
return c.laddr
}
// RemoteMultiaddr implements manet.Conn.
func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
return c.raddr
}
func (c *Conn) Read(b []byte) (int, error) {
@@ -99,26 +135,31 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return len(b), nil
}
// Close closes the connection. Only the first call to Close will receive the
// close error, subsequent and concurrent calls will return nil.
func (c *Conn) Scope() network.ConnManagementScope {
nc := c.NetConn()
if sc, ok := nc.(interface {
Scope() network.ConnManagementScope
}); ok {
return sc.Scope()
}
return nil
}
// Close closes the connection.
// subsequent and concurrent calls will return the same error value.
// This method is thread-safe.
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
err1 := c.Conn.WriteControl(
ws.CloseMessage,
ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"),
time.Now().Add(GracefulCloseTimeout),
)
err2 := c.Conn.Close()
switch {
case err1 != nil:
err = err1
case err2 != nil:
err = err2
}
})
return err
return c.closeOnceVal()
}
func (c *Conn) closeOnceFn() error {
err1 := c.Conn.WriteControl(
ws.CloseMessage,
ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"),
time.Now().Add(GracefulCloseTimeout),
)
err2 := c.Conn.Close()
return errors.Join(err1, err2)
}
func (c *Conn) LocalAddr() net.Addr {

View File

@@ -4,14 +4,16 @@ import (
"crypto/tls"
"errors"
"fmt"
"go.uber.org/zap"
"net"
"net/http"
"sync"
"go.uber.org/zap"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
@@ -50,7 +52,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
// newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a)
if err != nil {
return nil, err
@@ -60,19 +62,36 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
}
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
if err != nil {
return nil, err
}
nl, err := net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
var nl net.Listener
if sharedTcp == nil {
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
if err != nil {
return nil, err
}
nl, err = net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
}
} else {
var connType tcpreuse.DemultiplexedConnType
if parsed.isWSS {
connType = tcpreuse.DemultiplexedConnType_TLS
} else {
connType = tcpreuse.DemultiplexedConnType_HTTP
}
mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
if err != nil {
return nil, err
}
nl = manet.NetListener(mal)
}
laddr, err := manet.FromNetAddr(nl.Addr())
if err != nil {
return nil, err
}
first, _ := ma.SplitFirst(a)
// Don't resolve dns addresses.
// We want to be able to announce domain names, so the peer can validate the TLS certificate.
@@ -111,7 +130,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// The upgrader writes a response for us.
return
}
nc := NewConn(c, l.isWss)
if nc == nil {
c.Close()
w.WriteHeader(500)
return
}
select {
case l.incoming <- NewConn(c, l.isWss):
case <-l.closed:
@@ -126,13 +150,7 @@ func (l *listener) Accept() (manet.Conn, error) {
if !ok {
return nil, transport.ErrListenerClosed
}
mnc, err := manet.WrapNetConn(c)
if err != nil {
c.Close()
return nil, err
}
return mnc, nil
return c, nil
case <-l.closed:
return nil, transport.ErrListenerClosed
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
@@ -87,11 +88,13 @@ type WebsocketTransport struct {
tlsClientConf *tls.Config
tlsConf *tls.Config
sharedTcp *tcpreuse.ConnMgr
}
var _ transport.Transport = (*WebsocketTransport)(nil)
func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) {
func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*WebsocketTransport, error) {
if rcmgr == nil {
rcmgr = &network.NullResourceManager{}
}
@@ -99,6 +102,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*
upgrader: u,
rcmgr: rcmgr,
tlsClientConf: &tls.Config{},
sharedTcp: sharedTCP,
}
for _, opt := range opts {
if err := opt(t); err != nil {
@@ -233,7 +237,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
if t.tlsConf != nil {
tlsConf = t.tlsConf.Clone()
}
l, err := newListener(a, tlsConf)
l, err := newListener(a, tlsConf, t.sharedTcp)
if err != nil {
return nil, err
}

View File

@@ -154,7 +154,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID
}
id, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf))
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf))
if err != nil {
t.Fatal(err)
}
@@ -237,7 +237,7 @@ func TestHostHeaderWss(t *testing.T) {
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig))
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
@@ -256,7 +256,7 @@ func TestDialWss(t *testing.T) {
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig))
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
@@ -279,7 +279,7 @@ func TestDialWssNoClientCert(t *testing.T) {
require.Contains(t, serverMA.String(), "tls")
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{})
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
@@ -294,12 +294,12 @@ func TestDialWssNoClientCert(t *testing.T) {
func TestWebsocketTransport(t *testing.T) {
peerA, ua := newUpgrader(t)
ta, err := New(ua, nil)
ta, err := New(ua, nil, nil)
if err != nil {
t.Fatal(err)
}
_, ub := newUpgrader(t)
tb, err := New(ub, nil)
tb, err := New(ub, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
opts = append(opts, WithTLSConfig(tlsConf))
}
server, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, opts...)
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
require.NoError(t, err)
l, err := tpt.Listen(laddr)
require.NoError(t, err)
@@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
}
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, opts...)
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
require.NoError(t, err)
c, err := tpt.Dial(context.Background(), l.Multiaddr(), server)
require.NoError(t, err)
@@ -382,7 +382,7 @@ func TestWebsocketConnection(t *testing.T) {
func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{})
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")
_, err = tpt.Listen(addr)
@@ -391,7 +391,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
func TestWebsocketListenSecureAndInsecure(t *testing.T) {
serverID, serverUpgrader := newUpgrader(t)
server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t)))
server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
@@ -401,7 +401,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) {
t.Run("insecure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
@@ -418,7 +418,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) {
t.Run("secure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
@@ -436,7 +436,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) {
func TestConcurrentClose(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{})
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
@@ -474,7 +474,7 @@ func TestConcurrentClose(t *testing.T) {
func TestWriteZero(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{})
tpt, err := New(u, &network.NullResourceManager{}, nil)
if err != nil {
t.Fatal(err)
}