feat: WebRTC reuse QUIC conn (#2889)

* feat: WebRTC reuse QUIC conn

* Fix transport constructor in test

* Move provide to where the transports are
This commit is contained in:
Marco Munizaga
2024-08-01 07:36:46 -07:00
committed by GitHub
parent ae1645d24e
commit db41da3b26
8 changed files with 269 additions and 9 deletions

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"net"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
@@ -35,10 +36,12 @@ import (
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
"github.com/prometheus/client_golang/prometheus"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
"go.uber.org/fx"
"go.uber.org/fx/fxevent"
@@ -284,6 +287,29 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
quicAddrPorts := map[string]struct{}{}
for _, addr := range sw.ListenAddresses() {
if _, err := addr.ValueForProtocol(ma.P_QUIC_V1); err == nil {
netw, addr, err := manet.DialArgs(addr)
if err != nil {
return false
}
quicAddrPorts[netw+"_"+addr] = struct{}{}
}
}
_, ok := quicAddrPorts[network+"_"+laddr.String()]
return ok
}
return func(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
if hasQuicAddrPortFor(network, laddr) {
return cm.SharedNonQUICPacketConn(network, laddr)
}
return net.ListenUDP(network, laddr)
}
}),
}
fxopts = append(fxopts, cfg.Transports...)
if cfg.Insecure {

View File

@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"testing"
"time"
@@ -24,7 +26,9 @@ import (
"github.com/libp2p/go-libp2p/p2p/security/noise"
tls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"go.uber.org/goleak"
@@ -416,6 +420,7 @@ func TestMain(m *testing.M) {
m,
// This will return eventually (5s timeout) but doesn't take a context.
goleak.IgnoreAnyFunction("github.com/koron/go-ssdp.Search"),
goleak.IgnoreAnyFunction("github.com/pion/sctp.(*Stream).SetReadDeadline.func1"),
// Logging & Stats
goleak.IgnoreTopFunction("github.com/ipfs/go-log/v2/writer.(*MirrorWriter).logRoutine"),
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
@@ -490,3 +495,84 @@ func TestHostAddrsFactoryAddsCerthashes(t *testing.T) {
}, 5*time.Second, 50*time.Millisecond)
h.Close()
}
func TestWebRTCReuseAddrWithQUIC(t *testing.T) {
order := [][]string{
{"/ip4/127.0.0.1/udp/54322/quic-v1", "/ip4/127.0.0.1/udp/54322/webrtc-direct"},
{"/ip4/127.0.0.1/udp/54322/webrtc-direct", "/ip4/127.0.0.1/udp/54322/quic-v1"},
// We do not support WebRTC automatically reusing QUIC addresses if port is not specified, yet.
// {"/ip4/127.0.0.1/udp/0/webrtc-direct", "/ip4/127.0.0.1/udp/0/quic-v1"},
}
for i, addrs := range order {
t.Run("Order "+strconv.Itoa(i), func(t *testing.T) {
h1, err := New(ListenAddrStrings(addrs...), Transport(quic.NewTransport), Transport(libp2pwebrtc.New))
require.NoError(t, err)
defer h1.Close()
seenPorts := make(map[string]struct{})
for _, addr := range h1.Addrs() {
s, err := addr.ValueForProtocol(ma.P_UDP)
require.NoError(t, err)
seenPorts[s] = struct{}{}
}
require.Len(t, seenPorts, 1)
quicClient, err := New(NoListenAddrs, Transport(quic.NewTransport))
require.NoError(t, err)
defer quicClient.Close()
webrtcClient, err := New(NoListenAddrs, Transport(libp2pwebrtc.New))
require.NoError(t, err)
defer webrtcClient.Close()
for _, client := range []host.Host{quicClient, webrtcClient} {
err := client.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()})
require.NoError(t, err)
}
t.Run("quic client can connect", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
p := ping.NewPingService(quicClient)
resCh := p.Ping(ctx, h1.ID())
res := <-resCh
require.NoError(t, res.Error)
})
t.Run("webrtc client can connect", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
p := ping.NewPingService(webrtcClient)
resCh := p.Ping(ctx, h1.ID())
res := <-resCh
require.NoError(t, res.Error)
})
})
}
swapPort := func(addrStrs []string, newPort string) []string {
out := make([]string, 0, len(addrStrs))
for _, addrStr := range addrStrs {
out = append(out, strings.Replace(addrStr, "54322", newPort, 1))
}
return out
}
t.Run("setup with no reuseport. Should fail", func(t *testing.T) {
h1, err := New(ListenAddrStrings(swapPort(order[0], "54323")...), Transport(quic.NewTransport), Transport(libp2pwebrtc.New), QUICReuse(quicreuse.NewConnManager, quicreuse.DisableReuseport()))
require.NoError(t, err) // It's a bug/feature that swarm.Listen does not error if at least one transport succeeds in listening.
defer h1.Close()
// Check that webrtc did fail to listen
require.Equal(t, 1, len(h1.Addrs()))
require.Contains(t, h1.Addrs()[0].String(), "quic-v1")
})
t.Run("setup with autonat", func(t *testing.T) {
h1, err := New(EnableAutoNATv2(), ListenAddrStrings(swapPort(order[0], "54324")...), Transport(quic.NewTransport), Transport(libp2pwebrtc.New), QUICReuse(quicreuse.NewConnManager, quicreuse.DisableReuseport()))
require.NoError(t, err) // It's a bug/feature that swarm.Listen does not error if at least one transport succeeds in listening.
defer h1.Close()
// Check that webrtc did fail to listen
require.Equal(t, 1, len(h1.Addrs()))
require.Contains(t, h1.Addrs()[0].String(), "quic-v1")
})
}

