quicreuse: clean up associations for closed listeners. (#3306)

Co-authored-by: Prithvi Shahi <shahi.prithvi@gmail.com>
Co-authored-by: sukun <sukunrt@gmail.com>
This commit is contained in:
Probot
2025-08-07 22:39:44 +05:30
committed by GitHub
parent b75e678e64
commit 02e583d319
4 changed files with 197 additions and 67 deletions

View File

@@ -2,13 +2,13 @@
// for reusing QUIC transports for various purposes, like listening & dialing, having
// multiple QUIC listeners on the same address with different ALPNs, and sharing the
// same address with non QUIC transports like WebRTC.
package quicreuse
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
@@ -224,7 +224,7 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr,
key := laddr.String()
entry, ok := c.quicListeners[key]
if !ok {
tr, err := c.transportForListen(association, netw, laddr)
tr, err := c.transportForListen(netw, laddr)
if err != nil {
return nil, err
}
@@ -234,20 +234,15 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr,
}
key = tr.LocalAddr().String()
entry = quicListenerEntry{ln: ln}
} else if c.enableReuseport && association != nil {
reuse, err := c.getReuse(netw)
if err != nil {
return nil, fmt.Errorf("reuse error: %w", err)
}
err = reuse.AssertTransportExists(entry.ln.transport)
if err != nil {
return nil, fmt.Errorf("reuse assert transport failed: %w", err)
}
if tr, ok := entry.ln.transport.(*refcountedTransport); ok {
tr.associate(association)
}
if c.enableReuseport && association != nil {
if _, ok := entry.ln.transport.(*refcountedTransport); !ok {
log.Warnf("reuseport is enabled, association is non-nil, but the transport is not a refcountedTransport.")
}
}
l, err := entry.ln.Add(tlsConf, allowWindowIncrease, func() { c.onListenerClosed(key) })
l, err := entry.ln.Add(association, tlsConf, allowWindowIncrease, func() {
c.onListenerClosed(key)
})
if err != nil {
if entry.refCount <= 0 {
entry.ln.Close()
@@ -296,7 +291,7 @@ func (c *ConnManager) SharedNonQUICPacketConn(_ string, laddr *net.UDPAddr) (net
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
}
func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (RefCountedQUICTransport, error) {
func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (RefCountedQUICTransport, error) {
if c.enableReuseport {
reuse, err := c.getReuse(network)
if err != nil {
@@ -306,7 +301,6 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
if err != nil {
return nil, err
}
tr.associate(association)
return tr, nil
}

View File

@@ -95,7 +95,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {
_, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil)
require.NoError(t, err)
quicTr, err := cm.transportForListen(nil, netw, naddr)
quicTr, err := cm.transportForListen(netw, naddr)
require.NoError(t, err)
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.(*wrappedQUICTransport).Conn.(quic.OOBCapablePacketConn); !ok {
@@ -489,3 +489,139 @@ func TestConnContext(t *testing.T) {
})
}
}
func TestAssociationCleanup(t *testing.T) {
cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
require.NoError(t, err)
defer cm.Close()
// Create 3 listeners with 3 different associations
lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}}
assoc1 := "test-association-1"
assoc2 := "test-association-2"
assoc3 := "test-association-3"
ln1, err := cm.ListenQUICAndAssociate(assoc1, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
require.NoError(t, err)
defer ln1.Close()
addr := ln1.Multiaddrs()[0]
port, err := addr.ValueForProtocol(ma.P_UDP)
require.NoError(t, err)
h3TLS := &tls.Config{NextProtos: []string{"h3"}}
ln2, err := cm.ListenQUICAndAssociate(assoc2, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic-v1", port)), h3TLS, nil)
require.NoError(t, err)
defer ln2.Close()
ln3, err := cm.ListenQUICAndAssociate(assoc3, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
require.NoError(t, err)
defer ln3.Close()
// Get the listen addresses for verification
addr1 := ln1.Addr().String()
addr2 := ln2.Addr().String()
addr3 := ln3.Addr().String()
require.Equal(t, addr1, addr2)
// Test that dialing with assoc1 uses the first listener's address
dialAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
numTries := 100
for i := 0; i < numTries; i++ {
tr, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr)
require.NoError(t, err)
require.Equal(t, addr1, tr.LocalAddr().String(), "assoc1 should use addr1")
}
// Close the first listener
ln1.Close()
// Call TransportWithAssociationForDial 10 times with assoc1 and check if we get at least one different address
foundDifferentAddr := false
for i := 0; i < numTries; i++ {
tr, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr)
require.NoError(t, err)
actualAddr := tr.LocalAddr().String()
if actualAddr != addr1 {
foundDifferentAddr = true
break
}
}
require.True(t, foundDifferentAddr, "assoc1 should use a different address than addr1 at least once after ln1 is closed")
for i := 0; i < numTries; i++ {
// Test that dialing with assoc2 still uses the second listener's address
tr2Still, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr)
require.NoError(t, err)
require.Equal(t, addr2, tr2Still.LocalAddr().String(), "assoc2 should still use addr2")
}
// Close the second listener
ln2.Close()
// Call TransportWithAssociationForDial 10 times with assoc2 and check if we get at least one different address
foundDifferentAddr2 := false
for i := 0; i < numTries; i++ {
tr, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr)
require.NoError(t, err)
actualAddr := tr.LocalAddr().String()
if actualAddr != addr2 {
foundDifferentAddr2 = true
}
}
require.True(t, foundDifferentAddr2, "assoc2 should use a different address than addr2 at least once after ln2 is closed")
for i := 0; i < numTries; i++ {
// Test that dialing with assoc3 still uses the third listener's address
tr3Still, err := cm.TransportWithAssociationForDial(assoc3, "udp4", dialAddr)
require.NoError(t, err)
require.Equal(t, addr3, tr3Still.LocalAddr().String(), "assoc3 should still use addr3")
}
}
func TestConnManagerIsolation(t *testing.T) {
// Create two separate ConnManager instances
cm1, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
require.NoError(t, err)
defer cm1.Close()
cm2, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
require.NoError(t, err)
defer cm2.Close()
// Create listeners in both ConnManagers
lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}}
assoc1 := "cm1-association"
assoc2 := "cm2-association"
ln1, err := cm1.ListenQUICAndAssociate(assoc1, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
require.NoError(t, err)
defer ln1.Close()
ln2, err := cm2.ListenQUICAndAssociate(assoc2, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
require.NoError(t, err)
defer ln2.Close()
// Verify that each ConnManager has its own isolated associations
// Verify associations are isolated
cm1.quicListenersMu.Lock()
key1 := ln1.Addr().String()
entry1 := cm1.quicListeners[key1]
tr1, ok := entry1.ln.transport.(*refcountedTransport)
require.True(t, ok)
require.True(t, tr1.hasAssociation(assoc1))
require.False(t, tr1.hasAssociation(assoc2))
cm1.quicListenersMu.Unlock()
cm2.quicListenersMu.Lock()
key2 := ln2.Addr().String()
entry2 := cm2.quicListeners[key2]
tr2, ok := entry2.ln.transport.(*refcountedTransport)
require.True(t, ok)
require.True(t, tr2.hasAssociation(assoc2))
require.False(t, tr2.hasAssociation(assoc1))
cm2.quicListenersMu.Unlock()
}

View File

@@ -91,7 +91,7 @@ func (l *quicListener) allowWindowIncrease(conn *quic.Conn, delta uint64) bool {
return conf.allowWindowIncrease(conn, delta)
}
func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn *quic.Conn, delta uint64) bool, onRemove func()) (Listener, error) {
func (l *quicListener) Add(association any, tlsConf *tls.Config, allowWindowIncrease func(conn *quic.Conn, delta uint64) bool, onRemove func()) (*listener, error) {
l.protocolsMu.Lock()
defer l.protocolsMu.Unlock()
@@ -105,14 +105,32 @@ func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn *q
}
}
ln := newSingleListener(l.l.Addr(), l.addrs, func() {
ln := &listener{
queue: make(chan *quic.Conn, queueLen),
acceptLoopRunning: l.running,
addr: l.l.Addr(),
addrs: l.addrs,
}
if association != nil {
if tr, ok := l.transport.(*refcountedTransport); ok {
tr.associateForListener(association, ln)
}
}
ln.remove = func() {
if association != nil {
if tr, ok := l.transport.(*refcountedTransport); ok {
tr.RemoveAssociationsForListener(ln)
}
}
l.protocolsMu.Lock()
for _, proto := range tlsConf.NextProtos {
delete(l.protocols, proto)
}
l.protocolsMu.Unlock()
onRemove()
}, l.running)
}
for _, proto := range tlsConf.NextProtos {
l.protocols[proto] = protoConf{
ln: ln,
@@ -167,16 +185,6 @@ type listener struct {
var _ Listener = &listener{}
func newSingleListener(addr net.Addr, addrs []ma.Multiaddr, remove func(), running chan struct{}) *listener {
return &listener{
queue: make(chan *quic.Conn, queueLen),
acceptLoopRunning: running,
remove: remove,
addr: addr,
addrs: addrs,
}
}
func (l *listener) add(c *quic.Conn) {
select {
case l.queue <- c:

View File

@@ -88,25 +88,44 @@ type refcountedTransport struct {
// channel to signal to the owner that we are done with it.
borrowDoneSignal chan struct{}
assocations map[any]struct{}
// Store associations as association -> set of listener objects
associations map[any]map[*listener]struct{}
}
type connContextFunc = func(context.Context, *quic.ClientInfo) (context.Context, error)
// associate an arbitrary value with this transport.
// associateForListener associates an arbitrary value with this transport for a specific listener.
// This lets us "tag" the refcountedTransport when listening so we can use it
// later for dialing. Necessary for holepunching and learning about our own
// observed listening address.
func (c *refcountedTransport) associate(a any) {
// later for dialing. The listener parameter allows proper cleanup when the listener closes.
// Necessary for holepunching and learning about our own observed listening address.
func (c *refcountedTransport) associateForListener(a any, ln *listener) {
if a == nil {
return
}
c.mutex.Lock()
defer c.mutex.Unlock()
if c.assocations == nil {
c.assocations = make(map[any]struct{})
if c.associations == nil {
c.associations = make(map[any]map[*listener]struct{})
}
if c.associations[a] == nil {
c.associations[a] = make(map[*listener]struct{})
}
c.associations[a][ln] = struct{}{}
}
// RemoveAssociationsForListener removes ALL associations added by a specific listener
func (c *refcountedTransport) RemoveAssociationsForListener(ln *listener) {
c.mutex.Lock()
defer c.mutex.Unlock()
// Remove this listener from all associations
for association, listeners := range c.associations {
delete(listeners, ln)
// If no listeners remain for this association, remove the association entirely
if len(listeners) == 0 {
delete(c.associations, association)
}
}
c.assocations[a] = struct{}{}
}
// hasAssociation returns true if the transport has the given association.
@@ -117,8 +136,8 @@ func (c *refcountedTransport) hasAssociation(a any) bool {
}
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.assocations[a]
return ok
listeners, ok := c.associations[a]
return ok && len(listeners) > 0
}
func (c *refcountedTransport) IncreaseCount() {
@@ -367,33 +386,6 @@ func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error
return nil
}
func (r *reuse) AssertTransportExists(tr RefCountedQUICTransport) error {
t, ok := tr.(*refcountedTransport)
if !ok {
return fmt.Errorf("invalid transport type: expected: *refcountedTransport, got: %T", tr)
}
laddr := t.LocalAddr().(*net.UDPAddr)
if laddr.IP.IsUnspecified() {
if lt, ok := r.globalListeners[laddr.Port]; ok {
if lt == t {
return nil
}
return errors.New("two global listeners on the same port")
}
return errors.New("transport not found")
}
if m, ok := r.unicast[laddr.IP.String()]; ok {
if lt, ok := m[laddr.Port]; ok {
if lt == t {
return nil
}
return errors.New("two unicast listeners on same ip:port")
}
return errors.New("transport not found")
}
return errors.New("transport not found")
}
func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
r.mutex.Lock()
defer r.mutex.Unlock()