mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-10-12 19:40:29 +08:00
fix(tcpreuse): handle connection that failed to be sampled (#3036)
Co-authored-by: Marco Munizaga <git@marcopolo.io>
This commit is contained in:
@@ -9,8 +9,11 @@ import (
|
|||||||
manet "github.com/multiformats/go-multiaddr/net"
|
manet "github.com/multiformats/go-multiaddr/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This is readiung the first 3 bytes of the packet. It should be instant.
|
// This is reading the first 3 bytes of the first packet after the handshake.
|
||||||
const identifyConnTimeout = 1 * time.Second
|
// It's set to the default TCP connect timeout in the TCP Transport.
|
||||||
|
//
|
||||||
|
// A var so we can change it in tests.
|
||||||
|
var identifyConnTimeout = 5 * time.Second
|
||||||
|
|
||||||
type DemultiplexedConnType int
|
type DemultiplexedConnType int
|
||||||
|
|
||||||
@@ -40,35 +43,35 @@ func (t DemultiplexedConnType) IsKnown() bool {
|
|||||||
|
|
||||||
// identifyConnType attempts to identify the connection type by peeking at the
|
// identifyConnType attempts to identify the connection type by peeking at the
|
||||||
// first few bytes.
|
// first few bytes.
|
||||||
// It Callers must not use the passed in Conn after this
|
// Its Callers must not use the passed in Conn after this function returns.
|
||||||
// function returns. if an error is returned, the connection will be closed.
|
// If an error is returned, the connection will be closed.
|
||||||
func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) {
|
func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) {
|
||||||
if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil {
|
if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil {
|
||||||
closeErr := c.Close()
|
closeErr := c.Close()
|
||||||
return 0, nil, errors.Join(err, closeErr)
|
return 0, nil, errors.Join(err, closeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
s, c, err := sampledconn.PeekBytes(c)
|
s, peekedConn, err := sampledconn.PeekBytes(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
closeErr := c.Close()
|
closeErr := c.Close()
|
||||||
return 0, nil, errors.Join(err, closeErr)
|
return 0, nil, errors.Join(err, closeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.SetReadDeadline(time.Time{}); err != nil {
|
if err := peekedConn.SetReadDeadline(time.Time{}); err != nil {
|
||||||
closeErr := c.Close()
|
closeErr := peekedConn.Close()
|
||||||
return 0, nil, errors.Join(err, closeErr)
|
return 0, nil, errors.Join(err, closeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if IsMultistreamSelect(s) {
|
if IsMultistreamSelect(s) {
|
||||||
return DemultiplexedConnType_MultistreamSelect, c, nil
|
return DemultiplexedConnType_MultistreamSelect, peekedConn, nil
|
||||||
}
|
}
|
||||||
if IsTLS(s) {
|
if IsTLS(s) {
|
||||||
return DemultiplexedConnType_TLS, c, nil
|
return DemultiplexedConnType_TLS, peekedConn, nil
|
||||||
}
|
}
|
||||||
if IsHTTP(s) {
|
if IsHTTP(s) {
|
||||||
return DemultiplexedConnType_HTTP, c, nil
|
return DemultiplexedConnType_HTTP, peekedConn, nil
|
||||||
}
|
}
|
||||||
return DemultiplexedConnType_Unknown, c, nil
|
return DemultiplexedConnType_Unknown, peekedConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
|
// Matchers are implemented here instead of in the transports so we can easily fuzz them together.
|
||||||
|
@@ -231,8 +231,6 @@ func (m *multiplexedListener) run() error {
|
|||||||
t, c, err := identifyConnType(c)
|
t, c, err := identifyConnType(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connScope.Done()
|
connScope.Done()
|
||||||
closeErr := c.Close()
|
|
||||||
err = errors.Join(err, closeErr)
|
|
||||||
log.Debugf("error demultiplexing connection: %s", err.Error())
|
log.Debugf("error demultiplexing connection: %s", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -447,3 +447,30 @@ func TestListenerClose(t *testing.T) {
|
|||||||
testClose(listenAddr)
|
testClose(listenAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setDeferReset[T any](t testing.TB, ptr *T, val T) {
|
||||||
|
t.Helper()
|
||||||
|
orig := *ptr
|
||||||
|
*ptr = val
|
||||||
|
t.Cleanup(func() { *ptr = orig })
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHitTimeout asserts that we don't panic in case we fail to peek at the connection.
|
||||||
|
func TestHitTimeout(t *testing.T) {
|
||||||
|
setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond)
|
||||||
|
// listen on port 0
|
||||||
|
cm := NewConnMgr(false, nil, nil)
|
||||||
|
|
||||||
|
listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0")
|
||||||
|
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ml.Close()
|
||||||
|
|
||||||
|
tcpConn, err := net.Dial(ml.Addr().Network(), ml.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Stall tcp conn for over the timeout.
|
||||||
|
time.Sleep(identifyConnTimeout + 100*time.Millisecond)
|
||||||
|
|
||||||
|
tcpConn.Close()
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user