Files
go-libp2p/p2p/transport/websocket/websocket_test.go
Marten Seemann 94ce988fbb use yamux instead of mplex in tests (#1456)
* use yamux in TCP transport tests

* use yamux in websocket tests
2022-04-26 03:19:06 -07:00

351 lines
9.1 KiB
Go

package websocket
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"io/ioutil"
"math/big"
"net"
"testing"
"time"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/sec"
"github.com/libp2p/go-libp2p-core/sec/insecure"
"github.com/libp2p/go-libp2p-core/test"
"github.com/libp2p/go-libp2p-core/transport"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
csms "github.com/libp2p/go-conn-security-multistream"
ttransport "github.com/libp2p/go-libp2p-testing/suites/transport"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)
func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) {
t.Helper()
id, m := newSecureMuxer(t)
u, err := tptu.New(m, yamux.DefaultTransport)
if err != nil {
t.Fatal(err)
}
return id, u
}
func newSecureMuxer(t *testing.T) (peer.ID, sec.SecureMuxer) {
t.Helper()
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
if err != nil {
t.Fatal(err)
}
id, err := peer.IDFromPrivateKey(priv)
if err != nil {
t.Fatal(err)
}
var secMuxer csms.SSMuxer
secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(id, priv))
return id, &secMuxer
}
func lastComponent(t *testing.T, a ma.Multiaddr) ma.Multiaddr {
t.Helper()
_, wscomponent := ma.SplitLast(a)
require.NotNil(t, wscomponent)
if wscomponent.Equal(wsma) {
return wsma
}
if wscomponent.Equal(wssma) {
return wssma
}
t.Fatal("expected a ws or wss component")
return nil
}
func generateTLSConfig(t *testing.T) *tls.Config {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{},
SignatureAlgorithm: x509.SHA256WithRSA,
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour), // valid for an hour
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
PrivateKey: priv,
Certificate: [][]byte{certDER},
}},
}
}
func TestCanDial(t *testing.T) {
d := &WebsocketTransport{}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/ws")) {
t.Fatal("expected to match websocket maddr, but did not")
}
if !d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555/wss")) {
t.Fatal("expected to match secure websocket maddr, but did not")
}
if d.CanDial(ma.StringCast("/ip4/127.0.0.1/tcp/5555")) {
t.Fatal("expected to not match tcp maddr, but did")
}
}
func TestDialWss(t *testing.T) {
if _, err := net.LookupIP("nyc-1.bootstrap.libp2p.io"); err != nil {
t.Skip("this test requries an internet connection and it seems like we currently don't have one")
}
raddr := ma.StringCast("/dns4/nyc-1.bootstrap.libp2p.io/tcp/443/wss")
rid, err := peer.Decode("QmSoLueR4xBeUbY9WZ9xGUUxunbKWcrNFTDAadQJmocnWm")
if err != nil {
t.Fatal(err)
}
tlsConfig := &tls.Config{InsecureSkipVerify: true}
_, u := newUpgrader(t)
tpt, err := New(u, network.NullResourceManager, WithTLSClientConfig(tlsConfig))
if err != nil {
t.Fatal(err)
}
conn, err := tpt.Dial(context.Background(), raddr, rid)
if err != nil {
t.Fatal(err)
}
stream, err := conn.OpenStream(context.Background())
if err != nil {
t.Fatal(err)
}
defer stream.Close()
}
func TestWebsocketTransport(t *testing.T) {
t.Skip("This test is failing, see https://github.com/libp2p/go-ws-transport/issues/99")
_, ua := newUpgrader(t)
ta, err := New(ua, nil)
if err != nil {
t.Fatal(err)
}
_, ub := newUpgrader(t)
tb, err := New(ub, nil)
if err != nil {
t.Fatal(err)
}
ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", "peerA")
}
func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) {
var opts []Option
var tlsConf *tls.Config
if secure {
tlsConf = generateTLSConfig(t)
opts = append(opts, WithTLSConfig(tlsConf))
}
server, u := newUpgrader(t)
tpt, err := New(u, network.NullResourceManager, opts...)
require.NoError(t, err)
l, err := tpt.Listen(laddr)
require.NoError(t, err)
if secure {
require.Equal(t, lastComponent(t, l.Multiaddr()), wssma)
} else {
require.Equal(t, lastComponent(t, l.Multiaddr()), wsma)
}
defer l.Close()
msg := []byte("HELLO WORLD")
go func() {
var opts []Option
if secure {
opts = append(opts, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
}
_, u := newUpgrader(t)
tpt, err := New(u, network.NullResourceManager, opts...)
require.NoError(t, err)
c, err := tpt.Dial(context.Background(), l.Multiaddr(), server)
require.NoError(t, err)
str, err := c.OpenStream(context.Background())
require.NoError(t, err)
defer str.Close()
_, err = str.Write(msg)
require.NoError(t, err)
}()
c, err := l.Accept()
require.NoError(t, err)
defer c.Close()
str, err := c.AcceptStream()
require.NoError(t, err)
defer str.Close()
out, err := ioutil.ReadAll(str)
require.NoError(t, err)
require.Equal(t, out, msg, "got wrong message")
}
func TestWebsocketConnection(t *testing.T) {
t.Run("unencrypted", func(t *testing.T) {
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"), false)
})
t.Run("encrypted", func(t *testing.T) {
connectAndExchangeData(t, ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"), true)
})
}
func TestWebsocketListenSecureFailWithoutTLSConfig(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, network.NullResourceManager)
require.NoError(t, err)
addr := ma.StringCast("/ip4/127.0.0.1/tcp/0/wss")
_, err = tpt.Listen(addr)
require.EqualError(t, err, fmt.Sprintf("cannot listen on wss address %s without a tls.Config", addr))
}
func TestWebsocketListenSecureAndInsecure(t *testing.T) {
serverID, serverUpgrader := newUpgrader(t)
server, err := New(serverUpgrader, network.NullResourceManager, WithTLSConfig(generateTLSConfig(t)))
require.NoError(t, err)
lnInsecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
require.NoError(t, err)
lnSecure, err := server.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/wss"))
require.NoError(t, err)
t.Run("insecure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, network.NullResourceManager, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
conn, err := client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsma.String())
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsma.String())
// dialing the secure address should fail
_, err = client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
require.NoError(t, err)
})
t.Run("secure", func(t *testing.T) {
_, clientUpgrader := newUpgrader(t)
client, err := New(clientUpgrader, network.NullResourceManager, WithTLSClientConfig(&tls.Config{InsecureSkipVerify: true}))
require.NoError(t, err)
// dialing the insecure address should succeed
conn, err := client.Dial(context.Background(), lnSecure.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()), wssma)
require.Equal(t, lastComponent(t, conn.LocalMultiaddr()), wssma)
// dialing the insecure address should fail
_, err = client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID)
require.NoError(t, err)
})
}
func TestConcurrentClose(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, network.NullResourceManager)
require.NoError(t, err)
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := []byte("HELLO WORLD")
go func() {
for i := 0; i < 100; i++ {
c, err := tpt.maDial(context.Background(), l.Multiaddr())
if err != nil {
t.Error(err)
return
}
go func() {
_, _ = c.Write(msg)
}()
go func() {
_ = c.Close()
}()
}
}()
for i := 0; i < 100; i++ {
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
c.Close()
}
}
func TestWriteZero(t *testing.T) {
_, u := newUpgrader(t)
tpt, err := New(u, network.NullResourceManager)
if err != nil {
t.Fatal(err)
}
l, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws"))
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := []byte(nil)
go func() {
c, err := tpt.maDial(context.Background(), l.Multiaddr())
if err != nil {
t.Error(err)
return
}
defer c.Close()
for i := 0; i < 100; i++ {
n, err := c.Write(msg)
if n != 0 {
t.Errorf("expected to write 0 bytes, wrote %d", n)
}
if err != nil {
t.Error(err)
return
}
}
}()
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := make([]byte, 100)
n, err := c.Read(buf)
if n != 0 {
t.Errorf("read %d bytes, expected 0", n)
}
if err != io.EOF {
t.Errorf("expected EOF, got err: %s", err)
}
}