fix(tcpreuse): handle connection that failed to be sampled (#3036)

Co-authored-by: Marco Munizaga <git@marcopolo.io>
This commit is contained in:
Adin Schmahmann
2024-11-20 19:50:42 -05:00
committed by GitHub
parent 1c9e5a1101
commit 773bedf877
3 changed files with 41 additions and 13 deletions

View File

@@ -9,8 +9,11 @@ import (
manet "github.com/multiformats/go-multiaddr/net"
)
// This is readiung the first 3 bytes of the packet. It should be instant.
const identifyConnTimeout = 1 * time.Second
// This is reading the first 3 bytes of the first packet after the handshake.
// It's set to the default TCP connect timeout in the TCP Transport.
//
// A var so we can change it in tests.
var identifyConnTimeout = 5 * time.Second
type DemultiplexedConnType int
@@ -40,35 +43,35 @@ func (t DemultiplexedConnType) IsKnown() bool {
// identifyConnType attempts to identify the connection type by peeking at the
// first few bytes.
// It Callers must not use the passed in Conn after this
// function returns. if an error is returned, the connection will be closed.
// Its Callers must not use the passed in Conn after this function returns.
// If an error is returned, the connection will be closed.
func identifyConnType(c manet.Conn) (DemultiplexedConnType, manet.Conn, error) {
if err := c.SetReadDeadline(time.Now().Add(identifyConnTimeout)); err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}
s, c, err := sampledconn.PeekBytes(c)
s, peekedConn, err := sampledconn.PeekBytes(c)
if err != nil {
closeErr := c.Close()
return 0, nil, errors.Join(err, closeErr)
}
if err := c.SetReadDeadline(time.Time{}); err != nil {
closeErr := c.Close()
if err := peekedConn.SetReadDeadline(time.Time{}); err != nil {
closeErr := peekedConn.Close()
return 0, nil, errors.Join(err, closeErr)
}
if IsMultistreamSelect(s) {
return DemultiplexedConnType_MultistreamSelect, c, nil
return DemultiplexedConnType_MultistreamSelect, peekedConn, nil
}
if IsTLS(s) {
return DemultiplexedConnType_TLS, c, nil
return DemultiplexedConnType_TLS, peekedConn, nil
}
if IsHTTP(s) {
return DemultiplexedConnType_HTTP, c, nil
return DemultiplexedConnType_HTTP, peekedConn, nil
}
return DemultiplexedConnType_Unknown, c, nil
return DemultiplexedConnType_Unknown, peekedConn, nil
}
// Matchers are implemented here instead of in the transports so we can easily fuzz them together.

View File

@@ -231,8 +231,6 @@ func (m *multiplexedListener) run() error {
t, c, err := identifyConnType(c)
if err != nil {
connScope.Done()
closeErr := c.Close()
err = errors.Join(err, closeErr)
log.Debugf("error demultiplexing connection: %s", err.Error())
return
}

View File

@@ -447,3 +447,30 @@ func TestListenerClose(t *testing.T) {
testClose(listenAddr)
}
}
func setDeferReset[T any](t testing.TB, ptr *T, val T) {
t.Helper()
orig := *ptr
*ptr = val
t.Cleanup(func() { *ptr = orig })
}
// TestHitTimeout asserts that we don't panic in case we fail to peek at the connection.
func TestHitTimeout(t *testing.T) {
setDeferReset(t, &identifyConnTimeout, 100*time.Millisecond)
// listen on port 0
cm := NewConnMgr(false, nil, nil)
listenAddr := ma.StringCast("/ip4/127.0.0.1/tcp/0")
ml, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)
defer ml.Close()
tcpConn, err := net.Dial(ml.Addr().Network(), ml.Addr().String())
require.NoError(t, err)
// Stall tcp conn for over the timeout.
time.Sleep(identifyConnTimeout + 100*time.Millisecond)
tcpConn.Close()
}