mirror of
https://github.com/VaalaCat/frp-panel.git
synced 2025-12-24 11:51:06 +08:00
feat: support keep connection to all avaliable peer and add handshake time
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/VaalaCat/frp-panel/defs"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
"github.com/VaalaCat/frp-panel/services/app"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func PullWireGuards(appInstance app.Application, clientID, clientSecret string) error {
|
||||
@@ -32,7 +33,8 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string)
|
||||
}
|
||||
|
||||
log.Debugf("client [%s] has [%d] wireguards, check their status", clientID, len(resp.GetWireguardConfigs()))
|
||||
log.Tracef("wireguardConfigs: %v", resp.GetWireguardConfigs())
|
||||
log.Tracef("wireguardConfigs: %s", resp.String())
|
||||
|
||||
wgMgr := ctx.GetApp().GetWireGuardManager()
|
||||
successCnt := 0
|
||||
for _, wireGuard := range resp.GetWireguardConfigs() {
|
||||
@@ -43,7 +45,7 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string)
|
||||
wgMgr.RemoveService(wireGuard.GetInterfaceName())
|
||||
} else {
|
||||
log.Debugf("wireguard [%s] already exists, skip create, update peers if need", wireGuard.GetInterfaceName())
|
||||
wgSvc.PatchPeers(wgCfg.GetParsedPeers())
|
||||
syncExistingWireGuard(log, wgSvc, wgCfg)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -65,3 +67,18 @@ func PullWireGuards(appInstance app.Application, clientID, clientSecret string)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncExistingWireGuard(log *logrus.Entry, wgSvc app.WireGuard, wgCfg *defs.WireGuardConfig) {
|
||||
if wgSvc == nil || wgCfg == nil {
|
||||
return
|
||||
}
|
||||
// 主链路:先更新 adjs,再 patch peers。wg 内部会基于最新拓扑做预连接补齐/不可直连清理。
|
||||
if err := wgSvc.UpdateAdjs(wgCfg.GetAdjs()); err != nil {
|
||||
log.WithError(err).Warn("update adjs failed while syncing existing wireguard")
|
||||
return
|
||||
}
|
||||
if _, err := wgSvc.PatchPeers(wgCfg.GetParsedPeers()); err != nil {
|
||||
log.WithError(err).Warn("patch peers failed while syncing existing wireguard")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
"github.com/VaalaCat/frp-panel/services/app"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.UpdateWireGuardResponse, error) {
|
||||
@@ -38,19 +39,15 @@ func AddPeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardReque
|
||||
|
||||
log.Debugf("add peer, peer_config: %+v", req.GetWireguardConfig().GetPeers())
|
||||
|
||||
for _, peer := range req.GetWireguardConfig().GetPeers() {
|
||||
err := wgSvc.AddPeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("add peer failed")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
|
||||
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
|
||||
// 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
|
||||
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "add peer", func(peer *pb.WireGuardPeerConfig) error {
|
||||
return wgSvc.AddPeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
|
||||
})
|
||||
|
||||
log.Infof("add peer done")
|
||||
|
||||
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
|
||||
@@ -61,19 +58,15 @@ func RemovePeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
|
||||
|
||||
log.Debugf("remove peer, peer_config: %+v", req.GetWireguardConfig().GetPeers())
|
||||
|
||||
for _, peer := range req.GetWireguardConfig().GetPeers() {
|
||||
err := wgSvc.RemovePeer(peer.GetPublicKey())
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("remove peer failed")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
|
||||
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
|
||||
// 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
|
||||
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "remove peer routes", func(peer *pb.WireGuardPeerConfig) error {
|
||||
return wgSvc.RemovePeer(peer.GetPublicKey())
|
||||
})
|
||||
|
||||
log.Infof("remove peer done")
|
||||
|
||||
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
|
||||
@@ -84,19 +77,15 @@ func UpdatePeer(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
|
||||
|
||||
log.Debugf("update peer, peer_config: %+v", req.GetWireguardConfig().GetPeers())
|
||||
|
||||
for _, peer := range req.GetWireguardConfig().GetPeers() {
|
||||
err := wgSvc.UpdatePeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("update peer failed")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
|
||||
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
|
||||
// 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
|
||||
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applyPeerOps(log, req.GetWireguardConfig().GetPeers(), "update peer", func(peer *pb.WireGuardPeerConfig) error {
|
||||
return wgSvc.UpdatePeer(&defs.WireGuardPeerConfig{WireGuardPeerConfig: peer})
|
||||
})
|
||||
|
||||
log.Infof("update peer done")
|
||||
|
||||
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
|
||||
@@ -107,6 +96,11 @@ func PatchPeers(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
|
||||
|
||||
log.Debugf("patch peers, peer_config: %+v", req.GetWireguardConfig().GetPeers())
|
||||
|
||||
// 主链路:先更新 adjs(保证后续 wg 内部的预连接/清理逻辑使用最新拓扑)
|
||||
if err := updateAdjsFirst(log, wgSvc, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgCfg := &defs.WireGuardConfig{WireGuardConfig: req.GetWireguardConfig()}
|
||||
|
||||
diffResp, err := wgSvc.PatchPeers(wgCfg.GetParsedPeers())
|
||||
@@ -115,14 +109,32 @@ func PatchPeers(ctx *app.Context, wgSvc app.WireGuard, req *pb.UpdateWireGuardRe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
|
||||
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("patch peers done, add_peers: %+v, remove_peers: %+v",
|
||||
lo.Map(diffResp.AddPeers, func(item *defs.WireGuardPeerConfig, _ int) string { return item.GetClientId() }),
|
||||
lo.Map(diffResp.RemovePeers, func(item *defs.WireGuardPeerConfig, _ int) string { return item.GetClientId() }))
|
||||
|
||||
return &pb.UpdateWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil
|
||||
}
|
||||
|
||||
func updateAdjsFirst(log *logrus.Entry, wgSvc app.WireGuard, req *pb.UpdateWireGuardRequest) error {
|
||||
if req == nil || req.GetWireguardConfig() == nil {
|
||||
return nil
|
||||
}
|
||||
if err := wgSvc.UpdateAdjs(req.GetWireguardConfig().GetAdjs()); err != nil {
|
||||
log.WithError(err).Errorf("update adjs failed, adjs: %+v", req.GetWireguardConfig().GetAdjs())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPeerOps(log *logrus.Entry, peers []*pb.WireGuardPeerConfig, op string, fn func(peer *pb.WireGuardPeerConfig) error) {
|
||||
for _, peer := range peers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
if err := fn(peer); err != nil {
|
||||
log.WithError(err).Errorf("%s failed", op)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/models"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
"github.com/VaalaCat/frp-panel/services/app"
|
||||
"github.com/VaalaCat/frp-panel/services/dao"
|
||||
wgsvc "github.com/VaalaCat/frp-panel/services/wg"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest) (*pb.ListClientWireGuardsResponse, error) {
|
||||
@@ -79,6 +82,15 @@ func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建 network 内 WireGuard 索引,用于补齐“可直连 peer 的基础配置”
|
||||
idToWg := make(map[uint32]*models.WireGuard, len(networkPeers[wgCfg.NetworkID]))
|
||||
for _, item := range networkPeers[wgCfg.NetworkID] {
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
idToWg[uint32(item.ID)] = item
|
||||
}
|
||||
|
||||
r := wgCfg.ToPB()
|
||||
r.Peers = lo.Map(networkPeerConfigsMap[wgCfg.NetworkID][wgCfg.ID],
|
||||
func(peerCfg *pb.WireGuardPeerConfig, _ int) *pb.WireGuardPeerConfig {
|
||||
@@ -87,9 +99,89 @@ func ListClientWireGuards(ctx *app.Context, req *pb.ListClientWireGuardsRequest)
|
||||
|
||||
r.Adjs = adjsToPB(networkAllEdgesMap[wgCfg.NetworkID])
|
||||
|
||||
fillConnectablePeersAsPreconnect(r, uint32(wgCfg.ID), idToWg, log)
|
||||
sortPeersStable(r)
|
||||
|
||||
return r
|
||||
}),
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// fillConnectablePeersAsPreconnect 将 adj[localID] 中可直连的 peer 补齐到 r.peers 中,并将 AllowedIPs 置空(只预连接,不承载路由)。
|
||||
func fillConnectablePeersAsPreconnect(r *pb.WireGuardConfig, localID uint32, idToWg map[uint32]*models.WireGuard, log *logrus.Entry) {
|
||||
if r == nil || localID == 0 {
|
||||
return
|
||||
}
|
||||
exists := make(map[uint32]struct{}, len(r.GetPeers()))
|
||||
for _, p := range r.GetPeers() {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.GetId() != 0 {
|
||||
exists[p.GetId()] = struct{}{}
|
||||
}
|
||||
if p.GetEndpoint() != nil && p.GetEndpoint().GetWireguardId() != 0 {
|
||||
exists[p.GetEndpoint().GetWireguardId()] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
links := r.GetAdjs()[localID]
|
||||
if links == nil {
|
||||
return
|
||||
}
|
||||
for _, l := range links.GetLinks() {
|
||||
if l == nil {
|
||||
continue
|
||||
}
|
||||
toID := l.GetToWireguardId()
|
||||
if toID == 0 || toID == localID {
|
||||
continue
|
||||
}
|
||||
if _, ok := exists[toID]; ok {
|
||||
continue
|
||||
}
|
||||
remote, ok := idToWg[toID]
|
||||
if !ok || remote == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 优先使用链路显式 to_endpoint
|
||||
var specifiedEndpoint *models.Endpoint
|
||||
if l.GetToEndpoint() != nil {
|
||||
m := &models.Endpoint{}
|
||||
m.FromPB(l.GetToEndpoint())
|
||||
specifiedEndpoint = m
|
||||
}
|
||||
|
||||
base, err := remote.AsBasePeerConfig(specifiedEndpoint)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("failed to build base peer config for preconnect: local=%d to=%d", localID, toID)
|
||||
continue
|
||||
}
|
||||
base.AllowedIps = nil
|
||||
r.Peers = append(r.Peers, base)
|
||||
exists[toID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func sortPeersStable(r *pb.WireGuardConfig) {
|
||||
if r == nil || len(r.Peers) <= 1 {
|
||||
return
|
||||
}
|
||||
sort.SliceStable(r.Peers, func(i, j int) bool {
|
||||
pi := r.Peers[i]
|
||||
pj := r.Peers[j]
|
||||
if pi == nil && pj == nil {
|
||||
return false
|
||||
}
|
||||
if pi == nil {
|
||||
return false
|
||||
}
|
||||
if pj == nil {
|
||||
return true
|
||||
}
|
||||
return pi.GetClientId() < pj.GetClientId()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -48,19 +48,26 @@ func GetNetworkTopology(ctx *app.Context, req *pb.GetNetworkTopologyRequest) (*p
|
||||
ctx.GetApp().GetClientsManager(),
|
||||
)
|
||||
|
||||
var resp map[uint][]wg.Edge
|
||||
|
||||
if req.GetSpf() {
|
||||
resp, err = wg.NewDijkstraAllowedIPsPlanner(policy).BuildFinalGraph(peers, links)
|
||||
} else {
|
||||
resp, err = wg.NewDijkstraAllowedIPsPlanner(policy).BuildGraph(peers, links)
|
||||
// SPF 模式:展示“真实下发的路由表”(即 PeerConfig.AllowedIps),确保与实际一致。
|
||||
peerCfgs, allEdges, err := wg.PlanAllowedIPs(peers, links, policy)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("failed to plan allowed ips")
|
||||
return nil, err
|
||||
}
|
||||
adjs := peerConfigsToPBAdjs(peerCfgs, allEdges)
|
||||
|
||||
return &pb.GetNetworkTopologyResponse{
|
||||
Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"},
|
||||
Adjs: adjs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
resp, err := wg.NewDijkstraAllowedIPsPlanner(policy).BuildGraph(peers, links)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("failed to build graph")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
adjs := adjsToPB(resp)
|
||||
|
||||
return &pb.GetNetworkTopologyResponse{
|
||||
|
||||
@@ -33,3 +33,66 @@ func adjsToPB(resp map[uint][]wg.Edge) map[uint32]*pb.WireGuardLinks {
|
||||
|
||||
return adjs
|
||||
}
|
||||
|
||||
// peerConfigsToPBAdjs 将“真实下发给节点的路由表(PeerConfig.AllowedIps)”转换为拓扑展示所需的 Adjs。
|
||||
//
|
||||
// - routes: 直接使用 peerCfg.AllowedIps(这才是 WireGuard 实际使用的路由表)
|
||||
// - latency/up/down: 尽量从 allEdges(buildAdjacency 的直连边指标)中补齐,仅用于展示
|
||||
// - endpoint: 优先使用 peerCfg.Endpoint(与实际下发一致)
|
||||
func peerConfigsToPBAdjs(peerCfgs map[uint][]*pb.WireGuardPeerConfig, allEdges map[uint][]wg.Edge) map[uint32]*pb.WireGuardLinks {
|
||||
adjs := make(map[uint32]*pb.WireGuardLinks, len(peerCfgs))
|
||||
|
||||
for src, pcs := range peerCfgs {
|
||||
srcID := uint32(src)
|
||||
links := make([]*pb.WireGuardLink, 0, len(pcs))
|
||||
|
||||
// 构建 toID -> edge 指标索引(仅用于展示 latency/up)
|
||||
edgePBByTo := make(map[uint32]*pb.WireGuardLink, 16)
|
||||
for _, e := range allEdges[src] {
|
||||
epb := e.ToPB()
|
||||
edgePBByTo[epb.GetToWireguardId()] = epb
|
||||
}
|
||||
|
||||
for _, pc := range pcs {
|
||||
if pc == nil || pc.GetId() == 0 {
|
||||
continue
|
||||
}
|
||||
toID := pc.GetId()
|
||||
var latency uint32
|
||||
var up uint32
|
||||
if epb, ok := edgePBByTo[toID]; ok && epb != nil {
|
||||
latency = epb.GetLatencyMs()
|
||||
up = epb.GetUpBandwidthMbps()
|
||||
}
|
||||
|
||||
links = append(links, &pb.WireGuardLink{
|
||||
FromWireguardId: srcID,
|
||||
ToWireguardId: toID,
|
||||
LatencyMs: latency,
|
||||
UpBandwidthMbps: up,
|
||||
DownBandwidthMbps: 0, // 下面统一填充
|
||||
Active: true,
|
||||
ToEndpoint: pc.GetEndpoint(),
|
||||
Routes: pc.GetAllowedIps(),
|
||||
})
|
||||
}
|
||||
|
||||
adjs[srcID] = &pb.WireGuardLinks{Links: links}
|
||||
}
|
||||
|
||||
// 填充 down bandwidth(参考 adjsToPB 的做法:取反向边的 up)
|
||||
for id, links := range adjs {
|
||||
for _, link := range links.GetLinks() {
|
||||
toWireguardEdges, ok := adjs[uint32(link.GetToWireguardId())]
|
||||
if ok {
|
||||
for _, edge := range toWireguardEdges.GetLinks() {
|
||||
if edge.GetToWireguardId() == uint32(uint(id)) {
|
||||
link.DownBandwidthMbps = edge.GetUpBandwidthMbps()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return adjs
|
||||
}
|
||||
|
||||
@@ -5,11 +5,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseIPOrCIDRWithNetip(t *testing.T) {
|
||||
ip, cidr, _ := models.ParseIPOrCIDRWithNetip("192.168.1.1/24")
|
||||
t.Errorf("ip: %v, cidr: %v", ip, cidr)
|
||||
t.Logf("ip: %v, cidr: %v", ip, cidr)
|
||||
assert.Equal(t, ip, netip.MustParseAddr("192.168.1.1"))
|
||||
assert.Equal(t, cidr, netip.MustParsePrefix("192.168.1.1/24"))
|
||||
newcidr := netip.PrefixFrom(ip, 32)
|
||||
t.Errorf("newcidr: %v", newcidr)
|
||||
assert.Equal(t, newcidr, netip.MustParsePrefix("192.168.1.1/32"))
|
||||
t.Logf("newcidr: %v", newcidr)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func TestGenerateKeys(t *testing.T) {
|
||||
got := wg.GenerateKeys()
|
||||
// TODO: update the condition below to compare got with tt.want.
|
||||
if true {
|
||||
t.Errorf("GenerateKeys() = %v, want %v", got, tt.want)
|
||||
t.Logf("GenerateKeys() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
// RoutingPolicy 决定边权重的计算方式。
|
||||
// cost = LatencyWeight*latency_ms + InverseBandwidthWeight*(1/max(up_mbps,1e-6)) + HopWeight
|
||||
// cost = LatencyWeight*latency_ms + InverseBandwidthWeight*(1/max(up_mbps,1e-6)) + HopWeight + HandshakePenalty
|
||||
type RoutingPolicy struct {
|
||||
LatencyWeight float64
|
||||
InverseBandwidthWeight float64
|
||||
@@ -23,6 +23,10 @@ type RoutingPolicy struct {
|
||||
DefaultEndpointUpMbps uint32
|
||||
DefaultEndpointLatencyMs uint32
|
||||
OfflineThreshold time.Duration
|
||||
// HandshakeStaleThreshold/HandshakeStalePenalty 用于抑制“握手过旧”的链路被选为最短路。
|
||||
// 仅在能从 runtimeInfo 中找到对应 peer 的 last_handshake_time_sec 时生效;否则不惩罚(避免误伤)。
|
||||
HandshakeStaleThreshold time.Duration
|
||||
HandshakeStalePenalty float64
|
||||
|
||||
ACL *ACL
|
||||
NetworkTopologyCache app.NetworkTopologyCache
|
||||
@@ -42,9 +46,12 @@ func DefaultRoutingPolicy(acl *ACL, networkTopologyCache app.NetworkTopologyCach
|
||||
DefaultEndpointUpMbps: 50,
|
||||
DefaultEndpointLatencyMs: 30,
|
||||
OfflineThreshold: 2 * time.Minute,
|
||||
ACL: acl,
|
||||
NetworkTopologyCache: networkTopologyCache,
|
||||
CliMgr: cliMgr,
|
||||
// 默认启用一个温和的“握手过旧惩罚”:优先选择近期有握手的链路,但不至于强制剔除路径。
|
||||
HandshakeStaleThreshold: 5 * time.Minute,
|
||||
HandshakeStalePenalty: 30.0,
|
||||
ACL: acl,
|
||||
NetworkTopologyCache: networkTopologyCache,
|
||||
CliMgr: cliMgr,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,12 +86,19 @@ func (p *dijkstraAllowedIPsPlanner) Compute(peers []*models.WireGuard, links []*
|
||||
idToPeer, order := buildNodeIndex(peers)
|
||||
adj := buildAdjacency(order, idToPeer, links, p.policy)
|
||||
spfAdj := filterAdjacencyForSPF(order, adj, p.policy)
|
||||
// 路由(AllowedIPs)依赖 WireGuard 的“源地址校验”:下一跳收到的包会按“来自哪个 peer”做匹配,
|
||||
// 并校验 inner packet 的 source IP 是否落在该 peer 的 AllowedIPs 中。
|
||||
// 因此用于承载路由的直连边必须是双向的:若存在单向边,最短路会产生单向选路,导致中间节点丢包。
|
||||
spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj)
|
||||
aggByNode, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy)
|
||||
result, err := assemblePeerConfigs(order, aggByNode, edgeInfoMap, idToPeer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
fillIsolates(order, result)
|
||||
if err := ensureRoutingPeerSymmetry(order, result, idToPeer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 填充没有链路的节点
|
||||
for _, id := range order {
|
||||
@@ -112,6 +126,7 @@ func (p *dijkstraAllowedIPsPlanner) BuildFinalGraph(peers []*models.WireGuard, l
|
||||
idToPeer, order := buildNodeIndex(peers)
|
||||
adj := buildAdjacency(order, idToPeer, links, p.policy)
|
||||
spfAdj := filterAdjacencyForSPF(order, adj, p.policy)
|
||||
spfAdj = filterAdjacencyForSymmetricLinks(order, spfAdj)
|
||||
routesInfoMap, edgeInfoMap := runAllPairsDijkstra(order, spfAdj, idToPeer, p.policy)
|
||||
|
||||
ret := map[uint][]Edge{}
|
||||
@@ -175,6 +190,108 @@ func (e *Edge) ToPB() *pb.WireGuardLink {
|
||||
return link
|
||||
}
|
||||
|
||||
// filterAdjacencyForSymmetricLinks 仅保留“存在反向直连边”的邻接(用于 SPF)。
|
||||
// 这样最短路产生的每一步转发 hop 都对应一个双向直连 peer,避免出现单向路由导致的丢包。
|
||||
func filterAdjacencyForSymmetricLinks(order []uint, adj map[uint][]Edge) map[uint][]Edge {
|
||||
ret := make(map[uint][]Edge, len(order))
|
||||
edgeSet := make(map[[2]uint]struct{}, 16)
|
||||
|
||||
for from, edges := range adj {
|
||||
for _, e := range edges {
|
||||
edgeSet[[2]uint{from, e.to}] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for from, edges := range adj {
|
||||
for _, e := range edges {
|
||||
if _, ok := edgeSet[[2]uint{e.to, from}]; !ok {
|
||||
continue
|
||||
}
|
||||
ret[from] = append(ret[from], e)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range order {
|
||||
if _, ok := ret[id]; !ok {
|
||||
ret[id] = []Edge{}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// ensureRoutingPeerSymmetry 确保:如果 src 的 peers 中存在 nextHop(承载路由),则 nextHop 的 peers 中也必须存在 src。
|
||||
// 这里“对称”不是指两端 routes/AllowedIPs 集合一致,而是指两端都必须配置对方这个 peer,
|
||||
// 以满足 WG 的解密与源地址校验(否则 nextHop 会丢弃来自 src 的转发包)。
|
||||
func ensureRoutingPeerSymmetry(order []uint, peerCfgs map[uint][]*pb.WireGuardPeerConfig, idToPeer map[uint]*models.WireGuard) error {
|
||||
if len(order) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 预计算每个节点自身的 /32 CIDR(AsBasePeerConfig 返回的 AllowedIps[0])
|
||||
selfCIDR := make(map[uint]string, len(order))
|
||||
for _, id := range order {
|
||||
p := idToPeer[id]
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
base, err := p.AsBasePeerConfig(nil)
|
||||
if err != nil || len(base.GetAllowedIps()) == 0 {
|
||||
continue
|
||||
}
|
||||
selfCIDR[id] = base.GetAllowedIps()[0]
|
||||
}
|
||||
|
||||
hasPeer := func(owner uint, peerID uint) bool {
|
||||
for _, pc := range peerCfgs[owner] {
|
||||
if pc == nil {
|
||||
continue
|
||||
}
|
||||
if uint(pc.GetId()) == peerID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, src := range order {
|
||||
for _, pc := range peerCfgs[src] {
|
||||
if pc == nil {
|
||||
continue
|
||||
}
|
||||
if len(pc.GetAllowedIps()) == 0 {
|
||||
continue
|
||||
}
|
||||
nextHop := uint(pc.GetId())
|
||||
if nextHop == 0 || nextHop == src {
|
||||
continue
|
||||
}
|
||||
if hasPeer(nextHop, src) {
|
||||
continue
|
||||
}
|
||||
|
||||
remote := idToPeer[src]
|
||||
if remote == nil {
|
||||
continue
|
||||
}
|
||||
base, err := remote.AsBasePeerConfig(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cidr := selfCIDR[src]; cidr != "" {
|
||||
base.AllowedIps = []string{cidr}
|
||||
}
|
||||
peerCfgs[nextHop] = append(peerCfgs[nextHop], base)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range order {
|
||||
sort.SliceStable(peerCfgs[id], func(i, j int) bool {
|
||||
return peerCfgs[id][i].GetClientId() < peerCfgs[id][j].GetClientId()
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildNodeIndex(peers []*models.WireGuard) (map[uint]*models.WireGuard, []uint) {
|
||||
idToPeer := make(map[uint]*models.WireGuard, len(peers))
|
||||
order := make([]uint, 0, len(peers))
|
||||
@@ -364,7 +481,15 @@ func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*m
|
||||
visited[u] = true
|
||||
for _, e := range adj[u] {
|
||||
invBw := 1.0 / math.Max(float64(e.upMbps), 1e-6)
|
||||
w := policy.LatencyWeight*float64(e.latency) + policy.InverseBandwidthWeight*invBw + policy.HopWeight
|
||||
handshakePenalty := 0.0
|
||||
if policy.HandshakeStalePenalty > 0 && policy.HandshakeStaleThreshold > 0 {
|
||||
// 握手惩罚必须是“无方向”的,否则会导致 A->B 与 B->A 权重不一致,
|
||||
// 进而产生单向选路(WireGuard AllowedIPs 源地址校验下会丢包)。
|
||||
if age, ok := getHandshakeAgeBetween(u, e.to, idToPeer, policy); ok && age > policy.HandshakeStaleThreshold {
|
||||
handshakePenalty = policy.HandshakeStalePenalty
|
||||
}
|
||||
}
|
||||
w := policy.LatencyWeight*float64(e.latency) + policy.InverseBandwidthWeight*invBw + policy.HopWeight + handshakePenalty
|
||||
alt := dist[u] + w
|
||||
if alt < dist[e.to] {
|
||||
dist[e.to] = alt
|
||||
@@ -417,6 +542,63 @@ func runAllPairsDijkstra(order []uint, adj map[uint][]Edge, idToPeer map[uint]*m
|
||||
return aggByNode, edgeInfoMap
|
||||
}
|
||||
|
||||
// getHandshakeAgeBetween 返回 a<->b 间 peer handshake 的“最大”年龄(只要任意一侧可观测到握手时间就生效)。
|
||||
// 选择 max 的原因:如果任一方向握手过旧,都应抑制这对节点作为可靠转发 hop。
|
||||
func getHandshakeAgeBetween(aWGID, bWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) {
|
||||
ageA, okA := getOneWayHandshakeAge(aWGID, bWGID, idToPeer, policy)
|
||||
ageB, okB := getOneWayHandshakeAge(bWGID, aWGID, idToPeer, policy)
|
||||
if !okA && !okB {
|
||||
return 0, false
|
||||
}
|
||||
if !okA {
|
||||
return ageB, true
|
||||
}
|
||||
if !okB {
|
||||
return ageA, true
|
||||
}
|
||||
if ageA >= ageB {
|
||||
return ageA, true
|
||||
}
|
||||
return ageB, true
|
||||
}
|
||||
|
||||
// getOneWayHandshakeAge 从 fromWGID 的 runtimeInfo 中,查找到 toWGID 对应 peer 的 last_handshake_time_sec/nsec,返回握手“距离现在”的时间差。
|
||||
func getOneWayHandshakeAge(fromWGID, toWGID uint, idToPeer map[uint]*models.WireGuard, policy RoutingPolicy) (time.Duration, bool) {
|
||||
if policy.NetworkTopologyCache == nil {
|
||||
return 0, false
|
||||
}
|
||||
toPeer := idToPeer[toWGID]
|
||||
if toPeer == nil || toPeer.ClientID == "" {
|
||||
return 0, false
|
||||
}
|
||||
runtimeInfo, ok := policy.NetworkTopologyCache.GetRuntimeInfo(fromWGID)
|
||||
if !ok || runtimeInfo == nil {
|
||||
return 0, false
|
||||
}
|
||||
var hsSec uint64
|
||||
var hsNsec uint64
|
||||
for _, p := range runtimeInfo.GetPeers() {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if p.GetClientId() != toPeer.ClientID {
|
||||
continue
|
||||
}
|
||||
hsSec = p.GetLastHandshakeTimeSec()
|
||||
hsNsec = p.GetLastHandshakeTimeNsec()
|
||||
break
|
||||
}
|
||||
if hsSec == 0 {
|
||||
return 0, false
|
||||
}
|
||||
t := time.Unix(int64(hsSec), int64(hsNsec))
|
||||
age := time.Since(t)
|
||||
if age < 0 {
|
||||
age = 0
|
||||
}
|
||||
return age, true
|
||||
}
|
||||
|
||||
func initSSSP(order []uint) (map[uint]float64, map[uint]uint, map[uint]bool) {
|
||||
dist := make(map[uint]float64, len(order))
|
||||
prev := make(map[uint]uint, len(order))
|
||||
|
||||
@@ -2,17 +2,28 @@ package wg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/models"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
)
|
||||
|
||||
type fakeTopologyCache struct {
|
||||
lat map[[2]uint]uint32
|
||||
rt map[uint]*pb.WGDeviceRuntimeInfo
|
||||
}
|
||||
|
||||
func (c *fakeTopologyCache) GetRuntimeInfo(_ uint) (*pb.WGDeviceRuntimeInfo, bool) { return nil, false }
|
||||
func (c *fakeTopologyCache) SetRuntimeInfo(_ uint, _ *pb.WGDeviceRuntimeInfo) {}
|
||||
func (c *fakeTopologyCache) DeleteRuntimeInfo(_ uint) {}
|
||||
func (c *fakeTopologyCache) GetRuntimeInfo(id uint) (*pb.WGDeviceRuntimeInfo, bool) {
|
||||
if c == nil || c.rt == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := c.rt[id]
|
||||
return v, ok
|
||||
}
|
||||
func (c *fakeTopologyCache) SetRuntimeInfo(_ uint, _ *pb.WGDeviceRuntimeInfo) {}
|
||||
func (c *fakeTopologyCache) DeleteRuntimeInfo(_ uint) {}
|
||||
func (c *fakeTopologyCache) GetLatencyMs(fromWGID, toWGID uint) (uint32, bool) {
|
||||
if c == nil || c.lat == nil {
|
||||
return 0, false
|
||||
@@ -78,3 +89,198 @@ func TestFilterAdjacencyForSPF(t *testing.T) {
|
||||
t.Fatalf("node 2 should exist in return map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAllPairsDijkstra_PreferFreshHandshake(t *testing.T) {
|
||||
// 1 -> 2 (stale handshake)
|
||||
// 1 -> 3 (fresh)
|
||||
// 3 -> 2 (fresh)
|
||||
// 期望:从 1 到 2 的 nextHop 选择 3,而不是 2
|
||||
now := time.Now().Unix()
|
||||
|
||||
priv1, _ := wgtypes.GeneratePrivateKey()
|
||||
priv2, _ := wgtypes.GeneratePrivateKey()
|
||||
priv3, _ := wgtypes.GeneratePrivateKey()
|
||||
|
||||
p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
|
||||
ClientID: "c1",
|
||||
PrivateKey: priv1.String(),
|
||||
LocalAddress: "10.0.0.1/32",
|
||||
}}
|
||||
p1.ID = 1
|
||||
p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
|
||||
ClientID: "c2",
|
||||
PrivateKey: priv2.String(),
|
||||
LocalAddress: "10.0.0.2/32",
|
||||
}}
|
||||
p2.ID = 2
|
||||
p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
|
||||
ClientID: "c3",
|
||||
PrivateKey: priv3.String(),
|
||||
LocalAddress: "10.0.0.3/32",
|
||||
}}
|
||||
p3.ID = 3
|
||||
|
||||
idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3}
|
||||
order := []uint{1, 2, 3}
|
||||
adj := map[uint][]Edge{
|
||||
1: {
|
||||
{to: 2, latency: 10, upMbps: 50, explicit: true},
|
||||
{to: 3, latency: 5, upMbps: 50, explicit: true},
|
||||
},
|
||||
3: {
|
||||
{to: 2, latency: 5, upMbps: 50, explicit: true},
|
||||
},
|
||||
2: {},
|
||||
}
|
||||
|
||||
cache := &fakeTopologyCache{
|
||||
rt: map[uint]*pb.WGDeviceRuntimeInfo{
|
||||
1: {
|
||||
Peers: []*pb.WGPeerRuntimeInfo{
|
||||
{ClientId: "c2", LastHandshakeTimeSec: uint64(now - 3600)}, // stale
|
||||
{ClientId: "c3", LastHandshakeTimeSec: uint64(now)}, // fresh
|
||||
},
|
||||
},
|
||||
3: {
|
||||
Peers: []*pb.WGPeerRuntimeInfo{
|
||||
{ClientId: "c2", LastHandshakeTimeSec: uint64(now)}, // fresh
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
policy := RoutingPolicy{
|
||||
LatencyWeight: 1,
|
||||
InverseBandwidthWeight: 0,
|
||||
HopWeight: 0,
|
||||
HandshakeStaleThreshold: 1 * time.Second,
|
||||
HandshakeStalePenalty: 100,
|
||||
NetworkTopologyCache: cache,
|
||||
}
|
||||
|
||||
aggByNode, _ := runAllPairsDijkstra(order, adj, idToPeer, policy)
|
||||
if aggByNode[1] == nil {
|
||||
t.Fatalf("aggByNode[1] should not be nil")
|
||||
}
|
||||
// dst=2 的 CIDR 应该被聚合到 nextHop=3 下(而不是 nextHop=2)
|
||||
if _, ok := aggByNode[1][3]; !ok {
|
||||
t.Fatalf("want nextHop=3 for src=1, got keys=%v", keysUint(aggByNode[1]))
|
||||
}
|
||||
if _, ok := aggByNode[1][2]; ok {
|
||||
t.Fatalf("did not expect nextHop=2 for src=1 when handshake is stale")
|
||||
}
|
||||
}
|
||||
|
||||
func keysUint(m map[uint]map[string]struct{}) []uint {
|
||||
ret := make([]uint, 0, len(m))
|
||||
for k := range m {
|
||||
ret = append(ret, k)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestSymmetrizeAdjacencyForPeers_FillReverseEdge(t *testing.T) {
|
||||
t.Skip("symmetrizeAdjacencyForPeers 已移除:路由承载的边必须双向存在,不应自动补齐单向边")
|
||||
}
|
||||
|
||||
func TestFilterAdjacencyForSymmetricLinks_DropOneWay(t *testing.T) {
|
||||
order := []uint{1, 2}
|
||||
adj := map[uint][]Edge{
|
||||
1: {{to: 2, latency: 10, upMbps: 50, explicit: true}}, // 单向
|
||||
2: {},
|
||||
}
|
||||
ret := filterAdjacencyForSymmetricLinks(order, adj)
|
||||
if len(ret[1]) != 0 {
|
||||
t.Fatalf("want 0 edges for node 1 after symmetric filter, got %d: %#v", len(ret[1]), ret[1])
|
||||
}
|
||||
if _, ok := ret[2]; !ok {
|
||||
t.Fatalf("node 2 should exist in return map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureRoutingPeerSymmetry_AddReversePeer(t *testing.T) {
|
||||
// 构造一个“1 直连 2,但 2 到 1 会更偏好走 3”的场景:
|
||||
// 1->2 成为承载路由的 nextHop,但 2 的路由结果中可能不包含 peer(1),需要对称补齐。
|
||||
now := time.Now().Unix()
|
||||
|
||||
priv1, _ := wgtypes.GeneratePrivateKey()
|
||||
priv2, _ := wgtypes.GeneratePrivateKey()
|
||||
priv3, _ := wgtypes.GeneratePrivateKey()
|
||||
|
||||
p1 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
|
||||
ClientID: "c1",
|
||||
PrivateKey: priv1.String(),
|
||||
LocalAddress: "10.0.0.1/32",
|
||||
}}
|
||||
p1.ID = 1
|
||||
p2 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
|
||||
ClientID: "c2",
|
||||
PrivateKey: priv2.String(),
|
||||
LocalAddress: "10.0.0.2/32",
|
||||
}}
|
||||
p2.ID = 2
|
||||
p3 := &models.WireGuard{WireGuardEntity: &models.WireGuardEntity{
|
||||
ClientID: "c3",
|
||||
PrivateKey: priv3.String(),
|
||||
LocalAddress: "10.0.0.3/32",
|
||||
}}
|
||||
p3.ID = 3
|
||||
|
||||
idToPeer := map[uint]*models.WireGuard{1: p1, 2: p2, 3: p3}
|
||||
order := []uint{1, 2, 3}
|
||||
|
||||
// 全双向连通,但设置权重让 2->1 更偏好 2->3->1
|
||||
adj := map[uint][]Edge{
|
||||
1: {
|
||||
{to: 2, latency: 1, upMbps: 50, explicit: true},
|
||||
{to: 3, latency: 100, upMbps: 50, explicit: true},
|
||||
},
|
||||
2: {
|
||||
{to: 1, latency: 100, upMbps: 50, explicit: true},
|
||||
{to: 3, latency: 1, upMbps: 50, explicit: true},
|
||||
},
|
||||
3: {
|
||||
{to: 1, latency: 1, upMbps: 50, explicit: true},
|
||||
{to: 2, latency: 100, upMbps: 50, explicit: true},
|
||||
},
|
||||
}
|
||||
|
||||
cache := &fakeTopologyCache{
|
||||
rt: map[uint]*pb.WGDeviceRuntimeInfo{
|
||||
1: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c2", LastHandshakeTimeSec: uint64(now)}}},
|
||||
2: {Peers: []*pb.WGPeerRuntimeInfo{{ClientId: "c1", LastHandshakeTimeSec: uint64(now)}}},
|
||||
},
|
||||
}
|
||||
|
||||
policy := RoutingPolicy{
|
||||
LatencyWeight: 1,
|
||||
InverseBandwidthWeight: 0,
|
||||
HopWeight: 0,
|
||||
HandshakeStaleThreshold: 1 * time.Hour,
|
||||
HandshakeStalePenalty: 0,
|
||||
NetworkTopologyCache: cache,
|
||||
}
|
||||
|
||||
aggByNode, edgeInfo := runAllPairsDijkstra(order, adj, idToPeer, policy)
|
||||
peersMap, err := assemblePeerConfigs(order, aggByNode, edgeInfo, idToPeer)
|
||||
if err != nil {
|
||||
t.Fatalf("assemblePeerConfigs err: %v", err)
|
||||
}
|
||||
fillIsolates(order, peersMap)
|
||||
|
||||
// 预期:在没有对称补齐前,2 可能不会包含 peer(1)
|
||||
_ = ensureRoutingPeerSymmetry(order, peersMap, idToPeer)
|
||||
|
||||
found := false
|
||||
for _, pc := range peersMap[2] {
|
||||
if pc != nil && pc.GetId() == 1 {
|
||||
found = true
|
||||
if len(pc.GetAllowedIps()) == 0 || pc.GetAllowedIps()[0] != "10.0.0.1/32" {
|
||||
t.Fatalf("peer(1) on node2 should include 10.0.0.1/32, got=%v", pc.GetAllowedIps())
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("node2 should contain peer(1) after ensureRoutingPeerSymmetry")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,67 +4,18 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/defs"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
"github.com/VaalaCat/frp-panel/services/app"
|
||||
"github.com/VaalaCat/frp-panel/services/wg/multibind"
|
||||
"github.com/VaalaCat/frp-panel/services/wg/transport/ws"
|
||||
"github.com/VaalaCat/frp-panel/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
ReportInterval = time.Second * 60
|
||||
)
|
||||
|
||||
var (
|
||||
_ app.WireGuard = (*wireGuard)(nil)
|
||||
)
|
||||
|
||||
type wireGuard struct {
|
||||
sync.RWMutex
|
||||
|
||||
ifce *defs.WireGuardConfig
|
||||
endpointPingMap *utils.SyncMap[uint32, uint32] // ms
|
||||
virtAddrPingMap *utils.SyncMap[string, uint32] // ms
|
||||
|
||||
wgDevice *device.Device
|
||||
tunDevice tun.Device
|
||||
multiBind *multibind.MultiBind
|
||||
gvisorNet *netstack.Net
|
||||
fwManager *firewallManager
|
||||
|
||||
running bool
|
||||
useGvisorNet bool // if true, use gvisor netstack
|
||||
|
||||
svcLogger *logrus.Entry
|
||||
ctx *app.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.Entry) (app.WireGuard, error) {
|
||||
if logger == nil {
|
||||
defaultLog := logrus.New()
|
||||
@@ -86,7 +37,6 @@ func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.En
|
||||
fwManager := newFirewallManager(logger.WithField("component", "iptables"))
|
||||
|
||||
return &wireGuard{
|
||||
RWMutex: sync.RWMutex{},
|
||||
ifce: &cfg,
|
||||
ctx: svcCtx,
|
||||
cancel: cancel,
|
||||
@@ -95,6 +45,8 @@ func NewWireGuard(ctx *app.Context, ifce defs.WireGuardConfig, logger *logrus.En
|
||||
useGvisorNet: useGvisorNet,
|
||||
virtAddrPingMap: &utils.SyncMap[string, uint32]{},
|
||||
fwManager: fwManager,
|
||||
peerDirectory: make(map[uint32]*pb.WireGuardPeerConfig, 64),
|
||||
preconnectPeers: make(map[uint32]struct{}, 64),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -191,6 +143,7 @@ func (w *wireGuard) AddPeer(peer *defs.WireGuardPeerConfig) error {
|
||||
defer w.Unlock()
|
||||
|
||||
w.ifce.Peers = append(w.ifce.Peers, peer.WireGuardPeerConfig)
|
||||
w.indexPeerDirectoryLocked(peerCfg)
|
||||
uapiBuilder := NewUAPIBuilder().AddPeerConfig(peerCfg)
|
||||
|
||||
log.Debugf("uapiBuilder: %s", uapiBuilder.Build())
|
||||
@@ -199,6 +152,9 @@ func (w *wireGuard) AddPeer(peer *defs.WireGuardPeerConfig) error {
|
||||
return errors.Join(errors.New("add peer IpcSet error"), err)
|
||||
}
|
||||
|
||||
// 补齐本节点可连接的常驻 peer(若目录里有基础信息)
|
||||
w.onPeersChangedLocked("after AddPeer")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -248,23 +204,27 @@ func (w *wireGuard) RemovePeer(peerNameOrPk string) error {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
|
||||
newPeers := []*pb.WireGuardPeerConfig{}
|
||||
var peerToRemove *defs.WireGuardPeerConfig
|
||||
// 语义:真正移除 peer(下发 remove=true),并从本地配置中删除。
|
||||
// 如需要“只移除路由但保持连接”,请使用 PatchPeers/UpdatePeer 下发 AllowedIPs=nil 的更新策略。
|
||||
|
||||
var removedPeerPB *pb.WireGuardPeerConfig
|
||||
newPeers := make([]*pb.WireGuardPeerConfig, 0, len(w.ifce.Peers))
|
||||
for _, p := range w.ifce.Peers {
|
||||
if p.ClientId != peerNameOrPk && p.PublicKey != peerNameOrPk {
|
||||
newPeers = append(newPeers, p)
|
||||
if p.ClientId == peerNameOrPk || p.PublicKey == peerNameOrPk {
|
||||
removedPeerPB = p
|
||||
continue
|
||||
}
|
||||
peerToRemove = &defs.WireGuardPeerConfig{WireGuardPeerConfig: p}
|
||||
newPeers = append(newPeers, p)
|
||||
}
|
||||
|
||||
if len(newPeers) == len(w.ifce.Peers) {
|
||||
if removedPeerPB == nil {
|
||||
return errors.New("peer not found")
|
||||
}
|
||||
|
||||
w.ifce.Peers = newPeers
|
||||
removedPeer := &defs.WireGuardPeerConfig{WireGuardPeerConfig: removedPeerPB}
|
||||
log.Debugf("remove peer completely: key=%s pk=%s", truncate(peerNameOrPk, 10), truncate(removedPeer.GetPublicKey(), 10))
|
||||
|
||||
uapiBuilder := NewUAPIBuilder().RemovePeerByKey(peerToRemove.GetParsedPublicKey())
|
||||
uapiBuilder := NewUAPIBuilder().RemovePeerByKey(removedPeer.GetParsedPublicKey())
|
||||
|
||||
log.Debugf("uapiBuilder: %s", uapiBuilder.Build())
|
||||
|
||||
@@ -272,6 +232,22 @@ func (w *wireGuard) RemovePeer(peerNameOrPk string) error {
|
||||
return errors.Join(errors.New("remove peer IpcSet error"), err)
|
||||
}
|
||||
|
||||
// IpcSet 成功后再更新本地缓存,避免不一致
|
||||
w.ifce.Peers = newPeers
|
||||
w.deletePeerDirectoryLocked(removedPeer)
|
||||
if id := removedPeer.GetId(); id != 0 {
|
||||
delete(w.preconnectPeers, id)
|
||||
}
|
||||
if removedPeer.GetEndpoint() != nil {
|
||||
if id := removedPeer.GetEndpoint().GetWireguardId(); id != 0 {
|
||||
delete(w.preconnectPeers, id)
|
||||
}
|
||||
}
|
||||
w.cleanupPreconnectPeersLocked()
|
||||
|
||||
// 移除后也补齐其他可连接 peer(若目录里有基础信息)
|
||||
w.onPeersChangedLocked("after RemovePeer")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -297,6 +273,7 @@ func (w *wireGuard) UpdatePeer(peer *defs.WireGuardPeerConfig) error {
|
||||
}
|
||||
|
||||
w.ifce.Peers = newPeers
|
||||
w.indexPeerDirectoryLocked(peerCfg)
|
||||
|
||||
uapiBuilder := NewUAPIBuilder().UpdatePeerConfig(peerCfg)
|
||||
|
||||
@@ -306,6 +283,9 @@ func (w *wireGuard) UpdatePeer(peer *defs.WireGuardPeerConfig) error {
|
||||
return errors.Join(errors.New("update peer IpcSet error"), err)
|
||||
}
|
||||
|
||||
// 更新后补齐常驻 peer(若目录里有基础信息)
|
||||
w.onPeersChangedLocked("after UpdatePeer")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -323,6 +303,24 @@ func (w *wireGuard) PatchPeers(newPeers []*defs.WireGuardPeerConfig) (*app.WireG
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 优化:从 adj 中提取“本节点可直连/可连接”的 peer,确保它们常驻但不分配 AllowedIPs。
|
||||
// 这样后续路由变化只需要更新 AllowedIPs,不需要 remove+add 重新建立连接。
|
||||
beforeMerge := len(typedNewPeers)
|
||||
typedNewPeers = mergeConnectablePeersFromAdj(w.ifce, typedNewPeers, oldPeers)
|
||||
if delta := len(typedNewPeers) - beforeMerge; delta > 0 {
|
||||
log.Debugf("merged connectable peers from adjs: +%d (desired=%d -> %d)", delta, beforeMerge, len(typedNewPeers))
|
||||
}
|
||||
|
||||
// 更新 peerDirectory(把本次看到的 peer 基础信息都缓存起来,后续 adjs 变化可用于补齐常驻 peer)
|
||||
w.Lock()
|
||||
for _, p := range typedNewPeers {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
w.indexPeerDirectoryLocked(p)
|
||||
}
|
||||
w.Unlock()
|
||||
|
||||
oldByPK := make(map[string]*defs.WireGuardPeerConfig, len(oldPeers))
|
||||
for _, p := range oldPeers {
|
||||
if p == nil || p.GetPublicKey() == "" {
|
||||
@@ -398,6 +396,12 @@ func (w *wireGuard) PatchPeers(newPeers []*defs.WireGuardPeerConfig) (*app.WireG
|
||||
}
|
||||
w.ifce.Peers = newPBPeers
|
||||
|
||||
// 清理 preconnectPeers 中已不存在的 peer id,避免无限增长
|
||||
w.cleanupPreconnectPeersLocked()
|
||||
|
||||
// PatchPeers 后再次补齐(主要覆盖:adjs 先于 peers 变化、且目录已有信息的场景)
|
||||
w.onPeersChangedLocked("after PatchPeers")
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -456,7 +460,9 @@ func (w *wireGuard) UpdateAdjs(adjs map[uint32]*pb.WireGuardLinks) error {
|
||||
defer w.Unlock()
|
||||
|
||||
w.ifce.Adjs = adjs
|
||||
return nil
|
||||
|
||||
// adjs 变化后立刻补齐“可连接 peer 常驻”(只要 peerDirectory 中已有其基础信息)
|
||||
return w.ensureConnectablePeersLocked()
|
||||
}
|
||||
|
||||
func (w *wireGuard) NeedRecreate(newCfg *defs.WireGuardConfig) bool {
|
||||
@@ -470,310 +476,3 @@ func (w *wireGuard) NeedRecreate(newCfg *defs.WireGuardConfig) bool {
|
||||
w.ifce.GetUseGvisorNet() != newCfg.GetUseGvisorNet() ||
|
||||
w.ifce.GetNetworkId() != newCfg.GetNetworkId()
|
||||
}
|
||||
|
||||
func (w *wireGuard) initTransports() error {
|
||||
log := w.svcLogger.WithField("op", "initTransports")
|
||||
|
||||
wsTrans := ws.NewWSBind(w.ctx)
|
||||
w.multiBind = multibind.NewMultiBind(
|
||||
w.svcLogger,
|
||||
multibind.NewTransport(conn.NewDefaultBind(), "udp"),
|
||||
multibind.NewTransport(wsTrans, "ws"),
|
||||
)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Any(defs.DefaultWSHandlerPath, func(c *gin.Context) {
|
||||
err := wsTrans.HandleHTTP(c.Writer, c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// if ws listen port not set, use wg listen port, share tcp and udp port
|
||||
listenPort := w.ifce.GetWsListenPort()
|
||||
if listenPort == 0 {
|
||||
listenPort = w.ifce.GetListenPort()
|
||||
}
|
||||
go func() {
|
||||
if err := engine.Run(fmt.Sprintf(":%d", listenPort)); err != nil {
|
||||
w.svcLogger.WithError(err).Errorf("failed to run gin engine for ws transport on port %d", listenPort)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("WS transport engine running on port %d", listenPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) initWGDevice() error {
|
||||
log := w.svcLogger.WithField("op", "initWGDevice")
|
||||
|
||||
log.Debugf("start to create TUN device '%s' (MTU %d)", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
|
||||
|
||||
var err error
|
||||
|
||||
if w.useGvisorNet {
|
||||
log.Infof("using gvisor netstack for TUN device")
|
||||
prf, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("parse local addr '%s' for netip", w.ifce.GetLocalAddress()), err)
|
||||
}
|
||||
|
||||
addrs := lo.Map(w.ifce.GetDnsServers(), func(s string, _ int) netip.Addr {
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return addr
|
||||
})
|
||||
if len(addrs) == 0 {
|
||||
addrs = []netip.Addr{netip.AddrFrom4([4]byte{1, 2, 4, 8})}
|
||||
}
|
||||
log.Debugf("create netstack TUN with addr '%s' and dns servers '%v'", prf.Addr().String(), addrs)
|
||||
w.tunDevice, w.gvisorNet, err = netstack.CreateNetTUN([]netip.Addr{prf.Addr()}, addrs, 1200)
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("create netstack TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
|
||||
}
|
||||
} else {
|
||||
w.tunDevice, err = tun.CreateTUN(w.ifce.GetInterfaceName(), int(w.ifce.GetInterfaceMtu()))
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("create TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("TUN device '%s' (MTU %d) created successfully", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
|
||||
|
||||
log.Debugf("start to create WireGuard device '%s'", w.ifce.GetInterfaceName())
|
||||
|
||||
w.wgDevice = device.NewDevice(w.tunDevice, w.multiBind, &device.Logger{
|
||||
Verbosef: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Debugf,
|
||||
Errorf: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Errorf,
|
||||
})
|
||||
|
||||
log.Debugf("WireGuard device '%s' created successfully", w.ifce.GetInterfaceName())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) applyPeerConfig() error {
|
||||
log := w.svcLogger.WithField("op", "applyConfig")
|
||||
|
||||
log.Debugf("start to apply config to WireGuard device '%s'", w.ifce.GetInterfaceName())
|
||||
|
||||
if w.wgDevice == nil {
|
||||
return errors.New("wgDevice is nil, please init WG device first")
|
||||
}
|
||||
|
||||
wgTypedPeerConfigs, err := parseAndValidatePeerConfigs(w.ifce.GetParsedPeers())
|
||||
if err != nil {
|
||||
return errors.Join(errors.New("parse/validate peers"), err)
|
||||
}
|
||||
|
||||
log.Debugf("wgTypedPeerConfigs: %v", wgTypedPeerConfigs)
|
||||
|
||||
uapiConfigString := generateUAPIConfigString(w.ifce, w.ifce.GetParsedPrivKey(), wgTypedPeerConfigs, !w.running, false)
|
||||
|
||||
log.Debugf("uapiBuilder: %s", uapiConfigString)
|
||||
|
||||
log.Debugf("calling IpcSet...")
|
||||
if err = w.wgDevice.IpcSet(uapiConfigString); err != nil {
|
||||
return errors.Join(errors.New("IpcSet error"), err)
|
||||
}
|
||||
log.Debugf("IpcSet completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) initNetwork() error {
|
||||
log := w.svcLogger.WithField("op", "initNetwork")
|
||||
|
||||
// 等待 TUN 设备在内核中完全注册,避免竞态条件
|
||||
var link netlink.Link
|
||||
var err error
|
||||
maxRetries := 10
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
link, err = netlink.LinkByName(w.ifce.GetInterfaceName())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i < maxRetries-1 {
|
||||
log.Debugf("attempt %d: waiting for iface '%s' to be ready, will retry...", i+1, w.ifce.GetInterfaceName())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("get iface '%s' via netlink after %d retries", w.ifce.GetInterfaceName(), maxRetries), err)
|
||||
}
|
||||
log.Debugf("successfully found interface '%s' via netlink", w.ifce.GetInterfaceName())
|
||||
|
||||
addr, err := netlink.ParseAddr(w.ifce.GetLocalAddress())
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("parse local addr '%s' for netlink", w.ifce.GetLocalAddress()), err)
|
||||
}
|
||||
|
||||
if err = netlink.AddrAdd(link, addr); err != nil && !os.IsExist(err) {
|
||||
return errors.Join(fmt.Errorf("add IP '%s' to '%s'", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()), err)
|
||||
} else if os.IsExist(err) {
|
||||
log.Infof("IP %s already on '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
|
||||
} else {
|
||||
log.Infof("IP %s added to '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
|
||||
}
|
||||
|
||||
if err = netlink.LinkSetMTU(link, int(w.ifce.GetInterfaceMtu())); err != nil {
|
||||
log.Warnf("Set MTU %d on '%s' via netlink: %v. TUN MTU is %d.",
|
||||
w.ifce.GetInterfaceMtu(), w.ifce.GetInterfaceName(), err, w.ifce.GetInterfaceMtu())
|
||||
} else {
|
||||
log.Infof("Iface '%s' MTU %d set via netlink.", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
|
||||
}
|
||||
|
||||
if err = netlink.LinkSetUp(link); err != nil {
|
||||
return errors.Join(fmt.Errorf("bring up iface '%s' via netlink", w.ifce.GetInterfaceName()), err)
|
||||
}
|
||||
|
||||
log.Infof("Iface '%s' up via netlink.", w.ifce.GetInterfaceName())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) initGvisorNetwork() error {
|
||||
log := w.svcLogger.WithField("op", "initGvisorNetwork")
|
||||
|
||||
if w.gvisorNet == nil {
|
||||
return errors.New("gvisorNet is nil, cannot initialize network")
|
||||
}
|
||||
|
||||
// wg-go dose not expose the stack field, so we need to use reflection to access it
|
||||
netValue := reflect.ValueOf(w.gvisorNet).Elem()
|
||||
stackField := netValue.FieldByName("stack")
|
||||
|
||||
if !stackField.IsValid() {
|
||||
return errors.New("cannot find stack field in gvisorNet")
|
||||
}
|
||||
|
||||
stackPtrValue := reflect.NewAt(stackField.Type(), unsafe.Pointer(stackField.UnsafeAddr())).Elem()
|
||||
if !stackPtrValue.IsValid() || stackPtrValue.IsNil() {
|
||||
return errors.New("gvisor stack is nil or invalid")
|
||||
}
|
||||
|
||||
gvisorStack := stackPtrValue.Interface().(*stack.Stack)
|
||||
if gvisorStack == nil {
|
||||
return errors.New("gvisor stack is nil after conversion")
|
||||
}
|
||||
|
||||
log.Infof("successfully accessed gvisor stack, enabling IP forwarding")
|
||||
|
||||
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
|
||||
log.Warnf("failed to enable IPv4 forwarding: %v, relay may not work", err)
|
||||
} else {
|
||||
log.Infof("IPv4 forwarding enabled for gvisor netstack")
|
||||
}
|
||||
|
||||
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
|
||||
log.Warnf("failed to enable IPv6 forwarding: %v", err)
|
||||
} else {
|
||||
log.Infof("IPv6 forwarding enabled for gvisor netstack")
|
||||
}
|
||||
|
||||
for _, peer := range w.ifce.Peers {
|
||||
for _, allowedIP := range peer.AllowedIps {
|
||||
prefix, err := netip.ParsePrefix(allowedIP)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("failed to parse allowed IP: %s", allowedIP)
|
||||
continue
|
||||
}
|
||||
|
||||
addr := tcpip.AddrFromSlice(prefix.Addr().AsSlice())
|
||||
|
||||
ones := prefix.Bits()
|
||||
maskBytes := make([]byte, len(prefix.Addr().AsSlice()))
|
||||
for i := 0; i < len(maskBytes); i++ {
|
||||
if ones >= 8 {
|
||||
maskBytes[i] = 0xff
|
||||
ones -= 8
|
||||
} else if ones > 0 {
|
||||
maskBytes[i] = byte(0xff << (8 - ones))
|
||||
ones = 0
|
||||
}
|
||||
}
|
||||
|
||||
subnet, err := tcpip.NewSubnet(addr, tcpip.MaskFromBytes(maskBytes))
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("failed to create subnet for %s", allowedIP)
|
||||
continue
|
||||
}
|
||||
|
||||
route := tcpip.Route{
|
||||
Destination: subnet,
|
||||
NIC: 1,
|
||||
}
|
||||
|
||||
gvisorStack.AddRoute(route)
|
||||
log.Debugf("added route for peer allowed IP: %s via NIC 1", allowedIP)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("gvisor netstack initialized with IP forwarding enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupNetwork() {
|
||||
log := w.svcLogger.WithField("op", "cleanupNetwork")
|
||||
|
||||
if w.useGvisorNet {
|
||||
log.Infof("skip network cleanup for gvisor netstack")
|
||||
return
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(w.ifce.GetInterfaceName())
|
||||
if err == nil {
|
||||
if err := netlink.LinkSetDown(link); err != nil {
|
||||
log.Warnf("Failed to LinkSetDown '%s' after wgDevice.Up() error: %v", w.ifce.GetInterfaceName(), err)
|
||||
}
|
||||
}
|
||||
log.Debug("Cleanup network complete.")
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupWGDevice() {
|
||||
log := w.svcLogger.WithField("op", "cleanupWGDevice")
|
||||
|
||||
if w.wgDevice != nil {
|
||||
w.wgDevice.Close()
|
||||
} else if w.tunDevice != nil {
|
||||
w.tunDevice.Close()
|
||||
}
|
||||
w.wgDevice = nil
|
||||
w.tunDevice = nil
|
||||
log.Debug("Cleanup WG device complete.")
|
||||
}
|
||||
|
||||
func (w *wireGuard) applyFirewallRulesLocked() error {
|
||||
if w.useGvisorNet || w.fwManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("parse local address '%s' for firewall", w.ifce.GetLocalAddress()), err)
|
||||
}
|
||||
|
||||
return w.fwManager.ApplyRelayRules(w.ifce.GetInterfaceName(), prefix.Masked().String())
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupFirewallRulesLocked() error {
|
||||
if w.useGvisorNet || w.fwManager == nil {
|
||||
return nil
|
||||
}
|
||||
return w.fwManager.Cleanup(w.ifce.GetInterfaceName())
|
||||
}
|
||||
|
||||
func (w *wireGuard) reportStatusTask() {
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
default:
|
||||
w.pingPeers()
|
||||
time.Sleep(ReportInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
108
services/wg/wireguard_device.go
Normal file
108
services/wg/wireguard_device.go
Normal file
@@ -0,0 +1,108 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
|
||||
func (w *wireGuard) initWGDevice() error {
|
||||
log := w.svcLogger.WithField("op", "initWGDevice")
|
||||
|
||||
log.Debugf("start to create TUN device '%s' (MTU %d)", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
|
||||
|
||||
var err error
|
||||
|
||||
if w.useGvisorNet {
|
||||
log.Infof("using gvisor netstack for TUN device")
|
||||
prf, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("parse local addr '%s' for netip", w.ifce.GetLocalAddress()), err)
|
||||
}
|
||||
|
||||
addrs := lo.Map(w.ifce.GetDnsServers(), func(s string, _ int) netip.Addr {
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
return addr
|
||||
})
|
||||
if len(addrs) == 0 {
|
||||
addrs = []netip.Addr{netip.AddrFrom4([4]byte{1, 2, 4, 8})}
|
||||
}
|
||||
log.Debugf("create netstack TUN with addr '%s' and dns servers '%v'", prf.Addr().String(), addrs)
|
||||
w.tunDevice, w.gvisorNet, err = netstack.CreateNetTUN([]netip.Addr{prf.Addr()}, addrs, 1200)
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("create netstack TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
|
||||
}
|
||||
} else {
|
||||
w.tunDevice, err = tun.CreateTUN(w.ifce.GetInterfaceName(), int(w.ifce.GetInterfaceMtu()))
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("create TUN device '%s' (MTU %d) failed", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu()), err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("TUN device '%s' (MTU %d) created successfully", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
|
||||
|
||||
log.Debugf("start to create WireGuard device '%s'", w.ifce.GetInterfaceName())
|
||||
|
||||
w.wgDevice = device.NewDevice(w.tunDevice, w.multiBind, &device.Logger{
|
||||
Verbosef: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Debugf,
|
||||
Errorf: w.svcLogger.WithField("wg-dev-iface", w.ifce.GetInterfaceName()).Errorf,
|
||||
})
|
||||
|
||||
log.Debugf("WireGuard device '%s' created successfully", w.ifce.GetInterfaceName())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) applyPeerConfig() error {
|
||||
log := w.svcLogger.WithField("op", "applyConfig")
|
||||
|
||||
log.Debugf("start to apply config to WireGuard device '%s'", w.ifce.GetInterfaceName())
|
||||
|
||||
if w.wgDevice == nil {
|
||||
return errors.New("wgDevice is nil, please init WG device first")
|
||||
}
|
||||
|
||||
wgTypedPeerConfigs, err := parseAndValidatePeerConfigs(w.ifce.GetParsedPeers())
|
||||
if err != nil {
|
||||
return errors.Join(errors.New("parse/validate peers"), err)
|
||||
}
|
||||
|
||||
log.Debugf("wgTypedPeerConfigs: %v", wgTypedPeerConfigs)
|
||||
|
||||
uapiConfigString := generateUAPIConfigString(w.ifce, w.ifce.GetParsedPrivKey(), wgTypedPeerConfigs, !w.running, false)
|
||||
|
||||
log.Debugf("uapiBuilder: %s", uapiConfigString)
|
||||
|
||||
log.Debugf("calling IpcSet...")
|
||||
if err = w.wgDevice.IpcSet(uapiConfigString); err != nil {
|
||||
return errors.Join(errors.New("IpcSet error"), err)
|
||||
}
|
||||
log.Debugf("IpcSet completed successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupWGDevice() {
|
||||
log := w.svcLogger.WithField("op", "cleanupWGDevice")
|
||||
|
||||
if w.wgDevice != nil {
|
||||
w.wgDevice.Close()
|
||||
} else if w.tunDevice != nil {
|
||||
w.tunDevice.Close()
|
||||
}
|
||||
w.wgDevice = nil
|
||||
w.tunDevice = nil
|
||||
log.Debug("Cleanup WG device complete.")
|
||||
}
|
||||
30
services/wg/wireguard_firewall_ops.go
Normal file
30
services/wg/wireguard_firewall_ops.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func (w *wireGuard) applyFirewallRulesLocked() error {
|
||||
if w.useGvisorNet || w.fwManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(w.ifce.GetLocalAddress())
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("parse local address '%s' for firewall", w.ifce.GetLocalAddress()), err)
|
||||
}
|
||||
|
||||
return w.fwManager.ApplyRelayRules(w.ifce.GetInterfaceName(), prefix.Masked().String())
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupFirewallRulesLocked() error {
|
||||
if w.useGvisorNet || w.fwManager == nil {
|
||||
return nil
|
||||
}
|
||||
return w.fwManager.Cleanup(w.ifce.GetInterfaceName())
|
||||
}
|
||||
97
services/wg/wireguard_network_gvisor.go
Normal file
97
services/wg/wireguard_network_gvisor.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
func (w *wireGuard) initGvisorNetwork() error {
|
||||
log := w.svcLogger.WithField("op", "initGvisorNetwork")
|
||||
|
||||
if w.gvisorNet == nil {
|
||||
return errors.New("gvisorNet is nil, cannot initialize network")
|
||||
}
|
||||
|
||||
// wg-go dose not expose the stack field, so we need to use reflection to access it
|
||||
netValue := reflect.ValueOf(w.gvisorNet).Elem()
|
||||
stackField := netValue.FieldByName("stack")
|
||||
|
||||
if !stackField.IsValid() {
|
||||
return errors.New("cannot find stack field in gvisorNet")
|
||||
}
|
||||
|
||||
stackPtrValue := reflect.NewAt(stackField.Type(), unsafe.Pointer(stackField.UnsafeAddr())).Elem()
|
||||
if !stackPtrValue.IsValid() || stackPtrValue.IsNil() {
|
||||
return errors.New("gvisor stack is nil or invalid")
|
||||
}
|
||||
|
||||
gvisorStack := stackPtrValue.Interface().(*stack.Stack)
|
||||
if gvisorStack == nil {
|
||||
return errors.New("gvisor stack is nil after conversion")
|
||||
}
|
||||
|
||||
log.Infof("successfully accessed gvisor stack, enabling IP forwarding")
|
||||
|
||||
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
|
||||
log.Warnf("failed to enable IPv4 forwarding: %v, relay may not work", err)
|
||||
} else {
|
||||
log.Infof("IPv4 forwarding enabled for gvisor netstack")
|
||||
}
|
||||
|
||||
if err := gvisorStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
|
||||
log.Warnf("failed to enable IPv6 forwarding: %v", err)
|
||||
} else {
|
||||
log.Infof("IPv6 forwarding enabled for gvisor netstack")
|
||||
}
|
||||
|
||||
for _, peer := range w.ifce.Peers {
|
||||
for _, allowedIP := range peer.AllowedIps {
|
||||
prefix, err := netip.ParsePrefix(allowedIP)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("failed to parse allowed IP: %s", allowedIP)
|
||||
continue
|
||||
}
|
||||
|
||||
addr := tcpip.AddrFromSlice(prefix.Addr().AsSlice())
|
||||
|
||||
ones := prefix.Bits()
|
||||
maskBytes := make([]byte, len(prefix.Addr().AsSlice()))
|
||||
for i := 0; i < len(maskBytes); i++ {
|
||||
if ones >= 8 {
|
||||
maskBytes[i] = 0xff
|
||||
ones -= 8
|
||||
} else if ones > 0 {
|
||||
maskBytes[i] = byte(0xff << (8 - ones))
|
||||
ones = 0
|
||||
}
|
||||
}
|
||||
|
||||
subnet, err := tcpip.NewSubnet(addr, tcpip.MaskFromBytes(maskBytes))
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("failed to create subnet for %s", allowedIP)
|
||||
continue
|
||||
}
|
||||
|
||||
route := tcpip.Route{
|
||||
Destination: subnet,
|
||||
NIC: 1,
|
||||
}
|
||||
|
||||
gvisorStack.AddRoute(route)
|
||||
log.Debugf("added route for peer allowed IP: %s via NIC 1", allowedIP)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("gvisor netstack initialized with IP forwarding enabled")
|
||||
return nil
|
||||
}
|
||||
80
services/wg/wireguard_network_netlink.go
Normal file
80
services/wg/wireguard_network_netlink.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
func (w *wireGuard) initNetwork() error {
|
||||
log := w.svcLogger.WithField("op", "initNetwork")
|
||||
|
||||
// 等待 TUN 设备在内核中完全注册,避免竞态条件
|
||||
var link netlink.Link
|
||||
var err error
|
||||
maxRetries := 10
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
link, err = netlink.LinkByName(w.ifce.GetInterfaceName())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i < maxRetries-1 {
|
||||
log.Debugf("attempt %d: waiting for iface '%s' to be ready, will retry...", i+1, w.ifce.GetInterfaceName())
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("get iface '%s' via netlink after %d retries", w.ifce.GetInterfaceName(), maxRetries), err)
|
||||
}
|
||||
log.Debugf("successfully found interface '%s' via netlink", w.ifce.GetInterfaceName())
|
||||
|
||||
addr, err := netlink.ParseAddr(w.ifce.GetLocalAddress())
|
||||
if err != nil {
|
||||
return errors.Join(fmt.Errorf("parse local addr '%s' for netlink", w.ifce.GetLocalAddress()), err)
|
||||
}
|
||||
|
||||
if err = netlink.AddrAdd(link, addr); err != nil && !os.IsExist(err) {
|
||||
return errors.Join(fmt.Errorf("add IP '%s' to '%s'", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName()), err)
|
||||
} else if os.IsExist(err) {
|
||||
log.Infof("IP %s already on '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
|
||||
} else {
|
||||
log.Infof("IP %s added to '%s'.", w.ifce.GetLocalAddress(), w.ifce.GetInterfaceName())
|
||||
}
|
||||
|
||||
if err = netlink.LinkSetMTU(link, int(w.ifce.GetInterfaceMtu())); err != nil {
|
||||
log.Warnf("Set MTU %d on '%s' via netlink: %v. TUN MTU is %d.",
|
||||
w.ifce.GetInterfaceMtu(), w.ifce.GetInterfaceName(), err, w.ifce.GetInterfaceMtu())
|
||||
} else {
|
||||
log.Infof("Iface '%s' MTU %d set via netlink.", w.ifce.GetInterfaceName(), w.ifce.GetInterfaceMtu())
|
||||
}
|
||||
|
||||
if err = netlink.LinkSetUp(link); err != nil {
|
||||
return errors.Join(fmt.Errorf("bring up iface '%s' via netlink", w.ifce.GetInterfaceName()), err)
|
||||
}
|
||||
|
||||
log.Infof("Iface '%s' up via netlink.", w.ifce.GetInterfaceName())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupNetwork() {
|
||||
log := w.svcLogger.WithField("op", "cleanupNetwork")
|
||||
|
||||
if w.useGvisorNet {
|
||||
log.Infof("skip network cleanup for gvisor netstack")
|
||||
return
|
||||
}
|
||||
|
||||
link, err := netlink.LinkByName(w.ifce.GetInterfaceName())
|
||||
if err == nil {
|
||||
if err := netlink.LinkSetDown(link); err != nil {
|
||||
log.Warnf("Failed to LinkSetDown '%s' after wgDevice.Up() error: %v", w.ifce.GetInterfaceName(), err)
|
||||
}
|
||||
}
|
||||
log.Debug("Cleanup network complete.")
|
||||
}
|
||||
133
services/wg/wireguard_patchpeers_preconnect_test.go
Normal file
133
services/wg/wireguard_patchpeers_preconnect_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/defs"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
)
|
||||
|
||||
func TestMergeConnectablePeersFromAdj_AddsMissingPeerWithEmptyAllowedIPs(t *testing.T) {
|
||||
ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{
|
||||
Id: 1,
|
||||
Adjs: map[uint32]*pb.WireGuardLinks{
|
||||
1: {Links: []*pb.WireGuardLink{
|
||||
{ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "1.2.3.4", Port: 51820}},
|
||||
}},
|
||||
},
|
||||
}}
|
||||
|
||||
desired := []*defs.WireGuardPeerConfig{
|
||||
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
|
||||
Id: 3,
|
||||
PublicKey: "pk-3",
|
||||
AllowedIps: []string{
|
||||
"10.0.0.3/32",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
known := []*defs.WireGuardPeerConfig{
|
||||
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
|
||||
Id: 2,
|
||||
PublicKey: "pk-2",
|
||||
AllowedIps: []string{
|
||||
"10.0.0.2/32",
|
||||
},
|
||||
PersistentKeepalive: 0, // should be defaulted by parseAndValidatePeerConfig
|
||||
}},
|
||||
}
|
||||
|
||||
got := mergeConnectablePeersFromAdj(ifce, desired, known)
|
||||
|
||||
var found *defs.WireGuardPeerConfig
|
||||
for _, p := range got {
|
||||
if p != nil && p.GetId() == 2 {
|
||||
found = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
t.Fatalf("expected peer id=2 to be added")
|
||||
}
|
||||
if len(found.GetAllowedIps()) != 0 {
|
||||
t.Fatalf("expected allowed_ips empty, got=%v", found.GetAllowedIps())
|
||||
}
|
||||
if found.GetEndpoint() == nil || found.GetEndpoint().GetHost() != "1.2.3.4" || found.GetEndpoint().GetPort() != 51820 {
|
||||
t.Fatalf("expected endpoint from adj to be applied, got=%v", found.GetEndpoint())
|
||||
}
|
||||
if found.GetPersistentKeepalive() == 0 {
|
||||
t.Fatalf("expected persistent_keepalive defaulted, got=0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConnectablePeersFromAdj_DoesNotOverrideExistingDesiredPeer(t *testing.T) {
|
||||
ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{
|
||||
Id: 1,
|
||||
Adjs: map[uint32]*pb.WireGuardLinks{
|
||||
1: {Links: []*pb.WireGuardLink{
|
||||
{ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "9.9.9.9", Port: 9999}},
|
||||
}},
|
||||
},
|
||||
}}
|
||||
|
||||
desired := []*defs.WireGuardPeerConfig{
|
||||
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
|
||||
Id: 2,
|
||||
PublicKey: "pk-2",
|
||||
AllowedIps: []string{
|
||||
"10.0.0.2/32",
|
||||
},
|
||||
Endpoint: &pb.Endpoint{Host: "1.1.1.1", Port: 1111},
|
||||
}},
|
||||
}
|
||||
|
||||
got := mergeConnectablePeersFromAdj(ifce, desired, nil)
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected no extra peers added, got len=%d", len(got))
|
||||
}
|
||||
if len(got[0].GetAllowedIps()) != 1 || got[0].GetAllowedIps()[0] != "10.0.0.2/32" {
|
||||
t.Fatalf("expected desired allowed_ips preserved, got=%v", got[0].GetAllowedIps())
|
||||
}
|
||||
// endpoint should not be overridden because peer already existed in desired list
|
||||
if got[0].GetEndpoint() == nil || got[0].GetEndpoint().GetHost() != "1.1.1.1" {
|
||||
t.Fatalf("expected desired endpoint preserved, got=%v", got[0].GetEndpoint())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeConnectablePeersFromAdj_UsesEndpointWireguardIDWhenPeerIDMissing(t *testing.T) {
|
||||
ifce := &defs.WireGuardConfig{WireGuardConfig: &pb.WireGuardConfig{
|
||||
Id: 1,
|
||||
Adjs: map[uint32]*pb.WireGuardLinks{
|
||||
1: {Links: []*pb.WireGuardLink{
|
||||
{ToWireguardId: 2, ToEndpoint: &pb.Endpoint{Host: "2.2.2.2", Port: 2222}},
|
||||
}},
|
||||
},
|
||||
}}
|
||||
|
||||
desired := []*defs.WireGuardPeerConfig{}
|
||||
|
||||
known := []*defs.WireGuardPeerConfig{
|
||||
{WireGuardPeerConfig: &pb.WireGuardPeerConfig{
|
||||
Id: 0, // 模拟:peer.id 未下发
|
||||
PublicKey: "pk-2",
|
||||
Endpoint: &pb.Endpoint{WireguardId: 2, Host: "old", Port: 1},
|
||||
}},
|
||||
}
|
||||
|
||||
got := mergeConnectablePeersFromAdj(ifce, desired, known)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 peer added, got len=%d", len(got))
|
||||
}
|
||||
if got[0].GetPublicKey() != "pk-2" {
|
||||
t.Fatalf("expected pk-2, got=%s", got[0].GetPublicKey())
|
||||
}
|
||||
// endpoint should be overridden by adj's to_endpoint
|
||||
if got[0].GetEndpoint() == nil || got[0].GetEndpoint().GetHost() != "2.2.2.2" || got[0].GetEndpoint().GetPort() != 2222 {
|
||||
t.Fatalf("expected endpoint from adj, got=%v", got[0].GetEndpoint())
|
||||
}
|
||||
if len(got[0].GetAllowedIps()) != 0 {
|
||||
t.Fatalf("expected allowed_ips empty, got=%v", got[0].GetAllowedIps())
|
||||
}
|
||||
}
|
||||
362
services/wg/wireguard_peers_preconnect.go
Normal file
362
services/wg/wireguard_peers_preconnect.go
Normal file
@@ -0,0 +1,362 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/defs"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
)
|
||||
|
||||
// peer 预连接/常驻补齐逻辑
|
||||
//
|
||||
// 目标:在拓扑(adjs)变化时,确保本节点对“可直连/可连接”的 peer 已配置到 wg 设备,
|
||||
// 但 AllowedIPs 为空(只保持连接,不承载路由),从而避免路由变化导致频繁 remove+add 造成断链。
|
||||
|
||||
func (w *wireGuard) onPeersChangedLocked(reason string) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
if err := w.ensureConnectablePeersLocked(); err != nil {
|
||||
w.svcLogger.WithError(err).WithField("op", "onPeersChanged").Warnf("ensure connectable peers failed (%s)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wireGuard) cleanupPreconnectPeersLocked() {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
if len(w.preconnectPeers) == 0 {
|
||||
return
|
||||
}
|
||||
exists := make(map[uint32]struct{}, len(w.ifce.Peers))
|
||||
for _, p := range w.ifce.GetParsedPeers() {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if id := p.GetId(); id != 0 {
|
||||
exists[id] = struct{}{}
|
||||
}
|
||||
if p.GetEndpoint() != nil {
|
||||
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
|
||||
exists[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
for id := range w.preconnectPeers {
|
||||
if _, ok := exists[id]; !ok {
|
||||
delete(w.preconnectPeers, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wireGuard) indexPeerDirectoryLocked(p *defs.WireGuardPeerConfig) {
|
||||
if p == nil || p.WireGuardPeerConfig == nil {
|
||||
return
|
||||
}
|
||||
// 1) peer.id
|
||||
if id := p.GetId(); id != 0 {
|
||||
w.peerDirectory[id] = proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig)
|
||||
}
|
||||
// 2) endpoint.wireguard_id(部分场景 peer.id 可能未填,但 endpoint 带 wireguard_id)
|
||||
if p.GetEndpoint() != nil {
|
||||
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
|
||||
w.peerDirectory[id] = proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wireGuard) deletePeerDirectoryLocked(p *defs.WireGuardPeerConfig) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
if id := p.GetId(); id != 0 {
|
||||
delete(w.peerDirectory, id)
|
||||
}
|
||||
if p.GetEndpoint() != nil {
|
||||
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
|
||||
delete(w.peerDirectory, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureConnectablePeersLocked 保证本节点在 wg 设备里已配置所有“当前可直连/可连接”的 peer,
|
||||
// 但这些补齐 peer 的 AllowedIPs 为空(只保持连接,不承载路由)。
|
||||
//
|
||||
// 约束:必须在持有 w.Lock() 的情况下调用。
|
||||
func (w *wireGuard) ensureConnectablePeersLocked() error {
|
||||
if w == nil || w.ifce == nil || w.wgDevice == nil {
|
||||
return nil
|
||||
}
|
||||
localID := w.ifce.GetId()
|
||||
if localID == 0 {
|
||||
return nil
|
||||
}
|
||||
adjs := w.ifce.GetAdjs()
|
||||
if adjs == nil {
|
||||
return nil
|
||||
}
|
||||
localLinks, ok := adjs[localID]
|
||||
if !ok || localLinks == nil || len(localLinks.GetLinks()) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 当前可直连/可连接的 peer id 集合(来自 adj)
|
||||
connectable := make(map[uint32]struct{}, len(localLinks.GetLinks()))
|
||||
for _, l := range localLinks.GetLinks() {
|
||||
if l == nil {
|
||||
continue
|
||||
}
|
||||
toID := l.GetToWireguardId()
|
||||
if toID == 0 || toID == localID {
|
||||
continue
|
||||
}
|
||||
connectable[toID] = struct{}{}
|
||||
}
|
||||
|
||||
log := w.svcLogger.WithField("op", "ensureConnectablePeers")
|
||||
log.Debugf("ensure connectable peers: local=%d connectable=%d peers=%d preconnect=%d directory=%d",
|
||||
localID, len(connectable), len(w.ifce.Peers), len(w.preconnectPeers), len(w.peerDirectory))
|
||||
|
||||
// 当前已配置的 peer:用 peer.id 与 endpoint.wireguard_id 双索引,避免 peer.id 缺失导致重复补齐
|
||||
exists := make(map[uint32]struct{}, len(w.ifce.Peers))
|
||||
for _, p := range w.ifce.GetParsedPeers() {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if id := p.GetId(); id != 0 {
|
||||
exists[id] = struct{}{}
|
||||
}
|
||||
if p.GetEndpoint() != nil {
|
||||
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
|
||||
exists[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uapiBuilder := NewUAPIBuilder()
|
||||
added := 0
|
||||
removed := 0
|
||||
skippedNoBase := 0
|
||||
skippedAlready := 0
|
||||
|
||||
// 先清理:相比上次,本次拓扑中已“完全不可直连”的 peer,需要彻底从设备移除
|
||||
// 仅清理“AllowedIPs 为空”的 peer(也就是不承载路由、只为保持连接而存在的 peer)
|
||||
newPeers := make([]*pb.WireGuardPeerConfig, 0, len(w.ifce.Peers))
|
||||
for _, raw := range w.ifce.GetParsedPeers() {
|
||||
if raw == nil || raw.WireGuardPeerConfig == nil {
|
||||
continue
|
||||
}
|
||||
// 仅对 AllowedIPs 为空的 peer 做自动清理
|
||||
if len(raw.GetAllowedIps()) != 0 {
|
||||
newPeers = append(newPeers, raw.WireGuardPeerConfig)
|
||||
continue
|
||||
}
|
||||
|
||||
var peerID uint32
|
||||
if raw.GetId() != 0 {
|
||||
peerID = raw.GetId()
|
||||
} else if raw.GetEndpoint() != nil && raw.GetEndpoint().GetWireguardId() != 0 {
|
||||
peerID = raw.GetEndpoint().GetWireguardId()
|
||||
}
|
||||
|
||||
// 无法识别 peer id:保守起见不清理
|
||||
if peerID == 0 {
|
||||
newPeers = append(newPeers, raw.WireGuardPeerConfig)
|
||||
continue
|
||||
}
|
||||
|
||||
// 当前拓扑不可直连:彻底移除
|
||||
if _, ok := connectable[peerID]; !ok {
|
||||
log.Debugf("preconnect remove: peerID=%d pk=%s (reason=not_connectable)", peerID, truncate(raw.GetPublicKey(), 10))
|
||||
uapiBuilder.RemovePeerByKey(raw.GetParsedPublicKey())
|
||||
delete(w.preconnectPeers, peerID)
|
||||
delete(exists, peerID)
|
||||
removed++
|
||||
continue
|
||||
}
|
||||
|
||||
newPeers = append(newPeers, raw.WireGuardPeerConfig)
|
||||
}
|
||||
// 如果有清理发生,先更新本地缓存(设备更新在最后统一 IpcSet)
|
||||
if removed > 0 {
|
||||
w.ifce.Peers = newPeers
|
||||
}
|
||||
|
||||
for _, l := range localLinks.GetLinks() {
|
||||
if l == nil {
|
||||
continue
|
||||
}
|
||||
toID := l.GetToWireguardId()
|
||||
if toID == 0 || toID == localID {
|
||||
continue
|
||||
}
|
||||
if _, ok := exists[toID]; ok {
|
||||
skippedAlready++
|
||||
continue
|
||||
}
|
||||
|
||||
base, ok := w.peerDirectory[toID]
|
||||
if !ok || base == nil || base.GetPublicKey() == "" {
|
||||
skippedNoBase++
|
||||
continue
|
||||
}
|
||||
cloned := &defs.WireGuardPeerConfig{WireGuardPeerConfig: proto.Clone(base).(*pb.WireGuardPeerConfig)}
|
||||
cloned.AllowedIps = nil
|
||||
if l.GetToEndpoint() != nil {
|
||||
cloned.Endpoint = l.GetToEndpoint()
|
||||
}
|
||||
if _, err := parseAndValidatePeerConfig(cloned); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("preconnect add: peerID=%d pk=%s endpoint=%s",
|
||||
toID, truncate(cloned.GetPublicKey(), 10), endpointForLog(cloned.GetEndpoint()))
|
||||
uapiBuilder.AddPeerConfig(cloned)
|
||||
w.ifce.Peers = append(w.ifce.Peers, cloned.WireGuardPeerConfig)
|
||||
w.indexPeerDirectoryLocked(cloned)
|
||||
exists[toID] = struct{}{}
|
||||
w.preconnectPeers[toID] = struct{}{}
|
||||
added++
|
||||
}
|
||||
|
||||
if added == 0 && removed == 0 {
|
||||
log.Debugf("ensure result: no-op (skippedAlready=%d skippedNoBase=%d)", skippedAlready, skippedNoBase)
|
||||
return nil
|
||||
}
|
||||
log.Debugf("ensure result: add=%d remove=%d skippedAlready=%d skippedNoBase=%d", added, removed, skippedAlready, skippedNoBase)
|
||||
if err := w.wgDevice.IpcSet(uapiBuilder.Build()); err != nil {
|
||||
log.WithError(err).Debugf("ensure IpcSet failed (add=%d remove=%d)", added, removed)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeConnectablePeersFromAdj 将本节点 adj 图中可直连的 peer 合并进目标 peers。
|
||||
//
|
||||
// - **只在目标 peers 中缺失时才补齐**(按 PublicKey 去重),避免覆盖由路由规划器计算出的 AllowedIPs。
|
||||
// - **补齐的 peer AllowedIPs 置空**,确保不会引入额外路由。
|
||||
// - 如链路显式携带 to_endpoint,则优先用它覆盖 peer.endpoint(用于快速恢复直连)。
|
||||
//
|
||||
// knownPeers 用于在目标 peers 缺失时提供“可用的 peer 基础信息”(公钥/预共享密钥/端点等),通常传 oldPeers。
|
||||
func mergeConnectablePeersFromAdj(ifce *defs.WireGuardConfig, desiredPeers []*defs.WireGuardPeerConfig, knownPeers []*defs.WireGuardPeerConfig) []*defs.WireGuardPeerConfig {
|
||||
if ifce == nil {
|
||||
return desiredPeers
|
||||
}
|
||||
localID := ifce.GetId()
|
||||
if localID == 0 {
|
||||
return desiredPeers
|
||||
}
|
||||
adjs := ifce.GetAdjs()
|
||||
if adjs == nil {
|
||||
return desiredPeers
|
||||
}
|
||||
localLinks, ok := adjs[localID]
|
||||
if !ok || localLinks == nil || len(localLinks.GetLinks()) == 0 {
|
||||
return desiredPeers
|
||||
}
|
||||
|
||||
desiredByPK := make(map[string]*defs.WireGuardPeerConfig, len(desiredPeers))
|
||||
for _, p := range desiredPeers {
|
||||
if p == nil || p.GetPublicKey() == "" {
|
||||
continue
|
||||
}
|
||||
desiredByPK[p.GetPublicKey()] = p
|
||||
}
|
||||
|
||||
// build id -> peer 基础信息索引(优先 desired,其次 known)
|
||||
idToPeer := make(map[uint32]*defs.WireGuardPeerConfig, len(desiredPeers)+len(knownPeers))
|
||||
putPeerIDs := func(p *defs.WireGuardPeerConfig) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
// 1) peer.id
|
||||
if id := p.GetId(); id != 0 {
|
||||
if _, exists := idToPeer[id]; !exists {
|
||||
idToPeer[id] = p
|
||||
}
|
||||
}
|
||||
// 2) endpoint.wireguard_id(有些下发场景可能不填 peer.id,但 endpoint 里带 wireguard_id)
|
||||
if p.GetEndpoint() != nil {
|
||||
if id := p.GetEndpoint().GetWireguardId(); id != 0 {
|
||||
if _, exists := idToPeer[id]; !exists {
|
||||
idToPeer[id] = p
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, p := range desiredPeers {
|
||||
putPeerIDs(p)
|
||||
}
|
||||
for _, p := range knownPeers {
|
||||
putPeerIDs(p)
|
||||
}
|
||||
|
||||
// 仅补齐:adj 中的直连节点(to_wireguard_id)
|
||||
for _, l := range localLinks.GetLinks() {
|
||||
if l == nil {
|
||||
continue
|
||||
}
|
||||
toID := l.GetToWireguardId()
|
||||
if toID == 0 || toID == localID {
|
||||
continue
|
||||
}
|
||||
|
||||
base, ok := idToPeer[toID]
|
||||
if !ok || base == nil || base.GetPublicKey() == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := desiredByPK[base.GetPublicKey()]; exists {
|
||||
// 已在目标列表中(通常含有路由规划器计算出的 AllowedIPs),不覆盖。
|
||||
continue
|
||||
}
|
||||
|
||||
// 复制一份(避免直接改 oldPeers / knownPeers 的底层 pb 指针)
|
||||
cloned := clonePeerConfig(base)
|
||||
// 不分配路由:AllowedIPs 置空
|
||||
cloned.AllowedIps = nil
|
||||
// 显式链路 endpoint 优先
|
||||
if l.GetToEndpoint() != nil {
|
||||
cloned.Endpoint = l.GetToEndpoint()
|
||||
}
|
||||
// 确保 keepalive/AllowedIPs 格式一致(复用现有校验逻辑)
|
||||
if _, err := parseAndValidatePeerConfig(cloned); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
desiredPeers = append(desiredPeers, cloned)
|
||||
desiredByPK[cloned.GetPublicKey()] = cloned
|
||||
}
|
||||
|
||||
return desiredPeers
|
||||
}
|
||||
|
||||
func clonePeerConfig(p *defs.WireGuardPeerConfig) *defs.WireGuardPeerConfig {
|
||||
if p == nil || p.WireGuardPeerConfig == nil {
|
||||
return &defs.WireGuardPeerConfig{}
|
||||
}
|
||||
|
||||
// 使用 proto.Clone 避免直接拷贝 protoimpl.MessageState(内部含 mutex,会触发拷贝锁值的告警)
|
||||
cp, _ := proto.Clone(p.WireGuardPeerConfig).(*pb.WireGuardPeerConfig)
|
||||
if cp == nil {
|
||||
return &defs.WireGuardPeerConfig{}
|
||||
}
|
||||
return &defs.WireGuardPeerConfig{WireGuardPeerConfig: cp}
|
||||
}
|
||||
|
||||
func endpointForLog(ep *pb.Endpoint) string {
|
||||
if ep == nil {
|
||||
return ""
|
||||
}
|
||||
if ep.GetUri() != "" {
|
||||
return ep.GetUri()
|
||||
}
|
||||
if ep.GetHost() != "" || ep.GetPort() != 0 {
|
||||
return fmt.Sprintf("%s:%d", ep.GetHost(), ep.GetPort())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
22
services/wg/wireguard_report.go
Normal file
22
services/wg/wireguard_report.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
ReportInterval = time.Second * 60
|
||||
)
|
||||
|
||||
func (w *wireGuard) reportStatusTask() {
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
default:
|
||||
w.pingPeers()
|
||||
time.Sleep(ReportInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
52
services/wg/wireguard_transport.go
Normal file
52
services/wg/wireguard_transport.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/defs"
|
||||
"github.com/VaalaCat/frp-panel/services/wg/multibind"
|
||||
"github.com/VaalaCat/frp-panel/services/wg/transport/ws"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
func (w *wireGuard) initTransports() error {
|
||||
log := w.svcLogger.WithField("op", "initTransports")
|
||||
|
||||
wsTrans := ws.NewWSBind(w.ctx)
|
||||
w.multiBind = multibind.NewMultiBind(
|
||||
w.svcLogger,
|
||||
multibind.NewTransport(conn.NewDefaultBind(), "udp"),
|
||||
multibind.NewTransport(wsTrans, "ws"),
|
||||
)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Any(defs.DefaultWSHandlerPath, func(c *gin.Context) {
|
||||
err := wsTrans.HandleHTTP(c.Writer, c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// if ws listen port not set, use wg listen port, share tcp and udp port
|
||||
listenPort := w.ifce.GetWsListenPort()
|
||||
if listenPort == 0 {
|
||||
listenPort = w.ifce.GetListenPort()
|
||||
}
|
||||
go func() {
|
||||
if err := engine.Run(fmt.Sprintf(":%d", listenPort)); err != nil {
|
||||
w.svcLogger.WithError(err).Errorf("failed to run gin engine for ws transport on port %d", listenPort)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("WS transport engine running on port %d", listenPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
48
services/wg/wireguard_types.go
Normal file
48
services/wg/wireguard_types.go
Normal file
@@ -0,0 +1,48 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package wg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/VaalaCat/frp-panel/defs"
|
||||
"github.com/VaalaCat/frp-panel/pb"
|
||||
"github.com/VaalaCat/frp-panel/services/app"
|
||||
"github.com/VaalaCat/frp-panel/services/wg/multibind"
|
||||
"github.com/VaalaCat/frp-panel/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
_ app.WireGuard = (*wireGuard)(nil)
|
||||
)
|
||||
|
||||
type wireGuard struct {
|
||||
sync.RWMutex
|
||||
|
||||
ifce *defs.WireGuardConfig
|
||||
endpointPingMap *utils.SyncMap[uint32, uint32] // ms
|
||||
virtAddrPingMap *utils.SyncMap[string, uint32] // ms
|
||||
peerDirectory map[uint32]*pb.WireGuardPeerConfig
|
||||
// 仅用于“预连接/保持连接”的 peer(AllowedIPs 为空),用于后续根据拓扑变化做增删
|
||||
preconnectPeers map[uint32]struct{}
|
||||
|
||||
wgDevice *device.Device
|
||||
tunDevice tun.Device
|
||||
multiBind *multibind.MultiBind
|
||||
gvisorNet *netstack.Net
|
||||
fwManager *firewallManager
|
||||
|
||||
running bool
|
||||
useGvisorNet bool // if true, use gvisor netstack
|
||||
|
||||
svcLogger *logrus.Entry
|
||||
ctx *app.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -19,5 +19,5 @@ remotePort = 6000`)
|
||||
if err := LoadConfigureFromContent(content, allCfg, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Errorf("%+v", allCfg)
|
||||
t.Logf("%+v", allCfg)
|
||||
}
|
||||
|
||||
@@ -11,5 +11,5 @@ func TestAllocateIP(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("AllocateIP() failed: %v", err)
|
||||
}
|
||||
t.Errorf("AllocateIP() = %v", ip)
|
||||
t.Logf("AllocateIP() = %v", ip)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user