use testing.Cleanup to shut down peerstore and revert most test changes

This commit is contained in:
Marten Seemann
2021-09-07 10:42:50 +01:00
parent 782897ea41
commit a872d26b7c
7 changed files with 29 additions and 53 deletions

View File

@@ -28,12 +28,7 @@ func init() {
transport.DialTimeout = time.Second transport.DialTimeout = time.Second
} }
type swarmWithBackoff interface { func closeSwarms(swarms []*Swarm) {
network.Network
Backoff() *DialBackoff
}
func closeSwarms(swarms []network.Network) {
for _, s := range swarms { for _, s := range swarms {
s.Close() s.Close()
} }
@@ -101,7 +96,7 @@ func TestSimultDials(t *testing.T) {
// connect everyone // connect everyone
{ {
var wg sync.WaitGroup var wg sync.WaitGroup
connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) {
// copy for other peer // copy for other peer
log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr)
s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL)
@@ -193,7 +188,7 @@ func TestDialWait(t *testing.T) {
t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts)
} }
if !s1.(swarmWithBackoff).Backoff().Backoff(s2p, s2addr) { if !s1.Backoff().Backoff(s2p, s2addr) {
t.Error("s2 should now be on backoff") t.Error("s2 should now be on backoff")
} }
} }
@@ -330,10 +325,10 @@ func TestDialBackoff(t *testing.T) {
} }
// check backoff state // check backoff state
if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) {
t.Error("s2 should not be on backoff") t.Error("s2 should not be on backoff")
} }
if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { if !s1.Backoff().Backoff(s3p, s3addr) {
t.Error("s3 should be on backoff") t.Error("s3 should be on backoff")
} }
@@ -400,10 +395,10 @@ func TestDialBackoff(t *testing.T) {
} }
// check backoff state (the same) // check backoff state (the same)
if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) {
t.Error("s2 should not be on backoff") t.Error("s2 should not be on backoff")
} }
if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { if !s1.Backoff().Backoff(s3p, s3addr) {
t.Error("s3 should be on backoff") t.Error("s3 should be on backoff")
} }
} }
@@ -445,7 +440,7 @@ func TestDialBackoffClears(t *testing.T) {
t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts)
} }
if !s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { if !s1.Backoff().Backoff(s2.LocalPeer(), s2bad) {
t.Error("s2 should now be on backoff") t.Error("s2 should now be on backoff")
} else { } else {
t.Log("correctly added to backoff") t.Log("correctly added to backoff")
@@ -472,7 +467,7 @@ func TestDialBackoffClears(t *testing.T) {
t.Log("correctly connected") t.Log("correctly connected")
} }
if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { if s1.Backoff().Backoff(s2.LocalPeer(), s2bad) {
t.Error("s2 should no longer be on backoff") t.Error("s2 should no longer be on backoff")
} else { } else {
t.Log("correctly cleared backoff") t.Log("correctly cleared backoff")

View File

@@ -9,6 +9,8 @@ import (
"github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/peerstore"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
. "github.com/libp2p/go-libp2p-swarm"
) )
func TestPeers(t *testing.T) { func TestPeers(t *testing.T) {
@@ -17,7 +19,7 @@ func TestPeers(t *testing.T) {
s1 := swarms[0] s1 := swarms[0]
s2 := swarms[1] s2 := swarms[1]
connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) {
// TODO: make a DialAddr func. // TODO: make a DialAddr func.
s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL)
// t.Logf("connections from %s", s.LocalPeer()) // t.Logf("connections from %s", s.LocalPeer())
@@ -53,7 +55,7 @@ func TestPeers(t *testing.T) {
log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers())
} }
test := func(s network.Network) { test := func(s *Swarm) {
expect := 1 expect := 1
actual := len(s.Peers()) actual := len(s.Peers())
if actual != expect { if actual != expect {

View File

@@ -7,12 +7,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/peerstore"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
. "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing" swarmt "github.com/libp2p/go-libp2p-swarm/testing"
"github.com/libp2p/go-libp2p-testing/ci" "github.com/libp2p/go-libp2p-testing/ci"
) )
@@ -24,7 +24,7 @@ func TestSimultOpen(t *testing.T) {
// connect everyone // connect everyone
{ {
var wg sync.WaitGroup var wg sync.WaitGroup
connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) {
defer wg.Done() defer wg.Done()
// copy for other peer // copy for other peer
log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr)

View File

@@ -94,7 +94,7 @@ func TestNotifications(t *testing.T) {
} }
} }
complement := func(c network.Conn) (network.Network, *netNotifiee, *Conn) { complement := func(c network.Conn) (*Swarm, *netNotifiee, *Conn) {
for i, s := range swarms { for i, s := range swarms {
for _, c2 := range s.Conns() { for _, c2 := range s.Conns() {
if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) &&

View File

@@ -58,14 +58,14 @@ func EchoStreamHandler(stream network.Stream) {
}() }()
} }
func makeDialOnlySwarm(t *testing.T) network.Network { func makeDialOnlySwarm(t *testing.T) *Swarm {
swarm := GenSwarm(t, OptDialOnly) swarm := GenSwarm(t, OptDialOnly)
swarm.SetStreamHandler(EchoStreamHandler) swarm.SetStreamHandler(EchoStreamHandler)
return swarm return swarm
} }
func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { func makeSwarms(t *testing.T, num int, opts ...Option) []*Swarm {
swarms := make([]network.Network, 0, num) swarms := make([]*Swarm, 0, num)
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
swarm := GenSwarm(t, opts...) swarm := GenSwarm(t, opts...)
swarm.SetStreamHandler(EchoStreamHandler) swarm.SetStreamHandler(EchoStreamHandler)
@@ -74,9 +74,9 @@ func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network {
return swarms return swarms
} }
func connectSwarms(t *testing.T, ctx context.Context, swarms []network.Network) { func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) {
var wg sync.WaitGroup var wg sync.WaitGroup
connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) {
// TODO: make a DialAddr func. // TODO: make a DialAddr func.
s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL)
if _, err := s.DialPeer(ctx, dst); err != nil { if _, err := s.DialPeer(ctx, dst); err != nil {

View File

@@ -71,7 +71,7 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option {
} }
// GenUpgrader creates a new connection upgrader for use with this swarm. // GenUpgrader creates a new connection upgrader for use with this swarm.
func GenUpgrader(n network.Network) *tptu.Upgrader { func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader {
id := n.LocalPeer() id := n.LocalPeer()
pk := n.Peerstore().PrivKey(id) pk := n.Peerstore().PrivKey(id)
secMuxer := new(csms.SSMuxer) secMuxer := new(csms.SSMuxer)
@@ -86,18 +86,8 @@ func GenUpgrader(n network.Network) *tptu.Upgrader {
} }
} }
type mSwarm struct {
*swarm.Swarm
ps peerstore.Peerstore
}
func (s *mSwarm) Close() error {
s.ps.Close()
return s.Swarm.Close()
}
// GenSwarm generates a new test swarm. // GenSwarm generates a new test swarm.
func GenSwarm(t *testing.T, opts ...Option) network.Network { func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm {
var cfg config var cfg config
for _, o := range opts { for _, o := range opts {
o(t, &cfg) o(t, &cfg)
@@ -121,10 +111,9 @@ func GenSwarm(t *testing.T, opts ...Option) network.Network {
ps := pstoremem.NewPeerstore() ps := pstoremem.NewPeerstore()
ps.AddPubKey(p.ID, p.PubKey) ps.AddPubKey(p.ID, p.PubKey)
ps.AddPrivKey(p.ID, p.PrivKey) ps.AddPrivKey(p.ID, p.PrivKey)
s := &mSwarm{ t.Cleanup(func() { ps.Close() })
Swarm: swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater),
ps: ps, s := swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater)
}
upgrader := GenUpgrader(s) upgrader := GenUpgrader(s)
upgrader.ConnGater = cfg.connectionGater upgrader.ConnGater = cfg.connectionGater

View File

@@ -7,7 +7,6 @@ import (
swarm "github.com/libp2p/go-libp2p-swarm" swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing" swarmt "github.com/libp2p/go-libp2p-swarm/testing"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/transport" "github.com/libp2p/go-libp2p-core/transport"
@@ -46,28 +45,19 @@ func (dt *dummyTransport) Close() error {
return nil return nil
} }
type swarmWithTransport interface {
network.Network
AddTransport(transport.Transport) error
}
func TestUselessTransport(t *testing.T) { func TestUselessTransport(t *testing.T) {
s := swarmt.GenSwarm(t) s := swarmt.GenSwarm(t)
err := s.(swarmWithTransport).AddTransport(new(dummyTransport)) require.Error(t, s.AddTransport(new(dummyTransport)), "adding a transport that supports no protocols should have failed")
if err == nil {
t.Fatal("adding a transport that supports no protocols should have failed")
}
} }
func TestTransportClose(t *testing.T) { func TestTransportClose(t *testing.T) {
s := swarmt.GenSwarm(t) s := swarmt.GenSwarm(t)
tpt := &dummyTransport{protocols: []int{1}} tpt := &dummyTransport{protocols: []int{1}}
require.NoError(t, s.(swarmWithTransport).AddTransport(tpt)) require.NoError(t, s.AddTransport(tpt))
_ = s.Close() _ = s.Close()
if !tpt.closed { if !tpt.closed {
t.Fatal("expected transport to be closed") t.Fatal("expected transport to be closed")
} }
} }
func TestTransportAfterClose(t *testing.T) { func TestTransportAfterClose(t *testing.T) {
@@ -75,7 +65,7 @@ func TestTransportAfterClose(t *testing.T) {
s.Close() s.Close()
tpt := &dummyTransport{protocols: []int{1}} tpt := &dummyTransport{protocols: []int{1}}
if err := s.(swarmWithTransport).AddTransport(tpt); err != swarm.ErrSwarmClosed { if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed {
t.Fatal("expected swarm closed error, got: ", err) t.Fatal("expected swarm closed error, got: ", err)
} }
} }