diff --git a/core/network/conn.go b/core/network/conn.go index aa6b96f71..f82da29e6 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -87,6 +87,20 @@ type Conn interface { // IsClosed returns whether a connection is fully closed, so it can // be garbage collected. IsClosed() bool + + // As finds the first conn in Conn's wrapped types that matches target, and + // if one is found, sets target to that conn value and returns true. + // Otherwise, it returns false. Similar to errors.As. + // + // target must be a pointer to the type you are matching against. + // + // This is an EXPERIMENTAL API. Getting access to the underlying type can + // lead to hard to debug issues. For example, if you mutate connection state + // on the underlying type, hooks that relied on only mutating that state + // from the wrapped connection would never be called. + // + // You very likely do not need to use this method. + As(target any) bool } // ConnectionState holds information about the connection. diff --git a/core/network/mux.go b/core/network/mux.go index be61ccf62..8a62b81cb 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -137,6 +137,20 @@ type MuxedConn interface { // AcceptStream accepts a stream opened by the other side. AcceptStream() (MuxedStream, error) + + // As finds the first conn in MuxedConn's wrapped types that matches target, + // and if one is found, sets target to that conn value and returns true. + // Otherwise, it returns false. Similar to errors.As. + // + // target must be a pointer to the type you are matching against. + // + // This is an EXPERIMENTAL API. Getting access to the underlying type can + // lead to hard to debug issues. For example, if you mutate connection state + // on the underlying type, hooks that relied on only mutating that state + // from the wrapped connection would never be called. + // + // You very likely do not need to use this method. + As(target any) bool } // Multiplexer wraps a net.Conn with a stream multiplexing diff --git a/libp2p_test.go b/libp2p_test.go index df1c793d5..ba8c475ab 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -40,6 +40,10 @@ import ( libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "github.com/libp2p/go-libp2p/p2p/transport/websocket" webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + "github.com/libp2p/go-yamux/v5" + "github.com/pion/webrtc/v4" + quicgo "github.com/quic-go/quic-go" + wtgo "github.com/quic-go/webtransport-go" "go.uber.org/goleak" ma "github.com/multiformats/go-multiaddr" @@ -842,3 +846,76 @@ func BenchmarkAllAddrs(b *testing.B) { addrsHost.AllAddrs() } } + +func TestConnAs(t *testing.T) { + type testCase struct { + name string + listenAddr string + testAs func(t *testing.T, c network.Conn) + } + + testCases := []testCase{ + { + "QUIC", + "/ip4/0.0.0.0/udp/0/quic-v1", + func(t *testing.T, c network.Conn) { + var quicConn *quicgo.Conn + require.True(t, c.As(&quicConn)) + }, + }, + { + "TCP+Yamux", + "/ip4/0.0.0.0/tcp/0", + func(t *testing.T, c network.Conn) { + var yamuxSession *yamux.Session + require.True(t, c.As(&yamuxSession)) + }, + }, + { + "WebRTC", + "/ip4/0.0.0.0/udp/0/webrtc-direct", + func(t *testing.T, c network.Conn) { + var webrtcPC *webrtc.PeerConnection + require.True(t, c.As(&webrtcPC)) + }, + }, + { + "WebTransport Session", + "/ip4/0.0.0.0/udp/0/quic-v1/webtransport", + func(t *testing.T, c network.Conn) { + var s *wtgo.Session + require.True(t, c.As(&s)) + }, + }, + { + "WebTransport QUIC Conn", + "/ip4/0.0.0.0/udp/0/quic-v1/webtransport", + func(t *testing.T, c network.Conn) { + var quicConn *quicgo.Conn + require.True(t, c.As(&quicConn)) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + h1, err := New(ListenAddrStrings( + tc.listenAddr, + )) + require.NoError(t, err) + defer h1.Close() + h2, err := New(ListenAddrStrings( + tc.listenAddr, + )) + require.NoError(t, err) + defer h2.Close() + err = h1.Connect(context.Background(), peer.AddrInfo{ + ID: h2.ID(), + Addrs: h2.Addrs(), + }) + require.NoError(t, err) + c := h1.Network().ConnsToPeer(h2.ID())[0] + tc.testAs(t, c) + }) + } +} diff --git a/p2p/muxer/yamux/conn.go b/p2p/muxer/yamux/conn.go index 54a856e58..d33f3b00a 100644 --- a/p2p/muxer/yamux/conn.go +++ b/p2p/muxer/yamux/conn.go @@ -13,6 +13,14 @@ type conn yamux.Session var _ network.MuxedConn = &conn{} +func (c *conn) As(target any) bool { + if t, ok := target.(**yamux.Session); ok { + *t = (*yamux.Session)(c) + return true + } + return false +} + // NewMuxedConn constructs a new MuxedConn from a yamux.Session. func NewMuxedConn(m *yamux.Session) network.MuxedConn { return (*conn)(m) diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index e8e61914d..b9cc5c00e 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -818,6 +818,7 @@ func (m mockConn) NewStream(_ context.Context) (network.Stream, error) { panic(" func (m mockConn) GetStreams() []network.Stream { panic("implement me") } func (m mockConn) Scope() network.ConnScope { panic("implement me") } func (m mockConn) ConnState() network.ConnectionState { return network.ConnectionState{} } +func (m mockConn) As(_ any) bool { return false } func makeSegmentsWithPeerInfos(peerInfos peerInfos) *segments { var s = func() *segments { diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 36c60f53f..216e0d9b1 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -86,6 +86,10 @@ func (c *conn) Close() error { return nil } +func (c *conn) As(_ any) bool { + return false +} + func (c *conn) teardown() { for _, s := range c.allStreams() { s.Reset() diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 503f62f7b..b97fe669a 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -833,6 +833,10 @@ func wrapWithMetrics(capableConn transport.CapableConn, metricsTracer MetricsTra return c } +func (c *connWithMetrics) As(target any) bool { + return c.CapableConn.As(target) +} + func (c *connWithMetrics) completedHandshake() { c.metricsTracer.CompletedHandshake(time.Since(c.opened), c.ConnState(), c.LocalMultiaddr()) } diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 1d6cf96b4..7bcda66b9 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -42,6 +42,10 @@ type Conn struct { var _ network.Conn = &Conn{} +func (c *Conn) As(target any) bool { + return c.conn.As(target) +} + func (c *Conn) IsClosed() bool { return c.conn.IsClosed() } diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 2cc4dcfbb..d8ed022e1 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -23,6 +23,10 @@ type transportConn struct { var _ transport.CapableConn = &transportConn{} +func (c *transportConn) As(target any) bool { + return c.MuxedConn.As(target) +} + func (t *transportConn) Transport() transport.Transport { return t.transport } diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index a8dba723f..0bd5e5ab3 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -25,6 +25,15 @@ type conn struct { remoteMultiaddr ma.Multiaddr } +func (c *conn) As(target any) bool { + if t, ok := target.(**quic.Conn); ok { + *t = c.quicConn + return true + } + + return false +} + var _ tpt.CapableConn = &conn{} // Close closes the connection. diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index d75c309c5..bb97c55fd 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -132,6 +132,14 @@ func (c *connection) Close() error { return nil } +func (c *connection) As(target any) bool { + if target, ok := target.(**webrtc.PeerConnection); ok { + *target = c.pc + return true + } + return false +} + // CloseWithError closes the connection ignoring the error code. As there's no way to signal // the remote peer on closing the underlying peerconnection, we ignore the error code. func (c *connection) CloseWithError(_ network.ConnErrorCode) error { diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 44b4d2fb8..fd4b1187b 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -89,3 +89,15 @@ func (c *conn) Transport() tpt.Transport { return c.transport } func (c *conn) ConnState() network.ConnectionState { return network.ConnectionState{Transport: "webtransport"} } + +func (c *conn) As(target any) bool { + if target, ok := target.(**quic.Conn); ok { + *target = c.qconn + return true + } + if target, ok := target.(**webtransport.Session); ok { + *target = c.session + return true + } + return false +}