From fa555b17f57b056d6b747bcb83b1d3a66e012cb4 Mon Sep 17 00:00:00 2001 From: VaalaCat Date: Sun, 14 Dec 2025 04:50:06 +0000 Subject: [PATCH] refactor: rewrite routing rule and do not ping when no allowed ips --- services/wg/ping.go | 11 +- services/wg/routing_planner.go | 795 ++++++++++++++++------------ services/wg/routing_planner_test.go | 379 ++++++++----- 3 files changed, 710 insertions(+), 475 deletions(-) diff --git a/services/wg/ping.go b/services/wg/ping.go index 2a70c0b..88ad3b0 100644 --- a/services/wg/ping.go +++ b/services/wg/ping.go @@ -142,6 +142,12 @@ func (w *wireGuard) scheduleVirtualAddrPings(log *logrus.Entry, ifceConfig *defs peers := ifceConfig.Peers for _, peer := range peers { p := peer + // 没有 AllowedIPs 的 peer(例如仅用于保持连接/预连接的常驻 peer)不参与 virt addr 探测: + // - 此时本机通常没有到对端 virtIP 的路由,探测必然失败 + // - 若失败写入不可达哨兵,会污染 master 的拓扑 cache,导致 SPF 误判为不可达,进而“完全不连通” + if p == nil || len(p.GetAllowedIps()) == 0 { + continue + } addr := p.GetVirtualIp() if addr == "" { continue @@ -157,7 +163,10 @@ func (w *wireGuard) scheduleVirtualAddrPings(log *logrus.Entry, ifceConfig *defs avg, err := tcpPingAvg(tcpAddr, endpointPingCount, endpointPingTimeout) if err != nil { log.WithError(err).Errorf("failed to tcp ping virt addr %s via %s", addr, tcpAddr) - w.storeVirtAddrPing(addr, math.MaxUint32) + // 失败时不写入不可达哨兵,避免污染拓扑;删除该条记录即可回退到 endpoint latency + if w.virtAddrPingMap != nil { + w.virtAddrPingMap.Delete(addr) + } return } diff --git a/services/wg/routing_planner.go b/services/wg/routing_planner.go index aa7fa14..9d88ec0 100644 --- a/services/wg/routing_planner.go +++ b/services/wg/routing_planner.go @@ -2,7 +2,9 @@ package wg import ( "errors" + "fmt" "math" + "net/netip" "sort" "time" @@ -12,14 +14,28 @@ import ( "github.com/VaalaCat/frp-panel/pb" ) +// WireGuard 的 AllowedIPs 同时承担两件事: +// 1) 出站选路:目的 IP 匹配哪个 peer 的 AllowedIPs,就把包发给哪个 peer +// 2) 入站源地址校验:从某 peer 解密出来的 inner packet,其 source IP 必须落在该 peer 的 AllowedIPs +// +// 因此,多跳转发时,某节点 i 从“上一跳 peer=j”收到的包,其 inner source 仍是“原始源节点 s 的 /32”, +// 所以 i 配置 peer(j) 的 AllowedIPs 必须包含这些会经由 j 转发进来的“源地址集合”,否则会直接丢包。 +// 思路: +// - 在一个“对称权重”的图上做最短路(保证路径可逆,避免重复/冲突) +// - 同时产出: +// - Out(i->nextHop): i 出站时,哪些目的 /32 应走 nextHop(目的集合) +// - In(i<-prevHop): i 入站时,从 prevHop 过来的包允许哪些源 /32(源集合) +// - 最终对每个 i 的每个直连 peer(j),AllowedIPs = Out(i->j) ∪ In(i<-j) +// - 严格校验:对同一节点 i,不允许出现同一个 /32 同时出现在多个 peer 的 AllowedIPs(否则 WG 行为不确定) + type AllowedIPsPlanner interface { // Compute 基于拓扑与链路指标,计算每个节点应配置到直连邻居的 AllowedIPs。 // 输入的 peers 应包含同一 Network 下的所有 WireGuard 节点,links 为其有向链路。 - // 返回节点ID->PeerConfig 列表,节点所有 ID->Edge 列表。 + // 返回:节点ID->PeerConfig 列表,节点ID->Edge 列表(完整候选图,用于展示)。 Compute(peers []*models.WireGuard, links []*models.WireGuardLink) (map[uint][]*pb.WireGuardPeerConfig, map[uint][]Edge, error) - // BuildGraph 基于拓扑与链路指标,计算每个节点应配置到直连邻居的 AllowedIPs,并返回节点ID->Edge 列表。 + // BuildGraph 基于拓扑与链路指标,返回完整候选图(用于展示/诊断)。 BuildGraph(peers []*models.WireGuard, links []*models.WireGuardLink) (map[uint][]Edge, error) - // BuildFinalGraph 最短路径算法,返回节点ID->Edge 列表。 + // BuildFinalGraph 返回“最终下发的直连边”与其 routes(用于展示 SPF 结果)。 BuildFinalGraph(peers []*models.WireGuard, links []*models.WireGuardLink) (map[uint][]Edge, error) } @@ -40,37 +56,44 @@ func (p *dijkstraAllowedIPsPlanner) Compute(peers []*models.WireGuard, links []* return map[uint][]*pb.WireGuardPeerConfig{}, map[uint][]Edge{}, nil } - idToPeer, order := buildNodeIndex(peers) - adj := buildAdjacency(order, idToPeer, links, p.policy) - spfAdj := filterAdjacencyForSPF(order, adj, p.policy) - // 路由(AllowedIPs)依赖 WireGuard 的“源地址校验”:下一跳收到的包会按“来自哪个 peer”做匹配, - // 并校验 inner packet 的 source IP 是否落在该 peer 的 AllowedIPs 中。 - // 因此用于承载路由的直连边必须是双向的:若存在单向边,最短路会产生单向选路,导致中间节点丢包。 - spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj) - aggByNode, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy) - result, err := assemblePeerConfigs(order, aggByNode, edgeInfoMap, idToPeer) + idToPeer, order := buildNodeIndexSorted(peers) + cidrByID, err := buildNodeCIDRMap(order, idToPeer) if err != nil { return nil, nil, err } - fillIsolates(order, result) - if err := ensureRoutingPeerSymmetry(order, result, idToPeer); err != nil { + + adj := buildAdjacency(order, idToPeer, links, p.policy) + // SPF 参与的边:显式边 + 已探测可达的推断边,并要求“可用于转发”的边必须双向存在 + spfAdj := filterAdjacencyForSPF(order, adj, p.policy) + spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj) + + peerCfgs, finalEdges, err := computeAllowedIPs(order, idToPeer, cidrByID, spfAdj, adj, p.policy) + if err != nil { return nil, nil, err } - // 填充没有链路的节点 + // 填充没有链路的节点(展示用) for _, id := range order { if _, ok := adj[id]; !ok { adj[id] = []Edge{} } + if _, ok := finalEdges[id]; !ok { + finalEdges[id] = []Edge{} + } + if _, ok := peerCfgs[id]; !ok { + peerCfgs[id] = []*pb.WireGuardPeerConfig{} + } } - return result, adj, nil + return peerCfgs, adj, nil } func (p *dijkstraAllowedIPsPlanner) BuildGraph(peers []*models.WireGuard, links []*models.WireGuardLink) (map[uint][]Edge, error) { - idToPeer, order := buildNodeIndex(peers) + if len(peers) == 0 { + return map[uint][]Edge{}, nil + } + idToPeer, order := buildNodeIndexSorted(peers) adj := buildAdjacency(order, idToPeer, links, p.policy) - // 填充没有链路的节点 for _, id := range order { if _, ok := adj[id]; !ok { adj[id] = []Edge{} @@ -80,56 +103,39 @@ func (p *dijkstraAllowedIPsPlanner) BuildGraph(peers []*models.WireGuard, links } func (p *dijkstraAllowedIPsPlanner) BuildFinalGraph(peers []*models.WireGuard, links []*models.WireGuardLink) (map[uint][]Edge, error) { - idToPeer, order := buildNodeIndex(peers) + if len(peers) == 0 { + return map[uint][]Edge{}, nil + } + + idToPeer, order := buildNodeIndexSorted(peers) + cidrByID, err := buildNodeCIDRMap(order, idToPeer) + if err != nil { + return nil, err + } + adj := buildAdjacency(order, idToPeer, links, p.policy) spfAdj := filterAdjacencyForSPF(order, adj, p.policy) spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj) - routesInfoMap, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy) - ret := map[uint][]Edge{} - for src, edgeInfo := range edgeInfoMap { - for next := range edgeInfo { - if _, ok := adj[src]; !ok { - continue - } - originEdge := Edge{} - finded := false - for _, e := range adj[src] { - if e.to == next { - originEdge = e - finded = true - break - } - } - if !finded { - continue - } - - routesInfo := routesInfoMap[src][next] - - ret[src] = append(ret[src], Edge{ - to: next, - latency: originEdge.latency, - upMbps: originEdge.upMbps, - toEndpoint: originEdge.toEndpoint, - routes: lo.Keys(routesInfo), - }) - } + _, finalEdges, err := computeAllowedIPs(order, idToPeer, cidrByID, spfAdj, adj, p.policy) + if err != nil { + return nil, err } for _, id := range order { - if _, ok := ret[id]; !ok { - ret[id] = []Edge{} + if _, ok := finalEdges[id]; !ok { + finalEdges[id] = []Edge{} } } - return ret, nil + return finalEdges, nil } +// Edge 表示候选/最终图里的“有向直连边”。 type Edge struct { to uint latency uint32 upMbps uint32 toEndpoint *models.Endpoint // 指定的目标端点,可能为 nil - routes []string // 路由信息 + routes []string // 最终展示:该直连 peer 承载的路由(AllowedIPs) explicit bool // true: 显式 link;false: 推断/探测用 link } @@ -147,45 +153,24 @@ func (e *Edge) ToPB() *pb.WireGuardLink { return link } -// filterAdjacencyForSymmetricLinks 仅保留“存在反向直连边”的邻接(用于 SPF)。 -// 这样最短路产生的每一步转发 hop 都对应一个双向直连 peer,避免出现单向路由导致的丢包。 -func filterAdjacencyForSymmetricLinks(order []uint, adj map[uint][]Edge) map[uint][]Edge { - ret := make(map[uint][]Edge, len(order)) - edgeSet := make(map[[2]uint]struct{}, 16) - - for from, edges := range adj { - for _, e := range edges { - edgeSet[[2]uint{from, e.to}] = struct{}{} +// buildNodeIndexSorted 返回:id->peer 映射 与 按 id 排序的 order(用于确定性) +func buildNodeIndexSorted(peers []*models.WireGuard) (map[uint]*models.WireGuard, []uint) { + idToPeer := make(map[uint]*models.WireGuard, len(peers)) + order := make([]uint, 0, len(peers)) + for _, p := range peers { + if p == nil { + continue } + id := uint(p.ID) + idToPeer[id] = p + order = append(order, id) } - - for from, edges := range adj { - for _, e := range edges { - if _, ok := edgeSet[[2]uint{e.to, from}]; !ok { - continue - } - ret[from] = append(ret[from], e) - } - } - - for _, id := range order { - if _, ok := ret[id]; !ok { - ret[id] = []Edge{} - } - } - return ret + sort.Slice(order, func(i, j int) bool { return order[i] < order[j] }) + return idToPeer, order } -// ensureRoutingPeerSymmetry 确保:如果 src 的 peers 中存在 nextHop(承载路由),则 nextHop 的 peers 中也必须存在 src。 -// 这里“对称”不是指两端 routes/AllowedIPs 集合一致,而是指两端都必须配置对方这个 peer, -// 以满足 WG 的解密与源地址校验(否则 nextHop 会丢弃来自 src 的转发包)。 -func ensureRoutingPeerSymmetry(order []uint, peerCfgs map[uint][]*pb.WireGuardPeerConfig, idToPeer map[uint]*models.WireGuard) error { - if len(order) == 0 { - return nil - } - - // 预计算每个节点自身的 /32 CIDR(AsBasePeerConfig 返回的 AllowedIps[0]) - selfCIDR := make(map[uint]string, len(order)) +func buildNodeCIDRMap(order []uint, idToPeer map[uint]*models.WireGuard) (map[uint]string, error) { + out := make(map[uint]string, len(order)) for _, id := range order { p := idToPeer[id] if p == nil { @@ -193,108 +178,66 @@ func ensureRoutingPeerSymmetry(order []uint, peerCfgs map[uint][]*pb.WireGuardPe } base, err := p.AsBasePeerConfig(nil) if err != nil || len(base.GetAllowedIps()) == 0 { - continue + return nil, fmt.Errorf("invalid wireguard local address for id=%d", id) } - selfCIDR[id] = base.GetAllowedIps()[0] + out[id] = base.GetAllowedIps()[0] } - - hasPeer := func(owner uint, peerID uint) bool { - for _, pc := range peerCfgs[owner] { - if pc == nil { - continue - } - if uint(pc.GetId()) == peerID { - return true - } - } - return false - } - - for _, src := range order { - for _, pc := range peerCfgs[src] { - if pc == nil { - continue - } - if len(pc.GetAllowedIps()) == 0 { - continue - } - nextHop := uint(pc.GetId()) - if nextHop == 0 || nextHop == src { - continue - } - if hasPeer(nextHop, src) { - continue - } - - remote := idToPeer[src] - if remote == nil { - continue - } - base, err := remote.AsBasePeerConfig(nil) - if err != nil { - return err - } - if cidr := selfCIDR[src]; cidr != "" { - base.AllowedIps = []string{cidr} - } - peerCfgs[nextHop] = append(peerCfgs[nextHop], base) - } - } - - for _, id := range order { - sort.SliceStable(peerCfgs[id], func(i, j int) bool { - return peerCfgs[id][i].GetClientId() < peerCfgs[id][j].GetClientId() - }) - } - return nil -} - -func buildNodeIndex(peers []*models.WireGuard) (map[uint]*models.WireGuard, []uint) { - idToPeer := make(map[uint]*models.WireGuard, len(peers)) - order := make([]uint, 0, len(peers)) - for _, p := range peers { - idToPeer[uint(p.ID)] = p - order = append(order, uint(p.ID)) - } - return idToPeer, order + return out, nil } +// buildAdjacency 构建“候选直连边”: +// 1) 显式链路(管理员配置)直接加入 +// 2) 若某节点具备 endpoint,则其他节点可按 ACL 推断直连它(用于探测/候选) func buildAdjacency(order []uint, idToPeer map[uint]*models.WireGuard, links []*models.WireGuardLink, policy RoutingPolicy) map[uint][]Edge { adj := make(map[uint][]Edge, len(order)) + + online := func(id uint) bool { + if policy.CliMgr == nil { + return true + } + p := idToPeer[id] + if p == nil || p.ClientID == "" { + return false + } + lastSeenAt, ok := policy.CliMgr.GetLastSeenAt(p.ClientID) + if !ok { + return false + } + if policy.OfflineThreshold > 0 && time.Since(lastSeenAt) > policy.OfflineThreshold { + return false + } + return true + } + // 1) 显式链路 for _, l := range links { - if !l.Active { + if l == nil || !l.Active { continue } from := l.FromWireGuardID to := l.ToWireGuardID - if _, ok := idToPeer[from]; !ok { continue } - if _, ok := idToPeer[to]; !ok { continue } - - if lastSeenAt, ok := policy.CliMgr.GetLastSeenAt(idToPeer[from].ClientID); !ok || time.Since(lastSeenAt) > policy.OfflineThreshold { + if !online(from) || !online(to) { continue } - - if lastSeenAt, ok := policy.CliMgr.GetLastSeenAt(idToPeer[to].ClientID); !ok || time.Since(lastSeenAt) > policy.OfflineThreshold { - continue - } - - // 如果两个peer都没有endpoint,则不建立链路 + // 如果两个 peer 都没有 endpoint,则不建立链路(无法直连) if len(idToPeer[from].AdvertisedEndpoints) == 0 && len(idToPeer[to].AdvertisedEndpoints) == 0 { continue } latency := l.LatencyMs - if latency == 0 { // 如果指定latency为0,则使用真实值 - if latencyMs, ok := policy.NetworkTopologyCache.GetLatencyMs(from, to); ok { - latency = latencyMs - } else { + if latency == 0 { + if policy.NetworkTopologyCache != nil { + if latencyMs, ok := policy.NetworkTopologyCache.GetLatencyMs(from, to); ok { + latency = latencyMs + } + } + if latency == 0 { latency = policy.DefaultEndpointLatencyMs } } @@ -308,18 +251,17 @@ func buildAdjacency(order []uint, idToPeer map[uint]*models.WireGuard, links []* }) } - // 2) 若某节点具备 endpoint,则所有其他节点可直连它 - edgeSet := make(map[[2]uint]struct{}, 16) + // 2) 推断/探测用边:若某节点具备 endpoint,则所有其他节点可直连它 + edgeSet := make(map[[2]uint]struct{}, 64) for from, edges := range adj { - for _, e := range edges { // 先拿到所有直连的节点 + for _, e := range edges { edgeSet[[2]uint{from, e.to}] = struct{}{} - edgeSet[[2]uint{e.to, from}] = struct{}{} } } for _, to := range order { - peer := idToPeer[to] - if peer == nil || len(peer.AdvertisedEndpoints) == 0 { + peerTo := idToPeer[to] + if peerTo == nil || len(peerTo.AdvertisedEndpoints) == 0 { continue } for _, from := range order { @@ -329,83 +271,96 @@ func buildAdjacency(order []uint, idToPeer map[uint]*models.WireGuard, links []* if _, ok := idToPeer[from]; !ok { continue } + if !online(from) || !online(to) { + continue + } latency := policy.DefaultEndpointLatencyMs - // GetLatencyMs 已自带“正反向兜底 + endpoint/virt ping 组合”,这里避免重复查询与覆盖,减少抖动 - if latencyMs, ok := policy.NetworkTopologyCache.GetLatencyMs(from, to); ok { - latency = latencyMs - } - - if lastSeenAt, ok := policy.CliMgr.GetLastSeenAt(idToPeer[from].ClientID); !ok || time.Since(lastSeenAt) > policy.OfflineThreshold { - continue - } - - if lastSeenAt, ok := policy.CliMgr.GetLastSeenAt(idToPeer[to].ClientID); !ok || time.Since(lastSeenAt) > policy.OfflineThreshold { - continue - } - - // 有 acl 限制 - if policy.ACL.CanConnect(idToPeer[from], idToPeer[to]) { - key1 := [2]uint{from, to} - if _, exists := edgeSet[key1]; exists { - continue + if policy.NetworkTopologyCache != nil { + if latencyMs, ok := policy.NetworkTopologyCache.GetLatencyMs(from, to); ok { + latency = latencyMs } - - adj[from] = append(adj[from], Edge{ - to: to, - latency: latency, - upMbps: policy.DefaultEndpointUpMbps, - explicit: false, - }) - edgeSet[key1] = struct{}{} } - if policy.ACL.CanConnect(idToPeer[to], idToPeer[from]) { - key2 := [2]uint{to, from} - if _, exists := edgeSet[key2]; exists { - continue + // 注意:推断边需要按“两个方向”分别判断 ACL 并分别建边。 + // 这样即使 from 没有 endpoint,也能被 endpoint 节点纳入邻接(满足对称直连 peer 的要求)。 + + // from -> to + if policy.ACL == nil || policy.ACL.CanConnect(idToPeer[from], idToPeer[to]) { + key := [2]uint{from, to} + if _, exists := edgeSet[key]; !exists { + adj[from] = append(adj[from], Edge{ + to: to, + latency: latency, + upMbps: policy.DefaultEndpointUpMbps, + explicit: false, + }) + edgeSet[key] = struct{}{} + } + } + + // to -> from(反向边同样使用同一对节点的 latency 估计;GetLatencyMs 本身已做正反向兜底) + if policy.ACL == nil || policy.ACL.CanConnect(idToPeer[to], idToPeer[from]) { + key := [2]uint{to, from} + if _, exists := edgeSet[key]; !exists { + adj[to] = append(adj[to], Edge{ + to: from, + latency: latency, + upMbps: policy.DefaultEndpointUpMbps, + explicit: false, + }) + edgeSet[key] = struct{}{} } - adj[to] = append(adj[to], Edge{ - to: from, - latency: latency, - upMbps: policy.DefaultEndpointUpMbps, - explicit: false, - }) - edgeSet[key2] = struct{}{} } } } + + // 稳定排序:保证遍历顺序确定性 + for _, from := range order { + if edges, ok := adj[from]; ok { + sort.SliceStable(edges, func(i, j int) bool { + if edges[i].explicit != edges[j].explicit { + return edges[i].explicit // explicit 优先 + } + return edges[i].to < edges[j].to + }) + adj[from] = edges + } + } + return adj } -// filterAdjacencyForSPF 将“用于探测的候选邻接(adj)”过滤为“允许进入 SPF 的邻接”。 -// -// 参考 OSPF:新邻接必须先被确认可达(这里用 runtime ping/virt ping 的存在性作为信号)后, -// 才能参与最短路计算。否则在节点刚更新/刚加入时,会因为默认权重过低被误选,导致部分节点不可达。 +func isUnreachableLatency(latency uint32) bool { + // 兼容两类不可达哨兵: + // - math.MaxUint32(历史实现) + // - math.MaxInt32(部分展示/转换链路里会出现 2147483647) + return latency == math.MaxUint32 || latency == uint32(math.MaxInt32) +} + +// filterAdjacencyForSPF:显式边直接保留;推断边必须有探测数据,且不可达哨兵值剔除 func filterAdjacencyForSPF(order []uint, adj map[uint][]Edge, policy RoutingPolicy) map[uint][]Edge { ret := make(map[uint][]Edge, len(order)) - for from, edges := range adj { for _, e := range edges { - // 显式 link:管理员配置的边,允许进入 SPF if e.explicit { ret[from] = append(ret[from], e) continue } - - // 推断/探测用 link:必须已存在探测数据,且不可达哨兵值要剔除 + if policy.NetworkTopologyCache == nil { + continue + } latency, ok := policy.NetworkTopologyCache.GetLatencyMs(from, e.to) if !ok { continue } - if latency == math.MaxUint32 { + if isUnreachableLatency(latency) { continue } e.latency = latency ret[from] = append(ret[from], e) } } - for _, id := range order { if _, ok := ret[id]; !ok { ret[id] = []Edge{} @@ -414,78 +369,320 @@ func filterAdjacencyForSPF(order []uint, adj map[uint][]Edge, policy RoutingPoli return ret } -// EdgeInfo 保存边的端点信息,用于后续组装 PeerConfig -type EdgeInfo struct { - toEndpoint *models.Endpoint +// filterAdjacencyForSymmetricLinks 仅保留“存在反向直连边”的邻接(用于可转发 SPF)。 +func filterAdjacencyForSymmetricLinks(order []uint, adj map[uint][]Edge) map[uint][]Edge { + ret := make(map[uint][]Edge, len(order)) + edgeSet := make(map[[2]uint]struct{}, 64) + for from, edges := range adj { + for _, e := range edges { + edgeSet[[2]uint{from, e.to}] = struct{}{} + } + } + for from, edges := range adj { + for _, e := range edges { + if _, ok := edgeSet[[2]uint{e.to, from}]; !ok { + continue + } + ret[from] = append(ret[from], e) + } + } + for _, id := range order { + if _, ok := ret[id]; !ok { + ret[id] = []Edge{} + } + } + return ret } -// runAllPairsDijkstra returns: map[src]map[nextHop]map[CIDR], map[src]map[nextHop]*EdgeInfo -func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (map[uint]map[uint]map[string]struct{}, map[uint]map[uint]*EdgeInfo) { - aggByNode := make(map[uint]map[uint]map[string]struct{}, len(order)) - edgeInfoMap := make(map[uint]map[uint]*EdgeInfo, len(order)) // 保存 src -> nextHop 的边信息 +type directedEdgeInfo struct { + latency uint32 + upMbps uint32 + toEndpoint *models.Endpoint + explicit bool +} + +type undirectedNeighbor struct { + to uint + weight float64 +} + +// computeAllowedIPs 是“最终下发路由”的核心: +// - 在 spfAdj 上构建“对称权重的无向图” +// - 对每个 src 做一次 Dijkstra,得到最短路树 prev +// - 同时生成 Out(dst prefixes) 与 In(src prefixes) 并合并到每条直连 peer 的 AllowedIPs +func computeAllowedIPs( + order []uint, + idToPeer map[uint]*models.WireGuard, + cidrByID map[uint]string, + spfAdj map[uint][]Edge, + fullAdj map[uint][]Edge, // 用于展示补齐 latency/up/endpoint + policy RoutingPolicy, +) (map[uint][]*pb.WireGuardPeerConfig, map[uint][]Edge, error) { + // 构建 directed edge info(用于 endpoint/展示),并构建 undirected graph(对称权重) + dInfo := make(map[[2]uint]*directedEdgeInfo, 128) + undir := make(map[uint][]undirectedNeighbor, len(order)) + + // 先把 spfAdj 的 directed info 记下来 + for _, from := range order { + for _, e := range spfAdj[from] { + key := [2]uint{from, e.to} + dInfo[key] = &directedEdgeInfo{ + latency: e.latency, + upMbps: e.upMbps, + toEndpoint: e.toEndpoint, + explicit: e.explicit, + } + } + } + + // 无向图:只添加“成对存在”的边,weight 用 max(w_uv, w_vu) 保证对称 + added := make(map[[2]uint]struct{}, 128) + for _, u := range order { + for _, e := range spfAdj[u] { + v := e.to + if u == v { + continue + } + // 只处理一次 pair(u,v) + a, b := u, v + if a > b { + a, b = b, a + } + pair := [2]uint{a, b} + if _, ok := added[pair]; ok { + continue + } + // 需要双向边信息 + uv, ok1 := dInfo[[2]uint{u, v}] + vu, ok2 := dInfo[[2]uint{v, u}] + if !ok1 || !ok2 || uv == nil || vu == nil { + continue + } + // 用 policy.EdgeWeight 计算双向权重并取 max 做对称 + wuv := policy.EdgeWeight(u, Edge{to: v, latency: uv.latency, upMbps: uv.upMbps, toEndpoint: uv.toEndpoint, explicit: uv.explicit}, idToPeer) + wvu := policy.EdgeWeight(v, Edge{to: u, latency: vu.latency, upMbps: vu.upMbps, toEndpoint: vu.toEndpoint, explicit: vu.explicit}, idToPeer) + w := math.Max(wuv, wvu) + undir[a] = append(undir[a], undirectedNeighbor{to: b, weight: w}) + undir[b] = append(undir[b], undirectedNeighbor{to: a, weight: w}) + added[pair] = struct{}{} + } + } + + // 稳定排序 + for _, u := range order { + neis := undir[u] + sort.SliceStable(neis, func(i, j int) bool { return neis[i].to < neis[j].to }) + undir[u] = neis + } + + // Out/ In 聚合:owner -> peer -> set[cidr] + allowed := make(map[uint]map[uint]map[string]struct{}, len(order)) for _, src := range order { - dist, prev, visited := initSSSP(order) + dist := make(map[uint]float64, len(order)) + prev := make(map[uint]uint, len(order)) // prev[dst] = predecessor of dst on path from src + visited := make(map[uint]bool, len(order)) + for _, id := range order { + dist[id] = math.Inf(1) + } dist[src] = 0 + // Dijkstra(O(n^2),节点数通常不大;同时保证确定性) for { u, ok := pickNext(order, dist, visited) if !ok { break } visited[u] = true - for _, e := range adj[u] { - w := policy.EdgeWeight(u, e, idToPeer) - alt := dist[u] + w - if alt < dist[e.to] { - dist[e.to] = alt - prev[e.to] = u + for _, nb := range undir[u] { + v := nb.to + if visited[v] { + continue + } + alt := dist[u] + nb.weight + if alt < dist[v] { + dist[v] = alt + prev[v] = u + continue + } + // tie-break:相同距离时,选择更小的 predecessor,确保稳定 + if alt == dist[v] { + if cur, ok := prev[v]; !ok || u < cur { + prev[v] = u + } } } } - // 累计 nextHop -> CIDR,并保存边信息 + // 1) 出站目的集合:dstCIDR -> nextHop(src,dst) for _, dst := range order { if dst == src { continue } if _, ok := prev[dst]; !ok { - continue + continue // unreachable } next := findNextHop(src, dst, prev) if next == 0 { continue } - dstPeer := idToPeer[dst] - allowed, err := dstPeer.AsBasePeerConfig(nil) // 这里只获取 CIDR,不需要指定 endpoint - if err != nil || len(allowed.GetAllowedIps()) == 0 { + cidr := cidrByID[dst] + if cidr == "" { continue } - cidr := allowed.GetAllowedIps()[0] - if _, ok := aggByNode[src]; !ok { - aggByNode[src] = make(map[uint]map[string]struct{}) - } - if _, ok := aggByNode[src][next]; !ok { - aggByNode[src][next] = map[string]struct{}{} - } - aggByNode[src][next][cidr] = struct{}{} + ensureAllowedSet(allowed, src, next)[cidr] = struct{}{} + } - // 保存从 src 到 next 的边信息(查找直接边) - if _, ok := edgeInfoMap[src]; !ok { - edgeInfoMap[src] = make(map[uint]*EdgeInfo) - } - if _, ok := edgeInfoMap[src][next]; !ok { - // 查找从 src 到 next 的边 - for _, e := range adj[src] { - if e.to == next { - edgeInfoMap[src][next] = &EdgeInfo{toEndpoint: e.toEndpoint} - break - } + // 2) 入站源集合:srcCIDR -> prevHop(src,dst) 归到 dst 节点的 peer(prevHop) + srcCIDR := cidrByID[src] + if srcCIDR != "" { + for _, dst := range order { + if dst == src { + continue } + pred, ok := prev[dst] + if !ok || pred == 0 { + continue + } + ensureAllowedSet(allowed, dst, pred)[srcCIDR] = struct{}{} } } } - return aggByNode, edgeInfoMap + + // 构建 PeerConfigs,并做强校验(同一节点不允许 CIDR 分配到多个 peer) + result := make(map[uint][]*pb.WireGuardPeerConfig, len(order)) + finalEdges := make(map[uint][]Edge, len(order)) + + for _, owner := range order { + peerToCIDRs := allowed[owner] + if len(peerToCIDRs) == 0 { + result[owner] = []*pb.WireGuardPeerConfig{} + finalEdges[owner] = []Edge{} + continue + } + + seen := make(map[string]uint, 128) + peerIDs := lo.Keys(peerToCIDRs) + sort.Slice(peerIDs, func(i, j int) bool { return peerIDs[i] < peerIDs[j] }) + + pcs := make([]*pb.WireGuardPeerConfig, 0, len(peerIDs)) + edges := make([]Edge, 0, len(peerIDs)) + + for _, peerID := range peerIDs { + cset := peerToCIDRs[peerID] + if len(cset) == 0 { + continue + } + remote := idToPeer[peerID] + if remote == nil { + continue + } + + // endpoint:优先使用 spfAdj 的直连边的 toEndpoint(与实际更一致) + var specifiedEndpoint *models.Endpoint + if info := dInfo[[2]uint{owner, peerID}]; info != nil && info.toEndpoint != nil { + specifiedEndpoint = info.toEndpoint + } + + base, err := remote.AsBasePeerConfig(specifiedEndpoint) + if err != nil { + return nil, nil, errors.Join(errors.New("build peer base config failed"), err) + } + + cidrs := make([]string, 0, len(cset)) + for c := range cset { + if prevOwner, ok := seen[c]; ok && prevOwner != peerID { + return nil, nil, fmt.Errorf("duplicate allowed ip on node %d: %s appears in peer %d and peer %d", owner, c, prevOwner, peerID) + } + seen[c] = peerID + cidrs = append(cidrs, c) + } + sort.Strings(cidrs) + base.AllowedIps = lo.Uniq(cidrs) + pcs = append(pcs, base) + + // 用 fullAdj 补齐展示指标(latency/up/endpoint) + lat, up, ep, explicit := lookupEdgeForDisplay(fullAdj, owner, peerID) + edges = append(edges, Edge{ + to: peerID, + latency: lat, + upMbps: up, + toEndpoint: ep, + routes: base.AllowedIps, + explicit: explicit, + }) + } + + // 按 client_id 稳定排序(保持原接口习惯) + sort.SliceStable(pcs, func(i, j int) bool { return pcs[i].GetClientId() < pcs[j].GetClientId() }) + sort.SliceStable(edges, func(i, j int) bool { return edges[i].to < edges[j].to }) + + result[owner] = pcs + finalEdges[owner] = edges + } + + return result, finalEdges, nil +} + +func ensureAllowedSet(m map[uint]map[uint]map[string]struct{}, owner, peer uint) map[string]struct{} { + if _, ok := m[owner]; !ok { + m[owner] = make(map[uint]map[string]struct{}, 8) + } + if _, ok := m[owner][peer]; !ok { + m[owner][peer] = make(map[string]struct{}, 32) + } + return m[owner][peer] +} + +func lookupEdgeForDisplay(fullAdj map[uint][]Edge, from, to uint) (latency uint32, up uint32, ep *models.Endpoint, explicit bool) { + edges := fullAdj[from] + for _, e := range edges { + if e.to == to { + return e.latency, e.upMbps, e.toEndpoint, e.explicit + } + } + return 0, 0, nil, false +} + +func pickNext(order []uint, dist map[uint]float64, visited map[uint]bool) (uint, bool) { + best := uint(0) + bestVal := math.Inf(1) + found := false + for _, vid := range order { + if visited[vid] { + continue + } + if dist[vid] < bestVal { + bestVal = dist[vid] + best = vid + found = true + } + } + return best, found +} + +// findNextHop 返回从 src 到 dst 的 nextHop(src 的直连邻居),依赖 prev[dst] = predecessor(dst) +func findNextHop(src, dst uint, prev map[uint]uint) uint { + next := dst + for { + p, ok := prev[next] + if !ok { + return 0 + } + if p == src { + return next + } + next = p + } +} + +// 仅用于测试/诊断:解析 /32 的 host ip(校验格式) +func parseHostFromCIDR(c string) (netip.Addr, bool) { + p, err := netip.ParsePrefix(c) + if err != nil { + return netip.Addr{}, false + } + return p.Addr(), true } // getHandshakeAgeBetween 返回 a<->b 间 peer handshake 的“最大”年龄(只要任意一侧可观测到握手时间就生效)。 @@ -508,7 +705,7 @@ func getHandshakeAgeBetween(aWGID, bWGID uint, idToPeer map[uint]*models.WireGua return ageB, true } -// getOneWayHandshakeAge 从 fromWGID 的 runtimeInfo 中,查找到 toWGID 对应 peer 的 last_handshake_time_sec/nsec,返回握手“距离现在”的时间差。 +// getOneWayHandshakeAge 从 fromWGID 的 runtimeInfo 中查到 toWGID 对应 peer 的 last_handshake_time_sec/nsec,返回握手“距离现在”的时间差。 func getOneWayHandshakeAge(fromWGID, toWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) { if policy.NetworkTopologyCache == nil { return 0, false @@ -544,85 +741,3 @@ func getOneWayHandshakeAge(fromWGID, toWGID uint, idToPeer map[uint]*models.Wire } return age, true } - -func initSSSP(order []uint) (map[uint]float64, map[uint]uint, map[uint]bool) { - dist := make(map[uint]float64, len(order)) - prev := make(map[uint]uint, len(order)) - visited := make(map[uint]bool, len(order)) - for _, vid := range order { - dist[vid] = math.Inf(1) - } - return dist, prev, visited -} - -func pickNext(order []uint, dist map[uint]float64, visited map[uint]bool) (uint, bool) { - best := uint(0) - bestVal := math.Inf(1) - found := false - for _, vid := range order { - if visited[vid] { - continue - } - if dist[vid] < bestVal { - bestVal = dist[vid] - best = vid - found = true - } - } - return best, found -} - -func findNextHop(src, dst uint, prev map[uint]uint) uint { - next := dst - for { - p, ok := prev[next] - if !ok { - return 0 - } - if p == src { - return next - } - next = p - } -} - -func assemblePeerConfigs(order []uint, aggByNode map[uint]map[uint]map[string]struct{}, edgeInfoMap map[uint]map[uint]*EdgeInfo, idToPeer map[uint]*models.WireGuard) (map[uint][]*pb.WireGuardPeerConfig, error) { - result := make(map[uint][]*pb.WireGuardPeerConfig, len(order)) - for src, nextMap := range aggByNode { - peersForSrc := make([]*pb.WireGuardPeerConfig, 0, len(nextMap)) - for nextHop, cidrSet := range nextMap { - remote := idToPeer[nextHop] - - // 获取从 src 到 nextHop 的边信息,确定使用哪个 endpoint - var specifiedEndpoint *models.Endpoint - if edgeInfo, ok := edgeInfoMap[src][nextHop]; ok && edgeInfo != nil && edgeInfo.toEndpoint != nil { - specifiedEndpoint = edgeInfo.toEndpoint - } - - base, err := remote.AsBasePeerConfig(specifiedEndpoint) - if err != nil { - return nil, errors.Join(errors.New("build peer base config failed"), err) - } - cidrs := make([]string, 0, len(cidrSet)) - for c := range cidrSet { - cidrs = append(cidrs, c) - } - sort.Strings(cidrs) - base.AllowedIps = lo.Uniq(cidrs) - peersForSrc = append(peersForSrc, base) - } - sort.SliceStable(peersForSrc, func(i, j int) bool { - return peersForSrc[i].GetClientId() < peersForSrc[j].GetClientId() - }) - result[src] = peersForSrc - } - return result, nil -} - -func fillIsolates(order []uint, result map[uint][]*pb.WireGuardPeerConfig) { - for _, id := range order { - if _, ok := result[id]; !ok { - result[id] = []*pb.WireGuardPeerConfig{} - } - } -} diff --git a/services/wg/routing_planner_test.go b/services/wg/routing_planner_test.go index 24cb6b2..2493b28 100644 --- a/services/wg/routing_planner_test.go +++ b/services/wg/routing_planner_test.go @@ -8,6 +8,7 @@ import ( "github.com/VaalaCat/frp-panel/models" "github.com/VaalaCat/frp-panel/pb" + "github.com/samber/lo" ) type fakeTopologyCache struct { @@ -90,93 +91,86 @@ func TestFilterAdjacencyForSPF(t *testing.T) { } } -func TestRunAllPairsDijkstra_PreferFreshHandshake(t *testing.T) { - // 1 -> 2 (stale handshake) - // 1 -> 3 (fresh) - // 3 -> 2 (fresh) - // 期望:从 1 到 2 的 nextHop 选择 3,而不是 2 +func TestPlanAllowedIPs_PreferFreshHandshake(t *testing.T) { + // 1 <-> 2:低延迟但握手过旧(应被惩罚) + // 1 <-> 3 <-> 2:略高延迟但握手新(应被选为 1->2 的 nextHop=3) now := time.Now().Unix() priv1, _ := wgtypes.GeneratePrivateKey() priv2, _ := wgtypes.GeneratePrivateKey() priv3, _ := wgtypes.GeneratePrivateKey() - p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ - ClientID: "c1", - PrivateKey: priv1.String(), - LocalAddress: "10.0.0.1/32", - }} + p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ClientID: "c1", PrivateKey: priv1.String(), LocalAddress: "10.0.0.1/32"}} p1.ID = 1 - p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ - ClientID: "c2", - PrivateKey: priv2.String(), - LocalAddress: "10.0.0.2/32", - }} + p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ClientID: "c2", PrivateKey: priv2.String(), LocalAddress: "10.0.0.2/32"}} p2.ID = 2 - p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ - ClientID: "c3", - PrivateKey: priv3.String(), - LocalAddress: "10.0.0.3/32", - }} + p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ClientID: "c3", PrivateKey: priv3.String(), LocalAddress: "10.0.0.3/32"}} p3.ID = 3 - idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3} - order := []uint{1, 2, 3} - adj := map[uint][]Edge{ - 1: { - {to: 2, latency: 10, upMbps: 50, explicit: true}, - {to: 3, latency: 5, upMbps: 50, explicit: true}, - }, - 3: { - {to: 2, latency: 5, upMbps: 50, explicit: true}, - }, - 2: {}, + // 显式链路也要求至少一侧存在 endpoint(符合真实运行时:需要可连接入口) + p1.AdvertisedEndpoints = []*models.Endpoint{{EndpointEntity: &models.EndpointEntity{Host: "redacted.example", Port: 61820, Type: "ws", WireGuardID: 1, ClientID: "c1"}}} + p2.AdvertisedEndpoints = []*models.Endpoint{{EndpointEntity: &models.EndpointEntity{Host: "redacted.example", Port: 61820, Type: "ws", WireGuardID: 2, ClientID: "c2"}}} + p3.AdvertisedEndpoints = []*models.Endpoint{{EndpointEntity: &models.EndpointEntity{Host: "redacted.example", Port: 61820, Type: "ws", WireGuardID: 3, ClientID: "c3"}}} + + peers := []*models.WireGuard{p1, p2, p3} + link := func(from, to uint, latency uint32) *models.WireGuardLink { + return &models.WireGuardLink{WireGuardLinkEntity: &models.WireGuardLinkEntity{ + FromWireGuardID: from, + ToWireGuardID: to, + UpBandwidthMbps: 50, + LatencyMs: latency, + Active: true, + }} + } + links := []*models.WireGuardLink{ + link(1, 2, 5), link(2, 1, 5), + link(1, 3, 8), link(3, 1, 8), + link(3, 2, 8), link(2, 3, 8), } cache := &fakeTopologyCache{ rt: map[uint]*pb.WGDeviceRuntimeInfo{ - 1: { - Peers: []*pb.WGPeerRuntimeInfo{ - {ClientId: "c2", LastHandshakeTimeSec: uint64(now - 3600)}, // stale - {ClientId: "c3", LastHandshakeTimeSec: uint64(now)}, // fresh - }, - }, - 3: { - Peers: []*pb.WGPeerRuntimeInfo{ - {ClientId: "c2", LastHandshakeTimeSec: uint64(now)}, // fresh - }, - }, + 1: {Peers: []*pb.WGPeerRuntimeInfo{ + {ClientId: "c2", LastHandshakeTimeSec: uint64(now - 3600)}, // stale + {ClientId: "c3", LastHandshakeTimeSec: uint64(now)}, // fresh + }}, + 2: {Peers: []*pb.WGPeerRuntimeInfo{ + {ClientId: "c1", LastHandshakeTimeSec: uint64(now - 3600)}, // stale (对称) + {ClientId: "c3", LastHandshakeTimeSec: uint64(now)}, // fresh + }}, + 3: {Peers: []*pb.WGPeerRuntimeInfo{ + {ClientId: "c1", LastHandshakeTimeSec: uint64(now)}, + {ClientId: "c2", LastHandshakeTimeSec: uint64(now)}, + }}, }, } - policy := RoutingPolicy{ - LatencyWeight: 1, - InverseBandwidthWeight: 0, - HopWeight: 0, - HandshakeStaleThreshold: 1 * time.Second, - HandshakeStalePenalty: 100, - NetworkTopologyCache: cache, + policy := DefaultRoutingPolicy(NewACL(), cache, nil) + policy.HandshakeStaleThreshold = 1 * time.Second + policy.HandshakeStalePenalty = 1000 + policy.InverseBandwidthWeight = 0 + policy.HopWeight = 0 + policy.LatencyLogScale = 0 + + peerCfgs, _, err := PlanAllowedIPs(peers, links, policy) + if err != nil { + t.Fatalf("PlanAllowedIPs err: %v", err) } - aggByNode, _ := runAllPairsDijkstra(order, adj, idToPeer, policy) - if aggByNode[1] == nil { - t.Fatalf("aggByNode[1] should not be nil") + // 对 node1:10.0.0.2/32 应走 peer(3) 而不是 peer(2) + wantDst := "10.0.0.2/32" + var gotPeer uint32 + for _, pc := range peerCfgs[1] { + if pc == nil { + continue + } + if lo.Contains(pc.GetAllowedIps(), wantDst) { + gotPeer = pc.GetId() + } } - // dst=2 的 CIDR 应该被聚合到 nextHop=3 下(而不是 nextHop=2) - if _, ok := aggByNode[1][3]; !ok { - t.Fatalf("want nextHop=3 for src=1, got keys=%v", keysUint(aggByNode[1])) + if gotPeer != 3 { + t.Fatalf("want node1 route %s via peer 3, got peer %d", wantDst, gotPeer) } - if _, ok := aggByNode[1][2]; ok { - t.Fatalf("did not expect nextHop=2 for src=1 when handshake is stale") - } -} - -func keysUint(m map[uint]map[string]struct{}) []uint { - ret := make([]uint, 0, len(m)) - for k := range m { - ret = append(ret, k) - } - return ret } func TestSymmetrizeAdjacencyForPeers_FillReverseEdge(t *testing.T) { @@ -199,88 +193,205 @@ func TestFilterAdjacencyForSymmetricLinks_DropOneWay(t *testing.T) { } func TestEnsureRoutingPeerSymmetry_AddReversePeer(t *testing.T) { - // 构造一个“1 直连 2,但 2 到 1 会更偏好走 3”的场景: - // 1->2 成为承载路由的 nextHop,但 2 的路由结果中可能不包含 peer(1),需要对称补齐。 - now := time.Now().Unix() + t.Skip("routing planner rewritten: inbound-source-set generation replaces old symmetry patching") +} - priv1, _ := wgtypes.GeneratePrivateKey() - priv2, _ := wgtypes.GeneratePrivateKey() - priv3, _ := wgtypes.GeneratePrivateKey() +func TestPlanAllowedIPs_Regression_NoDuplicateAllowedIPs_And_TransitSourceValidation(t *testing.T) { + // 复现 & 防回归: + // 1) 同一节点的 AllowedIPs 不允许在多个 peer 间重复(例如 10.10.0.4/32 只能分配给一个 nextHop) + // 2) 多跳转发时,入站 source validation 需要允许“原始源地址”: + // 构造 21(10.10.0.8) -> 16(10.10.0.2) 走 24 中转, + // 期望 16 的 peer(24) AllowedIPs 包含 10.10.0.8/32(否则 16 会丢弃来自 24 的转发包)。 - p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ - ClientID: "c1", - PrivateKey: priv1.String(), - LocalAddress: "10.0.0.1/32", - }} - p1.ID = 1 - p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ - ClientID: "c2", - PrivateKey: priv2.String(), - LocalAddress: "10.0.0.2/32", - }} - p2.ID = 2 - p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ - ClientID: "c3", - PrivateKey: priv3.String(), - LocalAddress: "10.0.0.3/32", - }} - p3.ID = 3 - - idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3} - order := []uint{1, 2, 3} - - // 全双向连通,但设置权重让 2->1 更偏好 2->3->1 - adj := map[uint][]Edge{ - 1: { - {to: 2, latency: 1, upMbps: 50, explicit: true}, - {to: 3, latency: 100, upMbps: 50, explicit: true}, - }, - 2: { - {to: 1, latency: 100, upMbps: 50, explicit: true}, - {to: 3, latency: 1, upMbps: 50, explicit: true}, - }, - 3: { - {to: 1, latency: 1, upMbps: 50, explicit: true}, - {to: 2, latency: 100, upMbps: 50, explicit: true}, - }, + type node struct { + id uint + cid string + addr string + tags []string + hasEP bool } - cache := &fakeTopologyCache{ - rt: map[uint]*pb.WGDeviceRuntimeInfo{ - 1: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c2", LastHandshakeTimeSec: uint64(now)}}}, - 2: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c1", LastHandshakeTimeSec: uint64(now)}}}, - }, + nodes := []node{ + {id: 4, cid: "c4", addr: "10.10.0.4/24", tags: []string{"cn", "bj"}, hasEP: true}, + {id: 11, cid: "c11", addr: "10.10.0.1/24", tags: []string{"cn", "wh"}, hasEP: false}, + {id: 16, cid: "c16", addr: "10.10.0.2/24", tags: []string{"cn", "bj", "ali"}, hasEP: true}, + {id: 17, cid: "c17", addr: "10.10.0.3/24", tags: []string{"cn", "wh"}, hasEP: false}, + {id: 18, cid: "c18", addr: "10.10.0.6/24", tags: []string{"us"}, hasEP: true}, + {id: 20, cid: "c20", addr: "10.10.0.7/24", tags: []string{"us"}, hasEP: false}, + {id: 21, cid: "c21", addr: "10.10.0.8/24", tags: []string{"cn", "nc"}, hasEP: false}, + {id: 22, cid: "c22", addr: "10.10.0.9/24", tags: []string{"cn", "nc"}, hasEP: false}, + {id: 24, cid: "c24", addr: "10.10.0.5/24", tags: []string{"cn", "nc"}, hasEP: true}, } - policy := RoutingPolicy{ - LatencyWeight: 1, - InverseBandwidthWeight: 0, - HopWeight: 0, - HandshakeStaleThreshold: 1 * time.Hour, - HandshakeStalePenalty: 0, - NetworkTopologyCache: cache, + makePeer := func(n node) *models.WireGuard { + priv, _ := wgtypes.GeneratePrivateKey() + wg := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: n.cid, + PrivateKey: priv.String(), + LocalAddress: n.addr, + Tags: n.tags, + }} + wg.ID = n.id + if n.hasEP { + wg.AdvertisedEndpoints = []*models.Endpoint{ + {EndpointEntity: &models.EndpointEntity{ + Host: "redacted.example", + Port: 61820, + Type: "ws", + WireGuardID: n.id, + ClientID: n.cid, + }}, + } + } + return wg } - aggByNode, edgeInfo := runAllPairsDijkstra(order, adj, idToPeer, policy) - peersMap, err := assemblePeerConfigs(order, aggByNode, edgeInfo, idToPeer) + peers := lo.Map(nodes, func(n node, _ int) *models.WireGuard { return makePeer(n) }) + + // 构造 ACL(与用户提供一致:只验证 tag 匹配逻辑正确,不涉及公网信息) + acl := NewACL().LoadFromPB(&pb.AclConfig{Acls: []*pb.AclRuleConfig{ + {Action: "allow", Src: []string{"bj", "wh"}, Dst: []string{"bj", "wh"}}, + {Action: "allow", Src: []string{"nc", "wh"}, Dst: []string{"nc", "wh"}}, + {Action: "allow", Src: []string{"nc", "ali"}, Dst: []string{"nc", "ali"}}, + {Action: "allow", Src: []string{"wh", "ali"}, Dst: []string{"wh", "ali"}}, + {Action: "allow", Src: []string{"us"}, Dst: []string{"us"}}, + }}) + + // 只需要 latency cache 为推断边提供“探测存在性”,这里直接手动构造显式 links,更可控 + // 关键:让 21->16 走 24 中转(21-24-16 低延迟,21-16 高延迟) + link := func(from, to uint, latency uint32) *models.WireGuardLink { + return &models.WireGuardLink{WireGuardLinkEntity: &models.WireGuardLinkEntity{ + FromWireGuardID: from, + ToWireGuardID: to, + UpBandwidthMbps: 50, + LatencyMs: latency, + Active: true, + }} + } + links := []*models.WireGuardLink{ + link(21, 24, 10), link(24, 21, 10), + link(24, 16, 10), link(16, 24, 10), + link(21, 16, 200), link(16, 21, 200), + + // 再补一些连通边,确保能算出包含 4 的路由 + link(11, 16, 30), link(16, 11, 30), + link(16, 4, 5), link(4, 16, 5), + link(11, 4, 50), link(4, 11, 50), + } + + policy := DefaultRoutingPolicy(acl, &fakeTopologyCache{lat: map[[2]uint]uint32{}}, nil) + policy.HandshakeStalePenalty = 0 + policy.HandshakeStaleThreshold = 0 + policy.InverseBandwidthWeight = 0 + policy.HopWeight = 0 + policy.LatencyLogScale = 0 + + peerCfgs, _, err := PlanAllowedIPs(peers, links, policy) if err != nil { - t.Fatalf("assemblePeerConfigs err: %v", err) + t.Fatalf("PlanAllowedIPs err: %v", err) } - fillIsolates(order, peersMap) - // 预期:在没有对称补齐前,2 可能不会包含 peer(1) - _ = ensureRoutingPeerSymmetry(order, peersMap, idToPeer) - - found := false - for _, pc := range peersMap[2] { - if pc != nil && pc.GetId() == 1 { - found = true - if len(pc.GetAllowedIps()) == 0 || pc.GetAllowedIps()[0] != "10.0.0.1/32" { - t.Fatalf("peer(1) on node2 should include 10.0.0.1/32, got=%v", pc.GetAllowedIps()) + // 1) 断言:每个节点的 AllowedIPs 在不同 peer 间不重复 + for owner, pcs := range peerCfgs { + seen := map[string]uint32{} + for _, pc := range pcs { + if pc == nil { + continue + } + for _, cidr := range pc.GetAllowedIps() { + if prev, ok := seen[cidr]; ok && prev != pc.GetId() { + t.Fatalf("node %d has duplicate cidr %s on peer %d and peer %d", owner, cidr, prev, pc.GetId()) + } + seen[cidr] = pc.GetId() } } } + + // 2) 断言:16 的 peer(24) 必须包含 10.10.0.8/32(21 的 /32),用于入站 source validation + wantSrc := "10.10.0.8/32" + found := false + for _, pc := range peerCfgs[16] { + if pc == nil || pc.GetId() != 24 { + continue + } + if lo.Contains(pc.GetAllowedIps(), wantSrc) { + found = true + } + } if !found { - t.Fatalf("node2 should contain peer(1) after ensureRoutingPeerSymmetry") + t.Fatalf("node 16 peer(24) should contain %s for transit source validation", wantSrc) + } + + // 3) 断言:11 节点的 10.10.0.4/32 不能同时出现在多个 peer + wantC4 := "10.10.0.4/32" + var peersWithC4 []uint32 + for _, pc := range peerCfgs[11] { + if pc == nil { + continue + } + if lo.Contains(pc.GetAllowedIps(), wantC4) { + peersWithC4 = append(peersWithC4, pc.GetId()) + } + } + if len(peersWithC4) != 1 { + t.Fatalf("node 11 should have exactly one peer carrying %s, got peers=%v", wantC4, peersWithC4) + } +} + +func TestBuildAdjacency_InferredEdgesAreBidirectionalWhenACLAllows(t *testing.T) { + // 回归:推断边必须支持 to(with endpoint) -> from(no endpoint) 的反向补齐, + // 否则 filterAdjacencyForSymmetricLinks 会把所有 “no-endpoint 节点” 剔除,导致 SPF 结果为空。 + + privA, _ := wgtypes.GeneratePrivateKey() + privB, _ := wgtypes.GeneratePrivateKey() + + a := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "ca", + PrivateKey: privA.String(), + LocalAddress: "10.0.0.1/24", + Tags: []string{"t1"}, + }} + a.ID = 1 // no endpoint + + b := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ + ClientID: "cb", + PrivateKey: privB.String(), + LocalAddress: "10.0.0.2/24", + Tags: []string{"t1"}, + }} + b.ID = 2 + b.AdvertisedEndpoints = []*models.Endpoint{{EndpointEntity: &models.EndpointEntity{ + Host: "redacted.example", + Port: 61820, + Type: "ws", + WireGuardID: 2, + ClientID: "cb", + }}} + + idToPeer, order := buildNodeIndexSorted([]*models.WireGuard{a, b}) + acl := NewACL().LoadFromPB(&pb.AclConfig{Acls: []*pb.AclRuleConfig{ + {Action: "allow", Src: []string{"t1"}, Dst: []string{"t1"}}, + }}) + policy := DefaultRoutingPolicy(acl, &fakeTopologyCache{lat: map[[2]uint]uint32{ + {1, 2}: 10, + {2, 1}: 10, + }}, nil) + + adj := buildAdjacency(order, idToPeer, nil, policy) + // 期望:1->2 与 2->1 都存在(推断边双向) + has12 := false + for _, e := range adj[1] { + if e.to == 2 { + has12 = true + } + } + has21 := false + for _, e := range adj[2] { + if e.to == 1 { + has21 = true + } + } + if !has12 || !has21 { + t.Fatalf("want inferred edges 1->2 and 2->1, got has12=%v has21=%v adj=%#v", has12, has21, adj) } }