diff --git a/p2p/net/swarm/dial_test.go b/p2p/net/swarm/dial_test.go index de41e7f74..5fe4bdc22 100644 --- a/p2p/net/swarm/dial_test.go +++ b/p2p/net/swarm/dial_test.go @@ -28,12 +28,7 @@ func init() { transport.DialTimeout = time.Second } -type swarmWithBackoff interface { - network.Network - Backoff() *DialBackoff -} - -func closeSwarms(swarms []network.Network) { +func closeSwarms(swarms []*Swarm) { for _, s := range swarms { s.Close() } @@ -101,7 +96,7 @@ func TestSimultDials(t *testing.T) { // connect everyone { var wg sync.WaitGroup - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) s.Peerstore().AddAddr(dst, addr, peerstore.TempAddrTTL) @@ -193,7 +188,7 @@ func TestDialWait(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.(swarmWithBackoff).Backoff().Backoff(s2p, s2addr) { + if !s1.Backoff().Backoff(s2p, s2addr) { t.Error("s2 should now be on backoff") } } @@ -330,10 +325,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state - if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { + if !s1.Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } @@ -400,10 +395,10 @@ func TestDialBackoff(t *testing.T) { } // check backoff state (the same) - if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { + if s1.Backoff().Backoff(s2.LocalPeer(), s2addrs[0]) { t.Error("s2 should not be on backoff") } - if !s1.(swarmWithBackoff).Backoff().Backoff(s3p, s3addr) { + if !s1.Backoff().Backoff(s3p, s3addr) { t.Error("s3 should be on backoff") } } @@ -445,7 +440,7 @@ func TestDialBackoffClears(t *testing.T) { t.Error("> 2*transport.DialTimeout * DialAttempts not being respected", duration, 2*transport.DialTimeout*DialAttempts) } - if !s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { + if !s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should now be on backoff") } else { t.Log("correctly added to backoff") @@ -472,7 +467,7 @@ func TestDialBackoffClears(t *testing.T) { t.Log("correctly connected") } - if s1.(swarmWithBackoff).Backoff().Backoff(s2.LocalPeer(), s2bad) { + if s1.Backoff().Backoff(s2.LocalPeer(), s2bad) { t.Error("s2 should no longer be on backoff") } else { t.Log("correctly cleared backoff") diff --git a/p2p/net/swarm/peers_test.go b/p2p/net/swarm/peers_test.go index 908abe91a..3145d862b 100644 --- a/p2p/net/swarm/peers_test.go +++ b/p2p/net/swarm/peers_test.go @@ -9,6 +9,8 @@ import ( "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" + + . "github.com/libp2p/go-libp2p-swarm" ) func TestPeers(t *testing.T) { @@ -17,7 +19,7 @@ func TestPeers(t *testing.T) { s1 := swarms[0] s2 := swarms[1] - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) // t.Logf("connections from %s", s.LocalPeer()) @@ -53,7 +55,7 @@ func TestPeers(t *testing.T) { log.Infof("%s swarm routing table: %s", s.LocalPeer(), s.Peers()) } - test := func(s network.Network) { + test := func(s *Swarm) { expect := 1 actual := len(s.Peers()) if actual != expect { diff --git a/p2p/net/swarm/simul_test.go b/p2p/net/swarm/simul_test.go index 326c4e21b..aa4eb5909 100644 --- a/p2p/net/swarm/simul_test.go +++ b/p2p/net/swarm/simul_test.go @@ -7,12 +7,12 @@ import ( "testing" "time" - "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" ma "github.com/multiformats/go-multiaddr" + . "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" "github.com/libp2p/go-libp2p-testing/ci" ) @@ -24,7 +24,7 @@ func TestSimultOpen(t *testing.T) { // connect everyone { var wg sync.WaitGroup - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { defer wg.Done() // copy for other peer log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.LocalPeer(), dst, addr) diff --git a/p2p/net/swarm/swarm_notif_test.go b/p2p/net/swarm/swarm_notif_test.go index 157405cba..c0c6f82db 100644 --- a/p2p/net/swarm/swarm_notif_test.go +++ b/p2p/net/swarm/swarm_notif_test.go @@ -94,7 +94,7 @@ func TestNotifications(t *testing.T) { } } - complement := func(c network.Conn) (network.Network, *netNotifiee, *Conn) { + complement := func(c network.Conn) (*Swarm, *netNotifiee, *Conn) { for i, s := range swarms { for _, c2 := range s.Conns() { if c.LocalMultiaddr().Equal(c2.RemoteMultiaddr()) && diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index e6cb0ba30..9f799f663 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -58,14 +58,14 @@ func EchoStreamHandler(stream network.Stream) { }() } -func makeDialOnlySwarm(t *testing.T) network.Network { +func makeDialOnlySwarm(t *testing.T) *Swarm { swarm := GenSwarm(t, OptDialOnly) swarm.SetStreamHandler(EchoStreamHandler) return swarm } -func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { - swarms := make([]network.Network, 0, num) +func makeSwarms(t *testing.T, num int, opts ...Option) []*Swarm { + swarms := make([]*Swarm, 0, num) for i := 0; i < num; i++ { swarm := GenSwarm(t, opts...) swarm.SetStreamHandler(EchoStreamHandler) @@ -74,9 +74,9 @@ func makeSwarms(t *testing.T, num int, opts ...Option) []network.Network { return swarms } -func connectSwarms(t *testing.T, ctx context.Context, swarms []network.Network) { +func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) { var wg sync.WaitGroup - connect := func(s network.Network, dst peer.ID, addr ma.Multiaddr) { + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { // TODO: make a DialAddr func. s.Peerstore().AddAddr(dst, addr, peerstore.PermanentAddrTTL) if _, err := s.DialPeer(ctx, dst); err != nil { diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index d6091354b..201b4f0f1 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -71,7 +71,7 @@ func OptPeerPrivateKey(sk crypto.PrivKey) Option { } // GenUpgrader creates a new connection upgrader for use with this swarm. -func GenUpgrader(n network.Network) *tptu.Upgrader { +func GenUpgrader(n *swarm.Swarm) *tptu.Upgrader { id := n.LocalPeer() pk := n.Peerstore().PrivKey(id) secMuxer := new(csms.SSMuxer) @@ -86,18 +86,8 @@ func GenUpgrader(n network.Network) *tptu.Upgrader { } } -type mSwarm struct { - *swarm.Swarm - ps peerstore.Peerstore -} - -func (s *mSwarm) Close() error { - s.ps.Close() - return s.Swarm.Close() -} - // GenSwarm generates a new test swarm. -func GenSwarm(t *testing.T, opts ...Option) network.Network { +func GenSwarm(t *testing.T, opts ...Option) *swarm.Swarm { var cfg config for _, o := range opts { o(t, &cfg) @@ -121,10 +111,9 @@ func GenSwarm(t *testing.T, opts ...Option) network.Network { ps := pstoremem.NewPeerstore() ps.AddPubKey(p.ID, p.PubKey) ps.AddPrivKey(p.ID, p.PrivKey) - s := &mSwarm{ - Swarm: swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater), - ps: ps, - } + t.Cleanup(func() { ps.Close() }) + + s := swarm.NewSwarm(p.ID, ps, metrics.NewBandwidthCounter(), cfg.connectionGater) upgrader := GenUpgrader(s) upgrader.ConnGater = cfg.connectionGater diff --git a/p2p/net/swarm/transport_test.go b/p2p/net/swarm/transport_test.go index 527260265..6d5913cf5 100644 --- a/p2p/net/swarm/transport_test.go +++ b/p2p/net/swarm/transport_test.go @@ -7,7 +7,6 @@ import ( swarm "github.com/libp2p/go-libp2p-swarm" swarmt "github.com/libp2p/go-libp2p-swarm/testing" - "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/transport" @@ -46,28 +45,19 @@ func (dt *dummyTransport) Close() error { return nil } -type swarmWithTransport interface { - network.Network - AddTransport(transport.Transport) error -} - func TestUselessTransport(t *testing.T) { s := swarmt.GenSwarm(t) - err := s.(swarmWithTransport).AddTransport(new(dummyTransport)) - if err == nil { - t.Fatal("adding a transport that supports no protocols should have failed") - } + require.Error(t, s.AddTransport(new(dummyTransport)), "adding a transport that supports no protocols should have failed") } func TestTransportClose(t *testing.T) { s := swarmt.GenSwarm(t) tpt := &dummyTransport{protocols: []int{1}} - require.NoError(t, s.(swarmWithTransport).AddTransport(tpt)) + require.NoError(t, s.AddTransport(tpt)) _ = s.Close() if !tpt.closed { t.Fatal("expected transport to be closed") } - } func TestTransportAfterClose(t *testing.T) { @@ -75,7 +65,7 @@ func TestTransportAfterClose(t *testing.T) { s.Close() tpt := &dummyTransport{protocols: []int{1}} - if err := s.(swarmWithTransport).AddTransport(tpt); err != swarm.ErrSwarmClosed { + if err := s.AddTransport(tpt); err != swarm.ErrSwarmClosed { t.Fatal("expected swarm closed error, got: ", err) } }