diff --git a/biz/master/shell/mgr.go b/biz/master/shell/mgr.go index 199545d..6aa4b0c 100644 --- a/biz/master/shell/mgr.go +++ b/biz/master/shell/mgr.go @@ -2,6 +2,7 @@ package shell import ( "github.com/VaalaCat/frp-panel/pb" + "github.com/VaalaCat/frp-panel/services/app" "github.com/VaalaCat/frp-panel/utils" ) @@ -31,10 +32,9 @@ func (m *PTYMgr) Add(sessionID string, conn pb.Master_PTYConnectServer) { m.doneMap.Store(sessionID, make(chan bool)) } -func NewPTYMgr() *PTYMgr { +func NewPTYMgr() app.ShellPTYMgr { return &PTYMgr{ SyncMap: &utils.SyncMap[string, pb.Master_PTYConnectServer]{}, doneMap: &utils.SyncMap[string, chan bool]{}, } } - diff --git a/biz/master/streamlog/collect_log.go b/biz/master/streamlog/collect_log.go index 6ae7a88..4d8962f 100644 --- a/biz/master/streamlog/collect_log.go +++ b/biz/master/streamlog/collect_log.go @@ -28,7 +28,7 @@ func (c *ClientLogManager) GetClientLock(clientId string) *sync.Mutex { return lock } -func NewClientLogManager() *ClientLogManager { +func NewClientLogManager() app.ClientLogManager { return &ClientLogManager{ SyncMap: &utils.SyncMap[string, chan string]{}, clientLocksMap: &utils.SyncMap[string, *sync.Mutex]{}, diff --git a/cmd/frpp/shared/client.go b/cmd/frpp/shared/client.go index a729929..1de1af4 100644 --- a/cmd/frpp/shared/client.go +++ b/cmd/frpp/shared/client.go @@ -8,6 +8,7 @@ import ( "github.com/VaalaCat/frp-panel/defs" "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" + "github.com/VaalaCat/frp-panel/services/rpc" "github.com/VaalaCat/frp-panel/services/rpcclient" "github.com/VaalaCat/frp-panel/services/tunnel" "github.com/VaalaCat/frp-panel/services/watcher" @@ -25,7 +26,8 @@ type runClientParam struct { AppInstance app.Application TaskManager watcher.Client `name:"clientTaskManager"` WorkersManager app.WorkersManager - Cfg conf.Config + + Cfg conf.Config } func runClient(param runClientParam) { @@ -53,7 +55,7 @@ func runClient(param runClientParam) { param.Lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { appInstance.SetRPCCred(NewClientCred(appInstance)) - appInstance.SetMasterCli(NewClientMasterCli(appInstance)) + appInstance.SetMasterCli(rpc.NewMasterCli(appInstance)) appInstance.SetClientController(tunnel.NewClientController()) cliRpcHandler := rpcclient.NewClientRPCHandler( diff --git a/cmd/frpp/shared/modules.go b/cmd/frpp/shared/modules.go index 4d21292..5a2ea74 100644 --- a/cmd/frpp/shared/modules.go +++ b/cmd/frpp/shared/modules.go @@ -1,6 +1,13 @@ package shared import ( + bizmaster "github.com/VaalaCat/frp-panel/biz/master" + "github.com/VaalaCat/frp-panel/biz/master/shell" + "github.com/VaalaCat/frp-panel/biz/master/streamlog" + bizserver "github.com/VaalaCat/frp-panel/biz/server" + "github.com/VaalaCat/frp-panel/conf" + "github.com/VaalaCat/frp-panel/services/rpc" + "github.com/VaalaCat/frp-panel/utils/logger" "go.uber.org/fx" ) @@ -14,21 +21,22 @@ var ( serverMod = fx.Module("cmd.server", fx.Provide( fx.Annotate(NewServerAPI, fx.ResultTags(`name:"serverApiService"`)), - fx.Annotate(NewServerRouter, fx.ResultTags(`name:"serverRouter"`)), + fx.Annotate(bizserver.NewRouter, fx.ResultTags(`name:"serverRouter"`)), fx.Annotate(NewWatcher, fx.ResultTags(`name:"serverTaskManager"`)), )) masterMod = fx.Module("cmd.master", fx.Provide( NewPermissionManager, NewEnforcer, - NewListenerOptions, + conf.GetListener, NewDBManager, NewWSListener, NewMasterTLSConfig, NewWSUpgrader, - NewClientLogManager, + streamlog.NewClientLogManager, + // wireguard.NewWireGuardManager, fx.Annotate(NewWatcher, fx.ResultTags(`name:"masterTaskManager"`)), - fx.Annotate(NewMasterRouter, fx.ResultTags(`name:"masterRouter"`)), + fx.Annotate(bizmaster.NewRouter, fx.ResultTags(`name:"masterRouter"`)), fx.Annotate(NewHTTPMasterService, fx.ResultTags(`name:"httpMasterService"`)), fx.Annotate(NewHTTPMasterService, fx.ResultTags(`name:"wsMasterService"`)), fx.Annotate(NewTLSMasterService, fx.ResultTags(`name:"tlsMasterService"`)), @@ -38,12 +46,15 @@ var ( )) commonMod = fx.Module("common", fx.Provide( + logger.Logger, + logger.Instance, NewLogHookManager, - NewPTYManager, + shell.NewPTYMgr, NewBaseApp, NewContext, - NewClientsManager, - NewAutoJoin, + NewAndFinishNormalContext, + rpc.NewClientsManager, + NewAutoJoin, // provide final config fx.Annotate(NewPatchedConfig, fx.ResultTags(`name:"argsPatchedConfig"`)), )) ) diff --git a/cmd/frpp/shared/providers.go b/cmd/frpp/shared/providers.go index 0f66062..6fe3e87 100644 --- a/cmd/frpp/shared/providers.go +++ b/cmd/frpp/shared/providers.go @@ -3,7 +3,6 @@ package shared import ( "context" "crypto/tls" - "embed" "net" "net/http" "os" @@ -11,10 +10,6 @@ import ( "sync" bizcommon "github.com/VaalaCat/frp-panel/biz/common" - bizmaster "github.com/VaalaCat/frp-panel/biz/master" - "github.com/VaalaCat/frp-panel/biz/master/shell" - "github.com/VaalaCat/frp-panel/biz/master/streamlog" - bizserver "github.com/VaalaCat/frp-panel/biz/server" "github.com/VaalaCat/frp-panel/conf" "github.com/VaalaCat/frp-panel/defs" "github.com/VaalaCat/frp-panel/models" @@ -43,12 +38,14 @@ import ( "gorm.io/gorm" ) -func NewLogHookManager() app.StreamLogHookMgr { - return &bizcommon.HookMgr{} +type Finish struct { + fx.Out + + Context context.Context } -func NewPTYManager() app.ShellPTYMgr { - return shell.NewPTYMgr() +func NewLogHookManager() app.StreamLogHookMgr { + return &bizcommon.HookMgr{} } func NewBaseApp(param struct { @@ -68,10 +65,6 @@ func NewBaseApp(param struct { return appInstance } -func NewClientsManager() app.ClientsManager { - return rpc.NewClientsManager() -} - func NewPatchedConfig(param struct { fx.In @@ -88,8 +81,16 @@ func NewContext(appInstance app.Application) *app.Context { return app.NewContext(context.Background(), appInstance) } -func NewClientLogManager() app.ClientLogManager { - return streamlog.NewClientLogManager() +func NewAndFinishNormalContext(param struct { + fx.In + + Ctx *app.Context + Cfg conf.Config +}) Finish { + + return Finish{ + Context: param.Ctx, + } } func NewDBManager(ctx *app.Context, appInstance app.Application) app.DBManager { @@ -151,14 +152,6 @@ func NewMasterTLSConfig(ctx *app.Context) *tls.Config { return dao.NewQuery(ctx).InitCert(conf.GetCertTemplate(ctx.GetApp().GetConfig())) } -func NewMasterRouter(fs embed.FS, appInstance app.Application) *gin.Engine { - return bizmaster.NewRouter(fs, appInstance) -} - -func NewListenerOptions(ctx *app.Context, cfg conf.Config) conf.LisOpt { - return conf.GetListener(ctx, cfg) -} - func NewTLSMasterService(appInstance app.Application, masterTLSConfig *tls.Config) master.MasterService { return master.NewMasterService(appInstance, credentials.NewTLS(masterTLSConfig)) } @@ -167,14 +160,6 @@ func NewHTTPMasterService(appInstance app.Application) master.MasterService { return master.NewMasterService(appInstance, insecure.NewCredentials()) } -func NewServerMasterCli(appInstance app.Application) app.MasterClient { - return rpc.NewMasterCli(appInstance) -} - -func NewClientMasterCli(appInstance app.Application) app.MasterClient { - return rpc.NewMasterCli(appInstance) -} - func NewMux(param struct { fx.In @@ -214,10 +199,6 @@ func NewWSUpgrader(ctx *app.Context, cfg conf.Config) *websocket.Upgrader { } } -func NewServerRouter(appInstance app.Application) *gin.Engine { - return bizserver.NewRouter(appInstance) -} - func NewServerAPI(param struct { fx.In Ctx *app.Context @@ -305,7 +286,7 @@ func NewAutoJoin(param struct { Ctx *app.Context Cfg conf.Config `name:"argsPatchedConfig"` CommonArgs CommonArgs -}) conf.Config { +}) conf.Config { // provide final config var ( ctx = param.Ctx clientID = param.Cfg.Client.ID diff --git a/cmd/frpp/shared/server.go b/cmd/frpp/shared/server.go index c7e20aa..e41e6d9 100644 --- a/cmd/frpp/shared/server.go +++ b/cmd/frpp/shared/server.go @@ -8,6 +8,7 @@ import ( "github.com/VaalaCat/frp-panel/defs" "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" + "github.com/VaalaCat/frp-panel/services/rpc" "github.com/VaalaCat/frp-panel/services/rpcclient" "github.com/VaalaCat/frp-panel/services/tunnel" "github.com/VaalaCat/frp-panel/services/watcher" @@ -52,7 +53,7 @@ func runServer(param runServerParam) { OnStart: func(ctx context.Context) error { logger.Logger(ctx).Infof("start to run server, serverID: [%s]", clientID) appInstance.SetRPCCred(NewServerCred(appInstance)) - appInstance.SetMasterCli(NewServerMasterCli(appInstance)) + appInstance.SetMasterCli(rpc.NewMasterCli(appInstance)) cliHandler := rpcclient.NewClientRPCHandler( appInstance, diff --git a/defs/const.go b/defs/const.go index 3d9ed62..7aa650d 100644 --- a/defs/const.go +++ b/defs/const.go @@ -175,3 +175,12 @@ const ( FrpProxyAnnotationsKey_WorkerId = "worker_id" FrpProxyAnnotationsKey_LoadBalancerGroup = "load_balancer_group" ) + +const ( + PlaceholderPrivateKey = "" + PlaceholderPeerVPNAddressCIDR = "" +) + +var VaalaMagicBytes = []byte("vaala-ping") + +const VaalaMagicBytesCookie = uint32(1630367849) diff --git a/defs/types_rpc.go b/defs/types_rpc.go new file mode 100644 index 0000000..5143914 --- /dev/null +++ b/defs/types_rpc.go @@ -0,0 +1,9 @@ +package defs + +import "github.com/VaalaCat/frp-panel/pb" + +type Connector struct { + CliID string + Conn pb.Master_ServerSendServer + CliType string +} diff --git a/services/app/app_impl.go b/services/app/app_impl.go index 0e53398..d902fe6 100644 --- a/services/app/app_impl.go +++ b/services/app/app_impl.go @@ -1,10 +1,14 @@ package app import ( + "context" "sync" "github.com/VaalaCat/frp-panel/conf" + "github.com/VaalaCat/frp-panel/pb" + "github.com/VaalaCat/frp-panel/utils/logger" "github.com/casbin/casbin/v2" + "github.com/sirupsen/logrus" "google.golang.org/grpc/credentials" ) @@ -28,6 +32,33 @@ type application struct { enforcer *casbin.Enforcer workerExecManager WorkerExecManager workersManager WorkersManager + + loggerInstance *logrus.Logger +} + +func (a *application) GetClientBase() *pb.ClientBase { + return &pb.ClientBase{ + ClientId: a.GetConfig().Client.ID, + ClientSecret: a.GetConfig().Client.Secret, + } +} + +func (a *application) GetServerBase() *pb.ServerBase { + return &pb.ServerBase{ + ServerId: a.GetConfig().Client.ID, + ServerSecret: a.GetConfig().Client.Secret, + } +} + +func (a *application) SetLogger(l *logrus.Logger) { + a.loggerInstance = l +} + +func (a *application) Logger(ctx context.Context) *logrus.Entry { + if a.loggerInstance == nil { + return logger.Logger(ctx) + } + return a.loggerInstance.WithContext(ctx) } // GetWorkersManager implements Application. diff --git a/services/app/application.go b/services/app/application.go index 21d0ae2..3fae0b0 100644 --- a/services/app/application.go +++ b/services/app/application.go @@ -5,8 +5,10 @@ import ( "sync" "github.com/VaalaCat/frp-panel/conf" + "github.com/VaalaCat/frp-panel/pb" "github.com/casbin/casbin/v2" "github.com/gin-gonic/gin" + "github.com/sirupsen/logrus" "google.golang.org/grpc/credentials" ) @@ -47,11 +49,16 @@ type Application interface { SetWorkerExecManager(WorkerExecManager) GetWorkersManager() WorkersManager SetWorkersManager(WorkersManager) + SetLogger(*logrus.Logger) + Logger(ctx context.Context) *logrus.Entry + GetClientBase() *pb.ClientBase + GetServerBase() *pb.ServerBase } type Context struct { context.Context - appInstance Application + appInstance Application + loggerInstance *logrus.Logger } func (c *Context) GetApp() Application { @@ -70,6 +77,31 @@ func (c *Context) Background() *Context { return NewContext(context.Background(), c.appInstance) } +func (c *Context) Copy() *Context { + return NewContext(c.Context, c.appInstance) +} + +func (c *Context) CopyWithCancel() (*Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(c.Context) + return NewContext(ctx, c.appInstance), cancel +} + +func (c *Context) BackgroundWithCancel() (*Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + return NewContext(ctx, c.appInstance), cancel +} + +func (c *Context) Logger() *logrus.Entry { + if c.loggerInstance != nil { + return c.loggerInstance.WithContext(c) + } + return c.GetApp().Logger(c) +} + +func (c *Context) SetLogger(logger *logrus.Logger) { + c.loggerInstance = logger +} + func NewContext(c context.Context, appInstance Application) *Context { return &Context{ Context: c, diff --git a/services/app/provider.go b/services/app/provider.go index 0658fa1..249738e 100644 --- a/services/app/provider.go +++ b/services/app/provider.go @@ -78,19 +78,13 @@ type DBManager interface { } type ClientsManager interface { - Get(cliID string) *Connector + Get(cliID string) *defs.Connector Set(cliID, clientType string, sender pb.Master_ServerSendServer) Remove(cliID string) ClientAddr(cliID string) string ConnectTime(cliID string) (time.Time, bool) } -type Connector struct { - CliID string - Conn pb.Master_ServerSendServer - CliType string -} - type Service interface { Run() Stop() diff --git a/services/port/manager.go b/services/port/manager.go deleted file mode 100644 index 5a71675..0000000 --- a/services/port/manager.go +++ /dev/null @@ -1,48 +0,0 @@ -package tunnel - -import ( - "context" - - "github.com/VaalaCat/frp-panel/defs" - "github.com/VaalaCat/frp-panel/utils" - "github.com/VaalaCat/frp-panel/utils/logger" -) - -type PortManager interface { - ClaimWorkerPort(c context.Context, workerID string) int32 - GetWorkerPort(c context.Context, workerID string) (int32, bool) -} - -type portManager struct { - portMap *utils.SyncMap[string, int32] -} - -func (p *portManager) ClaimWorkerPort(c context.Context, workerID string) int32 { - port, err := utils.GetAvailablePort(defs.DefaultHostName) - if err != nil { - logger.Logger(c).WithError(err).Panic("get available port failed") - } - p.portMap.Store(workerID, int32(port)) - return int32(port) -} - -func (p *portManager) GetWorkerPort(c context.Context, workerID string) (int32, bool) { - return p.portMap.Load(workerID) -} - -var ( - mgr PortManager -) - -func NewPortManager() PortManager { - return &portManager{ - portMap: &utils.SyncMap[string, int32]{}, - } -} - -func GetPortManager() PortManager { - if mgr == nil { - mgr = NewPortManager() - } - return mgr -} diff --git a/services/rpc/client_manager.go b/services/rpc/client_manager.go index e6951a2..f2349ce 100644 --- a/services/rpc/client_manager.go +++ b/services/rpc/client_manager.go @@ -3,6 +3,7 @@ package rpc import ( "time" + "github.com/VaalaCat/frp-panel/defs" "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" "github.com/VaalaCat/frp-panel/utils" @@ -10,7 +11,7 @@ import ( ) type ClientsManager interface { - Get(cliID string) *app.Connector + Get(cliID string) *defs.Connector Set(cliID, clientType string, sender pb.Master_ServerSendServer) Remove(cliID string) ClientAddr(cliID string) string @@ -18,12 +19,12 @@ type ClientsManager interface { } type ClientsManagerImpl struct { - senders *utils.SyncMap[string, *app.Connector] + senders *utils.SyncMap[string, *defs.Connector] connectTime *utils.SyncMap[string, time.Time] } // Get implements ClientsManager. -func (c *ClientsManagerImpl) Get(cliID string) *app.Connector { +func (c *ClientsManagerImpl) Get(cliID string) *defs.Connector { cliAny, ok := c.senders.Load(cliID) if !ok { return nil @@ -33,7 +34,7 @@ func (c *ClientsManagerImpl) Get(cliID string) *app.Connector { // Set implements ClientsManager. func (c *ClientsManagerImpl) Set(cliID, clientType string, sender pb.Master_ServerSendServer) { - c.senders.Store(cliID, &app.Connector{ + c.senders.Store(cliID, &defs.Connector{ CliID: cliID, Conn: sender, CliType: clientType, @@ -66,9 +67,9 @@ func (c *ClientsManagerImpl) ConnectTime(cliID string) (time.Time, bool) { return t, true } -func NewClientsManager() *ClientsManagerImpl { +func NewClientsManager() app.ClientsManager { return &ClientsManagerImpl{ - senders: &utils.SyncMap[string, *app.Connector]{}, + senders: &utils.SyncMap[string, *defs.Connector]{}, connectTime: &utils.SyncMap[string, time.Time]{}, } } diff --git a/utils/net.go b/utils/net.go new file mode 100644 index 0000000..0ceae14 --- /dev/null +++ b/utils/net.go @@ -0,0 +1,52 @@ +package utils + +import ( + "fmt" + "net" +) + +// GetLocalIPv4s 返回本地所有活跃网络接口的 IPv4 地址列表。 +// 忽略未启用、回环以及非 IPv4 地址。 +func GetLocalIPv4s() ([]net.IP, error) { + var ips []net.IP + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("list interfaces failed: %w", err) + } + + for _, iface := range ifaces { + // 跳过未启用或回环接口 + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + // 某些接口可能无权限,此处跳过 + continue + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + + // 仅保留 IPv4,过滤回环 + if ip == nil || ip.IsLoopback() || ip.To4() == nil { + continue + } + + ipstr := ip.String() + if ipstr == "" { + continue + } + + ips = append(ips, ip) + } + } + + return ips, nil +} diff --git a/utils/port.go b/utils/port.go index 1b87c62..5957d2b 100644 --- a/utils/port.go +++ b/utils/port.go @@ -3,50 +3,58 @@ package utils import ( "fmt" "net" - "time" - - "github.com/sirupsen/logrus" + "strconv" + "strings" ) -func GetAvailablePort(addr string) (int, error) { - address, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:0", addr)) - if err != nil { - return 0, err - } +// GetFreePort asks the kernel for a free port for the given network. +// Valid networks: "tcp4", "tcp6", "udp4", "udp6". +func GetFreePort(network string) (uint32, error) { + network = strings.ToLower(network) + var ( + port int + err error + ) - listener, err := net.ListenTCP("tcp", address) - if err != nil { - return 0, err - } - - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil -} - -func IsPortAvailable(port int, addr string) bool { - - address := fmt.Sprintf("%s:%d", addr, port) - listener, err := net.Listen("tcp", address) - if err != nil { - logrus.Infof("port %s is taken: %s", address, err) - return false - } - - defer listener.Close() - return true -} - -func WaitForPort(host string, port int) { - for { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) - if err == nil { - conn.Close() - break + addr := ":0" // let OS choose + switch network { + case "tcp4", "tcp6": + var ln net.Listener + ln, err = net.Listen(network, addr) + if err != nil { + return 0, fmt.Errorf("listen %s failed: %w", network, err) } + defer ln.Close() + port, err = extractPort(ln.Addr().String()) - logrus.Warnf("Target port %s:%d is not open yet, waiting...\n", host, port) - time.Sleep(time.Second * 5) + case "udp4", "udp6": + var pc net.PacketConn + pc, err = net.ListenPacket(network, addr) + if err != nil { + return 0, fmt.Errorf("listenpacket %s failed: %w", network, err) + } + defer pc.Close() + port, err = extractPort(pc.LocalAddr().String()) + + default: + return 0, fmt.Errorf("unsupported network %q", network) } - logrus.Infof("Target port %s:%d is open", host, port) - time.Sleep(time.Second * 1) + + if err != nil { + return 0, err + } + return uint32(port), nil +} + +// extractPort splits "host:port" and returns port as int. +func extractPort(hostport string) (int, error) { + _, portStr, err := net.SplitHostPort(hostport) + if err != nil { + return 0, fmt.Errorf("split hostport %q: %w", hostport, err) + } + p, err := strconv.Atoi(portStr) + if err != nil { + return 0, fmt.Errorf("invalid port %q: %w", portStr, err) + } + return p, nil } diff --git a/utils/proto.go b/utils/proto.go new file mode 100644 index 0000000..692b6ba --- /dev/null +++ b/utils/proto.go @@ -0,0 +1,9 @@ +package utils + +import ( + "google.golang.org/protobuf/proto" +) + +func DeepCopyProto[T proto.Message](msg T) T { + return proto.Clone(msg).(T) +} diff --git a/utils/rand.go b/utils/rand.go new file mode 100644 index 0000000..c71b6cf --- /dev/null +++ b/utils/rand.go @@ -0,0 +1,12 @@ +package utils + +import ( + "math/rand" +) + +func RandomInt(a, b int) int { + if a > b { + a, b = b, a + } + return rand.Intn(b-a+1) + a +} diff --git a/utils/udp.go b/utils/udp.go new file mode 100644 index 0000000..0d584ef --- /dev/null +++ b/utils/udp.go @@ -0,0 +1,101 @@ +package utils + +import ( + "context" + "errors" + "math" + "net" + "sync" + "time" + + "github.com/VaalaCat/frp-panel/defs" +) + +// ProbeEndpoint sends a small UDP packet to addr and waits for a reply. +// It returns the measured RTT or an error. +func ProbeEndpoint(ctx context.Context, addr EndpointGettable, timeout time.Duration) (time.Duration, error) { + // Resolve UDP address + udpAddr, err := net.ResolveUDPAddr("udp", addr.GetEndpoint()) + if err != nil { + return 0, err + } + + // Dial UDP + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return 0, err + } + defer conn.Close() + + // Prepare a simple ping payload + payload := []byte(defs.VaalaMagicBytes) + + // Set deadlines + deadline := time.Now().Add(timeout) + conn.SetDeadline(deadline) + + start := time.Now() + if _, err := conn.Write(payload); err != nil { + return 0, err + } + + // Buffer for response + buf := make([]byte, 64) + if _, _, err := conn.ReadFrom(buf); err != nil { + return 0, err + } + rtt := time.Since(start) + + return rtt, nil +} + +type EndpointGettable interface { + GetEndpoint() string +} + +// SelectFastestEndpoint concurrently probes all candidates and returns the fastest. +func SelectFastestEndpoint(ctx context.Context, candidates []EndpointGettable, timeout time.Duration) (EndpointGettable, error) { + var ( + wg sync.WaitGroup + mu sync.Mutex + bestEP EndpointGettable + bestRTT = time.Duration(math.MaxInt64) + firstErr error + ) + + wg.Add(len(candidates)) + for _, addr := range candidates { + go func(addr EndpointGettable) { + defer wg.Done() + + rtt, err := ProbeEndpoint(ctx, addr, timeout) + if err != nil { + // record first error + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + return + } + + mu.Lock() + if rtt < bestRTT { + bestRTT = rtt + bestEP = addr + } + mu.Unlock() + }(addr) + } + + wg.Wait() + + if bestEP == nil { + if firstErr != nil { + return nil, firstErr + } + return nil, errors.New("no endpoint reachable") + } + + return bestEP, nil +}