View File

@@ -3,6 +3,7 @@ package swarm
import (
"errors"
"fmt"
"slices"
"time"
"github.com/libp2p/go-libp2p/core/canonicallog"
@@ -12,13 +13,44 @@ import (
ma "github.com/multiformats/go-multiaddr"
)
type OrderedListener interface {
// Transports optionally implement this interface to indicate the relative
// ordering that listeners should be setup. Some transports may optionally
// make use of other listeners if they are setup. e.g. WebRTC may reuse the
// same UDP port as QUIC, but only when QUIC is setup first.
// lower values are setup first.
ListenOrder() int
}
// Listen sets up listeners for all of the given addresses.
// It returns as long as we successfully listen on at least *one* address.
func (s *Swarm) Listen(addrs ...ma.Multiaddr) error {
errs := make([]error, len(addrs))
var succeeded int
for i, a := range addrs {
if err := s.AddListenAddr(a); err != nil {
type addrAndListener struct {
addr ma.Multiaddr
lTpt transport.Transport
}
sortedAddrsAndTpts := make([]addrAndListener, 0, len(addrs))
for _, a := range addrs {
t := s.TransportForListening(a)
sortedAddrsAndTpts = append(sortedAddrsAndTpts, addrAndListener{addr: a, lTpt: t})
}
slices.SortFunc(sortedAddrsAndTpts, func(a, b addrAndListener) int {
aOrder := 0
bOrder := 0
if l, ok := a.lTpt.(OrderedListener); ok {
aOrder = l.ListenOrder()
}
if l, ok := b.lTpt.(OrderedListener); ok {
bOrder = l.ListenOrder()
}
return aOrder - bOrder
})
for i, a := range sortedAddrsAndTpts {
if err := s.AddListenAddr(a.addr); err != nil {
errs[i] = err
} else {
succeeded++
@@ -27,11 +59,11 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error {
for i, e := range errs {
if e != nil {
log.Warnw("listening failed", "on", addrs[i], "error", errs[i])
log.Warnw("listening failed", "on", sortedAddrsAndTpts[i].addr, "error", errs[i])
}
}
if succeeded == 0 && len(addrs) > 0 {
if succeeded == 0 && len(sortedAddrsAndTpts) > 0 {
return fmt.Errorf("failed to listen on any addresses: %s", errs)
}

View File

@@ -26,6 +26,8 @@ import (
"github.com/quic-go/quic-go"
)
const ListenOrder = 1
var log = logging.Logger("quic-transport")
var ErrHolePunching = errors.New("hole punching attempted; no active dial")
@@ -103,6 +105,10 @@ func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.P
}, nil
}
func (t *transport) ListenOrder() int {
return ListenOrder
}
// Dial dials a new QUIC connection
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (_c tpt.CapableConn, _err error) {
if ok, isClient, _ := network.GetSimultaneousConnect(ctx); ok && !isClient {

View File

@@ -154,6 +154,28 @@ func (c *ConnManager) onListenerClosed(key string) {
}
}
func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
c.quicListenersMu.Lock()
defer c.quicListenersMu.Unlock()
key := laddr.String()
entry, ok := c.quicListeners[key]
if !ok {
return nil, errors.New("expected to be able to share with a QUIC listener, but no QUIC listener found. The QUIC listener should start first")
}
t := entry.ln.transport
if t, ok := t.(*refcountedTransport); ok {
t.IncreaseCount()
ctx, cancel := context.WithCancel(context.Background())
return &nonQUICPacketConn{
ctx: ctx,
ctxCancel: cancel,
owningTransport: t,
tr: &t.Transport,
}, nil
}
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
}
func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)

View File

@@ -0,0 +1,74 @@
package quicreuse
import (
"context"
"net"
"time"
"github.com/quic-go/quic-go"
)
// nonQUICPacketConn is a net.PacketConn that can be used to read and write
// non-QUIC packets on a quic.Transport. This lets us reuse this UDP port for
// other transports like WebRTC.
type nonQUICPacketConn struct {
owningTransport refCountedQuicTransport
tr *quic.Transport
ctx context.Context
ctxCancel context.CancelFunc
readCtx context.Context
readCancel context.CancelFunc
}
// Close implements net.PacketConn.
func (n *nonQUICPacketConn) Close() error {
n.ctxCancel()
// Don't actually close the underlying transport since someone else might be using it.
// reuse has it's own gc to close unused transports.
n.owningTransport.DecreaseCount()
return nil
}
// LocalAddr implements net.PacketConn.
func (n *nonQUICPacketConn) LocalAddr() net.Addr {
return n.tr.Conn.LocalAddr()
}
// ReadFrom implements net.PacketConn.
func (n *nonQUICPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
ctx := n.readCtx
if ctx == nil {
ctx = n.ctx
}
return n.tr.ReadNonQUICPacket(ctx, p)
}
// SetDeadline implements net.PacketConn.
func (n *nonQUICPacketConn) SetDeadline(t time.Time) error {
// Only used for reads.
return n.SetReadDeadline(t)
}
// SetReadDeadline implements net.PacketConn.
func (n *nonQUICPacketConn) SetReadDeadline(t time.Time) error {
if t.IsZero() && n.readCtx != nil {
n.readCancel()
n.readCtx = nil
}
n.readCtx, n.readCancel = context.WithDeadline(n.ctx, t)
return nil
}
// SetWriteDeadline implements net.PacketConn.
func (n *nonQUICPacketConn) SetWriteDeadline(t time.Time) error {
// Unused. quic-go doesn't support deadlines for writes.
return nil
}
// WriteTo implements net.PacketConn.
func (n *nonQUICPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return n.tr.WriteTo(p, addr)
}
var _ net.PacketConn = &nonQUICPacketConn{}

View File

@@ -26,6 +26,7 @@ import (
"github.com/libp2p/go-libp2p/core/sec"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
"github.com/libp2p/go-msgio"
@@ -78,6 +79,8 @@ type WebRTCTransport struct {
noiseTpt *noise.Transport
localPeerId peer.ID
listenUDP func(network string, laddr *net.UDPAddr) (net.PacketConn, error)
// timeouts
peerConnectionTimeouts iceTimeouts
@@ -95,7 +98,9 @@ type iceTimeouts struct {
Keepalive time.Duration
}
func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (*WebRTCTransport, error) {
type ListenUDPFn func(network string, laddr *net.UDPAddr) (net.PacketConn, error)
func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, listenUDP ListenUDPFn, opts ...Option) (*WebRTCTransport, error) {
if psk != nil {
log.Error("WebRTC doesn't support private networks yet.")
return nil, fmt.Errorf("WebRTC doesn't support private networks yet")
@@ -141,6 +146,7 @@ func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr
noiseTpt: noiseTpt,
localPeerId: localPeerID,
listenUDP: listenUDP,
peerConnectionTimeouts: iceTimeouts{
Disconnect: DefaultDisconnectedTimeout,
Failed: DefaultFailedTimeout,
@@ -157,6 +163,10 @@ func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr
return transport, nil
}
func (t *WebRTCTransport) ListenOrder() int {
return libp2pquic.ListenOrder + 1 // We want to listen after QUIC listens so we can possibly reuse the same port.
}
func (t *WebRTCTransport) Protocols() []int {
return []int{ma.P_WEBRTC_DIRECT}
}
@@ -190,7 +200,7 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
return nil, fmt.Errorf("listener could not resolve udp address: %w", err)
}
socket, err := net.ListenUDP(nw, udpAddr)
socket, err := t.listenUDP(nw, udpAddr)
if err != nil {
return nil, fmt.Errorf("listen on udp: %w", err)
}
@@ -203,7 +213,7 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
return listener, nil
}
func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error) {
func (t *WebRTCTransport) listenSocket(socket net.PacketConn) (tpt.Listener, error) {
listenerMultiaddr, err := manet.FromNetAddr(socket.LocalAddr())
if err != nil {
return nil, err

View File

@@ -29,12 +29,16 @@ import (
"golang.org/x/crypto/sha3"
)
var netListenUDP ListenUDPFn = func(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
return net.ListenUDP(network, laddr)
}
func getTransport(t *testing.T, opts ...Option) (*WebRTCTransport, peer.ID) {
t.Helper()
privKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1)
require.NoError(t, err)
rcmgr := &network.NullResourceManager{}
transport, err := New(privKey, nil, nil, rcmgr, opts...)
transport, err := New(privKey, nil, nil, rcmgr, netListenUDP, opts...)
require.NoError(t, err)
peerID, err := peer.IDFromPrivateKey(privKey)
require.NoError(t, err)
@@ -45,7 +49,7 @@ func getTransport(t *testing.T, opts ...Option) (*WebRTCTransport, peer.ID) {
func TestNullRcmgrTransport(t *testing.T) {
privKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1)
require.NoError(t, err)
transport, err := New(privKey, nil, nil, nil)
transport, err := New(privKey, nil, nil, nil, netListenUDP)
require.NoError(t, err)
listenTransport, pid := getTransport(t)