fix(dcutr): Fix end to end tests and add legacy behavior flag (default=true) (#3044)

This commit is contained in:
Marco Munizaga
2025-02-24 20:53:47 -08:00
committed by GitHub
parent 3e51326ff1
commit 93014e1148
11 changed files with 1216 additions and 232 deletions

179
p2p/net/simconn/router.go Normal file
View File

@@ -0,0 +1,179 @@
package simconn
import (
"errors"
"fmt"
"net"
"sync"
"time"
)
type PacketReciever interface {
RecvPacket(p Packet)
}
// PerfectRouter is a router that has no latency or jitter and can route to
// every node
type PerfectRouter struct {
mu sync.Mutex
nodes map[net.Addr]PacketReciever
}
// SendPacket implements Router.
func (r *PerfectRouter) SendPacket(p Packet) error {
r.mu.Lock()
defer r.mu.Unlock()
conn, ok := r.nodes[p.To]
if !ok {
return errors.New("unknown destination")
}
conn.RecvPacket(p)
return nil
}
func (r *PerfectRouter) AddNode(addr net.Addr, conn PacketReciever) {
r.mu.Lock()
defer r.mu.Unlock()
if r.nodes == nil {
r.nodes = make(map[net.Addr]PacketReciever)
}
r.nodes[addr] = conn
}
func (r *PerfectRouter) RemoveNode(addr net.Addr) {
delete(r.nodes, addr)
}
var _ Router = &PerfectRouter{}
type DelayedPacketReciever struct {
inner PacketReciever
delay time.Duration
}
func (r *DelayedPacketReciever) RecvPacket(p Packet) {
time.AfterFunc(r.delay, func() { r.inner.RecvPacket(p) })
}
type FixedLatencyRouter struct {
PerfectRouter
latency time.Duration
}
func (r *FixedLatencyRouter) SendPacket(p Packet) error {
return r.PerfectRouter.SendPacket(p)
}
func (r *FixedLatencyRouter) AddNode(addr net.Addr, conn PacketReciever) {
r.PerfectRouter.AddNode(addr, &DelayedPacketReciever{
inner: conn,
delay: r.latency,
})
}
var _ Router = &FixedLatencyRouter{}
type simpleNodeFirewall struct {
mu sync.Mutex
publiclyReachable bool
packetsOutTo map[string]struct{}
node *SimConn
}
func (f *simpleNodeFirewall) MarkPacketSentOut(p Packet) {
f.mu.Lock()
defer f.mu.Unlock()
if f.packetsOutTo == nil {
f.packetsOutTo = make(map[string]struct{})
}
f.packetsOutTo[p.To.String()] = struct{}{}
}
func (f *simpleNodeFirewall) IsPacketInAllowed(p Packet) bool {
f.mu.Lock()
defer f.mu.Unlock()
if f.publiclyReachable {
return true
}
_, ok := f.packetsOutTo[p.From.String()]
return ok
}
func (f *simpleNodeFirewall) String() string {
return fmt.Sprintf("public: %v, packetsOutTo: %v", f.publiclyReachable, f.packetsOutTo)
}
type SimpleFirewallRouter struct {
mu sync.Mutex
nodes map[string]*simpleNodeFirewall
}
func (r *SimpleFirewallRouter) String() string {
r.mu.Lock()
defer r.mu.Unlock()
nodes := make([]string, 0, len(r.nodes))
for _, node := range r.nodes {
nodes = append(nodes, node.String())
}
return fmt.Sprintf("%v", nodes)
}
func (r *SimpleFirewallRouter) SendPacket(p Packet) error {
r.mu.Lock()
defer r.mu.Unlock()
toNode, exists := r.nodes[p.To.String()]
if !exists {
return errors.New("unknown destination")
}
// Record that this node is sending a packet to the destination
fromNode, exists := r.nodes[p.From.String()]
if !exists {
return errors.New("unknown source")
}
fromNode.MarkPacketSentOut(p)
if !toNode.IsPacketInAllowed(p) {
return nil // Silently drop blocked packets
}
toNode.node.RecvPacket(p)
return nil
}
func (r *SimpleFirewallRouter) AddNode(addr net.Addr, conn *SimConn) {
r.mu.Lock()
defer r.mu.Unlock()
if r.nodes == nil {
r.nodes = make(map[string]*simpleNodeFirewall)
}
r.nodes[addr.String()] = &simpleNodeFirewall{
packetsOutTo: make(map[string]struct{}),
node: conn,
}
}
func (r *SimpleFirewallRouter) AddPubliclyReachableNode(addr net.Addr, conn *SimConn) {
r.mu.Lock()
defer r.mu.Unlock()
if r.nodes == nil {
r.nodes = make(map[string]*simpleNodeFirewall)
}
r.nodes[addr.String()] = &simpleNodeFirewall{
publiclyReachable: true,
node: conn,
}
}
func (r *SimpleFirewallRouter) RemoveNode(addr net.Addr) {
r.mu.Lock()
defer r.mu.Unlock()
if r.nodes == nil {
return
}
delete(r.nodes, addr.String())
}
var _ Router = &SimpleFirewallRouter{}

218
p2p/net/simconn/simconn.go Normal file
View File

@@ -0,0 +1,218 @@
package simconn
import (
"errors"
"net"
"slices"
"sync"
"sync/atomic"
"time"
)
var ErrDeadlineExceeded = errors.New("deadline exceeded")
type Router interface {
SendPacket(p Packet) error
}
type Packet struct {
To net.Addr
From net.Addr
buf []byte
}
type SimConn struct {
mu sync.Mutex
closed bool
closedChan chan struct{}
packetsSent atomic.Uint64
packetsRcvd atomic.Uint64
bytesSent atomic.Int64
bytesRcvd atomic.Int64
router Router
myAddr *net.UDPAddr
myLocalAddr net.Addr
packetsToRead chan Packet
readDeadline time.Time
writeDeadline time.Time
}
// NewSimConn creates a new simulated connection with the specified parameters
func NewSimConn(addr *net.UDPAddr, rtr Router) *SimConn {
return &SimConn{
router: rtr,
myAddr: addr,
packetsToRead: make(chan Packet, 512), // buffered channel to prevent blocking
closedChan: make(chan struct{}),
}
}
type ConnStats struct {
BytesSent int
BytesRcvd int
PacketsSent int
PacketsRcvd int
}
func (c *SimConn) Stats() ConnStats {
return ConnStats{
BytesSent: int(c.bytesSent.Load()),
BytesRcvd: int(c.bytesRcvd.Load()),
PacketsSent: int(c.packetsSent.Load()),
PacketsRcvd: int(c.packetsRcvd.Load()),
}
}
// SetLocalAddr only changes what `.LocalAddr()` returns.
// Packets will still come From the initially configured addr.
func (c *SimConn) SetLocalAddr(addr net.Addr) {
c.myLocalAddr = addr
}
func (c *SimConn) RecvPacket(p Packet) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.mu.Unlock()
c.packetsRcvd.Add(1)
c.bytesRcvd.Add(int64(len(p.buf)))
select {
case c.packetsToRead <- p:
default:
// drop the packet if the channel is full
}
}
var _ net.PacketConn = &SimConn{}
// Close implements net.PacketConn
func (c *SimConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.closedChan)
return nil
}
// ReadFrom implements net.PacketConn
func (c *SimConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, nil, net.ErrClosed
}
deadline := c.readDeadline
c.mu.Unlock()
if !deadline.IsZero() && time.Now().After(deadline) {
return 0, nil, ErrDeadlineExceeded
}
var pkt Packet
if !deadline.IsZero() {
select {
case pkt = <-c.packetsToRead:
case <-time.After(time.Until(deadline)):
return 0, nil, ErrDeadlineExceeded
}
} else {
pkt = <-c.packetsToRead
}
n = copy(p, pkt.buf)
// if the provided buffer is not enough to read the whole packet, we drop
// the rest of the data. this is similar to what `recvfrom` does on Linux
// and macOS.
return n, pkt.From, nil
}
// WriteTo implements net.PacketConn
func (c *SimConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, net.ErrClosed
}
deadline := c.writeDeadline
c.mu.Unlock()
if !deadline.IsZero() && time.Now().After(deadline) {
return 0, ErrDeadlineExceeded
}
c.packetsSent.Add(1)
c.bytesSent.Add(int64(len(p)))
pkt := Packet{
From: c.myAddr,
To: addr,
buf: slices.Clone(p),
}
return len(p), c.router.SendPacket(pkt)
}
func (c *SimConn) UnicastAddr() net.Addr {
return c.myAddr
}
// LocalAddr implements net.PacketConn
func (c *SimConn) LocalAddr() net.Addr {
if c.myLocalAddr != nil {
return c.myLocalAddr
}
return c.myAddr
}
// SetDeadline implements net.PacketConn
func (c *SimConn) SetDeadline(t time.Time) error {
c.mu.Lock()
defer c.mu.Unlock()
c.readDeadline = t
c.writeDeadline = t
return nil
}
// SetReadDeadline implements net.PacketConn
func (c *SimConn) SetReadDeadline(t time.Time) error {
c.mu.Lock()
defer c.mu.Unlock()
c.readDeadline = t
return nil
}
// SetWriteDeadline implements net.PacketConn
func (c *SimConn) SetWriteDeadline(t time.Time) error {
c.mu.Lock()
defer c.mu.Unlock()
c.writeDeadline = t
return nil
}
func IntToPublicIPv4(n int) net.IP {
n += 1
// Avoid private IP ranges
b := make([]byte, 4)
b[0] = byte((n>>24)&0xFF | 1)
b[1] = byte((n >> 16) & 0xFF)
b[2] = byte((n >> 8) & 0xFF)
b[3] = byte(n & 0xFF)
ip := net.IPv4(b[0], b[1], b[2], b[3])
// Check and modify if it's in private ranges
if ip.IsPrivate() {
b[0] = 1 // Use 1.x.x.x as public range
}
return ip
}

