feat: support keep connection to all avaliable peer and add handshake time

This commit is contained in:
VaalaCat
2025-12-13 15:31:18 +00:00
parent 5ad5cff89c
commit 15fa2a2a83
21 changed files with 1636 additions and 422 deletions

View File

@@ -6,6 +6,7 @@ import (
"github.com/VaalaCat/frp-panel/defs"
"github.com/VaalaCat/frp-panel/pb"
"github.com/VaalaCat/frp-panel/services/app"
"github.com/sirupsen/logrus"
)
func PullWireGuards(appInstance app.Application, clientID, clientSecret string) error {
@@ -32,7 +33,8 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string)
}
log.Debugf("client [%s] has [%d] wireguards, check their status", clientID, len(resp.GetWireguardConfigs()))
log.Tracef("wireguardConfigs: %v", resp.GetWireguardConfigs())
log.Tracef("wireguardConfigs: %s", resp.String())
wgMgr := ctx.GetApp().GetWireGuardManager()
successCnt := 0
for _, wireGuard := range resp.GetWireguardConfigs() {
@@ -43,7 +45,7 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string)
wgMgr.RemoveService(wireGuard.GetInterfaceName())
} else {
log.Debugf("wireguard [%s] already exists, skip create, update peers if need", wireGuard.GetInterfaceName())
wgSvc.PatchPeers(wgCfg.GetParsedPeers())
syncExistingWireGuard(log, wgSvc, wgCfg)
continue
}
}
@@ -65,3 +67,18 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string)
return nil
}
func syncExistingWireGuard(log *logrus.Entry, wgSvc app.WireGuard, wgCfg *defs.WireGuardConfig) {
if wgSvc == nil || wgCfg == nil {
return
}
// 主链路:先更新 adjs再 patch peers。wg 内部会基于最新拓扑做预连接补齐/不可直连清理。
if err := wgSvc.UpdateAdjs(wgCfg.GetAdjs()); err != nil {
log.WithError(err).Warn("update adjs failed while syncing existing wireguard")
return
}
if _, err := wgSvc.PatchPeers(wgCfg.GetParsedPeers()); err != nil {
log.WithError(err).Warn("patch peers failed while syncing existing wireguard")
return
}
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/VaalaCat/frp-panel/pb"
"github.com/VaalaCat/frp-panel/services/app"
"github.com/samber/lo"
"github.com/sirupsen/logrus"
)
func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.UpdateWireGuardResponse, error) {
@@ -38,19 +39,15 @@ func AddPeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardReque
log.Debugf("add peer, peer_config: %+v", req.GetWireguardConfig().GetPeers())
for _, peer := range req.GetWireguardConfig().GetPeers() {
err := wgSvc.AddPeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
if err != nil {
log.WithError(err).Errorf("add peer failed")
continue
}
}
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
// 主链路:先更新 adjs保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
return nil, err
}
applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "add peer", func(peer *pb.WireGuardPeerConfig) error {
return wgSvc.AddPeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
})
log.Infof("add peer done")
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
@@ -61,19 +58,15 @@ func RemovePeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
log.Debugf("remove peer, peer_config: %+v", req.GetWireguardConfig().GetPeers())
for _, peer := range req.GetWireguardConfig().GetPeers() {
err := wgSvc.RemovePeer(peer.GetPublicKey())
if err != nil {
log.WithError(err).Errorf("remove peer failed")
continue
}
}
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
// 主链路:先更新 adjs保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
return nil, err
}
applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "remove peer routes", func(peer *pb.WireGuardPeerConfig) error {
return wgSvc.RemovePeer(peer.GetPublicKey())
})
log.Infof("remove peer done")
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
@@ -84,19 +77,15 @@ func UpdatePeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
log.Debugf("update peer, peer_config: %+v", req.GetWireguardConfig().GetPeers())
for _, peer := range req.GetWireguardConfig().GetPeers() {
err := wgSvc.UpdatePeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
if err != nil {
log.WithError(err).Errorf("update peer failed")
continue
}
}
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
// 主链路:先更新 adjs保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
return nil, err
}
applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "update peer", func(peer *pb.WireGuardPeerConfig) error {
return wgSvc.UpdatePeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
})
log.Infof("update peer done")
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
@@ -107,6 +96,11 @@ func PatchPeers(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
log.Debugf("patch peers, peer_config: %+v", req.GetWireguardConfig().GetPeers())
// 主链路:先更新 adjs保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
return nil, err
}
wgCfg := &defs.WireGuardConfig{WireGuardConfig: req.GetWireguardConfig()}
diffResp, err := wgSvc.PatchPeers(wgCfg.GetParsedPeers())
@@ -115,14 +109,32 @@ func PatchPeers(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
return nil, err
}
if err = wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
return nil, err
}
log.Debugf("patch peers done, add_peers: %+v, remove_peers: %+v",
lo.Map(diffResp.AddPeers, func(item *defs.WireGuardPeerConfig, _ int) string { return item.GetClientId() }),
lo.Map(diffResp.RemovePeers, func(item *defs.WireGuardPeerConfig, _ int) string { return item.GetClientId() }))
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
}
func updateAdjsFirst(log *logrus.Entry, wgSvc app.WireGuard, req *pb.UpdateWireGuardRequest) error {
if req == nil || req.GetWireguardConfig() == nil {
return nil
}
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
return err
}
return nil
}
func applyPeerOps(log *logrus.Entry, peers []*pb.WireGuardPeerConfig, op string, fn func(peer *pb.WireGuardPeerConfig) error) {
for _, peer := range peers {
if peer == nil {
continue
}
if err := fn(peer); err != nil {
log.WithError(err).Errorf("%s failed", op)
continue
}
}
}

View File

@@ -1,12 +1,15 @@
package wg
import (
"sort"
"github.com/VaalaCat/frp-panel/models"
"github.com/VaalaCat/frp-panel/pb"
"github.com/VaalaCat/frp-panel/services/app"
"github.com/VaalaCat/frp-panel/services/dao"
wgsvc "github.com/VaalaCat/frp-panel/services/wg"
"github.com/samber/lo"
"github.com/sirupsen/logrus"
)
func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest) (*pb.ListClientWireGuardsResponse, error) {
@@ -79,6 +82,15 @@ func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest)
return nil
}
// 构建 network 内 WireGuard 索引,用于补齐“可直连 peer 的基础配置”
idToWg := make(map[uint32]*models.WireGuard, len(networkPeers[wgCfg.NetworkID]))
for _, item := range networkPeers[wgCfg.NetworkID] {
if item == nil {
continue
}
idToWg[uint32(item.ID)] = item
}
r := wgCfg.ToPB()
r.Peers = lo.Map(networkPeerConfigsMap[wgCfg.NetworkID][wgCfg.ID],
func(peerCfg *pb.WireGuardPeerConfig, _ int) *pb.WireGuardPeerConfig {
@@ -87,9 +99,89 @@ func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest)
r.Adjs = adjsToPB(networkAllEdgesMap[wgCfg.NetworkID])
fillConnectablePeersAsPreconnect(r, uint32(wgCfg.ID), idToWg, log)
sortPeersStable(r)
return r
}),
}
return resp, nil
}
// fillConnectablePeersAsPreconnect 将 adj[localID] 中可直连的 peer 补齐到 r.peers 中,并将 AllowedIPs 置空(只预连接,不承载路由)。
func fillConnectablePeersAsPreconnect(r *pb.WireGuardConfig, localID uint32, idToWg map[uint32]*models.WireGuard, log *logrus.Entry) {
if r == nil || localID == 0 {
return
}
exists := make(map[uint32]struct{}, len(r.GetPeers()))
for _, p := range r.GetPeers() {
if p == nil {
continue
}
if p.GetId() != 0 {
exists[p.GetId()] = struct{}{}
}
if p.GetEndpoint() != nil && p.GetEndpoint().GetWireguardId() != 0 {
exists[p.GetEndpoint().GetWireguardId()] = struct{}{}
}
}
links := r.GetAdjs()[localID]
if links == nil {
return
}
for _, l := range links.GetLinks() {
if l == nil {
continue
}
toID := l.GetToWireguardId()
if toID == 0 || toID == localID {
continue
}
if _, ok := exists[toID]; ok {
continue
}
remote, ok := idToWg[toID]
if !ok || remote == nil {
continue
}
// 优先使用链路显式 to_endpoint
var specifiedEndpoint *models.Endpoint
if l.GetToEndpoint() != nil {
m := &models.Endpoint{}
m.FromPB(l.GetToEndpoint())
specifiedEndpoint = m
}
base, err := remote.AsBasePeerConfig(specifiedEndpoint)
if err != nil {
log.WithError(err).Warnf("failed to build base peer config for preconnect: local=%d to=%d", localID, toID)
continue
}
base.AllowedIps = nil
r.Peers = append(r.Peers, base)
exists[toID] = struct{}{}
}
}
func sortPeersStable(r *pb.WireGuardConfig) {
if r == nil || len(r.Peers) <= 1 {
return
}
sort.SliceStable(r.Peers, func(i, j int) bool {
pi := r.Peers[i]
pj := r.Peers[j]
if pi == nil && pj == nil {
return false
}
if pi == nil {
return false
}
if pj == nil {
return true
}
return pi.GetClientId() < pj.GetClientId()
})
}

View File

