diff --git a/config/config.go b/config/config.go index 81cfff293..6e32ca44b 100644 --- a/config/config.go +++ b/config/config.go @@ -13,7 +13,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" - "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/routing" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autonat" @@ -71,7 +70,7 @@ type Config struct { PeerKey crypto.PrivKey Transports []fx.Option - Muxers []Muxer + Muxers []tptu.StreamMuxer SecurityTransports []fx.Option Insecure bool PSK pnet.PSK @@ -168,31 +167,21 @@ func (cfg *Config) addTransports(h host.Host) error { return fmt.Errorf("swarm does not support transports") } - muxers := make([]protocol.ID, 0, len(cfg.Muxers)) - for _, m := range cfg.Muxers { - muxers = append(muxers, m.ID) - } - var security []fx.Option if cfg.Insecure { security = append(security, fx.Provide(makeInsecureTransport)) } else { security = cfg.SecurityTransports } - muxer, err := makeMuxer(cfg.Muxers) - if err != nil { - return err - } fxopts := []fx.Option{ fx.WithLogger(func() fxevent.Logger { return getFXLogger() }), fx.Provide(tptu.New), - fx.Provide(func() network.Multiplexer { return muxer }), fx.Provide(fx.Annotate( makeSecurityMuxer, fx.ParamTags(`group:"security"`), )), - fx.Supply(muxers), + fx.Supply(cfg.Muxers), fx.Provide(func() host.Host { return h }), fx.Provide(func() crypto.PrivKey { return h.Peerstore().PrivKey(h.ID()) }), fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }), diff --git a/config/muxer.go b/config/muxer.go deleted file mode 100644 index 448db65e8..000000000 --- a/config/muxer.go +++ /dev/null @@ -1,29 +0,0 @@ -package config - -import ( - "fmt" - - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/protocol" - msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" -) - -type Muxer struct { - ID protocol.ID - Multiplexer network.Multiplexer -} - -func makeMuxer(muxers []Muxer) (network.Multiplexer, error) { - muxMuxer := msmux.NewBlankTransport() - transportSet := make(map[protocol.ID]struct{}, len(muxers)) - for _, m := range muxers { - if _, ok := transportSet[m.ID]; ok { - return nil, fmt.Errorf("duplicate muxer transport: %s", m.ID) - } - transportSet[m.ID] = struct{}{} - } - for _, m := range muxers { - muxMuxer.AddTransport(string(m.ID), m.Multiplexer) - } - return muxMuxer, nil -} diff --git a/options.go b/options.go index aac614698..0ff63b5f3 100644 --- a/options.go +++ b/options.go @@ -23,6 +23,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autorelay" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" @@ -106,7 +107,7 @@ var NoSecurity Option = func(cfg *Config) error { // name is the protocol name. func Muxer(name string, muxer network.Multiplexer) Option { return func(cfg *Config) error { - cfg.Muxers = append(cfg.Muxers, config.Muxer{Multiplexer: muxer, ID: protocol.ID(name)}) + cfg.Muxers = append(cfg.Muxers, tptu.StreamMuxer{Muxer: muxer, ID: protocol.ID(name)}) return nil } } diff --git a/p2p/muxer/muxer-multistream/multistream.go b/p2p/muxer/muxer-multistream/multistream.go deleted file mode 100644 index e81ae0ded..000000000 --- a/p2p/muxer/muxer-multistream/multistream.go +++ /dev/null @@ -1,80 +0,0 @@ -// Package muxer_multistream implements a peerstream transport using -// go-multistream to select the underlying stream muxer -package muxer_multistream - -import ( - "fmt" - "net" - "time" - - "github.com/libp2p/go-libp2p/core/network" - - mss "github.com/multiformats/go-multistream" -) - -var DefaultNegotiateTimeout = time.Second * 60 - -type Transport struct { - mux *mss.MultistreamMuxer - - tpts map[string]network.Multiplexer - - NegotiateTimeout time.Duration - - OrderPreference []string -} - -func NewBlankTransport() *Transport { - return &Transport{ - mux: mss.NewMultistreamMuxer(), - tpts: make(map[string]network.Multiplexer), - NegotiateTimeout: DefaultNegotiateTimeout, - } -} - -func (t *Transport) AddTransport(path string, tpt network.Multiplexer) { - t.mux.AddHandler(path, nil) - t.tpts[path] = tpt - t.OrderPreference = append(t.OrderPreference, path) -} - -func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) (network.MuxedConn, error) { - if t.NegotiateTimeout != 0 { - if err := nc.SetDeadline(time.Now().Add(t.NegotiateTimeout)); err != nil { - return nil, err - } - } - - var proto string - if isServer { - selected, _, err := t.mux.Negotiate(nc) - if err != nil { - return nil, err - } - proto = selected - } else { - selected, err := mss.SelectOneOf(t.OrderPreference, nc) - if err != nil { - return nil, err - } - proto = selected - } - - if t.NegotiateTimeout != 0 { - if err := nc.SetDeadline(time.Time{}); err != nil { - return nil, err - } - } - - tpt, ok := t.tpts[proto] - if !ok { - return nil, fmt.Errorf("selected protocol we don't have a transport for") - } - - return tpt.NewConn(nc, isServer, scope) -} - -func (t *Transport) GetTransportByKey(key string) (network.Multiplexer, bool) { - val, ok := t.tpts[key] - return val, ok -} diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 728c81da3..14ff3d64d 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -15,7 +15,6 @@ import ( "github.com/libp2p/go-libp2p/core/sec/insecure" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" - msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" @@ -79,9 +78,7 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { secMuxer := new(csms.SSMuxer) secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, pk)) - stMuxer := msmux.NewBlankTransport() - stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) - u, err := tptu.New(secMuxer, stMuxer, nil, nil, nil) + u, err := tptu.New(secMuxer, []tptu.StreamMuxer{{ID: "/yamux/1.0.0", Muxer: yamux.DefaultTransport}}, nil, nil, nil) require.NoError(t, err) return u } diff --git a/p2p/net/swarm/testing/testing.go b/p2p/net/swarm/testing/testing.go index a28b488ef..8b9420c44 100644 --- a/p2p/net/swarm/testing/testing.go +++ b/p2p/net/swarm/testing/testing.go @@ -15,7 +15,6 @@ import ( "github.com/libp2p/go-libp2p/core/sec/insecure" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" - msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" "github.com/libp2p/go-libp2p/p2p/net/swarm" @@ -105,9 +104,7 @@ func GenUpgrader(t *testing.T, n *swarm.Swarm, connGater connmgr.ConnectionGater secMuxer := new(csms.SSMuxer) secMuxer.AddTransport(insecure.ID, insecure.NewWithIdentity(insecure.ID, id, pk)) - stMuxer := msmux.NewBlankTransport() - stMuxer.AddTransport("/yamux/1.0.0", yamux.DefaultTransport) - u, err := tptu.New(secMuxer, stMuxer, nil, nil, connGater, opts...) + u, err := tptu.New(secMuxer, []tptu.StreamMuxer{{ID: "/yamux/1.0.0", Muxer: yamux.DefaultTransport}}, nil, nil, connGater, opts...) require.NoError(t, err) return u } diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 67fb292b1..5b5410753 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -134,7 +134,7 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) { func TestFailedUpgradeOnListen(t *testing.T) { require := require.New(t) - id, u := createUpgraderWithMuxer(t, &errorMuxer{}, nil, nil) + id, u := createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "errorMuxer", Muxer: &errorMuxer{}}}, nil, nil) ln := createListener(t, u) errCh := make(chan error) @@ -225,7 +225,7 @@ func TestConcurrentAccept(t *testing.T) { var num = 3 * upgrader.AcceptQueueLength blockingMuxer := newBlockingMuxer() - id, u := createUpgraderWithMuxer(t, blockingMuxer, nil, nil) + id, u := createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "blockingMuxer", Muxer: blockingMuxer}}, nil, nil) ln := createListener(t, u) defer ln.Close() diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index f18972068..5351c3f2f 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -11,12 +11,13 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" ipnet "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" - msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" "github.com/libp2p/go-libp2p/p2p/net/pnet" manet "github.com/multiformats/go-multiaddr/net" + mss "github.com/multiformats/go-multistream" ) // ErrNilPeer is returned when attempting to upgrade an outbound connection @@ -26,7 +27,10 @@ var ErrNilPeer = errors.New("nil peer") // AcceptQueueLength is the number of connections to fully setup before not accepting any new connections var AcceptQueueLength = 16 -const defaultAcceptTimeout = 15 * time.Second +const ( + defaultAcceptTimeout = 15 * time.Second + defaultNegotiateTimeout = 60 * time.Second +) type Option func(*upgrader) error @@ -37,16 +41,24 @@ func WithAcceptTimeout(t time.Duration) Option { } } +type StreamMuxer struct { + ID protocol.ID + Muxer network.Multiplexer +} + // Upgrader is a multistream upgrader that can upgrade an underlying connection // to a full transport connection (secure and multiplexed). type upgrader struct { secure sec.SecureMuxer - muxer network.Multiplexer psk ipnet.PSK connGater connmgr.ConnectionGater rcmgr network.ResourceManager + msmuxer *mss.MultistreamMuxer + muxers []StreamMuxer + muxerIDs []string + // AcceptTimeout is the maximum duration an Accept is allowed to take. // This includes the time between accepting the raw network connection, // protocol selection as well as the handshake, if applicable. @@ -57,14 +69,15 @@ type upgrader struct { var _ transport.Upgrader = &upgrader{} -func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, psk ipnet.PSK, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...Option) (transport.Upgrader, error) { +func New(secureMuxer sec.SecureMuxer, muxers []StreamMuxer, psk ipnet.PSK, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...Option) (transport.Upgrader, error) { u := &upgrader{ secure: secureMuxer, - muxer: muxer, acceptTimeout: defaultAcceptTimeout, rcmgr: rcmgr, connGater: connGater, psk: psk, + msmuxer: mss.NewMultistreamMuxer(), + muxers: muxers, } for _, opt := range opts { if err := opt(u); err != nil { @@ -74,6 +87,11 @@ func New(secureMuxer sec.SecureMuxer, muxer network.Multiplexer, psk ipnet.PSK, if u.rcmgr == nil { u.rcmgr = &network.NullResourceManager{} } + u.muxerIDs = make([]string, 0, len(muxers)) + for _, m := range muxers { + u.msmuxer.AddHandler(string(m.ID), nil) + u.muxerIDs = append(u.muxerIDs, string(m.ID)) + } return u, nil } @@ -177,17 +195,54 @@ func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, return u.secure.SecureOutbound(ctx, conn, p) } +func (u *upgrader) negotiateMuxer(nc net.Conn, isServer bool) (*StreamMuxer, error) { + if err := nc.SetDeadline(time.Now().Add(defaultNegotiateTimeout)); err != nil { + return nil, err + } + + var proto string + if isServer { + selected, _, err := u.msmuxer.Negotiate(nc) + if err != nil { + return nil, err + } + proto = selected + } else { + selected, err := mss.SelectOneOf(u.muxerIDs, nc) + if err != nil { + return nil, err + } + proto = selected + } + + if err := nc.SetDeadline(time.Time{}); err != nil { + return nil, err + } + + if m := u.getMuxerByID(proto); m != nil { + return m, nil + } + return nil, fmt.Errorf("selected protocol we don't have a transport for") +} + +func (u *upgrader) getMuxerByID(id string) *StreamMuxer { + for _, m := range u.muxers { + if string(m.ID) == id { + return &m + } + } + return nil +} + func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server bool, scope network.PeerScope) (network.MuxedConn, error) { - msmuxer, ok := u.muxer.(*msmux.Transport) muxerSelected := conn.ConnState().NextProto // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. - if ok && len(muxerSelected) > 0 { - tpt, ok := msmuxer.GetTransportByKey(muxerSelected) - if !ok { + if len(muxerSelected) > 0 { + m := u.getMuxerByID(muxerSelected) + if m == nil { return nil, fmt.Errorf("selected a muxer we don't know: %s", muxerSelected) } - - return tpt.NewConn(conn, server, scope) + return m.Muxer.NewConn(conn, server, scope) } done := make(chan struct{}) @@ -197,7 +252,12 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b // TODO: The muxer should take a context. go func() { defer close(done) - smconn, err = u.muxer.NewConn(conn, server, scope) + var m *StreamMuxer + m, err = u.negotiateMuxer(conn, server) + if err != nil { + return + } + smconn, err = m.Muxer.NewConn(conn, server, scope) }() select { diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go index d39d36022..106752ab6 100644 --- a/p2p/net/upgrader/upgrader_test.go +++ b/p2p/net/upgrader/upgrader_test.go @@ -24,27 +24,27 @@ import ( ) func createUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { - return createUpgraderWithMuxer(t, &negotiatingMuxer{}, nil, nil) + return createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "negotiate", Muxer: &negotiatingMuxer{}}}, nil, nil) } func createUpgraderWithConnGater(t *testing.T, connGater connmgr.ConnectionGater) (peer.ID, transport.Upgrader) { - return createUpgraderWithMuxer(t, &negotiatingMuxer{}, nil, connGater) + return createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "negotiate", Muxer: &negotiatingMuxer{}}}, nil, connGater) } func createUpgraderWithResourceManager(t *testing.T, rcmgr network.ResourceManager) (peer.ID, transport.Upgrader) { - return createUpgraderWithMuxer(t, &negotiatingMuxer{}, rcmgr, nil) + return createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "negotiate", Muxer: &negotiatingMuxer{}}}, rcmgr, nil) } func createUpgraderWithOpts(t *testing.T, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { - return createUpgraderWithMuxer(t, &negotiatingMuxer{}, nil, nil, opts...) + return createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "negotiate", Muxer: &negotiatingMuxer{}}}, nil, nil, opts...) } -func createUpgraderWithMuxer(t *testing.T, muxer network.Multiplexer, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { +func createUpgraderWithMuxers(t *testing.T, muxers []upgrader.StreamMuxer, rcmgr network.ResourceManager, connGater connmgr.ConnectionGater, opts ...upgrader.Option) (peer.ID, transport.Upgrader) { priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256) require.NoError(t, err) id, err := peer.IDFromPrivateKey(priv) require.NoError(t, err) - u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(insecure.ID, id, priv)}, muxer, nil, rcmgr, connGater, opts...) + u, err := upgrader.New(&MuxAdapter{tpt: insecure.NewWithIdentity(insecure.ID, id, priv)}, muxers, nil, rcmgr, connGater, opts...) require.NoError(t, err) return id, u } @@ -177,7 +177,7 @@ func TestOutboundResourceManagement(t *testing.T) { }) t.Run("failed negotiation", func(t *testing.T) { - id, upgrader := createUpgraderWithMuxer(t, &errorMuxer{}, nil, nil) + id, upgrader := createUpgraderWithMuxers(t, []upgrader.StreamMuxer{{ID: "errorMuxer", Muxer: &errorMuxer{}}}, nil, nil) ln := createListener(t, upgrader) defer ln.Close() diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index e436c46aa..e5640197e 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -9,6 +9,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" manet "github.com/multiformats/go-multiaddr/net" @@ -29,15 +30,15 @@ var _ sec.SecureTransport = &Transport{} // New creates a new Noise transport using the given private key as its // libp2p identity key. -func New(id protocol.ID, privkey crypto.PrivKey, muxers []protocol.ID) (*Transport, error) { +func New(id protocol.ID, privkey crypto.PrivKey, muxers []tptu.StreamMuxer) (*Transport, error) { localID, err := peer.IDFromPrivateKey(privkey) if err != nil { return nil, err } smuxers := make([]string, 0, len(muxers)) - for _, muxer := range muxers { - smuxers = append(smuxers, string(muxer)) + for _, m := range muxers { + smuxers = append(smuxers, string(m.ID)) } return &Transport{ diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index 754a33513..bb134e960 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -15,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" manet "github.com/multiformats/go-multiaddr/net" ) @@ -35,16 +36,20 @@ type Transport struct { var _ sec.SecureTransport = &Transport{} // New creates a TLS encrypted transport -func New(id protocol.ID, key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { +func New(id protocol.ID, key ci.PrivKey, muxers []tptu.StreamMuxer) (*Transport, error) { localPeer, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err } + muxerIDs := make([]protocol.ID, 0, len(muxers)) + for _, m := range muxers { + muxerIDs = append(muxerIDs, m.ID) + } t := &Transport{ protocolID: id, localPeer: localPeer, privKey: key, - muxers: muxers, + muxers: muxerIDs, } identity, err := NewIdentity(key) diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 8c9fa7ced..986dc63cd 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -24,6 +24,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" + tptu "github.com/libp2p/go-libp2p/p2p/net/upgrader" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -183,16 +184,15 @@ type testcase struct { } func TestHandshakeWithNextProtoSucceeds(t *testing.T) { - tests := []testcase{ {clientProtos: nil, serverProtos: nil, expectedResult: ""}, - {[]protocol.ID{"muxer1/1.0.0", "muxer2/1.0.1"}, []protocol.ID{"muxer2/1.0.1", "muxer1/1.0.0"}, "muxer2/1.0.1"}, - {[]protocol.ID{"muxer1/1.0.0", "muxer2/1.0.1", "libp2p"}, []protocol.ID{"muxer2/1.0.1", "muxer1/1.0.0", "libp2p"}, "muxer2/1.0.1"}, - {[]protocol.ID{"muxer1/1.0.0", "libp2p"}, []protocol.ID{"libp2p"}, ""}, - {[]protocol.ID{"libp2p"}, []protocol.ID{"libp2p"}, ""}, - {[]protocol.ID{"muxer1"}, []protocol.ID{}, ""}, - {[]protocol.ID{}, []protocol.ID{"muxer1"}, ""}, - {[]protocol.ID{"muxer2"}, []protocol.ID{"muxer1"}, ""}, + {clientProtos: []protocol.ID{"muxer1", "muxer2"}, serverProtos: []protocol.ID{"muxer2", "muxer1"}, expectedResult: "muxer2"}, + {clientProtos: []protocol.ID{"muxer1", "muxer2", "libp2p"}, serverProtos: []protocol.ID{"muxer2", "muxer1", "libp2p"}, expectedResult: "muxer2"}, + {clientProtos: []protocol.ID{"muxer1", "libp2p"}, serverProtos: []protocol.ID{"libp2p"}, expectedResult: ""}, + {clientProtos: []protocol.ID{"libp2p"}, serverProtos: []protocol.ID{"libp2p"}, expectedResult: ""}, + {clientProtos: []protocol.ID{"muxer1"}, serverProtos: []protocol.ID{}, expectedResult: ""}, + {clientProtos: []protocol.ID{}, serverProtos: []protocol.ID{"muxer1"}, expectedResult: ""}, + {clientProtos: []protocol.ID{"muxer2"}, serverProtos: []protocol.ID{"muxer1"}, expectedResult: ""}, } clientID, clientKey := createPeer(t) @@ -240,9 +240,17 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { // Iterate through the NextProto combinations. for _, test := range tests { - clientTransport, err := New(ID, clientKey, test.clientProtos) + clientMuxers := make([]tptu.StreamMuxer, 0, len(test.clientProtos)) + for _, id := range test.clientProtos { + clientMuxers = append(clientMuxers, tptu.StreamMuxer{ID: id}) + } + clientTransport, err := New(ID, clientKey, clientMuxers) require.NoError(t, err) - serverTransport, err := New(ID, serverKey, test.serverProtos) + serverMuxers := make([]tptu.StreamMuxer, 0, len(test.clientProtos)) + for _, id := range test.serverProtos { + serverMuxers = append(serverMuxers, tptu.StreamMuxer{ID: id}) + } + serverTransport, err := New(ID, serverKey, serverMuxers) require.NoError(t, err) t.Run("TLS handshake with ALPN extension", func(t *testing.T) { diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index eec1657dd..920404632 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -22,16 +22,18 @@ import ( "github.com/stretchr/testify/require" ) +var muxers = []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}} + func TestTcpTransport(t *testing.T) { for i := 0; i < 2; i++ { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - ua, err := tptu.New(ia, yamux.DefaultTransport, nil, nil, nil) + ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) ta, err := NewTCPTransport(ua, nil) require.NoError(t, err) - ub, err := tptu.New(ib, yamux.DefaultTransport, nil, nil, nil) + ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) tb, err := NewTCPTransport(ub, nil) require.NoError(t, err) @@ -48,11 +50,11 @@ func TestTcpTransportWithMetrics(t *testing.T) { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - ua, err := tptu.New(ia, yamux.DefaultTransport, nil, nil, nil) + ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) ta, err := NewTCPTransport(ua, nil, WithMetrics()) require.NoError(t, err) - ub, err := tptu.New(ib, yamux.DefaultTransport, nil, nil, nil) + ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) tb, err := NewTCPTransport(ub, nil, WithMetrics()) require.NoError(t, err) @@ -68,7 +70,7 @@ func TestResourceManager(t *testing.T) { peerA, ia := makeInsecureMuxer(t) _, ib := makeInsecureMuxer(t) - ua, err := tptu.New(ia, yamux.DefaultTransport, nil, nil, nil) + ua, err := tptu.New(ia, muxers, nil, nil, nil) require.NoError(t, err) ta, err := NewTCPTransport(ua, nil) require.NoError(t, err) @@ -76,7 +78,7 @@ func TestResourceManager(t *testing.T) { require.NoError(t, err) defer ln.Close() - ub, err := tptu.New(ib, yamux.DefaultTransport, nil, nil, nil) + ub, err := tptu.New(ib, muxers, nil, nil, nil) require.NoError(t, err) rcmgr := mocknetwork.NewMockResourceManager(ctrl) tb, err := NewTCPTransport(ub, rcmgr) diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 714fe89a6..016dbe59b 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -39,7 +39,7 @@ import ( func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { t.Helper() id, m := newInsecureMuxer(t) - u, err := tptu.New(m, yamux.DefaultTransport, nil, nil, nil) + u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil) if err != nil { t.Fatal(err) } @@ -49,7 +49,7 @@ func newUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { func newSecureUpgrader(t *testing.T) (peer.ID, transport.Upgrader) { t.Helper() id, m := newSecureMuxer(t) - u, err := tptu.New(m, yamux.DefaultTransport, nil, nil, nil) + u, err := tptu.New(m, []tptu.StreamMuxer{{ID: "/yamux", Muxer: yamux.DefaultTransport}}, nil, nil, nil) if err != nil { t.Fatal(err) }