diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go index dedca7488..6a28a4eb7 100644 --- a/p2p/transport/quicreuse/connmgr.go +++ b/p2p/transport/quicreuse/connmgr.go @@ -2,13 +2,13 @@ // for reusing QUIC transports for various purposes, like listening & dialing, having // multiple QUIC listeners on the same address with different ALPNs, and sharing the // same address with non QUIC transports like WebRTC. + package quicreuse import ( "context" "crypto/tls" "errors" - "fmt" "io" "net" "sync" @@ -224,7 +224,7 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, key := laddr.String() entry, ok := c.quicListeners[key] if !ok { - tr, err := c.transportForListen(association, netw, laddr) + tr, err := c.transportForListen(netw, laddr) if err != nil { return nil, err } @@ -234,20 +234,15 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, } key = tr.LocalAddr().String() entry = quicListenerEntry{ln: ln} - } else if c.enableReuseport && association != nil { - reuse, err := c.getReuse(netw) - if err != nil { - return nil, fmt.Errorf("reuse error: %w", err) - } - err = reuse.AssertTransportExists(entry.ln.transport) - if err != nil { - return nil, fmt.Errorf("reuse assert transport failed: %w", err) - } - if tr, ok := entry.ln.transport.(*refcountedTransport); ok { - tr.associate(association) + } + if c.enableReuseport && association != nil { + if _, ok := entry.ln.transport.(*refcountedTransport); !ok { + log.Warnf("reuseport is enabled, association is non-nil, but the transport is not a refcountedTransport.") } } - l, err := entry.ln.Add(tlsConf, allowWindowIncrease, func() { c.onListenerClosed(key) }) + l, err := entry.ln.Add(association, tlsConf, allowWindowIncrease, func() { + c.onListenerClosed(key) + }) if err != nil { if entry.refCount <= 0 { entry.ln.Close() @@ -296,7 +291,7 @@ func (c *ConnManager) SharedNonQUICPacketConn(_ string, laddr *net.UDPAddr) (net return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set") } -func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (RefCountedQUICTransport, error) { +func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (RefCountedQUICTransport, error) { if c.enableReuseport { reuse, err := c.getReuse(network) if err != nil { @@ -306,7 +301,6 @@ func (c *ConnManager) transportForListen(association any, network string, laddr if err != nil { return nil, err } - tr.associate(association) return tr, nil } diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go index 8d85dffae..f99646cdc 100644 --- a/p2p/transport/quicreuse/connmgr_test.go +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -95,7 +95,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) { _, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil) require.NoError(t, err) - quicTr, err := cm.transportForListen(nil, netw, naddr) + quicTr, err := cm.transportForListen(netw, naddr) require.NoError(t, err) defer quicTr.Close() if _, ok := quicTr.(*singleOwnerTransport).Transport.(*wrappedQUICTransport).Conn.(quic.OOBCapablePacketConn); !ok { @@ -489,3 +489,139 @@ func TestConnContext(t *testing.T) { }) } } + +func TestAssociationCleanup(t *testing.T) { + cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) + require.NoError(t, err) + defer cm.Close() + + // Create 3 listeners with 3 different associations + lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}} + assoc1 := "test-association-1" + assoc2 := "test-association-2" + assoc3 := "test-association-3" + + ln1, err := cm.ListenQUICAndAssociate(assoc1, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil) + require.NoError(t, err) + defer ln1.Close() + + addr := ln1.Multiaddrs()[0] + port, err := addr.ValueForProtocol(ma.P_UDP) + require.NoError(t, err) + + h3TLS := &tls.Config{NextProtos: []string{"h3"}} + ln2, err := cm.ListenQUICAndAssociate(assoc2, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic-v1", port)), h3TLS, nil) + require.NoError(t, err) + defer ln2.Close() + + ln3, err := cm.ListenQUICAndAssociate(assoc3, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil) + require.NoError(t, err) + defer ln3.Close() + + // Get the listen addresses for verification + addr1 := ln1.Addr().String() + addr2 := ln2.Addr().String() + addr3 := ln3.Addr().String() + require.Equal(t, addr1, addr2) + + // Test that dialing with assoc1 uses the first listener's address + dialAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + + numTries := 100 + + for i := 0; i < numTries; i++ { + tr, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr) + require.NoError(t, err) + require.Equal(t, addr1, tr.LocalAddr().String(), "assoc1 should use addr1") + } + + // Close the first listener + ln1.Close() + + // Call TransportWithAssociationForDial 10 times with assoc1 and check if we get at least one different address + foundDifferentAddr := false + for i := 0; i < numTries; i++ { + tr, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr) + require.NoError(t, err) + actualAddr := tr.LocalAddr().String() + if actualAddr != addr1 { + foundDifferentAddr = true + break + } + } + require.True(t, foundDifferentAddr, "assoc1 should use a different address than addr1 at least once after ln1 is closed") + + for i := 0; i < numTries; i++ { + // Test that dialing with assoc2 still uses the second listener's address + tr2Still, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr) + require.NoError(t, err) + require.Equal(t, addr2, tr2Still.LocalAddr().String(), "assoc2 should still use addr2") + } + + // Close the second listener + ln2.Close() + + // Call TransportWithAssociationForDial 10 times with assoc2 and check if we get at least one different address + foundDifferentAddr2 := false + for i := 0; i < numTries; i++ { + tr, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr) + require.NoError(t, err) + actualAddr := tr.LocalAddr().String() + if actualAddr != addr2 { + foundDifferentAddr2 = true + } + } + require.True(t, foundDifferentAddr2, "assoc2 should use a different address than addr2 at least once after ln2 is closed") + + for i := 0; i < numTries; i++ { + // Test that dialing with assoc3 still uses the third listener's address + tr3Still, err := cm.TransportWithAssociationForDial(assoc3, "udp4", dialAddr) + require.NoError(t, err) + require.Equal(t, addr3, tr3Still.LocalAddr().String(), "assoc3 should still use addr3") + } +} + +func TestConnManagerIsolation(t *testing.T) { + // Create two separate ConnManager instances + cm1, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) + require.NoError(t, err) + defer cm1.Close() + + cm2, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) + require.NoError(t, err) + defer cm2.Close() + + // Create listeners in both ConnManagers + lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}} + assoc1 := "cm1-association" + assoc2 := "cm2-association" + + ln1, err := cm1.ListenQUICAndAssociate(assoc1, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil) + require.NoError(t, err) + defer ln1.Close() + + ln2, err := cm2.ListenQUICAndAssociate(assoc2, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil) + require.NoError(t, err) + defer ln2.Close() + + // Verify that each ConnManager has its own isolated associations + + // Verify associations are isolated + cm1.quicListenersMu.Lock() + key1 := ln1.Addr().String() + entry1 := cm1.quicListeners[key1] + tr1, ok := entry1.ln.transport.(*refcountedTransport) + require.True(t, ok) + require.True(t, tr1.hasAssociation(assoc1)) + require.False(t, tr1.hasAssociation(assoc2)) + cm1.quicListenersMu.Unlock() + + cm2.quicListenersMu.Lock() + key2 := ln2.Addr().String() + entry2 := cm2.quicListeners[key2] + tr2, ok := entry2.ln.transport.(*refcountedTransport) + require.True(t, ok) + require.True(t, tr2.hasAssociation(assoc2)) + require.False(t, tr2.hasAssociation(assoc1)) + cm2.quicListenersMu.Unlock() +} diff --git a/p2p/transport/quicreuse/listener.go b/p2p/transport/quicreuse/listener.go index 48c075203..71a896b77 100644 --- a/p2p/transport/quicreuse/listener.go +++ b/p2p/transport/quicreuse/listener.go @@ -91,7 +91,7 @@ func (l *quicListener) allowWindowIncrease(conn *quic.Conn, delta uint64) bool { return conf.allowWindowIncrease(conn, delta) } -func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn *quic.Conn, delta uint64) bool, onRemove func()) (Listener, error) { +func (l *quicListener) Add(association any, tlsConf *tls.Config, allowWindowIncrease func(conn *quic.Conn, delta uint64) bool, onRemove func()) (*listener, error) { l.protocolsMu.Lock() defer l.protocolsMu.Unlock() @@ -105,14 +105,32 @@ func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn *q } } - ln := newSingleListener(l.l.Addr(), l.addrs, func() { + ln := &listener{ + queue: make(chan *quic.Conn, queueLen), + acceptLoopRunning: l.running, + addr: l.l.Addr(), + addrs: l.addrs, + } + if association != nil { + if tr, ok := l.transport.(*refcountedTransport); ok { + tr.associateForListener(association, ln) + } + } + + ln.remove = func() { + if association != nil { + if tr, ok := l.transport.(*refcountedTransport); ok { + tr.RemoveAssociationsForListener(ln) + } + } l.protocolsMu.Lock() for _, proto := range tlsConf.NextProtos { delete(l.protocols, proto) } l.protocolsMu.Unlock() onRemove() - }, l.running) + } + for _, proto := range tlsConf.NextProtos { l.protocols[proto] = protoConf{ ln: ln, @@ -167,16 +185,6 @@ type listener struct { var _ Listener = &listener{} -func newSingleListener(addr net.Addr, addrs []ma.Multiaddr, remove func(), running chan struct{}) *listener { - return &listener{ - queue: make(chan *quic.Conn, queueLen), - acceptLoopRunning: running, - remove: remove, - addr: addr, - addrs: addrs, - } -} - func (l *listener) add(c *quic.Conn) { select { case l.queue <- c: diff --git a/p2p/transport/quicreuse/reuse.go b/p2p/transport/quicreuse/reuse.go index 1593e2266..b13cda208 100644 --- a/p2p/transport/quicreuse/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -88,25 +88,44 @@ type refcountedTransport struct { // channel to signal to the owner that we are done with it. borrowDoneSignal chan struct{} - assocations map[any]struct{} + // Store associations as association -> set of listener objects + associations map[any]map[*listener]struct{} } type connContextFunc = func(context.Context, *quic.ClientInfo) (context.Context, error) -// associate an arbitrary value with this transport. +// associateForListener associates an arbitrary value with this transport for a specific listener. // This lets us "tag" the refcountedTransport when listening so we can use it -// later for dialing. Necessary for holepunching and learning about our own -// observed listening address. -func (c *refcountedTransport) associate(a any) { +// later for dialing. The listener parameter allows proper cleanup when the listener closes. +// Necessary for holepunching and learning about our own observed listening address. +func (c *refcountedTransport) associateForListener(a any, ln *listener) { if a == nil { return } c.mutex.Lock() defer c.mutex.Unlock() - if c.assocations == nil { - c.assocations = make(map[any]struct{}) + if c.associations == nil { + c.associations = make(map[any]map[*listener]struct{}) + } + if c.associations[a] == nil { + c.associations[a] = make(map[*listener]struct{}) + } + c.associations[a][ln] = struct{}{} +} + +// RemoveAssociationsForListener removes ALL associations added by a specific listener +func (c *refcountedTransport) RemoveAssociationsForListener(ln *listener) { + c.mutex.Lock() + defer c.mutex.Unlock() + + // Remove this listener from all associations + for association, listeners := range c.associations { + delete(listeners, ln) + // If no listeners remain for this association, remove the association entirely + if len(listeners) == 0 { + delete(c.associations, association) + } } - c.assocations[a] = struct{}{} } // hasAssociation returns true if the transport has the given association. @@ -117,8 +136,8 @@ func (c *refcountedTransport) hasAssociation(a any) bool { } c.mutex.Lock() defer c.mutex.Unlock() - _, ok := c.assocations[a] - return ok + listeners, ok := c.associations[a] + return ok && len(listeners) > 0 } func (c *refcountedTransport) IncreaseCount() { @@ -367,33 +386,6 @@ func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error return nil } -func (r *reuse) AssertTransportExists(tr RefCountedQUICTransport) error { - t, ok := tr.(*refcountedTransport) - if !ok { - return fmt.Errorf("invalid transport type: expected: *refcountedTransport, got: %T", tr) - } - laddr := t.LocalAddr().(*net.UDPAddr) - if laddr.IP.IsUnspecified() { - if lt, ok := r.globalListeners[laddr.Port]; ok { - if lt == t { - return nil - } - return errors.New("two global listeners on the same port") - } - return errors.New("transport not found") - } - if m, ok := r.unicast[laddr.IP.String()]; ok { - if lt, ok := m[laddr.Port]; ok { - if lt == t { - return nil - } - return errors.New("two unicast listeners on same ip:port") - } - return errors.New("transport not found") - } - return errors.New("transport not found") -} - func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) { r.mutex.Lock() defer r.mutex.Unlock()