Files
go-libp2p/p2p/transport/websocket/websocket_test.go
2025-06-11 10:00:47 +01:00

743 lines
22 KiB
Go

package websocket
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/url"
"strings"
"testing"
"time"
gws "github.com/gorilla/websocket"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/core/sec/insecure"
"github.com/libp2p/go-libp2p/core/test"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
"github.com/libp2p/go-libp2p/p2p/security/noise"
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
t.Helper()
id, m := newInsecureMuxer(t)
u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
return id, u
}
func newSecureUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
t.Helper()
id, m := newSecureMuxer(t)
u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
return id, u
}
func newInsecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
t.Helper()
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(priv)
require.NoError(t, err)
return id, []sec.SecureTransport{insecure.NewWithIdentity(insecure.ID, id, priv)}
}
func newSecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
t.Helper()
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
if err != nil {
t.Fatal(err)
}
id, err := peer.IDFromPrivateKey(priv)
if err != nil {
t.Fatal(err)
}
noiseTpt, err := noise.New(noise.ID, priv, nil)
require.NoError(t, err)
return id, []sec.SecureTransport{noiseTpt}
}
func lastComponent(t *testing.T, a ma.Multiaddr) *ma.Component {
t.Helper()
_, wscomponent := ma.SplitLast(a)
require.NotNil(t, wscomponent)
if wscomponent.Equal(wsComponent) {
return wsComponent
}
if wscomponent.Equal(wssComponent) {
return wssComponent
}
t.Fatal("expected a ws or wss component")
return nil
}
func generateTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}},
}
}
func TestCanDial(t *testing.T) {
d := &WebsocketTransport{}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/ws")) {
t.Fatal("expected to match websocket maddr, but did not")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/wss")) {
t.Fatal("expected to match secure websocket maddr, but did not")
}
if d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555")) {
t.Fatal("expected to not match tcp maddr, but did")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/ws")) {
t.Fatal("expected to match secure websocket maddr, but did not")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/sni/example.com/ws")) {
t.Fatal("expected to match secure websocket maddr with sni, but did not")
}
if !d.CanDial(ma.StringCast("/dns4/example.com/tcp/5555/tls/sni/example.com/ws")) {
t.Fatal("expected to match secure websocket maddr with sni, but did not")
}
if !d.CanDial(ma.StringCast("/dnsaddr/example.com/tcp/5555/tls/sni/example.com/ws")) {
t.Fatal("expected to match secure websocket maddr with sni, but did not")
}
}
// testWSSServer returns a client hello info
func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID, chan error) {
errChan := make(chan error, 1)
ip := net.ParseIP("::")
tlsConf := getTLSConf(t, ip, time.Now(), time.Now().Add(time.Hour))
tlsConf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
if chi.ServerName != "example.com" {
errChan <- fmt.Errorf("didn't get the expected sni")
}
return tlsConf, nil
}
id, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf))
if err != nil {
t.Fatal(err)
}
l, err := tpt.Listen(listenAddr)
require.NoError(t, err)
t.Cleanup(func() {
l.Close()
})
go func() {
conn, err := l.Accept()
if err != nil {
errChan <- fmt.Errorf("error in accepting conn: %w", err)
return
}
defer conn.Close()
strm, err := conn.AcceptStream()
if err != nil {
errChan <- fmt.Errorf("error in accepting stream: %w", err)
return
}
defer strm.Close()
close(errChan)
}()
return l.Multiaddr(), id, errChan
}
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
t.Helper()
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1234),
Subject: pkix.Name{Organization: []string{"websocket"}},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{ip},
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
}
}
func TestHostHeaderWss(t *testing.T) {
server := &http.Server{}
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
defer server.Close()
errChan := make(chan error, 1)
go func() {
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(errChan)
if !strings.Contains(r.Host, "example.com") {
errChan <- errors.New("Didn't see host header")
}
w.WriteHeader(http.StatusNotFound)
})
server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour))
server.ServeTLS(l, "", "")
}()
_, port, err := net.SplitHostPort(l.Addr().String())
require.NoError(t, err)
serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws")
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
_, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t))
require.Error(t, err)
err = <-errChan
require.NoError(t, err)
}
func TestDialWss(t *testing.T) {
serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
require.Contains(t, serverMA.String(), "tls")
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
conn, err := tpt.Dial(context.Background(), masToDial[0], rid)
require.NoError(t, err)
defer conn.Close()
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
defer stream.Close()
err = <-errChan
require.NoError(t, err)
}
func TestDialWssNoClientCert(t *testing.T) {
serverMA, rid, _ := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
require.Contains(t, serverMA.String(), "tls")
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
masToDial, err := tpt.Resolve(context.Background(), serverMA)
require.NoError(t, err)
_, err = tpt.Dial(context.Background(), masToDial[0], rid)
require.Error(t, err)
// The server doesn't have a signed certificate
require.Contains(t, err.Error(), "x509")
}
func TestWebsocketTransport(t *testing.T) {
t.Run("/ws", func(t *testing.T) {
peerA, ua := newUpgrader(t)
ta, err := New(ua, nil, nil)
if err != nil {
t.Fatal(err)
}
peerB, ub := newUpgrader(t)
tb, err := New(ub, nil, nil)
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
t.Run("/wss", func(t *testing.T) {
peerA, ua := newUpgrader(t)
tca := generateTLSConfig(t)
ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
peerB, ub := newUpgrader(t)
tcb := generateTLSConfig(t)
tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA)
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
})
}
func isWSS(addr ma.Multiaddr) bool {
if _, err := addr.ValueForProtocol(ma.P_WSS); err == nil {
return true
}
if _, err := addr.ValueForProtocol(ma.P_WS); err == nil {
return false
}
panic("not a WebSocket address")
}
func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
var opts []Option
var tlsConf *tls.Config
if secure {
tlsConf = generateTLSConfig(t)
opts = append(opts, WithTLSConfig(tlsConf))
}
server, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
require.NoError(t, err)
l, err := tpt.Listen(laddr)
require.NoError(t, err)
if secure {
require.Contains(t, l.Multiaddr().String(), "tls")
} else {
require.Equal(t, lastComponent(t, l.Multiaddr()).String(), wsComponent.String())
}
defer l.Close()
msg := []byte("HELLO WORLD")
go func() {
var opts []Option
if secure {
opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
}
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
require.NoError(t, err)
c, err := tpt.Dial(context.Background(), l.Multiaddr(), server)
require.NoError(t, err)
require.Equal(t, secure, isWSS(c.LocalMultiaddr()))
require.Equal(t, secure, isWSS(c.RemoteMultiaddr()))
str, err := c.OpenStream(context.Background())
require.NoError(t, err)
defer str.Close()
_, err = str.Write(msg)
require.NoError(t, err)
}()
c, err := l.Accept()
require.NoError(t, err)
defer c.Close()
require.Equal(t, secure, isWSS(c.LocalMultiaddr()))
require.Equal(t, secure, isWSS(c.RemoteMultiaddr()))
str, err := c.AcceptStream()
require.NoError(t, err)
defer str.Close()
out, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, out, msg, "got wrong message")
}
func TestWebsocketConnection(t *testing.T) {
t.Run("unencrypted", func(t *testing.T) {
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"), false)
})
t.Run("encrypted", func(t *testing.T) {
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"), true)
})
}
func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")
_, err = tpt.Listen(addr)
require.EqualError(t, err, fmt.Sprintf("cannot listen on wss address %s without a tls.Config", addr))
}
func TestWebsocketListenSecureAndInsecure(t *testing.T) {
serverID, serverUpgrader := newUpgrader(t)
server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
lnSecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
t.Run("insecure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
conn, err := client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsComponent.String())
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsComponent.String())
// dialing the secure address should fail
_, err = client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
require.NoError(t, err)
})
t.Run("secure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
conn, err := client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wssComponent.String())
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wssComponent.String())
// dialing the insecure address should fail
_, err = client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
require.NoError(t, err)
})
}
func TestConcurrentClose(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
require.NoError(t, err)
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := []byte("HELLO WORLD")
go func() {
for i := 0; i < 100; i++ {
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil {
t.Error(err)
return
}
go func() {
_, _ = c.Write(msg)
}()
go func() {
_ = c.Close()
}()
}
}()
for i := 0; i < 100; i++ {
c, _, err := l.Accept()
if err != nil {
t.Fatal(err)
}
c.Close()
}
}
func TestWriteZero(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil)
if err != nil {
t.Fatal(err)
}
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := []byte(nil)
go func() {
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
if err != nil {
t.Error(err)
return
}
defer c.Close()
for i := 0; i < 100; i++ {
n, err := c.Write(msg)
if n != 0 {
t.Errorf("expected to write 0 bytes, wrote %d", n)
}
if err != nil {
t.Error(err)
return
}
}
}()
c, _, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := make([]byte, 100)
n, err := c.Read(buf)
if n != 0 {
t.Errorf("read %d bytes, expected 0", n)
}
if err != io.EOF {
t.Errorf("expected EOF, got err: %s", err)
}
}
func TestResolveMultiaddr(t *testing.T) {
// map[unresolved]resolved
testCases := map[string]string{
"/dns/example.com/tcp/1234/wss": "/dns/example.com/tcp/1234/tls/sni/example.com/ws",
"/dns4/example.com/tcp/1234/wss": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws",
"/dns6/example.com/tcp/1234/wss": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws",
"/dnsaddr/example.com/tcp/1234/wss": "/dnsaddr/example.com/tcp/1234/wss",
"/dns4/example.com/tcp/1234/tls/ws": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws",
"/dns6/example.com/tcp/1234/tls/ws": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws",
"/dnsaddr/example.com/tcp/1234/tls/ws": "/dnsaddr/example.com/tcp/1234/tls/ws",
}
for unresolved, expectedMA := range testCases {
t.Run(unresolved, func(t *testing.T) {
m1 := ma.StringCast(unresolved)
wsTpt := WebsocketTransport{}
ctx := context.Background()
addrs, err := wsTpt.Resolve(ctx, m1)
require.NoError(t, err)
require.Len(t, addrs, 1)
require.Equal(t, expectedMA, addrs[0].String())
})
}
}
func TestSocksProxy(t *testing.T) {
testCases := []string{
"/ip4/1.2.3.4/tcp/1/ws", // No TLS
"/ip4/1.2.3.4/tcp/1/tls/ws", // TLS no SNI
"/ip4/1.2.3.4/tcp/1/tls/sni/example.com/ws", // TLS with an SNI
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
proxyServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
proxyServerErr := make(chan error, 1)
go func() {
defer proxyServer.Close()
c, err := proxyServer.Accept()
if err != nil {
proxyServerErr <- err
return
}
defer c.Close()
req := [32]byte{}
_, err = io.ReadFull(c, req[:3])
if err != nil {
proxyServerErr <- err
return
}
// Handshake a SOCKS5 client: https://www.rfc-editor.org/rfc/rfc1928.html#section-3
if !bytes.Equal([]byte{0x05, 0x01, 0x00}, req[:3]) {
t.Log("expected SOCKS5 connect request")
proxyServerErr <- err
return
}
_, err = c.Write([]byte{0x05, 0x00})
if err != nil {
proxyServerErr <- err
return
}
proxyServerErr <- nil
}()
orig := gws.DefaultDialer.Proxy
defer func() { gws.DefaultDialer.Proxy = orig }()
proxyUrl, err := url.Parse("socks5://" + proxyServer.Addr().String())
require.NoError(t, err)
gws.DefaultDialer.Proxy = http.ProxyURL(proxyUrl)
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
_, u := newSecureUpgrader(t)
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
require.NoError(t, err)
// This can be any wss address. We aren't actually going to dial it.
maToDial := ma.StringCast(tc)
_, err = tpt.Dial(context.Background(), maToDial, "")
require.ErrorContains(t, err, "failed to read connect reply from SOCKS5 proxy", "This should error as we don't have a real socks server")
select {
case <-time.After(1 * time.Second):
case err := <-proxyServerErr:
if err != nil {
t.Fatal(err)
}
}
})
}
}
func TestListenerAddr(t *testing.T) {
_, upgrader := newUpgrader(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer l1.Close()
require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String())
l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer l2.Close()
require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String())
}
func TestHandshakeTimeout(t *testing.T) {
handshakeTimeout := 200 * time.Millisecond
_, upgrader := newUpgrader(t)
tlsconf := generateTLSConfig(t)
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf))
require.NoError(t, err)
fastWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
NetDial: func(_, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
return tcpConn, nil
},
}
slowWSDialer := gws.Dialer{
HandshakeTimeout: 10 * handshakeTimeout,
NetDial: func(_, addr string) (net.Conn, error) {
tcpConn, err := net.Dial("tcp", addr)
if !assert.NoError(t, err) {
return nil, err
}
// wait to simulate a slow handshake
time.Sleep(2 * handshakeTimeout)
return tcpConn, nil
},
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
t.Run("ws", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
defer wsListener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if !assert.NoError(t, err) {
return
}
conn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
conn.Close()
resp.Body.Close()
t.Fatal("should error as the handshake will time out")
}
})
t.Run("wss", func(t *testing.T) {
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
defer wsListener.Close()
// Test that the normal dial works fine
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
require.NoError(t, err)
wsConn.Close()
resp.Body.Close()
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
defer cancel()
wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
if err == nil {
wsConn.Close()
resp.Body.Close()
t.Fatal("websocket handshake should have timed out")
}
})
}