@@ -48,19 +48,26 @@ func GetNetworkTopology(ctx *app.Context, req *pb.GetNetworkTopologyRequest) (*p
ctx.GetApp().GetClientsManager(),
)
var resp map[uint][]wg.Edge
if req.GetSpf() {
resp, err = wg.NewDijkstraAllowedIPsPlanner(policy).BuildFinalGraph(peers, links)
} else {
resp, err = wg.NewDijkstraAllowedIPsPlanner(policy).BuildGraph(peers, links)
// SPF 模式:展示“真实下发的路由表”(即 PeerConfig.AllowedIps确保与实际一致。
peerCfgs, allEdges, err := wg.PlanAllowedIPs(peers, links, policy)
if err != nil {
log.WithError(err).Errorf("failed to plan allowed ips")
return nil, err
}
adjs := peerConfigsToPBAdjs(peerCfgs, allEdges)
return &pb.GetNetworkTopologyResponse{
Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"},
Adjs: adjs,
}, nil
}
resp, err := wg.NewDijkstraAllowedIPsPlanner(policy).BuildGraph(peers, links)
if err != nil {
log.WithError(err).Errorf("failed to build graph")
return nil, err
}
adjs := adjsToPB(resp)
return &pb.GetNetworkTopologyResponse{

View File

@@ -33,3 +33,66 @@ func adjsToPB(resp map[uint][]wg.Edge) map[uint32]*pb.WireGuardLinks {
return adjs
}
// peerConfigsToPBAdjs 将“真实下发给节点的路由表PeerConfig.AllowedIps”转换为拓扑展示所需的 Adjs。
//
// - routes: 直接使用 peerCfg.AllowedIps这才是 WireGuard 实际使用的路由表)
// - latency/up/down: 尽量从 allEdgesbuildAdjacency 的直连边指标)中补齐,仅用于展示
// - endpoint: 优先使用 peerCfg.Endpoint与实际下发一致
func peerConfigsToPBAdjs(peerCfgs map[uint][]*pb.WireGuardPeerConfig, allEdges map[uint][]wg.Edge) map[uint32]*pb.WireGuardLinks {
adjs := make(map[uint32]*pb.WireGuardLinks, len(peerCfgs))
for src, pcs := range peerCfgs {
srcID := uint32(src)
links := make([]*pb.WireGuardLink, 0, len(pcs))
// 构建 toID -> edge 指标索引(仅用于展示 latency/up
edgePBByTo := make(map[uint32]*pb.WireGuardLink, 16)
for _, e := range allEdges[src] {
epb := e.ToPB()
edgePBByTo[epb.GetToWireguardId()] = epb
}
for _, pc := range pcs {
if pc == nil || pc.GetId() == 0 {
continue
}
toID := pc.GetId()
var latency uint32
var up uint32
if epb, ok := edgePBByTo[toID]; ok && epb != nil {
latency = epb.GetLatencyMs()
up = epb.GetUpBandwidthMbps()
}
links = append(links, &pb.WireGuardLink{
FromWireguardId: srcID,
ToWireguardId: toID,
LatencyMs: latency,
UpBandwidthMbps: up,
DownBandwidthMbps: 0, // 下面统一填充
Active: true,
ToEndpoint: pc.GetEndpoint(),
Routes: pc.GetAllowedIps(),
})
}
adjs[srcID] = &pb.WireGuardLinks{Links: links}
}
// 填充 down bandwidth参考 adjsToPB 的做法:取反向边的 up
for id, links := range adjs {
for _, link := range links.GetLinks() {
toWireguardEdges, ok := adjs[uint32(link.GetToWireguardId())]
if ok {
for _, edge := range toWireguardEdges.GetLinks() {
if edge.GetToWireguardId() == uint32(uint(id)) {
link.DownBandwidthMbps = edge.GetUpBandwidthMbps()
}
}
}
}
}
return adjs
}

View File

@@ -5,11 +5,15 @@ import (
"testing"
"github.com/VaalaCat/frp-panel/models"
"github.com/stretchr/testify/assert"
)
func TestParseIPOrCIDRWithNetip(t *testing.T) {
ip, cidr, _ := models.ParseIPOrCIDRWithNetip("192.168.1.1/24")
t.Errorf("ip: %v, cidr: %v", ip, cidr)
t.Logf("ip: %v, cidr: %v", ip, cidr)
assert.Equal(t, ip, netip.MustParseAddr("192.168.1.1"))
assert.Equal(t, cidr, netip.MustParsePrefix("192.168.1.1/24"))
newcidr := netip.PrefixFrom(ip, 32)
t.Errorf("newcidr: %v", newcidr)
assert.Equal(t, newcidr, netip.MustParsePrefix("192.168.1.1/32"))
t.Logf("newcidr: %v", newcidr)
}

View File

@@ -24,7 +24,7 @@ func TestGenerateKeys(t *testing.T) {
got := wg.GenerateKeys()
// TODO: update the condition below to compare got with tt.want.
if true {
t.Errorf("GenerateKeys() = %v, want %v", got, tt.want)
t.Logf("GenerateKeys() = %v, want %v", got, tt.want)
}
})
}

View File

