Files
go-libp2p/p2p/transport/websocket/websocket_test.go
sukun 6249e685e9 transport: add GatedMaListener type (#3186)
This introduces a new GatedMaListener type which gates conns
accepted from a manet.Listener with a gater and creates the rcmgr
scope for it. Explicitly passing the scope allows for many guardrails
that the previous interface assertion didn't.

This breaks the previous responsibility of the upgradeListener method
into two, one gating the connection initially, and the other upgrading
the connection with a security and muxer selection.

This split makes it easy to gate the connection with the resource
manager as early as possible. This is especially true for websocket
because we want to gate the connection just after the TCP connection is
established, and not after the tls handshake + websocket upgrade is
completed.
2025-03-25 22:09:57 +05:30

743 lines
22 KiB
Go

package websocket
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/url"
"strings"
"testing"
"time"
gws "github.com/gorilla/websocket"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/core/sec/insecure"
"github.com/libp2p/go-libp2p/core/test"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/security/noise"
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
t.Helper()
id, m := newInsecureMuxer(t)
u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
return id, u
}
func newSecureUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
t.Helper()
id, m := newSecureMuxer(t)
u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
return id, u
}
func newInsecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
t.Helper()
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(priv)
require.NoError(t, err)
return id, []sec.SecureTransport{insecure.NewWithIdentity(insecure.ID, id, priv)}
}
func newSecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
t.Helper()
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
if err != nil {
t.Fatal(err)
}
id, err := peer.IDFromPrivateKey(priv)
if err != nil {
t.Fatal(err)
}
noiseTpt, err := noise.New(noise.ID, priv, nil)
require.NoError(t, err)
return id, []sec.SecureTransport{noiseTpt}
}
func lastComponent(t *testing.T, a ma.Multiaddr) *ma.Component {
t.Helper()
_, wscomponent := ma.SplitLast(a)
require.NotNil(t, wscomponent)
if wscomponent.Equal(wsComponent) {
return wsComponent
}
if wscomponent.Equal(wssComponent) {
return wssComponent
}
t.Fatal("expected a ws or wss component")
return nil
}
func generateTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}},
}
}
func TestCanDial(t *testing.T) {
d := &WebsocketTransport{}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/ws")) {
t.Fatal("expected to match websocket maddr, but did not")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/wss")) {
t.Fatal("expected to match secure websocket maddr, but did not")
}
if d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555")) {
t.Fatal("expected to not match tcp maddr, but did")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/ws")) {
t.Fatal("expected to match secure websocket maddr, but did not")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/sni/example.com/ws")) {
t.Fatal("expected to match secure websocket maddr with sni, but did not")
}
if !d.CanDial(ma.StringCast("/dns4/example.com/tcp/5555/tls/sni/example.com/ws")) {
t.Fatal("expected to match secure websocket maddr with sni, but did not")
}
if !d.CanDial(ma.StringCast("/dnsaddr/example.com/tcp/5555/tls/sni/example.com/ws")) {
t.Fatal("expected to match secure websocket maddr with sni, but did not")
}
}
// testWSSServer returns a client hello info
func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID, chan error) {
errChan := make(chan error, 1)
ip := net.ParseIP("::")
tlsConf := getTLSConf(t, ip, time.Now(), time.Now().Add(time.Hour))
tlsConf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
if chi.ServerName != "example.com" {
errChan <- fmt.Errorf("didn't get the expected sni")
}
return tlsConf, nil
}
id, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf))
if err != nil {
t.Fatal(err)
}
l, err := tpt.Listen(listenAddr)
require.NoError(t, err)
t.Cleanup(func() {
l.Close()
})
go func() {
conn, err := l.Accept()
if err != nil {
errChan <- fmt.Errorf("error in accepting conn: %w", err)
return
}
defer conn.Close()
strm, err := conn.AcceptStream()
if err != nil {
errChan <- fmt.Errorf("error in accepting stream: %w", err)
return
}
defer strm.Close()
close(errChan)
}()
return l.Multiaddr(), id, errChan
}
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
t.Helper()
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1234),
Subject: pkix.Name{Organization: []string{"websocket"}},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{ip},
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
}
}
func TestHostHeaderWss(t *testing.T) {
server := &http.Server{}
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
defer server.Close()
errChan := make(chan error, 1)
go func() {
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(errChan)
if !strings.Contains(r.Host, "example.com") {
errChan <- errors.New("Didn't see host header")
}
w.WriteHeader(http.StatusNotFound)
})
server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour))
server.ServeTLS(l, "", "")
}()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoError(t, err)
serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws")
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
_, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t))
require.Error(t, err)
err = <-errChan
require.NoError(t, err)
}
func TestDialWss(t *testing.T) {
serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
require.Contains(t, serverMA.String(), "tls")
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
conn, err := tpt.Dial(context.Background(), masToDial[0], rid)
require.NoError(t, err)
defer conn.Close()
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
defer stream.Close()
err = <-errChan
require.NoError(t, err)
}
func TestDialWssNoClientCert(t *testing.T) {
serverMA, rid, _ := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
require.Contains(t, serverMA.String(), "tls")
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
_, err = tpt.Dial(context.Background(), masToDial[0], rid)
require.Error(t, err)
// The server doesn't have a signed certificate
require.Contains(t, err.Error(), "x509")
}
func TestWebsocketTransport(t *testing.T) {
t.Run("/ws", func(t *testing.T) {
peerA, ua := newUpgrader(t)
ta, err := New(ua, nil, nil)
if err != nil {
t.Fatal(err)
}
peerB, ub := newUpgrader(t)
tb, err := New(ub, nil, nil)
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
t.Run("/wss", func(t *testing.T) {
peerA, ua := newUpgrader(t)
tca := generateTLSConfig(t)
ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
peerB, ub := newUpgrader(t)
tcb := generateTLSConfig(t)
tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
}
func isWSS(addr ma.Multiaddr) bool {
if _, err := addr.ValueForProtocol(ma.P_WSS); err == nil {
return true
}
if _, err := addr.ValueForProtocol(ma.P_WS); err == nil {
return false
}
panic("not a WebSocket address")
}
func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
var opts []Option
var tlsConf *tls.Config
if secure {
tlsConf = generateTLSConfig(t)
opts = append(opts, WithTLSConfig(tlsConf))
}
server, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
require.NoError(t, err)
l, err := tpt.Listen(laddr)
require.NoError(t, err)
if secure {
require.Contains(t, l.Multiaddr().String(), "tls")
} else {
require.Equal(t, lastComponent(t, l.Multiaddr()).String(), wsComponent.String())
}
defer l.Close()
msg := []byte("HELLO WORLD")
go func() {
var opts []Option
if secure {
opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
}
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
require.NoError(t, err)
c, err := tpt.Dial(context.Background(), l.Multiaddr(), server)
require.NoError(t, err)
require.Equal(t, secure, isWSS(c.LocalMultiaddr()))
require.Equal(t, secure, isWSS(c.RemoteMultiaddr()))
str, err := c.OpenStream(context.Background())
require.NoError(t, err)
defer str.Close()
_, err = str.Write(msg)
require.NoError(t, err)
}()
c, err := l.Accept()
require.NoError(t, err)
defer c.Close()
require.Equal(t, secure, isWSS(c.LocalMultiaddr()))
require.Equal(t, secure, isWSS(c.RemoteMultiaddr()))
str, err := c.AcceptStream()
require.NoError(t, err)
defer str.Close()
out, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, out, msg, "got wrong message")
}
func TestWebsocketConnection(t *testing.T) {
t.Run("unencrypted", func(t *testing.T) {
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"), false)
})
t.Run("encrypted", func(t *testing.T) {
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"), true)
})
}
func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")
_, err = tpt.Listen(addr)
require.EqualError(t, err, fmt.Sprintf("cannot listen on wss address %s without a tls.Config", addr))
}
func TestWebsocketListenSecureAndInsecure(t *testing.T) {
serverID, serverUpgrader := newUpgrader(t)
server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
lnSecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
t.Run("insecure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
conn, err := client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsComponent.String())
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsComponent.String())
// dialing the secure address should fail
_, err = client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
require.NoError(t, err)
})
t.Run("secure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
conn, err := client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wssComponent.String())
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wssComponent.String())
// dialing the insecure address should fail
_, err = client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
require.NoError(t, err)
})
}
func TestConcurrentClose(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := []byte("HELLO WORLD")
go func() {
for i := 0; i < 100; i++ {
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil {
t.Error(err)
return
}
go func() {
_, _ = c.Write(msg)
}()
go func() {
_ = c.Close()
}()
}
}()
for i := 0; i < 100; i++ {
c, _, err := l.Accept()
if err != nil {
t.Fatal(err)
}
c.Close()
}
}
func TestWriteZero(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
if err != nil {
t.Fatal(err)
}
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := []byte(nil)
go func() {
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil {
t.Error(err)
return
}
defer c.Close()
for i := 0; i < 100; i++ {
n, err := c.Write(msg)
if n != 0 {
t.Errorf("expected to write 0 bytes, wrote %d", n)
}
if err != nil {
t.Error(err)
return
}
}
}()
c, _, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := make([]byte, 100)
n, err := c.Read(buf)
if n != 0 {
t.Errorf("read %d bytes, expected 0", n)
}
if err != io.EOF {
t.Errorf("expected EOF, got err: %s", err)
}
}
func TestResolveMultiaddr(t *testing.T) {
// map[unresolved]resolved
testCases := map[string]string{
"/dns/example.com/tcp/1234/wss": "/dns/example.com/tcp/1234/tls/sni/example.com/ws",
"/dns4/example.com/tcp/1234/wss": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws",
"/dns6/example.com/tcp/1234/wss": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws",
"/dnsaddr/example.com/tcp/1234/wss": "/dnsaddr/example.com/tcp/1234/wss",
"/dns4/example.com/tcp/1234/tls/ws": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws",
"/dns6/example.com/tcp/1234/tls/ws": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws",
"/dnsaddr/example.com/tcp/1234/tls/ws": "/dnsaddr/example.com/tcp/1234/tls/ws",
}
for unresolved, expectedMA := range testCases {
t.Run(unresolved, func(t *testing.T) {
m1 := ma.StringCast(unresolved)
wsTpt := WebsocketTransport{}
ctx := context.Background()
addrs, err := wsTpt.Resolve(ctx, m1)
require.NoError(t, err)
require.Len(t, addrs, 1)
require.Equal(t, expectedMA, addrs[0].String())
})
}
}
func TestSocksProxy(t *testing.T) {
testCases := []string{
"/ip4/1.2.3.4/tcp/1/ws", // No TLS
"/ip4/1.2.3.4/tcp/1/tls/ws", // TLS no SNI
"/ip4/1.2.3.4/tcp/1/tls/sni/example.com/ws", // TLS with an SNI
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
proxyServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
proxyServerErr := make(chan error, 1)
go func() {
defer proxyServer.Close()
c, err := proxyServer.Accept()
if err != nil {
proxyServerErr <- err
return
}
defer c.Close()
req := [32]byte{}
_, err = io.ReadFull(c, req[:3])
if err != nil {
proxyServerErr <- err
return
}
// Handshake a SOCKS5 client: https://www.rfc-editor.org/rfc/rfc1928.html#section-3
if !bytes.Equal([]byte{0x05, 0x01, 0x00}, req[:3]) {
t.Log("expected SOCKS5 connect request")
proxyServerErr <- err
return
}
_, err = c.Write([]byte{0x05, 0x00})
if err != nil {
proxyServerErr <- err
return
}
proxyServerErr <- nil
}()
orig := gws.DefaultDialer.Proxy
defer func() { gws.DefaultDialer.Proxy = orig }()
proxyUrl, err := url.Parse("socks5://" + proxyServer.Addr().String())
require.NoError(t, err)
gws.DefaultDialer.Proxy = http.ProxyURL(proxyUrl)
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
// This can be any wss address. We aren't actually going to dial it.
maToDial := ma.StringCast(tc)
_, err = tpt.Dial(context.Background(), maToDial, "")
require.ErrorContains(t, err, "failed to read connect reply from SOCKS5 proxy", "This should error as we don't have a real socks server")
select {
case <-time.After(1 * time.Second):
case err := <-proxyServerErr:
if err != nil {
t.Fatal(err)
}
}
})
}
}
func TestListenerAddr(t *testing.T) {
_, upgrader := newUpgrader(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer l1.Close()
require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String())
l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer l2.Close()
require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String())
}
func TestHandshakeTimeout(t *testing.T) {
handshakeTimeout := 200 * time.Millisecond
_, upgrader := newUpgrader(t)
tlsconf := generateTLSConfig(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf))
require.NoError(t, err)
fastWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
NetDial: func(network, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
return tcpConn, nil
},
}
slowWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
NetDial: func(network, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
// wait to simulate a slow handshake
time.Sleep(2 * handshakeTimeout)
return tcpConn, nil
},
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
t.Run("ws", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer wsListener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if !assert.NoError(t, err) {
return
}
conn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
conn.Close()
resp.Body.Close()
t.Fatal("should error as the handshake will time out")
}
})
t.Run("wss", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer wsListener.Close()
// Test that the normal dial works fine
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
require.NoError(t, err)
wsConn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
wsConn.Close()
resp.Body.Close()
t.Fatal("websocket handshake should have timed out")
}
})
}