mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-26 20:21:26 +08:00
feat(tcpreuse): add options for sharing TCP listeners amongst TCP, WS and WSS transports (#2984)
Allows the same socket to be shared amongst TCP,WS,WSS transports. --------- Co-authored-by: sukun <sukunrt@gmail.com> Co-authored-by: Marco Munizaga <git@marcopolo.io>
This commit is contained in:
@@ -38,6 +38,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
@@ -145,6 +146,8 @@ type Config struct {
|
||||
CustomIPv6BlackHoleSuccessCounter bool
|
||||
|
||||
UserFxOptions []fx.Option
|
||||
|
||||
ShareTCPListener bool
|
||||
}
|
||||
|
||||
func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) {
|
||||
@@ -289,6 +292,12 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
|
||||
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
|
||||
fx.Provide(func() pnet.PSK { return cfg.PSK }),
|
||||
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
|
||||
fx.Provide(func(gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *tcpreuse.ConnMgr {
|
||||
if !cfg.ShareTCPListener {
|
||||
return nil
|
||||
}
|
||||
return tcpreuse.NewConnMgr(tcpreuse.EnvReuseportVal, gater, rcmgr)
|
||||
}),
|
||||
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
|
||||
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
|
||||
quicAddrPorts := map[string]struct{}{}
|
||||
|
@@ -59,7 +59,7 @@ func TestTransportConstructor(t *testing.T) {
|
||||
_ connmgr.ConnectionGater,
|
||||
upgrader transport.Upgrader,
|
||||
) transport.Transport {
|
||||
tpt, err := tcp.NewTCPTransport(upgrader, nil)
|
||||
tpt, err := tcp.NewTCPTransport(upgrader, nil, nil)
|
||||
require.NoError(t, err)
|
||||
return tpt
|
||||
}
|
||||
@@ -751,3 +751,27 @@ func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedTCPAddr(t *testing.T) {
|
||||
h, err := New(
|
||||
ShareTCPListener(),
|
||||
Transport(tcp.NewTCPTransport),
|
||||
Transport(websocket.New),
|
||||
ListenAddrStrings("/ip4/0.0.0.0/tcp/8888"),
|
||||
ListenAddrStrings("/ip4/0.0.0.0/tcp/8888/ws"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
sawTCP := false
|
||||
sawWS := false
|
||||
for _, addr := range h.Addrs() {
|
||||
if strings.HasSuffix(addr.String(), "/tcp/8888") {
|
||||
sawTCP = true
|
||||
}
|
||||
if strings.HasSuffix(addr.String(), "/tcp/8888/ws") {
|
||||
sawWS = true
|
||||
}
|
||||
}
|
||||
require.True(t, sawTCP)
|
||||
require.True(t, sawWS)
|
||||
h.Close()
|
||||
}
|
||||
|
12
options.go
12
options.go
@@ -643,3 +643,15 @@ func WithFxOption(opts ...fx.Option) Option {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ShareTCPListener shares the same listen address between TCP and Websocket
|
||||
// transports. This lets both transports use the same TCP port.
|
||||
//
|
||||
// Currently this behavior is Opt-in. In a future release this will be the
|
||||
// default, and this option will be removed.
|
||||
func ShareTCPListener() Option {
|
||||
return func(cfg *Config) error {
|
||||
cfg.ShareTCPListener = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@@ -84,7 +84,7 @@ func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm {
|
||||
upgrader := makeUpgrader(t, s)
|
||||
var tcpOpts []tcp.Option
|
||||
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
|
||||
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
|
||||
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...)
|
||||
require.NoError(t, err)
|
||||
if err := s.AddTransport(tcpTransport); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@@ -79,7 +79,7 @@ func TestDialAddressSelection(t *testing.T) {
|
||||
s, err := swarm.NewSwarm("local", nil, eventbus.NewBus())
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpTr, err := tcp.NewTCPTransport(nil, nil)
|
||||
tcpTr, err := tcp.NewTCPTransport(nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, s.AddTransport(tcpTr))
|
||||
reuse, err := quicreuse.NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
|
||||
|
@@ -53,7 +53,7 @@ func TestAddrsForDial(t *testing.T) {
|
||||
ps.AddPrivKey(id, priv)
|
||||
t.Cleanup(func() { ps.Close() })
|
||||
|
||||
tpt, err := websocket.New(nil, &network.NullResourceManager{})
|
||||
tpt, err := websocket.New(nil, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(ResolverFromMaDNS{resolver}))
|
||||
require.NoError(t, err)
|
||||
@@ -100,7 +100,7 @@ func TestDedupAddrsForDial(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
|
||||
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
err = s.AddTransport(tpt)
|
||||
require.NoError(t, err)
|
||||
@@ -134,7 +134,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
|
||||
})
|
||||
|
||||
// Add a tcp transport so that we know we can dial a tcp multiaddr and we don't filter it out.
|
||||
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
|
||||
tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
err = s.AddTransport(tpt)
|
||||
require.NoError(t, err)
|
||||
@@ -151,7 +151,7 @@ func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
|
||||
err = s.AddTransport(wtTpt)
|
||||
require.NoError(t, err)
|
||||
|
||||
wsTpt, err := websocket.New(nil, &network.NullResourceManager{})
|
||||
wsTpt, err := websocket.New(nil, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
err = s.AddTransport(wsTpt)
|
||||
require.NoError(t, err)
|
||||
|
@@ -164,7 +164,7 @@ func GenSwarm(t testing.TB, opts ...Option) *swarm.Swarm {
|
||||
if cfg.disableReuseport {
|
||||
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
|
||||
}
|
||||
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
|
||||
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, nil, tcpOpts...)
|
||||
require.NoError(t, err)
|
||||
if err := s.AddTransport(tcpTransport); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@@ -84,23 +84,33 @@ func (l *listener) handleIncoming() {
|
||||
}
|
||||
catcher.Reset()
|
||||
|
||||
// gate the connection if applicable
|
||||
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
|
||||
log.Debugf("gater blocked incoming connection on local addr %s from %s",
|
||||
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
|
||||
if err := maconn.Close(); err != nil {
|
||||
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
|
||||
}
|
||||
continue
|
||||
// Check if we already have a connection scope. See the comment in tcpreuse/listener.go for an explanation.
|
||||
var connScope network.ConnManagementScope
|
||||
if sc, ok := maconn.(interface {
|
||||
Scope() network.ConnManagementScope
|
||||
}); ok {
|
||||
connScope = sc.Scope()
|
||||
}
|
||||
|
||||
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked accept of new connection", "error", err)
|
||||
if err := maconn.Close(); err != nil {
|
||||
log.Warnf("failed to incoming connection rejected by resource manager: %s", err)
|
||||
if connScope == nil {
|
||||
// gate the connection if applicable
|
||||
if l.upgrader.connGater != nil && !l.upgrader.connGater.InterceptAccept(maconn) {
|
||||
log.Debugf("gater blocked incoming connection on local addr %s from %s",
|
||||
maconn.LocalMultiaddr(), maconn.RemoteMultiaddr())
|
||||
if err := maconn.Close(); err != nil {
|
||||
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
connScope, err = l.rcmgr.OpenConnection(network.DirInbound, true, maconn.RemoteMultiaddr())
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked accept of new connection", "error", err)
|
||||
if err := maconn.Close(); err != nil {
|
||||
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// The go routine below calls Release when the context is
|
||||
|
@@ -60,7 +60,7 @@ func getNetHosts(t *testing.T, ctx context.Context, n int) (hosts []host.Host, u
|
||||
upgrader := swarmt.GenUpgrader(t, netw, nil)
|
||||
upgraders = append(upgraders, upgrader)
|
||||
|
||||
tpt, err := tcp.NewTCPTransport(upgrader, nil)
|
||||
tpt, err := tcp.NewTCPTransport(upgrader, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -2,6 +2,8 @@ package transport_integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -30,6 +32,23 @@ func stripCertHash(addr ma.Multiaddr) ma.Multiaddr {
|
||||
return addr
|
||||
}
|
||||
|
||||
func addrPort(addr ma.Multiaddr) netip.AddrPort {
|
||||
a := netip.Addr{}
|
||||
p := uint16(0)
|
||||
ma.ForEach(addr, func(c ma.Component) bool {
|
||||
if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 {
|
||||
a, _ = netip.AddrFromSlice(c.RawValue())
|
||||
return false
|
||||
}
|
||||
if c.Protocol().Code == ma.P_UDP || c.Protocol().Code == ma.P_TCP {
|
||||
p = binary.BigEndian.Uint16(c.RawValue())
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
return netip.AddrPortFrom(a, p)
|
||||
}
|
||||
|
||||
func TestInterceptPeerDial(t *testing.T) {
|
||||
if race.WithRace() {
|
||||
t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.")
|
||||
@@ -173,10 +192,14 @@ func TestInterceptAccept(t *testing.T) {
|
||||
// remove the certhash component from WebTransport addresses
|
||||
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
|
||||
}).AnyTimes()
|
||||
} else if strings.Contains(tc.Name, "WebSocket-Shared") {
|
||||
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
|
||||
require.Equal(t, addrPort(h2.Addrs()[0]), addrPort(addrs.LocalMultiaddr()))
|
||||
})
|
||||
} else {
|
||||
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
|
||||
// remove the certhash component from WebTransport addresses
|
||||
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
|
||||
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr(), "%s\n%s", h2.Addrs()[0], addrs.LocalMultiaddr())
|
||||
})
|
||||
}
|
||||
|
||||
|
@@ -99,6 +99,38 @@ var transportsToTest = []TransportTestCase{
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "TCP-Shared / TLS / Yamux",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
libp2pOpts := transformOpts(opts)
|
||||
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
|
||||
libp2pOpts = append(libp2pOpts, libp2p.Security(tls.ID, tls.New))
|
||||
libp2pOpts = append(libp2pOpts, libp2p.Muxer(yamux.ID, yamux.DefaultTransport))
|
||||
if opts.NoListen {
|
||||
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
|
||||
} else {
|
||||
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0"))
|
||||
}
|
||||
h, err := libp2p.New(libp2pOpts...)
|
||||
require.NoError(t, err)
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "WebSocket-Shared",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
libp2pOpts := transformOpts(opts)
|
||||
libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener())
|
||||
if opts.NoListen {
|
||||
libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs)
|
||||
} else {
|
||||
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
}
|
||||
h, err := libp2p.New(libp2pOpts...)
|
||||
require.NoError(t, err)
|
||||
return h
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "WebSocket",
|
||||
HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host {
|
||||
|
54
p2p/transport/tcp/metrics_unix_test.go
Normal file
54
p2p/transport/tcp/metrics_unix_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
// go:build: unix
|
||||
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTcpTransportCollectsMetricsWithSharedTcpSocket(t *testing.T) {
|
||||
|
||||
peerA, ia := makeInsecureMuxer(t)
|
||||
_, ib := makeInsecureMuxer(t)
|
||||
|
||||
sharedTCPSocketA := tcpreuse.NewConnMgr(false, nil, nil)
|
||||
sharedTCPSocketB := tcpreuse.NewConnMgr(false, nil, nil)
|
||||
|
||||
ua, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ta, err := NewTCPTransport(ua, nil, sharedTCPSocketA, WithMetrics())
|
||||
require.NoError(t, err)
|
||||
ub, err := tptu.New(ib, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
tb, err := NewTCPTransport(ub, nil, sharedTCPSocketB, WithMetrics())
|
||||
require.NoError(t, err)
|
||||
|
||||
zero := "/ip4/127.0.0.1/tcp/0"
|
||||
|
||||
// Not running any test that needs more than 1 conn because the testsuite
|
||||
// opens multiple conns via multiple listeners, which is not expected to work
|
||||
// with the shared TCP socket.
|
||||
subtestsToRun := []ttransport.TransportSubTestFn{
|
||||
ttransport.SubtestProtocols,
|
||||
ttransport.SubtestBasic,
|
||||
ttransport.SubtestCancel,
|
||||
ttransport.SubtestPingPong,
|
||||
|
||||
// Stolen from the stream muxer test suite.
|
||||
ttransport.SubtestStress1Conn1Stream1Msg,
|
||||
ttransport.SubtestStress1Conn1Stream100Msg,
|
||||
ttransport.SubtestStress1Conn100Stream100Msg,
|
||||
ttransport.SubtestStress1Conn1000Stream10Msg,
|
||||
ttransport.SubtestStress1Conn100Stream100Msg10MB,
|
||||
ttransport.SubtestStreamOpenStress,
|
||||
ttransport.SubtestStreamReset,
|
||||
}
|
||||
|
||||
ttransport.SubtestTransportWithFs(t, ta, tb, zero, peerA, subtestsToRun)
|
||||
}
|
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
@@ -33,6 +34,9 @@ type canKeepAlive interface {
|
||||
|
||||
var _ canKeepAlive = &net.TCPConn{}
|
||||
|
||||
// Deprecated: Use tcpreuse.ReuseportIsAvailable
|
||||
var ReuseportIsAvailable = tcpreuse.ReuseportIsAvailable
|
||||
|
||||
func tryKeepAlive(conn net.Conn, keepAlive bool) {
|
||||
keepAliveConn, ok := conn.(canKeepAlive)
|
||||
if !ok {
|
||||
@@ -122,6 +126,9 @@ type TcpTransport struct {
|
||||
disableReuseport bool // Explicitly disable reuseport.
|
||||
enableMetrics bool
|
||||
|
||||
// share and demultiplex TCP listeners across multiple transports
|
||||
sharedTcp *tcpreuse.ConnMgr
|
||||
|
||||
// TCP connect timeout
|
||||
connectTimeout time.Duration
|
||||
|
||||
@@ -134,8 +141,8 @@ var _ transport.Transport = &TcpTransport{}
|
||||
var _ transport.DialUpdater = &TcpTransport{}
|
||||
|
||||
// NewTCPTransport creates a tcp transport object that tracks dialers and listeners
|
||||
// created. It represents an entire TCP stack (though it might not necessarily be).
|
||||
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*TcpTransport, error) {
|
||||
// created.
|
||||
func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*TcpTransport, error) {
|
||||
if rcmgr == nil {
|
||||
rcmgr = &network.NullResourceManager{}
|
||||
}
|
||||
@@ -143,6 +150,7 @@ func NewTCPTransport(upgrader transport.Upgrader, rcmgr network.ResourceManager,
|
||||
upgrader: upgrader,
|
||||
connectTimeout: defaultConnectTimeout, // can be set by using the WithConnectionTimeout option
|
||||
rcmgr: rcmgr,
|
||||
sharedTcp: sharedTCP,
|
||||
}
|
||||
for _, o := range opts {
|
||||
if err := o(tr); err != nil {
|
||||
@@ -168,6 +176,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if t.sharedTcp != nil {
|
||||
return t.sharedTcp.DialContext(ctx, raddr)
|
||||
}
|
||||
|
||||
if t.UseReuseport() {
|
||||
return t.reuse.DialContext(ctx, raddr)
|
||||
}
|
||||
@@ -233,10 +245,10 @@ func (t *TcpTransport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p
|
||||
|
||||
// UseReuseport returns true if reuseport is enabled and available.
|
||||
func (t *TcpTransport) UseReuseport() bool {
|
||||
return !t.disableReuseport && ReuseportIsAvailable()
|
||||
return !t.disableReuseport && tcpreuse.ReuseportIsAvailable()
|
||||
}
|
||||
|
||||
func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
|
||||
func (t *TcpTransport) unsharedMAListen(laddr ma.Multiaddr) (manet.Listener, error) {
|
||||
if t.UseReuseport() {
|
||||
return t.reuse.Listen(laddr)
|
||||
}
|
||||
@@ -245,10 +257,18 @@ func (t *TcpTransport) maListen(laddr ma.Multiaddr) (manet.Listener, error) {
|
||||
|
||||
// Listen listens on the given multiaddr.
|
||||
func (t *TcpTransport) Listen(laddr ma.Multiaddr) (transport.Listener, error) {
|
||||
list, err := t.maListen(laddr)
|
||||
var list manet.Listener
|
||||
var err error
|
||||
|
||||
if t.sharedTcp == nil {
|
||||
list, err = t.unsharedMAListen(laddr)
|
||||
} else {
|
||||
list, err = t.sharedTcp.DemultiplexedListen(laddr, tcpreuse.DemultiplexedConnType_MultistreamSelect)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t.enableMetrics {
|
||||
list = newTracingListener(&tcpListener{list, 0})
|
||||
}
|
||||
|
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
|
||||
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
@@ -31,19 +32,19 @@ func TestTcpTransport(t *testing.T) {
|
||||
|
||||
ua, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ta, err := NewTCPTransport(ua, nil)
|
||||
ta, err := NewTCPTransport(ua, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ub, err := tptu.New(ib, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
tb, err := NewTCPTransport(ub, nil)
|
||||
tb, err := NewTCPTransport(ub, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
zero := "/ip4/127.0.0.1/tcp/0"
|
||||
ttransport.SubtestTransport(t, ta, tb, zero, peerA)
|
||||
|
||||
envReuseportVal = false
|
||||
tcpreuse.EnvReuseportVal = false
|
||||
}
|
||||
envReuseportVal = true
|
||||
tcpreuse.EnvReuseportVal = true
|
||||
}
|
||||
|
||||
func TestTcpTransportWithMetrics(t *testing.T) {
|
||||
@@ -52,11 +53,11 @@ func TestTcpTransportWithMetrics(t *testing.T) {
|
||||
|
||||
ua, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ta, err := NewTCPTransport(ua, nil, WithMetrics())
|
||||
ta, err := NewTCPTransport(ua, nil, nil, WithMetrics())
|
||||
require.NoError(t, err)
|
||||
ub, err := tptu.New(ib, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
tb, err := NewTCPTransport(ub, nil, WithMetrics())
|
||||
tb, err := NewTCPTransport(ub, nil, nil, WithMetrics())
|
||||
require.NoError(t, err)
|
||||
|
||||
zero := "/ip4/127.0.0.1/tcp/0"
|
||||
@@ -72,7 +73,7 @@ func TestResourceManager(t *testing.T) {
|
||||
|
||||
ua, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ta, err := NewTCPTransport(ua, nil)
|
||||
ta, err := NewTCPTransport(ua, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
|
||||
require.NoError(t, err)
|
||||
@@ -81,7 +82,7 @@ func TestResourceManager(t *testing.T) {
|
||||
ub, err := tptu.New(ib, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
|
||||
tb, err := NewTCPTransport(ub, rcmgr)
|
||||
tb, err := NewTCPTransport(ub, rcmgr, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
@@ -119,16 +120,16 @@ func TestTcpTransportCantDialDNS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var u transport.Upgrader
|
||||
tpt, err := NewTCPTransport(u, nil)
|
||||
tpt, err := NewTCPTransport(u, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tpt.CanDial(dnsa) {
|
||||
t.Fatal("shouldn't be able to dial dns")
|
||||
}
|
||||
|
||||
envReuseportVal = false
|
||||
tcpreuse.EnvReuseportVal = false
|
||||
}
|
||||
envReuseportVal = true
|
||||
tcpreuse.EnvReuseportVal = true
|
||||
}
|
||||
|
||||
func TestTcpTransportCantListenUtp(t *testing.T) {
|
||||
@@ -137,15 +138,15 @@ func TestTcpTransportCantListenUtp(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
var u transport.Upgrader
|
||||
tpt, err := NewTCPTransport(u, nil)
|
||||
tpt, err := NewTCPTransport(u, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tpt.Listen(utpa)
|
||||
require.Error(t, err, "shouldn't be able to listen on utp addr with tcp transport")
|
||||
|
||||
envReuseportVal = false
|
||||
tcpreuse.EnvReuseportVal = false
|
||||
}
|
||||
envReuseportVal = true
|
||||
tcpreuse.EnvReuseportVal = true
|
||||
}
|
||||
|
||||
func TestDialWithUpdates(t *testing.T) {
|
||||
@@ -154,7 +155,7 @@ func TestDialWithUpdates(t *testing.T) {
|
||||
|
||||
ua, err := tptu.New(ia, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ta, err := NewTCPTransport(ua, nil)
|
||||
ta, err := NewTCPTransport(ua, nil, nil)
|
||||
require.NoError(t, err)
|
||||
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
|
||||
require.NoError(t, err)
|
||||
@@ -162,7 +163,7 @@ func TestDialWithUpdates(t *testing.T) {
|
||||
|
||||
ub, err := tptu.New(ib, muxers, nil, nil, nil)
|
||||
require.NoError(t, err)
|
||||
tb, err := NewTCPTransport(ub, nil)
|
||||
tb, err := NewTCPTransport(ub, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
updCh := make(chan transport.DialUpdate, 1)
|
||||
|
26
p2p/transport/tcpreuse/connwithscope.go
Normal file
26
p2p/transport/tcpreuse/connwithscope.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package tcpreuse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
type connWithScope struct {
|
||||
sampledconn.ManetTCPConnInterface
|
||||
scope network.ConnManagementScope
|
||||
}
|
||||
|
||||
func (c connWithScope) Scope() network.ConnManagementScope {
|
||||
return c.scope
|
||||
}
|
||||
|
||||
func manetConnWithScope(c manet.Conn, scope network.ConnManagementScope) (manet.Conn, error) {
|
||||
if tcpconn, ok := c.(sampledconn.ManetTCPConnInterface); ok {
|
||||
return &connWithScope{tcpconn, scope}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("manet.Conn is not a TCP Conn")
|
||||
}
|
97
p2p/transport/tcpreuse/demultiplex.go
Normal file
97
p2p/transport/tcpreuse/demultiplex.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package tcpreuse
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse/internal/sampledconn"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
// This is readiung the first 3 bytes of the packet. It should be instant.
|
||||
const identifyConnTimeout = 1 * time.Second
|
||||
|
||||
type DemultiplexedConnType int
|
||||
|
||||
const (
|
||||
DemultiplexedConnType_Unknown DemultiplexedConnType = iota
|
||||
DemultiplexedConnType_MultistreamSelect
|
||||
DemultiplexedConnType_HTTP
|
||||
DemultiplexedConnType_TLS
|
||||
)
|
||||
|
||||
func (t DemultiplexedConnType) String() string {
|
||||
switch t {
|
||||
case DemultiplexedConnType_MultistreamSelect:
|
||||
return "MultistreamSelect"
|
||||
case DemultiplexedConnType_HTTP:
|
||||
return "HTTP"
|
||||
case DemultiplexedConnType_TLS:
|
||||
return "TLS"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown(%d)", int(t))
|
||||
}
|
||||
}
|
||||
|
||||
func (t DemultiplexedConnType) IsKnown() bool {
|
||||
return t >= 1 || t <= 3
|
||||
}
|
||||
|
||||
// identifyConnType attempts to identify the connection type by peeking at the
|
||||
// first few bytes.
|
||||
// It Callers must not use the passed in Conn after this
|
||||
// function returns. if an error is returned, the connection will be closed.
|
||||
func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) {
|
||||
if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil {
|
||||
closeErr := c.Close()
|
||||
return 0, nil, errors.Join(err, closeErr)
|
||||
}
|
||||
|
||||
s, c, err := sampledconn.PeekBytes(c)
|
||||
if err != nil {
|
||||
closeErr := c.Close()
|
||||
return 0, nil, errors.Join(err, closeErr)
|
||||
}
|
||||
|
||||
if err := c.SetReadDeadline(time.Time{}); err != nil {
|
||||
closeErr := c.Close()
|
||||
return 0, nil, errors.Join(err, closeErr)
|
||||
}
|
||||
|
||||
if IsMultistreamSelect(s) {
|
||||
return DemultiplexedConnType_MultistreamSelect, c, nil
|
||||
}
|
||||
if IsTLS(s) {
|
||||
return DemultiplexedConnType_TLS, c, nil
|
||||
}
|
||||
if IsHTTP(s) {
|
||||
return DemultiplexedConnType_HTTP, c, nil
|
||||
}
|
||||
return DemultiplexedConnType_Unknown, c, nil
|
||||
}
|
||||
|
||||
// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
|
||||
type Prefix = [3]byte
|
||||
|
||||
func IsMultistreamSelect(s Prefix) bool {
|
||||
return string(s[:]) == "\x13/m"
|
||||
}
|
||||
|
||||
func IsHTTP(s Prefix) bool {
|
||||
switch string(s[:]) {
|
||||
case "GET", "HEA", "POS", "PUT", "DEL", "CON", "OPT", "TRA", "PAT":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func IsTLS(s Prefix) bool {
|
||||
switch string(s[:]) {
|
||||
case "\x16\x03\x01", "\x16\x03\x02", "\x16\x03\x03":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
50
p2p/transport/tcpreuse/demultiplex_test.go
Normal file
50
p2p/transport/tcpreuse/demultiplex_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package tcpreuse
|
||||
|
||||
import "testing"
|
||||
|
||||
func FuzzClash(f *testing.F) {
|
||||
// make untyped literals type correctly
|
||||
add := func(a, b, c byte) { f.Add(a, b, c) }
|
||||
|
||||
// multistream-select
|
||||
add('\x13', '/', 'm')
|
||||
// http
|
||||
add('G', 'E', 'T')
|
||||
add('H', 'E', 'A')
|
||||
add('P', 'O', 'S')
|
||||
add('P', 'U', 'T')
|
||||
add('D', 'E', 'L')
|
||||
add('C', 'O', 'N')
|
||||
add('O', 'P', 'T')
|
||||
add('T', 'R', 'A')
|
||||
add('P', 'A', 'T')
|
||||
// tls
|
||||
add('\x16', '\x03', '\x01')
|
||||
add('\x16', '\x03', '\x02')
|
||||
add('\x16', '\x03', '\x03')
|
||||
add('\x16', '\x03', '\x04')
|
||||
|
||||
f.Fuzz(func(t *testing.T, a, b, c byte) {
|
||||
s := Prefix{a, b, c}
|
||||
var total uint
|
||||
|
||||
ms := IsMultistreamSelect(s)
|
||||
if ms {
|
||||
total++
|
||||
}
|
||||
|
||||
http := IsHTTP(s)
|
||||
if http {
|
||||
total++
|
||||
}
|
||||
|
||||
tls := IsTLS(s)
|
||||
if tls {
|
||||
total++
|
||||
}
|
||||
|
||||
if total > 1 {
|
||||
t.Errorf("clash on: %q; ms: %v; http: %v; tls: %v", s, ms, http, tls)
|
||||
}
|
||||
})
|
||||
}
|
16
p2p/transport/tcpreuse/dialer.go
Normal file
16
p2p/transport/tcpreuse/dialer.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package tcpreuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
// DialContext is like Dial but takes a context.
|
||||
func (t *ConnMgr) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
|
||||
if t.useReuseport() {
|
||||
return t.reuse.DialContext(ctx, raddr)
|
||||
}
|
||||
var d manet.Dialer
|
||||
return d.DialContext(ctx, raddr)
|
||||
}
|
@@ -0,0 +1,89 @@
|
||||
package sampledconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
const peekSize = 3
|
||||
|
||||
type PeekedBytes = [peekSize]byte
|
||||
|
||||
var errNotSupported = errors.New("not supported on this platform")
|
||||
|
||||
var ErrNotTCPConn = errors.New("passed conn is not a TCPConn")
|
||||
|
||||
func PeekBytes(conn manet.Conn) (PeekedBytes, manet.Conn, error) {
|
||||
if c, ok := conn.(syscall.Conn); ok {
|
||||
b, err := OSPeekConn(c)
|
||||
if err == nil {
|
||||
return b, conn, nil
|
||||
}
|
||||
if err != errNotSupported {
|
||||
return PeekedBytes{}, nil, err
|
||||
}
|
||||
// Fallback to wrapping the coonn
|
||||
}
|
||||
|
||||
if c, ok := conn.(ManetTCPConnInterface); ok {
|
||||
return newFallbackSampledConn(c)
|
||||
}
|
||||
|
||||
return PeekedBytes{}, nil, ErrNotTCPConn
|
||||
}
|
||||
|
||||
type fallbackPeekingConn struct {
|
||||
ManetTCPConnInterface
|
||||
peekedBytes PeekedBytes
|
||||
bytesPeeked uint8
|
||||
}
|
||||
|
||||
// tcpConnInterface is the interface for TCPConn's functions
|
||||
// NOTE: `SyscallConn() (syscall.RawConn, error)` is here to make using this as
|
||||
// a TCP Conn easier, but it's a potential footgun as you could skipped the
|
||||
// peeked bytes if using the fallback
|
||||
type tcpConnInterface interface {
|
||||
net.Conn
|
||||
syscall.Conn
|
||||
|
||||
CloseRead() error
|
||||
CloseWrite() error
|
||||
|
||||
SetLinger(sec int) error
|
||||
SetKeepAlive(keepalive bool) error
|
||||
SetKeepAlivePeriod(d time.Duration) error
|
||||
SetNoDelay(noDelay bool) error
|
||||
MultipathTCP() (bool, error)
|
||||
|
||||
io.ReaderFrom
|
||||
io.WriterTo
|
||||
}
|
||||
|
||||
type ManetTCPConnInterface interface {
|
||||
manet.Conn
|
||||
tcpConnInterface
|
||||
}
|
||||
|
||||
func newFallbackSampledConn(conn ManetTCPConnInterface) (PeekedBytes, *fallbackPeekingConn, error) {
|
||||
s := &fallbackPeekingConn{ManetTCPConnInterface: conn}
|
||||
_, err := io.ReadFull(conn, s.peekedBytes[:])
|
||||
if err != nil {
|
||||
return s.peekedBytes, nil, err
|
||||
}
|
||||
return s.peekedBytes, s, nil
|
||||
}
|
||||
|
||||
func (sc *fallbackPeekingConn) Read(b []byte) (int, error) {
|
||||
if int(sc.bytesPeeked) != len(sc.peekedBytes) {
|
||||
red := copy(b, sc.peekedBytes[sc.bytesPeeked:])
|
||||
sc.bytesPeeked += uint8(red)
|
||||
return red, nil
|
||||
}
|
||||
|
||||
return sc.ManetTCPConnInterface.Read(b)
|
||||
}
|
@@ -0,0 +1,11 @@
|
||||
//go:build !unix && !windows
|
||||
|
||||
package sampledconn
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
|
||||
return PeekedBytes{}, errNotSupported
|
||||
}
|
@@ -0,0 +1,78 @@
|
||||
package sampledconn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSampledConn(t *testing.T) {
|
||||
testCases := []string{
|
||||
"platform",
|
||||
"fallback",
|
||||
}
|
||||
|
||||
// Start a TCP server
|
||||
listener, err := manet.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
|
||||
assert.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
serverAddr := listener.Multiaddr()
|
||||
|
||||
// Server goroutine
|
||||
go func() {
|
||||
for i := 0; i < len(testCases); i++ {
|
||||
conn, err := listener.Accept()
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Write some data to the connection
|
||||
_, err = conn.Write([]byte("hello"))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give the server a moment to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
// Create a TCP client
|
||||
clientConn, err := manet.Dial(serverAddr)
|
||||
assert.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
if tc == "platform" {
|
||||
// Wrap the client connection in SampledConn
|
||||
peeked, clientConn, err := PeekBytes(clientConn.(interface {
|
||||
manet.Conn
|
||||
syscall.Conn
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hel", string(peeked[:]))
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = clientConn.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hello", string(buf))
|
||||
} else {
|
||||
// Wrap the client connection in SampledConn
|
||||
sample, sampledConn, err := newFallbackSampledConn(clientConn.(ManetTCPConnInterface))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hel", string(sample[:]))
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = io.ReadFull(sampledConn, buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hello", string(buf))
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -0,0 +1,42 @@
|
||||
//go:build unix
|
||||
|
||||
package sampledconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
|
||||
s := PeekedBytes{}
|
||||
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
readBytes := 0
|
||||
var readErr error
|
||||
err = rawConn.Read(func(fd uintptr) bool {
|
||||
for readBytes < peekSize {
|
||||
var n int
|
||||
n, _, readErr = syscall.Recvfrom(int(fd), s[readBytes:], syscall.MSG_PEEK)
|
||||
if errors.Is(readErr, syscall.EAGAIN) {
|
||||
return false
|
||||
}
|
||||
if readErr != nil {
|
||||
return true
|
||||
}
|
||||
readBytes += n
|
||||
}
|
||||
return true
|
||||
})
|
||||
if readErr != nil {
|
||||
return s, readErr
|
||||
}
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
@@ -0,0 +1,49 @@
|
||||
//go:build windows
|
||||
|
||||
package sampledconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"golang.org/x/sys/windows"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) {
|
||||
s := PeekedBytes{}
|
||||
|
||||
rawConn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
readBytes := 0
|
||||
var readErr error
|
||||
err = rawConn.Read(func(fd uintptr) bool {
|
||||
for readBytes < peekSize {
|
||||
var n uint32
|
||||
flags := uint32(windows.MSG_PEEK)
|
||||
wsabuf := windows.WSABuf{
|
||||
Len: uint32(len(s) - readBytes),
|
||||
Buf: &s[readBytes],
|
||||
}
|
||||
|
||||
readErr = windows.WSARecv(windows.Handle(fd), &wsabuf, 1, &n, &flags, nil, nil)
|
||||
if errors.Is(readErr, windows.WSAEWOULDBLOCK) {
|
||||
return false
|
||||
}
|
||||
if readErr != nil {
|
||||
return true
|
||||
}
|
||||
readBytes += int(n)
|
||||
}
|
||||
return true
|
||||
})
|
||||
if readErr != nil {
|
||||
return s, readErr
|
||||
}
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
329
p2p/transport/tcpreuse/listener.go
Normal file
329
p2p/transport/tcpreuse/listener.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package tcpreuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-libp2p/core/connmgr"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/reuseport"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
)
|
||||
|
||||
const acceptQueueSize = 64 // It is fine to read 3 bytes from 64 connections in parallel.
|
||||
|
||||
// How long we wait for a connection to be accepted before dropping it.
|
||||
const acceptTimeout = 30 * time.Second
|
||||
|
||||
var log = logging.Logger("tcp-demultiplex")
|
||||
|
||||
// ConnMgr enables you to share the same listen address between TCP and WebSocket transports.
|
||||
type ConnMgr struct {
|
||||
enableReuseport bool
|
||||
reuse reuseport.Transport
|
||||
connGater connmgr.ConnectionGater
|
||||
rcmgr network.ResourceManager
|
||||
|
||||
mx sync.Mutex
|
||||
listeners map[string]*multiplexedListener
|
||||
}
|
||||
|
||||
func NewConnMgr(enableReuseport bool, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) *ConnMgr {
|
||||
if rcmgr == nil {
|
||||
rcmgr = &network.NullResourceManager{}
|
||||
}
|
||||
return &ConnMgr{
|
||||
enableReuseport: enableReuseport,
|
||||
reuse: reuseport.Transport{},
|
||||
connGater: gater,
|
||||
rcmgr: rcmgr,
|
||||
listeners: make(map[string]*multiplexedListener),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ConnMgr) maListen(listenAddr ma.Multiaddr) (manet.Listener, error) {
|
||||
if t.useReuseport() {
|
||||
return t.reuse.Listen(listenAddr)
|
||||
} else {
|
||||
return manet.Listen(listenAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ConnMgr) useReuseport() bool {
|
||||
return t.enableReuseport && ReuseportIsAvailable()
|
||||
}
|
||||
|
||||
func getTCPAddr(listenAddr ma.Multiaddr) (ma.Multiaddr, error) {
|
||||
haveTCP := false
|
||||
addr, _ := ma.SplitFunc(listenAddr, func(c ma.Component) bool {
|
||||
if haveTCP {
|
||||
return true
|
||||
}
|
||||
if c.Protocol().Code == ma.P_TCP {
|
||||
haveTCP = true
|
||||
}
|
||||
return false
|
||||
})
|
||||
if !haveTCP {
|
||||
return nil, fmt.Errorf("invalid listen addr %s, need tcp address", listenAddr)
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// DemultiplexedListen returns a listener for laddr listening for `connType` connections. The connections
|
||||
// accepted from returned listeners need to be upgraded with a `transport.Upgrader`.
|
||||
// NOTE: All listeners for port 0 share the same underlying socket, so they have the same specific port.
|
||||
func (t *ConnMgr) DemultiplexedListen(laddr ma.Multiaddr, connType DemultiplexedConnType) (manet.Listener, error) {
|
||||
if !connType.IsKnown() {
|
||||
return nil, fmt.Errorf("unknown connection type: %s", connType)
|
||||
}
|
||||
laddr, err := getTCPAddr(laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.mx.Lock()
|
||||
defer t.mx.Unlock()
|
||||
ml, ok := t.listeners[laddr.String()]
|
||||
if ok {
|
||||
dl, err := ml.DemultiplexedListen(connType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
l, err := t.maListen(laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancelFunc := func() error {
|
||||
cancel()
|
||||
t.mx.Lock()
|
||||
defer t.mx.Unlock()
|
||||
delete(t.listeners, laddr.String())
|
||||
delete(t.listeners, l.Multiaddr().String())
|
||||
return l.Close()
|
||||
}
|
||||
ml = &multiplexedListener{
|
||||
Listener: l,
|
||||
listeners: make(map[DemultiplexedConnType]*demultiplexedListener),
|
||||
ctx: ctx,
|
||||
closeFn: cancelFunc,
|
||||
connGater: t.connGater,
|
||||
rcmgr: t.rcmgr,
|
||||
}
|
||||
t.listeners[laddr.String()] = ml
|
||||
t.listeners[l.Multiaddr().String()] = ml
|
||||
|
||||
dl, err := ml.DemultiplexedListen(connType)
|
||||
if err != nil {
|
||||
cerr := ml.Close()
|
||||
return nil, errors.Join(err, cerr)
|
||||
}
|
||||
|
||||
ml.wg.Add(1)
|
||||
go ml.run()
|
||||
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
var _ manet.Listener = &demultiplexedListener{}
|
||||
|
||||
type multiplexedListener struct {
|
||||
manet.Listener
|
||||
listeners map[DemultiplexedConnType]*demultiplexedListener
|
||||
mx sync.RWMutex
|
||||
|
||||
connGater connmgr.ConnectionGater
|
||||
rcmgr network.ResourceManager
|
||||
ctx context.Context
|
||||
closeFn func() error
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
var ErrListenerExists = errors.New("listener already exists for this conn type on this address")
|
||||
|
||||
func (m *multiplexedListener) DemultiplexedListen(connType DemultiplexedConnType) (manet.Listener, error) {
|
||||
if !connType.IsKnown() {
|
||||
return nil, fmt.Errorf("unknown connection type: %s", connType)
|
||||
}
|
||||
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
if _, ok := m.listeners[connType]; ok {
|
||||
return nil, ErrListenerExists
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(m.ctx)
|
||||
l := &demultiplexedListener{
|
||||
buffer: make(chan manet.Conn),
|
||||
inner: m.Listener,
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
closeFn: func() error { m.removeDemultiplexedListener(connType); return nil },
|
||||
}
|
||||
|
||||
m.listeners[connType] = l
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (m *multiplexedListener) run() error {
|
||||
defer m.Close()
|
||||
defer m.wg.Done()
|
||||
acceptQueue := make(chan struct{}, acceptQueueSize)
|
||||
for {
|
||||
c, err := m.Listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Gate and resource limit the connection here.
|
||||
// If done after sampling the connection, we'll be vulnerable to DOS attacks by a single peer
|
||||
// which clogs up our entire connection queue.
|
||||
// This duplicates the responsibility of gating and resource limiting between here and the upgrader. The
|
||||
// alternative without duplication requires moving the process of upgrading the connection here, which forces
|
||||
// us to establish the websocket connection here. That is more duplication, or a significant breaking change.
|
||||
//
|
||||
// Bugs around multiple calls to OpenConnection or InterceptAccept are prevented by the transport
|
||||
// integration tests.
|
||||
if m.connGater != nil && !m.connGater.InterceptAccept(c) {
|
||||
log.Debugf("gater blocked incoming connection on local addr %s from %s",
|
||||
c.LocalMultiaddr(), c.RemoteMultiaddr())
|
||||
if err := c.Close(); err != nil {
|
||||
log.Warnf("failed to close incoming connection rejected by gater: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
connScope, err := m.rcmgr.OpenConnection(network.DirInbound, true, c.RemoteMultiaddr())
|
||||
if err != nil {
|
||||
log.Debugw("resource manager blocked accept of new connection", "error", err)
|
||||
if err := c.Close(); err != nil {
|
||||
log.Warnf("failed to open incoming connection. Rejected by resource manager: %s", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case acceptQueue <- struct{}{}:
|
||||
// NOTE: We can drop the connection, but this is similar to the behaviour in the upgrader.
|
||||
case <-m.ctx.Done():
|
||||
c.Close()
|
||||
log.Debugf("accept queue full, dropping connection: %s", c.RemoteMultiaddr())
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer func() { <-acceptQueue }()
|
||||
defer m.wg.Done()
|
||||
ctx, cancelCtx := context.WithTimeout(m.ctx, acceptTimeout)
|
||||
defer cancelCtx()
|
||||
t, c, err := identifyConnType(c)
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
closeErr := c.Close()
|
||||
err = errors.Join(err, closeErr)
|
||||
log.Debugf("error demultiplexing connection: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
connWithScope, err := manetConnWithScope(c, connScope)
|
||||
if err != nil {
|
||||
connScope.Done()
|
||||
closeErr := c.Close()
|
||||
err = errors.Join(err, closeErr)
|
||||
log.Debugf("error wrapping connection with scope: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
m.mx.RLock()
|
||||
demux, ok := m.listeners[t]
|
||||
m.mx.RUnlock()
|
||||
if !ok {
|
||||
closeErr := connWithScope.Close()
|
||||
if closeErr != nil {
|
||||
log.Debugf("no registered listener for demultiplex connection %s. Error closing the connection %s", t, closeErr.Error())
|
||||
} else {
|
||||
log.Debugf("no registered listener for demultiplex connection %s", t)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case demux.buffer <- connWithScope:
|
||||
case <-ctx.Done():
|
||||
connWithScope.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *multiplexedListener) Close() error {
|
||||
m.mx.Lock()
|
||||
for _, l := range m.listeners {
|
||||
l.cancelFunc()
|
||||
}
|
||||
err := m.closeListener()
|
||||
m.mx.Unlock()
|
||||
m.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *multiplexedListener) closeListener() error {
|
||||
lerr := m.Listener.Close()
|
||||
cerr := m.closeFn()
|
||||
return errors.Join(lerr, cerr)
|
||||
}
|
||||
|
||||
func (m *multiplexedListener) removeDemultiplexedListener(c DemultiplexedConnType) {
|
||||
m.mx.Lock()
|
||||
defer m.mx.Unlock()
|
||||
|
||||
delete(m.listeners, c)
|
||||
if len(m.listeners) == 0 {
|
||||
m.closeListener()
|
||||
m.mx.Unlock()
|
||||
m.wg.Wait()
|
||||
m.mx.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
type demultiplexedListener struct {
|
||||
buffer chan manet.Conn
|
||||
inner manet.Listener
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (m *demultiplexedListener) Accept() (manet.Conn, error) {
|
||||
select {
|
||||
case c := <-m.buffer:
|
||||
return c, nil
|
||||
case <-m.ctx.Done():
|
||||
return nil, transport.ErrListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (m *demultiplexedListener) Close() error {
|
||||
m.cancelFunc()
|
||||
return m.closeFn()
|
||||
}
|
||||
|
||||
func (m *demultiplexedListener) Multiaddr() ma.Multiaddr {
|
||||
return m.inner.Multiaddr()
|
||||
}
|
||||
|
||||
func (m *demultiplexedListener) Addr() net.Addr {
|
||||
return m.inner.Addr()
|
||||
}
|
449
p2p/transport/tcpreuse/listener_test.go
Normal file
449
p2p/transport/tcpreuse/listener_test.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package tcpreuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/multiformats/go-multistream"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func selfSignedTLSConfig(t *testing.T) *tls.Config {
|
||||
t.Helper()
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
certTemplate := x509.Certificate{
|
||||
SerialNumber: &big.Int{},
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test"},
|
||||
},
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert := tls.Certificate{
|
||||
Certificate: [][]byte{derBytes},
|
||||
PrivateKey: priv,
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
type wsHandler struct{ conns chan *websocket.Conn }
|
||||
|
||||
func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
u := websocket.Upgrader{}
|
||||
c, _ := u.Upgrade(w, r, http.Header{})
|
||||
wh.conns <- c
|
||||
}
|
||||
|
||||
func TestListenerSingle(t *testing.T) {
|
||||
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
|
||||
const N = 64
|
||||
for _, enableReuseport := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("multistream-reuseport:%v", enableReuseport), func(t *testing.T) {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
go func() {
|
||||
d := net.Dialer{}
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String())
|
||||
if err != nil {
|
||||
t.Error("failed to dial", err, i)
|
||||
return
|
||||
}
|
||||
lconn := multistream.NewMSSelect(conn, "a")
|
||||
buf := make([]byte, 10)
|
||||
_, err = lconn.Write([]byte("hello-multistream"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = lconn.Read(buf)
|
||||
if err == nil {
|
||||
t.Error("expected EOF got nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < N; i++ {
|
||||
c, err := l.Accept()
|
||||
require.NoError(t, err)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cc := multistream.NewMSSelect(c, "a")
|
||||
defer cc.Close()
|
||||
buf := make([]byte, 30)
|
||||
n, err := cc.Read(buf)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "hello-multistream", string(buf[:n])) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("WebSocket-reuseport:%v", enableReuseport), func(t *testing.T) {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
http.Serve(manet.NetListener(l), wh)
|
||||
}()
|
||||
go func() {
|
||||
d := websocket.Dialer{}
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{})
|
||||
if err != nil {
|
||||
t.Error("failed to dial", err, i)
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Error("expected EOF got nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < N; i++ {
|
||||
c := <-wh.conns
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer c.Close()
|
||||
msgType, buf, err := c.ReadMessage()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, msgType, websocket.TextMessage) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "hello", string(buf)) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", enableReuseport), func(t *testing.T) {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)}
|
||||
s.ServeTLS(manet.NetListener(l), "", "")
|
||||
}()
|
||||
go func() {
|
||||
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{})
|
||||
if err != nil {
|
||||
t.Error("failed to dial", err, i)
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Error("expected EOF got nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < N; i++ {
|
||||
c := <-wh.conns
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer c.Close()
|
||||
msgType, buf, err := c.ReadMessage()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, msgType, websocket.TextMessage) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "hello", string(buf)) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerMultiplexed(t *testing.T) {
|
||||
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
|
||||
const N = 20
|
||||
for _, enableReuseport := range []bool{true, false} {
|
||||
cm := NewConnMgr(enableReuseport, nil, nil)
|
||||
msl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
defer msl.Close()
|
||||
|
||||
wsl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
defer wsl.Close()
|
||||
require.Equal(t, wsl.Multiaddr(), msl.Multiaddr())
|
||||
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
http.Serve(manet.NetListener(wsl), wh)
|
||||
}()
|
||||
|
||||
wssl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
|
||||
require.NoError(t, err)
|
||||
defer wssl.Close()
|
||||
require.Equal(t, wssl.Multiaddr(), wsl.Multiaddr())
|
||||
whs := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
|
||||
go func() {
|
||||
s := http.Server{Handler: whs, TLSConfig: selfSignedTLSConfig(t)}
|
||||
s.ServeTLS(manet.NetListener(wssl), "", "")
|
||||
}()
|
||||
|
||||
// multistream connections
|
||||
go func() {
|
||||
d := net.Dialer{}
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
conn, err := d.DialContext(ctx, msl.Addr().Network(), msl.Addr().String())
|
||||
if err != nil {
|
||||
t.Error("failed to dial", err, i)
|
||||
return
|
||||
}
|
||||
lconn := multistream.NewMSSelect(conn, "a")
|
||||
buf := make([]byte, 10)
|
||||
_, err = lconn.Write([]byte("multistream"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = lconn.Read(buf)
|
||||
if err == nil {
|
||||
t.Error("expected EOF got nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// ws connections
|
||||
go func() {
|
||||
d := websocket.Dialer{}
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", msl.Addr().String()), http.Header{})
|
||||
if err != nil {
|
||||
t.Error("failed to dial", err, i)
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, []byte("websocket"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Error("expected EOF got nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// wss connections
|
||||
go func() {
|
||||
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", msl.Addr().String()), http.Header{})
|
||||
if err != nil {
|
||||
t.Error("failed to dial", err, i)
|
||||
return
|
||||
}
|
||||
err = conn.WriteMessage(websocket.TextMessage, []byte("websocket-tls"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Error("expected EOF got nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < N; i++ {
|
||||
c, err := msl.Accept()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cc := multistream.NewMSSelect(c, "a")
|
||||
defer cc.Close()
|
||||
buf := make([]byte, 20)
|
||||
n, err := cc.Read(buf)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "multistream", string(buf[:n])) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < N; i++ {
|
||||
c := <-wh.conns
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer c.Close()
|
||||
msgType, buf, err := c.ReadMessage()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, msgType, websocket.TextMessage) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "websocket", string(buf)) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < N; i++ {
|
||||
c := <-whs.conns
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer c.Close()
|
||||
msgType, buf, err := c.ReadMessage()
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, msgType, websocket.TextMessage) {
|
||||
return
|
||||
}
|
||||
if !assert.Equal(t, "websocket-tls", string(buf)) {
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerClose(t *testing.T) {
|
||||
testClose := func(listenAddr ma.Multiaddr) {
|
||||
// listen on port 0
|
||||
cm := NewConnMgr(false, nil, nil)
|
||||
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
wl, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
|
||||
wl.Close()
|
||||
|
||||
wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
|
||||
|
||||
ml.Close()
|
||||
|
||||
mll, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
|
||||
|
||||
mll.Close()
|
||||
wl.Close()
|
||||
|
||||
ml, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now listen on the specific port previously used
|
||||
listenAddr = ml.Multiaddr()
|
||||
wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
|
||||
wl.Close()
|
||||
|
||||
wl, err = cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wl.Multiaddr(), ml.Multiaddr())
|
||||
|
||||
ml.Close()
|
||||
wl.Close()
|
||||
}
|
||||
listenAddrs := []ma.Multiaddr{ma.StringCast("/ip4/0.0.0.0/tcp/0"), ma.StringCast("/ip6/::/tcp/0")}
|
||||
for _, listenAddr := range listenAddrs {
|
||||
testClose(listenAddr)
|
||||
}
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package tcp
|
||||
package tcpreuse
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
// It default to true.
|
||||
const envReuseport = "LIBP2P_TCP_REUSEPORT"
|
||||
|
||||
// envReuseportVal stores the value of envReuseport. defaults to true.
|
||||
var envReuseportVal = true
|
||||
// EnvReuseportVal stores the value of envReuseport. defaults to true.
|
||||
var EnvReuseportVal = true
|
||||
|
||||
func init() {
|
||||
v := strings.ToLower(os.Getenv(envReuseport))
|
||||
if v == "false" || v == "f" || v == "0" {
|
||||
envReuseportVal = false
|
||||
EnvReuseportVal = false
|
||||
log.Infof("REUSEPORT disabled (LIBP2P_TCP_REUSEPORT=%s)", v)
|
||||
}
|
||||
}
|
||||
@@ -31,5 +31,5 @@ func init() {
|
||||
// If this becomes a sought after feature, we could add this to the config.
|
||||
// In the end, reuseport is a stop-gap.
|
||||
func ReuseportIsAvailable() bool {
|
||||
return envReuseportVal && reuseport.Available()
|
||||
return EnvReuseportVal && reuseport.Available()
|
||||
}
|
@@ -11,7 +11,9 @@ import (
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
var Subtests = []func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID){
|
||||
type TransportSubTestFn func(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, peerA peer.ID)
|
||||
|
||||
var Subtests = []TransportSubTestFn{
|
||||
SubtestProtocols,
|
||||
SubtestBasic,
|
||||
SubtestCancel,
|
||||
@@ -33,12 +35,17 @@ func getFunctionName(i interface{}) string {
|
||||
}
|
||||
|
||||
func SubtestTransport(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID) {
|
||||
t.Helper()
|
||||
SubtestTransportWithFs(t, ta, tb, addr, peerA, Subtests)
|
||||
}
|
||||
|
||||
func SubtestTransportWithFs(t *testing.T, ta, tb transport.Transport, addr string, peerA peer.ID, tests []TransportSubTestFn) {
|
||||
maddr, err := ma.NewMultiaddr(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, f := range Subtests {
|
||||
for _, f := range tests {
|
||||
t.Run(getFunctionName(f), func(t *testing.T) {
|
||||
f(t, ta, tb, maddr, peerA)
|
||||
})
|
||||
|
@@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListeningOnDNSAddr(t *testing.T) {
|
||||
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil)
|
||||
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
|
||||
require.NoError(t, err)
|
||||
addr := ln.Multiaddr()
|
||||
first, rest := ma.SplitFirst(addr)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
|
||||
ws "github.com/gorilla/websocket"
|
||||
)
|
||||
@@ -22,20 +25,53 @@ type Conn struct {
|
||||
secure bool
|
||||
DefaultMessageType int
|
||||
reader io.Reader
|
||||
closeOnce sync.Once
|
||||
closeOnceVal func() error
|
||||
laddr ma.Multiaddr
|
||||
raddr ma.Multiaddr
|
||||
|
||||
readLock, writeLock sync.Mutex
|
||||
}
|
||||
|
||||
var _ net.Conn = (*Conn)(nil)
|
||||
var _ manet.Conn = (*Conn)(nil)
|
||||
|
||||
// NewConn creates a Conn given a regular gorilla/websocket Conn.
|
||||
//
|
||||
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
|
||||
func NewConn(raw *ws.Conn, secure bool) *Conn {
|
||||
return &Conn{
|
||||
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
|
||||
laddr, err := manet.FromNetAddr(lna)
|
||||
if err != nil {
|
||||
log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
|
||||
raddr, err := manet.FromNetAddr(rna)
|
||||
if err != nil {
|
||||
log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &Conn{
|
||||
Conn: raw,
|
||||
secure: secure,
|
||||
DefaultMessageType: ws.BinaryMessage,
|
||||
laddr: laddr,
|
||||
raddr: raddr,
|
||||
}
|
||||
c.closeOnceVal = sync.OnceValue(c.closeOnceFn)
|
||||
return c
|
||||
}
|
||||
|
||||
// LocalMultiaddr implements manet.Conn.
|
||||
func (c *Conn) LocalMultiaddr() ma.Multiaddr {
|
||||
return c.laddr
|
||||
}
|
||||
|
||||
// RemoteMultiaddr implements manet.Conn.
|
||||
func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
|
||||
return c.raddr
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
@@ -99,26 +135,31 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Close closes the connection. Only the first call to Close will receive the
|
||||
// close error, subsequent and concurrent calls will return nil.
|
||||
func (c *Conn) Scope() network.ConnManagementScope {
|
||||
nc := c.NetConn()
|
||||
if sc, ok := nc.(interface {
|
||||
Scope() network.ConnManagementScope
|
||||
}); ok {
|
||||
return sc.Scope()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
// subsequent and concurrent calls will return the same error value.
|
||||
// This method is thread-safe.
|
||||
func (c *Conn) Close() error {
|
||||
var err error
|
||||
c.closeOnce.Do(func() {
|
||||
err1 := c.Conn.WriteControl(
|
||||
ws.CloseMessage,
|
||||
ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"),
|
||||
time.Now().Add(GracefulCloseTimeout),
|
||||
)
|
||||
err2 := c.Conn.Close()
|
||||
switch {
|
||||
case err1 != nil:
|
||||
err = err1
|
||||
case err2 != nil:
|
||||
err = err2
|
||||
}
|
||||
})
|
||||
return err
|
||||
return c.closeOnceVal()
|
||||
}
|
||||
|
||||
func (c *Conn) closeOnceFn() error {
|
||||
err1 := c.Conn.WriteControl(
|
||||
ws.CloseMessage,
|
||||
ws.FormatCloseMessage(ws.CloseNormalClosure, "closed"),
|
||||
time.Now().Add(GracefulCloseTimeout),
|
||||
)
|
||||
err2 := c.Conn.Close()
|
||||
return errors.Join(err1, err2)
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
|
@@ -4,14 +4,16 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
@@ -50,7 +52,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {
|
||||
|
||||
// newListener creates a new listener from a raw net.Listener.
|
||||
// tlsConf may be nil (for unencrypted websockets).
|
||||
func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
|
||||
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) {
|
||||
parsed, err := parseWebsocketMultiaddr(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -60,19 +62,36 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) {
|
||||
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
|
||||
}
|
||||
|
||||
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nl, err := net.Listen(lnet, lnaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var nl net.Listener
|
||||
|
||||
if sharedTcp == nil {
|
||||
lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nl, err = net.Listen(lnet, lnaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var connType tcpreuse.DemultiplexedConnType
|
||||
if parsed.isWSS {
|
||||
connType = tcpreuse.DemultiplexedConnType_TLS
|
||||
} else {
|
||||
connType = tcpreuse.DemultiplexedConnType_HTTP
|
||||
}
|
||||
mal, err := sharedTcp.DemultiplexedListen(parsed.restMultiaddr, connType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nl = manet.NetListener(mal)
|
||||
}
|
||||
|
||||
laddr, err := manet.FromNetAddr(nl.Addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
first, _ := ma.SplitFirst(a)
|
||||
// Don't resolve dns addresses.
|
||||
// We want to be able to announce domain names, so the peer can validate the TLS certificate.
|
||||
@@ -111,7 +130,12 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// The upgrader writes a response for us.
|
||||
return
|
||||
}
|
||||
|
||||
nc := NewConn(c, l.isWss)
|
||||
if nc == nil {
|
||||
c.Close()
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case l.incoming <- NewConn(c, l.isWss):
|
||||
case <-l.closed:
|
||||
@@ -126,13 +150,7 @@ func (l *listener) Accept() (manet.Conn, error) {
|
||||
if !ok {
|
||||
return nil, transport.ErrListenerClosed
|
||||
}
|
||||
|
||||
mnc, err := manet.WrapNetConn(c)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return mnc, nil
|
||||
return c, nil
|
||||
case <-l.closed:
|
||||
return nil, transport.ErrListenerClosed
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/transport"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcpreuse"
|
||||
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
mafmt "github.com/multiformats/go-multiaddr-fmt"
|
||||
@@ -87,11 +88,13 @@ type WebsocketTransport struct {
|
||||
|
||||
tlsClientConf *tls.Config
|
||||
tlsConf *tls.Config
|
||||
|
||||
sharedTcp *tcpreuse.ConnMgr
|
||||
}
|
||||
|
||||
var _ transport.Transport = (*WebsocketTransport)(nil)
|
||||
|
||||
func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*WebsocketTransport, error) {
|
||||
func New(u transport.Upgrader, rcmgr network.ResourceManager, sharedTCP *tcpreuse.ConnMgr, opts ...Option) (*WebsocketTransport, error) {
|
||||
if rcmgr == nil {
|
||||
rcmgr = &network.NullResourceManager{}
|
||||
}
|
||||
@@ -99,6 +102,7 @@ func New(u transport.Upgrader, rcmgr network.ResourceManager, opts ...Option) (*
|
||||
upgrader: u,
|
||||
rcmgr: rcmgr,
|
||||
tlsClientConf: &tls.Config{},
|
||||
sharedTcp: sharedTCP,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt(t); err != nil {
|
||||
@@ -233,7 +237,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
|
||||
if t.tlsConf != nil {
|
||||
tlsConf = t.tlsConf.Clone()
|
||||
}
|
||||
l, err := newListener(a, tlsConf)
|
||||
l, err := newListener(a, tlsConf, t.sharedTcp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -154,7 +154,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID
|
||||
}
|
||||
|
||||
id, u := newSecureUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, WithTLSConfig(tlsConf))
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -237,7 +237,7 @@ func TestHostHeaderWss(t *testing.T) {
|
||||
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
|
||||
_, u := newSecureUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig))
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
|
||||
require.NoError(t, err)
|
||||
|
||||
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
||||
@@ -256,7 +256,7 @@ func TestDialWss(t *testing.T) {
|
||||
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
|
||||
_, u := newSecureUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, WithTLSClientConfig(tlsConfig))
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
|
||||
require.NoError(t, err)
|
||||
|
||||
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
||||
@@ -279,7 +279,7 @@ func TestDialWssNoClientCert(t *testing.T) {
|
||||
require.Contains(t, serverMA.String(), "tls")
|
||||
|
||||
_, u := newSecureUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{})
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
||||
@@ -294,12 +294,12 @@ func TestDialWssNoClientCert(t *testing.T) {
|
||||
|
||||
func TestWebsocketTransport(t *testing.T) {
|
||||
peerA, ua := newUpgrader(t)
|
||||
ta, err := New(ua, nil)
|
||||
ta, err := New(ua, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ub := newUpgrader(t)
|
||||
tb, err := New(ub, nil)
|
||||
tb, err := New(ub, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -325,7 +325,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
|
||||
opts = append(opts, WithTLSConfig(tlsConf))
|
||||
}
|
||||
server, u := newUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, opts...)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
|
||||
require.NoError(t, err)
|
||||
l, err := tpt.Listen(laddr)
|
||||
require.NoError(t, err)
|
||||
@@ -344,7 +344,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
|
||||
opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
}
|
||||
_, u := newUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, opts...)
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
|
||||
require.NoError(t, err)
|
||||
c, err := tpt.Dial(context.Background(), l.Multiaddr(), server)
|
||||
require.NoError(t, err)
|
||||
@@ -382,7 +382,7 @@ func TestWebsocketConnection(t *testing.T) {
|
||||
|
||||
func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
|
||||
_, u := newUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{})
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")
|
||||
_, err = tpt.Listen(addr)
|
||||
@@ -391,7 +391,7 @@ func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
|
||||
|
||||
func TestWebsocketListenSecureAndInsecure(t *testing.T) {
|
||||
serverID, serverUpgrader := newUpgrader(t)
|
||||
server, err := New(serverUpgrader, &network.NullResourceManager{}, WithTLSConfig(generateTLSConfig(t)))
|
||||
server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
|
||||
require.NoError(t, err)
|
||||
|
||||
lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
@@ -401,7 +401,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) {
|
||||
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
_, clientUpgrader := newUpgrader(t)
|
||||
client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// dialing the insecure address should succeed
|
||||
@@ -418,7 +418,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) {
|
||||
|
||||
t.Run("secure", func(t *testing.T) {
|
||||
_, clientUpgrader := newUpgrader(t)
|
||||
client, err := New(clientUpgrader, &network.NullResourceManager{}, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// dialing the insecure address should succeed
|
||||
@@ -436,7 +436,7 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) {
|
||||
|
||||
func TestConcurrentClose(t *testing.T) {
|
||||
_, u := newUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{})
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
||||
require.NoError(t, err)
|
||||
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
||||
if err != nil {
|
||||
@@ -474,7 +474,7 @@ func TestConcurrentClose(t *testing.T) {
|
||||
|
||||
func TestWriteZero(t *testing.T) {
|
||||
_, u := newUpgrader(t)
|
||||
tpt, err := New(u, &network.NullResourceManager{})
|
||||
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user