mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-12-24 13:29:35 +08:00
quicreuse: clean up associations for closed listeners. (#3306)
Co-authored-by: Prithvi Shahi <shahi.prithvi@gmail.com> Co-authored-by: sukun <sukunrt@gmail.com>
This commit is contained in:
@@ -2,13 +2,13 @@
|
||||
// for reusing QUIC transports for various purposes, like listening & dialing, having
|
||||
// multiple QUIC listeners on the same address with different ALPNs, and sharing the
|
||||
// same address with non QUIC transports like WebRTC.
|
||||
|
||||
package quicreuse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -224,7 +224,7 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr,
|
||||
key := laddr.String()
|
||||
entry, ok := c.quicListeners[key]
|
||||
if !ok {
|
||||
tr, err := c.transportForListen(association, netw, laddr)
|
||||
tr, err := c.transportForListen(netw, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -234,20 +234,15 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr,
|
||||
}
|
||||
key = tr.LocalAddr().String()
|
||||
entry = quicListenerEntry{ln: ln}
|
||||
} else if c.enableReuseport && association != nil {
|
||||
reuse, err := c.getReuse(netw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reuse error: %w", err)
|
||||
}
|
||||
err = reuse.AssertTransportExists(entry.ln.transport)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reuse assert transport failed: %w", err)
|
||||
}
|
||||
if tr, ok := entry.ln.transport.(*refcountedTransport); ok {
|
||||
tr.associate(association)
|
||||
}
|
||||
if c.enableReuseport && association != nil {
|
||||
if _, ok := entry.ln.transport.(*refcountedTransport); !ok {
|
||||
log.Warnf("reuseport is enabled, association is non-nil, but the transport is not a refcountedTransport.")
|
||||
}
|
||||
}
|
||||
l, err := entry.ln.Add(tlsConf, allowWindowIncrease, func() { c.onListenerClosed(key) })
|
||||
l, err := entry.ln.Add(association, tlsConf, allowWindowIncrease, func() {
|
||||
c.onListenerClosed(key)
|
||||
})
|
||||
if err != nil {
|
||||
if entry.refCount <= 0 {
|
||||
entry.ln.Close()
|
||||
@@ -296,7 +291,7 @@ func (c *ConnManager) SharedNonQUICPacketConn(_ string, laddr *net.UDPAddr) (net
|
||||
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
|
||||
}
|
||||
|
||||
func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (RefCountedQUICTransport, error) {
|
||||
func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (RefCountedQUICTransport, error) {
|
||||
if c.enableReuseport {
|
||||
reuse, err := c.getReuse(network)
|
||||
if err != nil {
|
||||
@@ -306,7 +301,6 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tr.associate(association)
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {
|
||||
|
||||
_, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil)
|
||||
require.NoError(t, err)
|
||||
quicTr, err := cm.transportForListen(nil, netw, naddr)
|
||||
quicTr, err := cm.transportForListen(netw, naddr)
|
||||
require.NoError(t, err)
|
||||
defer quicTr.Close()
|
||||
if _, ok := quicTr.(*singleOwnerTransport).Transport.(*wrappedQUICTransport).Conn.(quic.OOBCapablePacketConn); !ok {
|
||||
@@ -489,3 +489,139 @@ func TestConnContext(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssociationCleanup(t *testing.T) {
|
||||
cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
|
||||
require.NoError(t, err)
|
||||
defer cm.Close()
|
||||
|
||||
// Create 3 listeners with 3 different associations
|
||||
lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}}
|
||||
assoc1 := "test-association-1"
|
||||
assoc2 := "test-association-2"
|
||||
assoc3 := "test-association-3"
|
||||
|
||||
ln1, err := cm.ListenQUICAndAssociate(assoc1, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln1.Close()
|
||||
|
||||
addr := ln1.Multiaddrs()[0]
|
||||
port, err := addr.ValueForProtocol(ma.P_UDP)
|
||||
require.NoError(t, err)
|
||||
|
||||
h3TLS := &tls.Config{NextProtos: []string{"h3"}}
|
||||
ln2, err := cm.ListenQUICAndAssociate(assoc2, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic-v1", port)), h3TLS, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln2.Close()
|
||||
|
||||
ln3, err := cm.ListenQUICAndAssociate(assoc3, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln3.Close()
|
||||
|
||||
// Get the listen addresses for verification
|
||||
addr1 := ln1.Addr().String()
|
||||
addr2 := ln2.Addr().String()
|
||||
addr3 := ln3.Addr().String()
|
||||
require.Equal(t, addr1, addr2)
|
||||
|
||||
// Test that dialing with assoc1 uses the first listener's address
|
||||
dialAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
|
||||
|
||||
numTries := 100
|
||||
|
||||
for i := 0; i < numTries; i++ {
|
||||
tr, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr1, tr.LocalAddr().String(), "assoc1 should use addr1")
|
||||
}
|
||||
|
||||
// Close the first listener
|
||||
ln1.Close()
|
||||
|
||||
// Call TransportWithAssociationForDial 10 times with assoc1 and check if we get at least one different address
|
||||
foundDifferentAddr := false
|
||||
for i := 0; i < numTries; i++ {
|
||||
tr, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr)
|
||||
require.NoError(t, err)
|
||||
actualAddr := tr.LocalAddr().String()
|
||||
if actualAddr != addr1 {
|
||||
foundDifferentAddr = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundDifferentAddr, "assoc1 should use a different address than addr1 at least once after ln1 is closed")
|
||||
|
||||
for i := 0; i < numTries; i++ {
|
||||
// Test that dialing with assoc2 still uses the second listener's address
|
||||
tr2Still, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr2, tr2Still.LocalAddr().String(), "assoc2 should still use addr2")
|
||||
}
|
||||
|
||||
// Close the second listener
|
||||
ln2.Close()
|
||||
|
||||
// Call TransportWithAssociationForDial 10 times with assoc2 and check if we get at least one different address
|
||||
foundDifferentAddr2 := false
|
||||
for i := 0; i < numTries; i++ {
|
||||
tr, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr)
|
||||
require.NoError(t, err)
|
||||
actualAddr := tr.LocalAddr().String()
|
||||
if actualAddr != addr2 {
|
||||
foundDifferentAddr2 = true
|
||||
}
|
||||
}
|
||||
require.True(t, foundDifferentAddr2, "assoc2 should use a different address than addr2 at least once after ln2 is closed")
|
||||
|
||||
for i := 0; i < numTries; i++ {
|
||||
// Test that dialing with assoc3 still uses the third listener's address
|
||||
tr3Still, err := cm.TransportWithAssociationForDial(assoc3, "udp4", dialAddr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr3, tr3Still.LocalAddr().String(), "assoc3 should still use addr3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnManagerIsolation(t *testing.T) {
|
||||
// Create two separate ConnManager instances
|
||||
cm1, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
|
||||
require.NoError(t, err)
|
||||
defer cm1.Close()
|
||||
|
||||
cm2, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
|
||||
require.NoError(t, err)
|
||||
defer cm2.Close()
|
||||
|
||||
// Create listeners in both ConnManagers
|
||||
lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}}
|
||||
assoc1 := "cm1-association"
|
||||
assoc2 := "cm2-association"
|
||||
|
||||
ln1, err := cm1.ListenQUICAndAssociate(assoc1, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln1.Close()
|
||||
|
||||
ln2, err := cm2.ListenQUICAndAssociate(assoc2, ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), lp2pTLS, nil)
|
||||
require.NoError(t, err)
|
||||
defer ln2.Close()
|
||||
|
||||
// Verify that each ConnManager has its own isolated associations
|
||||
|
||||
// Verify associations are isolated
|
||||
cm1.quicListenersMu.Lock()
|
||||
key1 := ln1.Addr().String()
|
||||
entry1 := cm1.quicListeners[key1]
|
||||
tr1, ok := entry1.ln.transport.(*refcountedTransport)
|
||||
require.True(t, ok)
|
||||
require.True(t, tr1.hasAssociation(assoc1))
|
||||
require.False(t, tr1.hasAssociation(assoc2))
|
||||
cm1.quicListenersMu.Unlock()
|
||||
|
||||
cm2.quicListenersMu.Lock()
|
||||
key2 := ln2.Addr().String()
|
||||
entry2 := cm2.quicListeners[key2]
|
||||
tr2, ok := entry2.ln.transport.(*refcountedTransport)
|
||||
require.True(t, ok)
|
||||
require.True(t, tr2.hasAssociation(assoc2))
|
||||
require.False(t, tr2.hasAssociation(assoc1))
|
||||
cm2.quicListenersMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func (l *quicListener) allowWindowIncrease(conn *quic.Conn, delta uint64) bool {
|
||||
return conf.allowWindowIncrease(conn, delta)
|
||||
}
|
||||
|
||||
func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn *quic.Conn, delta uint64) bool, onRemove func()) (Listener, error) {
|
||||
func (l *quicListener) Add(association any, tlsConf *tls.Config, allowWindowIncrease func(conn *quic.Conn, delta uint64) bool, onRemove func()) (*listener, error) {
|
||||
l.protocolsMu.Lock()
|
||||
defer l.protocolsMu.Unlock()
|
||||
|
||||
@@ -105,14 +105,32 @@ func (l *quicListener) Add(tlsConf *tls.Config, allowWindowIncrease func(conn *q
|
||||
}
|
||||
}
|
||||
|
||||
ln := newSingleListener(l.l.Addr(), l.addrs, func() {
|
||||
ln := &listener{
|
||||
queue: make(chan *quic.Conn, queueLen),
|
||||
acceptLoopRunning: l.running,
|
||||
addr: l.l.Addr(),
|
||||
addrs: l.addrs,
|
||||
}
|
||||
if association != nil {
|
||||
if tr, ok := l.transport.(*refcountedTransport); ok {
|
||||
tr.associateForListener(association, ln)
|
||||
}
|
||||
}
|
||||
|
||||
ln.remove = func() {
|
||||
if association != nil {
|
||||
if tr, ok := l.transport.(*refcountedTransport); ok {
|
||||
tr.RemoveAssociationsForListener(ln)
|
||||
}
|
||||
}
|
||||
l.protocolsMu.Lock()
|
||||
for _, proto := range tlsConf.NextProtos {
|
||||
delete(l.protocols, proto)
|
||||
}
|
||||
l.protocolsMu.Unlock()
|
||||
onRemove()
|
||||
}, l.running)
|
||||
}
|
||||
|
||||
for _, proto := range tlsConf.NextProtos {
|
||||
l.protocols[proto] = protoConf{
|
||||
ln: ln,
|
||||
@@ -167,16 +185,6 @@ type listener struct {
|
||||
|
||||
var _ Listener = &listener{}
|
||||
|
||||
func newSingleListener(addr net.Addr, addrs []ma.Multiaddr, remove func(), running chan struct{}) *listener {
|
||||
return &listener{
|
||||
queue: make(chan *quic.Conn, queueLen),
|
||||
acceptLoopRunning: running,
|
||||
remove: remove,
|
||||
addr: addr,
|
||||
addrs: addrs,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *listener) add(c *quic.Conn) {
|
||||
select {
|
||||
case l.queue <- c:
|
||||
|
||||
@@ -88,25 +88,44 @@ type refcountedTransport struct {
|
||||
// channel to signal to the owner that we are done with it.
|
||||
borrowDoneSignal chan struct{}
|
||||
|
||||
assocations map[any]struct{}
|
||||
// Store associations as association -> set of listener objects
|
||||
associations map[any]map[*listener]struct{}
|
||||
}
|
||||
|
||||
type connContextFunc = func(context.Context, *quic.ClientInfo) (context.Context, error)
|
||||
|
||||
// associate an arbitrary value with this transport.
|
||||
// associateForListener associates an arbitrary value with this transport for a specific listener.
|
||||
// This lets us "tag" the refcountedTransport when listening so we can use it
|
||||
// later for dialing. Necessary for holepunching and learning about our own
|
||||
// observed listening address.
|
||||
func (c *refcountedTransport) associate(a any) {
|
||||
// later for dialing. The listener parameter allows proper cleanup when the listener closes.
|
||||
// Necessary for holepunching and learning about our own observed listening address.
|
||||
func (c *refcountedTransport) associateForListener(a any, ln *listener) {
|
||||
if a == nil {
|
||||
return
|
||||
}
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if c.assocations == nil {
|
||||
c.assocations = make(map[any]struct{})
|
||||
if c.associations == nil {
|
||||
c.associations = make(map[any]map[*listener]struct{})
|
||||
}
|
||||
if c.associations[a] == nil {
|
||||
c.associations[a] = make(map[*listener]struct{})
|
||||
}
|
||||
c.associations[a][ln] = struct{}{}
|
||||
}
|
||||
|
||||
// RemoveAssociationsForListener removes ALL associations added by a specific listener
|
||||
func (c *refcountedTransport) RemoveAssociationsForListener(ln *listener) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Remove this listener from all associations
|
||||
for association, listeners := range c.associations {
|
||||
delete(listeners, ln)
|
||||
// If no listeners remain for this association, remove the association entirely
|
||||
if len(listeners) == 0 {
|
||||
delete(c.associations, association)
|
||||
}
|
||||
}
|
||||
c.assocations[a] = struct{}{}
|
||||
}
|
||||
|
||||
// hasAssociation returns true if the transport has the given association.
|
||||
@@ -117,8 +136,8 @@ func (c *refcountedTransport) hasAssociation(a any) bool {
|
||||
}
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
_, ok := c.assocations[a]
|
||||
return ok
|
||||
listeners, ok := c.associations[a]
|
||||
return ok && len(listeners) > 0
|
||||
}
|
||||
|
||||
func (c *refcountedTransport) IncreaseCount() {
|
||||
@@ -367,33 +386,6 @@ func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *reuse) AssertTransportExists(tr RefCountedQUICTransport) error {
|
||||
t, ok := tr.(*refcountedTransport)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid transport type: expected: *refcountedTransport, got: %T", tr)
|
||||
}
|
||||
laddr := t.LocalAddr().(*net.UDPAddr)
|
||||
if laddr.IP.IsUnspecified() {
|
||||
if lt, ok := r.globalListeners[laddr.Port]; ok {
|
||||
if lt == t {
|
||||
return nil
|
||||
}
|
||||
return errors.New("two global listeners on the same port")
|
||||
}
|
||||
return errors.New("transport not found")
|
||||
}
|
||||
if m, ok := r.unicast[laddr.IP.String()]; ok {
|
||||
if lt, ok := m[laddr.Port]; ok {
|
||||
if lt == t {
|
||||
return nil
|
||||
}
|
||||
return errors.New("two unicast listeners on same ip:port")
|
||||
}
|
||||
return errors.New("transport not found")
|
||||
}
|
||||
return errors.New("transport not found")
|
||||
}
|
||||
|
||||
func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user