Files
go-libp2p/p2p/net/swarm/dial_worker_test.go
2023-06-03 02:24:50 -07:00

411 lines
10 KiB
Go

package swarm
import (
"context"
"crypto/rand"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/libp2p/go-libp2p/core/sec/insecure"
"github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/host/eventbus"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/muxer/yamux"
tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/require"
)
func newPeer(t *testing.T) (crypto.PrivKey, peer.ID) {
priv, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(priv)
require.NoError(t, err)
return priv, id
}
func makeSwarm(t *testing.T) *Swarm {
priv, id := newPeer(t)
ps, err := pstoremem.NewPeerstore()
require.NoError(t, err)
ps.AddPubKey(context.Background(), id, priv.GetPublic())
ps.AddPrivKey(context.Background(), id, priv)
t.Cleanup(func() { ps.Close() })
s, err := NewSwarm(id, ps, eventbus.NewBus(), WithDialTimeout(time.Second))
require.NoError(t, err)
upgrader := makeUpgrader(t, s)
var tcpOpts []tcp.Option
tcpOpts = append(tcpOpts, tcp.DisableReuseport())
tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...)
require.NoError(t, err)
if err := s.AddTransport(tcpTransport); err != nil {
t.Fatal(err)
}
if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")); err != nil {
t.Fatal(err)
}
reuse, err := quicreuse.NewConnManager([32]byte{})
if err != nil {
t.Fatal(err)
}
quicTransport, err := quic.NewTransport(priv, reuse, nil, nil, nil)
if err != nil {
t.Fatal(err)
}
if err := s.AddTransport(quicTransport); err != nil {
t.Fatal(err)
}
if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil {
t.Fatal(err)
}
return s
}
func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader {
id := n.LocalPeer()
pk := n.Peerstore().PrivKey(context.Background(), id)
st := insecure.NewWithIdentity(insecure.ID, id, pk)
u, err := tptu.New([]sec.SecureTransport{st}, []tptu.StreamMuxer{{ID: yamux.ID, Muxer: yamux.DefaultTransport}}, nil, nil, nil)
require.NoError(t, err)
return u
}
func TestDialWorkerLoopBasic(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)
defer s1.Close()
defer s2.Close()
// Only pass in a single address here, otherwise we might end up with a TCP and QUIC connection dialed.
s1.Peerstore().AddAddrs(context.Background(), s2.LocalPeer(), []ma.Multiaddr{s2.ListenAddresses()[0]}, peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
resch := make(chan dialResponse)
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
go worker.loop()
var conn *Conn
reqch <- dialRequest{ctx: context.Background(), resch: resch}
select {
case res := <-resch:
require.NoError(t, res.err)
conn = res.conn
case <-time.After(10 * time.Second):
t.Fatal("dial didn't complete")
}
s, err := conn.NewStream(context.Background())
require.NoError(t, err)
s.Close()
var conn2 *Conn
reqch <- dialRequest{ctx: context.Background(), resch: resch}
select {
case res := <-resch:
require.NoError(t, res.err)
conn2 = res.conn
case <-time.After(10 * time.Second):
t.Fatal("dial didn't complete")
}
// can't use require.Equal here, as this does a deep comparison
if conn != conn2 {
t.Fatalf("expecting the same connection from both dials. %s <-> %s vs. %s <-> %s", conn.LocalMultiaddr(), conn.RemoteMultiaddr(), conn2.LocalMultiaddr(), conn2.RemoteMultiaddr())
}
close(reqch)
worker.wg.Wait()
}
func TestDialWorkerLoopConcurrent(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)
defer s1.Close()
defer s2.Close()
s1.Peerstore().AddAddrs(context.Background(), s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
go worker.loop()
const dials = 100
var wg sync.WaitGroup
resch := make(chan dialResponse, dials)
for i := 0; i < dials; i++ {
wg.Add(1)
go func() {
defer wg.Done()
reschgo := make(chan dialResponse, 1)
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
select {
case res := <-reschgo:
resch <- res
case <-time.After(time.Minute):
resch <- dialResponse{err: errors.New("timed out!")}
}
}()
}
wg.Wait()
for i := 0; i < dials; i++ {
res := <-resch
require.NoError(t, res.err)
}
t.Log("all concurrent dials done")
close(reqch)
worker.wg.Wait()
}
func TestDialWorkerLoopFailure(t *testing.T) {
s1 := makeSwarm(t)
defer s1.Close()
_, p2 := newPeer(t)
s1.Peerstore().AddAddrs(context.Background(), p2, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
resch := make(chan dialResponse)
worker := newDialWorker(s1, p2, reqch)
go worker.loop()
reqch <- dialRequest{ctx: context.Background(), resch: resch}
select {
case res := <-resch:
require.Error(t, res.err)
case <-time.After(time.Minute):
t.Fatal("dial didn't complete")
}
close(reqch)
worker.wg.Wait()
}
func TestDialWorkerLoopConcurrentFailure(t *testing.T) {
s1 := makeSwarm(t)
defer s1.Close()
_, p2 := newPeer(t)
s1.Peerstore().AddAddrs(context.Background(), p2, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
worker := newDialWorker(s1, p2, reqch)
go worker.loop()
const dials = 100
var errTimeout = errors.New("timed out!")
var wg sync.WaitGroup
resch := make(chan dialResponse, dials)
for i := 0; i < dials; i++ {
wg.Add(1)
go func() {
defer wg.Done()
reschgo := make(chan dialResponse, 1)
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
select {
case res := <-reschgo:
resch <- res
case <-time.After(time.Minute):
resch <- dialResponse{err: errTimeout}
}
}()
}
wg.Wait()
for i := 0; i < dials; i++ {
res := <-resch
require.Error(t, res.err)
if res.err == errTimeout {
t.Fatal("dial response timed out")
}
}
t.Log("all concurrent dials done")
close(reqch)
worker.wg.Wait()
}
func TestDialWorkerLoopConcurrentMix(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)
defer s1.Close()
defer s2.Close()
s1.Peerstore().AddAddrs(context.Background(), s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL)
s1.Peerstore().AddAddrs(context.Background(), s2.LocalPeer(), []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
go worker.loop()
const dials = 100
var wg sync.WaitGroup
resch := make(chan dialResponse, dials)
for i := 0; i < dials; i++ {
wg.Add(1)
go func() {
defer wg.Done()
reschgo := make(chan dialResponse, 1)
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
select {
case res := <-reschgo:
resch <- res
case <-time.After(time.Minute):
resch <- dialResponse{err: errors.New("timed out!")}
}
}()
}
wg.Wait()
for i := 0; i < dials; i++ {
res := <-resch
require.NoError(t, res.err)
}
t.Log("all concurrent dials done")
close(reqch)
worker.wg.Wait()
}
func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) {
s1 := makeSwarm(t)
defer s1.Close()
_, p2 := newPeer(t)
var addrs []ma.Multiaddr
for i := 0; i < 16; i++ {
addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/11.0.0.%d/tcp/%d", i%256, 1234+i)))
}
s1.Peerstore().AddAddrs(context.Background(), p2, addrs, peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
worker := newDialWorker(s1, p2, reqch)
go worker.loop()
const dials = 100
var errTimeout = errors.New("timed out!")
var wg sync.WaitGroup
resch := make(chan dialResponse, dials)
for i := 0; i < dials; i++ {
wg.Add(1)
go func() {
defer wg.Done()
reschgo := make(chan dialResponse, 1)
reqch <- dialRequest{ctx: context.Background(), resch: reschgo}
select {
case res := <-reschgo:
t.Log("received result")
resch <- res
case <-time.After(15 * time.Second):
resch <- dialResponse{err: errTimeout}
}
}()
}
wg.Wait()
for i := 0; i < dials; i++ {
res := <-resch
require.Error(t, res.err)
if res.err == errTimeout {
t.Fatal("dial response timed out")
}
}
t.Log("all concurrent dials done")
close(reqch)
worker.wg.Wait()
}
func TestDialWorkerLoopAddrDedup(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)
defer s1.Close()
defer s2.Close()
t1 := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 10000))
t2 := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 10000))
// acceptAndClose accepts a connection and closes it
acceptAndClose := func(a ma.Multiaddr, ch chan struct{}, closech chan struct{}) {
list, err := manet.Listen(a)
if err != nil {
t.Error(err)
return
}
go func() {
ch <- struct{}{}
for {
conn, err := list.Accept()
if err != nil {
return
}
ch <- struct{}{}
conn.Close()
}
}()
<-closech
list.Close()
}
ch := make(chan struct{}, 1)
closeCh := make(chan struct{})
go acceptAndClose(t1, ch, closeCh)
defer close(closeCh)
<-ch // the routine has started listening on addr
s1.Peerstore().AddAddrs(context.Background(), s2.LocalPeer(), []ma.Multiaddr{t1}, peerstore.PermanentAddrTTL)
reqch := make(chan dialRequest)
resch := make(chan dialResponse, 2)
worker := newDialWorker(s1, s2.LocalPeer(), reqch)
go worker.loop()
defer worker.wg.Wait()
defer close(reqch)
reqch <- dialRequest{ctx: context.Background(), resch: resch}
<-ch
<-resch
// Need to clear backoff otherwise the dial attempt would not be made
s1.Backoff().Clear(s2.LocalPeer())
s1.Peerstore().ClearAddrs(context.Background(), s2.LocalPeer())
s1.Peerstore().AddAddrs(context.Background(), s2.LocalPeer(), []ma.Multiaddr{t2}, peerstore.PermanentAddrTTL)
reqch <- dialRequest{ctx: context.Background(), resch: resch}
select {
case r := <-resch:
require.Error(t, r.err)
case <-ch:
t.Errorf("didn't expect a connection attempt")
case <-time.After(5 * time.Second):
t.Errorf("expected a fail response")
}
}