View File

@@ -0,0 +1,315 @@
package simconn
import (
"bytes"
"net"
"sync"
"testing"
"testing/quick"
"time"
"github.com/stretchr/testify/require"
)
func TestSimConnBasicConnectivity(t *testing.T) {
router := &PerfectRouter{}
// Create two endpoints
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234}
conn1 := NewSimConn(addr1, router)
conn2 := NewSimConn(addr2, router)
router.AddNode(addr1, conn1)
router.AddNode(addr2, conn2)
// Test sending data from conn1 to conn2
testData := []byte("hello world")
n, err := conn1.WriteTo(testData, addr2)
require.NoError(t, err)
require.Equal(t, len(testData), n)
// Read data from conn2
buf := make([]byte, 1024)
n, addr, err := conn2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, testData, buf[:n])
require.Equal(t, addr1, addr)
// Check stats
stats1 := conn1.Stats()
require.Equal(t, len(testData), stats1.BytesSent)
require.Equal(t, 1, stats1.PacketsSent)
stats2 := conn2.Stats()
require.Equal(t, len(testData), stats2.BytesRcvd)
require.Equal(t, 1, stats2.PacketsRcvd)
}
func TestSimConnDeadlines(t *testing.T) {
router := &PerfectRouter{}
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
conn := NewSimConn(addr1, router)
router.AddNode(addr1, conn)
t.Run("read deadline", func(t *testing.T) {
deadline := time.Now().Add(10 * time.Millisecond)
err := conn.SetReadDeadline(deadline)
require.NoError(t, err)
buf := make([]byte, 1024)
_, _, err = conn.ReadFrom(buf)
require.ErrorIs(t, err, ErrDeadlineExceeded)
})
t.Run("write deadline", func(t *testing.T) {
deadline := time.Now().Add(-time.Second) // Already expired
err := conn.SetWriteDeadline(deadline)
require.NoError(t, err)
_, err = conn.WriteTo([]byte("test"), &net.UDPAddr{})
require.ErrorIs(t, err, ErrDeadlineExceeded)
})
}
func TestSimConnClose(t *testing.T) {
router := &PerfectRouter{}
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
conn := NewSimConn(addr1, router)
router.AddNode(addr1, conn)
err := conn.Close()
require.NoError(t, err)
// Verify operations fail after close
_, err = conn.WriteTo([]byte("test"), addr1)
require.ErrorIs(t, err, net.ErrClosed)
buf := make([]byte, 1024)
_, _, err = conn.ReadFrom(buf)
require.ErrorIs(t, err, net.ErrClosed)
// Second close should not error
err = conn.Close()
require.NoError(t, err)
}
func TestSimConnLocalAddr(t *testing.T) {
router := &PerfectRouter{}
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
conn := NewSimConn(addr1, router)
// Test default local address
require.Equal(t, addr1, conn.LocalAddr())
// Test setting custom local address
customAddr := &net.UDPAddr{IP: IntToPublicIPv4(3), Port: 5678}
conn.SetLocalAddr(customAddr)
require.Equal(t, customAddr, conn.LocalAddr())
}
func TestSimConnDeadlinesWithLatency(t *testing.T) {
router := &FixedLatencyRouter{
PerfectRouter: PerfectRouter{},
latency: 100 * time.Millisecond,
}
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234}
conn1 := NewSimConn(addr1, router)
conn2 := NewSimConn(addr2, router)
router.AddNode(addr1, conn1)
router.AddNode(addr2, conn2)
reset := func() {
router.RemoveNode(addr1)
router.RemoveNode(addr2)
conn1 = NewSimConn(addr1, router)
conn2 = NewSimConn(addr2, router)
router.AddNode(addr1, conn1)
router.AddNode(addr2, conn2)
}
t.Run("write succeeds within deadline", func(t *testing.T) {
deadline := time.Now().Add(200 * time.Millisecond)
err := conn1.SetWriteDeadline(deadline)
require.NoError(t, err)
n, err := conn1.WriteTo([]byte("test"), addr2)
require.NoError(t, err)
require.Equal(t, 4, n)
reset()
})
t.Run("write fails after past deadline", func(t *testing.T) {
deadline := time.Now().Add(-time.Second) // Already expired
err := conn1.SetWriteDeadline(deadline)
require.NoError(t, err)
_, err = conn1.WriteTo([]byte("test"), addr2)
require.ErrorIs(t, err, ErrDeadlineExceeded)
reset()
})
t.Run("read succeeds within deadline", func(t *testing.T) {
// Reset deadline and send a message
conn2.SetReadDeadline(time.Time{})
testData := []byte("hello")
deadline := time.Now().Add(200 * time.Millisecond)
conn1.SetWriteDeadline(deadline)
_, err := conn1.WriteTo(testData, addr2)
require.NoError(t, err)
// Set read deadline and try to read
deadline = time.Now().Add(200 * time.Millisecond)
err = conn2.SetReadDeadline(deadline)
require.NoError(t, err)
buf := make([]byte, 1024)
n, addr, err := conn2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, addr1, addr)
require.Equal(t, testData, buf[:n])
reset()
})
t.Run("read fails after deadline", func(t *testing.T) {
defer reset()
// Set a short deadline
deadline := time.Now().Add(50 * time.Millisecond) // Less than router latency
err := conn2.SetReadDeadline(deadline)
require.NoError(t, err)
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(1)
go func() {
defer wg.Done()
// Send data after setting deadline
_, err := conn1.WriteTo([]byte("test"), addr2)
require.NoError(t, err)
}()
// Read should fail due to deadline
buf := make([]byte, 1024)
_, _, err = conn2.ReadFrom(buf)
require.ErrorIs(t, err, ErrDeadlineExceeded)
})
}
func TestSimpleHolePunch(t *testing.T) {
router := &SimpleFirewallRouter{
nodes: make(map[string]*simpleNodeFirewall),
}
// Create two peers
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234}
peer1 := NewSimConn(addr1, router)
peer2 := NewSimConn(addr2, router)
router.AddNode(addr1, peer1)
router.AddNode(addr2, peer2)
reset := func() {
router.RemoveNode(addr1)
router.RemoveNode(addr2)
peer1 = NewSimConn(addr1, router)
peer2 = NewSimConn(addr2, router)
router.AddNode(addr1, peer1)
router.AddNode(addr2, peer2)
}
// Initially, direct communication between peer1 and peer2 should fail
t.Run("direct communication blocked initially", func(t *testing.T) {
_, err := peer1.WriteTo([]byte("direct message"), addr2)
require.NoError(t, err) // Write succeeds but packet is dropped
// Try to read from peer2
peer2.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
buf := make([]byte, 1024)
_, _, err = peer2.ReadFrom(buf)
require.ErrorIs(t, err, ErrDeadlineExceeded)
reset()
})
holePunchMsg := []byte("hole punch")
// Simulate hole punching
t.Run("hole punch and direct communication", func(t *testing.T) {
// Both peers send packets to each other simultaneously
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := peer1.WriteTo(holePunchMsg, addr2)
require.NoError(t, err)
}()
go func() {
defer wg.Done()
_, err := peer2.WriteTo(holePunchMsg, addr1)
require.NoError(t, err)
}()
wg.Wait()
// Now direct communication should work both ways
t.Run("peer1 to peer2", func(t *testing.T) {
testMsg := []byte("direct message after hole punch")
_, err := peer1.WriteTo(testMsg, addr2)
require.NoError(t, err)
buf := make([]byte, 1024)
peer2.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, addr, err := peer2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, addr1, addr)
if bytes.Equal(buf[:n], holePunchMsg) {
// Read again to get the actual message
n, addr, err = peer2.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, addr1, addr)
}
require.Equal(t, string(testMsg), string(buf[:n]))
})
t.Run("peer2 to peer1", func(t *testing.T) {
testMsg := []byte("response from peer2")
_, err := peer2.WriteTo(testMsg, addr1)
require.NoError(t, err)
buf := make([]byte, 1024)
peer1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, addr, err := peer1.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, addr2, addr)
if bytes.Equal(buf[:n], holePunchMsg) {
// Read again to get the actual message
n, addr, err = peer1.ReadFrom(buf)
require.NoError(t, err)
require.Equal(t, addr2, addr)
}
require.Equal(t, string(testMsg), string(buf[:n]))
})
})
}
func TestPublicIP(t *testing.T) {
err := quick.Check(func(n int) bool {
ip := IntToPublicIPv4(n)
return !ip.IsPrivate()
}, nil)
require.NoError(t, err)
}

