mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-10-15 12:42:26 +08:00
fix: basichost: Use NegotiationTimeout as fallback timeout for NewStream (#3020)
This commit is contained in:
@@ -122,9 +122,10 @@ type HostOpts struct {
|
|||||||
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
|
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
|
||||||
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]
|
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]
|
||||||
|
|
||||||
// NegotiationTimeout determines the read and write timeouts on streams.
|
// NegotiationTimeout determines the read and write timeouts when negotiating
|
||||||
// If 0 or omitted, it will use DefaultNegotiationTimeout.
|
// protocols for streams. If 0 or omitted, it will use
|
||||||
// If below 0, timeouts on streams will be deactivated.
|
// DefaultNegotiationTimeout. If below 0, timeouts on streams will be
|
||||||
|
// deactivated.
|
||||||
NegotiationTimeout time.Duration
|
NegotiationTimeout time.Duration
|
||||||
|
|
||||||
// AddrsFactory holds a function which can be used to override or filter the result of Addrs.
|
// AddrsFactory holds a function which can be used to override or filter the result of Addrs.
|
||||||
@@ -689,6 +690,14 @@ func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
|
|||||||
// to create one. If ProtocolID is "", writes no header.
|
// to create one. If ProtocolID is "", writes no header.
|
||||||
// (Thread-safe)
|
// (Thread-safe)
|
||||||
func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) {
|
func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.ID) (str network.Stream, strErr error) {
|
||||||
|
if _, ok := ctx.Deadline(); !ok {
|
||||||
|
if h.negtimeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, h.negtimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If the caller wants to prevent the host from dialing, it should use the NoDial option.
|
// If the caller wants to prevent the host from dialing, it should use the NoDial option.
|
||||||
if nodial, _ := network.GetNoDial(ctx); !nodial {
|
if nodial, _ := network.GetNoDial(ctx); !nodial {
|
||||||
err := h.Connect(ctx, peer.AddrInfo{ID: p})
|
err := h.Connect(ctx, peer.AddrInfo{ID: p})
|
||||||
|
@@ -2,6 +2,7 @@ package basichost
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -941,3 +942,56 @@ func TestTrimHostAddrList(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHostTimeoutNewStream(t *testing.T) {
|
||||||
|
h1, err := NewHost(swarmt.GenSwarm(t), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
h1.Start()
|
||||||
|
defer h1.Close()
|
||||||
|
|
||||||
|
const proto = "/testing"
|
||||||
|
h2 := swarmt.GenSwarm(t)
|
||||||
|
|
||||||
|
h2.SetStreamHandler(func(s network.Stream) {
|
||||||
|
// First message is multistream header. Just echo it
|
||||||
|
msHeader := []byte("\x19/multistream/1.0.0\n")
|
||||||
|
_, err := s.Read(msHeader)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = s.Write(msHeader)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, err := s.Read(buf)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
msgLen, varintN := binary.Uvarint(buf[:n])
|
||||||
|
buf = buf[varintN:]
|
||||||
|
proto := buf[:int(msgLen)]
|
||||||
|
if string(proto) == "/ipfs/id/1.0.0\n" {
|
||||||
|
// Signal we don't support identify
|
||||||
|
na := []byte("na\n")
|
||||||
|
n := binary.PutUvarint(buf, uint64(len(na)))
|
||||||
|
copy(buf[n:], na)
|
||||||
|
|
||||||
|
_, err = s.Write(buf[:int(n)+len(na)])
|
||||||
|
assert.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
// Stall
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
}
|
||||||
|
t.Log("Resetting")
|
||||||
|
s.Reset()
|
||||||
|
})
|
||||||
|
|
||||||
|
err = h1.Connect(context.Background(), peer.AddrInfo{
|
||||||
|
ID: h2.LocalPeer(),
|
||||||
|
Addrs: h2.ListenAddresses(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// No context passed in, fallback to negtimeout
|
||||||
|
h1.negtimeout = time.Second
|
||||||
|
_, err = h1.NewStream(context.Background(), h2.LocalPeer(), proto)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, "context deadline exceeded")
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user