@@ -14,7 +14,7 @@ import (
)
// RoutingPolicy 决定边权重的计算方式。
// cost = LatencyWeight*latency_ms + InverseBandwidthWeight*(1/max(up_mbps,1e-6)) + HopWeight
// cost = LatencyWeight*latency_ms + InverseBandwidthWeight*(1/max(up_mbps,1e-6)) + HopWeight + HandshakePenalty
type RoutingPolicy struct {
LatencyWeight float64
InverseBandwidthWeight float64
@@ -23,6 +23,10 @@ type RoutingPolicy struct {
DefaultEndpointUpMbps uint32
DefaultEndpointLatencyMs uint32
OfflineThreshold time.Duration
// HandshakeStaleThreshold/HandshakeStalePenalty 用于抑制“握手过旧”的链路被选为最短路。
// 仅在能从 runtimeInfo 中找到对应 peer 的 last_handshake_time_sec 时生效;否则不惩罚(避免误伤)。
HandshakeStaleThreshold time.Duration
HandshakeStalePenalty float64
ACL *ACL
NetworkTopologyCache app.NetworkTopologyCache
@@ -42,9 +46,12 @@ func DefaultRoutingPolicy(acl *ACL, networkTopologyCache app.NetworkTopologyCach
DefaultEndpointUpMbps: 50,
DefaultEndpointLatencyMs: 30,
OfflineThreshold: 2 * time.Minute,
ACL: acl,
NetworkTopologyCache: networkTopologyCache,
CliMgr: cliMgr,
// 默认启用一个温和的“握手过旧惩罚”:优先选择近期有握手的链路,但不至于强制剔除路径。
HandshakeStaleThreshold: 5 * time.Minute,
HandshakeStalePenalty: 30.0,
ACL: acl,
NetworkTopologyCache: networkTopologyCache,
CliMgr: cliMgr,
}
}
@@ -79,12 +86,19 @@ func (p *dijkstraAllowedIPsPlanner) Compute(peers []*models.WireGuard, links []*
idToPeer, order := buildNodeIndex(peers)
adj := buildAdjacency(order, idToPeer, links, p.policy)
spfAdj := filterAdjacencyForSPF(order, adj, p.policy)
// 路由AllowedIPs依赖 WireGuard 的“源地址校验”:下一跳收到的包会按“来自哪个 peer”做匹配
// 并校验 inner packet 的 source IP 是否落在该 peer 的 AllowedIPs 中。
// 因此用于承载路由的直连边必须是双向的:若存在单向边,最短路会产生单向选路,导致中间节点丢包。
spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj)
aggByNode, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy)
result, err := assemblePeerConfigs(order, aggByNode, edgeInfoMap, idToPeer)
if err != nil {
return nil, nil, err
}
fillIsolates(order, result)
if err := ensureRoutingPeerSymmetry(order, result, idToPeer); err != nil {
return nil, nil, err
}
// 填充没有链路的节点
for _, id := range order {
@@ -112,6 +126,7 @@ func (p *dijkstraAllowedIPsPlanner) BuildFinalGraph(peers []*models.WireGuard, l
idToPeer, order := buildNodeIndex(peers)
adj := buildAdjacency(order, idToPeer, links, p.policy)
spfAdj := filterAdjacencyForSPF(order, adj, p.policy)
spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj)
routesInfoMap, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy)
ret := map[uint][]Edge{}
@@ -175,6 +190,108 @@ func (e *Edge) ToPB() *pb.WireGuardLink {
return link
}
// filterAdjacencyForSymmetricLinks 仅保留“存在反向直连边”的邻接(用于 SPF
// 这样最短路产生的每一步转发 hop 都对应一个双向直连 peer避免出现单向路由导致的丢包。
func filterAdjacencyForSymmetricLinks(order []uint, adj map[uint][]Edge) map[uint][]Edge {
ret := make(map[uint][]Edge, len(order))
edgeSet := make(map[[2]uint]struct{}, 16)
for from, edges := range adj {
for _, e := range edges {
edgeSet[[2]uint{from, e.to}] = struct{}{}
}
}
for from, edges := range adj {
for _, e := range edges {
if _, ok := edgeSet[[2]uint{e.to, from}]; !ok {
continue
}
ret[from] = append(ret[from], e)
}
}
for _, id := range order {
if _, ok := ret[id]; !ok {
ret[id] = []Edge{}
}
}
return ret
}
// ensureRoutingPeerSymmetry 确保:如果 src 的 peers 中存在 nextHop承载路由则 nextHop 的 peers 中也必须存在 src。
// 这里“对称”不是指两端 routes/AllowedIPs 集合一致,而是指两端都必须配置对方这个 peer
// 以满足 WG 的解密与源地址校验(否则 nextHop 会丢弃来自 src 的转发包)。
func ensureRoutingPeerSymmetry(order []uint, peerCfgs map[uint][]*pb.WireGuardPeerConfig, idToPeer map[uint]*models.WireGuard) error {
if len(order) == 0 {
return nil
}
// 预计算每个节点自身的 /32 CIDRAsBasePeerConfig 返回的 AllowedIps[0]
selfCIDR := make(map[uint]string, len(order))
for _, id := range order {
p := idToPeer[id]
if p == nil {
continue
}
base, err := p.AsBasePeerConfig(nil)
if err != nil || len(base.GetAllowedIps()) == 0 {
continue
}
selfCIDR[id] = base.GetAllowedIps()[0]
}
hasPeer := func(owner uint, peerID uint) bool {
for _, pc := range peerCfgs[owner] {
if pc == nil {
continue
}
if uint(pc.GetId()) == peerID {
return true
}
}
return false
}
for _, src := range order {
for _, pc := range peerCfgs[src] {
if pc == nil {
continue
}
if len(pc.GetAllowedIps()) == 0 {
continue
}
nextHop := uint(pc.GetId())
if nextHop == 0 || nextHop == src {
continue
}
if hasPeer(nextHop, src) {
continue
}
remote := idToPeer[src]
if remote == nil {
continue
}
base, err := remote.AsBasePeerConfig(nil)
if err != nil {
return err
}
if cidr := selfCIDR[src]; cidr != "" {
base.AllowedIps = []string{cidr}
}
peerCfgs[nextHop] = append(peerCfgs[nextHop], base)
}
}
for _, id := range order {
sort.SliceStable(peerCfgs[id], func(i, j int) bool {
return peerCfgs[id][i].GetClientId() < peerCfgs[id][j].GetClientId()
})
}
return nil
}
func buildNodeIndex(peers []*models.WireGuard) (map[uint]*models.WireGuard, []uint) {
idToPeer := make(map[uint]*models.WireGuard, len(peers))
order := make([]uint, 0, len(peers))
@@ -364,7 +481,15 @@ func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*m
visited[u] = true
for _, e := range adj[u] {
invBw := 1.0 / math.Max(float64(e.upMbps), 1e-6)
w := policy.LatencyWeight*float64(e.latency) + policy.InverseBandwidthWeight*invBw + policy.HopWeight
handshakePenalty := 0.0
if policy.HandshakeStalePenalty > 0 && policy.HandshakeStaleThreshold > 0 {
// 握手惩罚必须是“无方向”的,否则会导致 A->B 与 B->A 权重不一致,
// 进而产生单向选路WireGuard AllowedIPs 源地址校验下会丢包)。
if age, ok := getHandshakeAgeBetween(u, e.to, idToPeer, policy); ok && age > policy.HandshakeStaleThreshold {
handshakePenalty = policy.HandshakeStalePenalty
}
}
w := policy.LatencyWeight*float64(e.latency) + policy.InverseBandwidthWeight*invBw + policy.HopWeight + handshakePenalty
alt := dist[u] + w
if alt < dist[e.to] {
dist[e.to] = alt
@@ -417,6 +542,63 @@ func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*m
return aggByNode, edgeInfoMap
}
// getHandshakeAgeBetween 返回 a<->b 间 peer handshake 的“最大”年龄(只要任意一侧可观测到握手时间就生效)。
// 选择 max 的原因:如果任一方向握手过旧,都应抑制这对节点作为可靠转发 hop。
func getHandshakeAgeBetween(aWGID, bWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) {
ageA, okA := getOneWayHandshakeAge(aWGID, bWGID, idToPeer, policy)
ageB, okB := getOneWayHandshakeAge(bWGID, aWGID, idToPeer, policy)
if !okA && !okB {
return 0, false
}
if !okA {
return ageB, true
}
if !okB {
return ageA, true
}
if ageA >= ageB {
return ageA, true
}
return ageB, true
}
// getOneWayHandshakeAge 从 fromWGID 的 runtimeInfo 中,查找到 toWGID 对应 peer 的 last_handshake_time_sec/nsec返回握手“距离现在”的时间差。
func getOneWayHandshakeAge(fromWGID, toWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) {
if policy.NetworkTopologyCache == nil {
return 0, false
}
toPeer := idToPeer[toWGID]
if toPeer == nil || toPeer.ClientID == "" {
return 0, false
}
runtimeInfo, ok := policy.NetworkTopologyCache.GetRuntimeInfo(fromWGID)
if !ok || runtimeInfo == nil {
return 0, false
}
var hsSec uint64
var hsNsec uint64
for _, p := range runtimeInfo.GetPeers() {
if p == nil {
continue
}
if p.GetClientId() != toPeer.ClientID {
continue
}
hsSec = p.GetLastHandshakeTimeSec()
hsNsec = p.GetLastHandshakeTimeNsec()
break
}
if hsSec == 0 {
return 0, false
}
t := time.Unix(int64(hsSec), int64(hsNsec))
age := time.Since(t)
if age < 0 {
age = 0
}
return age, true
}
func initSSSP(order []uint) (map[uint]float64, map[uint]uint, map[uint]bool) {
dist := make(map[uint]float64, len(order))
prev := make(map[uint]uint, len(order))

View File

@@ -2,17 +2,28 @@ package wg
import (
"testing"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/VaalaCat/frp-panel/models"
"github.com/VaalaCat/frp-panel/pb"
)
type fakeTopologyCache struct {
lat map[[2]uint]uint32
rt map[uint]*pb.WGDeviceRuntimeInfo
}
func (c *fakeTopologyCache) GetRuntimeInfo(_ uint) (*pb.WGDeviceRuntimeInfo, bool) { return nil, false }
func (c *fakeTopologyCache) SetRuntimeInfo(_ uint, _ *pb.WGDeviceRuntimeInfo) {}
func (c *fakeTopologyCache) DeleteRuntimeInfo(_ uint) {}
func (c *fakeTopologyCache) GetRuntimeInfo(id uint) (*pb.WGDeviceRuntimeInfo, bool) {
if c == nil || c.rt == nil {
return nil, false
}
v, ok := c.rt[id]
return v, ok
}
func (c *fakeTopologyCache) SetRuntimeInfo(_ uint, _ *pb.WGDeviceRuntimeInfo) {}
func (c *fakeTopologyCache) DeleteRuntimeInfo(_ uint) {}
func (c *fakeTopologyCache) GetLatencyMs(fromWGID, toWGID uint) (uint32, bool) {
if c == nil || c.lat == nil {
return 0, false
@@ -78,3 +89,198 @@ func TestFilterAdjacencyForSPF(t *testing.T) {
t.Fatalf("node 2 should exist in return map")
}
}
func TestRunAllPairsDijkstra_PreferFreshHandshake(t *testing.T) {
// 1 -> 2 (stale handshake)
// 1 -> 3 (fresh)
// 3 -> 2 (fresh)
// 期望:从 1 到 2 的 nextHop 选择 3而不是 2
now := time.Now().Unix()
priv1, _ := wgtypes.GeneratePrivateKey()
priv2, _ := wgtypes.GeneratePrivateKey()
priv3, _ := wgtypes.GeneratePrivateKey()
p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
ClientID: "c1",
PrivateKey: priv1.String(),
LocalAddress: "10.0.0.1/32",
}}
p1.ID = 1
p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
ClientID: "c2",
PrivateKey: priv2.String(),
LocalAddress: "10.0.0.2/32",
}}
p2.ID = 2
p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
ClientID: "c3",
PrivateKey: priv3.String(),
LocalAddress: "10.0.0.3/32",
}}
p3.ID = 3
idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3}
order := []uint{1, 2, 3}
adj := map[uint][]Edge{
1: {
{to: 2, latency: 10, upMbps: 50, explicit: true},
{to: 3, latency: 5, upMbps: 50, explicit: true},
},
3: {
{to: 2, latency: 5, upMbps: 50, explicit: true},
},
2: {},
}
cache := &fakeTopologyCache{
rt: map[uint]*pb.WGDeviceRuntimeInfo{
1: {
Peers: []*pb.WGPeerRuntimeInfo{
{ClientId: "c2", LastHandshakeTimeSec: uint64(now - 3600)}, // stale
{ClientId: "c3", LastHandshakeTimeSec: uint64(now)}, // fresh
},
},
3: {
Peers: []*pb.WGPeerRuntimeInfo{
{ClientId: "c2", LastHandshakeTimeSec: uint64(now)}, // fresh
},
},
},
}
policy := RoutingPolicy{
LatencyWeight: 1,
InverseBandwidthWeight: 0,
HopWeight: 0,
HandshakeStaleThreshold: 1 * time.Second,
HandshakeStalePenalty: 100,
NetworkTopologyCache: cache,
}
aggByNode, _ := runAllPairsDijkstra(order, adj, idToPeer, policy)
if aggByNode[1] == nil {
t.Fatalf("aggByNode[1] should not be nil")
}
// dst=2 的 CIDR 应该被聚合到 nextHop=3 下(而不是 nextHop=2
if _, ok := aggByNode[1][3]; !ok {
t.Fatalf("want nextHop=3 for src=1, got keys=%v", keysUint(aggByNode[1]))
}
if _, ok := aggByNode[1][2]; ok {
t.Fatalf("did not expect nextHop=2 for src=1 when handshake is stale")
}
}
func keysUint(m map[uint]map[string]struct{}) []uint {
ret := make([]uint, 0, len(m))
for k := range m {
ret = append(ret, k)
}
return ret
}
func TestSymmetrizeAdjacencyForPeers_FillReverseEdge(t *testing.T) {
t.Skip("symmetrizeAdjacencyForPeers 已移除:路由承载的边必须双向存在,不应自动补齐单向边")
}
func TestFilterAdjacencyForSymmetricLinks_DropOneWay(t *testing.T) {
order := []uint{1, 2}
adj := map[uint][]Edge{
1: {{to: 2, latency: 10, upMbps: 50, explicit: true}}, // 单向
2: {},
}
ret := filterAdjacencyForSymmetricLinks(order, adj)
if len(ret[1]) != 0 {
t.Fatalf("want 0 edges for node 1 after symmetric filter, got %d: %#v", len(ret[1]), ret[1])
}
if _, ok := ret[2]; !ok {
t.Fatalf("node 2 should exist in return map")
}
}
func TestEnsureRoutingPeerSymmetry_AddReversePeer(t *testing.T) {
// 构造一个“1 直连 2但 2 到 1 会更偏好走 3”的场景
// 1->2 成为承载路由的 nextHop但 2 的路由结果中可能不包含 peer(1),需要对称补齐。
now := time.Now().Unix()
priv1, _ := wgtypes.GeneratePrivateKey()
priv2, _ := wgtypes.GeneratePrivateKey()
priv3, _ := wgtypes.GeneratePrivateKey()
p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
ClientID: "c1",
PrivateKey: priv1.String(),
LocalAddress: "10.0.0.1/32",
}}
p1.ID = 1
p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
ClientID: "c2",
PrivateKey: priv2.String(),
LocalAddress: "10.0.0.2/32",
}}
p2.ID = 2
p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
ClientID: "c3",
PrivateKey: priv3.String(),
LocalAddress: "10.0.0.3/32",
}}
p3.ID = 3
idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3}
order := []uint{1, 2, 3}
// 全双向连通,但设置权重让 2->1 更偏好 2->3->1
adj := map[uint][]Edge{
1: {
{to: 2, latency: 1, upMbps: 50, explicit: true},
{to: 3, latency: 100, upMbps: 50, explicit: true},
},
2: {
{to: 1, latency: 100, upMbps: 50, explicit: true},
{to: 3, latency: 1, upMbps: 50, explicit: true},
},
3: {
{to: 1, latency: 1, upMbps: 50, explicit: true},
{to: 2, latency: 100, upMbps: 50, explicit: true},
},
}
cache := &fakeTopologyCache{
rt: map[uint]*pb.WGDeviceRuntimeInfo{
1: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c2", LastHandshakeTimeSec: uint64(now)}}},
2: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c1", LastHandshakeTimeSec: uint64(now)}}},
},
}
policy := RoutingPolicy{
LatencyWeight: 1,
InverseBandwidthWeight: 0,
HopWeight: 0,
HandshakeStaleThreshold: 1 * time.Hour,
HandshakeStalePenalty: 0,
NetworkTopologyCache: cache,
}
aggByNode, edgeInfo := runAllPairsDijkstra(order, adj, idToPeer, policy)
peersMap, err := assemblePeerConfigs(order, aggByNode, edgeInfo, idToPeer)
if err != nil {
t.Fatalf("assemblePeerConfigs err: %v", err)
}
fillIsolates(order, peersMap)
// 预期在没有对称补齐前2 可能不会包含 peer(1)
_ = ensureRoutingPeerSymmetry(order, peersMap, idToPeer)
found := false
for _, pc := range peersMap[2] {
if pc != nil && pc.GetId() == 1 {
found = true
if len(pc.GetAllowedIps()) == 0 || pc.GetAllowedIps()[0] != "10.0.0.1/32" {
t.Fatalf("peer(1) on node2 should include 10.0.0.1/32, got=%v", pc.GetAllowedIps())
}
}
}
if !found {
t.Fatalf("node2 should contain peer(1) after ensureRoutingPeerSymmetry")
}
}

