mirror of
https://github.com/veops/oneterm.git
synced 2025-10-16 04:10:39 +08:00
325 lines
8.5 KiB
Go
325 lines
8.5 KiB
Go
package schedule
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/samber/lo"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/veops/oneterm/internal/model"
|
|
"github.com/veops/oneterm/internal/tunneling"
|
|
dbpkg "github.com/veops/oneterm/pkg/db"
|
|
"github.com/veops/oneterm/pkg/logger"
|
|
"github.com/veops/oneterm/pkg/utils"
|
|
)
|
|
|
|
// ConnectableResult represents the result of a connectivity check
|
|
type ConnectableResult struct {
|
|
AssetID int
|
|
SessionID string
|
|
Success bool
|
|
Error error
|
|
}
|
|
|
|
func UpdateConnectables(ids ...int) (err error) {
|
|
start := time.Now()
|
|
defer func() {
|
|
duration := time.Since(start)
|
|
if err != nil {
|
|
logger.L().Warn("Connectivity check failed",
|
|
zap.Error(err),
|
|
zap.Duration("duration", duration))
|
|
} else {
|
|
logger.L().Info("Connectivity check completed",
|
|
zap.Duration("duration", duration))
|
|
}
|
|
}()
|
|
|
|
// Load assets to check
|
|
assets, err := getAssetsToCheck(ids...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(assets) == 0 {
|
|
logger.L().Debug("No assets to check connectivity")
|
|
return nil
|
|
}
|
|
|
|
logger.L().Info("Starting connectivity check",
|
|
zap.Int("total_assets", len(assets)),
|
|
zap.Int("batch_size", scheduleConfig.BatchSize),
|
|
zap.Int("concurrent_workers", scheduleConfig.ConcurrentWorkers))
|
|
|
|
// Load and decrypt gateways
|
|
gatewayMap, err := getGatewayMap(assets)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Process assets in concurrent batches
|
|
results := processConcurrentBatches(assets, gatewayMap)
|
|
|
|
// Update database with results
|
|
return updateConnectableStatus(results)
|
|
}
|
|
|
|
func getAssetsToCheck(ids ...int) ([]*model.Asset, error) {
|
|
assets := make([]*model.Asset, 0)
|
|
db := dbpkg.DB.Model(assets)
|
|
|
|
if len(ids) > 0 {
|
|
db = db.Where("id IN ?", ids)
|
|
} else {
|
|
// Only check assets not updated within the configured interval OR offline assets
|
|
checkInterval := scheduleConfig.ConnectableCheckInterval
|
|
db = db.Where("updated_at <= ? OR connectable = ?",
|
|
time.Now().Add(-checkInterval).Add(-time.Second*30), false)
|
|
}
|
|
|
|
// Only select fields needed for connectivity check, exclude authorization to avoid V1/V2 compatibility issues
|
|
if err := db.Select("id", "name", "ip", "protocols", "gateway_id", "connectable", "updated_at").Find(&assets).Error; err != nil {
|
|
logger.L().Error("Failed to get assets for connectivity check", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
return assets, nil
|
|
}
|
|
|
|
func getGatewayMap(assets []*model.Asset) (map[int]*model.Gateway, error) {
|
|
gids := lo.Without(lo.Uniq(lo.Map(assets, func(a *model.Asset, _ int) int {
|
|
return a.GatewayId
|
|
})), 0)
|
|
|
|
if len(gids) == 0 {
|
|
return make(map[int]*model.Gateway), nil
|
|
}
|
|
|
|
gateways := make([]*model.Gateway, 0)
|
|
if err := dbpkg.DB.Model(gateways).Where("id IN ?", gids).Find(&gateways).Error; err != nil {
|
|
logger.L().Error("Failed to get gateways for connectivity check", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
// Decrypt gateway credentials
|
|
for _, g := range gateways {
|
|
g.Password = utils.DecryptAES(g.Password)
|
|
g.Pk = utils.DecryptAES(g.Pk)
|
|
g.Phrase = utils.DecryptAES(g.Phrase)
|
|
}
|
|
|
|
return lo.SliceToMap(gateways, func(g *model.Gateway) (int, *model.Gateway) {
|
|
return g.Id, g
|
|
}), nil
|
|
}
|
|
|
|
func processConcurrentBatches(assets []*model.Asset, gatewayMap map[int]*model.Gateway) []ConnectableResult {
|
|
batchSize := scheduleConfig.BatchSize
|
|
concurrentWorkers := scheduleConfig.ConcurrentWorkers
|
|
|
|
totalBatches := int(math.Ceil(float64(len(assets)) / float64(batchSize)))
|
|
resultChan := make(chan []ConnectableResult, totalBatches)
|
|
semaphore := make(chan struct{}, concurrentWorkers)
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
// Process assets in batches
|
|
for i := 0; i < len(assets); i += batchSize {
|
|
end := min(i+batchSize, len(assets))
|
|
batch := assets[i:end]
|
|
|
|
wg.Add(1)
|
|
go func(batch []*model.Asset, batchNum int) {
|
|
defer wg.Done()
|
|
|
|
// Acquire semaphore
|
|
semaphore <- struct{}{}
|
|
defer func() { <-semaphore }()
|
|
|
|
logger.L().Debug("Processing connectivity batch",
|
|
zap.Int("batch_number", batchNum+1),
|
|
zap.Int("batch_size", len(batch)))
|
|
|
|
batchResults := processBatch(batch, gatewayMap)
|
|
resultChan <- batchResults
|
|
}(batch, i/batchSize)
|
|
}
|
|
|
|
// Close result channel when all goroutines complete
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultChan)
|
|
}()
|
|
|
|
// Collect all results
|
|
var allResults []ConnectableResult
|
|
for batchResults := range resultChan {
|
|
allResults = append(allResults, batchResults...)
|
|
}
|
|
|
|
return allResults
|
|
}
|
|
|
|
func processBatch(assets []*model.Asset, gatewayMap map[int]*model.Gateway) []ConnectableResult {
|
|
results := make([]ConnectableResult, len(assets))
|
|
|
|
for i, asset := range assets {
|
|
gateway := gatewayMap[asset.GatewayId]
|
|
sessionID, success := updateConnectable(asset, gateway)
|
|
results[i] = ConnectableResult{
|
|
AssetID: asset.Id,
|
|
SessionID: sessionID,
|
|
Success: success,
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
func updateConnectableStatus(results []ConnectableResult) error {
|
|
if len(results) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Collect session IDs for cleanup
|
|
sessionIDs := make([]string, 0, len(results))
|
|
successfulAssets := make([]int, 0)
|
|
failedAssets := make([]int, 0)
|
|
|
|
for _, result := range results {
|
|
sessionIDs = append(sessionIDs, result.SessionID)
|
|
if result.Success {
|
|
successfulAssets = append(successfulAssets, result.AssetID)
|
|
} else {
|
|
failedAssets = append(failedAssets, result.AssetID)
|
|
}
|
|
}
|
|
|
|
// Clean up tunnels
|
|
defer tunneling.CloseTunnels(sessionIDs...)
|
|
|
|
// Update successful assets
|
|
if len(successfulAssets) > 0 {
|
|
if err := dbpkg.DB.Model(&model.Asset{}).
|
|
Where("id IN ?", successfulAssets).
|
|
Update("connectable", true).Error; err != nil {
|
|
logger.L().Error("Failed to update successful assets",
|
|
zap.Error(err),
|
|
zap.Int("count", len(successfulAssets)))
|
|
return err
|
|
}
|
|
logger.L().Debug("Updated successful assets", zap.Int("count", len(successfulAssets)))
|
|
}
|
|
|
|
// Update failed assets
|
|
if len(failedAssets) > 0 {
|
|
if err := dbpkg.DB.Model(&model.Asset{}).
|
|
Where("id IN ?", failedAssets).
|
|
Update("connectable", false).Error; err != nil {
|
|
logger.L().Error("Failed to update failed assets",
|
|
zap.Error(err),
|
|
zap.Int("count", len(failedAssets)))
|
|
return err
|
|
}
|
|
logger.L().Debug("Updated failed assets", zap.Int("count", len(failedAssets)))
|
|
}
|
|
|
|
logger.L().Info("Connectivity status updated",
|
|
zap.Int("successful", len(successfulAssets)),
|
|
zap.Int("failed", len(failedAssets)),
|
|
zap.Int("total", len(results)))
|
|
|
|
return nil
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
|
|
sid = uuid.New().String()
|
|
ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string {
|
|
return strings.Split(p, ":")[0]
|
|
}), ",")
|
|
|
|
ip, port, err := tunneling.Proxy(true, sid, ps, asset, gateway)
|
|
if err != nil {
|
|
logger.L().Debug("Proxy connection failed",
|
|
zap.String("protocol", ps),
|
|
zap.String("asset_ip", asset.Ip),
|
|
zap.Int("asset_id", asset.Id),
|
|
zap.Error(err))
|
|
return
|
|
}
|
|
|
|
var hostPort string
|
|
if strings.Contains(ip, ":") {
|
|
hostPort = fmt.Sprintf("[%s]:%d", ip, port)
|
|
} else {
|
|
hostPort = fmt.Sprintf("%s:%d", ip, port)
|
|
}
|
|
|
|
// Use configurable timeout
|
|
conn, err := net.DialTimeout("tcp", hostPort, scheduleConfig.ConnectTimeout)
|
|
if err != nil {
|
|
logger.L().Debug("TCP connection failed",
|
|
zap.String("address", hostPort),
|
|
zap.Int("asset_id", asset.Id),
|
|
zap.Duration("timeout", scheduleConfig.ConnectTimeout),
|
|
zap.Error(err))
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Verify gateway tunnel if using gateway
|
|
if asset.GatewayId != 0 {
|
|
t := tunneling.GetTunnelBySessionId(sid)
|
|
if t == nil {
|
|
logger.L().Debug("Gateway tunnel not found",
|
|
zap.String("session_id", sid),
|
|
zap.Int("asset_id", asset.Id),
|
|
zap.Int("gateway_id", asset.GatewayId))
|
|
return
|
|
}
|
|
|
|
select {
|
|
case err = <-t.Opened:
|
|
if err != nil {
|
|
logger.L().Debug("Gateway tunnel failed to open",
|
|
zap.String("session_id", sid),
|
|
zap.Int("asset_id", asset.Id),
|
|
zap.Int("gateway_id", asset.GatewayId),
|
|
zap.Error(err))
|
|
return
|
|
}
|
|
case <-time.After(scheduleConfig.ConnectTimeout):
|
|
logger.L().Debug("Gateway tunnel open timeout",
|
|
zap.String("session_id", sid),
|
|
zap.Int("asset_id", asset.Id),
|
|
zap.Int("gateway_id", asset.GatewayId),
|
|
zap.Duration("timeout", scheduleConfig.ConnectTimeout))
|
|
return
|
|
}
|
|
}
|
|
|
|
logger.L().Debug("Asset connectivity check successful",
|
|
zap.Int("asset_id", asset.Id),
|
|
zap.String("address", hostPort))
|
|
ok = true
|
|
return
|
|
}
|
|
|
|
// UpdateAssetConnectables is used by service/asset.go to update connectables
|
|
func UpdateAssetConnectables(ids ...int) error {
|
|
return UpdateConnectables(ids...)
|
|
}
|