mirror of
https://github.com/libp2p/go-libp2p.git
synced 2025-09-26 20:21:26 +08:00
fix(dcutr): Fix end to end tests and add legacy behavior flag (default=true) (#3044)
This commit is contained in:
179
p2p/net/simconn/router.go
Normal file
179
p2p/net/simconn/router.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package simconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PacketReciever interface {
|
||||
RecvPacket(p Packet)
|
||||
}
|
||||
|
||||
// PerfectRouter is a router that has no latency or jitter and can route to
|
||||
// every node
|
||||
type PerfectRouter struct {
|
||||
mu sync.Mutex
|
||||
nodes map[net.Addr]PacketReciever
|
||||
}
|
||||
|
||||
// SendPacket implements Router.
|
||||
func (r *PerfectRouter) SendPacket(p Packet) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
conn, ok := r.nodes[p.To]
|
||||
if !ok {
|
||||
return errors.New("unknown destination")
|
||||
}
|
||||
|
||||
conn.RecvPacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PerfectRouter) AddNode(addr net.Addr, conn PacketReciever) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.nodes == nil {
|
||||
r.nodes = make(map[net.Addr]PacketReciever)
|
||||
}
|
||||
r.nodes[addr] = conn
|
||||
}
|
||||
|
||||
func (r *PerfectRouter) RemoveNode(addr net.Addr) {
|
||||
delete(r.nodes, addr)
|
||||
}
|
||||
|
||||
var _ Router = &PerfectRouter{}
|
||||
|
||||
type DelayedPacketReciever struct {
|
||||
inner PacketReciever
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (r *DelayedPacketReciever) RecvPacket(p Packet) {
|
||||
time.AfterFunc(r.delay, func() { r.inner.RecvPacket(p) })
|
||||
}
|
||||
|
||||
type FixedLatencyRouter struct {
|
||||
PerfectRouter
|
||||
latency time.Duration
|
||||
}
|
||||
|
||||
func (r *FixedLatencyRouter) SendPacket(p Packet) error {
|
||||
return r.PerfectRouter.SendPacket(p)
|
||||
}
|
||||
|
||||
func (r *FixedLatencyRouter) AddNode(addr net.Addr, conn PacketReciever) {
|
||||
r.PerfectRouter.AddNode(addr, &DelayedPacketReciever{
|
||||
inner: conn,
|
||||
delay: r.latency,
|
||||
})
|
||||
}
|
||||
|
||||
var _ Router = &FixedLatencyRouter{}
|
||||
|
||||
type simpleNodeFirewall struct {
|
||||
mu sync.Mutex
|
||||
publiclyReachable bool
|
||||
packetsOutTo map[string]struct{}
|
||||
node *SimConn
|
||||
}
|
||||
|
||||
func (f *simpleNodeFirewall) MarkPacketSentOut(p Packet) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.packetsOutTo == nil {
|
||||
f.packetsOutTo = make(map[string]struct{})
|
||||
}
|
||||
f.packetsOutTo[p.To.String()] = struct{}{}
|
||||
}
|
||||
|
||||
func (f *simpleNodeFirewall) IsPacketInAllowed(p Packet) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.publiclyReachable {
|
||||
return true
|
||||
}
|
||||
|
||||
_, ok := f.packetsOutTo[p.From.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (f *simpleNodeFirewall) String() string {
|
||||
return fmt.Sprintf("public: %v, packetsOutTo: %v", f.publiclyReachable, f.packetsOutTo)
|
||||
}
|
||||
|
||||
type SimpleFirewallRouter struct {
|
||||
mu sync.Mutex
|
||||
nodes map[string]*simpleNodeFirewall
|
||||
}
|
||||
|
||||
func (r *SimpleFirewallRouter) String() string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
nodes := make([]string, 0, len(r.nodes))
|
||||
for _, node := range r.nodes {
|
||||
nodes = append(nodes, node.String())
|
||||
}
|
||||
return fmt.Sprintf("%v", nodes)
|
||||
}
|
||||
|
||||
func (r *SimpleFirewallRouter) SendPacket(p Packet) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
toNode, exists := r.nodes[p.To.String()]
|
||||
if !exists {
|
||||
return errors.New("unknown destination")
|
||||
}
|
||||
|
||||
// Record that this node is sending a packet to the destination
|
||||
fromNode, exists := r.nodes[p.From.String()]
|
||||
if !exists {
|
||||
return errors.New("unknown source")
|
||||
}
|
||||
fromNode.MarkPacketSentOut(p)
|
||||
|
||||
if !toNode.IsPacketInAllowed(p) {
|
||||
return nil // Silently drop blocked packets
|
||||
}
|
||||
|
||||
toNode.node.RecvPacket(p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SimpleFirewallRouter) AddNode(addr net.Addr, conn *SimConn) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.nodes == nil {
|
||||
r.nodes = make(map[string]*simpleNodeFirewall)
|
||||
}
|
||||
r.nodes[addr.String()] = &simpleNodeFirewall{
|
||||
packetsOutTo: make(map[string]struct{}),
|
||||
node: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SimpleFirewallRouter) AddPubliclyReachableNode(addr net.Addr, conn *SimConn) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.nodes == nil {
|
||||
r.nodes = make(map[string]*simpleNodeFirewall)
|
||||
}
|
||||
r.nodes[addr.String()] = &simpleNodeFirewall{
|
||||
publiclyReachable: true,
|
||||
node: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SimpleFirewallRouter) RemoveNode(addr net.Addr) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.nodes == nil {
|
||||
return
|
||||
}
|
||||
delete(r.nodes, addr.String())
|
||||
}
|
||||
|
||||
var _ Router = &SimpleFirewallRouter{}
|
218
p2p/net/simconn/simconn.go
Normal file
218
p2p/net/simconn/simconn.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package simconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrDeadlineExceeded = errors.New("deadline exceeded")
|
||||
|
||||
type Router interface {
|
||||
SendPacket(p Packet) error
|
||||
}
|
||||
|
||||
type Packet struct {
|
||||
To net.Addr
|
||||
From net.Addr
|
||||
buf []byte
|
||||
}
|
||||
|
||||
type SimConn struct {
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
closedChan chan struct{}
|
||||
|
||||
packetsSent atomic.Uint64
|
||||
packetsRcvd atomic.Uint64
|
||||
bytesSent atomic.Int64
|
||||
bytesRcvd atomic.Int64
|
||||
|
||||
router Router
|
||||
|
||||
myAddr *net.UDPAddr
|
||||
myLocalAddr net.Addr
|
||||
packetsToRead chan Packet
|
||||
|
||||
readDeadline time.Time
|
||||
writeDeadline time.Time
|
||||
}
|
||||
|
||||
// NewSimConn creates a new simulated connection with the specified parameters
|
||||
func NewSimConn(addr *net.UDPAddr, rtr Router) *SimConn {
|
||||
return &SimConn{
|
||||
router: rtr,
|
||||
myAddr: addr,
|
||||
packetsToRead: make(chan Packet, 512), // buffered channel to prevent blocking
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type ConnStats struct {
|
||||
BytesSent int
|
||||
BytesRcvd int
|
||||
PacketsSent int
|
||||
PacketsRcvd int
|
||||
}
|
||||
|
||||
func (c *SimConn) Stats() ConnStats {
|
||||
return ConnStats{
|
||||
BytesSent: int(c.bytesSent.Load()),
|
||||
BytesRcvd: int(c.bytesRcvd.Load()),
|
||||
PacketsSent: int(c.packetsSent.Load()),
|
||||
PacketsRcvd: int(c.packetsRcvd.Load()),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLocalAddr only changes what `.LocalAddr()` returns.
|
||||
// Packets will still come From the initially configured addr.
|
||||
func (c *SimConn) SetLocalAddr(addr net.Addr) {
|
||||
c.myLocalAddr = addr
|
||||
}
|
||||
|
||||
func (c *SimConn) RecvPacket(p Packet) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.packetsRcvd.Add(1)
|
||||
c.bytesRcvd.Add(int64(len(p.buf)))
|
||||
|
||||
select {
|
||||
case c.packetsToRead <- p:
|
||||
default:
|
||||
// drop the packet if the channel is full
|
||||
}
|
||||
}
|
||||
|
||||
var _ net.PacketConn = &SimConn{}
|
||||
|
||||
// Close implements net.PacketConn
|
||||
func (c *SimConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
close(c.closedChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadFrom implements net.PacketConn
|
||||
func (c *SimConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
deadline := c.readDeadline
|
||||
c.mu.Unlock()
|
||||
|
||||
if !deadline.IsZero() && time.Now().After(deadline) {
|
||||
return 0, nil, ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
var pkt Packet
|
||||
if !deadline.IsZero() {
|
||||
select {
|
||||
case pkt = <-c.packetsToRead:
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return 0, nil, ErrDeadlineExceeded
|
||||
}
|
||||
} else {
|
||||
pkt = <-c.packetsToRead
|
||||
}
|
||||
|
||||
n = copy(p, pkt.buf)
|
||||
// if the provided buffer is not enough to read the whole packet, we drop
|
||||
// the rest of the data. this is similar to what `recvfrom` does on Linux
|
||||
// and macOS.
|
||||
return n, pkt.From, nil
|
||||
}
|
||||
|
||||
// WriteTo implements net.PacketConn
|
||||
func (c *SimConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
deadline := c.writeDeadline
|
||||
c.mu.Unlock()
|
||||
|
||||
if !deadline.IsZero() && time.Now().After(deadline) {
|
||||
return 0, ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
c.packetsSent.Add(1)
|
||||
c.bytesSent.Add(int64(len(p)))
|
||||
|
||||
pkt := Packet{
|
||||
From: c.myAddr,
|
||||
To: addr,
|
||||
buf: slices.Clone(p),
|
||||
}
|
||||
return len(p), c.router.SendPacket(pkt)
|
||||
}
|
||||
|
||||
func (c *SimConn) UnicastAddr() net.Addr {
|
||||
return c.myAddr
|
||||
}
|
||||
|
||||
// LocalAddr implements net.PacketConn
|
||||
func (c *SimConn) LocalAddr() net.Addr {
|
||||
if c.myLocalAddr != nil {
|
||||
return c.myLocalAddr
|
||||
}
|
||||
return c.myAddr
|
||||
}
|
||||
|
||||
// SetDeadline implements net.PacketConn
|
||||
func (c *SimConn) SetDeadline(t time.Time) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readDeadline = t
|
||||
c.writeDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements net.PacketConn
|
||||
func (c *SimConn) SetReadDeadline(t time.Time) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.readDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements net.PacketConn
|
||||
func (c *SimConn) SetWriteDeadline(t time.Time) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.writeDeadline = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func IntToPublicIPv4(n int) net.IP {
|
||||
n += 1
|
||||
// Avoid private IP ranges
|
||||
b := make([]byte, 4)
|
||||
b[0] = byte((n>>24)&0xFF | 1)
|
||||
b[1] = byte((n >> 16) & 0xFF)
|
||||
b[2] = byte((n >> 8) & 0xFF)
|
||||
b[3] = byte(n & 0xFF)
|
||||
|
||||
ip := net.IPv4(b[0], b[1], b[2], b[3])
|
||||
|
||||
// Check and modify if it's in private ranges
|
||||
if ip.IsPrivate() {
|
||||
b[0] = 1 // Use 1.x.x.x as public range
|
||||
}
|
||||
|
||||
return ip
|
||||
}
|
315
p2p/net/simconn/simconn_test.go
Normal file
315
p2p/net/simconn/simconn_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package simconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimConnBasicConnectivity(t *testing.T) {
|
||||
router := &PerfectRouter{}
|
||||
|
||||
// Create two endpoints
|
||||
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
|
||||
addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234}
|
||||
|
||||
conn1 := NewSimConn(addr1, router)
|
||||
conn2 := NewSimConn(addr2, router)
|
||||
|
||||
router.AddNode(addr1, conn1)
|
||||
router.AddNode(addr2, conn2)
|
||||
|
||||
// Test sending data from conn1 to conn2
|
||||
testData := []byte("hello world")
|
||||
n, err := conn1.WriteTo(testData, addr2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(testData), n)
|
||||
|
||||
// Read data from conn2
|
||||
buf := make([]byte, 1024)
|
||||
n, addr, err := conn2.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testData, buf[:n])
|
||||
require.Equal(t, addr1, addr)
|
||||
|
||||
// Check stats
|
||||
stats1 := conn1.Stats()
|
||||
require.Equal(t, len(testData), stats1.BytesSent)
|
||||
require.Equal(t, 1, stats1.PacketsSent)
|
||||
|
||||
stats2 := conn2.Stats()
|
||||
require.Equal(t, len(testData), stats2.BytesRcvd)
|
||||
require.Equal(t, 1, stats2.PacketsRcvd)
|
||||
}
|
||||
|
||||
func TestSimConnDeadlines(t *testing.T) {
|
||||
router := &PerfectRouter{}
|
||||
|
||||
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
|
||||
conn := NewSimConn(addr1, router)
|
||||
router.AddNode(addr1, conn)
|
||||
|
||||
t.Run("read deadline", func(t *testing.T) {
|
||||
deadline := time.Now().Add(10 * time.Millisecond)
|
||||
err := conn.SetReadDeadline(deadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_, _, err = conn.ReadFrom(buf)
|
||||
require.ErrorIs(t, err, ErrDeadlineExceeded)
|
||||
})
|
||||
|
||||
t.Run("write deadline", func(t *testing.T) {
|
||||
deadline := time.Now().Add(-time.Second) // Already expired
|
||||
err := conn.SetWriteDeadline(deadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.WriteTo([]byte("test"), &net.UDPAddr{})
|
||||
require.ErrorIs(t, err, ErrDeadlineExceeded)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSimConnClose(t *testing.T) {
|
||||
router := &PerfectRouter{}
|
||||
|
||||
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
|
||||
conn := NewSimConn(addr1, router)
|
||||
router.AddNode(addr1, conn)
|
||||
|
||||
err := conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify operations fail after close
|
||||
_, err = conn.WriteTo([]byte("test"), addr1)
|
||||
require.ErrorIs(t, err, net.ErrClosed)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_, _, err = conn.ReadFrom(buf)
|
||||
require.ErrorIs(t, err, net.ErrClosed)
|
||||
|
||||
// Second close should not error
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSimConnLocalAddr(t *testing.T) {
|
||||
router := &PerfectRouter{}
|
||||
|
||||
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
|
||||
conn := NewSimConn(addr1, router)
|
||||
|
||||
// Test default local address
|
||||
require.Equal(t, addr1, conn.LocalAddr())
|
||||
|
||||
// Test setting custom local address
|
||||
customAddr := &net.UDPAddr{IP: IntToPublicIPv4(3), Port: 5678}
|
||||
conn.SetLocalAddr(customAddr)
|
||||
require.Equal(t, customAddr, conn.LocalAddr())
|
||||
}
|
||||
|
||||
func TestSimConnDeadlinesWithLatency(t *testing.T) {
|
||||
router := &FixedLatencyRouter{
|
||||
PerfectRouter: PerfectRouter{},
|
||||
latency: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
|
||||
addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234}
|
||||
|
||||
conn1 := NewSimConn(addr1, router)
|
||||
conn2 := NewSimConn(addr2, router)
|
||||
|
||||
router.AddNode(addr1, conn1)
|
||||
router.AddNode(addr2, conn2)
|
||||
|
||||
reset := func() {
|
||||
router.RemoveNode(addr1)
|
||||
router.RemoveNode(addr2)
|
||||
|
||||
conn1 = NewSimConn(addr1, router)
|
||||
conn2 = NewSimConn(addr2, router)
|
||||
|
||||
router.AddNode(addr1, conn1)
|
||||
router.AddNode(addr2, conn2)
|
||||
}
|
||||
|
||||
t.Run("write succeeds within deadline", func(t *testing.T) {
|
||||
deadline := time.Now().Add(200 * time.Millisecond)
|
||||
err := conn1.SetWriteDeadline(deadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := conn1.WriteTo([]byte("test"), addr2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, n)
|
||||
reset()
|
||||
})
|
||||
|
||||
t.Run("write fails after past deadline", func(t *testing.T) {
|
||||
deadline := time.Now().Add(-time.Second) // Already expired
|
||||
err := conn1.SetWriteDeadline(deadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn1.WriteTo([]byte("test"), addr2)
|
||||
require.ErrorIs(t, err, ErrDeadlineExceeded)
|
||||
reset()
|
||||
})
|
||||
|
||||
t.Run("read succeeds within deadline", func(t *testing.T) {
|
||||
// Reset deadline and send a message
|
||||
conn2.SetReadDeadline(time.Time{})
|
||||
testData := []byte("hello")
|
||||
deadline := time.Now().Add(200 * time.Millisecond)
|
||||
conn1.SetWriteDeadline(deadline)
|
||||
_, err := conn1.WriteTo(testData, addr2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set read deadline and try to read
|
||||
deadline = time.Now().Add(200 * time.Millisecond)
|
||||
err = conn2.SetReadDeadline(deadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, addr, err := conn2.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr1, addr)
|
||||
require.Equal(t, testData, buf[:n])
|
||||
reset()
|
||||
})
|
||||
|
||||
t.Run("read fails after deadline", func(t *testing.T) {
|
||||
defer reset()
|
||||
// Set a short deadline
|
||||
deadline := time.Now().Add(50 * time.Millisecond) // Less than router latency
|
||||
err := conn2.SetReadDeadline(deadline)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Send data after setting deadline
|
||||
_, err := conn1.WriteTo([]byte("test"), addr2)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Read should fail due to deadline
|
||||
buf := make([]byte, 1024)
|
||||
_, _, err = conn2.ReadFrom(buf)
|
||||
require.ErrorIs(t, err, ErrDeadlineExceeded)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSimpleHolePunch(t *testing.T) {
|
||||
router := &SimpleFirewallRouter{
|
||||
nodes: make(map[string]*simpleNodeFirewall),
|
||||
}
|
||||
|
||||
// Create two peers
|
||||
addr1 := &net.UDPAddr{IP: IntToPublicIPv4(1), Port: 1234}
|
||||
addr2 := &net.UDPAddr{IP: IntToPublicIPv4(2), Port: 1234}
|
||||
|
||||
peer1 := NewSimConn(addr1, router)
|
||||
peer2 := NewSimConn(addr2, router)
|
||||
|
||||
router.AddNode(addr1, peer1)
|
||||
router.AddNode(addr2, peer2)
|
||||
|
||||
reset := func() {
|
||||
router.RemoveNode(addr1)
|
||||
router.RemoveNode(addr2)
|
||||
|
||||
peer1 = NewSimConn(addr1, router)
|
||||
peer2 = NewSimConn(addr2, router)
|
||||
|
||||
router.AddNode(addr1, peer1)
|
||||
router.AddNode(addr2, peer2)
|
||||
}
|
||||
|
||||
// Initially, direct communication between peer1 and peer2 should fail
|
||||
t.Run("direct communication blocked initially", func(t *testing.T) {
|
||||
_, err := peer1.WriteTo([]byte("direct message"), addr2)
|
||||
require.NoError(t, err) // Write succeeds but packet is dropped
|
||||
|
||||
// Try to read from peer2
|
||||
peer2.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
||||
buf := make([]byte, 1024)
|
||||
_, _, err = peer2.ReadFrom(buf)
|
||||
require.ErrorIs(t, err, ErrDeadlineExceeded)
|
||||
reset()
|
||||
})
|
||||
|
||||
holePunchMsg := []byte("hole punch")
|
||||
// Simulate hole punching
|
||||
t.Run("hole punch and direct communication", func(t *testing.T) {
|
||||
// Both peers send packets to each other simultaneously
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := peer1.WriteTo(holePunchMsg, addr2)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := peer2.WriteTo(holePunchMsg, addr1)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Now direct communication should work both ways
|
||||
t.Run("peer1 to peer2", func(t *testing.T) {
|
||||
testMsg := []byte("direct message after hole punch")
|
||||
_, err := peer1.WriteTo(testMsg, addr2)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
peer2.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
n, addr, err := peer2.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr1, addr)
|
||||
if bytes.Equal(buf[:n], holePunchMsg) {
|
||||
// Read again to get the actual message
|
||||
n, addr, err = peer2.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr1, addr)
|
||||
}
|
||||
require.Equal(t, string(testMsg), string(buf[:n]))
|
||||
})
|
||||
|
||||
t.Run("peer2 to peer1", func(t *testing.T) {
|
||||
testMsg := []byte("response from peer2")
|
||||
_, err := peer2.WriteTo(testMsg, addr1)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
peer1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
|
||||
n, addr, err := peer1.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr2, addr)
|
||||
if bytes.Equal(buf[:n], holePunchMsg) {
|
||||
// Read again to get the actual message
|
||||
n, addr, err = peer1.ReadFrom(buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, addr2, addr)
|
||||
}
|
||||
require.Equal(t, string(testMsg), string(buf[:n]))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPublicIP(t *testing.T) {
|
||||
err := quick.Check(func(n int) bool {
|
||||
ip := IntToPublicIPv4(n)
|
||||
return !ip.IsPrivate()
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
@@ -2,28 +2,30 @@ package holepunch_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p"
|
||||
"github.com/libp2p/go-libp2p-testing/race"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/peerstore"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"
|
||||
"github.com/libp2p/go-libp2p/p2p/net/simconn"
|
||||
relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
|
||||
holepunch_pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
|
||||
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
|
||||
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
|
||||
"go.uber.org/fx"
|
||||
|
||||
"github.com/libp2p/go-msgio/pbio"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -81,90 +83,239 @@ func (s *mockIDService) OwnObservedAddrs() []ma.Multiaddr {
|
||||
}
|
||||
|
||||
func TestNoHolePunchIfDirectConnExists(t *testing.T) {
|
||||
tr := &mockEventTracer{}
|
||||
h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
|
||||
defer h1.Close()
|
||||
h2, _ := mkHostWithHolePunchSvc(t)
|
||||
defer h2.Close()
|
||||
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
|
||||
ID: h2.ID(),
|
||||
Addrs: h2.Addrs(),
|
||||
}))
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
nc1 := len(h1.Network().ConnsToPeer(h2.ID()))
|
||||
require.GreaterOrEqual(t, nc1, 1)
|
||||
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
|
||||
require.GreaterOrEqual(t, nc2, 1)
|
||||
require.NoError(t, hps.DirectConnect(h2.ID()))
|
||||
require.Len(t, h1.Network().ConnsToPeer(h2.ID()), nc1)
|
||||
require.Len(t, h2.Network().ConnsToPeer(h1.ID()), nc2)
|
||||
require.Empty(t, tr.getEvents())
|
||||
}
|
||||
|
||||
func TestDirectDialWorks(t *testing.T) {
|
||||
if race.WithRace() {
|
||||
t.Skip("modifying manet.Private4 is racy")
|
||||
}
|
||||
|
||||
// mark all addresses as public
|
||||
cpy := manet.Private4
|
||||
manet.Private4 = []*net.IPNet{}
|
||||
defer func() { manet.Private4 = cpy }()
|
||||
router := &simconn.SimpleFirewallRouter{}
|
||||
relay := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.DisableRelay(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
|
||||
// Setup relay service
|
||||
_, err := relayv2.New(h)
|
||||
require.NoError(t, err)
|
||||
})),
|
||||
)
|
||||
|
||||
tr := &mockEventTracer{}
|
||||
h1, h1ps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
|
||||
defer h1.Close()
|
||||
h2, _ := mkHostWithHolePunchSvc(t)
|
||||
defer h2.Close()
|
||||
h2.RemoveStreamHandler(holepunch.Protocol)
|
||||
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
|
||||
h1 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.EnableHolePunching(holepunch.DirectDialTimeout(100*time.Millisecond)),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
)
|
||||
|
||||
// try to hole punch without any connection and streams, if it works -> it's a direct connection
|
||||
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
|
||||
require.NoError(t, h1ps.DirectConnect(h2.ID()))
|
||||
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1)
|
||||
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1)
|
||||
events := tr.getEvents()
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, holepunch.DirectDialEvtT, events[0].Type)
|
||||
}
|
||||
h2 := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.ForceReachabilityPublic(),
|
||||
connectToRelay(&relay),
|
||||
libp2p.EnableHolePunching(holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100*time.Millisecond)),
|
||||
)
|
||||
|
||||
func TestEndToEndSimConnect(t *testing.T) {
|
||||
h1tr := &mockEventTracer{}
|
||||
h2tr := &mockEventTracer{}
|
||||
h1, h2, relay, _ := makeRelayedHosts(t, []holepunch.Option{holepunch.WithTracer(h1tr)}, []holepunch.Option{holepunch.WithTracer(h2tr)}, true)
|
||||
defer h1.Close()
|
||||
defer h2.Close()
|
||||
defer relay.Close()
|
||||
|
||||
// wait till a direct connection is complete
|
||||
ensureDirectConn(t, h1, h2)
|
||||
// ensure no hole-punching streams are open on either side
|
||||
ensureNoHolePunchingStream(t, h1, h2)
|
||||
var h2Events []*holepunch.Event
|
||||
require.Eventually(t,
|
||||
func() bool {
|
||||
h2Events = h2tr.getEvents()
|
||||
return len(h2Events) == 3
|
||||
},
|
||||
time.Second,
|
||||
10*time.Millisecond,
|
||||
)
|
||||
require.Equal(t, holepunch.StartHolePunchEvtT, h2Events[0].Type)
|
||||
require.Equal(t, holepunch.HolePunchAttemptEvtT, h2Events[1].Type)
|
||||
require.Equal(t, holepunch.EndHolePunchEvtT, h2Events[2].Type)
|
||||
waitForHolePunchingSvcActive(t, h1)
|
||||
waitForHolePunchingSvcActive(t, h2)
|
||||
|
||||
h1Events := h1tr.getEvents()
|
||||
// We don't really expect a hole-punched connection to be established in this test,
|
||||
// as we probably don't get the timing right for the TCP simultaneous open.
|
||||
// From time to time, it still happens occasionally, and then we get a EndHolePunchEvtT here.
|
||||
if len(h1Events) != 2 && len(h1Events) != 3 {
|
||||
t.Fatal("expected either 2 or 3 events")
|
||||
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
|
||||
// try to hole punch without any connection and streams, if it works -> it's a direct connection
|
||||
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
|
||||
pingAtoB(t, h1, h2)
|
||||
|
||||
nc1 := len(h1.Network().ConnsToPeer(h2.ID()))
|
||||
require.Equal(t, nc1, 1)
|
||||
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
|
||||
require.Equal(t, nc2, 1)
|
||||
assert.Never(t, func() bool {
|
||||
return (len(h1.Network().ConnsToPeer(h2.ID())) != nc1 ||
|
||||
len(h2.Network().ConnsToPeer(h1.ID())) != nc2 ||
|
||||
len(tr.getEvents()) != 0)
|
||||
}, time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDirectDialWorks(t *testing.T) {
|
||||
router := &simconn.SimpleFirewallRouter{}
|
||||
relay := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.DisableRelay(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
|
||||
// Setup relay service
|
||||
_, err := relayv2.New(h)
|
||||
require.NoError(t, err)
|
||||
})),
|
||||
)
|
||||
|
||||
tr := &mockEventTracer{}
|
||||
// h1 is public
|
||||
h1 := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ForceReachabilityPublic(),
|
||||
libp2p.EnableHolePunching(holepunch.DirectDialTimeout(100*time.Millisecond)),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
)
|
||||
|
||||
h2 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
connectToRelay(&relay),
|
||||
libp2p.EnableHolePunching(holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100*time.Millisecond)),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
)
|
||||
|
||||
defer h1.Close()
|
||||
defer h2.Close()
|
||||
defer relay.Close()
|
||||
|
||||
// wait for dcutr to be available
|
||||
waitForHolePunchingSvcActive(t, h2)
|
||||
|
||||
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
|
||||
// try to hole punch without any connection and streams, if it works -> it's a direct connection
|
||||
require.Empty(t, h1.Network().ConnsToPeer(h2.ID()))
|
||||
pingAtoB(t, h1, h2)
|
||||
|
||||
// require.NoError(t, h1ps.DirectConnect(h2.ID()))
|
||||
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1)
|
||||
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1)
|
||||
require.EventuallyWithT(t, func(collect *assert.CollectT) {
|
||||
events := tr.getEvents()
|
||||
fmt.Println("events:", events)
|
||||
if !assert.Len(collect, events, 1) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, holepunch.DirectDialEvtT, events[0].Type)
|
||||
}, 2*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
func connectToRelay(relayPtr *host.Host) libp2p.Option {
|
||||
return func(cfg *libp2p.Config) error {
|
||||
if relayPtr == nil {
|
||||
return nil
|
||||
}
|
||||
relay := *relayPtr
|
||||
pi := peer.AddrInfo{
|
||||
ID: relay.ID(),
|
||||
Addrs: relay.Addrs(),
|
||||
}
|
||||
|
||||
return cfg.Apply(
|
||||
libp2p.EnableRelay(),
|
||||
libp2p.EnableAutoRelayWithStaticRelays([]peer.AddrInfo{pi}),
|
||||
)
|
||||
}
|
||||
require.Equal(t, holepunch.StartHolePunchEvtT, h1Events[0].Type)
|
||||
require.Equal(t, holepunch.HolePunchAttemptEvtT, h1Events[1].Type)
|
||||
if len(h1Events) == 3 {
|
||||
require.Equal(t, holepunch.EndHolePunchEvtT, h1Events[2].Type)
|
||||
}
|
||||
|
||||
func learnAddrs(h1, h2 host.Host) {
|
||||
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
|
||||
h2.Peerstore().AddAddrs(h1.ID(), h1.Addrs(), peerstore.ConnectedAddrTTL)
|
||||
}
|
||||
|
||||
func pingAtoB(t *testing.T, a, b host.Host) {
|
||||
t.Helper()
|
||||
p1 := ping.NewPingService(a)
|
||||
require.NoError(t, a.Connect(context.Background(), peer.AddrInfo{
|
||||
ID: b.ID(),
|
||||
Addrs: b.Addrs(),
|
||||
}))
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
res := p1.Ping(ctx, b.ID())
|
||||
result := <-res
|
||||
require.NoError(t, result.Error)
|
||||
}
|
||||
|
||||
func MustNewHost(t *testing.T, opts ...libp2p.Option) host.Host {
|
||||
t.Helper()
|
||||
h, err := libp2p.New(opts...)
|
||||
require.NoError(t, err)
|
||||
return h
|
||||
}
|
||||
|
||||
func TestEndToEndSimConnect(t *testing.T) {
|
||||
for _, useLegacyHolePunchingBehavior := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("legacy=%t", useLegacyHolePunchingBehavior), func(t *testing.T) {
|
||||
h1tr := &mockEventTracer{}
|
||||
h2tr := &mockEventTracer{}
|
||||
|
||||
router := &simconn.SimpleFirewallRouter{}
|
||||
relay := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.DisableRelay(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
|
||||
// Setup relay service
|
||||
_, err := relayv2.New(h)
|
||||
require.NoError(t, err)
|
||||
})),
|
||||
)
|
||||
|
||||
h1 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.EnableHolePunching(holepunch.WithTracer(h1tr), holepunch.DirectDialTimeout(100*time.Millisecond), SetLegacyBehavior(useLegacyHolePunchingBehavior)),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
)
|
||||
|
||||
h2 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
connectToRelay(&relay),
|
||||
libp2p.EnableHolePunching(holepunch.WithTracer(h2tr), holepunch.DirectDialTimeout(100*time.Millisecond), SetLegacyBehavior(useLegacyHolePunchingBehavior)),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
)
|
||||
|
||||
defer h1.Close()
|
||||
defer h2.Close()
|
||||
defer relay.Close()
|
||||
|
||||
// Wait for holepunch service to start
|
||||
waitForHolePunchingSvcActive(t, h1)
|
||||
waitForHolePunchingSvcActive(t, h2)
|
||||
|
||||
learnAddrs(h1, h2)
|
||||
pingAtoB(t, h1, h2)
|
||||
|
||||
// wait till a direct connection is complete
|
||||
ensureDirectConn(t, h1, h2)
|
||||
// ensure no hole-punching streams are open on either side
|
||||
ensureNoHolePunchingStream(t, h1, h2)
|
||||
var h2Events []*holepunch.Event
|
||||
require.Eventually(t,
|
||||
func() bool {
|
||||
h2Events = h2tr.getEvents()
|
||||
return len(h2Events) == 4
|
||||
},
|
||||
time.Second,
|
||||
100*time.Millisecond,
|
||||
)
|
||||
require.Equal(t, holepunch.DirectDialEvtT, h2Events[0].Type)
|
||||
require.Equal(t, holepunch.StartHolePunchEvtT, h2Events[1].Type)
|
||||
require.Equal(t, holepunch.HolePunchAttemptEvtT, h2Events[2].Type)
|
||||
require.Equal(t, holepunch.EndHolePunchEvtT, h2Events[3].Type)
|
||||
|
||||
h1Events := h1tr.getEvents()
|
||||
// We don't really expect a hole-punched connection to be established in this test,
|
||||
// as we probably don't get the timing right for the TCP simultaneous open.
|
||||
// From time to time, it still happens occasionally, and then we get a EndHolePunchEvtT here.
|
||||
if len(h1Events) != 2 && len(h1Events) != 3 {
|
||||
t.Fatal("expected either 2 or 3 events")
|
||||
}
|
||||
require.Equal(t, holepunch.StartHolePunchEvtT, h1Events[0].Type)
|
||||
require.Equal(t, holepunch.HolePunchAttemptEvtT, h1Events[1].Type)
|
||||
if len(h1Events) == 3 {
|
||||
require.Equal(t, holepunch.EndHolePunchEvtT, h1Events[2].Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +343,7 @@ func TestFailuresOnInitiator(t *testing.T) {
|
||||
errMsg: "failed to read CONNECT message",
|
||||
},
|
||||
"responder does NOT reply within hole punch deadline": {
|
||||
holePunchTimeout: 10 * time.Millisecond,
|
||||
holePunchTimeout: 200 * time.Millisecond,
|
||||
rhandler: func(s network.Stream) { time.Sleep(5 * time.Second) },
|
||||
errMsg: "i/o deadline reached",
|
||||
},
|
||||
@@ -213,13 +364,43 @@ func TestFailuresOnInitiator(t *testing.T) {
|
||||
defer func() { holepunch.StreamTimeout = cpy }()
|
||||
}
|
||||
|
||||
tr := &mockEventTracer{}
|
||||
h1, h2, relay, _ := makeRelayedHosts(t, nil, nil, false)
|
||||
router := &simconn.SimpleFirewallRouter{}
|
||||
relay := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.DisableRelay(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
|
||||
// Setup relay service
|
||||
_, err := relayv2.New(h)
|
||||
require.NoError(t, err)
|
||||
})),
|
||||
)
|
||||
|
||||
// h1 does not have a holepunching service because we'll mock the holepunching stream handler below.
|
||||
h1 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
connectToRelay(&relay),
|
||||
)
|
||||
|
||||
h2 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
connectToRelay(&relay),
|
||||
)
|
||||
|
||||
defer h1.Close()
|
||||
defer h2.Close()
|
||||
defer relay.Close()
|
||||
|
||||
opts := []holepunch.Option{holepunch.WithTracer(tr)}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tr := &mockEventTracer{}
|
||||
opts := []holepunch.Option{holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100 * time.Millisecond)}
|
||||
if tc.filter != nil {
|
||||
f := mockMaddrFilter{
|
||||
filterLocal: tc.filter,
|
||||
@@ -228,19 +409,18 @@ func TestFailuresOnInitiator(t *testing.T) {
|
||||
opts = append(opts, holepunch.WithAddrFilter(f))
|
||||
}
|
||||
|
||||
hps := addHolePunchService(t, h2, opts...)
|
||||
// wait until the hole punching protocol has actually started
|
||||
require.Eventually(t, func() bool {
|
||||
protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), holepunch.Protocol)
|
||||
return len(protos) > 0
|
||||
}, 200*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
hps := addHolePunchService(t, h2, []ma.Multiaddr{ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")}, opts...)
|
||||
// We are only holepunching from h2 to h1. Remove h2's holepunching stream handler to avoid confusion.
|
||||
h2.RemoveStreamHandler(holepunch.Protocol)
|
||||
if tc.rhandler != nil {
|
||||
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler)
|
||||
} else {
|
||||
h1.RemoveStreamHandler(holepunch.Protocol)
|
||||
}
|
||||
|
||||
require.NoError(t, h2.Connect(context.Background(), peer.AddrInfo{
|
||||
ID: h1.ID(),
|
||||
Addrs: h1.Addrs(),
|
||||
}))
|
||||
|
||||
err := hps.DirectConnect(h1.ID())
|
||||
require.Error(t, err)
|
||||
if tc.errMsg != "" {
|
||||
@@ -325,7 +505,7 @@ func TestFailuresOnResponder(t *testing.T) {
|
||||
}
|
||||
tr := &mockEventTracer{}
|
||||
|
||||
opts := []holepunch.Option{holepunch.WithTracer(tr)}
|
||||
opts := []holepunch.Option{holepunch.WithTracer(tr), holepunch.DirectDialTimeout(100 * time.Millisecond)}
|
||||
if tc.filter != nil {
|
||||
f := mockMaddrFilter{
|
||||
filterLocal: tc.filter,
|
||||
@@ -334,11 +514,49 @@ func TestFailuresOnResponder(t *testing.T) {
|
||||
opts = append(opts, holepunch.WithAddrFilter(f))
|
||||
}
|
||||
|
||||
h1, h2, relay, _ := makeRelayedHosts(t, opts, nil, false)
|
||||
router := &simconn.SimpleFirewallRouter{}
|
||||
relay := MustNewHost(t,
|
||||
quicSimConn(true, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.DisableRelay(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
libp2p.WithFxOption(fx.Invoke(func(h host.Host) {
|
||||
// Setup relay service
|
||||
_, err := relayv2.New(h)
|
||||
require.NoError(t, err)
|
||||
})),
|
||||
)
|
||||
h1 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.EnableHolePunching(opts...),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8000/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
connectToRelay(&relay),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
)
|
||||
|
||||
h2 := MustNewHost(t,
|
||||
quicSimConn(false, router),
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1")),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
connectToRelay(&relay),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
)
|
||||
|
||||
defer h1.Close()
|
||||
defer h2.Close()
|
||||
defer relay.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
|
||||
ID: h2.ID(),
|
||||
Addrs: h2.Addrs(),
|
||||
}))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
assert.Contains(c, h1.Mux().Protocols(), holepunch.Protocol)
|
||||
}, time.Second, 100*time.Millisecond)
|
||||
|
||||
s, err := h2.NewStream(network.WithAllowLimitedConn(context.Background(), "holepunch"), h1.ID(), holepunch.Protocol)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -407,108 +625,60 @@ func ensureDirectConn(t *testing.T, h1, h2 host.Host) {
|
||||
}, 5*time.Second, 50*time.Millisecond)
|
||||
}
|
||||
|
||||
func mkHostWithStaticAutoRelay(t *testing.T, relay host.Host) host.Host {
|
||||
if race.WithRace() {
|
||||
t.Skip("modifying manet.Private4 is racy")
|
||||
}
|
||||
pi := peer.AddrInfo{
|
||||
ID: relay.ID(),
|
||||
Addrs: relay.Addrs(),
|
||||
}
|
||||
type MockSourceIPSelector struct {
|
||||
ip atomic.Pointer[net.IP]
|
||||
}
|
||||
|
||||
cpy := manet.Private4
|
||||
manet.Private4 = []*net.IPNet{}
|
||||
defer func() { manet.Private4 = cpy }()
|
||||
func (m *MockSourceIPSelector) PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error) {
|
||||
return *m.ip.Load(), nil
|
||||
}
|
||||
|
||||
h, err := libp2p.New(
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")),
|
||||
libp2p.EnableRelay(),
|
||||
libp2p.EnableAutoRelayWithStaticRelays([]peer.AddrInfo{pi}),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// wait till we have a relay addr
|
||||
require.Eventually(t, func() bool {
|
||||
for _, a := range h.Addrs() {
|
||||
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
|
||||
return true
|
||||
func quicSimConn(isPubliclyReachably bool, router *simconn.SimpleFirewallRouter) libp2p.Option {
|
||||
m := &MockSourceIPSelector{}
|
||||
return libp2p.QUICReuse(
|
||||
quicreuse.NewConnManager,
|
||||
quicreuse.OverrideSourceIPSelector(func() (quicreuse.SourceIPSelector, error) {
|
||||
return m, nil
|
||||
}),
|
||||
quicreuse.OverrideListenUDP(func(network string, address *net.UDPAddr) (net.PacketConn, error) {
|
||||
m.ip.Store(&address.IP)
|
||||
c := simconn.NewSimConn(address, router)
|
||||
if isPubliclyReachably {
|
||||
router.AddPubliclyReachableNode(address, c)
|
||||
} else {
|
||||
router.AddNode(address, c)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, 5*time.Second, 50*time.Millisecond)
|
||||
return h
|
||||
return c, nil
|
||||
}))
|
||||
}
|
||||
|
||||
func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
|
||||
t.Helper()
|
||||
h1, _ = mkHostWithHolePunchSvc(t, h1opt...)
|
||||
var err error
|
||||
relay, err = libp2p.New(
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")),
|
||||
libp2p.DisableRelay(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = relayv2.New(relay)
|
||||
require.NoError(t, err)
|
||||
|
||||
// make sure the relay service is started and advertised by Identify
|
||||
h, err := libp2p.New(
|
||||
libp2p.NoListenAddrs,
|
||||
libp2p.Transport(tcp.NewTCPTransport),
|
||||
libp2p.DisableRelay(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer h.Close()
|
||||
require.NoError(t, h.Connect(context.Background(), peer.AddrInfo{ID: relay.ID(), Addrs: relay.Addrs()}))
|
||||
require.Eventually(t, func() bool {
|
||||
supported, err := h.Peerstore().SupportsProtocols(relay.ID(), proto.ProtoIDv2Hop)
|
||||
return err == nil && len(supported) > 0
|
||||
}, 3*time.Second, 100*time.Millisecond)
|
||||
|
||||
h2 = mkHostWithStaticAutoRelay(t, relay)
|
||||
if addHolePuncher {
|
||||
hps = addHolePunchService(t, h2, h2opt...)
|
||||
}
|
||||
|
||||
// h2 has a relay addr
|
||||
var raddr ma.Multiaddr
|
||||
for _, a := range h2.Addrs() {
|
||||
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
|
||||
raddr = a
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEmpty(t, raddr)
|
||||
// h1 should connect to the relay addr
|
||||
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
|
||||
ID: h2.ID(),
|
||||
Addrs: []ma.Multiaddr{raddr},
|
||||
}))
|
||||
return
|
||||
}
|
||||
|
||||
func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service {
|
||||
func addHolePunchService(t *testing.T, h host.Host, extraAddrs []ma.Multiaddr, opts ...holepunch.Option) *holepunch.Service {
|
||||
t.Helper()
|
||||
hps, err := holepunch.NewService(h, newMockIDService(t, h), func() []ma.Multiaddr {
|
||||
addrs := h.Addrs()
|
||||
addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) })
|
||||
return append(addrs, ma.StringCast("/ip4/1.2.3.4/tcp/1234"))
|
||||
addrs = append(addrs, extraAddrs...)
|
||||
return addrs
|
||||
}, opts...)
|
||||
require.NoError(t, err)
|
||||
return hps
|
||||
}
|
||||
|
||||
func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, *holepunch.Service) {
|
||||
t.Helper()
|
||||
h, err := libp2p.New(
|
||||
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")),
|
||||
libp2p.ForceReachabilityPrivate(),
|
||||
libp2p.ResourceManager(&network.NullResourceManager{}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
hps := addHolePunchService(t, h, opts...)
|
||||
return h, hps
|
||||
func waitForHolePunchingSvcActive(t *testing.T, h host.Host) {
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
assert.Contains(c, h.Mux().Protocols(), holepunch.Protocol)
|
||||
}, time.Second, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
// setLegacyBehavior is an option that controls the isClient behavior of the hole punching service.
|
||||
// Prior to https://github.com/libp2p/go-libp2p/pull/3044, go-libp2p would
|
||||
// pick the opposite roles for client/server a hole punch. Setting this to
|
||||
// true preserves that behavior.
|
||||
//
|
||||
// Currently, only exposed for testing purposes.
|
||||
// Do not set this unless you know what you are doing.
|
||||
func SetLegacyBehavior(legacyBehavior bool) holepunch.Option {
|
||||
return func(s *holepunch.Service) error {
|
||||
s.SetLegacyBehavior(legacyBehavior)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@@ -20,10 +20,7 @@ import (
|
||||
// ErrHolePunchActive is returned from DirectConnect when another hole punching attempt is currently running
|
||||
var ErrHolePunchActive = errors.New("another hole punching attempt to this peer is active")
|
||||
|
||||
const (
|
||||
dialTimeout = 5 * time.Second
|
||||
maxRetries = 3
|
||||
)
|
||||
const maxRetries = 3
|
||||
|
||||
// The holePuncher is run on the peer that's behind a NAT / Firewall.
|
||||
// It observes new incoming connections via a relay that it has a reservation with,
|
||||
@@ -40,6 +37,8 @@ type holePuncher struct {
|
||||
ids identify.IDService
|
||||
listenAddrs func() []ma.Multiaddr
|
||||
|
||||
directDialTimeout time.Duration
|
||||
|
||||
// active hole punches for deduplicating
|
||||
activeMx sync.Mutex
|
||||
active map[peer.ID]struct{}
|
||||
@@ -49,6 +48,11 @@ type holePuncher struct {
|
||||
|
||||
tracer *tracer
|
||||
filter AddrFilter
|
||||
|
||||
// Prior to https://github.com/libp2p/go-libp2p/pull/3044, go-libp2p would
|
||||
// pick the opposite roles for client/server a hole punch. Setting this to
|
||||
// true preserves that behavior
|
||||
legacyBehavior bool
|
||||
}
|
||||
|
||||
func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma.Multiaddr, tracer *tracer, filter AddrFilter) *holePuncher {
|
||||
@@ -59,6 +63,8 @@ func newHolePuncher(h host.Host, ids identify.IDService, listenAddrs func() []ma
|
||||
tracer: tracer,
|
||||
filter: filter,
|
||||
listenAddrs: listenAddrs,
|
||||
|
||||
legacyBehavior: true,
|
||||
}
|
||||
hp.ctx, hp.ctxCancel = context.WithCancel(context.Background())
|
||||
h.Network().Notify((*netNotifiee)(hp))
|
||||
@@ -86,6 +92,7 @@ func (hp *holePuncher) beginDirectConnect(p peer.ID) error {
|
||||
// It first attempts a direct dial (if we have a public address of that peer), and then
|
||||
// coordinates a hole punch over the given relay connection.
|
||||
func (hp *holePuncher) DirectConnect(p peer.ID) error {
|
||||
log.Debugw("beginDirectConnect", "host", hp.host.ID(), "peer", p)
|
||||
if err := hp.beginDirectConnect(p); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -102,14 +109,17 @@ func (hp *holePuncher) DirectConnect(p peer.ID) error {
|
||||
func (hp *holePuncher) directConnect(rp peer.ID) error {
|
||||
// short-circuit check to see if we already have a direct connection
|
||||
if getDirectConnection(hp.host, rp) != nil {
|
||||
log.Debugw("already connected", "host", hp.host.ID(), "peer", rp)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugw("attempting direct dial", "host", hp.host.ID(), "peer", rp, "addrs", hp.host.Peerstore().Addrs(rp))
|
||||
// short-circuit hole punching if a direct dial works.
|
||||
// attempt a direct connection ONLY if we have a public address for the remote peer
|
||||
for _, a := range hp.host.Peerstore().Addrs(rp) {
|
||||
if !isRelayAddress(a) && manet.IsPublicAddr(a) {
|
||||
forceDirectConnCtx := network.WithForceDirectDial(hp.ctx, "hole-punching")
|
||||
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout)
|
||||
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, hp.directDialTimeout)
|
||||
|
||||
tstart := time.Now()
|
||||
// This dials *all* addresses, public and private, from the peerstore.
|
||||
@@ -150,7 +160,13 @@ func (hp *holePuncher) directConnect(rp peer.ID) error {
|
||||
}
|
||||
hp.tracer.StartHolePunch(rp, addrs, rtt)
|
||||
hp.tracer.HolePunchAttempt(pi.ID)
|
||||
err := holePunchConnect(hp.ctx, hp.host, pi, true)
|
||||
ctx, cancel := context.WithTimeout(hp.ctx, hp.directDialTimeout)
|
||||
isClient := true
|
||||
if hp.legacyBehavior {
|
||||
isClient = false
|
||||
}
|
||||
err := holePunchConnect(ctx, hp.host, pi, isClient)
|
||||
cancel()
|
||||
dt := time.Since(start)
|
||||
hp.tracer.EndHolePunch(rp, dt, err)
|
||||
if err == nil {
|
||||
@@ -180,6 +196,7 @@ func (hp *holePuncher) initiateHolePunch(rp peer.ID) ([]ma.Multiaddr, []ma.Multi
|
||||
return nil, nil, 0, fmt.Errorf("failed to open hole-punching stream: %w", err)
|
||||
}
|
||||
defer str.Close()
|
||||
log.Debugf("initiateHolePunch: %s, %s", str.Conn().RemotePeer(), str.Conn().RemoteMultiaddr())
|
||||
|
||||
addr, obsAddr, rtt, err := hp.initiateHolePunchImpl(str)
|
||||
if err != nil {
|
||||
@@ -212,6 +229,7 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
|
||||
if len(obsAddrs) == 0 {
|
||||
return nil, nil, 0, errors.New("aborting hole punch initiation as we have no public address")
|
||||
}
|
||||
log.Debugf("initiating hole punch with %s", obsAddrs)
|
||||
|
||||
start := time.Now()
|
||||
if err := w.WriteMsg(&pb.HolePunch{
|
||||
|
@@ -19,6 +19,8 @@ import (
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
const defaultDirectDialTimeout = 10 * time.Second
|
||||
|
||||
// Protocol is the libp2p protocol for Hole Punching.
|
||||
const Protocol protocol.ID = "/libp2p/dcutr"
|
||||
|
||||
@@ -38,6 +40,13 @@ var ErrClosed = errors.New("hole punching service closing")
|
||||
|
||||
type Option func(*Service) error
|
||||
|
||||
func DirectDialTimeout(timeout time.Duration) Option {
|
||||
return func(s *Service) error {
|
||||
s.directDialTimeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// The Service runs on every node that supports the DCUtR protocol.
|
||||
type Service struct {
|
||||
ctx context.Context
|
||||
@@ -52,8 +61,9 @@ type Service struct {
|
||||
// publicly reachable relay addresses.
|
||||
listenAddrs func() []ma.Multiaddr
|
||||
|
||||
holePuncherMx sync.Mutex
|
||||
holePuncher *holePuncher
|
||||
directDialTimeout time.Duration
|
||||
holePuncherMx sync.Mutex
|
||||
holePuncher *holePuncher
|
||||
|
||||
hasPublicAddrsChan chan struct{}
|
||||
|
||||
@@ -61,6 +71,17 @@ type Service struct {
|
||||
filter AddrFilter
|
||||
|
||||
refCount sync.WaitGroup
|
||||
|
||||
// Prior to https://github.com/libp2p/go-libp2p/pull/3044, go-libp2p would
|
||||
// pick the opposite roles for client/server a hole punch. Setting this to
|
||||
// true preserves that behavior
|
||||
legacyBehavior bool
|
||||
}
|
||||
|
||||
// SetLegacyBehavior is only exposed for testing purposes.
|
||||
// Do not set this unless you know what you are doing.
|
||||
func (s *Service) SetLegacyBehavior(legacyBehavior bool) {
|
||||
s.legacyBehavior = legacyBehavior
|
||||
}
|
||||
|
||||
// NewService creates a new service that can be used for hole punching
|
||||
@@ -83,6 +104,8 @@ func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Mul
|
||||
ids: ids,
|
||||
listenAddrs: listenAddrs,
|
||||
hasPublicAddrsChan: make(chan struct{}),
|
||||
directDialTimeout: defaultDirectDialTimeout,
|
||||
legacyBehavior: true,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -102,7 +125,7 @@ func NewService(h host.Host, ids identify.IDService, listenAddrs func() []ma.Mul
|
||||
func (s *Service) waitForPublicAddr() {
|
||||
defer s.refCount.Done()
|
||||
|
||||
log.Debug("waiting until we have at least one public address", "peer", s.host.ID())
|
||||
log.Debugw("waiting until we have at least one public address", "peer", s.host.ID())
|
||||
|
||||
// TODO: We should have an event here that fires when identify discovers a new
|
||||
// address.
|
||||
@@ -114,7 +137,7 @@ func (s *Service) waitForPublicAddr() {
|
||||
defer t.Stop()
|
||||
for {
|
||||
if len(s.listenAddrs()) > 0 {
|
||||
log.Debug("Host now has a public address. Starting holepunch protocol.")
|
||||
log.Debugf("Host %s now has a public address (%s). Starting holepunch protocol.", s.host.ID(), s.host.Addrs())
|
||||
s.host.SetStreamHandler(Protocol, s.handleNewStream)
|
||||
break
|
||||
}
|
||||
@@ -137,6 +160,8 @@ func (s *Service) waitForPublicAddr() {
|
||||
return
|
||||
}
|
||||
s.holePuncher = newHolePuncher(s.host, s.ids, s.listenAddrs, s.tracer, s.filter)
|
||||
s.holePuncher.directDialTimeout = s.directDialTimeout
|
||||
s.holePuncher.legacyBehavior = s.legacyBehavior
|
||||
s.holePuncherMx.Unlock()
|
||||
close(s.hasPublicAddrsChan)
|
||||
}
|
||||
@@ -258,7 +283,13 @@ func (s *Service) handleNewStream(str network.Stream) {
|
||||
log.Debugw("starting hole punch", "peer", rp)
|
||||
start := time.Now()
|
||||
s.tracer.HolePunchAttempt(pi.ID)
|
||||
err = holePunchConnect(s.ctx, s.host, pi, false)
|
||||
ctx, cancel := context.WithTimeout(s.ctx, s.directDialTimeout)
|
||||
isClient := false
|
||||
if s.legacyBehavior {
|
||||
isClient = true
|
||||
}
|
||||
err = holePunchConnect(ctx, s.host, pi, isClient)
|
||||
cancel()
|
||||
dt := time.Since(start)
|
||||
s.tracer.EndHolePunch(rp, dt, err)
|
||||
s.tracer.HolePunchFinished("receiver", 1, addrs, ownAddrs, getDirectConnection(s.host, rp))
|
||||
|
@@ -51,10 +51,9 @@ func getDirectConnection(h host.Host, p peer.ID) network.Conn {
|
||||
func holePunchConnect(ctx context.Context, host host.Host, pi peer.AddrInfo, isClient bool) error {
|
||||
holePunchCtx := network.WithSimultaneousConnect(ctx, isClient, "hole-punching")
|
||||
forceDirectConnCtx := network.WithForceDirectDial(holePunchCtx, "hole-punching")
|
||||
dialCtx, cancel := context.WithTimeout(forceDirectConnCtx, dialTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := host.Connect(dialCtx, pi); err != nil {
|
||||
log.Debugw("holepunchConnect", "host", host.ID(), "peer", pi.ID, "addrs", pi.Addrs)
|
||||
if err := host.Connect(forceDirectConnCtx, pi); err != nil {
|
||||
log.Debugw("hole punch attempt with peer failed", "peer ID", pi.ID, "error", err)
|
||||
return err
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
manet "github.com/multiformats/go-multiaddr/net"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -37,6 +38,9 @@ type ConnManager struct {
|
||||
reuseUDP6 *reuse
|
||||
enableReuseport bool
|
||||
|
||||
listenUDP listenUDP
|
||||
sourceIPSelectorFn func() (SourceIPSelector, error)
|
||||
|
||||
enableMetrics bool
|
||||
registerer prometheus.Registerer
|
||||
|
||||
@@ -55,13 +59,24 @@ type quicListenerEntry struct {
|
||||
ln *quicListener
|
||||
}
|
||||
|
||||
func defaultListenUDP(network string, laddr *net.UDPAddr) (net.PacketConn, error) {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
func defaultSourceIPSelectorFn() (SourceIPSelector, error) {
|
||||
r, err := netroute.New()
|
||||
return &netrouteSourceIPSelector{routes: r}, err
|
||||
}
|
||||
|
||||
func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.TokenGeneratorKey, opts ...Option) (*ConnManager, error) {
|
||||
cm := &ConnManager{
|
||||
enableReuseport: true,
|
||||
quicListeners: make(map[string]quicListenerEntry),
|
||||
srk: statelessResetKey,
|
||||
tokenKey: tokenKey,
|
||||
registerer: prometheus.DefaultRegisterer,
|
||||
enableReuseport: true,
|
||||
quicListeners: make(map[string]quicListenerEntry),
|
||||
srk: statelessResetKey,
|
||||
tokenKey: tokenKey,
|
||||
registerer: prometheus.DefaultRegisterer,
|
||||
listenUDP: defaultListenUDP,
|
||||
sourceIPSelectorFn: defaultSourceIPSelectorFn,
|
||||
}
|
||||
for _, o := range opts {
|
||||
if err := o(cm); err != nil {
|
||||
@@ -76,8 +91,8 @@ func NewConnManager(statelessResetKey quic.StatelessResetKey, tokenKey quic.Toke
|
||||
cm.clientConfig = quicConf
|
||||
cm.serverConfig = serverConfig
|
||||
if cm.enableReuseport {
|
||||
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey)
|
||||
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey)
|
||||
cm.reuseUDP4 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn)
|
||||
cm.reuseUDP6 = newReuse(&statelessResetKey, &tokenKey, cm.listenUDP, cm.sourceIPSelectorFn)
|
||||
}
|
||||
return cm, nil
|
||||
}
|
||||
@@ -238,7 +253,7 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
|
||||
return tr, nil
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP(network, laddr)
|
||||
conn, err := c.listenUDP(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -320,7 +335,8 @@ func (c *ConnManager) TransportWithAssociationForDial(association any, network s
|
||||
case "udp6":
|
||||
laddr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
|
||||
}
|
||||
conn, err := net.ListenUDP(network, laddr)
|
||||
conn, err := c.listenUDP(network, laddr)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -1,9 +1,29 @@
|
||||
package quicreuse
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type Option func(*ConnManager) error
|
||||
|
||||
type listenUDP func(network string, laddr *net.UDPAddr) (net.PacketConn, error)
|
||||
|
||||
func OverrideListenUDP(f listenUDP) Option {
|
||||
return func(m *ConnManager) error {
|
||||
m.listenUDP = f
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func OverrideSourceIPSelector(f func() (SourceIPSelector, error)) Option {
|
||||
return func(m *ConnManager) error {
|
||||
m.sourceIPSelectorFn = f
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func DisableReuseport() Option {
|
||||
return func(m *ConnManager) error {
|
||||
m.enableReuseport = false
|
||||
|
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/routing"
|
||||
"github.com/libp2p/go-netroute"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
@@ -168,7 +167,11 @@ type reuse struct {
|
||||
closeChan chan struct{}
|
||||
gcStopChan chan struct{}
|
||||
|
||||
routes routing.Router
|
||||
listenUDP listenUDP
|
||||
|
||||
sourceIPSelectorFn func() (SourceIPSelector, error)
|
||||
|
||||
routes SourceIPSelector
|
||||
unicast map[string] /* IP.String() */ map[int] /* port */ *refcountedTransport
|
||||
// globalListeners contains transports that are listening on 0.0.0.0 / ::
|
||||
globalListeners map[int]*refcountedTransport
|
||||
@@ -181,15 +184,17 @@ type reuse struct {
|
||||
tokenGeneratorKey *quic.TokenGeneratorKey
|
||||
}
|
||||
|
||||
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey) *reuse {
|
||||
func newReuse(srk *quic.StatelessResetKey, tokenKey *quic.TokenGeneratorKey, listenUDP listenUDP, sourceIPSelectorFn func() (SourceIPSelector, error)) *reuse {
|
||||
r := &reuse{
|
||||
unicast: make(map[string]map[int]*refcountedTransport),
|
||||
globalListeners: make(map[int]*refcountedTransport),
|
||||
globalDialers: make(map[int]*refcountedTransport),
|
||||
closeChan: make(chan struct{}),
|
||||
gcStopChan: make(chan struct{}),
|
||||
statelessResetKey: srk,
|
||||
tokenGeneratorKey: tokenKey,
|
||||
unicast: make(map[string]map[int]*refcountedTransport),
|
||||
globalListeners: make(map[int]*refcountedTransport),
|
||||
globalDialers: make(map[int]*refcountedTransport),
|
||||
closeChan: make(chan struct{}),
|
||||
gcStopChan: make(chan struct{}),
|
||||
listenUDP: listenUDP,
|
||||
sourceIPSelectorFn: sourceIPSelectorFn,
|
||||
statelessResetKey: srk,
|
||||
tokenGeneratorKey: tokenKey,
|
||||
}
|
||||
go r.gc()
|
||||
return r
|
||||
@@ -250,7 +255,7 @@ func (r *reuse) gc() {
|
||||
} else {
|
||||
// Ignore the error, there's nothing we can do about
|
||||
// it.
|
||||
r.routes, _ = netroute.New()
|
||||
r.routes, _ = r.sourceIPSelectorFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -270,7 +275,7 @@ func (r *reuse) transportWithAssociationForDial(association any, network string,
|
||||
r.mutex.Unlock()
|
||||
|
||||
if router != nil {
|
||||
_, _, src, err := router.Route(raddr.IP)
|
||||
src, err := router.PreferredSourceIPForDestination(raddr)
|
||||
if err == nil && !src.IsUnspecified() {
|
||||
ip = &src
|
||||
}
|
||||
@@ -323,7 +328,7 @@ func (r *reuse) transportForDialLocked(association any, network string, source *
|
||||
case "udp6":
|
||||
addr = &net.UDPAddr{IP: net.IPv6zero, Port: 0}
|
||||
}
|
||||
conn, err := net.ListenUDP(network, addr)
|
||||
conn, err := r.listenUDP(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -389,7 +394,7 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP(network, laddr)
|
||||
conn, err := r.listenUDP(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -418,7 +423,7 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
|
||||
r.unicast[localAddr.IP.String()] = make(map[int]*refcountedTransport)
|
||||
// Assume the system's routes may have changed if we're adding a new listener.
|
||||
// Ignore the error, there's nothing we can do.
|
||||
r.routes, _ = netroute.New()
|
||||
r.routes, _ = r.sourceIPSelectorFn()
|
||||
}
|
||||
|
||||
// The kernel already checked that the laddr is not already listen
|
||||
@@ -432,3 +437,16 @@ func (r *reuse) Close() error {
|
||||
<-r.gcStopChan
|
||||
return nil
|
||||
}
|
||||
|
||||
type SourceIPSelector interface {
|
||||
PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error)
|
||||
}
|
||||
|
||||
type netrouteSourceIPSelector struct {
|
||||
routes routing.Router
|
||||
}
|
||||
|
||||
func (s *netrouteSourceIPSelector) PreferredSourceIPForDestination(dst *net.UDPAddr) (net.IP, error) {
|
||||
_, _, src, err := s.routes.Route(dst.IP)
|
||||
return src, err
|
||||
}
|
||||
|
@@ -61,7 +61,7 @@ func cleanup(t *testing.T, reuse *reuse) {
|
||||
}
|
||||
|
||||
func TestReuseListenOnAllIPv4(t *testing.T) {
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
require.Eventually(t, isGarbageCollectorRunning, 500*time.Millisecond, 50*time.Millisecond, "expected garbage collector to be running")
|
||||
cleanup(t, reuse)
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestReuseListenOnAllIPv4(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseListenOnAllIPv6(t *testing.T) {
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
require.Eventually(t, isGarbageCollectorRunning, 500*time.Millisecond, 50*time.Millisecond, "expected garbage collector to be running")
|
||||
cleanup(t, reuse)
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestReuseListenOnAllIPv6(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
@@ -100,7 +100,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenDialing(t *testing.T) {
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
cleanup(t, reuse)
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0")
|
||||
@@ -117,7 +117,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenListening(t *testing.T) {
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
cleanup(t, reuse)
|
||||
|
||||
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
|
||||
@@ -132,7 +132,7 @@ func TestReuseConnectionWhenListening(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
cleanup(t, reuse)
|
||||
|
||||
// dial any address
|
||||
@@ -166,7 +166,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) {
|
||||
if platformHasRoutingTables() {
|
||||
t.Skip("this test only works on platforms that support routing tables")
|
||||
}
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
cleanup(t, reuse)
|
||||
|
||||
router, err := netroute.New()
|
||||
@@ -203,7 +203,7 @@ func TestReuseGarbageCollect(t *testing.T) {
|
||||
maxUnusedDuration = 10 * maxUnusedDuration
|
||||
}
|
||||
|
||||
reuse := newReuse(nil, nil)
|
||||
reuse := newReuse(nil, nil, defaultListenUDP, defaultSourceIPSelectorFn)
|
||||
cleanup(t, reuse)
|
||||
|
||||
numGlobals := func() int {
|
||||
|
Reference in New Issue
Block a user