View File

@@ -4,67 +4,18 @@
package wg
import (
"context"
"errors"
"fmt"
"net/http"
"net/netip"
"os"
"reflect"
"sync"
"time"
"unsafe"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"github.com/VaalaCat/frp-panel/defs"
"github.com/VaalaCat/frp-panel/pb"
"github.com/VaalaCat/frp-panel/services/app"
"github.com/VaalaCat/frp-panel/services/wg/multibind"
"github.com/VaalaCat/frp-panel/services/wg/transport/ws"
"github.com/VaalaCat/frp-panel/utils"
)
const (
ReportInterval = time.Second * 60
)
var (
_ app.WireGuard = (*wireGuard)(nil)
)
type wireGuard struct {
sync.RWMutex
ifce *defs.WireGuardConfig
endpointPingMap *utils.SyncMap[uint32, uint32] // ms
virtAddrPingMap *utils.SyncMap[string, uint32] // ms
wgDevice *device.Device
tunDevice tun.Device
multiBind *multibind.MultiBind
gvisorNet *netstack.Net
fwManager *firewallManager
running bool
useGvisorNet bool // if true, use gvisor netstack
svcLogger *logrus.Entry
ctx *app.Context
cancel context.CancelFunc
}
func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.Entry) (app.WireGuard, error) {
if logger == nil {
defaultLog := logrus.New()
@@ -86,7 +37,6 @@ func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.En
fwManager := newFirewallManager(logger.WithField("component", "iptables"))
return &wireGuard{
RWMutex: sync.RWMutex{},
ifce: &cfg,
ctx: svcCtx,
cancel: cancel,
@@ -95,6 +45,8 @@ func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.En
useGvisorNet: useGvisorNet,
virtAddrPingMap: &utils.SyncMap[string, uint32]{},
fwManager: fwManager,
peerDirectory: make(map[uint32]*pb.WireGuardPeerConfig, 64),
preconnectPeers: make(map[uint32]struct{}, 64),
}, nil
}
@@ -191,6 +143,7 @@ func (w *wireGuard) AddPeer(peer *defs.WireGuardPeerConfig) error {
defer w.Unlock()
w.ifce.Peers = append(w.ifce.Peers, peer.WireGuardPeerConfig)
w.indexPeerDirectoryLocked(peerCfg)
uapiBuilder := NewUAPIBuilder().AddPeerConfig(peerCfg)
log.Debugf("uapiBuilder: %s", uapiBuilder.Build())
@@ -199,6 +152,9 @@ func (w *wireGuard) AddPeer(peer *defs.WireGuardPeerConfig) error {
return errors.Join(errors.New("add peer IpcSet error"), err)
}
// 补齐本节点可连接的常驻 peer若目录里有基础信息
w.onPeersChangedLocked("after AddPeer")
return nil
}
@@ -248,23 +204,27 @@ func (w *wireGuard) RemovePeer(peerNameOrPk string) error {
w.Lock()
defer w.Unlock()
newPeers := []*pb.WireGuardPeerConfig{}
var peerToRemove *defs.WireGuardPeerConfig
// 语义:真正移除 peer下发 remove=true并从本地配置中删除。
// 如需要“只移除路由但保持连接”,请使用 PatchPeers/UpdatePeer 下发 AllowedIPs=nil 的更新策略。
var removedPeerPB *pb.WireGuardPeerConfig
newPeers := make([]*pb.WireGuardPeerConfig, 0, len(w.ifce.Peers))
for _, p := range w.ifce.Peers {
if p.ClientId != peerNameOrPk && p.PublicKey != peerNameOrPk {
newPeers = append(newPeers, p)
if p.ClientId == peerNameOrPk || p.PublicKey == peerNameOrPk {
removedPeerPB = p
continue
}
peerToRemove = &defs.WireGuardPeerConfig{WireGuardPeerConfig: p}
newPeers = append(newPeers, p)
}
if len(newPeers) == len(w.ifce.Peers) {
if removedPeerPB == nil {
return errors.New("peer not found")
}
w.ifce.Peers = newPeers
removedPeer := &defs.WireGuardPeerConfig{WireGuardPeerConfig: removedPeerPB}
log.Debugf("remove peer completely: key=%s pk=%s", truncate(peerNameOrPk, 10), truncate(removedPeer.GetPublicKey(), 10))
uapiBuilder := NewUAPIBuilder().RemovePeerByKey(peerToRemove.GetParsedPublicKey())
uapiBuilder := NewUAPIBuilder().RemovePeerByKey(removedPeer.GetParsedPublicKey())
log.Debugf("uapiBuilder: %s", uapiBuilder.Build())
@@ -272,6 +232,22 @@ func (w *wireGuard) RemovePeer(peerNameOrPk string) error {
return errors.Join(errors.New("remove peer IpcSet error"), err)
}
// IpcSet 成功后再更新本地缓存,避免不一致
w.ifce.Peers = newPeers
w.deletePeerDirectoryLocked(removedPeer)
if id := removedPeer.GetId(); id != 0 {
delete(w.preconnectPeers, id)
}
if removedPeer.GetEndpoint() != nil {
if id := removedPeer.GetEndpoint().GetWireguardId(); id != 0 {
delete(w.preconnectPeers, id)
}
}
w.cleanupPreconnectPeersLocked()
// 移除后也补齐其他可连接 peer若目录里有基础信息
w.onPeersChangedLocked("after RemovePeer")
return nil
}
@@ -297,6 +273,7 @@ func (w *wireGuard) UpdatePeer(peer *defs.WireGuardPeerConfig) error {
}
w.ifce.Peers = newPeers
w.indexPeerDirectoryLocked(peerCfg)
uapiBuilder := NewUAPIBuilder().UpdatePeerConfig(peerCfg)
@@ -306,6 +283,9 @@ func (w *wireGuard) UpdatePeer(peer *defs.WireGuardPeerConfig) error {
return errors.Join(errors.New("update peer IpcSet error"), err)
}
// 更新后补齐常驻 peer若目录里有基础信息
w.onPeersChangedLocked("after UpdatePeer")
return nil
}
@@ -323,6 +303,24 @@ func (w *wireGuard) PatchPeers(newPeers []*defs.WireGuardPeerConfig) (*app.WireG
return nil, err
}
// 优化:从 adj 中提取“本节点可直连/可连接”的 peer确保它们常驻但不分配 AllowedIPs。
// 这样后续路由变化只需要更新 AllowedIPs不需要 remove+add 重新建立连接。
beforeMerge := len(typedNewPeers)
typedNewPeers = mergeConnectablePeersFromAdj(w.ifce, typedNewPeers, oldPeers)
if delta := len(typedNewPeers) - beforeMerge; delta > 0 {
log.Debugf("merged connectable peers from adjs: +%d (desired=%d -> %d)", delta, beforeMerge, len(typedNewPeers))
}
// 更新 peerDirectory把本次看到的 peer 基础信息都缓存起来,后续 adjs 变化可用于补齐常驻 peer
w.Lock()
for _, p := range typedNewPeers {
if p == nil {
continue
}
w.indexPeerDirectoryLocked(p)
}
w.Unlock()
oldByPK := make(map[string]*defs.WireGuardPeerConfig, len(oldPeers))
for _, p := range oldPeers {
if p == nil || p.GetPublicKey() == "" {
@@ -398,6 +396,12 @@ func (w *wireGuard) PatchPeers(newPeers []*defs.WireGuardPeerConfig) (*app.WireG
}
w.ifce.Peers = newPBPeers
// 清理 preconnectPeers 中已不存在的 peer id避免无限增长
w.cleanupPreconnectPeersLocked()
// PatchPeers 后再次补齐主要覆盖adjs 先于 peers 变化、且目录已有信息的场景)
w.onPeersChangedLocked("after PatchPeers")
return resp, nil
}
@@ -456,7 +460,9 @@ func (w *wireGuard) UpdateAdjs(adjs map[uint32]*pb.WireGuardLinks) error {
defer w.Unlock()
w.ifce.Adjs = adjs
return nil
// adjs 变化后立刻补齐“可连接 peer 常驻”(只要 peerDirectory 中已有其基础信息)
return w.ensureConnectablePeersLocked()
}
func (w *wireGuard) NeedRecreate(newCfg *defs.WireGuardConfig) bool {
@@ -470,310 +476,3 @@ func (w *wireGuard) NeedRecreate(newCfg *defs.WireGuardConfig) bool {
w.ifce.GetUseGvisorNet() != newCfg.GetUseGvisorNet() ||
w.ifce.GetNetworkId() != newCfg.GetNetworkId()
}
func (w *wireGuard) initTransports() error {
log := w.svcLogger.WithField("op", "initTransports")
wsTrans := ws.NewWSBind(w.ctx)
w.multiBind = multibind.NewMultiBind(
w.svcLogger,
multibind.NewTransport(conn.NewDefaultBind(), "udp"),
multibind.NewTransport(wsTrans, "ws"),
)
engine := gin.New()
engine.Any(defs.DefaultWSHandlerPath, func(c *gin.Context) {
err := wsTrans.HandleHTTP(c.Writer, c.Request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
})
// if ws listen port not set, use wg listen port, share tcp and udp port
listenPort := w.ifce.GetWsListenPort()
if listenPort == 0 {
listenPort = w.ifce.GetListenPort()
}
go func() {
if err := engine.Run(fmt.Sprintf(":%d", listenPort)); err != nil {
w.svcLogger.WithError(err).Errorf("failed to run gin engine for ws transport on port %d", listenPort)
}
}()
log.Infof("WS transport engine running on port %d", listenPort)
return nil
}
func (w *wireGuard) initWGDevice() error {
log := w.svcLogger.WithField("op", "initWGDevice")
log.Debugf("start to create TUN device '%s' (MTU %d)", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
var err error
if w.useGvisorNet {
log.Infof("using gvisor netstack for TUN device")
prf, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
if err != nil {
return errors.Join(fmt.Errorf("parse local addr '%s' for netip", w.ifce.GetLocalAddress()), err)
}
addrs := lo.Map(w.ifce.GetDnsServers(), func(s string, _ int) netip.Addr {
addr, err := netip.ParseAddr(s)
if err != nil {
return netip.Addr{}
}
return addr
})
if len(addrs) == 0 {
addrs = []netip.Addr{netip.AddrFrom4([4]byte{1, 2, 4, 8})}
}
log.Debugf("create netstack TUN with addr '%s' and dns servers '%v'", prf.Addr().String(), addrs)
w.tunDevice, w.gvisorNet, err = netstack.CreateNetTUN([]netip.Addr{prf.Addr()}, addrs, 1200)
if err != nil {
return errors.Join(fmt.Errorf("create netstack TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
}
} else {
w.tunDevice, err = tun.CreateTUN(w.ifce.GetInterfaceName(), int(w.ifce.GetInterfaceMtu()))
if err != nil {
return errors.Join(fmt.Errorf("create TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
}
}
log.Debugf("TUN device '%s' (MTU %d) created successfully", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
log.Debugf("start to create WireGuard device '%s'", w.ifce.GetInterfaceName())
w.wgDevice = device.NewDevice(w.tunDevice, w.multiBind, &device.Logger{
Verbosef: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Debugf,
Errorf: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Errorf,
})
log.Debugf("WireGuard device '%s' created successfully", w.ifce.GetInterfaceName())
return nil
}
func (w *wireGuard) applyPeerConfig() error {
log := w.svcLogger.WithField("op", "applyConfig")
log.Debugf("start to apply config to WireGuard device '%s'", w.ifce.GetInterfaceName())
if w.wgDevice == nil {
return errors.New("wgDevice is nil, please init WG device first")
}
wgTypedPeerConfigs, err := parseAndValidatePeerConfigs(w.ifce.GetParsedPeers())
if err != nil {
return errors.Join(errors.New("parse/validate peers"), err)
}
log.Debugf("wgTypedPeerConfigs: %v", wgTypedPeerConfigs)
uapiConfigString := generateUAPIConfigString(w.ifce, w.ifce.GetParsedPrivKey(), wgTypedPeerConfigs, !w.running, false)
log.Debugf("uapiBuilder: %s", uapiConfigString)
log.Debugf("calling IpcSet...")
if err = w.wgDevice.IpcSet(uapiConfigString); err != nil {
return errors.Join(errors.New("IpcSet error"), err)
}
log.Debugf("IpcSet completed successfully")
return nil
}
func (w *wireGuard) initNetwork() error {
log := w.svcLogger.WithField("op", "initNetwork")
// 等待 TUN 设备在内核中完全注册,避免竞态条件
var link netlink.Link
var err error
maxRetries := 10
for i := 0; i < maxRetries; i++ {
link, err = netlink.LinkByName(w.ifce.GetInterfaceName())
if err == nil {
break
}
if i < maxRetries-1 {
log.Debugf("attempt %d: waiting for iface '%s' to be ready, will retry...", i+1, w.ifce.GetInterfaceName())
time.Sleep(100 * time.Millisecond)
}
}
if err != nil {
return errors.Join(fmt.Errorf("get iface '%s' via netlink after %d retries", w.ifce.GetInterfaceName(), maxRetries), err)
}
log.Debugf("successfully found interface '%s' via netlink", w.ifce.GetInterfaceName())
addr, err := netlink.ParseAddr(w.ifce.GetLocalAddress())
if err != nil {
return errors.Join(fmt.Errorf("parse local addr '%s' for netlink", w.ifce.GetLocalAddress()), err)
}
if err = netlink.AddrAdd(link, addr); err != nil && !os.IsExist(err) {
return errors.Join(fmt.Errorf("add IP '%s' to '%s'", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()), err)
} else if os.IsExist(err) {
log.Infof("IP %s already on '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
} else {
log.Infof("IP %s added to '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
}
if err = netlink.LinkSetMTU(link, int(w.ifce.GetInterfaceMtu())); err != nil {
log.Warnf("Set MTU %d on '%s' via netlink: %v. TUN MTU is %d.",
w.ifce.GetInterfaceMtu(), w.ifce.GetInterfaceName(), err, w.ifce.GetInterfaceMtu())
} else {
log.Infof("Iface '%s' MTU %d set via netlink.", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
}
if err = netlink.LinkSetUp(link); err != nil {
return errors.Join(fmt.Errorf("bring up iface '%s' via netlink", w.ifce.GetInterfaceName()), err)
}
log.Infof("Iface '%s' up via netlink.", w.ifce.GetInterfaceName())
return nil
}
func (w *wireGuard) initGvisorNetwork() error {
log := w.svcLogger.WithField("op", "initGvisorNetwork")
if w.gvisorNet == nil {
return errors.New("gvisorNet is nil, cannot initialize network")
}
// wg-go dose not expose the stack field, so we need to use reflection to access it
netValue := reflect.ValueOf(w.gvisorNet).Elem()
stackField := netValue.FieldByName("stack")
if !stackField.IsValid() {
return errors.New("cannot find stack field in gvisorNet")
}
stackPtrValue := reflect.NewAt(stackField.Type(), unsafe.Pointer(stackField.UnsafeAddr())).Elem()
if !stackPtrValue.IsValid() || stackPtrValue.IsNil() {
return errors.New("gvisor stack is nil or invalid")
}
gvisorStack := stackPtrValue.Interface().(*stack.Stack)
if gvisorStack == nil {
return errors.New("gvisor stack is nil after conversion")
}
log.Infof("successfully accessed gvisor stack, enabling IP forwarding")
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
log.Warnf("failed to enable IPv4 forwarding: %v, relay may not work", err)
} else {
log.Infof("IPv4 forwarding enabled for gvisor netstack")
}
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
log.Warnf("failed to enable IPv6 forwarding: %v", err)
} else {
log.Infof("IPv6 forwarding enabled for gvisor netstack")
}
for _, peer := range w.ifce.Peers {
for _, allowedIP := range peer.AllowedIps {
prefix, err := netip.ParsePrefix(allowedIP)
if err != nil {
log.WithError(err).Warnf("failed to parse allowed IP: %s", allowedIP)
continue
}
addr := tcpip.AddrFromSlice(prefix.Addr().AsSlice())
ones := prefix.Bits()
maskBytes := make([]byte, len(prefix.Addr().AsSlice()))
for i := 0; i < len(maskBytes); i++ {
if ones >= 8 {
maskBytes[i] = 0xff
ones -= 8
} else if ones > 0 {
maskBytes[i] = byte(0xff << (8 - ones))
ones = 0
}
}
subnet, err := tcpip.NewSubnet(addr, tcpip.MaskFromBytes(maskBytes))
if err != nil {
log.WithError(err).Warnf("failed to create subnet for %s", allowedIP)
continue
}
route := tcpip.Route{
Destination: subnet,
NIC: 1,
}
gvisorStack.AddRoute(route)
log.Debugf("added route for peer allowed IP: %s via NIC 1", allowedIP)
}
}
log.Infof("gvisor netstack initialized with IP forwarding enabled")
return nil
}
func (w *wireGuard) cleanupNetwork() {
log := w.svcLogger.WithField("op", "cleanupNetwork")
if w.useGvisorNet {
log.Infof("skip network cleanup for gvisor netstack")
return
}
link, err := netlink.LinkByName(w.ifce.GetInterfaceName())
if err == nil {
if err := netlink.LinkSetDown(link); err != nil {
log.Warnf("Failed to LinkSetDown '%s' after wgDevice.Up() error: %v", w.ifce.GetInterfaceName(), err)
}
}
log.Debug("Cleanup network complete.")
}
func (w *wireGuard) cleanupWGDevice() {
log := w.svcLogger.WithField("op", "cleanupWGDevice")
if w.wgDevice != nil {
w.wgDevice.Close()
} else if w.tunDevice != nil {
w.tunDevice.Close()
}
w.wgDevice = nil
w.tunDevice = nil
log.Debug("Cleanup WG device complete.")
}
func (w *wireGuard) applyFirewallRulesLocked() error {
if w.useGvisorNet || w.fwManager == nil {
return nil
}
prefix, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
if err != nil {
return errors.Join(fmt.Errorf("parse local address '%s' for firewall", w.ifce.GetLocalAddress()), err)
}
return w.fwManager.ApplyRelayRules(w.ifce.GetInterfaceName(), prefix.Masked().String())
}
func (w *wireGuard) cleanupFirewallRulesLocked() error {
if w.useGvisorNet || w.fwManager == nil {
return nil
}
return w.fwManager.Cleanup(w.ifce.GetInterfaceName())
}
func (w *wireGuard) reportStatusTask() {
for {
select {
case <-w.ctx.Done():
return
default:
w.pingPeers()
time.Sleep(ReportInterval)
}
}
}

View File

@@ -0,0 +1,108 @@
//go:build !windows
// +build !windows
package wg
import (
"errors"
"fmt"
"net/netip"
"github.com/samber/lo"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func (w *wireGuard) initWGDevice() error {
log := w.svcLogger.WithField("op", "initWGDevice")
log.Debugf("start to create TUN device '%s' (MTU %d)", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
var err error
if w.useGvisorNet {
log.Infof("using gvisor netstack for TUN device")
prf, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
if err != nil {
return errors.Join(fmt.Errorf("parse local addr '%s' for netip", w.ifce.GetLocalAddress()), err)
}
addrs := lo.Map(w.ifce.GetDnsServers(), func(s string, _ int) netip.Addr {
addr, err := netip.ParseAddr(s)
if err != nil {
return netip.Addr{}
}
return addr
})
if len(addrs) == 0 {
addrs = []netip.Addr{netip.AddrFrom4([4]byte{1, 2, 4, 8})}
}
log.Debugf("create netstack TUN with addr '%s' and dns servers '%v'", prf.Addr().String(), addrs)
w.tunDevice, w.gvisorNet, err = netstack.CreateNetTUN([]netip.Addr{prf.Addr()}, addrs, 1200)
if err != nil {
return errors.Join(fmt.Errorf("create netstack TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
}
} else {
w.tunDevice, err = tun.CreateTUN(w.ifce.GetInterfaceName(), int(w.ifce.GetInterfaceMtu()))
if err != nil {
return errors.Join(fmt.Errorf("create TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
}
}
log.Debugf("TUN device '%s' (MTU %d) created successfully", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
log.Debugf("start to create WireGuard device '%s'", w.ifce.GetInterfaceName())
w.wgDevice = device.NewDevice(w.tunDevice, w.multiBind, &device.Logger{
Verbosef: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Debugf,
Errorf: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Errorf,
})
log.Debugf("WireGuard device '%s' created successfully", w.ifce.GetInterfaceName())
return nil
}
func (w *wireGuard) applyPeerConfig() error {
log := w.svcLogger.WithField("op", "applyConfig")
log.Debugf("start to apply config to WireGuard device '%s'", w.ifce.GetInterfaceName())
if w.wgDevice == nil {
return errors.New("wgDevice is nil, please init WG device first")
}
wgTypedPeerConfigs, err := parseAndValidatePeerConfigs(w.ifce.GetParsedPeers())
if err != nil {
return errors.Join(errors.New("parse/validate peers"), err)
}
log.Debugf("wgTypedPeerConfigs: %v", wgTypedPeerConfigs)
uapiConfigString := generateUAPIConfigString(w.ifce, w.ifce.GetParsedPrivKey(), wgTypedPeerConfigs, !w.running, false)
log.Debugf("uapiBuilder: %s", uapiConfigString)
log.Debugf("calling IpcSet...")
if err = w.wgDevice.IpcSet(uapiConfigString); err != nil {
return errors.Join(errors.New("IpcSet error"), err)
}
log.Debugf("IpcSet completed successfully")
return nil
}
func (w *wireGuard) cleanupWGDevice() {
log := w.svcLogger.WithField("op", "cleanupWGDevice")
if w.wgDevice != nil {
w.wgDevice.Close()
} else if w.tunDevice != nil {
w.tunDevice.Close()
}
w.wgDevice = nil
w.tunDevice = nil
log.Debug("Cleanup WG device complete.")
}

View File

@@ -0,0 +1,30 @@
//go:build !windows
// +build !windows
package wg
import (
"errors"
"fmt"
"net/netip"
)
func (w *wireGuard) applyFirewallRulesLocked() error {
if w.useGvisorNet || w.fwManager == nil {
return nil
}
prefix, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
if err != nil {
return errors.Join(fmt.Errorf("parse local address '%s' for firewall", w.ifce.GetLocalAddress()), err)
}
return w.fwManager.ApplyRelayRules(w.ifce.GetInterfaceName(), prefix.Masked().String())
}
func (w *wireGuard) cleanupFirewallRulesLocked() error {
if w.useGvisorNet || w.fwManager == nil {
return nil
}
return w.fwManager.Cleanup(w.ifce.GetInterfaceName())
}

View File

@@ -0,0 +1,97 @@
//go:build !windows
// +build !windows
package wg
import (
"errors"
"net/netip"
"reflect"
"unsafe"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
func (w *wireGuard) initGvisorNetwork() error {
log := w.svcLogger.WithField("op", "initGvisorNetwork")
if w.gvisorNet == nil {
return errors.New("gvisorNet is nil, cannot initialize network")
}
// wg-go dose not expose the stack field, so we need to use reflection to access it
netValue := reflect.ValueOf(w.gvisorNet).Elem()
stackField := netValue.FieldByName("stack")
if !stackField.IsValid() {
return errors.New("cannot find stack field in gvisorNet")
}
stackPtrValue := reflect.NewAt(stackField.Type(), unsafe.Pointer(stackField.UnsafeAddr())).Elem()
if !stackPtrValue.IsValid() || stackPtrValue.IsNil() {
return errors.New("gvisor stack is nil or invalid")
}
gvisorStack := stackPtrValue.Interface().(*stack.Stack)
if gvisorStack == nil {
return errors.New("gvisor stack is nil after conversion")
}
log.Infof("successfully accessed gvisor stack, enabling IP forwarding")
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
log.Warnf("failed to enable IPv4 forwarding: %v, relay may not work", err)
} else {
log.Infof("IPv4 forwarding enabled for gvisor netstack")
}
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
log.Warnf("failed to enable IPv6 forwarding: %v", err)
} else {
log.Infof("IPv6 forwarding enabled for gvisor netstack")
}
for _, peer := range w.ifce.Peers {
for _, allowedIP := range peer.AllowedIps {
prefix, err := netip.ParsePrefix(allowedIP)
if err != nil {
log.WithError(err).Warnf("failed to parse allowed IP: %s", allowedIP)
continue
}
addr := tcpip.AddrFromSlice(prefix.Addr().AsSlice())
ones := prefix.Bits()
maskBytes := make([]byte, len(prefix.Addr().AsSlice()))
for i := 0; i < len(maskBytes); i++ {
if ones >= 8 {
maskBytes[i] = 0xff
ones -= 8
} else if ones > 0 {
maskBytes[i] = byte(0xff << (8 - ones))
ones = 0
}
}
subnet, err := tcpip.NewSubnet(addr, tcpip.MaskFromBytes(maskBytes))
if err != nil {
log.WithError(err).Warnf("failed to create subnet for %s", allowedIP)
continue
}
route := tcpip.Route{
Destination: subnet,
NIC: 1,
}
gvisorStack.AddRoute(route)
log.Debugf("added route for peer allowed IP: %s via NIC 1", allowedIP)
}
}
log.Infof("gvisor netstack initialized with IP forwarding enabled")
return nil
}

View File

@@ -0,0 +1,80 @@
//go:build !windows
// +build !windows
package wg
import (
"errors"
"fmt"
"os"
"time"
"github.com/vishvananda/netlink"
)
func (w *wireGuard) initNetwork() error {
log := w.svcLogger.WithField("op", "initNetwork")
// 等待 TUN 设备在内核中完全注册,避免竞态条件
var link netlink.Link
var err error
maxRetries := 10
for i := 0; i < maxRetries; i++ {
link, err = netlink.LinkByName(w.ifce.GetInterfaceName())
if err == nil {
break
}
if i < maxRetries-1 {
log.Debugf("attempt %d: waiting for iface '%s' to be ready, will retry...", i+1, w.ifce.GetInterfaceName())
time.Sleep(100 * time.Millisecond)
}
}
if err != nil {
return errors.Join(fmt.Errorf("get iface '%s' via netlink after %d retries", w.ifce.GetInterfaceName(), maxRetries), err)
}
log.Debugf("successfully found interface '%s' via netlink", w.ifce.GetInterfaceName())
addr, err := netlink.ParseAddr(w.ifce.GetLocalAddress())
if err != nil {
return errors.Join(fmt.Errorf("parse local addr '%s' for netlink", w.ifce.GetLocalAddress()), err)
}
if err = netlink.AddrAdd(link, addr); err != nil && !os.IsExist(err) {
return errors.Join(fmt.Errorf("add IP '%s' to '%s'", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()), err)
} else if os.IsExist(err) {
log.Infof("IP %s already on '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
} else {
log.Infof("IP %s added to '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
}
if err = netlink.LinkSetMTU(link, int(w.ifce.GetInterfaceMtu())); err != nil {
log.Warnf("Set MTU %d on '%s' via netlink: %v. TUN MTU is %d.",
w.ifce.GetInterfaceMtu(), w.ifce.GetInterfaceName(), err, w.ifce.GetInterfaceMtu())
} else {
log.Infof("Iface '%s' MTU %d set via netlink.", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
}
if err = netlink.LinkSetUp(link); err != nil {
return errors.Join(fmt.Errorf("bring up iface '%s' via netlink", w.ifce.GetInterfaceName()), err)
}
log.Infof("Iface '%s' up via netlink.", w.ifce.GetInterfaceName())
return nil
}
func (w *wireGuard) cleanupNetwork() {
log := w.svcLogger.WithField("op", "cleanupNetwork")
if w.useGvisorNet {
log.Infof("skip network cleanup for gvisor netstack")
return
}
link, err := netlink.LinkByName(w.ifce.GetInterfaceName())
if err == nil {
if err := netlink.LinkSetDown(link); err != nil {
log.Warnf("Failed to LinkSetDown '%s' after wgDevice.Up() error: %v", w.ifce.GetInterfaceName(), err)
}
}
log.Debug("Cleanup network complete.")
}

View File

@@ -0,0 +1,133 @@
package wg
import (
"testing"
"github.com/VaalaCat/frp-panel/defs"
"github.com/VaalaCat/frp-panel/pb"
)
func TestMergeConnectablePeersFromAdj_AddsMissingPeerWithEmptyAllowedIPs(t *testing.T) {
ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{
Id: 1,
Adjs: map[uint32]*pb.WireGuardLinks{
1: {Links: []*pb.WireGuardLink{
{ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "1.2.3.4", Port: 51820}},
}},
},
}}
desired := []*defs.WireGuardPeerConfig{
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
Id: 3,
PublicKey: "pk-3",
AllowedIps: []string{
"10.0.0.3/32",
},
}},
}
known := []*defs.WireGuardPeerConfig{
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
Id: 2,
PublicKey: "pk-2",
AllowedIps: []string{
"10.0.0.2/32",
},
PersistentKeepalive: 0, // should be defaulted by parseAndValidatePeerConfig
}},
}
got := mergeConnectablePeersFromAdj(ifce, desired, known)
var found *defs.WireGuardPeerConfig
for _, p := range got {
if p != nil && p.GetId() == 2 {
found = p
break
}
}
if found == nil {
t.Fatalf("expected peer id=2 to be added")
}
if len(found.GetAllowedIps()) != 0 {
t.Fatalf("expected allowed_ips empty, got=%v", found.GetAllowedIps())
}
if found.GetEndpoint() == nil || found.GetEndpoint().GetHost() != "1.2.3.4" || found.GetEndpoint().GetPort() != 51820 {
t.Fatalf("expected endpoint from adj to be applied, got=%v", found.GetEndpoint())
}
if found.GetPersistentKeepalive() == 0 {
t.Fatalf("expected persistent_keepalive defaulted, got=0")
}
}
func TestMergeConnectablePeersFromAdj_DoesNotOverrideExistingDesiredPeer(t *testing.T) {
ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{
Id: 1,
Adjs: map[uint32]*pb.WireGuardLinks{
1: {Links: []*pb.WireGuardLink{
{ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "9.9.9.9", Port: 9999}},
}},
},
}}
desired := []*defs.WireGuardPeerConfig{
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
Id: 2,
PublicKey: "pk-2",
AllowedIps: []string{
"10.0.0.2/32",
},
Endpoint: &pb.Endpoint{Host: "1.1.1.1", Port: 1111},
}},
}
got := mergeConnectablePeersFromAdj(ifce, desired, nil)
if len(got) != 1 {
t.Fatalf("expected no extra peers added, got len=%d", len(got))
}
if len(got[0].GetAllowedIps()) != 1 || got[0].GetAllowedIps()[0] != "10.0.0.2/32" {
t.Fatalf("expected desired allowed_ips preserved, got=%v", got[0].GetAllowedIps())
}
// endpoint should not be overridden because peer already existed in desired list
if got[0].GetEndpoint() == nil || got[0].GetEndpoint().GetHost() != "1.1.1.1" {
t.Fatalf("expected desired endpoint preserved, got=%v", got[0].GetEndpoint())
}
}
func TestMergeConnectablePeersFromAdj_UsesEndpointWireguardIDWhenPeerIDMissing(t *testing.T) {
ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{
Id: 1,
Adjs: map[uint32]*pb.WireGuardLinks{
1: {Links: []*pb.WireGuardLink{
{ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "2.2.2.2", Port: 2222}},
}},
},
}}
desired := []*defs.WireGuardPeerConfig{}
known := []*defs.WireGuardPeerConfig{
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
Id: 0, // 模拟peer.id 未下发
PublicKey: "pk-2",
Endpoint: &pb.Endpoint{WireguardId: 2, Host: "old", Port: 1},
}},
}
got := mergeConnectablePeersFromAdj(ifce, desired, known)
if len(got) != 1 {
t.Fatalf("expected 1 peer added, got len=%d", len(got))
}
if got[0].GetPublicKey() != "pk-2" {
t.Fatalf("expected pk-2, got=%s", got[0].GetPublicKey())
}
// endpoint should be overridden by adj's to_endpoint
if got[0].GetEndpoint() == nil || got[0].GetEndpoint().GetHost() != "2.2.2.2" || got[0].GetEndpoint().GetPort() != 2222 {
t.Fatalf("expected endpoint from adj, got=%v", got[0].GetEndpoint())
}
if len(got[0].GetAllowedIps()) != 0 {
t.Fatalf("expected allowed_ips empty, got=%v", got[0].GetAllowedIps())
}
}

View File

@@ -0,0 +1,362 @@
//go:build !windows
// +build !windows
package wg
import (
"fmt"
"google.golang.org/protobuf/proto"
"github.com/VaalaCat/frp-panel/defs"
"github.com/VaalaCat/frp-panel/pb"
)
// peer 预连接/常驻补齐逻辑
//
// 目标:在拓扑(adjs)变化时,确保本节点对“可直连/可连接”的 peer 已配置到 wg 设备,
// 但 AllowedIPs 为空(只保持连接,不承载路由),从而避免路由变化导致频繁 remove+add 造成断链。
func (w *wireGuard) onPeersChangedLocked(reason string) {
if w == nil {
return
}
if err := w.ensureConnectablePeersLocked(); err != nil {
w.svcLogger.WithError(err).WithField("op", "onPeersChanged").Warnf("ensure connectable peers failed (%s)", reason)
}
}
func (w *wireGuard) cleanupPreconnectPeersLocked() {
if w == nil {
return
}
if len(w.preconnectPeers) == 0 {
return
}
exists := make(map[uint32]struct{}, len(w.ifce.Peers))
for _, p := range w.ifce.GetParsedPeers() {
if p == nil {
continue
}
if id := p.GetId(); id != 0 {
exists[id] = struct{}{}
}
if p.GetEndpoint() != nil {
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
exists[id] = struct{}{}
}
}
}
for id := range w.preconnectPeers {
if _, ok := exists[id]; !ok {
delete(w.preconnectPeers, id)
}
}
}
func (w *wireGuard) indexPeerDirectoryLocked(p *defs.WireGuardPeerConfig) {
if p == nil || p.WireGuardPeerConfig == nil {
return
}
// 1) peer.id
if id := p.GetId(); id != 0 {
w.peerDirectory[id] = proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig)
}
// 2) endpoint.wireguard_id部分场景 peer.id 可能未填,但 endpoint 带 wireguard_id
if p.GetEndpoint() != nil {
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
w.peerDirectory[id] = proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig)
}
}
}
func (w *wireGuard) deletePeerDirectoryLocked(p *defs.WireGuardPeerConfig) {
if p == nil {
return
}
if id := p.GetId(); id != 0 {
delete(w.peerDirectory, id)
}
if p.GetEndpoint() != nil {
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
delete(w.peerDirectory, id)
}
}
}
// ensureConnectablePeersLocked 保证本节点在 wg 设备里已配置所有“当前可直连/可连接”的 peer
// 但这些补齐 peer 的 AllowedIPs 为空(只保持连接,不承载路由)。
//
// 约束:必须在持有 w.Lock() 的情况下调用。
func (w *wireGuard) ensureConnectablePeersLocked() error {
if w == nil || w.ifce == nil || w.wgDevice == nil {
return nil
}
localID := w.ifce.GetId()
if localID == 0 {
return nil
}
adjs := w.ifce.GetAdjs()
if adjs == nil {
return nil
}
localLinks, ok := adjs[localID]
if !ok || localLinks == nil || len(localLinks.GetLinks()) == 0 {
return nil
}
// 当前可直连/可连接的 peer id 集合(来自 adj
connectable := make(map[uint32]struct{}, len(localLinks.GetLinks()))
for _, l := range localLinks.GetLinks() {
if l == nil {
continue
}
toID := l.GetToWireguardId()
if toID == 0 || toID == localID {
continue
}
connectable[toID] = struct{}{}
}
log := w.svcLogger.WithField("op", "ensureConnectablePeers")
log.Debugf("ensure connectable peers: local=%d connectable=%d peers=%d preconnect=%d directory=%d",
localID, len(connectable), len(w.ifce.Peers), len(w.preconnectPeers), len(w.peerDirectory))
// 当前已配置的 peer用 peer.id 与 endpoint.wireguard_id 双索引,避免 peer.id 缺失导致重复补齐
exists := make(map[uint32]struct{}, len(w.ifce.Peers))
for _, p := range w.ifce.GetParsedPeers() {
if p == nil {
continue
}
if id := p.GetId(); id != 0 {
exists[id] = struct{}{}
}
if p.GetEndpoint() != nil {
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
exists[id] = struct{}{}
}
}
}
uapiBuilder := NewUAPIBuilder()
added := 0
removed := 0
skippedNoBase := 0
skippedAlready := 0
// 先清理:相比上次,本次拓扑中已“完全不可直连”的 peer需要彻底从设备移除
// 仅清理“AllowedIPs 为空”的 peer也就是不承载路由、只为保持连接而存在的 peer
newPeers := make([]*pb.WireGuardPeerConfig, 0, len(w.ifce.Peers))
for _, raw := range w.ifce.GetParsedPeers() {
if raw == nil || raw.WireGuardPeerConfig == nil {
continue
}
// 仅对 AllowedIPs 为空的 peer 做自动清理
if len(raw.GetAllowedIps()) != 0 {
newPeers = append(newPeers, raw.WireGuardPeerConfig)
continue
}
var peerID uint32
if raw.GetId() != 0 {
peerID = raw.GetId()
} else if raw.GetEndpoint() != nil && raw.GetEndpoint().GetWireguardId() != 0 {
peerID = raw.GetEndpoint().GetWireguardId()
}
// 无法识别 peer id保守起见不清理
if peerID == 0 {
newPeers = append(newPeers, raw.WireGuardPeerConfig)
continue
}
// 当前拓扑不可直连:彻底移除
if _, ok := connectable[peerID]; !ok {
log.Debugf("preconnect remove: peerID=%d pk=%s (reason=not_connectable)", peerID, truncate(raw.GetPublicKey(), 10))
uapiBuilder.RemovePeerByKey(raw.GetParsedPublicKey())
delete(w.preconnectPeers, peerID)
delete(exists, peerID)
removed++
continue
}
newPeers = append(newPeers, raw.WireGuardPeerConfig)
}
// 如果有清理发生,先更新本地缓存(设备更新在最后统一 IpcSet
if removed > 0 {
w.ifce.Peers = newPeers
}
for _, l := range localLinks.GetLinks() {
if l == nil {
continue
}
toID := l.GetToWireguardId()
if toID == 0 || toID == localID {
continue
}
if _, ok := exists[toID]; ok {
skippedAlready++
continue
}
base, ok := w.peerDirectory[toID]
if !ok || base == nil || base.GetPublicKey() == "" {
skippedNoBase++
continue
}
cloned := &defs.WireGuardPeerConfig{WireGuardPeerConfig: proto.Clone(base).(*pb.WireGuardPeerConfig)}
cloned.AllowedIps = nil
if l.GetToEndpoint() != nil {
cloned.Endpoint = l.GetToEndpoint()
}
if _, err := parseAndValidatePeerConfig(cloned); err != nil {
continue
}
log.Debugf("preconnect add: peerID=%d pk=%s endpoint=%s",
toID, truncate(cloned.GetPublicKey(), 10), endpointForLog(cloned.GetEndpoint()))
uapiBuilder.AddPeerConfig(cloned)
w.ifce.Peers = append(w.ifce.Peers, cloned.WireGuardPeerConfig)
w.indexPeerDirectoryLocked(cloned)
exists[toID] = struct{}{}
w.preconnectPeers[toID] = struct{}{}
added++
}
if added == 0 && removed == 0 {
log.Debugf("ensure result: no-op (skippedAlready=%d skippedNoBase=%d)", skippedAlready, skippedNoBase)
return nil
}
log.Debugf("ensure result: add=%d remove=%d skippedAlready=%d skippedNoBase=%d", added, removed, skippedAlready, skippedNoBase)
if err := w.wgDevice.IpcSet(uapiBuilder.Build()); err != nil {
log.WithError(err).Debugf("ensure IpcSet failed (add=%d remove=%d)", added, removed)
return err
}
return nil
}
// mergeConnectablePeersFromAdj 将本节点 adj 图中可直连的 peer 合并进目标 peers。
//
// - **只在目标 peers 中缺失时才补齐**(按 PublicKey 去重),避免覆盖由路由规划器计算出的 AllowedIPs。
// - **补齐的 peer AllowedIPs 置空**,确保不会引入额外路由。
// - 如链路显式携带 to_endpoint则优先用它覆盖 peer.endpoint用于快速恢复直连
//
// knownPeers 用于在目标 peers 缺失时提供“可用的 peer 基础信息”(公钥/预共享密钥/端点等),通常传 oldPeers。
func mergeConnectablePeersFromAdj(ifce *defs.WireGuardConfig, desiredPeers []*defs.WireGuardPeerConfig, knownPeers []*defs.WireGuardPeerConfig) []*defs.WireGuardPeerConfig {
if ifce == nil {
return desiredPeers
}
localID := ifce.GetId()
if localID == 0 {
return desiredPeers
}
adjs := ifce.GetAdjs()
if adjs == nil {
return desiredPeers
}
localLinks, ok := adjs[localID]
if !ok || localLinks == nil || len(localLinks.GetLinks()) == 0 {
return desiredPeers
}
desiredByPK := make(map[string]*defs.WireGuardPeerConfig, len(desiredPeers))
for _, p := range desiredPeers {
if p == nil || p.GetPublicKey() == "" {
continue
}
desiredByPK[p.GetPublicKey()] = p
}
// build id -> peer 基础信息索引(优先 desired其次 known
idToPeer := make(map[uint32]*defs.WireGuardPeerConfig, len(desiredPeers)+len(knownPeers))
putPeerIDs := func(p *defs.WireGuardPeerConfig) {
if p == nil {
return
}
// 1) peer.id
if id := p.GetId(); id != 0 {
if _, exists := idToPeer[id]; !exists {
idToPeer[id] = p
}
}
// 2) endpoint.wireguard_id有些下发场景可能不填 peer.id但 endpoint 里带 wireguard_id
if p.GetEndpoint() != nil {
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
if _, exists := idToPeer[id]; !exists {
idToPeer[id] = p
}
}
}
}
for _, p := range desiredPeers {
putPeerIDs(p)
}
for _, p := range knownPeers {
putPeerIDs(p)
}
// 仅补齐adj 中的直连节点to_wireguard_id
for _, l := range localLinks.GetLinks() {
if l == nil {
continue
}
toID := l.GetToWireguardId()
if toID == 0 || toID == localID {
continue
}
base, ok := idToPeer[toID]
if !ok || base == nil || base.GetPublicKey() == "" {
continue
}
if _, exists := desiredByPK[base.GetPublicKey()]; exists {
// 已在目标列表中(通常含有路由规划器计算出的 AllowedIPs不覆盖。
continue
}
// 复制一份(避免直接改 oldPeers / knownPeers 的底层 pb 指针)
cloned := clonePeerConfig(base)
// 不分配路由AllowedIPs 置空
cloned.AllowedIps = nil
// 显式链路 endpoint 优先
if l.GetToEndpoint() != nil {
cloned.Endpoint = l.GetToEndpoint()
}
// 确保 keepalive/AllowedIPs 格式一致(复用现有校验逻辑)
if _, err := parseAndValidatePeerConfig(cloned); err != nil {
continue
}
desiredPeers = append(desiredPeers, cloned)
desiredByPK[cloned.GetPublicKey()] = cloned
}
return desiredPeers
}
func clonePeerConfig(p *defs.WireGuardPeerConfig) *defs.WireGuardPeerConfig {
if p == nil || p.WireGuardPeerConfig == nil {
return &defs.WireGuardPeerConfig{}
}
// 使用 proto.Clone 避免直接拷贝 protoimpl.MessageState内部含 mutex会触发拷贝锁值的告警
cp, _ := proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig)
if cp == nil {
return &defs.WireGuardPeerConfig{}
}
return &defs.WireGuardPeerConfig{WireGuardPeerConfig: cp}
}
func endpointForLog(ep *pb.Endpoint) string {
if ep == nil {
return ""
}
if ep.GetUri() != "" {
return ep.GetUri()
}
if ep.GetHost() != "" || ep.GetPort() != 0 {
return fmt.Sprintf("%s:%d", ep.GetHost(), ep.GetPort())
}
return ""
}

