From 15fa2a2a837c38933634adffa0db1acc957f1dbe Mon Sep 17 00:00:00 2001 From: VaalaCat Date: Sat, 13 Dec 2025 15:31:18 +0000 Subject: [PATCH] feat: support keep connection to all avaliable peer and add handshake time --- biz/client/rpc_pull_wireguards.go | 21 +- biz/client/update_wireguard.go.go | 82 ++-- biz/master/wg/client_list_wireguards.go | 92 ++++ biz/master/wg/get_network_typology.go | 19 +- biz/master/wg/helper.go | 63 +++ models/wireguard_test.go | 8 +- services/wg/helper_test.go | 2 +- services/wg/routing_planner.go | 192 +++++++- services/wg/routing_planner_test.go | 212 ++++++++- services/wg/wireguard.go | 431 +++--------------- services/wg/wireguard_device.go | 108 +++++ services/wg/wireguard_firewall_ops.go | 30 ++ services/wg/wireguard_network_gvisor.go | 97 ++++ services/wg/wireguard_network_netlink.go | 80 ++++ .../wireguard_patchpeers_preconnect_test.go | 133 ++++++ services/wg/wireguard_peers_preconnect.go | 362 +++++++++++++++ services/wg/wireguard_report.go | 22 + services/wg/wireguard_transport.go | 52 +++ services/wg/wireguard_types.go | 48 ++ utils/load_test.go | 2 +- utils/net_test.go | 2 +- 21 files changed, 1636 insertions(+), 422 deletions(-) create mode 100644 services/wg/wireguard_device.go create mode 100644 services/wg/wireguard_firewall_ops.go create mode 100644 services/wg/wireguard_network_gvisor.go create mode 100644 services/wg/wireguard_network_netlink.go create mode 100644 services/wg/wireguard_patchpeers_preconnect_test.go create mode 100644 services/wg/wireguard_peers_preconnect.go create mode 100644 services/wg/wireguard_report.go create mode 100644 services/wg/wireguard_transport.go create mode 100644 services/wg/wireguard_types.go diff --git a/biz/client/rpc_pull_wireguards.go b/biz/client/rpc_pull_wireguards.go index 2743748..8d2848c 100644 --- a/biz/client/rpc_pull_wireguards.go +++ b/biz/client/rpc_pull_wireguards.go @@ -6,6 +6,7 @@ import ( "github.com/VaalaCat/frp-panel/defs" "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" + "github.com/sirupsen/logrus" ) func PullWireGuards(appInstance app.Application, clientID, clientSecret string) error { @@ -32,7 +33,8 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string) } log.Debugf("client [%s] has [%d] wireguards, check their status", clientID, len(resp.GetWireguardConfigs())) - log.Tracef("wireguardConfigs: %v", resp.GetWireguardConfigs()) + log.Tracef("wireguardConfigs: %s", resp.String()) + wgMgr := ctx.GetApp().GetWireGuardManager() successCnt := 0 for _, wireGuard := range resp.GetWireguardConfigs() { @@ -43,7 +45,7 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string) wgMgr.RemoveService(wireGuard.GetInterfaceName()) } else { log.Debugf("wireguard [%s] already exists, skip create, update peers if need", wireGuard.GetInterfaceName()) - wgSvc.PatchPeers(wgCfg.GetParsedPeers()) + syncExistingWireGuard(log, wgSvc, wgCfg) continue } } @@ -65,3 +67,18 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string) return nil } + +func syncExistingWireGuard(log *logrus.Entry, wgSvc app.WireGuard, wgCfg *defs.WireGuardConfig) { + if wgSvc == nil || wgCfg == nil { + return + } + // 主链路:先更新 adjs,再 patch peers。wg 内部会基于最新拓扑做预连接补齐/不可直连清理。 + if err := wgSvc.UpdateAdjs(wgCfg.GetAdjs()); err != nil { + log.WithError(err).Warn("update adjs failed while syncing existing wireguard") + return + } + if _, err := wgSvc.PatchPeers(wgCfg.GetParsedPeers()); err != nil { + log.WithError(err).Warn("patch peers failed while syncing existing wireguard") + return + } +} diff --git a/biz/client/update_wireguard.go.go b/biz/client/update_wireguard.go.go index 291c5ee..95d35aa 100644 --- a/biz/client/update_wireguard.go.go +++ b/biz/client/update_wireguard.go.go @@ -7,6 +7,7 @@ import ( "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" "github.com/samber/lo" + "github.com/sirupsen/logrus" ) func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.UpdateWireGuardResponse, error) { @@ -38,19 +39,15 @@ func AddPeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardReque log.Debugf("add peer, peer_config: %+v", req.GetWireguardConfig().GetPeers()) - for _, peer := range req.GetWireguardConfig().GetPeers() { - err := wgSvc.AddPeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer}) - if err != nil { - log.WithError(err).Errorf("add peer failed") - continue - } - } - - if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil { - log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs()) + // 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑) + if err := updateAdjsFirst(log, wgSvc, req); err != nil { return nil, err } + applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "add peer", func(peer *pb.WireGuardPeerConfig) error { + return wgSvc.AddPeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer}) + }) + log.Infof("add peer done") return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil @@ -61,19 +58,15 @@ func RemovePeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe log.Debugf("remove peer, peer_config: %+v", req.GetWireguardConfig().GetPeers()) - for _, peer := range req.GetWireguardConfig().GetPeers() { - err := wgSvc.RemovePeer(peer.GetPublicKey()) - if err != nil { - log.WithError(err).Errorf("remove peer failed") - continue - } - } - - if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil { - log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs()) + // 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑) + if err := updateAdjsFirst(log, wgSvc, req); err != nil { return nil, err } + applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "remove peer routes", func(peer *pb.WireGuardPeerConfig) error { + return wgSvc.RemovePeer(peer.GetPublicKey()) + }) + log.Infof("remove peer done") return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil @@ -84,19 +77,15 @@ func UpdatePeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe log.Debugf("update peer, peer_config: %+v", req.GetWireguardConfig().GetPeers()) - for _, peer := range req.GetWireguardConfig().GetPeers() { - err := wgSvc.UpdatePeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer}) - if err != nil { - log.WithError(err).Errorf("update peer failed") - continue - } - } - - if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil { - log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs()) + // 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑) + if err := updateAdjsFirst(log, wgSvc, req); err != nil { return nil, err } + applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "update peer", func(peer *pb.WireGuardPeerConfig) error { + return wgSvc.UpdatePeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer}) + }) + log.Infof("update peer done") return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil @@ -107,6 +96,11 @@ func PatchPeers(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe log.Debugf("patch peers, peer_config: %+v", req.GetWireguardConfig().GetPeers()) + // 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑) + if err := updateAdjsFirst(log, wgSvc, req); err != nil { + return nil, err + } + wgCfg := &defs.WireGuardConfig{WireGuardConfig: req.GetWireguardConfig()} diffResp, err := wgSvc.PatchPeers(wgCfg.GetParsedPeers()) @@ -115,14 +109,32 @@ func PatchPeers(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe return nil, err } - if err = wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil { - log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs()) - return nil, err - } - log.Debugf("patch peers done, add_peers: %+v, remove_peers: %+v", lo.Map(diffResp.AddPeers, func(item *defs.WireGuardPeerConfig, _ int) string { return item.GetClientId() }), lo.Map(diffResp.RemovePeers, func(item *defs.WireGuardPeerConfig, _ int) string { return item.GetClientId() })) return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil } + +func updateAdjsFirst(log *logrus.Entry, wgSvc app.WireGuard, req *pb.UpdateWireGuardRequest) error { + if req == nil || req.GetWireguardConfig() == nil { + return nil + } + if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil { + log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs()) + return err + } + return nil +} + +func applyPeerOps(log *logrus.Entry, peers []*pb.WireGuardPeerConfig, op string, fn func(peer *pb.WireGuardPeerConfig) error) { + for _, peer := range peers { + if peer == nil { + continue + } + if err := fn(peer); err != nil { + log.WithError(err).Errorf("%s failed", op) + continue + } + } +} diff --git a/biz/master/wg/client_list_wireguards.go b/biz/master/wg/client_list_wireguards.go index af9ea59..e638d0e 100644 --- a/biz/master/wg/client_list_wireguards.go +++ b/biz/master/wg/client_list_wireguards.go @@ -1,12 +1,15 @@ package wg import ( + "sort" + "github.com/VaalaCat/frp-panel/models" "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" "github.com/VaalaCat/frp-panel/services/dao" wgsvc "github.com/VaalaCat/frp-panel/services/wg" "github.com/samber/lo" + "github.com/sirupsen/logrus" ) func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest) (*pb.ListClientWireGuardsResponse, error) { @@ -79,6 +82,15 @@ func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest) return nil } + // 构建 network 内 WireGuard 索引,用于补齐“可直连 peer 的基础配置” + idToWg := make(map[uint32]*models.WireGuard, len(networkPeers[wgCfg.NetworkID])) + for _, item := range networkPeers[wgCfg.NetworkID] { + if item == nil { + continue + } + idToWg[uint32(item.ID)] = item + } + r := wgCfg.ToPB() r.Peers = lo.Map(networkPeerConfigsMap[wgCfg.NetworkID][wgCfg.ID], func(peerCfg *pb.WireGuardPeerConfig, _ int) *pb.WireGuardPeerConfig { @@ -87,9 +99,89 @@ func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest) r.Adjs = adjsToPB(networkAllEdgesMap[wgCfg.NetworkID]) + fillConnectablePeersAsPreconnect(r, uint32(wgCfg.ID), idToWg, log) + sortPeersStable(r) + return r }), } return resp, nil } + +// fillConnectablePeersAsPreconnect 将 adj[localID] 中可直连的 peer 补齐到 r.peers 中,并将 AllowedIPs 置空(只预连接,不承载路由)。 +func fillConnectablePeersAsPreconnect(r *pb.WireGuardConfig, localID uint32, idToWg map[uint32]*models.WireGuard, log *logrus.Entry) { + if r == nil || localID == 0 { + return + } + exists := make(map[uint32]struct{}, len(r.GetPeers())) + for _, p := range r.GetPeers() { + if p == nil { + continue + } + if p.GetId() != 0 { + exists[p.GetId()] = struct{}{} + } + if p.GetEndpoint() != nil && p.GetEndpoint().GetWireguardId() != 0 { + exists[p.GetEndpoint().GetWireguardId()] = struct{}{} + } + } + + links := r.GetAdjs()[localID] + if links == nil { + return + } + for _, l := range links.GetLinks() { + if l == nil { + continue + } + toID := l.GetToWireguardId() + if toID == 0 || toID == localID { + continue + } + if _, ok := exists[toID]; ok { + continue + } + remote, ok := idToWg[toID] + if !ok || remote == nil { + continue + } + + // 优先使用链路显式 to_endpoint + var specifiedEndpoint *models.Endpoint + if l.GetToEndpoint() != nil { + m := &models.Endpoint{} + m.FromPB(l.GetToEndpoint()) + specifiedEndpoint = m + } + + base, err := remote.AsBasePeerConfig(specifiedEndpoint) + if err != nil { + log.WithError(err).Warnf("failed to build base peer config for preconnect: local=%d to=%d", localID, toID) + continue + } + base.AllowedIps = nil + r.Peers = append(r.Peers, base) + exists[toID] = struct{}{} + } +} + +func sortPeersStable(r *pb.WireGuardConfig) { + if r == nil || len(r.Peers) <= 1 { + return + } + sort.SliceStable(r.Peers, func(i, j int) bool { + pi := r.Peers[i] + pj := r.Peers[j] + if pi == nil && pj == nil { + return false + } + if pi == nil { + return false + } + if pj == nil { + return true + } + return pi.GetClientId() < pj.GetClientId() + }) +} diff --git a/biz/master/wg/get_network_typology.go b/biz/master/wg/get_network_typology.go index df3a69a..0745797 100644 --- a/biz/master/wg/get_network_typology.go +++ b/biz/master/wg/get_network_typology.go @@ -48,19 +48,26 @@ func GetNetworkTopology(ctx *app.Context, req *pb.GetNetworkTopologyRequest) (*p ctx.GetApp().GetClientsManager(), ) - var resp map[uint][]wg.Edge - if req.GetSpf() { - resp, err = wg.NewDijkstraAllowedIPsPlanner(policy).BuildFinalGraph(peers, links) - } else { - resp, err = wg.NewDijkstraAllowedIPsPlanner(policy).BuildGraph(peers, links) + // SPF 模式:展示“真实下发的路由表”(即 PeerConfig.AllowedIps),确保与实际一致。 + peerCfgs, allEdges, err := wg.PlanAllowedIPs(peers, links, policy) + if err != nil { + log.WithError(err).Errorf("failed to plan allowed ips") + return nil, err + } + adjs := peerConfigsToPBAdjs(peerCfgs, allEdges) + + return &pb.GetNetworkTopologyResponse{ + Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, + Adjs: adjs, + }, nil } + resp, err := wg.NewDijkstraAllowedIPsPlanner(policy).BuildGraph(peers, links) if err != nil { log.WithError(err).Errorf("failed to build graph") return nil, err } - adjs := adjsToPB(resp) return &pb.GetNetworkTopologyResponse{ diff --git a/biz/master/wg/helper.go b/biz/master/wg/helper.go index 434fb9a..bf91642 100644 --- a/biz/master/wg/helper.go +++ b/biz/master/wg/helper.go @@ -33,3 +33,66 @@ func adjsToPB(resp map[uint][]wg.Edge) map[uint32]*pb.WireGuardLinks { return adjs } + +// peerConfigsToPBAdjs 将“真实下发给节点的路由表(PeerConfig.AllowedIps)”转换为拓扑展示所需的 Adjs。 +// +// - routes: 直接使用 peerCfg.AllowedIps(这才是 WireGuard 实际使用的路由表) +// - latency/up/down: 尽量从 allEdges(buildAdjacency 的直连边指标)中补齐,仅用于展示 +// - endpoint: 优先使用 peerCfg.Endpoint(与实际下发一致) +func peerConfigsToPBAdjs(peerCfgs map[uint][]*pb.WireGuardPeerConfig, allEdges map[uint][]wg.Edge) map[uint32]*pb.WireGuardLinks { + adjs := make(map[uint32]*pb.WireGuardLinks, len(peerCfgs)) + + for src, pcs := range peerCfgs { + srcID := uint32(src) + links := make([]*pb.WireGuardLink, 0, len(pcs)) + + // 构建 toID -> edge 指标索引(仅用于展示 latency/up) + edgePBByTo := make(map[uint32]*pb.WireGuardLink, 16) + for _, e := range allEdges[src] { + epb := e.ToPB() + edgePBByTo[epb.GetToWireguardId()] = epb + } + + for _, pc := range pcs { + if pc == nil || pc.GetId() == 0 { + continue + } + toID := pc.GetId() + var latency uint32 + var up uint32 + if epb, ok := edgePBByTo[toID]; ok && epb != nil { + latency = epb.GetLatencyMs() + up = epb.GetUpBandwidthMbps() + } + + links = append(links, &pb.WireGuardLink{ + FromWireguardId: srcID, + ToWireguardId: toID, + LatencyMs: latency, + UpBandwidthMbps: up, + DownBandwidthMbps: 0, // 下面统一填充 + Active: true, + ToEndpoint: pc.GetEndpoint(), + Routes: pc.GetAllowedIps(), + }) + } + + adjs[srcID] = &pb.WireGuardLinks{Links: links} + } + + // 填充 down bandwidth(参考 adjsToPB 的做法:取反向边的 up) + for id, links := range adjs { + for _, link := range links.GetLinks() { + toWireguardEdges, ok := adjs[uint32(link.GetToWireguardId())] + if ok { + for _, edge := range toWireguardEdges.GetLinks() { + if edge.GetToWireguardId() == uint32(uint(id)) { + link.DownBandwidthMbps = edge.GetUpBandwidthMbps() + } + } + } + } + } + + return adjs +} diff --git a/models/wireguard_test.go b/models/wireguard_test.go index e480d90..d3dff30 100644 --- a/models/wireguard_test.go +++ b/models/wireguard_test.go @@ -5,11 +5,15 @@ import ( "testing" "github.com/VaalaCat/frp-panel/models" + "github.com/stretchr/testify/assert" ) func TestParseIPOrCIDRWithNetip(t *testing.T) { ip, cidr, _ := models.ParseIPOrCIDRWithNetip("192.168.1.1/24") - t.Errorf("ip: %v, cidr: %v", ip, cidr) + t.Logf("ip: %v, cidr: %v", ip, cidr) + assert.Equal(t, ip, netip.MustParseAddr("192.168.1.1")) + assert.Equal(t, cidr, netip.MustParsePrefix("192.168.1.1/24")) newcidr := netip.PrefixFrom(ip, 32) - t.Errorf("newcidr: %v", newcidr) + assert.Equal(t, newcidr, netip.MustParsePrefix("192.168.1.1/32")) + t.Logf("newcidr: %v", newcidr) } diff --git a/services/wg/helper_test.go b/services/wg/helper_test.go index 33d1904..b0eb380 100644 --- a/services/wg/helper_test.go +++ b/services/wg/helper_test.go @@ -24,7 +24,7 @@ func TestGenerateKeys(t *testing.T) { got := wg.GenerateKeys() // TODO: update the condition below to compare got with tt.want. if true { - t.Errorf("GenerateKeys() = %v, want %v", got, tt.want) + t.Logf("GenerateKeys() = %v, want %v", got, tt.want) } }) } diff --git a/services/wg/routing_planner.go b/services/wg/routing_planner.go index f65c023..fc12454 100644 --- a/services/wg/routing_planner.go +++ b/services/wg/routing_planner.go @@ -14,7 +14,7 @@ import ( ) // RoutingPolicy 决定边权重的计算方式。 -// cost = LatencyWeight*latency_ms + InverseBandwidthWeight*(1/max(up_mbps,1e-6)) + HopWeight +// cost = LatencyWeight*latency_ms + InverseBandwidthWeight*(1/max(up_mbps,1e-6)) + HopWeight + HandshakePenalty type RoutingPolicy struct { LatencyWeight float64 InverseBandwidthWeight float64 @@ -23,6 +23,10 @@ type RoutingPolicy struct { DefaultEndpointUpMbps uint32 DefaultEndpointLatencyMs uint32 OfflineThreshold time.Duration + // HandshakeStaleThreshold/HandshakeStalePenalty 用于抑制“握手过旧”的链路被选为最短路。 + // 仅在能从 runtimeInfo 中找到对应 peer 的 last_handshake_time_sec 时生效;否则不惩罚(避免误伤)。 + HandshakeStaleThreshold time.Duration + HandshakeStalePenalty float64 ACL *ACL NetworkTopologyCache app.NetworkTopologyCache @@ -42,9 +46,12 @@ func DefaultRoutingPolicy(acl *ACL, networkTopologyCache app.NetworkTopologyCach DefaultEndpointUpMbps: 50, DefaultEndpointLatencyMs: 30, OfflineThreshold: 2 * time.Minute, - ACL: acl, - NetworkTopologyCache: networkTopologyCache, - CliMgr: cliMgr, + // 默认启用一个温和的“握手过旧惩罚”:优先选择近期有握手的链路,但不至于强制剔除路径。 + HandshakeStaleThreshold: 5 * time.Minute, + HandshakeStalePenalty: 30.0, + ACL: acl, + NetworkTopologyCache: networkTopologyCache, + CliMgr: cliMgr, } } @@ -79,12 +86,19 @@ func (p *dijkstraAllowedIPsPlanner) Compute(peers []*models.WireGuard, links []* idToPeer, order := buildNodeIndex(peers) adj := buildAdjacency(order, idToPeer, links, p.policy) spfAdj := filterAdjacencyForSPF(order, adj, p.policy) + // 路由(AllowedIPs)依赖 WireGuard 的“源地址校验”:下一跳收到的包会按“来自哪个 peer”做匹配, + // 并校验 inner packet 的 source IP 是否落在该 peer 的 AllowedIPs 中。 + // 因此用于承载路由的直连边必须是双向的:若存在单向边,最短路会产生单向选路,导致中间节点丢包。 + spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj) aggByNode, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy) result, err := assemblePeerConfigs(order, aggByNode, edgeInfoMap, idToPeer) if err != nil { return nil, nil, err } fillIsolates(order, result) + if err := ensureRoutingPeerSymmetry(order, result, idToPeer); err != nil { + return nil, nil, err + } // 填充没有链路的节点 for _, id := range order { @@ -112,6 +126,7 @@ func (p *dijkstraAllowedIPsPlanner) BuildFinalGraph(peers []*models.WireGuard, l idToPeer, order := buildNodeIndex(peers) adj := buildAdjacency(order, idToPeer, links, p.policy) spfAdj := filterAdjacencyForSPF(order, adj, p.policy) + spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj) routesInfoMap, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy) ret := map[uint][]Edge{} @@ -175,6 +190,108 @@ func (e *Edge) ToPB() *pb.WireGuardLink { return link } +// filterAdjacencyForSymmetricLinks 仅保留“存在反向直连边”的邻接(用于 SPF)。 +// 这样最短路产生的每一步转发 hop 都对应一个双向直连 peer,避免出现单向路由导致的丢包。 +func filterAdjacencyForSymmetricLinks(order []uint, adj map[uint][]Edge) map[uint][]Edge { + ret := make(map[uint][]Edge, len(order)) + edgeSet := make(map[[2]uint]struct{}, 16) + + for from, edges := range adj { + for _, e := range edges { + edgeSet[[2]uint{from, e.to}] = struct{}{} + } + } + + for from, edges := range adj { + for _, e := range edges { + if _, ok := edgeSet[[2]uint{e.to, from}]; !ok { + continue + } + ret[from] = append(ret[from], e) + } + } + + for _, id := range order { + if _, ok := ret[id]; !ok { + ret[id] = []Edge{} + } + } + return ret +} + +// ensureRoutingPeerSymmetry 确保:如果 src 的 peers 中存在 nextHop(承载路由),则 nextHop 的 peers 中也必须存在 src。 +// 这里“对称”不是指两端 routes/AllowedIPs 集合一致,而是指两端都必须配置对方这个 peer, +// 以满足 WG 的解密与源地址校验(否则 nextHop 会丢弃来自 src 的转发包)。 +func ensureRoutingPeerSymmetry(order []uint, peerCfgs map[uint][]*pb.WireGuardPeerConfig, idToPeer map[uint]*models.WireGuard) error { + if len(order) == 0 { + return nil + } + + // 预计算每个节点自身的 /32 CIDR(AsBasePeerConfig 返回的 AllowedIps[0]) + selfCIDR := make(map[uint]string, len(order)) + for _, id := range order { + p := idToPeer[id] + if p == nil { + continue + } + base, err := p.AsBasePeerConfig(nil) + if err != nil || len(base.GetAllowedIps()) == 0 { + continue + } + selfCIDR[id] = base.GetAllowedIps()[0] + } + + hasPeer := func(owner uint, peerID uint) bool { + for _, pc := range peerCfgs[owner] { + if pc == nil { + continue + } + if uint(pc.GetId()) == peerID { + return true + } + } + return false + } + + for _, src := range order { + for _, pc := range peerCfgs[src] { + if pc == nil { + continue + } + if len(pc.GetAllowedIps()) == 0 { + continue + } + nextHop := uint(pc.GetId()) + if nextHop == 0 || nextHop == src { + continue + } + if hasPeer(nextHop, src) { + continue + } + + remote := idToPeer[src] + if remote == nil { + continue + } + base, err := remote.AsBasePeerConfig(nil) + if err != nil { + return err + } + if cidr := selfCIDR[src]; cidr != "" { + base.AllowedIps = []string{cidr} + } + peerCfgs[nextHop] = append(peerCfgs[nextHop], base) + } + } + + for _, id := range order { + sort.SliceStable(peerCfgs[id], func(i, j int) bool { + return peerCfgs[id][i].GetClientId() < peerCfgs[id][j].GetClientId() + }) + } + return nil +} + func buildNodeIndex(peers []*models.WireGuard) (map[uint]*models.WireGuard, []uint) { idToPeer := make(map[uint]*models.WireGuard, len(peers)) order := make([]uint, 0, len(peers)) @@ -364,7 +481,15 @@ func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*m visited[u] = true for _, e := range adj[u] { invBw := 1.0 / math.Max(float64(e.upMbps), 1e-6) - w := policy.LatencyWeight*float64(e.latency) + policy.InverseBandwidthWeight*invBw + policy.HopWeight + handshakePenalty := 0.0 + if policy.HandshakeStalePenalty > 0 && policy.HandshakeStaleThreshold > 0 { + // 握手惩罚必须是“无方向”的,否则会导致 A->B 与 B->A 权重不一致, + // 进而产生单向选路(WireGuard AllowedIPs 源地址校验下会丢包)。 + if age, ok := getHandshakeAgeBetween(u, e.to, idToPeer, policy); ok && age > policy.HandshakeStaleThreshold { + handshakePenalty = policy.HandshakeStalePenalty + } + } + w := policy.LatencyWeight*float64(e.latency) + policy.InverseBandwidthWeight*invBw + policy.HopWeight + handshakePenalty alt := dist[u] + w if alt < dist[e.to] { dist[e.to] = alt @@ -417,6 +542,63 @@ func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*m return aggByNode, edgeInfoMap } +// getHandshakeAgeBetween 返回 a<->b 间 peer handshake 的“最大”年龄(只要任意一侧可观测到握手时间就生效)。 +// 选择 max 的原因:如果任一方向握手过旧,都应抑制这对节点作为可靠转发 hop。 +func getHandshakeAgeBetween(aWGID, bWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) { + ageA, okA := getOneWayHandshakeAge(aWGID, bWGID, idToPeer, policy) + ageB, okB := getOneWayHandshakeAge(bWGID, aWGID, idToPeer, policy) + if !okA && !okB { + return 0, false + } + if !okA { + return ageB, true + } + if !okB { + return ageA, true + } + if ageA >= ageB { + return ageA, true + } + return ageB, true +} + +// getOneWayHandshakeAge 从 fromWGID 的 runtimeInfo 中,查找到 toWGID 对应 peer 的 last_handshake_time_sec/nsec,返回握手“距离现在”的时间差。 +func getOneWayHandshakeAge(fromWGID, toWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) { + if policy.NetworkTopologyCache == nil { + return 0, false + } + toPeer := idToPeer[toWGID] + if toPeer == nil || toPeer.ClientID == "" { + return 0, false + } + runtimeInfo, ok := policy.NetworkTopologyCache.GetRuntimeInfo(fromWGID) + if !ok || runtimeInfo == nil { + return 0, false + } + var hsSec uint64 + var hsNsec uint64 + for _, p := range runtimeInfo.GetPeers() { + if p == nil { + continue + } + if p.GetClientId() != toPeer.ClientID { + continue + } + hsSec = p.GetLastHandshakeTimeSec() + hsNsec = p.GetLastHandshakeTimeNsec() + break + } + if hsSec == 0 { + return 0, false + } + t := time.Unix(int64(hsSec), int64(hsNsec)) + age := time.Since(t) + if age < 0 { + age = 0 + } + return age, true +} + func initSSSP(order []uint) (map[uint]float64, map[uint]uint, map[uint]bool) { dist := make(map[uint]float64, len(order)) prev := make(map[uint]uint, len(order)) diff --git a/services/wg/routing_planner_test.go b/services/wg/routing_planner_test.go index 324daad..24cb6b2 100644 --- a/services/wg/routing_planner_test.go +++ b/services/wg/routing_planner_test.go @@ -2,17 +2,28 @@ package wg import ( "testing" + "time" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/VaalaCat/frp-panel/models" "github.com/VaalaCat/frp-panel/pb" ) type fakeTopologyCache struct { lat map[[2]uint]uint32 + rt map[uint]*pb.WGDeviceRuntimeInfo } -func (c *fakeTopologyCache) GetRuntimeInfo(_ uint) (*pb.WGDeviceRuntimeInfo, bool) { return nil, false } -func (c *fakeTopologyCache) SetRuntimeInfo(_ uint, _ *pb.WGDeviceRuntimeInfo) {} -func (c *fakeTopologyCache) DeleteRuntimeInfo(_ uint) {} +func (c *fakeTopologyCache) GetRuntimeInfo(id uint) (*pb.WGDeviceRuntimeInfo, bool) { + if c == nil || c.rt == nil { + return nil, false + } + v, ok := c.rt[id] + return v, ok +} +func (c *fakeTopologyCache) SetRuntimeInfo(_ uint, _ *pb.WGDeviceRuntimeInfo) {} +func (c *fakeTopologyCache) DeleteRuntimeInfo(_ uint) {} func (c *fakeTopologyCache) GetLatencyMs(fromWGID, toWGID uint) (uint32, bool) { if c == nil || c.lat == nil { return 0, false @@ -78,3 +89,198 @@ func TestFilterAdjacencyForSPF(t *testing.T) { t.Fatalf("node 2 should exist in return map") } } + +func TestRunAllPairsDijkstra_PreferFreshHandshake(t *testing.T) { + // 1 -> 2 (stale handshake) + // 1 -> 3 (fresh) + // 3 -> 2 (fresh) + // 期望:从 1 到 2 的 nextHop 选择 3,而不是 2 + now := time.Now().Unix() + + priv1, _ := wgtypes.GeneratePrivateKey() + priv2, _ := wgtypes.GeneratePrivateKey() + priv3, _ := wgtypes.GeneratePrivateKey() + + p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "c1", + PrivateKey: priv1.String(), + LocalAddress: "10.0.0.1/32", + }} + p1.ID = 1 + p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "c2", + PrivateKey: priv2.String(), + LocalAddress: "10.0.0.2/32", + }} + p2.ID = 2 + p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "c3", + PrivateKey: priv3.String(), + LocalAddress: "10.0.0.3/32", + }} + p3.ID = 3 + + idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3} + order := []uint{1, 2, 3} + adj := map[uint][]Edge{ + 1: { + {to: 2, latency: 10, upMbps: 50, explicit: true}, + {to: 3, latency: 5, upMbps: 50, explicit: true}, + }, + 3: { + {to: 2, latency: 5, upMbps: 50, explicit: true}, + }, + 2: {}, + } + + cache := &fakeTopologyCache{ + rt: map[uint]*pb.WGDeviceRuntimeInfo{ + 1: { + Peers: []*pb.WGPeerRuntimeInfo{ + {ClientId: "c2", LastHandshakeTimeSec: uint64(now - 3600)}, // stale + {ClientId: "c3", LastHandshakeTimeSec: uint64(now)}, // fresh + }, + }, + 3: { + Peers: []*pb.WGPeerRuntimeInfo{ + {ClientId: "c2", LastHandshakeTimeSec: uint64(now)}, // fresh + }, + }, + }, + } + + policy := RoutingPolicy{ + LatencyWeight: 1, + InverseBandwidthWeight: 0, + HopWeight: 0, + HandshakeStaleThreshold: 1 * time.Second, + HandshakeStalePenalty: 100, + NetworkTopologyCache: cache, + } + + aggByNode, _ := runAllPairsDijkstra(order, adj, idToPeer, policy) + if aggByNode[1] == nil { + t.Fatalf("aggByNode[1] should not be nil") + } + // dst=2 的 CIDR 应该被聚合到 nextHop=3 下(而不是 nextHop=2) + if _, ok := aggByNode[1][3]; !ok { + t.Fatalf("want nextHop=3 for src=1, got keys=%v", keysUint(aggByNode[1])) + } + if _, ok := aggByNode[1][2]; ok { + t.Fatalf("did not expect nextHop=2 for src=1 when handshake is stale") + } +} + +func keysUint(m map[uint]map[string]struct{}) []uint { + ret := make([]uint, 0, len(m)) + for k := range m { + ret = append(ret, k) + } + return ret +} + +func TestSymmetrizeAdjacencyForPeers_FillReverseEdge(t *testing.T) { + t.Skip("symmetrizeAdjacencyForPeers 已移除:路由承载的边必须双向存在,不应自动补齐单向边") +} + +func TestFilterAdjacencyForSymmetricLinks_DropOneWay(t *testing.T) { + order := []uint{1, 2} + adj := map[uint][]Edge{ + 1: {{to: 2, latency: 10, upMbps: 50, explicit: true}}, // 单向 + 2: {}, + } + ret := filterAdjacencyForSymmetricLinks(order, adj) + if len(ret[1]) != 0 { + t.Fatalf("want 0 edges for node 1 after symmetric filter, got %d: %#v", len(ret[1]), ret[1]) + } + if _, ok := ret[2]; !ok { + t.Fatalf("node 2 should exist in return map") + } +} + +func TestEnsureRoutingPeerSymmetry_AddReversePeer(t *testing.T) { + // 构造一个“1 直连 2,但 2 到 1 会更偏好走 3”的场景: + // 1->2 成为承载路由的 nextHop,但 2 的路由结果中可能不包含 peer(1),需要对称补齐。 + now := time.Now().Unix() + + priv1, _ := wgtypes.GeneratePrivateKey() + priv2, _ := wgtypes.GeneratePrivateKey() + priv3, _ := wgtypes.GeneratePrivateKey() + + p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "c1", + PrivateKey: priv1.String(), + LocalAddress: "10.0.0.1/32", + }} + p1.ID = 1 + p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "c2", + PrivateKey: priv2.String(), + LocalAddress: "10.0.0.2/32", + }} + p2.ID = 2 + p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "c3", + PrivateKey: priv3.String(), + LocalAddress: "10.0.0.3/32", + }} + p3.ID = 3 + + idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3} + order := []uint{1, 2, 3} + + // 全双向连通,但设置权重让 2->1 更偏好 2->3->1 + adj := map[uint][]Edge{ + 1: { + {to: 2, latency: 1, upMbps: 50, explicit: true}, + {to: 3, latency: 100, upMbps: 50, explicit: true}, + }, + 2: { + {to: 1, latency: 100, upMbps: 50, explicit: true}, + {to: 3, latency: 1, upMbps: 50, explicit: true}, + }, + 3: { + {to: 1, latency: 1, upMbps: 50, explicit: true}, + {to: 2, latency: 100, upMbps: 50, explicit: true}, + }, + } + + cache := &fakeTopologyCache{ + rt: map[uint]*pb.WGDeviceRuntimeInfo{ + 1: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c2", LastHandshakeTimeSec: uint64(now)}}}, + 2: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c1", LastHandshakeTimeSec: uint64(now)}}}, + }, + } + + policy := RoutingPolicy{ + LatencyWeight: 1, + InverseBandwidthWeight: 0, + HopWeight: 0, + HandshakeStaleThreshold: 1 * time.Hour, + HandshakeStalePenalty: 0, + NetworkTopologyCache: cache, + } + + aggByNode, edgeInfo := runAllPairsDijkstra(order, adj, idToPeer, policy) + peersMap, err := assemblePeerConfigs(order, aggByNode, edgeInfo, idToPeer) + if err != nil { + t.Fatalf("assemblePeerConfigs err: %v", err) + } + fillIsolates(order, peersMap) + + // 预期:在没有对称补齐前,2 可能不会包含 peer(1) + _ = ensureRoutingPeerSymmetry(order, peersMap, idToPeer) + + found := false + for _, pc := range peersMap[2] { + if pc != nil && pc.GetId() == 1 { + found = true + if len(pc.GetAllowedIps()) == 0 || pc.GetAllowedIps()[0] != "10.0.0.1/32" { + t.Fatalf("peer(1) on node2 should include 10.0.0.1/32, got=%v", pc.GetAllowedIps()) + } + } + } + if !found { + t.Fatalf("node2 should contain peer(1) after ensureRoutingPeerSymmetry") + } +} diff --git a/services/wg/wireguard.go b/services/wg/wireguard.go index aded4b9..11e9f2d 100644 --- a/services/wg/wireguard.go +++ b/services/wg/wireguard.go @@ -4,67 +4,18 @@ package wg import ( - "context" "errors" "fmt" - "net/http" - "net/netip" - "os" - "reflect" - "sync" - "time" - "unsafe" - "github.com/gin-gonic/gin" - "github.com/samber/lo" "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/netstack" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" - "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/stack" "github.com/VaalaCat/frp-panel/defs" "github.com/VaalaCat/frp-panel/pb" "github.com/VaalaCat/frp-panel/services/app" - "github.com/VaalaCat/frp-panel/services/wg/multibind" - "github.com/VaalaCat/frp-panel/services/wg/transport/ws" "github.com/VaalaCat/frp-panel/utils" ) -const ( - ReportInterval = time.Second * 60 -) - -var ( - _ app.WireGuard = (*wireGuard)(nil) -) - -type wireGuard struct { - sync.RWMutex - - ifce *defs.WireGuardConfig - endpointPingMap *utils.SyncMap[uint32, uint32] // ms - virtAddrPingMap *utils.SyncMap[string, uint32] // ms - - wgDevice *device.Device - tunDevice tun.Device - multiBind *multibind.MultiBind - gvisorNet *netstack.Net - fwManager *firewallManager - - running bool - useGvisorNet bool // if true, use gvisor netstack - - svcLogger *logrus.Entry - ctx *app.Context - cancel context.CancelFunc -} - func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.Entry) (app.WireGuard, error) { if logger == nil { defaultLog := logrus.New() @@ -86,7 +37,6 @@ func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.En fwManager := newFirewallManager(logger.WithField("component", "iptables")) return &wireGuard{ - RWMutex: sync.RWMutex{}, ifce: &cfg, ctx: svcCtx, cancel: cancel, @@ -95,6 +45,8 @@ func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.En useGvisorNet: useGvisorNet, virtAddrPingMap: &utils.SyncMap[string, uint32]{}, fwManager: fwManager, + peerDirectory: make(map[uint32]*pb.WireGuardPeerConfig, 64), + preconnectPeers: make(map[uint32]struct{}, 64), }, nil } @@ -191,6 +143,7 @@ func (w *wireGuard) AddPeer(peer *defs.WireGuardPeerConfig) error { defer w.Unlock() w.ifce.Peers = append(w.ifce.Peers, peer.WireGuardPeerConfig) + w.indexPeerDirectoryLocked(peerCfg) uapiBuilder := NewUAPIBuilder().AddPeerConfig(peerCfg) log.Debugf("uapiBuilder: %s", uapiBuilder.Build()) @@ -199,6 +152,9 @@ func (w *wireGuard) AddPeer(peer *defs.WireGuardPeerConfig) error { return errors.Join(errors.New("add peer IpcSet error"), err) } + // 补齐本节点可连接的常驻 peer(若目录里有基础信息) + w.onPeersChangedLocked("after AddPeer") + return nil } @@ -248,23 +204,27 @@ func (w *wireGuard) RemovePeer(peerNameOrPk string) error { w.Lock() defer w.Unlock() - newPeers := []*pb.WireGuardPeerConfig{} - var peerToRemove *defs.WireGuardPeerConfig + // 语义:真正移除 peer(下发 remove=true),并从本地配置中删除。 + // 如需要“只移除路由但保持连接”,请使用 PatchPeers/UpdatePeer 下发 AllowedIPs=nil 的更新策略。 + + var removedPeerPB *pb.WireGuardPeerConfig + newPeers := make([]*pb.WireGuardPeerConfig, 0, len(w.ifce.Peers)) for _, p := range w.ifce.Peers { - if p.ClientId != peerNameOrPk && p.PublicKey != peerNameOrPk { - newPeers = append(newPeers, p) + if p.ClientId == peerNameOrPk || p.PublicKey == peerNameOrPk { + removedPeerPB = p continue } - peerToRemove = &defs.WireGuardPeerConfig{WireGuardPeerConfig: p} + newPeers = append(newPeers, p) } - if len(newPeers) == len(w.ifce.Peers) { + if removedPeerPB == nil { return errors.New("peer not found") } - w.ifce.Peers = newPeers + removedPeer := &defs.WireGuardPeerConfig{WireGuardPeerConfig: removedPeerPB} + log.Debugf("remove peer completely: key=%s pk=%s", truncate(peerNameOrPk, 10), truncate(removedPeer.GetPublicKey(), 10)) - uapiBuilder := NewUAPIBuilder().RemovePeerByKey(peerToRemove.GetParsedPublicKey()) + uapiBuilder := NewUAPIBuilder().RemovePeerByKey(removedPeer.GetParsedPublicKey()) log.Debugf("uapiBuilder: %s", uapiBuilder.Build()) @@ -272,6 +232,22 @@ func (w *wireGuard) RemovePeer(peerNameOrPk string) error { return errors.Join(errors.New("remove peer IpcSet error"), err) } + // IpcSet 成功后再更新本地缓存,避免不一致 + w.ifce.Peers = newPeers + w.deletePeerDirectoryLocked(removedPeer) + if id := removedPeer.GetId(); id != 0 { + delete(w.preconnectPeers, id) + } + if removedPeer.GetEndpoint() != nil { + if id := removedPeer.GetEndpoint().GetWireguardId(); id != 0 { + delete(w.preconnectPeers, id) + } + } + w.cleanupPreconnectPeersLocked() + + // 移除后也补齐其他可连接 peer(若目录里有基础信息) + w.onPeersChangedLocked("after RemovePeer") + return nil } @@ -297,6 +273,7 @@ func (w *wireGuard) UpdatePeer(peer *defs.WireGuardPeerConfig) error { } w.ifce.Peers = newPeers + w.indexPeerDirectoryLocked(peerCfg) uapiBuilder := NewUAPIBuilder().UpdatePeerConfig(peerCfg) @@ -306,6 +283,9 @@ func (w *wireGuard) UpdatePeer(peer *defs.WireGuardPeerConfig) error { return errors.Join(errors.New("update peer IpcSet error"), err) } + // 更新后补齐常驻 peer(若目录里有基础信息) + w.onPeersChangedLocked("after UpdatePeer") + return nil } @@ -323,6 +303,24 @@ func (w *wireGuard) PatchPeers(newPeers []*defs.WireGuardPeerConfig) (*app.WireG return nil, err } + // 优化:从 adj 中提取“本节点可直连/可连接”的 peer,确保它们常驻但不分配 AllowedIPs。 + // 这样后续路由变化只需要更新 AllowedIPs,不需要 remove+add 重新建立连接。 + beforeMerge := len(typedNewPeers) + typedNewPeers = mergeConnectablePeersFromAdj(w.ifce, typedNewPeers, oldPeers) + if delta := len(typedNewPeers) - beforeMerge; delta > 0 { + log.Debugf("merged connectable peers from adjs: +%d (desired=%d -> %d)", delta, beforeMerge, len(typedNewPeers)) + } + + // 更新 peerDirectory(把本次看到的 peer 基础信息都缓存起来,后续 adjs 变化可用于补齐常驻 peer) + w.Lock() + for _, p := range typedNewPeers { + if p == nil { + continue + } + w.indexPeerDirectoryLocked(p) + } + w.Unlock() + oldByPK := make(map[string]*defs.WireGuardPeerConfig, len(oldPeers)) for _, p := range oldPeers { if p == nil || p.GetPublicKey() == "" { @@ -398,6 +396,12 @@ func (w *wireGuard) PatchPeers(newPeers []*defs.WireGuardPeerConfig) (*app.WireG } w.ifce.Peers = newPBPeers + // 清理 preconnectPeers 中已不存在的 peer id,避免无限增长 + w.cleanupPreconnectPeersLocked() + + // PatchPeers 后再次补齐(主要覆盖:adjs 先于 peers 变化、且目录已有信息的场景) + w.onPeersChangedLocked("after PatchPeers") + return resp, nil } @@ -456,7 +460,9 @@ func (w *wireGuard) UpdateAdjs(adjs map[uint32]*pb.WireGuardLinks) error { defer w.Unlock() w.ifce.Adjs = adjs - return nil + + // adjs 变化后立刻补齐“可连接 peer 常驻”(只要 peerDirectory 中已有其基础信息) + return w.ensureConnectablePeersLocked() } func (w *wireGuard) NeedRecreate(newCfg *defs.WireGuardConfig) bool { @@ -470,310 +476,3 @@ func (w *wireGuard) NeedRecreate(newCfg *defs.WireGuardConfig) bool { w.ifce.GetUseGvisorNet() != newCfg.GetUseGvisorNet() || w.ifce.GetNetworkId() != newCfg.GetNetworkId() } - -func (w *wireGuard) initTransports() error { - log := w.svcLogger.WithField("op", "initTransports") - - wsTrans := ws.NewWSBind(w.ctx) - w.multiBind = multibind.NewMultiBind( - w.svcLogger, - multibind.NewTransport(conn.NewDefaultBind(), "udp"), - multibind.NewTransport(wsTrans, "ws"), - ) - - engine := gin.New() - engine.Any(defs.DefaultWSHandlerPath, func(c *gin.Context) { - err := wsTrans.HandleHTTP(c.Writer, c.Request) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - }) - - // if ws listen port not set, use wg listen port, share tcp and udp port - listenPort := w.ifce.GetWsListenPort() - if listenPort == 0 { - listenPort = w.ifce.GetListenPort() - } - go func() { - if err := engine.Run(fmt.Sprintf(":%d", listenPort)); err != nil { - w.svcLogger.WithError(err).Errorf("failed to run gin engine for ws transport on port %d", listenPort) - } - }() - - log.Infof("WS transport engine running on port %d", listenPort) - - return nil -} - -func (w *wireGuard) initWGDevice() error { - log := w.svcLogger.WithField("op", "initWGDevice") - - log.Debugf("start to create TUN device '%s' (MTU %d)", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()) - - var err error - - if w.useGvisorNet { - log.Infof("using gvisor netstack for TUN device") - prf, err := netip.ParsePrefix(w.ifce.GetLocalAddress()) - if err != nil { - return errors.Join(fmt.Errorf("parse local addr '%s' for netip", w.ifce.GetLocalAddress()), err) - } - - addrs := lo.Map(w.ifce.GetDnsServers(), func(s string, _ int) netip.Addr { - addr, err := netip.ParseAddr(s) - if err != nil { - return netip.Addr{} - } - return addr - }) - if len(addrs) == 0 { - addrs = []netip.Addr{netip.AddrFrom4([4]byte{1, 2, 4, 8})} - } - log.Debugf("create netstack TUN with addr '%s' and dns servers '%v'", prf.Addr().String(), addrs) - w.tunDevice, w.gvisorNet, err = netstack.CreateNetTUN([]netip.Addr{prf.Addr()}, addrs, 1200) - if err != nil { - return errors.Join(fmt.Errorf("create netstack TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err) - } - } else { - w.tunDevice, err = tun.CreateTUN(w.ifce.GetInterfaceName(), int(w.ifce.GetInterfaceMtu())) - if err != nil { - return errors.Join(fmt.Errorf("create TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err) - } - } - - log.Debugf("TUN device '%s' (MTU %d) created successfully", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()) - - log.Debugf("start to create WireGuard device '%s'", w.ifce.GetInterfaceName()) - - w.wgDevice = device.NewDevice(w.tunDevice, w.multiBind, &device.Logger{ - Verbosef: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Debugf, - Errorf: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Errorf, - }) - - log.Debugf("WireGuard device '%s' created successfully", w.ifce.GetInterfaceName()) - - return nil -} - -func (w *wireGuard) applyPeerConfig() error { - log := w.svcLogger.WithField("op", "applyConfig") - - log.Debugf("start to apply config to WireGuard device '%s'", w.ifce.GetInterfaceName()) - - if w.wgDevice == nil { - return errors.New("wgDevice is nil, please init WG device first") - } - - wgTypedPeerConfigs, err := parseAndValidatePeerConfigs(w.ifce.GetParsedPeers()) - if err != nil { - return errors.Join(errors.New("parse/validate peers"), err) - } - - log.Debugf("wgTypedPeerConfigs: %v", wgTypedPeerConfigs) - - uapiConfigString := generateUAPIConfigString(w.ifce, w.ifce.GetParsedPrivKey(), wgTypedPeerConfigs, !w.running, false) - - log.Debugf("uapiBuilder: %s", uapiConfigString) - - log.Debugf("calling IpcSet...") - if err = w.wgDevice.IpcSet(uapiConfigString); err != nil { - return errors.Join(errors.New("IpcSet error"), err) - } - log.Debugf("IpcSet completed successfully") - - return nil -} - -func (w *wireGuard) initNetwork() error { - log := w.svcLogger.WithField("op", "initNetwork") - - // 等待 TUN 设备在内核中完全注册,避免竞态条件 - var link netlink.Link - var err error - maxRetries := 10 - for i := 0; i < maxRetries; i++ { - link, err = netlink.LinkByName(w.ifce.GetInterfaceName()) - if err == nil { - break - } - if i < maxRetries-1 { - log.Debugf("attempt %d: waiting for iface '%s' to be ready, will retry...", i+1, w.ifce.GetInterfaceName()) - time.Sleep(100 * time.Millisecond) - } - } - if err != nil { - return errors.Join(fmt.Errorf("get iface '%s' via netlink after %d retries", w.ifce.GetInterfaceName(), maxRetries), err) - } - log.Debugf("successfully found interface '%s' via netlink", w.ifce.GetInterfaceName()) - - addr, err := netlink.ParseAddr(w.ifce.GetLocalAddress()) - if err != nil { - return errors.Join(fmt.Errorf("parse local addr '%s' for netlink", w.ifce.GetLocalAddress()), err) - } - - if err = netlink.AddrAdd(link, addr); err != nil && !os.IsExist(err) { - return errors.Join(fmt.Errorf("add IP '%s' to '%s'", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()), err) - } else if os.IsExist(err) { - log.Infof("IP %s already on '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()) - } else { - log.Infof("IP %s added to '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()) - } - - if err = netlink.LinkSetMTU(link, int(w.ifce.GetInterfaceMtu())); err != nil { - log.Warnf("Set MTU %d on '%s' via netlink: %v. TUN MTU is %d.", - w.ifce.GetInterfaceMtu(), w.ifce.GetInterfaceName(), err, w.ifce.GetInterfaceMtu()) - } else { - log.Infof("Iface '%s' MTU %d set via netlink.", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()) - } - - if err = netlink.LinkSetUp(link); err != nil { - return errors.Join(fmt.Errorf("bring up iface '%s' via netlink", w.ifce.GetInterfaceName()), err) - } - - log.Infof("Iface '%s' up via netlink.", w.ifce.GetInterfaceName()) - return nil -} - -func (w *wireGuard) initGvisorNetwork() error { - log := w.svcLogger.WithField("op", "initGvisorNetwork") - - if w.gvisorNet == nil { - return errors.New("gvisorNet is nil, cannot initialize network") - } - - // wg-go dose not expose the stack field, so we need to use reflection to access it - netValue := reflect.ValueOf(w.gvisorNet).Elem() - stackField := netValue.FieldByName("stack") - - if !stackField.IsValid() { - return errors.New("cannot find stack field in gvisorNet") - } - - stackPtrValue := reflect.NewAt(stackField.Type(), unsafe.Pointer(stackField.UnsafeAddr())).Elem() - if !stackPtrValue.IsValid() || stackPtrValue.IsNil() { - return errors.New("gvisor stack is nil or invalid") - } - - gvisorStack := stackPtrValue.Interface().(*stack.Stack) - if gvisorStack == nil { - return errors.New("gvisor stack is nil after conversion") - } - - log.Infof("successfully accessed gvisor stack, enabling IP forwarding") - - if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { - log.Warnf("failed to enable IPv4 forwarding: %v, relay may not work", err) - } else { - log.Infof("IPv4 forwarding enabled for gvisor netstack") - } - - if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { - log.Warnf("failed to enable IPv6 forwarding: %v", err) - } else { - log.Infof("IPv6 forwarding enabled for gvisor netstack") - } - - for _, peer := range w.ifce.Peers { - for _, allowedIP := range peer.AllowedIps { - prefix, err := netip.ParsePrefix(allowedIP) - if err != nil { - log.WithError(err).Warnf("failed to parse allowed IP: %s", allowedIP) - continue - } - - addr := tcpip.AddrFromSlice(prefix.Addr().AsSlice()) - - ones := prefix.Bits() - maskBytes := make([]byte, len(prefix.Addr().AsSlice())) - for i := 0; i < len(maskBytes); i++ { - if ones >= 8 { - maskBytes[i] = 0xff - ones -= 8 - } else if ones > 0 { - maskBytes[i] = byte(0xff << (8 - ones)) - ones = 0 - } - } - - subnet, err := tcpip.NewSubnet(addr, tcpip.MaskFromBytes(maskBytes)) - if err != nil { - log.WithError(err).Warnf("failed to create subnet for %s", allowedIP) - continue - } - - route := tcpip.Route{ - Destination: subnet, - NIC: 1, - } - - gvisorStack.AddRoute(route) - log.Debugf("added route for peer allowed IP: %s via NIC 1", allowedIP) - } - } - - log.Infof("gvisor netstack initialized with IP forwarding enabled") - return nil -} - -func (w *wireGuard) cleanupNetwork() { - log := w.svcLogger.WithField("op", "cleanupNetwork") - - if w.useGvisorNet { - log.Infof("skip network cleanup for gvisor netstack") - return - } - - link, err := netlink.LinkByName(w.ifce.GetInterfaceName()) - if err == nil { - if err := netlink.LinkSetDown(link); err != nil { - log.Warnf("Failed to LinkSetDown '%s' after wgDevice.Up() error: %v", w.ifce.GetInterfaceName(), err) - } - } - log.Debug("Cleanup network complete.") -} - -func (w *wireGuard) cleanupWGDevice() { - log := w.svcLogger.WithField("op", "cleanupWGDevice") - - if w.wgDevice != nil { - w.wgDevice.Close() - } else if w.tunDevice != nil { - w.tunDevice.Close() - } - w.wgDevice = nil - w.tunDevice = nil - log.Debug("Cleanup WG device complete.") -} - -func (w *wireGuard) applyFirewallRulesLocked() error { - if w.useGvisorNet || w.fwManager == nil { - return nil - } - - prefix, err := netip.ParsePrefix(w.ifce.GetLocalAddress()) - if err != nil { - return errors.Join(fmt.Errorf("parse local address '%s' for firewall", w.ifce.GetLocalAddress()), err) - } - - return w.fwManager.ApplyRelayRules(w.ifce.GetInterfaceName(), prefix.Masked().String()) -} - -func (w *wireGuard) cleanupFirewallRulesLocked() error { - if w.useGvisorNet || w.fwManager == nil { - return nil - } - return w.fwManager.Cleanup(w.ifce.GetInterfaceName()) -} - -func (w *wireGuard) reportStatusTask() { - for { - select { - case <-w.ctx.Done(): - return - default: - w.pingPeers() - time.Sleep(ReportInterval) - } - } -} diff --git a/services/wg/wireguard_device.go b/services/wg/wireguard_device.go new file mode 100644 index 0000000..c0f50a2 --- /dev/null +++ b/services/wg/wireguard_device.go @@ -0,0 +1,108 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "errors" + "fmt" + "net/netip" + + "github.com/samber/lo" + + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func (w *wireGuard) initWGDevice() error { + log := w.svcLogger.WithField("op", "initWGDevice") + + log.Debugf("start to create TUN device '%s' (MTU %d)", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()) + + var err error + + if w.useGvisorNet { + log.Infof("using gvisor netstack for TUN device") + prf, err := netip.ParsePrefix(w.ifce.GetLocalAddress()) + if err != nil { + return errors.Join(fmt.Errorf("parse local addr '%s' for netip", w.ifce.GetLocalAddress()), err) + } + + addrs := lo.Map(w.ifce.GetDnsServers(), func(s string, _ int) netip.Addr { + addr, err := netip.ParseAddr(s) + if err != nil { + return netip.Addr{} + } + return addr + }) + if len(addrs) == 0 { + addrs = []netip.Addr{netip.AddrFrom4([4]byte{1, 2, 4, 8})} + } + log.Debugf("create netstack TUN with addr '%s' and dns servers '%v'", prf.Addr().String(), addrs) + w.tunDevice, w.gvisorNet, err = netstack.CreateNetTUN([]netip.Addr{prf.Addr()}, addrs, 1200) + if err != nil { + return errors.Join(fmt.Errorf("create netstack TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err) + } + } else { + w.tunDevice, err = tun.CreateTUN(w.ifce.GetInterfaceName(), int(w.ifce.GetInterfaceMtu())) + if err != nil { + return errors.Join(fmt.Errorf("create TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err) + } + } + + log.Debugf("TUN device '%s' (MTU %d) created successfully", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()) + + log.Debugf("start to create WireGuard device '%s'", w.ifce.GetInterfaceName()) + + w.wgDevice = device.NewDevice(w.tunDevice, w.multiBind, &device.Logger{ + Verbosef: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Debugf, + Errorf: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Errorf, + }) + + log.Debugf("WireGuard device '%s' created successfully", w.ifce.GetInterfaceName()) + + return nil +} + +func (w *wireGuard) applyPeerConfig() error { + log := w.svcLogger.WithField("op", "applyConfig") + + log.Debugf("start to apply config to WireGuard device '%s'", w.ifce.GetInterfaceName()) + + if w.wgDevice == nil { + return errors.New("wgDevice is nil, please init WG device first") + } + + wgTypedPeerConfigs, err := parseAndValidatePeerConfigs(w.ifce.GetParsedPeers()) + if err != nil { + return errors.Join(errors.New("parse/validate peers"), err) + } + + log.Debugf("wgTypedPeerConfigs: %v", wgTypedPeerConfigs) + + uapiConfigString := generateUAPIConfigString(w.ifce, w.ifce.GetParsedPrivKey(), wgTypedPeerConfigs, !w.running, false) + + log.Debugf("uapiBuilder: %s", uapiConfigString) + + log.Debugf("calling IpcSet...") + if err = w.wgDevice.IpcSet(uapiConfigString); err != nil { + return errors.Join(errors.New("IpcSet error"), err) + } + log.Debugf("IpcSet completed successfully") + + return nil +} + +func (w *wireGuard) cleanupWGDevice() { + log := w.svcLogger.WithField("op", "cleanupWGDevice") + + if w.wgDevice != nil { + w.wgDevice.Close() + } else if w.tunDevice != nil { + w.tunDevice.Close() + } + w.wgDevice = nil + w.tunDevice = nil + log.Debug("Cleanup WG device complete.") +} diff --git a/services/wg/wireguard_firewall_ops.go b/services/wg/wireguard_firewall_ops.go new file mode 100644 index 0000000..a7e97d8 --- /dev/null +++ b/services/wg/wireguard_firewall_ops.go @@ -0,0 +1,30 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "errors" + "fmt" + "net/netip" +) + +func (w *wireGuard) applyFirewallRulesLocked() error { + if w.useGvisorNet || w.fwManager == nil { + return nil + } + + prefix, err := netip.ParsePrefix(w.ifce.GetLocalAddress()) + if err != nil { + return errors.Join(fmt.Errorf("parse local address '%s' for firewall", w.ifce.GetLocalAddress()), err) + } + + return w.fwManager.ApplyRelayRules(w.ifce.GetInterfaceName(), prefix.Masked().String()) +} + +func (w *wireGuard) cleanupFirewallRulesLocked() error { + if w.useGvisorNet || w.fwManager == nil { + return nil + } + return w.fwManager.Cleanup(w.ifce.GetInterfaceName()) +} diff --git a/services/wg/wireguard_network_gvisor.go b/services/wg/wireguard_network_gvisor.go new file mode 100644 index 0000000..b758759 --- /dev/null +++ b/services/wg/wireguard_network_gvisor.go @@ -0,0 +1,97 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "errors" + "net/netip" + "reflect" + "unsafe" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +func (w *wireGuard) initGvisorNetwork() error { + log := w.svcLogger.WithField("op", "initGvisorNetwork") + + if w.gvisorNet == nil { + return errors.New("gvisorNet is nil, cannot initialize network") + } + + // wg-go dose not expose the stack field, so we need to use reflection to access it + netValue := reflect.ValueOf(w.gvisorNet).Elem() + stackField := netValue.FieldByName("stack") + + if !stackField.IsValid() { + return errors.New("cannot find stack field in gvisorNet") + } + + stackPtrValue := reflect.NewAt(stackField.Type(), unsafe.Pointer(stackField.UnsafeAddr())).Elem() + if !stackPtrValue.IsValid() || stackPtrValue.IsNil() { + return errors.New("gvisor stack is nil or invalid") + } + + gvisorStack := stackPtrValue.Interface().(*stack.Stack) + if gvisorStack == nil { + return errors.New("gvisor stack is nil after conversion") + } + + log.Infof("successfully accessed gvisor stack, enabling IP forwarding") + + if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + log.Warnf("failed to enable IPv4 forwarding: %v, relay may not work", err) + } else { + log.Infof("IPv4 forwarding enabled for gvisor netstack") + } + + if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + log.Warnf("failed to enable IPv6 forwarding: %v", err) + } else { + log.Infof("IPv6 forwarding enabled for gvisor netstack") + } + + for _, peer := range w.ifce.Peers { + for _, allowedIP := range peer.AllowedIps { + prefix, err := netip.ParsePrefix(allowedIP) + if err != nil { + log.WithError(err).Warnf("failed to parse allowed IP: %s", allowedIP) + continue + } + + addr := tcpip.AddrFromSlice(prefix.Addr().AsSlice()) + + ones := prefix.Bits() + maskBytes := make([]byte, len(prefix.Addr().AsSlice())) + for i := 0; i < len(maskBytes); i++ { + if ones >= 8 { + maskBytes[i] = 0xff + ones -= 8 + } else if ones > 0 { + maskBytes[i] = byte(0xff << (8 - ones)) + ones = 0 + } + } + + subnet, err := tcpip.NewSubnet(addr, tcpip.MaskFromBytes(maskBytes)) + if err != nil { + log.WithError(err).Warnf("failed to create subnet for %s", allowedIP) + continue + } + + route := tcpip.Route{ + Destination: subnet, + NIC: 1, + } + + gvisorStack.AddRoute(route) + log.Debugf("added route for peer allowed IP: %s via NIC 1", allowedIP) + } + } + + log.Infof("gvisor netstack initialized with IP forwarding enabled") + return nil +} diff --git a/services/wg/wireguard_network_netlink.go b/services/wg/wireguard_network_netlink.go new file mode 100644 index 0000000..7b732b4 --- /dev/null +++ b/services/wg/wireguard_network_netlink.go @@ -0,0 +1,80 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "errors" + "fmt" + "os" + "time" + + "github.com/vishvananda/netlink" +) + +func (w *wireGuard) initNetwork() error { + log := w.svcLogger.WithField("op", "initNetwork") + + // 等待 TUN 设备在内核中完全注册,避免竞态条件 + var link netlink.Link + var err error + maxRetries := 10 + for i := 0; i < maxRetries; i++ { + link, err = netlink.LinkByName(w.ifce.GetInterfaceName()) + if err == nil { + break + } + if i < maxRetries-1 { + log.Debugf("attempt %d: waiting for iface '%s' to be ready, will retry...", i+1, w.ifce.GetInterfaceName()) + time.Sleep(100 * time.Millisecond) + } + } + if err != nil { + return errors.Join(fmt.Errorf("get iface '%s' via netlink after %d retries", w.ifce.GetInterfaceName(), maxRetries), err) + } + log.Debugf("successfully found interface '%s' via netlink", w.ifce.GetInterfaceName()) + + addr, err := netlink.ParseAddr(w.ifce.GetLocalAddress()) + if err != nil { + return errors.Join(fmt.Errorf("parse local addr '%s' for netlink", w.ifce.GetLocalAddress()), err) + } + + if err = netlink.AddrAdd(link, addr); err != nil && !os.IsExist(err) { + return errors.Join(fmt.Errorf("add IP '%s' to '%s'", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()), err) + } else if os.IsExist(err) { + log.Infof("IP %s already on '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()) + } else { + log.Infof("IP %s added to '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()) + } + + if err = netlink.LinkSetMTU(link, int(w.ifce.GetInterfaceMtu())); err != nil { + log.Warnf("Set MTU %d on '%s' via netlink: %v. TUN MTU is %d.", + w.ifce.GetInterfaceMtu(), w.ifce.GetInterfaceName(), err, w.ifce.GetInterfaceMtu()) + } else { + log.Infof("Iface '%s' MTU %d set via netlink.", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()) + } + + if err = netlink.LinkSetUp(link); err != nil { + return errors.Join(fmt.Errorf("bring up iface '%s' via netlink", w.ifce.GetInterfaceName()), err) + } + + log.Infof("Iface '%s' up via netlink.", w.ifce.GetInterfaceName()) + return nil +} + +func (w *wireGuard) cleanupNetwork() { + log := w.svcLogger.WithField("op", "cleanupNetwork") + + if w.useGvisorNet { + log.Infof("skip network cleanup for gvisor netstack") + return + } + + link, err := netlink.LinkByName(w.ifce.GetInterfaceName()) + if err == nil { + if err := netlink.LinkSetDown(link); err != nil { + log.Warnf("Failed to LinkSetDown '%s' after wgDevice.Up() error: %v", w.ifce.GetInterfaceName(), err) + } + } + log.Debug("Cleanup network complete.") +} diff --git a/services/wg/wireguard_patchpeers_preconnect_test.go b/services/wg/wireguard_patchpeers_preconnect_test.go new file mode 100644 index 0000000..7bd712a --- /dev/null +++ b/services/wg/wireguard_patchpeers_preconnect_test.go @@ -0,0 +1,133 @@ +package wg + +import ( + "testing" + + "github.com/VaalaCat/frp-panel/defs" + "github.com/VaalaCat/frp-panel/pb" +) + +func TestMergeConnectablePeersFromAdj_AddsMissingPeerWithEmptyAllowedIPs(t *testing.T) { + ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{ + Id: 1, + Adjs: map[uint32]*pb.WireGuardLinks{ + 1: {Links: []*pb.WireGuardLink{ + {ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "1.2.3.4", Port: 51820}}, + }}, + }, + }} + + desired := []*defs.WireGuardPeerConfig{ + {WireGuardPeerConfig: &pb.WireGuardPeerConfig{ + Id: 3, + PublicKey: "pk-3", + AllowedIps: []string{ + "10.0.0.3/32", + }, + }}, + } + + known := []*defs.WireGuardPeerConfig{ + {WireGuardPeerConfig: &pb.WireGuardPeerConfig{ + Id: 2, + PublicKey: "pk-2", + AllowedIps: []string{ + "10.0.0.2/32", + }, + PersistentKeepalive: 0, // should be defaulted by parseAndValidatePeerConfig + }}, + } + + got := mergeConnectablePeersFromAdj(ifce, desired, known) + + var found *defs.WireGuardPeerConfig + for _, p := range got { + if p != nil && p.GetId() == 2 { + found = p + break + } + } + if found == nil { + t.Fatalf("expected peer id=2 to be added") + } + if len(found.GetAllowedIps()) != 0 { + t.Fatalf("expected allowed_ips empty, got=%v", found.GetAllowedIps()) + } + if found.GetEndpoint() == nil || found.GetEndpoint().GetHost() != "1.2.3.4" || found.GetEndpoint().GetPort() != 51820 { + t.Fatalf("expected endpoint from adj to be applied, got=%v", found.GetEndpoint()) + } + if found.GetPersistentKeepalive() == 0 { + t.Fatalf("expected persistent_keepalive defaulted, got=0") + } +} + +func TestMergeConnectablePeersFromAdj_DoesNotOverrideExistingDesiredPeer(t *testing.T) { + ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{ + Id: 1, + Adjs: map[uint32]*pb.WireGuardLinks{ + 1: {Links: []*pb.WireGuardLink{ + {ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "9.9.9.9", Port: 9999}}, + }}, + }, + }} + + desired := []*defs.WireGuardPeerConfig{ + {WireGuardPeerConfig: &pb.WireGuardPeerConfig{ + Id: 2, + PublicKey: "pk-2", + AllowedIps: []string{ + "10.0.0.2/32", + }, + Endpoint: &pb.Endpoint{Host: "1.1.1.1", Port: 1111}, + }}, + } + + got := mergeConnectablePeersFromAdj(ifce, desired, nil) + + if len(got) != 1 { + t.Fatalf("expected no extra peers added, got len=%d", len(got)) + } + if len(got[0].GetAllowedIps()) != 1 || got[0].GetAllowedIps()[0] != "10.0.0.2/32" { + t.Fatalf("expected desired allowed_ips preserved, got=%v", got[0].GetAllowedIps()) + } + // endpoint should not be overridden because peer already existed in desired list + if got[0].GetEndpoint() == nil || got[0].GetEndpoint().GetHost() != "1.1.1.1" { + t.Fatalf("expected desired endpoint preserved, got=%v", got[0].GetEndpoint()) + } +} + +func TestMergeConnectablePeersFromAdj_UsesEndpointWireguardIDWhenPeerIDMissing(t *testing.T) { + ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{ + Id: 1, + Adjs: map[uint32]*pb.WireGuardLinks{ + 1: {Links: []*pb.WireGuardLink{ + {ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "2.2.2.2", Port: 2222}}, + }}, + }, + }} + + desired := []*defs.WireGuardPeerConfig{} + + known := []*defs.WireGuardPeerConfig{ + {WireGuardPeerConfig: &pb.WireGuardPeerConfig{ + Id: 0, // 模拟:peer.id 未下发 + PublicKey: "pk-2", + Endpoint: &pb.Endpoint{WireguardId: 2, Host: "old", Port: 1}, + }}, + } + + got := mergeConnectablePeersFromAdj(ifce, desired, known) + if len(got) != 1 { + t.Fatalf("expected 1 peer added, got len=%d", len(got)) + } + if got[0].GetPublicKey() != "pk-2" { + t.Fatalf("expected pk-2, got=%s", got[0].GetPublicKey()) + } + // endpoint should be overridden by adj's to_endpoint + if got[0].GetEndpoint() == nil || got[0].GetEndpoint().GetHost() != "2.2.2.2" || got[0].GetEndpoint().GetPort() != 2222 { + t.Fatalf("expected endpoint from adj, got=%v", got[0].GetEndpoint()) + } + if len(got[0].GetAllowedIps()) != 0 { + t.Fatalf("expected allowed_ips empty, got=%v", got[0].GetAllowedIps()) + } +} diff --git a/services/wg/wireguard_peers_preconnect.go b/services/wg/wireguard_peers_preconnect.go new file mode 100644 index 0000000..fcee60f --- /dev/null +++ b/services/wg/wireguard_peers_preconnect.go @@ -0,0 +1,362 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + + "github.com/VaalaCat/frp-panel/defs" + "github.com/VaalaCat/frp-panel/pb" +) + +// peer 预连接/常驻补齐逻辑 +// +// 目标:在拓扑(adjs)变化时,确保本节点对“可直连/可连接”的 peer 已配置到 wg 设备, +// 但 AllowedIPs 为空(只保持连接,不承载路由),从而避免路由变化导致频繁 remove+add 造成断链。 + +func (w *wireGuard) onPeersChangedLocked(reason string) { + if w == nil { + return + } + if err := w.ensureConnectablePeersLocked(); err != nil { + w.svcLogger.WithError(err).WithField("op", "onPeersChanged").Warnf("ensure connectable peers failed (%s)", reason) + } +} + +func (w *wireGuard) cleanupPreconnectPeersLocked() { + if w == nil { + return + } + if len(w.preconnectPeers) == 0 { + return + } + exists := make(map[uint32]struct{}, len(w.ifce.Peers)) + for _, p := range w.ifce.GetParsedPeers() { + if p == nil { + continue + } + if id := p.GetId(); id != 0 { + exists[id] = struct{}{} + } + if p.GetEndpoint() != nil { + if id := p.GetEndpoint().GetWireguardId(); id != 0 { + exists[id] = struct{}{} + } + } + } + for id := range w.preconnectPeers { + if _, ok := exists[id]; !ok { + delete(w.preconnectPeers, id) + } + } +} + +func (w *wireGuard) indexPeerDirectoryLocked(p *defs.WireGuardPeerConfig) { + if p == nil || p.WireGuardPeerConfig == nil { + return + } + // 1) peer.id + if id := p.GetId(); id != 0 { + w.peerDirectory[id] = proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig) + } + // 2) endpoint.wireguard_id(部分场景 peer.id 可能未填,但 endpoint 带 wireguard_id) + if p.GetEndpoint() != nil { + if id := p.GetEndpoint().GetWireguardId(); id != 0 { + w.peerDirectory[id] = proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig) + } + } +} + +func (w *wireGuard) deletePeerDirectoryLocked(p *defs.WireGuardPeerConfig) { + if p == nil { + return + } + if id := p.GetId(); id != 0 { + delete(w.peerDirectory, id) + } + if p.GetEndpoint() != nil { + if id := p.GetEndpoint().GetWireguardId(); id != 0 { + delete(w.peerDirectory, id) + } + } +} + +// ensureConnectablePeersLocked 保证本节点在 wg 设备里已配置所有“当前可直连/可连接”的 peer, +// 但这些补齐 peer 的 AllowedIPs 为空(只保持连接,不承载路由)。 +// +// 约束:必须在持有 w.Lock() 的情况下调用。 +func (w *wireGuard) ensureConnectablePeersLocked() error { + if w == nil || w.ifce == nil || w.wgDevice == nil { + return nil + } + localID := w.ifce.GetId() + if localID == 0 { + return nil + } + adjs := w.ifce.GetAdjs() + if adjs == nil { + return nil + } + localLinks, ok := adjs[localID] + if !ok || localLinks == nil || len(localLinks.GetLinks()) == 0 { + return nil + } + + // 当前可直连/可连接的 peer id 集合(来自 adj) + connectable := make(map[uint32]struct{}, len(localLinks.GetLinks())) + for _, l := range localLinks.GetLinks() { + if l == nil { + continue + } + toID := l.GetToWireguardId() + if toID == 0 || toID == localID { + continue + } + connectable[toID] = struct{}{} + } + + log := w.svcLogger.WithField("op", "ensureConnectablePeers") + log.Debugf("ensure connectable peers: local=%d connectable=%d peers=%d preconnect=%d directory=%d", + localID, len(connectable), len(w.ifce.Peers), len(w.preconnectPeers), len(w.peerDirectory)) + + // 当前已配置的 peer:用 peer.id 与 endpoint.wireguard_id 双索引,避免 peer.id 缺失导致重复补齐 + exists := make(map[uint32]struct{}, len(w.ifce.Peers)) + for _, p := range w.ifce.GetParsedPeers() { + if p == nil { + continue + } + if id := p.GetId(); id != 0 { + exists[id] = struct{}{} + } + if p.GetEndpoint() != nil { + if id := p.GetEndpoint().GetWireguardId(); id != 0 { + exists[id] = struct{}{} + } + } + } + + uapiBuilder := NewUAPIBuilder() + added := 0 + removed := 0 + skippedNoBase := 0 + skippedAlready := 0 + + // 先清理:相比上次,本次拓扑中已“完全不可直连”的 peer,需要彻底从设备移除 + // 仅清理“AllowedIPs 为空”的 peer(也就是不承载路由、只为保持连接而存在的 peer) + newPeers := make([]*pb.WireGuardPeerConfig, 0, len(w.ifce.Peers)) + for _, raw := range w.ifce.GetParsedPeers() { + if raw == nil || raw.WireGuardPeerConfig == nil { + continue + } + // 仅对 AllowedIPs 为空的 peer 做自动清理 + if len(raw.GetAllowedIps()) != 0 { + newPeers = append(newPeers, raw.WireGuardPeerConfig) + continue + } + + var peerID uint32 + if raw.GetId() != 0 { + peerID = raw.GetId() + } else if raw.GetEndpoint() != nil && raw.GetEndpoint().GetWireguardId() != 0 { + peerID = raw.GetEndpoint().GetWireguardId() + } + + // 无法识别 peer id:保守起见不清理 + if peerID == 0 { + newPeers = append(newPeers, raw.WireGuardPeerConfig) + continue + } + + // 当前拓扑不可直连:彻底移除 + if _, ok := connectable[peerID]; !ok { + log.Debugf("preconnect remove: peerID=%d pk=%s (reason=not_connectable)", peerID, truncate(raw.GetPublicKey(), 10)) + uapiBuilder.RemovePeerByKey(raw.GetParsedPublicKey()) + delete(w.preconnectPeers, peerID) + delete(exists, peerID) + removed++ + continue + } + + newPeers = append(newPeers, raw.WireGuardPeerConfig) + } + // 如果有清理发生,先更新本地缓存(设备更新在最后统一 IpcSet) + if removed > 0 { + w.ifce.Peers = newPeers + } + + for _, l := range localLinks.GetLinks() { + if l == nil { + continue + } + toID := l.GetToWireguardId() + if toID == 0 || toID == localID { + continue + } + if _, ok := exists[toID]; ok { + skippedAlready++ + continue + } + + base, ok := w.peerDirectory[toID] + if !ok || base == nil || base.GetPublicKey() == "" { + skippedNoBase++ + continue + } + cloned := &defs.WireGuardPeerConfig{WireGuardPeerConfig: proto.Clone(base).(*pb.WireGuardPeerConfig)} + cloned.AllowedIps = nil + if l.GetToEndpoint() != nil { + cloned.Endpoint = l.GetToEndpoint() + } + if _, err := parseAndValidatePeerConfig(cloned); err != nil { + continue + } + + log.Debugf("preconnect add: peerID=%d pk=%s endpoint=%s", + toID, truncate(cloned.GetPublicKey(), 10), endpointForLog(cloned.GetEndpoint())) + uapiBuilder.AddPeerConfig(cloned) + w.ifce.Peers = append(w.ifce.Peers, cloned.WireGuardPeerConfig) + w.indexPeerDirectoryLocked(cloned) + exists[toID] = struct{}{} + w.preconnectPeers[toID] = struct{}{} + added++ + } + + if added == 0 && removed == 0 { + log.Debugf("ensure result: no-op (skippedAlready=%d skippedNoBase=%d)", skippedAlready, skippedNoBase) + return nil + } + log.Debugf("ensure result: add=%d remove=%d skippedAlready=%d skippedNoBase=%d", added, removed, skippedAlready, skippedNoBase) + if err := w.wgDevice.IpcSet(uapiBuilder.Build()); err != nil { + log.WithError(err).Debugf("ensure IpcSet failed (add=%d remove=%d)", added, removed) + return err + } + return nil +} + +// mergeConnectablePeersFromAdj 将本节点 adj 图中可直连的 peer 合并进目标 peers。 +// +// - **只在目标 peers 中缺失时才补齐**(按 PublicKey 去重),避免覆盖由路由规划器计算出的 AllowedIPs。 +// - **补齐的 peer AllowedIPs 置空**,确保不会引入额外路由。 +// - 如链路显式携带 to_endpoint,则优先用它覆盖 peer.endpoint(用于快速恢复直连)。 +// +// knownPeers 用于在目标 peers 缺失时提供“可用的 peer 基础信息”(公钥/预共享密钥/端点等),通常传 oldPeers。 +func mergeConnectablePeersFromAdj(ifce *defs.WireGuardConfig, desiredPeers []*defs.WireGuardPeerConfig, knownPeers []*defs.WireGuardPeerConfig) []*defs.WireGuardPeerConfig { + if ifce == nil { + return desiredPeers + } + localID := ifce.GetId() + if localID == 0 { + return desiredPeers + } + adjs := ifce.GetAdjs() + if adjs == nil { + return desiredPeers + } + localLinks, ok := adjs[localID] + if !ok || localLinks == nil || len(localLinks.GetLinks()) == 0 { + return desiredPeers + } + + desiredByPK := make(map[string]*defs.WireGuardPeerConfig, len(desiredPeers)) + for _, p := range desiredPeers { + if p == nil || p.GetPublicKey() == "" { + continue + } + desiredByPK[p.GetPublicKey()] = p + } + + // build id -> peer 基础信息索引(优先 desired,其次 known) + idToPeer := make(map[uint32]*defs.WireGuardPeerConfig, len(desiredPeers)+len(knownPeers)) + putPeerIDs := func(p *defs.WireGuardPeerConfig) { + if p == nil { + return + } + // 1) peer.id + if id := p.GetId(); id != 0 { + if _, exists := idToPeer[id]; !exists { + idToPeer[id] = p + } + } + // 2) endpoint.wireguard_id(有些下发场景可能不填 peer.id,但 endpoint 里带 wireguard_id) + if p.GetEndpoint() != nil { + if id := p.GetEndpoint().GetWireguardId(); id != 0 { + if _, exists := idToPeer[id]; !exists { + idToPeer[id] = p + } + } + } + } + for _, p := range desiredPeers { + putPeerIDs(p) + } + for _, p := range knownPeers { + putPeerIDs(p) + } + + // 仅补齐:adj 中的直连节点(to_wireguard_id) + for _, l := range localLinks.GetLinks() { + if l == nil { + continue + } + toID := l.GetToWireguardId() + if toID == 0 || toID == localID { + continue + } + + base, ok := idToPeer[toID] + if !ok || base == nil || base.GetPublicKey() == "" { + continue + } + if _, exists := desiredByPK[base.GetPublicKey()]; exists { + // 已在目标列表中(通常含有路由规划器计算出的 AllowedIPs),不覆盖。 + continue + } + + // 复制一份(避免直接改 oldPeers / knownPeers 的底层 pb 指针) + cloned := clonePeerConfig(base) + // 不分配路由:AllowedIPs 置空 + cloned.AllowedIps = nil + // 显式链路 endpoint 优先 + if l.GetToEndpoint() != nil { + cloned.Endpoint = l.GetToEndpoint() + } + // 确保 keepalive/AllowedIPs 格式一致(复用现有校验逻辑) + if _, err := parseAndValidatePeerConfig(cloned); err != nil { + continue + } + + desiredPeers = append(desiredPeers, cloned) + desiredByPK[cloned.GetPublicKey()] = cloned + } + + return desiredPeers +} + +func clonePeerConfig(p *defs.WireGuardPeerConfig) *defs.WireGuardPeerConfig { + if p == nil || p.WireGuardPeerConfig == nil { + return &defs.WireGuardPeerConfig{} + } + + // 使用 proto.Clone 避免直接拷贝 protoimpl.MessageState(内部含 mutex,会触发拷贝锁值的告警) + cp, _ := proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig) + if cp == nil { + return &defs.WireGuardPeerConfig{} + } + return &defs.WireGuardPeerConfig{WireGuardPeerConfig: cp} +} + +func endpointForLog(ep *pb.Endpoint) string { + if ep == nil { + return "" + } + if ep.GetUri() != "" { + return ep.GetUri() + } + if ep.GetHost() != "" || ep.GetPort() != 0 { + return fmt.Sprintf("%s:%d", ep.GetHost(), ep.GetPort()) + } + return "" +} diff --git a/services/wg/wireguard_report.go b/services/wg/wireguard_report.go new file mode 100644 index 0000000..7395945 --- /dev/null +++ b/services/wg/wireguard_report.go @@ -0,0 +1,22 @@ +//go:build !windows +// +build !windows + +package wg + +import "time" + +const ( + ReportInterval = time.Second * 60 +) + +func (w *wireGuard) reportStatusTask() { + for { + select { + case <-w.ctx.Done(): + return + default: + w.pingPeers() + time.Sleep(ReportInterval) + } + } +} diff --git a/services/wg/wireguard_transport.go b/services/wg/wireguard_transport.go new file mode 100644 index 0000000..3230064 --- /dev/null +++ b/services/wg/wireguard_transport.go @@ -0,0 +1,52 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + + "github.com/VaalaCat/frp-panel/defs" + "github.com/VaalaCat/frp-panel/services/wg/multibind" + "github.com/VaalaCat/frp-panel/services/wg/transport/ws" + + "golang.zx2c4.com/wireguard/conn" +) + +func (w *wireGuard) initTransports() error { + log := w.svcLogger.WithField("op", "initTransports") + + wsTrans := ws.NewWSBind(w.ctx) + w.multiBind = multibind.NewMultiBind( + w.svcLogger, + multibind.NewTransport(conn.NewDefaultBind(), "udp"), + multibind.NewTransport(wsTrans, "ws"), + ) + + engine := gin.New() + engine.Any(defs.DefaultWSHandlerPath, func(c *gin.Context) { + err := wsTrans.HandleHTTP(c.Writer, c.Request) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + }) + + // if ws listen port not set, use wg listen port, share tcp and udp port + listenPort := w.ifce.GetWsListenPort() + if listenPort == 0 { + listenPort = w.ifce.GetListenPort() + } + go func() { + if err := engine.Run(fmt.Sprintf(":%d", listenPort)); err != nil { + w.svcLogger.WithError(err).Errorf("failed to run gin engine for ws transport on port %d", listenPort) + } + }() + + log.Infof("WS transport engine running on port %d", listenPort) + + return nil +} diff --git a/services/wg/wireguard_types.go b/services/wg/wireguard_types.go new file mode 100644 index 0000000..6352efa --- /dev/null +++ b/services/wg/wireguard_types.go @@ -0,0 +1,48 @@ +//go:build !windows +// +build !windows + +package wg + +import ( + "context" + "sync" + + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/VaalaCat/frp-panel/defs" + "github.com/VaalaCat/frp-panel/pb" + "github.com/VaalaCat/frp-panel/services/app" + "github.com/VaalaCat/frp-panel/services/wg/multibind" + "github.com/VaalaCat/frp-panel/utils" +) + +var ( + _ app.WireGuard = (*wireGuard)(nil) +) + +type wireGuard struct { + sync.RWMutex + + ifce *defs.WireGuardConfig + endpointPingMap *utils.SyncMap[uint32, uint32] // ms + virtAddrPingMap *utils.SyncMap[string, uint32] // ms + peerDirectory map[uint32]*pb.WireGuardPeerConfig + // 仅用于“预连接/保持连接”的 peer(AllowedIPs 为空),用于后续根据拓扑变化做增删 + preconnectPeers map[uint32]struct{} + + wgDevice *device.Device + tunDevice tun.Device + multiBind *multibind.MultiBind + gvisorNet *netstack.Net + fwManager *firewallManager + + running bool + useGvisorNet bool // if true, use gvisor netstack + + svcLogger *logrus.Entry + ctx *app.Context + cancel context.CancelFunc +} diff --git a/utils/load_test.go b/utils/load_test.go index 1325a67..aaa4198 100644 --- a/utils/load_test.go +++ b/utils/load_test.go @@ -19,5 +19,5 @@ remotePort = 6000`) if err := LoadConfigureFromContent(content, allCfg, true); err != nil { t.Error(err) } - t.Errorf("%+v", allCfg) + t.Logf("%+v", allCfg) } diff --git a/utils/net_test.go b/utils/net_test.go index f6145f6..51cc7bb 100644 --- a/utils/net_test.go +++ b/utils/net_test.go @@ -11,5 +11,5 @@ func TestAllocateIP(t *testing.T) { if err != nil { t.Errorf("AllocateIP() failed: %v", err) } - t.Errorf("AllocateIP() = %v", ip) + t.Logf("AllocateIP() = %v", ip) }