mirror of
https://github.com/libp2p/go-reuseport.git
synced 2025-09-27 11:12:10 +08:00
661 lines
14 KiB
Go
661 lines
14 KiB
Go
package reuseport
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func echo(c net.Conn) {
|
|
io.Copy(c, c)
|
|
c.Close()
|
|
}
|
|
|
|
func packetEcho(c net.PacketConn) {
|
|
defer c.Close()
|
|
buf := make([]byte, 65536)
|
|
for {
|
|
n, addr, err := c.ReadFrom(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if _, err := c.WriteTo(buf[:n], addr); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func acceptAndEcho(l net.Listener) {
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
go echo(c)
|
|
}
|
|
}
|
|
|
|
func CI() bool {
|
|
return os.Getenv("TRAVIS") == "true"
|
|
}
|
|
|
|
func TestStreamListenSamePort(t *testing.T) {
|
|
|
|
// any ports
|
|
any := [][]string{
|
|
[]string{"tcp", "0.0.0.0:0"},
|
|
[]string{"tcp4", "0.0.0.0:0"},
|
|
[]string{"tcp6", "[::]:0"},
|
|
|
|
[]string{"tcp", "127.0.0.1:0"},
|
|
[]string{"tcp", "[::1]:0"},
|
|
[]string{"tcp4", "127.0.0.1:0"},
|
|
[]string{"tcp6", "[::1]:0"},
|
|
}
|
|
|
|
// specific ports. off in CI
|
|
specific := [][]string{
|
|
[]string{"tcp", "127.0.0.1:5556"},
|
|
[]string{"tcp", "[::1]:5557"},
|
|
[]string{"tcp4", "127.0.0.1:5558"},
|
|
[]string{"tcp6", "[::1]:5559"},
|
|
}
|
|
|
|
testCases := any
|
|
if !CI() {
|
|
testCases = append(testCases, specific...)
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
network := tcase[0]
|
|
addr := tcase[1]
|
|
t.Log("testing", network, addr)
|
|
|
|
l1, err := Listen(network, addr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l1.Close()
|
|
t.Log("listening", l1.Addr())
|
|
|
|
l2, err := Listen(l1.Addr().Network(), l1.Addr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l2.Close()
|
|
t.Log("listening", l2.Addr())
|
|
|
|
l3, err := Listen(l2.Addr().Network(), l2.Addr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l3.Close()
|
|
t.Log("listening", l3.Addr())
|
|
|
|
if l1.Addr().String() != l2.Addr().String() {
|
|
t.Fatal("addrs should match", l1.Addr(), l2.Addr())
|
|
}
|
|
|
|
if l1.Addr().String() != l3.Addr().String() {
|
|
t.Fatal("addrs should match", l1.Addr(), l3.Addr())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDialSelf(t *testing.T) {
|
|
l, err := Listen("tcp4", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
_, err = Dial("tcp4", l.Addr().String(), l.Addr().String())
|
|
if err == nil {
|
|
t.Fatal("should have gotten an error for dialing self")
|
|
}
|
|
}
|
|
|
|
func TestPacketListenSamePort(t *testing.T) {
|
|
|
|
// any ports
|
|
any := [][]string{
|
|
[]string{"udp", "0.0.0.0:0"},
|
|
[]string{"udp4", "0.0.0.0:0"},
|
|
[]string{"udp6", "[::]:0"},
|
|
|
|
[]string{"udp", "127.0.0.1:0"},
|
|
[]string{"udp", "[::1]:0"},
|
|
[]string{"udp4", "127.0.0.1:0"},
|
|
[]string{"udp6", "[::1]:0"},
|
|
}
|
|
|
|
// specific ports. off in CI
|
|
specific := [][]string{
|
|
[]string{"udp", "127.0.0.1:5560"},
|
|
[]string{"udp", "[::1]:5561"},
|
|
[]string{"udp4", "127.0.0.1:5562"},
|
|
[]string{"udp6", "[::1]:5563"},
|
|
}
|
|
|
|
testCases := any
|
|
if !CI() {
|
|
testCases = append(testCases, specific...)
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
network := tcase[0]
|
|
addr := tcase[1]
|
|
t.Log("testing", network, addr)
|
|
|
|
l1, err := ListenPacket(network, addr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l1.Close()
|
|
t.Log("listening", l1.LocalAddr())
|
|
|
|
l2, err := ListenPacket(l1.LocalAddr().Network(), l1.LocalAddr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l2.Close()
|
|
t.Log("listening", l2.LocalAddr())
|
|
|
|
l3, err := ListenPacket(l2.LocalAddr().Network(), l2.LocalAddr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l3.Close()
|
|
t.Log("listening", l3.LocalAddr())
|
|
|
|
if l1.LocalAddr().String() != l2.LocalAddr().String() {
|
|
t.Fatal("addrs should match", l1.LocalAddr(), l2.LocalAddr())
|
|
}
|
|
|
|
if l1.LocalAddr().String() != l3.LocalAddr().String() {
|
|
t.Fatal("addrs should match", l1.LocalAddr(), l3.LocalAddr())
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestStreamListenDialSamePort(t *testing.T) {
|
|
|
|
any := [][]string{
|
|
[]string{"tcp", "0.0.0.0:0", "0.0.0.0:0"},
|
|
[]string{"tcp4", "0.0.0.0:0", "0.0.0.0:0"},
|
|
[]string{"tcp6", "[::]:0", "[::]:0"},
|
|
|
|
[]string{"tcp", "127.0.0.1:0", "127.0.0.1:0"},
|
|
[]string{"tcp4", "127.0.0.1:0", "127.0.0.1:0"},
|
|
[]string{"tcp6", "[::1]:0", "[::1]:0"},
|
|
}
|
|
|
|
specific := [][]string{
|
|
[]string{"tcp", "127.0.0.1:0", "127.0.0.1:5571"},
|
|
[]string{"tcp4", "127.0.0.1:0", "127.0.0.1:5573"},
|
|
[]string{"tcp6", "[::1]:0", "[::1]:5574"},
|
|
[]string{"tcp", "127.0.0.1:5570", "127.0.0.1:0"},
|
|
[]string{"tcp4", "127.0.0.1:5572", "127.0.0.1:0"},
|
|
[]string{"tcp6", "[::1]:5573", "[::1]:0"},
|
|
}
|
|
|
|
testCases := any
|
|
if !CI() {
|
|
testCases = append(testCases, specific...)
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
t.Log("testing", tcase)
|
|
network := tcase[0]
|
|
addr1 := tcase[1]
|
|
addr2 := tcase[2]
|
|
|
|
l1, err := Listen(network, addr1)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l1.Close()
|
|
t.Log("listening", l1.Addr())
|
|
|
|
l2, err := Listen(network, addr2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
defer l2.Close()
|
|
t.Log("listening", l2.Addr())
|
|
|
|
go acceptAndEcho(l1)
|
|
go acceptAndEcho(l2)
|
|
|
|
c1, err := Dial(network, l1.Addr().String(), l2.Addr().String())
|
|
if err != nil {
|
|
t.Fatal(err, network, l1.Addr().String(), l2.Addr().String())
|
|
continue
|
|
}
|
|
defer c1.Close()
|
|
t.Log("dialed", c1, c1.LocalAddr(), c1.RemoteAddr())
|
|
|
|
if getPort(l1.Addr()) != getPort(c1.LocalAddr()) {
|
|
t.Fatal("addrs should match", l1.Addr(), c1.LocalAddr())
|
|
}
|
|
|
|
if getPort(l2.Addr()) != getPort(c1.RemoteAddr()) {
|
|
t.Fatal("addrs should match", l2.Addr(), c1.RemoteAddr())
|
|
}
|
|
|
|
hello1 := []byte("hello world")
|
|
hello2 := make([]byte, len(hello1))
|
|
if _, err := c1.Write(hello1); err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
|
|
if _, err := c1.Read(hello2); err != nil {
|
|
t.Fatal(err)
|
|
continue
|
|
}
|
|
|
|
if !bytes.Equal(hello1, hello2) {
|
|
t.Fatal("echo failed", string(hello1), "!=", string(hello2))
|
|
}
|
|
t.Log("echoed", string(hello2))
|
|
c1.Close()
|
|
}
|
|
}
|
|
|
|
func TestStreamListenDialSamePortStressManyMsgs(t *testing.T) {
|
|
testCases := [][]string{
|
|
[]string{"tcp", "127.0.0.1:0"},
|
|
[]string{"tcp4", "127.0.0.1:0"},
|
|
[]string{"tcp6", "[::]:0"},
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
t.Run(tcase[0], func(t *testing.T) {
|
|
subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 2, 1000)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStreamListenDialSamePortStressManyNodes(t *testing.T) {
|
|
testCases := [][]string{
|
|
[]string{"tcp", "127.0.0.1:0"},
|
|
[]string{"tcp4", "127.0.0.1:0"},
|
|
[]string{"tcp6", "[::]:0"},
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
t.Run(tcase[0], func(t *testing.T) {
|
|
subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 50, 1)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStreamListenDialSamePortStressManyMsgsManyNodes(t *testing.T) {
|
|
testCases := [][]string{
|
|
[]string{"tcp", "127.0.0.1:0"},
|
|
[]string{"tcp4", "127.0.0.1:0"},
|
|
[]string{"tcp6", "[::]:0"},
|
|
}
|
|
for _, tcase := range testCases {
|
|
t.Run(tcase[0], func(t *testing.T) {
|
|
subestStreamListenDialSamePortStress(t, tcase[0], tcase[1], 50, 50)
|
|
})
|
|
}
|
|
}
|
|
|
|
func subestStreamListenDialSamePortStress(t *testing.T, network, addr string, nodes int, msgs int) {
|
|
|
|
var ls []net.Listener
|
|
for i := 0; i < nodes; i++ {
|
|
l, err := Listen(network, addr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer l.Close()
|
|
go acceptAndEcho(l)
|
|
ls = append(ls, l)
|
|
}
|
|
|
|
// connect them all
|
|
var cs []net.Conn
|
|
for i := 0; i < nodes; i++ {
|
|
for j := 0; j < i; j++ {
|
|
if i == j {
|
|
continue // cannot do self.
|
|
}
|
|
|
|
ia := ls[i].Addr().String()
|
|
ja := ls[j].Addr().String()
|
|
c, err := Dial(network, ia, ja)
|
|
if err != nil {
|
|
t.Fatal(network, ia, ja, err)
|
|
}
|
|
defer c.Close()
|
|
cs = append(cs, c)
|
|
}
|
|
}
|
|
|
|
errs := make(chan error)
|
|
|
|
send := func(c net.Conn, buf []byte) {
|
|
if _, err := c.Write(buf); err != nil {
|
|
errs <- err
|
|
}
|
|
}
|
|
|
|
recv := func(c net.Conn, buf []byte) {
|
|
buf2 := make([]byte, len(buf))
|
|
if _, err := c.Read(buf2); err != nil {
|
|
errs <- err
|
|
}
|
|
if !bytes.Equal(buf, buf2) {
|
|
errs <- fmt.Errorf("recv failure: %s <--> %s -- %s %s", c.RemoteAddr(), c.LocalAddr(), buf, buf2)
|
|
}
|
|
}
|
|
|
|
t.Logf("sending %d msgs per conn", msgs)
|
|
go func() {
|
|
var wg sync.WaitGroup
|
|
for _, c := range cs {
|
|
wg.Add(1)
|
|
go func(c net.Conn) {
|
|
defer wg.Done()
|
|
for i := 0; i < msgs; i++ {
|
|
msg := []byte(fmt.Sprintf("message %d", i))
|
|
send(c, msg)
|
|
recv(c, msg)
|
|
}
|
|
}(c)
|
|
}
|
|
wg.Wait()
|
|
close(errs)
|
|
}()
|
|
|
|
for err := range errs {
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPacketListenDialSamePort(t *testing.T) {
|
|
t.Skip("these don't pass reliably.")
|
|
|
|
any := [][]string{
|
|
[]string{"udp", "0.0.0.0:0", "0.0.0.0:0"},
|
|
[]string{"udp4", "0.0.0.0:0", "0.0.0.0:0"},
|
|
[]string{"udp6", "[::]:0", "[::]:0"},
|
|
|
|
[]string{"udp", "127.0.0.1:0", "127.0.0.1:0"},
|
|
[]string{"udp4", "127.0.0.1:0", "127.0.0.1:0"},
|
|
[]string{"udp6", "[::1]:0", "[::1]:0"},
|
|
}
|
|
|
|
specific := [][]string{
|
|
[]string{"udp", "127.0.0.1:5670", "127.0.0.1:5671"},
|
|
[]string{"udp4", "127.0.0.1:5672", "127.0.0.1:5673"},
|
|
[]string{"udp6", "[::1]:5673", "[::1]:5674"},
|
|
}
|
|
|
|
testCases := any
|
|
if !CI() {
|
|
testCases = append(testCases, specific...)
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
t.Run(tcase[0]+"/"+tcase[1], func(t *testing.T) {
|
|
network := tcase[0]
|
|
addr1 := tcase[1]
|
|
addr2 := tcase[2]
|
|
|
|
l1, err := ListenPacket(network, addr1)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer l1.Close()
|
|
t.Log("listening", l1.LocalAddr())
|
|
|
|
l2, err := ListenPacket(network, addr2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer l2.Close()
|
|
t.Log("listening", l2.LocalAddr())
|
|
|
|
go packetEcho(l1)
|
|
go packetEcho(l2)
|
|
|
|
c1, err := Dial(network, l1.LocalAddr().String(), l2.LocalAddr().String())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer c1.Close()
|
|
t.Log("dialed", c1.LocalAddr(), c1.RemoteAddr())
|
|
|
|
if getPort(l1.LocalAddr()) != getPort(c1.LocalAddr()) {
|
|
t.Fatal("addrs should match", l1.LocalAddr(), c1.LocalAddr())
|
|
}
|
|
|
|
if getPort(l2.LocalAddr()) != getPort(c1.RemoteAddr()) {
|
|
t.Fatal("addrs should match", l2.LocalAddr(), c1.RemoteAddr())
|
|
}
|
|
|
|
hello1 := []byte("hello world")
|
|
hello2 := make([]byte, len(hello1))
|
|
if _, err := c1.Write(hello1); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if err := c1.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if _, err := c1.Read(hello2); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !bytes.Equal(hello1, hello2) {
|
|
t.Fatal("echo failed", string(hello1), "!=", string(hello2))
|
|
}
|
|
t.Log("echoed", string(hello2))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDialRespectsTimeout(t *testing.T) {
|
|
|
|
testCases := [][]string{
|
|
[]string{"tcp", "127.0.0.1:6780", "1.2.3.4:6781"},
|
|
[]string{"tcp4", "127.0.0.1:6782", "1.2.3.4:6783"},
|
|
[]string{"tcp6", "[::1]:6784", "[::2]:6785"},
|
|
}
|
|
|
|
timeout := 50 * time.Millisecond
|
|
|
|
for _, tcase := range testCases {
|
|
network := tcase[0]
|
|
laddr := tcase[1]
|
|
raddr := tcase[2]
|
|
|
|
// l, err := Listen(network, raddr)
|
|
// if err != nil {
|
|
// t.Error("without a listener it wont work")
|
|
// continue
|
|
// }
|
|
// defer l.Close()
|
|
|
|
nladdr, err := ResolveAddr(network, laddr)
|
|
if err != nil {
|
|
t.Error("failed to resolve addr", network, laddr, err)
|
|
continue
|
|
}
|
|
t.Log("testing", network, nladdr, raddr)
|
|
|
|
d := Dialer{
|
|
D: net.Dialer{
|
|
LocalAddr: nil,
|
|
Timeout: timeout,
|
|
},
|
|
}
|
|
|
|
errs := make(chan error)
|
|
go func() {
|
|
c, err := d.Dial(network, raddr)
|
|
if err == nil {
|
|
c.Close()
|
|
errs <- errors.New("should've not connected")
|
|
return
|
|
}
|
|
close(errs) // success!
|
|
}()
|
|
|
|
ErrDrain:
|
|
select {
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("took too long")
|
|
case err, more := <-errs:
|
|
if !more {
|
|
break
|
|
}
|
|
t.Error(err)
|
|
goto ErrDrain
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func TestDialRespectsContext(t *testing.T) {
|
|
|
|
testCases := [][]string{
|
|
[]string{"tcp", "127.0.0.1:6780", "1.2.3.4:6781"},
|
|
[]string{"tcp4", "127.0.0.1:6782", "1.2.3.4:6783"},
|
|
[]string{"tcp6", "[::1]:6784", "[::2]:6785"},
|
|
}
|
|
|
|
timeout := 10 * time.Second
|
|
|
|
ctxTimeout := 50 * time.Millisecond
|
|
|
|
for _, tcase := range testCases {
|
|
network := tcase[0]
|
|
laddr := tcase[1]
|
|
raddr := tcase[2]
|
|
|
|
// l, err := Listen(network, raddr)
|
|
// if err != nil {
|
|
// t.Error("without a listener it wont work")
|
|
// continue
|
|
// }
|
|
// defer l.Close()
|
|
|
|
nladdr, err := ResolveAddr(network, laddr)
|
|
if err != nil {
|
|
t.Fatal("failed to resolve addr", network, laddr, err)
|
|
}
|
|
t.Log("testing", network, nladdr, raddr)
|
|
|
|
d := Dialer{
|
|
D: net.Dialer{
|
|
LocalAddr: nil,
|
|
Timeout: timeout,
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
|
|
defer cancel()
|
|
|
|
errs := make(chan error, 1)
|
|
go func(ctx context.Context) {
|
|
c, err := d.DialContext(ctx, network, raddr)
|
|
if err == nil {
|
|
c.Close()
|
|
errs <- errors.New("should've not connected")
|
|
return
|
|
}
|
|
close(errs) // success!
|
|
}(ctx)
|
|
|
|
ErrDrain:
|
|
select {
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("took too long")
|
|
case err, more := <-errs:
|
|
if !more {
|
|
break
|
|
}
|
|
t.Error(err)
|
|
goto ErrDrain
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func TestUnixNotSupported(t *testing.T) {
|
|
|
|
testCases := [][]string{
|
|
[]string{"unix", "/tmp/foo"},
|
|
}
|
|
|
|
for _, tcase := range testCases {
|
|
network := tcase[0]
|
|
addr := tcase[1]
|
|
t.Log("testing", network, addr)
|
|
|
|
l, err := Listen(network, addr)
|
|
if err == nil {
|
|
l.Close()
|
|
t.Fatal("unix supported")
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestOpenFDs(t *testing.T) {
|
|
// this is a totally ad-hoc limit. test harnesses may add fds.
|
|
// but if this is really much higher than 20, there's obviously leaks.
|
|
limit := 20
|
|
start := time.Now()
|
|
for countOpenFiles(t) > limit {
|
|
<-time.After(time.Second)
|
|
t.Log("open fds:", countOpenFiles(t))
|
|
if time.Now().Sub(start) > (time.Second * 15) {
|
|
t.Error("fd leak!")
|
|
}
|
|
}
|
|
}
|
|
|
|
func countOpenFiles(t *testing.T) int {
|
|
out, err := exec.Command("/bin/sh", "-c", fmt.Sprintf("lsof -p %v", os.Getpid())).Output()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return bytes.Count(out, []byte("\n"))
|
|
}
|
|
|
|
func getPort(a net.Addr) string {
|
|
if a == nil {
|
|
return ""
|
|
}
|
|
s := strings.Split(a.String(), ":")
|
|
if len(s) > 1 {
|
|
return s[1]
|
|
}
|
|
return ""
|
|
}
|