View File

@@ -0,0 +1,22 @@
//go:build !windows
// +build !windows
package wg
import "time"
const (
ReportInterval = time.Second * 60
)
func (w *wireGuard) reportStatusTask() {
for {
select {
case <-w.ctx.Done():
return
default:
w.pingPeers()
time.Sleep(ReportInterval)
}
}
}

View File

@@ -0,0 +1,52 @@
//go:build !windows
// +build !windows
package wg
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/VaalaCat/frp-panel/defs"
"github.com/VaalaCat/frp-panel/services/wg/multibind"
"github.com/VaalaCat/frp-panel/services/wg/transport/ws"
"golang.zx2c4.com/wireguard/conn"
)
func (w *wireGuard) initTransports() error {
log := w.svcLogger.WithField("op", "initTransports")
wsTrans := ws.NewWSBind(w.ctx)
w.multiBind = multibind.NewMultiBind(
w.svcLogger,
multibind.NewTransport(conn.NewDefaultBind(), "udp"),
multibind.NewTransport(wsTrans, "ws"),
)
engine := gin.New()
engine.Any(defs.DefaultWSHandlerPath, func(c *gin.Context) {
err := wsTrans.HandleHTTP(c.Writer, c.Request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
})
// if ws listen port not set, use wg listen port, share tcp and udp port
listenPort := w.ifce.GetWsListenPort()
if listenPort == 0 {
listenPort = w.ifce.GetListenPort()
}
go func() {
if err := engine.Run(fmt.Sprintf(":%d", listenPort)); err != nil {
w.svcLogger.WithError(err).Errorf("failed to run gin engine for ws transport on port %d", listenPort)
}
}()
log.Infof("WS transport engine running on port %d", listenPort)
return nil
}

View File

@@ -0,0 +1,48 @@
//go:build !windows
// +build !windows
package wg
import (
"context"
"sync"
"github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/VaalaCat/frp-panel/defs"
"github.com/VaalaCat/frp-panel/pb"
"github.com/VaalaCat/frp-panel/services/app"
"github.com/VaalaCat/frp-panel/services/wg/multibind"
"github.com/VaalaCat/frp-panel/utils"
)
var (
_ app.WireGuard = (*wireGuard)(nil)
)
type wireGuard struct {
sync.RWMutex
ifce *defs.WireGuardConfig
endpointPingMap *utils.SyncMap[uint32, uint32] // ms
virtAddrPingMap *utils.SyncMap[string, uint32] // ms
peerDirectory map[uint32]*pb.WireGuardPeerConfig
// 仅用于“预连接/保持连接”的 peerAllowedIPs 为空),用于后续根据拓扑变化做增删
preconnectPeers map[uint32]struct{}
wgDevice *device.Device
tunDevice tun.Device
multiBind *multibind.MultiBind
gvisorNet *netstack.Net
fwManager *firewallManager
running bool
useGvisorNet bool // if true, use gvisor netstack
svcLogger *logrus.Entry
ctx *app.Context
cancel context.CancelFunc
}

View File

@@ -19,5 +19,5 @@ remotePort = 6000`)
if err := LoadConfigureFromContent(content, allCfg, true); err != nil {
t.Error(err)
}
t.Errorf("%+v", allCfg)
t.Logf("%+v", allCfg)
}

View File

@@ -11,5 +11,5 @@ func TestAllocateIP(t *testing.T) {
if err != nil {
t.Errorf("AllocateIP() failed: %v", err)
}
t.Errorf("AllocateIP() = %v", ip)
t.Logf("AllocateIP() = %v", ip)
}