diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 67a61253e..a4518c4f3 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -190,12 +190,21 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I var lastErr error for _, pid := range pids { s, err := h.newStream(ctx, p, pid) - if err == nil { - h.setPreferredProtocol(p, pid) - return s, nil + if err != nil { + lastErr = err + log.Infof("NewStream to %s for %s failed: %s", p, pid, err) + continue } - lastErr = err - log.Infof("NewStream to %s for %s failed: %s", p, pid, err) + + _, err = s.Read(nil) + if err != nil { + lastErr = err + log.Infof("NewStream to %s for %s failed (on read): %s", p, pid, err) + continue + } + + h.setPreferredProtocol(p, pid) + return s, nil } return nil, lastErr @@ -248,6 +257,8 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) ( return nil, err } + s.SetProtocol(string(pid)) + logStream := mstream.WrapStream(s, pid, h.bwc) lzcon := msmux.NewMSSelect(logStream, string(pid)) diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index a7c893ef8..c52d9279c 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -4,7 +4,9 @@ import ( "bytes" "io" "testing" + "time" + host "github.com/libp2p/go-libp2p/p2p/host" inet "github.com/libp2p/go-libp2p/p2p/net" protocol "github.com/libp2p/go-libp2p/p2p/protocol" testutil "github.com/libp2p/go-libp2p/p2p/test/util" @@ -61,3 +63,156 @@ func TestHostSimple(t *testing.T) { t.Fatal("buf1 != buf3 -- %x != %x", buf1, buf3) } } + +func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) { + h1 := testutil.GenHostSwarm(t, ctx) + h2 := testutil.GenHostSwarm(t, ctx) + + h2pi := h2.Peerstore().PeerInfo(h2.ID()) + if err := h1.Connect(ctx, h2pi); err != nil { + t.Fatal(err) + } + + return h1, h2 +} + +func assertWait(t *testing.T, c chan string, exp string) { + select { + case proto := <-c: + if proto != exp { + t.Fatal("should have connected on ", exp) + } + case <-time.After(time.Second * 5): + t.Fatal("timeout waiting for stream") + } +} + +func TestHostProtoPreference(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1, h2 := getHostPair(ctx, t) + defer h1.Close() + defer h2.Close() + + protoOld := protocol.ID("/testing") + protoNew := protocol.ID("/testing/1.1.0") + protoMinor := protocol.ID("/testing/1.2.0") + + connectedOn := make(chan string, 16) + + handler := func(s inet.Stream) { + connectedOn <- s.Protocol() + s.Close() + } + + h1.SetStreamHandler(protoOld, handler) + + s, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld) + if err != nil { + t.Fatal(err) + } + + assertWait(t, connectedOn, string(protoOld)) + s.Close() + + mfunc, err := host.MultistreamSemverMatcher(string(protoMinor)) + if err != nil { + t.Fatal(err) + } + + h1.SetStreamHandlerMatch(protoMinor, mfunc, handler) + + // remembered preference will be chosen first, even when the other side newly supports it + s2, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld) + if err != nil { + t.Fatal(err) + } + + // required to force 'lazy' handshake + _, err = s2.Write([]byte("hello")) + if err != nil { + t.Fatal(err) + } + + assertWait(t, connectedOn, string(protoOld)) + + s2.Close() + + s3, err := h2.NewStream(ctx, h1.ID(), protoMinor) + if err != nil { + t.Fatal(err) + } + + _, err = s3.Read(nil) + if err != nil { + t.Fatal(err) + } + + assertWait(t, connectedOn, string(protoMinor)) + s3.Close() +} + +func TestHostProtoMismatch(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1, h2 := getHostPair(ctx, t) + defer h1.Close() + defer h2.Close() + + h1.SetStreamHandler("/super", func(s inet.Stream) { + t.Error("shouldnt get here") + s.Close() + }) + + _, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/baz/1.0.0") + if err == nil { + t.Fatal("expected new stream to fail") + } +} + +func TestHostProtoPreknowledge(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1 := testutil.GenHostSwarm(t, ctx) + h2 := testutil.GenHostSwarm(t, ctx) + + conn := make(chan string, 16) + handler := func(s inet.Stream) { + conn <- s.Protocol() + s.Close() + } + + h1.SetStreamHandler("/super", handler) + + h2pi := h2.Peerstore().PeerInfo(h2.ID()) + if err := h1.Connect(ctx, h2pi); err != nil { + t.Fatal(err) + } + defer h1.Close() + defer h2.Close() + + h1.SetStreamHandler("/foo", handler) + + s, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/super") + if err != nil { + t.Fatal(err) + } + + select { + case <-conn: + t.Fatal("shouldnt have gotten connection yet, we should have a lazy stream") + case <-time.After(time.Millisecond * 50): + } + + _, err = s.Read(nil) + if err != nil { + t.Fatal(err) + } + + assertWait(t, conn, "/super") + + s.Close() +}