diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index a4518c4f3..3eb6d725b 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -116,10 +116,10 @@ func (h *BasicHost) newStreamHandler(s inet.Stream) { } return } + s.SetProtocol(protocol.ID(protoID)) - logStream := mstream.WrapStream(s, protocol.ID(protoID), h.bwc) + logStream := mstream.WrapStream(s, h.bwc) - s.SetProtocol(protoID) go handle(protoID, logStream) } @@ -155,7 +155,7 @@ func (h *BasicHost) IDService() *identify.IDService { func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler) { h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error { is := rwc.(inet.Stream) - is.SetProtocol(p) + is.SetProtocol(protocol.ID(p)) handler(is) return nil }) @@ -166,7 +166,7 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler inet.StreamHandler func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler inet.StreamHandler) { h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error { is := rwc.(inet.Stream) - is.SetProtocol(p) + is.SetProtocol(protocol.ID(p)) handler(is) return nil }) @@ -187,27 +187,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return h.newStream(ctx, p, pref) } - var lastErr error + var protoStrs []string for _, pid := range pids { - s, err := h.newStream(ctx, p, pid) - if err != nil { - lastErr = err - log.Infof("NewStream to %s for %s failed: %s", p, pid, err) - continue - } - - _, err = s.Read(nil) - if err != nil { - lastErr = err - log.Infof("NewStream to %s for %s failed (on read): %s", p, pid, err) - continue - } - - h.setPreferredProtocol(p, pid) - return s, nil + protoStrs = append(protoStrs, string(pid)) } - return nil, lastErr + s, err := h.Network().NewStream(ctx, p) + if err != nil { + return nil, err + } + + selected, err := msmux.SelectOneOf(protoStrs, s) + if err != nil { + s.Close() + return nil, err + } + selpid := protocol.ID(selected) + s.SetProtocol(selpid) + h.setPreferredProtocol(p, selpid) + + return mstream.WrapStream(s, h.bwc), nil } func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) protocol.ID { @@ -257,9 +256,9 @@ func (h *BasicHost) newStream(ctx context.Context, p peer.ID, pid protocol.ID) ( return nil, err } - s.SetProtocol(string(pid)) + s.SetProtocol(pid) - logStream := mstream.WrapStream(s, pid, h.bwc) + logStream := mstream.WrapStream(s, h.bwc) lzcon := msmux.NewMSSelect(logStream, string(pid)) return &streamWrapper{ diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 5b56684bd..e52222d97 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -76,7 +76,7 @@ func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) { return h1, h2 } -func assertWait(t *testing.T, c chan string, exp string) { +func assertWait(t *testing.T, c chan protocol.ID, exp protocol.ID) { select { case proto := <-c: if proto != exp { @@ -99,7 +99,7 @@ func TestHostProtoPreference(t *testing.T) { protoNew := protocol.ID("/testing/1.1.0") protoMinor := protocol.ID("/testing/1.2.0") - connectedOn := make(chan string, 16) + connectedOn := make(chan protocol.ID, 16) handler := func(s inet.Stream) { connectedOn <- s.Protocol() @@ -113,10 +113,10 @@ func TestHostProtoPreference(t *testing.T) { t.Fatal(err) } - assertWait(t, connectedOn, string(protoOld)) + assertWait(t, connectedOn, protoOld) s.Close() - mfunc, err := host.MultistreamSemverMatcher(string(protoMinor)) + mfunc, err := host.MultistreamSemverMatcher(protoMinor) if err != nil { t.Fatal(err) } @@ -135,7 +135,7 @@ func TestHostProtoPreference(t *testing.T) { t.Fatal(err) } - assertWait(t, connectedOn, string(protoOld)) + assertWait(t, connectedOn, protoOld) s2.Close() @@ -144,12 +144,7 @@ func TestHostProtoPreference(t *testing.T) { t.Fatal(err) } - _, err = s3.Read(nil) - if err != nil { - t.Fatal(err) - } - - assertWait(t, connectedOn, string(protoMinor)) + assertWait(t, connectedOn, protoMinor) s3.Close() } @@ -179,7 +174,7 @@ func TestHostProtoPreknowledge(t *testing.T) { h1 := testutil.GenHostSwarm(t, ctx) h2 := testutil.GenHostSwarm(t, ctx) - conn := make(chan string, 16) + conn := make(chan protocol.ID, 16) handler := func(s inet.Stream) { conn <- s.Protocol() s.Close() diff --git a/p2p/host/match.go b/p2p/host/match.go index dfee37e26..571d652c2 100644 --- a/p2p/host/match.go +++ b/p2p/host/match.go @@ -1,13 +1,14 @@ package host import ( + "github.com/libp2p/go-libp2p/p2p/protocol" "strings" semver "github.com/coreos/go-semver/semver" ) -func MultistreamSemverMatcher(base string) (func(string) bool, error) { - parts := strings.Split(base, "/") +func MultistreamSemverMatcher(base protocol.ID) (func(string) bool, error) { + parts := strings.Split(string(base), "/") vers, err := semver.NewVersion(parts[len(parts)-1]) if err != nil { return nil, err diff --git a/p2p/metrics/stream/metered.go b/p2p/metrics/stream/metered.go index 2c7a4c6b9..14de24c57 100644 --- a/p2p/metrics/stream/metered.go +++ b/p2p/metrics/stream/metered.go @@ -19,18 +19,18 @@ type meteredStream struct { mesRecv metrics.StreamMeterCallback } -func newMeteredStream(base inet.Stream, pid protocol.ID, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream { +func newMeteredStream(base inet.Stream, p peer.ID, recvCB, sentCB metrics.StreamMeterCallback) inet.Stream { return &meteredStream{ Stream: base, mesSent: sentCB, mesRecv: recvCB, - protoKey: pid, + protoKey: base.Protocol(), peerKey: p, } } -func WrapStream(base inet.Stream, pid protocol.ID, bwc metrics.Reporter) inet.Stream { - return newMeteredStream(base, pid, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream) +func WrapStream(base inet.Stream, bwc metrics.Reporter) inet.Stream { + return newMeteredStream(base, base.Conn().RemotePeer(), bwc.LogRecvMessageStream, bwc.LogSentMessageStream) } func (s *meteredStream) Read(b []byte) (int, error) { diff --git a/p2p/metrics/stream/metered_test.go b/p2p/metrics/stream/metered_test.go index 5af586ee2..22b3f9370 100644 --- a/p2p/metrics/stream/metered_test.go +++ b/p2p/metrics/stream/metered_test.go @@ -24,6 +24,10 @@ func (fs *FakeStream) Write(b []byte) (int, error) { return len(b), nil } +func (fs *FakeStream) Protocol() protocol.ID { + return "TEST" +} + func TestCallbacksWork(t *testing.T) { fake := new(FakeStream) @@ -38,7 +42,7 @@ func TestCallbacksWork(t *testing.T) { recv += n } - ms := newMeteredStream(fake, protocol.ID("TEST"), peer.ID("PEER"), recvCB, sentCB) + ms := newMeteredStream(fake, peer.ID("PEER"), recvCB, sentCB) toWrite := int64(100000) toRead := int64(100000) diff --git a/p2p/net/interface.go b/p2p/net/interface.go index 12826ed8b..8ae607849 100644 --- a/p2p/net/interface.go +++ b/p2p/net/interface.go @@ -8,6 +8,7 @@ import ( ma "github.com/jbenet/go-multiaddr" "github.com/jbenet/goprocess" conn "github.com/libp2p/go-libp2p/p2p/net/conn" + protocol "github.com/libp2p/go-libp2p/p2p/protocol" context "golang.org/x/net/context" ) @@ -26,8 +27,8 @@ type Stream interface { io.Writer io.Closer - Protocol() string - SetProtocol(string) + Protocol() protocol.ID + SetProtocol(protocol.ID) // Conn returns the connection this stream is part of. Conn() Conn diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 2de6df2f3..dac95688f 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -7,6 +7,7 @@ import ( process "github.com/jbenet/goprocess" inet "github.com/libp2p/go-libp2p/p2p/net" + protocol "github.com/libp2p/go-libp2p/p2p/protocol" ) // stream implements inet.Stream @@ -17,7 +18,7 @@ type stream struct { toDeliver chan *transportObject proc process.Process - protocol string + protocol protocol.ID } type transportObject struct { @@ -50,11 +51,11 @@ func (s *stream) Write(p []byte) (n int, err error) { return len(p), nil } -func (s *stream) Protocol() string { +func (s *stream) Protocol() protocol.ID { return s.protocol } -func (s *stream) SetProtocol(proto string) { +func (s *stream) SetProtocol(proto protocol.ID) { s.protocol = proto } diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index 86719819a..dc365e8c0 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -2,6 +2,7 @@ package swarm import ( inet "github.com/libp2p/go-libp2p/p2p/net" + protocol "github.com/libp2p/go-libp2p/p2p/protocol" ps "github.com/jbenet/go-peerstream" ) @@ -10,7 +11,7 @@ import ( // our Conn and Swarm (instead of just the ps.Conn and ps.Swarm) type Stream struct { stream *ps.Stream - protocol string + protocol protocol.ID } // Stream returns the underlying peerstream.Stream @@ -44,11 +45,11 @@ func (s *Stream) Close() error { return s.stream.Close() } -func (s *Stream) Protocol() string { +func (s *Stream) Protocol() protocol.ID { return s.protocol } -func (s *Stream) SetProtocol(p string) { +func (s *Stream) SetProtocol(p protocol.ID) { s.protocol = p } diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index cba8b8ecc..b86023298 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -86,8 +86,10 @@ func (ids *IDService) IdentifyConn(c inet.Conn) { return } + s.SetProtocol(ID) + bwc := ids.Host.GetBandwidthReporter() - s = mstream.WrapStream(s, ID, bwc) + s = mstream.WrapStream(s, bwc) // ok give the response to our handler. if err := msmux.SelectProtoOrFail(ID, s); err != nil { @@ -115,7 +117,7 @@ func (ids *IDService) RequestHandler(s inet.Stream) { c := s.Conn() bwc := ids.Host.GetBandwidthReporter() - s = mstream.WrapStream(s, ID, bwc) + s = mstream.WrapStream(s, bwc) w := ggio.NewDelimitedWriter(s) mes := pb.Identify{}