mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-26 20:21:26 +08:00

This introduces a new GatedMaListener type which gates conns accepted from a manet.Listener with a gater and creates the rcmgr scope for it. Explicitly passing the scope allows for many guardrails that the previous interface assertion didn't. This breaks the previous responsibility of the upgradeListener method into two, one gating the connection initially, and the other upgrading the connection with a security and muxer selection. This split makes it easy to gate the connection with the resource manager as early as possible. This is especially true for websocket because we want to gate the connection just after the TCP connection is established, and not after the tls handshake + websocket upgrade is completed.
743 lines
22 KiB
Go
743 lines
22 KiB
Go
package websocket
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
gws "github.com/gorilla/websocket"
|
|
"github.com/libp2p/go-libp2p/core/crypto"
|
|
"github.com/libp2p/go-libp2p/core/network"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
"github.com/libp2p/go-libp2p/core/sec"
|
|
"github.com/libp2p/go-libp2p/core/sec/insecure"
|
|
"github.com/libp2p/go-libp2p/core/test"
|
|
"github.com/libp2p/go-libp2p/core/transport"
|
|
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
|
|
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
|
|
"github.com/libp2p/go-libp2p/p2p/security/noise"
|
|
ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite"
|
|
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
|
|
t.Helper()
|
|
id, m := newInsecureMuxer(t)
|
|
u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return id, u
|
|
}
|
|
|
|
func newSecureUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
|
|
t.Helper()
|
|
id, m := newSecureMuxer(t)
|
|
u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return id, u
|
|
}
|
|
|
|
func newInsecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
|
|
t.Helper()
|
|
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
|
|
require.NoError(t, err)
|
|
id, err := peer.IDFromPrivateKey(priv)
|
|
require.NoError(t, err)
|
|
return id, []sec.SecureTransport{insecure.NewWithIdentity(insecure.ID, id, priv)}
|
|
}
|
|
|
|
func newSecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
|
|
t.Helper()
|
|
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
id, err := peer.IDFromPrivateKey(priv)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
noiseTpt, err := noise.New(noise.ID, priv, nil)
|
|
require.NoError(t, err)
|
|
return id, []sec.SecureTransport{noiseTpt}
|
|
}
|
|
|
|
func lastComponent(t *testing.T, a ma.Multiaddr) *ma.Component {
|
|
t.Helper()
|
|
_, wscomponent := ma.SplitLast(a)
|
|
require.NotNil(t, wscomponent)
|
|
if wscomponent.Equal(wsComponent) {
|
|
return wsComponent
|
|
}
|
|
if wscomponent.Equal(wssComponent) {
|
|
return wssComponent
|
|
}
|
|
t.Fatal("expected a ws or wss component")
|
|
return nil
|
|
}
|
|
|
|
func generateTLSConfig(t *testing.T) *tls.Config {
|
|
t.Helper()
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{},
|
|
SignatureAlgorithm: x509.SHA256WithRSA,
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour), // valid for an hour
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
require.NoError(t, err)
|
|
return &tls.Config{
|
|
Certificates: []tls.Certificate{{
|
|
PrivateKey: priv,
|
|
Certificate: [][]byte{certDER},
|
|
}},
|
|
}
|
|
}
|
|
|
|
func TestCanDial(t *testing.T) {
|
|
d := &WebsocketTransport{}
|
|
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/ws")) {
|
|
t.Fatal("expected to match websocket maddr, but did not")
|
|
}
|
|
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/wss")) {
|
|
t.Fatal("expected to match secure websocket maddr, but did not")
|
|
}
|
|
if d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555")) {
|
|
t.Fatal("expected to not match tcp maddr, but did")
|
|
}
|
|
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/ws")) {
|
|
t.Fatal("expected to match secure websocket maddr, but did not")
|
|
}
|
|
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/tls/sni/example.com/ws")) {
|
|
t.Fatal("expected to match secure websocket maddr with sni, but did not")
|
|
}
|
|
if !d.CanDial(ma.StringCast("/dns4/example.com/tcp/5555/tls/sni/example.com/ws")) {
|
|
t.Fatal("expected to match secure websocket maddr with sni, but did not")
|
|
}
|
|
if !d.CanDial(ma.StringCast("/dnsaddr/example.com/tcp/5555/tls/sni/example.com/ws")) {
|
|
t.Fatal("expected to match secure websocket maddr with sni, but did not")
|
|
}
|
|
}
|
|
|
|
// testWSSServer returns a client hello info
|
|
func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID, chan error) {
|
|
errChan := make(chan error, 1)
|
|
|
|
ip := net.ParseIP("::")
|
|
tlsConf := getTLSConf(t, ip, time.Now(), time.Now().Add(time.Hour))
|
|
tlsConf.GetConfigForClient = func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
if chi.ServerName != "example.com" {
|
|
errChan <- fmt.Errorf("didn't get the expected sni")
|
|
}
|
|
return tlsConf, nil
|
|
}
|
|
|
|
id, u := newSecureUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSConfig(tlsConf))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
l, err := tpt.Listen(listenAddr)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
l.Close()
|
|
})
|
|
go func() {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error in accepting conn: %w", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
strm, err := conn.AcceptStream()
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("error in accepting stream: %w", err)
|
|
return
|
|
}
|
|
defer strm.Close()
|
|
close(errChan)
|
|
}()
|
|
|
|
return l.Multiaddr(), id, errChan
|
|
}
|
|
|
|
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
|
|
t.Helper()
|
|
certTempl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1234),
|
|
Subject: pkix.Name{Organization: []string{"websocket"}},
|
|
NotBefore: start,
|
|
NotAfter: end,
|
|
IsCA: true,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
BasicConstraintsValid: true,
|
|
IPAddresses: []net.IP{ip},
|
|
}
|
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
require.NoError(t, err)
|
|
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
|
|
require.NoError(t, err)
|
|
cert, err := x509.ParseCertificate(caBytes)
|
|
require.NoError(t, err)
|
|
return &tls.Config{
|
|
Certificates: []tls.Certificate{{
|
|
Certificate: [][]byte{cert.Raw},
|
|
PrivateKey: priv,
|
|
Leaf: cert,
|
|
}},
|
|
}
|
|
}
|
|
|
|
func TestHostHeaderWss(t *testing.T) {
|
|
server := &http.Server{}
|
|
l, err := net.Listen("tcp", ":0")
|
|
require.NoError(t, err)
|
|
defer server.Close()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer close(errChan)
|
|
if !strings.Contains(r.Host, "example.com") {
|
|
errChan <- errors.New("Didn't see host header")
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
})
|
|
server.TLSConfig = getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(time.Hour))
|
|
server.ServeTLS(l, "", "")
|
|
}()
|
|
|
|
_, port, err := net.SplitHostPort(l.Addr().String())
|
|
require.NoError(t, err)
|
|
serverMA := ma.StringCast("/ip4/127.0.0.1/tcp/" + port + "/tls/sni/example.com/ws")
|
|
|
|
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
|
|
_, u := newSecureUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
|
|
require.NoError(t, err)
|
|
|
|
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
|
require.NoError(t, err)
|
|
|
|
_, err = tpt.Dial(context.Background(), masToDial[0], test.RandPeerIDFatal(t))
|
|
require.Error(t, err)
|
|
|
|
err = <-errChan
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestDialWss(t *testing.T) {
|
|
serverMA, rid, errChan := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
|
|
require.Contains(t, serverMA.String(), "tls")
|
|
|
|
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
|
|
_, u := newSecureUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
|
|
require.NoError(t, err)
|
|
|
|
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
|
require.NoError(t, err)
|
|
|
|
conn, err := tpt.Dial(context.Background(), masToDial[0], rid)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
|
|
stream, err := conn.OpenStream(context.Background())
|
|
require.NoError(t, err)
|
|
defer stream.Close()
|
|
|
|
err = <-errChan
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestDialWssNoClientCert(t *testing.T) {
|
|
serverMA, rid, _ := testWSSServer(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/tls/sni/example.com/ws"))
|
|
require.Contains(t, serverMA.String(), "tls")
|
|
|
|
_, u := newSecureUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
|
require.NoError(t, err)
|
|
|
|
masToDial, err := tpt.Resolve(context.Background(), serverMA)
|
|
require.NoError(t, err)
|
|
|
|
_, err = tpt.Dial(context.Background(), masToDial[0], rid)
|
|
require.Error(t, err)
|
|
|
|
// The server doesn't have a signed certificate
|
|
require.Contains(t, err.Error(), "x509")
|
|
}
|
|
|
|
func TestWebsocketTransport(t *testing.T) {
|
|
t.Run("/ws", func(t *testing.T) {
|
|
peerA, ua := newUpgrader(t)
|
|
ta, err := New(ua, nil, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
peerB, ub := newUpgrader(t)
|
|
tb, err := New(ub, nil, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA)
|
|
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
|
|
|
|
})
|
|
t.Run("/wss", func(t *testing.T) {
|
|
peerA, ua := newUpgrader(t)
|
|
tca := generateTLSConfig(t)
|
|
ta, err := New(ua, nil, nil, WithTLSConfig(tca), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
peerB, ub := newUpgrader(t)
|
|
tcb := generateTLSConfig(t)
|
|
tb, err := New(ub, nil, nil, WithTLSConfig(tcb), WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/wss", peerA)
|
|
ttransport.SubtestTransport(t, tb, ta, "/ip4/127.0.0.1/tcp/0/ws", peerB)
|
|
})
|
|
}
|
|
|
|
func isWSS(addr ma.Multiaddr) bool {
|
|
if _, err := addr.ValueForProtocol(ma.P_WSS); err == nil {
|
|
return true
|
|
}
|
|
if _, err := addr.ValueForProtocol(ma.P_WS); err == nil {
|
|
return false
|
|
}
|
|
panic("not a WebSocket address")
|
|
}
|
|
|
|
func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
|
|
var opts []Option
|
|
var tlsConf *tls.Config
|
|
if secure {
|
|
tlsConf = generateTLSConfig(t)
|
|
opts = append(opts, WithTLSConfig(tlsConf))
|
|
}
|
|
server, u := newUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
|
|
require.NoError(t, err)
|
|
l, err := tpt.Listen(laddr)
|
|
require.NoError(t, err)
|
|
if secure {
|
|
require.Contains(t, l.Multiaddr().String(), "tls")
|
|
} else {
|
|
require.Equal(t, lastComponent(t, l.Multiaddr()).String(), wsComponent.String())
|
|
}
|
|
defer l.Close()
|
|
|
|
msg := []byte("HELLO WORLD")
|
|
|
|
go func() {
|
|
var opts []Option
|
|
if secure {
|
|
opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
}
|
|
_, u := newUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil, opts...)
|
|
require.NoError(t, err)
|
|
c, err := tpt.Dial(context.Background(), l.Multiaddr(), server)
|
|
require.NoError(t, err)
|
|
require.Equal(t, secure, isWSS(c.LocalMultiaddr()))
|
|
require.Equal(t, secure, isWSS(c.RemoteMultiaddr()))
|
|
str, err := c.OpenStream(context.Background())
|
|
require.NoError(t, err)
|
|
defer str.Close()
|
|
_, err = str.Write(msg)
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
c, err := l.Accept()
|
|
require.NoError(t, err)
|
|
defer c.Close()
|
|
require.Equal(t, secure, isWSS(c.LocalMultiaddr()))
|
|
require.Equal(t, secure, isWSS(c.RemoteMultiaddr()))
|
|
str, err := c.AcceptStream()
|
|
require.NoError(t, err)
|
|
defer str.Close()
|
|
|
|
out, err := io.ReadAll(str)
|
|
require.NoError(t, err)
|
|
require.Equal(t, out, msg, "got wrong message")
|
|
}
|
|
|
|
func TestWebsocketConnection(t *testing.T) {
|
|
t.Run("unencrypted", func(t *testing.T) {
|
|
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"), false)
|
|
})
|
|
t.Run("encrypted", func(t *testing.T) {
|
|
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"), true)
|
|
})
|
|
}
|
|
|
|
func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
|
|
_, u := newUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
|
require.NoError(t, err)
|
|
addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")
|
|
_, err = tpt.Listen(addr)
|
|
require.EqualError(t, err, fmt.Sprintf("cannot listen on wss address %s without a tls.Config", addr))
|
|
}
|
|
|
|
func TestWebsocketListenSecureAndInsecure(t *testing.T) {
|
|
serverID, serverUpgrader := newUpgrader(t)
|
|
server, err := New(serverUpgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
|
|
require.NoError(t, err)
|
|
|
|
lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
|
require.NoError(t, err)
|
|
lnSecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
|
|
require.NoError(t, err)
|
|
|
|
t.Run("insecure", func(t *testing.T) {
|
|
_, clientUpgrader := newUpgrader(t)
|
|
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
require.NoError(t, err)
|
|
|
|
// dialing the insecure address should succeed
|
|
conn, err := client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsComponent.String())
|
|
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsComponent.String())
|
|
|
|
// dialing the secure address should fail
|
|
_, err = client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("secure", func(t *testing.T) {
|
|
_, clientUpgrader := newUpgrader(t)
|
|
client, err := New(clientUpgrader, &network.NullResourceManager{}, nil, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
|
|
require.NoError(t, err)
|
|
|
|
// dialing the insecure address should succeed
|
|
conn, err := client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
defer conn.Close()
|
|
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wssComponent.String())
|
|
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wssComponent.String())
|
|
|
|
// dialing the insecure address should fail
|
|
_, err = client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestConcurrentClose(t *testing.T) {
|
|
_, u := newUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
|
require.NoError(t, err)
|
|
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer l.Close()
|
|
|
|
msg := []byte("HELLO WORLD")
|
|
|
|
go func() {
|
|
for i := 0; i < 100; i++ {
|
|
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
_, _ = c.Write(msg)
|
|
}()
|
|
go func() {
|
|
_ = c.Close()
|
|
}()
|
|
}
|
|
}()
|
|
|
|
for i := 0; i < 100; i++ {
|
|
c, _, err := l.Accept()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
c.Close()
|
|
}
|
|
}
|
|
|
|
func TestWriteZero(t *testing.T) {
|
|
_, u := newUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
l, err := tpt.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer l.Close()
|
|
|
|
msg := []byte(nil)
|
|
|
|
go func() {
|
|
c, err := tpt.maDial(context.Background(), l.Multiaddr(), &network.NullScope{})
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer c.Close()
|
|
|
|
for i := 0; i < 100; i++ {
|
|
n, err := c.Write(msg)
|
|
if n != 0 {
|
|
t.Errorf("expected to write 0 bytes, wrote %d", n)
|
|
}
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
c, _, err := l.Accept()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer c.Close()
|
|
buf := make([]byte, 100)
|
|
n, err := c.Read(buf)
|
|
if n != 0 {
|
|
t.Errorf("read %d bytes, expected 0", n)
|
|
}
|
|
if err != io.EOF {
|
|
t.Errorf("expected EOF, got err: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestResolveMultiaddr(t *testing.T) {
|
|
// map[unresolved]resolved
|
|
testCases := map[string]string{
|
|
"/dns/example.com/tcp/1234/wss": "/dns/example.com/tcp/1234/tls/sni/example.com/ws",
|
|
"/dns4/example.com/tcp/1234/wss": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws",
|
|
"/dns6/example.com/tcp/1234/wss": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws",
|
|
"/dnsaddr/example.com/tcp/1234/wss": "/dnsaddr/example.com/tcp/1234/wss",
|
|
"/dns4/example.com/tcp/1234/tls/ws": "/dns4/example.com/tcp/1234/tls/sni/example.com/ws",
|
|
"/dns6/example.com/tcp/1234/tls/ws": "/dns6/example.com/tcp/1234/tls/sni/example.com/ws",
|
|
"/dnsaddr/example.com/tcp/1234/tls/ws": "/dnsaddr/example.com/tcp/1234/tls/ws",
|
|
}
|
|
|
|
for unresolved, expectedMA := range testCases {
|
|
t.Run(unresolved, func(t *testing.T) {
|
|
|
|
m1 := ma.StringCast(unresolved)
|
|
wsTpt := WebsocketTransport{}
|
|
ctx := context.Background()
|
|
|
|
addrs, err := wsTpt.Resolve(ctx, m1)
|
|
require.NoError(t, err)
|
|
require.Len(t, addrs, 1)
|
|
|
|
require.Equal(t, expectedMA, addrs[0].String())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSocksProxy(t *testing.T) {
|
|
testCases := []string{
|
|
"/ip4/1.2.3.4/tcp/1/ws", // No TLS
|
|
"/ip4/1.2.3.4/tcp/1/tls/ws", // TLS no SNI
|
|
"/ip4/1.2.3.4/tcp/1/tls/sni/example.com/ws", // TLS with an SNI
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc, func(t *testing.T) {
|
|
proxyServer, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
proxyServerErr := make(chan error, 1)
|
|
|
|
go func() {
|
|
defer proxyServer.Close()
|
|
c, err := proxyServer.Accept()
|
|
if err != nil {
|
|
proxyServerErr <- err
|
|
return
|
|
}
|
|
defer c.Close()
|
|
|
|
req := [32]byte{}
|
|
_, err = io.ReadFull(c, req[:3])
|
|
if err != nil {
|
|
proxyServerErr <- err
|
|
return
|
|
}
|
|
|
|
// Handshake a SOCKS5 client: https://www.rfc-editor.org/rfc/rfc1928.html#section-3
|
|
if !bytes.Equal([]byte{0x05, 0x01, 0x00}, req[:3]) {
|
|
t.Log("expected SOCKS5 connect request")
|
|
proxyServerErr <- err
|
|
return
|
|
}
|
|
_, err = c.Write([]byte{0x05, 0x00})
|
|
if err != nil {
|
|
proxyServerErr <- err
|
|
return
|
|
}
|
|
|
|
proxyServerErr <- nil
|
|
}()
|
|
|
|
orig := gws.DefaultDialer.Proxy
|
|
defer func() { gws.DefaultDialer.Proxy = orig }()
|
|
|
|
proxyUrl, err := url.Parse("socks5://" + proxyServer.Addr().String())
|
|
require.NoError(t, err)
|
|
gws.DefaultDialer.Proxy = http.ProxyURL(proxyUrl)
|
|
|
|
tlsConfig := &tls.Config{InsecureSkipVerify: true} // Our test server doesn't have a cert signed by a CA
|
|
_, u := newSecureUpgrader(t)
|
|
tpt, err := New(u, &network.NullResourceManager{}, nil, WithTLSClientConfig(tlsConfig))
|
|
require.NoError(t, err)
|
|
|
|
// This can be any wss address. We aren't actually going to dial it.
|
|
maToDial := ma.StringCast(tc)
|
|
_, err = tpt.Dial(context.Background(), maToDial, "")
|
|
require.ErrorContains(t, err, "failed to read connect reply from SOCKS5 proxy", "This should error as we don't have a real socks server")
|
|
|
|
select {
|
|
case <-time.After(1 * time.Second):
|
|
case err := <-proxyServerErr:
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestListenerAddr(t *testing.T) {
|
|
_, upgrader := newUpgrader(t)
|
|
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithTLSConfig(generateTLSConfig(t)))
|
|
require.NoError(t, err)
|
|
l1, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
|
require.NoError(t, err)
|
|
defer l1.Close()
|
|
require.Regexp(t, `^ws://127\.0\.0\.1:[\d]+$`, l1.Addr().String())
|
|
l2, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
|
|
require.NoError(t, err)
|
|
defer l2.Close()
|
|
require.Regexp(t, `^wss://127\.0\.0\.1:[\d]+$`, l2.Addr().String())
|
|
}
|
|
func TestHandshakeTimeout(t *testing.T) {
|
|
handshakeTimeout := 200 * time.Millisecond
|
|
_, upgrader := newUpgrader(t)
|
|
tlsconf := generateTLSConfig(t)
|
|
transport, err := New(upgrader, &network.NullResourceManager{}, nil, WithHandshakeTimeout(handshakeTimeout), WithTLSConfig(tlsconf))
|
|
require.NoError(t, err)
|
|
|
|
fastWSDialer := gws.Dialer{
|
|
HandshakeTimeout: 10 * handshakeTimeout,
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
NetDial: func(network, addr string) (net.Conn, error) {
|
|
tcpConn, err := net.Dial("tcp", addr)
|
|
if !assert.NoError(t, err) {
|
|
return nil, err
|
|
}
|
|
return tcpConn, nil
|
|
},
|
|
}
|
|
|
|
slowWSDialer := gws.Dialer{
|
|
HandshakeTimeout: 10 * handshakeTimeout,
|
|
NetDial: func(network, addr string) (net.Conn, error) {
|
|
tcpConn, err := net.Dial("tcp", addr)
|
|
if !assert.NoError(t, err) {
|
|
return nil, err
|
|
}
|
|
// wait to simulate a slow handshake
|
|
time.Sleep(2 * handshakeTimeout)
|
|
return tcpConn, nil
|
|
},
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
t.Run("ws", func(t *testing.T) {
|
|
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
|
|
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
|
|
require.NoError(t, err)
|
|
defer wsListener.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
|
defer cancel()
|
|
conn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
conn.Close()
|
|
resp.Body.Close()
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
|
defer cancel()
|
|
conn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
|
if err == nil {
|
|
conn.Close()
|
|
resp.Body.Close()
|
|
t.Fatal("should error as the handshake will time out")
|
|
}
|
|
})
|
|
|
|
t.Run("wss", func(t *testing.T) {
|
|
// test the gatedMaListener as we're interested in the websocket handshake timeout and not the upgrader steps.
|
|
wsListener, err := transport.gatedMaListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
|
|
require.NoError(t, err)
|
|
defer wsListener.Close()
|
|
|
|
// Test that the normal dial works fine
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
|
defer cancel()
|
|
wsConn, resp, err := fastWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
|
require.NoError(t, err)
|
|
wsConn.Close()
|
|
resp.Body.Close()
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), 10*handshakeTimeout)
|
|
defer cancel()
|
|
wsConn, resp, err = slowWSDialer.DialContext(ctx, wsListener.Addr().String(), nil)
|
|
if err == nil {
|
|
wsConn.Close()
|
|
resp.Body.Close()
|
|
t.Fatal("websocket handshake should have timed out")
|
|
}
|
|
})
|
|
}
|