View File

@@ -2,28 +2,30 @@ package holepunch_test
import (
"context"
"fmt"
"net"
"slices"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/libp2p/go-libp2p"
"github.com/libp2p/go-libp2p-testing/race"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"
"github.com/libp2p/go-libp2p/p2p/net/simconn"
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
holepunch_pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb"
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"go.uber.org/fx"
"github.com/libp2p/go-msgio/pbio"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -81,90 +83,239 @@ func (s *mockIDService) OwnObservedAddrs() []ma.Multiaddr {
}
func TestNoHolePunchIfDirectConnExists(t *testing.T) {
tr := &mockEventTracer{}
h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
defer h1.Close()
h2, _ := mkHostWithHolePunchSvc(t)
defer h2.Close()
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
ID: h2.ID(),
Addrs: h2.Addrs(),
}))
time.Sleep(50 * time.Millisecond)
nc1 := len(h1.Network().ConnsToPeer(h2.ID()))
require.GreaterOrEqual(t, nc1, 1)
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
require.GreaterOrEqual(t, nc2, 1)
require.NoError(t, hps.DirectConnect(h2.ID()))
require.Len(t, h1.Network().ConnsToPeer(h2.ID()), nc1)
require.Len(t, h2.Network().ConnsToPeer(h1.ID()), nc2)
require.Empty(t, tr.getEvents())
}
func TestDirectDialWorks(t *testing.T) {
if race.WithRace() {
t.Skip("modifying manet.Private4 is racy")
}
// mark all addresses as public
cpy := manet.Private4
manet.Private4 = []*net.IPNet{}
defer func() { manet.Private4 = cpy }()
router := &simconn.SimpleFirewallRouter{}
relay := MustNewHost(t,
quicSimConn(true, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
libp2p.DisableRelay(),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
// Setup relay service
_, err := relayv2.New(h)
require.NoError(t, err)
})),
)
tr := &mockEventTracer{}
h1, h1ps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
defer h1.Close()
h2, _ := mkHostWithHolePunchSvc(t)
defer h2.Close()
h2.RemoveStreamHandler(holepunch.Protocol)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
h1 := MustNewHost(t,
quicSimConn(false, router),
libp2p.EnableHolePunching(holepunch.DirectDialTimeout(100*time.Millisecond)),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
)
// try to hole punch without any connection and streams, if it works -> it's a direct connection
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
require.NoError(t, h1ps.DirectConnect(h2.ID()))
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1)
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1)
events := tr.getEvents()
require.Len(t, events, 1)
require.Equal(t, holepunch.DirectDialEvtT, events[0].Type)
}
h2 := MustNewHost(t,
quicSimConn(true, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.ForceReachabilityPublic(),
connectToRelay(&relay),
libp2p.EnableHolePunching(holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100*time.Millisecond)),
)
func TestEndToEndSimConnect(t *testing.T) {
h1tr := &mockEventTracer{}
h2tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, []holepunch.Option{holepunch.WithTracer(h1tr)}, []holepunch.Option{holepunch.WithTracer(h2tr)}, true)
defer h1.Close()
defer h2.Close()
defer relay.Close()
// wait till a direct connection is complete
ensureDirectConn(t, h1, h2)
// ensure no hole-punching streams are open on either side
ensureNoHolePunchingStream(t, h1, h2)
var h2Events []*holepunch.Event
require.Eventually(t,
func() bool {
h2Events = h2tr.getEvents()
return len(h2Events) == 3
},
time.Second,
10*time.Millisecond,
)
require.Equal(t, holepunch.StartHolePunchEvtT, h2Events[0].Type)
require.Equal(t, holepunch.HolePunchAttemptEvtT, h2Events[1].Type)
require.Equal(t, holepunch.EndHolePunchEvtT, h2Events[2].Type)
waitForHolePunchingSvcActive(t, h1)
waitForHolePunchingSvcActive(t, h2)
h1Events := h1tr.getEvents()
// We don't really expect a hole-punched connection to be established in this test,
// as we probably don't get the timing right for the TCP simultaneous open.
// From time to time, it still happens occasionally, and then we get a EndHolePunchEvtT here.
if len(h1Events) != 2 && len(h1Events) != 3 {
t.Fatal("expected either 2 or 3 events")
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
// try to hole punch without any connection and streams, if it works -> it's a direct connection
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
pingAtoB(t, h1, h2)
nc1 := len(h1.Network().ConnsToPeer(h2.ID()))
require.Equal(t, nc1, 1)
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
require.Equal(t, nc2, 1)
assert.Never(t, func() bool {
return (len(h1.Network().ConnsToPeer(h2.ID())) != nc1 ||
len(h2.Network().ConnsToPeer(h1.ID())) != nc2 ||
len(tr.getEvents()) != 0)
}, time.Second, 100*time.Millisecond)
}
func TestDirectDialWorks(t *testing.T) {
router := &simconn.SimpleFirewallRouter{}
relay := MustNewHost(t,
quicSimConn(true, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
libp2p.DisableRelay(),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
// Setup relay service
_, err := relayv2.New(h)
require.NoError(t, err)
})),
)
tr := &mockEventTracer{}
// h1 is public
h1 := MustNewHost(t,
quicSimConn(true, router),
libp2p.ForceReachabilityPublic(),
libp2p.EnableHolePunching(holepunch.DirectDialTimeout(100*time.Millisecond)),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
)
h2 := MustNewHost(t,
quicSimConn(false, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
connectToRelay(&relay),
libp2p.EnableHolePunching(holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100*time.Millisecond)),
libp2p.ForceReachabilityPrivate(),
)
defer h1.Close()
defer h2.Close()
defer relay.Close()
// wait for dcutr to be available
waitForHolePunchingSvcActive(t, h2)
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
// try to hole punch without any connection and streams, if it works -> it's a direct connection
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
pingAtoB(t, h1, h2)
// require.NoError(t, h1ps.DirectConnect(h2.ID()))
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1)
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1)
require.EventuallyWithT(t, func(collect *assert.CollectT) {
events := tr.getEvents()
fmt.Println("events:", events)
if !assert.Len(collect, events, 1) {
return
}
assert.Equal(t, holepunch.DirectDialEvtT, events[0].Type)
}, 2*time.Second, 100*time.Millisecond)
}
func connectToRelay(relayPtr *host.Host) libp2p.Option {
return func(cfg *libp2p.Config) error {
if relayPtr == nil {
return nil
}
relay := *relayPtr
pi := peer.AddrInfo{
ID: relay.ID(),
Addrs: relay.Addrs(),
}
return cfg.Apply(
libp2p.EnableRelay(),
libp2p.EnableAutoRelayWithStaticRelays([]peer.AddrInfo{pi}),
)
}
require.Equal(t, holepunch.StartHolePunchEvtT, h1Events[0].Type)
require.Equal(t, holepunch.HolePunchAttemptEvtT, h1Events[1].Type)
if len(h1Events) == 3 {
require.Equal(t, holepunch.EndHolePunchEvtT, h1Events[2].Type)
}
func learnAddrs(h1, h2 host.Host) {
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.ConnectedAddrTTL)
}
func pingAtoB(t *testing.T, a, b host.Host) {
t.Helper()
p1 := ping.NewPingService(a)
require.NoError(t, a.Connect(context.Background(), peer.AddrInfo{
ID: b.ID(),
Addrs: b.Addrs(),
}))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
res := p1.Ping(ctx, b.ID())
result := <-res
require.NoError(t, result.Error)
}
func MustNewHost(t *testing.T, opts ...libp2p.Option) host.Host {
t.Helper()
h, err := libp2p.New(opts...)
require.NoError(t, err)
return h
}
func TestEndToEndSimConnect(t *testing.T) {
for _, useLegacyHolePunchingBehavior := range []bool{true, false} {
t.Run(fmt.Sprintf("legacy=%t", useLegacyHolePunchingBehavior), func(t *testing.T) {
h1tr := &mockEventTracer{}
h2tr := &mockEventTracer{}
router := &simconn.SimpleFirewallRouter{}
relay := MustNewHost(t,
quicSimConn(true, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
libp2p.DisableRelay(),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
// Setup relay service
_, err := relayv2.New(h)
require.NoError(t, err)
})),
)
h1 := MustNewHost(t,
quicSimConn(false, router),
libp2p.EnableHolePunching(holepunch.WithTracer(h1tr), holepunch.DirectDialTimeout(100*time.Millisecond), SetLegacyBehavior(useLegacyHolePunchingBehavior)),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.ForceReachabilityPrivate(),
)
h2 := MustNewHost(t,
quicSimConn(false, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
connectToRelay(&relay),
libp2p.EnableHolePunching(holepunch.WithTracer(h2tr), holepunch.DirectDialTimeout(100*time.Millisecond), SetLegacyBehavior(useLegacyHolePunchingBehavior)),
libp2p.ForceReachabilityPrivate(),
)
defer h1.Close()
defer h2.Close()
defer relay.Close()
// Wait for holepunch service to start
waitForHolePunchingSvcActive(t, h1)
waitForHolePunchingSvcActive(t, h2)
learnAddrs(h1, h2)
pingAtoB(t, h1, h2)
// wait till a direct connection is complete
ensureDirectConn(t, h1, h2)
// ensure no hole-punching streams are open on either side
ensureNoHolePunchingStream(t, h1, h2)
var h2Events []*holepunch.Event
require.Eventually(t,
func() bool {
h2Events = h2tr.getEvents()
return len(h2Events) == 4
},
time.Second,
100*time.Millisecond,
)
require.Equal(t, holepunch.DirectDialEvtT, h2Events[0].Type)
require.Equal(t, holepunch.StartHolePunchEvtT, h2Events[1].Type)
require.Equal(t, holepunch.HolePunchAttemptEvtT, h2Events[2].Type)
require.Equal(t, holepunch.EndHolePunchEvtT, h2Events[3].Type)
h1Events := h1tr.getEvents()
// We don't really expect a hole-punched connection to be established in this test,
// as we probably don't get the timing right for the TCP simultaneous open.
// From time to time, it still happens occasionally, and then we get a EndHolePunchEvtT here.
if len(h1Events) != 2 && len(h1Events) != 3 {
t.Fatal("expected either 2 or 3 events")
}
require.Equal(t, holepunch.StartHolePunchEvtT, h1Events[0].Type)
require.Equal(t, holepunch.HolePunchAttemptEvtT, h1Events[1].Type)
if len(h1Events) == 3 {
require.Equal(t, holepunch.EndHolePunchEvtT, h1Events[2].Type)
}
})
}
}
@@ -192,7 +343,7 @@ func TestFailuresOnInitiator(t *testing.T) {
errMsg: "failed to read CONNECT message",
},
"responder does NOT reply within hole punch deadline": {
holePunchTimeout: 10 * time.Millisecond,
holePunchTimeout: 200 * time.Millisecond,
rhandler: func(s network.Stream) { time.Sleep(5 * time.Second) },
errMsg: "i/o deadline reached",
},
@@ -213,13 +364,43 @@ func TestFailuresOnInitiator(t *testing.T) {
defer func() { holepunch.StreamTimeout = cpy }()
}
tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, nil, nil, false)
router := &simconn.SimpleFirewallRouter{}
relay := MustNewHost(t,
quicSimConn(true, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
libp2p.DisableRelay(),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
// Setup relay service
_, err := relayv2.New(h)
require.NoError(t, err)
})),
)
// h1 does not have a holepunching service because we'll mock the holepunching stream handler below.
h1 := MustNewHost(t,
quicSimConn(false, router),
libp2p.ForceReachabilityPrivate(),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
connectToRelay(&relay),
)
h2 := MustNewHost(t,
quicSimConn(false, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
connectToRelay(&relay),
)
defer h1.Close()
defer h2.Close()
defer relay.Close()
opts := []holepunch.Option{holepunch.WithTracer(tr)}
time.Sleep(100 * time.Millisecond)
tr := &mockEventTracer{}
opts := []holepunch.Option{holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100 * time.Millisecond)}
if tc.filter != nil {
f := mockMaddrFilter{
filterLocal: tc.filter,
@@ -228,19 +409,18 @@ func TestFailuresOnInitiator(t *testing.T) {
opts = append(opts, holepunch.WithAddrFilter(f))
}
hps := addHolePunchService(t, h2, opts...)
// wait until the hole punching protocol has actually started
require.Eventually(t, func() bool {
protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), holepunch.Protocol)
return len(protos) > 0
}, 200*time.Millisecond, 10*time.Millisecond)
hps := addHolePunchService(t, h2, []ma.Multiaddr{ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")}, opts...)
// We are only holepunching from h2 to h1. Remove h2's holepunching stream handler to avoid confusion.
h2.RemoveStreamHandler(holepunch.Protocol)
if tc.rhandler != nil {
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler)
} else {
h1.RemoveStreamHandler(holepunch.Protocol)
}
require.NoError(t, h2.Connect(context.Background(), peer.AddrInfo{
ID: h1.ID(),
Addrs: h1.Addrs(),
}))
err := hps.DirectConnect(h1.ID())
require.Error(t, err)
if tc.errMsg != "" {
@@ -325,7 +505,7 @@ func TestFailuresOnResponder(t *testing.T) {
}
tr := &mockEventTracer{}
opts := []holepunch.Option{holepunch.WithTracer(tr)}
opts := []holepunch.Option{holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100 * time.Millisecond)}
if tc.filter != nil {
f := mockMaddrFilter{
filterLocal: tc.filter,
@@ -334,11 +514,49 @@ func TestFailuresOnResponder(t *testing.T) {
opts = append(opts, holepunch.WithAddrFilter(f))
}
h1, h2, relay, _ := makeRelayedHosts(t, opts, nil, false)
router := &simconn.SimpleFirewallRouter{}
relay := MustNewHost(t,
quicSimConn(true, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
libp2p.DisableRelay(),
libp2p.ResourceManager(&network.NullResourceManager{}),
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
// Setup relay service
_, err := relayv2.New(h)
require.NoError(t, err)
})),
)
h1 := MustNewHost(t,
quicSimConn(false, router),
libp2p.EnableHolePunching(opts...),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
connectToRelay(&relay),
libp2p.ForceReachabilityPrivate(),
)
h2 := MustNewHost(t,
quicSimConn(false, router),
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
libp2p.ResourceManager(&network.NullResourceManager{}),
connectToRelay(&relay),
libp2p.ForceReachabilityPrivate(),
)
defer h1.Close()
defer h2.Close()
defer relay.Close()
time.Sleep(100 * time.Millisecond)
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
ID: h2.ID(),
Addrs: h2.Addrs(),
}))
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Contains(c, h1.Mux().Protocols(), holepunch.Protocol)
}, time.Second, 100*time.Millisecond)
s, err := h2.NewStream(network.WithAllowLimitedConn(context.Background(), "holepunch"), h1.ID(), holepunch.Protocol)
require.NoError(t, err)
@@ -407,108 +625,60 @@ func ensureDirectConn(t *testing.T, h1, h2 host.Host) {
}, 5*time.Second, 50*time.Millisecond)
}
func mkHostWithStaticAutoRelay(t *testing.T, relay host.Host) host.Host {
if race.WithRace() {
t.Skip("modifying manet.Private4 is racy")
}
pi := peer.AddrInfo{
ID: relay.ID(),
Addrs: relay.Addrs(),
}
type MockSourceIPSelector struct {
ip atomic.Pointer[net.IP]
}
cpy := manet.Private4
manet.Private4 = []*net.IPNet{}
defer func() { manet.Private4 = cpy }()
func (m *MockSourceIPSelector) PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error) {
return *m.ip.Load(), nil
}
h, err := libp2p.New(
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")),
libp2p.EnableRelay(),
libp2p.EnableAutoRelayWithStaticRelays([]peer.AddrInfo{pi}),
libp2p.ForceReachabilityPrivate(),
libp2p.ResourceManager(&network.NullResourceManager{}),
)
require.NoError(t, err)
// wait till we have a relay addr
require.Eventually(t, func() bool {
for _, a := range h.Addrs() {
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
return true
func quicSimConn(isPubliclyReachably bool, router *simconn.SimpleFirewallRouter) libp2p.Option {
m := &MockSourceIPSelector{}
return libp2p.QUICReuse(
quicreuse.NewConnManager,
quicreuse.OverrideSourceIPSelector(func() (quicreuse.SourceIPSelector, error) {
return m, nil
}),
quicreuse.OverrideListenUDP(func(network string, address *net.UDPAddr) (net.PacketConn, error) {
m.ip.Store(&address.IP)
c := simconn.NewSimConn(address, router)
if isPubliclyReachably {
router.AddPubliclyReachableNode(address, c)
} else {
router.AddNode(address, c)
}
}
return false
}, 5*time.Second, 50*time.Millisecond)
return h
return c, nil
}))
}
func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
t.Helper()
h1, _ = mkHostWithHolePunchSvc(t, h1opt...)
var err error
relay, err = libp2p.New(
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")),
libp2p.DisableRelay(),
libp2p.ResourceManager(&network.NullResourceManager{}),
)
require.NoError(t, err)
_, err = relayv2.New(relay)
require.NoError(t, err)
// make sure the relay service is started and advertised by Identify
h, err := libp2p.New(
libp2p.NoListenAddrs,
libp2p.Transport(tcp.NewTCPTransport),
libp2p.DisableRelay(),
)
require.NoError(t, err)
defer h.Close()
require.NoError(t, h.Connect(context.Background(), peer.AddrInfo{ID: relay.ID(), Addrs: relay.Addrs()}))
require.Eventually(t, func() bool {
supported, err := h.Peerstore().SupportsProtocols(relay.ID(), proto.ProtoIDv2Hop)
return err == nil && len(supported) > 0
}, 3*time.Second, 100*time.Millisecond)
h2 = mkHostWithStaticAutoRelay(t, relay)
if addHolePuncher {
hps = addHolePunchService(t, h2, h2opt...)
}
// h2 has a relay addr
var raddr ma.Multiaddr
for _, a := range h2.Addrs() {
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
raddr = a
break
}
}
require.NotEmpty(t, raddr)
// h1 should connect to the relay addr
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
ID: h2.ID(),
Addrs: []ma.Multiaddr{raddr},
}))
return
}
func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service {
func addHolePunchService(t *testing.T, h host.Host, extraAddrs []ma.Multiaddr, opts ...holepunch.Option) *holepunch.Service {
t.Helper()
hps, err := holepunch.NewService(h, newMockIDService(t, h), func() []ma.Multiaddr {
addrs := h.Addrs()
addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
return append(addrs, ma.StringCast("/ip4/1.2.3.4/tcp/1234"))
addrs = append(addrs, extraAddrs...)
return addrs
}, opts...)
require.NoError(t, err)
return hps
}
func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, *holepunch.Service) {
t.Helper()
h, err := libp2p.New(
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")),
libp2p.ForceReachabilityPrivate(),
libp2p.ResourceManager(&network.NullResourceManager{}),
)
require.NoError(t, err)
hps := addHolePunchService(t, h, opts...)
return h, hps
func waitForHolePunchingSvcActive(t *testing.T, h host.Host) {
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Contains(c, h.Mux().Protocols(), holepunch.Protocol)
}, time.Second, 100*time.Millisecond)
}
// setLegacyBehavior is an option that controls the isClient behavior of the hole punching service.
// Prior to https://github.com/libp2p/go-libp2p/pull/3044, go-libp2p would
// pick the opposite roles for client/server a hole punch. Setting this to
// true preserves that behavior.
//
// Currently, only exposed for testing purposes.
// Do not set this unless you know what you are doing.
func SetLegacyBehavior(legacyBehavior bool) holepunch.Option {
return func(s *holepunch.Service) error {
s.SetLegacyBehavior(legacyBehavior)
return nil
}
}

View File

@@ -20,10 +20,7 @@ import (
// ErrHolePunchActive is returned from DirectConnect when another hole punching attempt is currently running
var ErrHolePunchActive = errors.New("another hole punching attempt to this peer is active")
const (
dialTimeout = 5 * time.Second
maxRetries = 3
)
const maxRetries = 3
// The holePuncher is run on the peer that's behind a NAT / Firewall.
// It observes new incoming connections via a relay that it has a reservation with,
@@ -40,6 +37,8 @@ type holePuncher struct {
ids identify.IDService
listenAddrs func() []ma.Multiaddr
directDialTimeout time.Duration
// active hole punches for deduplicating
activeMx sync.Mutex
active map[peer.ID]struct{}
@@ -49,6 +48,11 @@ type holePuncher struct {
tracer *tracer
filter AddrFilter
// Prior to https://github.com/libp2p/go-libp2p/pull/3044, go-libp2p would
// pick the opposite roles for client/server a hole punch. Setting this to
// true preserves that behavior
legacyBehavior bool
}
func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, tracer *tracer, filter AddrFilter) *holePuncher {
@@ -59,6 +63,8 @@ func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma
tracer: tracer,
filter: filter,
listenAddrs: listenAddrs,
legacyBehavior: true,
}
hp.ctx, hp.ctxCancel = context.WithCancel(context.Background())
h.Network().Notify((*netNotifiee)(hp))
@@ -86,6 +92,7 @@ func (hp *holePuncher) beginDirectConnect(p peer.ID) error {
// It first attempts a direct dial (if we have a public address of that peer), and then
// coordinates a hole punch over the given relay connection.
func (hp *holePuncher) DirectConnect(p peer.ID) error {
log.Debugw("beginDirectConnect", "host", hp.host.ID(), "peer", p)
if err := hp.beginDirectConnect(p); err != nil {
return err
}
@@ -102,14 +109,17 @@ func (hp *holePuncher) DirectConnect(p peer.ID) error {
func (hp *holePuncher) directConnect(rp peer.ID) error {
// short-circuit check to see if we already have a direct connection
if getDirectConnection(hp.host, rp) != nil {
log.Debugw("already connected", "host", hp.host.ID(), "peer", rp)
return nil
}
log.Debugw("attempting direct dial", "host", hp.host.ID(), "peer", rp, "addrs", hp.host.Peerstore().Addrs(rp))
// short-circuit hole punching if a direct dial works.
// attempt a direct connection ONLY if we have a public address for the remote peer
for _, a := range hp.host.Peerstore().Addrs(rp) {
if !isRelayAddress(a) && manet.IsPublicAddr(a) {
forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching")
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout)
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, hp.directDialTimeout)
tstart := time.Now()
// This dials *all* addresses, public and private, from the peerstore.
@@ -150,7 +160,13 @@ func (hp *holePuncher) directConnect(rp peer.ID) error {
}
hp.tracer.StartHolePunch(rp, addrs, rtt)
hp.tracer.HolePunchAttempt(pi.ID)
err := holePunchConnect(hp.ctx, hp.host, pi, true)
ctx, cancel := context.WithTimeout(hp.ctx, hp.directDialTimeout)
isClient := true
if hp.legacyBehavior {
isClient = false
}
err := holePunchConnect(ctx, hp.host, pi, isClient)
cancel()
dt := time.Since(start)
hp.tracer.EndHolePunch(rp, dt, err)
if err == nil {
@@ -180,6 +196,7 @@ func (hp *holePuncher) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, []ma.Multi
return nil, nil, 0, fmt.Errorf("failed to open hole-punching stream: %w", err)
}
defer str.Close()
log.Debugf("initiateHolePunch: %s, %s", str.Conn().RemotePeer(), str.Conn().RemoteMultiaddr())
addr, obsAddr, rtt, err := hp.initiateHolePunchImpl(str)
if err != nil {
@@ -212,6 +229,7 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
if len(obsAddrs) == 0 {
return nil, nil, 0, errors.New("aborting hole punch initiation as we have no public address")
}
log.Debugf("initiating hole punch with %s", obsAddrs)
start := time.Now()
if err := w.WriteMsg(&pb.HolePunch{

View File

@@ -19,6 +19,8 @@ import (
ma "github.com/multiformats/go-multiaddr"
)
const defaultDirectDialTimeout = 10 * time.Second
// Protocol is the libp2p protocol for Hole Punching.
const Protocol protocol.ID = "/libp2p/dcutr"
@@ -38,6 +40,13 @@ var ErrClosed = errors.New("hole punching service closing")
type Option func(*Service) error
func DirectDialTimeout(timeout time.Duration) Option {
return func(s *Service) error {
s.directDialTimeout = timeout
return nil
}
}
// The Service runs on every node that supports the DCUtR protocol.
type Service struct {
ctx context.Context
@@ -52,8 +61,9 @@ type Service struct {
// publicly reachable relay addresses.
listenAddrs func() []ma.Multiaddr
holePuncherMx sync.Mutex
holePuncher *holePuncher
directDialTimeout time.Duration
holePuncherMx sync.Mutex
holePuncher *holePuncher
hasPublicAddrsChan chan struct{}
@@ -61,6 +71,17 @@ type Service struct {
filter AddrFilter
refCount sync.WaitGroup
// Prior to https://github.com/libp2p/go-libp2p/pull/3044, go-libp2p would
// pick the opposite roles for client/server a hole punch. Setting this to
// true preserves that behavior
legacyBehavior bool
}
// SetLegacyBehavior is only exposed for testing purposes.
// Do not set this unless you know what you are doing.
func (s *Service) SetLegacyBehavior(legacyBehavior bool) {
s.legacyBehavior = legacyBehavior
}
// NewService creates a new service that can be used for hole punching
@@ -83,6 +104,8 @@ func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Mul
ids: ids,
listenAddrs: listenAddrs,
hasPublicAddrsChan: make(chan struct{}),
directDialTimeout: defaultDirectDialTimeout,
legacyBehavior: true,
}
for _, opt := range opts {
@@ -102,7 +125,7 @@ func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Mul
func (s *Service) waitForPublicAddr() {
defer s.refCount.Done()
log.Debug("waiting until we have at least one public address", "peer", s.host.ID())
log.Debugw("waiting until we have at least one public address", "peer", s.host.ID())
// TODO: We should have an event here that fires when identify discovers a new
// address.
@@ -114,7 +137,7 @@ func (s *Service) waitForPublicAddr() {
defer t.Stop()
for {
if len(s.listenAddrs()) > 0 {
log.Debug("Host now has a public address. Starting holepunch protocol.")
log.Debugf("Host %s now has a public address (%s). Starting holepunch protocol.", s.host.ID(), s.host.Addrs())
s.host.SetStreamHandler(Protocol, s.handleNewStream)
break
}
@@ -137,6 +160,8 @@ func (s *Service) waitForPublicAddr() {
return
}
s.holePuncher = newHolePuncher(s.host, s.ids, s.listenAddrs, s.tracer, s.filter)
s.holePuncher.directDialTimeout = s.directDialTimeout
s.holePuncher.legacyBehavior = s.legacyBehavior
s.holePuncherMx.Unlock()
close(s.hasPublicAddrsChan)
}
@@ -258,7 +283,13 @@ func (s *Service) handleNewStream(str network.Stream) {
log.Debugw("starting hole punch", "peer", rp)
start := time.Now()
s.tracer.HolePunchAttempt(pi.ID)
err = holePunchConnect(s.ctx, s.host, pi, false)
ctx, cancel := context.WithTimeout(s.ctx, s.directDialTimeout)
isClient := false
if s.legacyBehavior {
isClient = true
}
err = holePunchConnect(ctx, s.host, pi, isClient)
cancel()
dt := time.Since(start)
s.tracer.EndHolePunch(rp, dt, err)
s.tracer.HolePunchFinished("receiver", 1, addrs, ownAddrs, getDirectConnection(s.host, rp))

View File

@@ -51,10 +51,9 @@ func getDirectConnection(h host.Host, p peer.ID) network.Conn {
func holePunchConnect(ctx context.Context, host host.Host, pi peer.AddrInfo, isClient bool) error {
holePunchCtx := network.WithSimultaneousConnect(ctx, isClient, "hole-punching")
forceDirectConnCtx := network.WithForceDirectDial(holePunchCtx, "hole-punching")
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout)
defer cancel()
if err := host.Connect(dialCtx, pi); err != nil {
log.Debugw("holepunchConnect", "host", host.ID(), "peer", pi.ID, "addrs", pi.Addrs)
if err := host.Connect(forceDirectConnCtx, pi); err != nil {
log.Debugw("hole punch attempt with peer failed", "peer ID", pi.ID, "error", err)
return err
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"sync"
"github.com/libp2p/go-netroute"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/prometheus/client_golang/prometheus"
@@ -37,6 +38,9 @@ type ConnManager struct {
reuseUDP6 *reuse
enableReuseport bool
listenUDP listenUDP
sourceIPSelectorFn func() (SourceIPSelector, error)
enableMetrics bool
registerer prometheus.Registerer
@@ -55,13 +59,24 @@ type quicListenerEntry struct {
ln *quicListener
}
func defaultListenUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
return net.ListenUDP(network, laddr)
}
func defaultSourceIPSelectorFn() (SourceIPSelector, error) {
r, err := netroute.New()
return &netrouteSourceIPSelector{routes: r}, err
}
func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.TokenGeneratorKey, opts ...Option) (*ConnManager, error) {
cm := &ConnManager{
enableReuseport: true,
quicListeners: make(map[string]quicListenerEntry),
srk: statelessResetKey,
tokenKey: tokenKey,
registerer: prometheus.DefaultRegisterer,
enableReuseport: true,
quicListeners: make(map[string]quicListenerEntry),
srk: statelessResetKey,
tokenKey: tokenKey,
registerer: prometheus.DefaultRegisterer,
listenUDP: defaultListenUDP,
sourceIPSelectorFn: defaultSourceIPSelectorFn,
}
for _, o := range opts {
if err := o(cm); err != nil {
@@ -76,8 +91,8 @@ func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.Toke
cm.clientConfig = quicConf
cm.serverConfig = serverConfig
if cm.enableReuseport {
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey)
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey)
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn)
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn)
}
return cm, nil
}
@@ -238,7 +253,7 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
return tr, nil
}
conn, err := net.ListenUDP(network, laddr)
conn, err := c.listenUDP(network, laddr)
if err != nil {
return nil, err
}
@@ -320,7 +335,8 @@ func (c *ConnManager) TransportWithAssociationForDial(association any, network s
case "udp6":
laddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
}
conn, err := net.ListenUDP(network, laddr)
conn, err := c.listenUDP(network, laddr)
if err != nil {
return nil, err
}

View File

@@ -1,9 +1,29 @@
package quicreuse
import "github.com/prometheus/client_golang/prometheus"
import (
"net"
"github.com/prometheus/client_golang/prometheus"
)
type Option func(*ConnManager) error
type listenUDP func(network string, laddr *net.UDPAddr) (net.PacketConn, error)
func OverrideListenUDP(f listenUDP) Option {
return func(m *ConnManager) error {
m.listenUDP = f
return nil
}
}
func OverrideSourceIPSelector(f func() (SourceIPSelector, error)) Option {
return func(m *ConnManager) error {
m.sourceIPSelectorFn = f
return nil
}
}
func DisableReuseport() Option {
return func(m *ConnManager) error {
m.enableReuseport = false

View File

@@ -10,7 +10,6 @@ import (
"time"
"github.com/google/gopacket/routing"
"github.com/libp2p/go-netroute"
"github.com/quic-go/quic-go"
)
@@ -168,7 +167,11 @@ type reuse struct {
closeChan chan struct{}
gcStopChan chan struct{}
routes routing.Router
listenUDP listenUDP
sourceIPSelectorFn func() (SourceIPSelector, error)
routes SourceIPSelector
unicast map[string] /* IP.String() */ map[int] /* port */ *refcountedTransport
// globalListeners contains transports that are listening on 0.0.0.0 / ::
globalListeners map[int]*refcountedTransport
@@ -181,15 +184,17 @@ type reuse struct {
tokenGeneratorKey *quic.TokenGeneratorKey
}
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey) *reuse {
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey, listenUDP listenUDP, sourceIPSelectorFn func() (SourceIPSelector, error)) *reuse {
r := &reuse{
unicast: make(map[string]map[int]*refcountedTransport),
globalListeners: make(map[int]*refcountedTransport),
globalDialers: make(map[int]*refcountedTransport),
closeChan: make(chan struct{}),
gcStopChan: make(chan struct{}),
statelessResetKey: srk,
tokenGeneratorKey: tokenKey,
unicast: make(map[string]map[int]*refcountedTransport),
globalListeners: make(map[int]*refcountedTransport),
globalDialers: make(map[int]*refcountedTransport),
closeChan: make(chan struct{}),
gcStopChan: make(chan struct{}),
listenUDP: listenUDP,
sourceIPSelectorFn: sourceIPSelectorFn,
statelessResetKey: srk,
tokenGeneratorKey: tokenKey,
}
go r.gc()
return r
@@ -250,7 +255,7 @@ func (r *reuse) gc() {
} else {
// Ignore the error, there's nothing we can do about
// it.
r.routes, _ = netroute.New()
r.routes, _ = r.sourceIPSelectorFn()
}
}
}
@@ -270,7 +275,7 @@ func (r *reuse) transportWithAssociationForDial(association any, network string,
r.mutex.Unlock()
if router != nil {
_, _, src, err := router.Route(raddr.IP)
src, err := router.PreferredSourceIPForDestination(raddr)
if err == nil && !src.IsUnspecified() {
ip = &src
}
@@ -323,7 +328,7 @@ func (r *reuse) transportForDialLocked(association any, network string, source *
case "udp6":
addr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
}
conn, err := net.ListenUDP(network, addr)
conn, err := r.listenUDP(network, addr)
if err != nil {
return nil, err
}
@@ -389,7 +394,7 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
}
}
conn, err := net.ListenUDP(network, laddr)
conn, err := r.listenUDP(network, laddr)
if err != nil {
return nil, err
}
@@ -418,7 +423,7 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
r.unicast[localAddr.IP.String()] = make(map[int]*refcountedTransport)
// Assume the system's routes may have changed if we're adding a new listener.
// Ignore the error, there's nothing we can do.
r.routes, _ = netroute.New()
r.routes, _ = r.sourceIPSelectorFn()
}
// The kernel already checked that the laddr is not already listen
@@ -432,3 +437,16 @@ func (r *reuse) Close() error {
<-r.gcStopChan
return nil
}
type SourceIPSelector interface {
PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error)
}
type netrouteSourceIPSelector struct {
routes routing.Router
}
func (s *netrouteSourceIPSelector) PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error) {
_, _, src, err := s.routes.Route(dst.IP)
return src, err
}

View File

@@ -61,7 +61,7 @@ func cleanup(t *testing.T, reuse *reuse) {
}
func TestReuseListenOnAllIPv4(t *testing.T) {
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
require.Eventually(t, isGarbageCollectorRunning, 500*time.Millisecond, 50*time.Millisecond, "expected garbage collector to be running")
cleanup(t, reuse)
@@ -73,7 +73,7 @@ func TestReuseListenOnAllIPv4(t *testing.T) {
}
func TestReuseListenOnAllIPv6(t *testing.T) {
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
require.Eventually(t, isGarbageCollectorRunning, 500*time.Millisecond, 50*time.Millisecond, "expected garbage collector to be running")
cleanup(t, reuse)
@@ -86,7 +86,7 @@ func TestReuseListenOnAllIPv6(t *testing.T) {
}
func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
cleanup(t, reuse)
addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
@@ -100,7 +100,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
}
func TestReuseConnectionWhenDialing(t *testing.T) {
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
cleanup(t, reuse)
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
@@ -117,7 +117,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) {
}
func TestReuseConnectionWhenListening(t *testing.T) {
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
cleanup(t, reuse)
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
@@ -132,7 +132,7 @@ func TestReuseConnectionWhenListening(t *testing.T) {
}
func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
cleanup(t, reuse)
// dial any address
@@ -166,7 +166,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) {
if platformHasRoutingTables() {
t.Skip("this test only works on platforms that support routing tables")
}
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
cleanup(t, reuse)
router, err := netroute.New()
@@ -203,7 +203,7 @@ func TestReuseGarbageCollect(t *testing.T) {
maxUnusedDuration = 10 * maxUnusedDuration
}
reuse := newReuse(nil, nil)
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
cleanup(t, reuse)
numGlobals := func() int {