Files
oneterm/backend/internal/schedule/connectable.go
2025-07-16 18:11:04 +08:00

325 lines
8.5 KiB
Go

package schedule
import (
"fmt"
"math"
"net"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/tunneling"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
"github.com/veops/oneterm/pkg/utils"
)
// ConnectableResult represents the result of a connectivity check
type ConnectableResult struct {
AssetID int
SessionID string
Success bool
Error error
}
func UpdateConnectables(ids ...int) (err error) {
start := time.Now()
defer func() {
duration := time.Since(start)
if err != nil {
logger.L().Warn("Connectivity check failed",
zap.Error(err),
zap.Duration("duration", duration))
} else {
logger.L().Info("Connectivity check completed",
zap.Duration("duration", duration))
}
}()
// Load assets to check
assets, err := getAssetsToCheck(ids...)
if err != nil {
return err
}
if len(assets) == 0 {
logger.L().Debug("No assets to check connectivity")
return nil
}
logger.L().Info("Starting connectivity check",
zap.Int("total_assets", len(assets)),
zap.Int("batch_size", scheduleConfig.BatchSize),
zap.Int("concurrent_workers", scheduleConfig.ConcurrentWorkers))
// Load and decrypt gateways
gatewayMap, err := getGatewayMap(assets)
if err != nil {
return err
}
// Process assets in concurrent batches
results := processConcurrentBatches(assets, gatewayMap)
// Update database with results
return updateConnectableStatus(results)
}
func getAssetsToCheck(ids ...int) ([]*model.Asset, error) {
assets := make([]*model.Asset, 0)
db := dbpkg.DB.Model(assets)
if len(ids) > 0 {
db = db.Where("id IN ?", ids)
} else {
// Only check assets not updated within the configured interval OR offline assets
checkInterval := scheduleConfig.ConnectableCheckInterval
db = db.Where("updated_at <= ? OR connectable = ?",
time.Now().Add(-checkInterval).Add(-time.Second*30), false)
}
// Only select fields needed for connectivity check, exclude authorization to avoid V1/V2 compatibility issues
if err := db.Select("id", "name", "ip", "protocols", "gateway_id", "connectable", "updated_at").Find(&assets).Error; err != nil {
logger.L().Error("Failed to get assets for connectivity check", zap.Error(err))
return nil, err
}
return assets, nil
}
func getGatewayMap(assets []*model.Asset) (map[int]*model.Gateway, error) {
gids := lo.Without(lo.Uniq(lo.Map(assets, func(a *model.Asset, _ int) int {
return a.GatewayId
})), 0)
if len(gids) == 0 {
return make(map[int]*model.Gateway), nil
}
gateways := make([]*model.Gateway, 0)
if err := dbpkg.DB.Model(gateways).Where("id IN ?", gids).Find(&gateways).Error; err != nil {
logger.L().Error("Failed to get gateways for connectivity check", zap.Error(err))
return nil, err
}
// Decrypt gateway credentials
for _, g := range gateways {
g.Password = utils.DecryptAES(g.Password)
g.Pk = utils.DecryptAES(g.Pk)
g.Phrase = utils.DecryptAES(g.Phrase)
}
return lo.SliceToMap(gateways, func(g *model.Gateway) (int, *model.Gateway) {
return g.Id, g
}), nil
}
func processConcurrentBatches(assets []*model.Asset, gatewayMap map[int]*model.Gateway) []ConnectableResult {
batchSize := scheduleConfig.BatchSize
concurrentWorkers := scheduleConfig.ConcurrentWorkers
totalBatches := int(math.Ceil(float64(len(assets)) / float64(batchSize)))
resultChan := make(chan []ConnectableResult, totalBatches)
semaphore := make(chan struct{}, concurrentWorkers)
var wg sync.WaitGroup
// Process assets in batches
for i := 0; i < len(assets); i += batchSize {
end := min(i+batchSize, len(assets))
batch := assets[i:end]
wg.Add(1)
go func(batch []*model.Asset, batchNum int) {
defer wg.Done()
// Acquire semaphore
semaphore <- struct{}{}
defer func() { <-semaphore }()
logger.L().Debug("Processing connectivity batch",
zap.Int("batch_number", batchNum+1),
zap.Int("batch_size", len(batch)))
batchResults := processBatch(batch, gatewayMap)
resultChan <- batchResults
}(batch, i/batchSize)
}
// Close result channel when all goroutines complete
go func() {
wg.Wait()
close(resultChan)
}()
// Collect all results
var allResults []ConnectableResult
for batchResults := range resultChan {
allResults = append(allResults, batchResults...)
}
return allResults
}
func processBatch(assets []*model.Asset, gatewayMap map[int]*model.Gateway) []ConnectableResult {
results := make([]ConnectableResult, len(assets))
for i, asset := range assets {
gateway := gatewayMap[asset.GatewayId]
sessionID, success := updateConnectable(asset, gateway)
results[i] = ConnectableResult{
AssetID: asset.Id,
SessionID: sessionID,
Success: success,
}
}
return results
}
func updateConnectableStatus(results []ConnectableResult) error {
if len(results) == 0 {
return nil
}
// Collect session IDs for cleanup
sessionIDs := make([]string, 0, len(results))
successfulAssets := make([]int, 0)
failedAssets := make([]int, 0)
for _, result := range results {
sessionIDs = append(sessionIDs, result.SessionID)
if result.Success {
successfulAssets = append(successfulAssets, result.AssetID)
} else {
failedAssets = append(failedAssets, result.AssetID)
}
}
// Clean up tunnels
defer tunneling.CloseTunnels(sessionIDs...)
// Update successful assets
if len(successfulAssets) > 0 {
if err := dbpkg.DB.Model(&model.Asset{}).
Where("id IN ?", successfulAssets).
Update("connectable", true).Error; err != nil {
logger.L().Error("Failed to update successful assets",
zap.Error(err),
zap.Int("count", len(successfulAssets)))
return err
}
logger.L().Debug("Updated successful assets", zap.Int("count", len(successfulAssets)))
}
// Update failed assets
if len(failedAssets) > 0 {
if err := dbpkg.DB.Model(&model.Asset{}).
Where("id IN ?", failedAssets).
Update("connectable", false).Error; err != nil {
logger.L().Error("Failed to update failed assets",
zap.Error(err),
zap.Int("count", len(failedAssets)))
return err
}
logger.L().Debug("Updated failed assets", zap.Int("count", len(failedAssets)))
}
logger.L().Info("Connectivity status updated",
zap.Int("successful", len(successfulAssets)),
zap.Int("failed", len(failedAssets)),
zap.Int("total", len(results)))
return nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
sid = uuid.New().String()
ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string {
return strings.Split(p, ":")[0]
}), ",")
ip, port, err := tunneling.Proxy(true, sid, ps, asset, gateway)
if err != nil {
logger.L().Debug("Proxy connection failed",
zap.String("protocol", ps),
zap.String("asset_ip", asset.Ip),
zap.Int("asset_id", asset.Id),
zap.Error(err))
return
}
var hostPort string
if strings.Contains(ip, ":") {
hostPort = fmt.Sprintf("[%s]:%d", ip, port)
} else {
hostPort = fmt.Sprintf("%s:%d", ip, port)
}
// Use configurable timeout
conn, err := net.DialTimeout("tcp", hostPort, scheduleConfig.ConnectTimeout)
if err != nil {
logger.L().Debug("TCP connection failed",
zap.String("address", hostPort),
zap.Int("asset_id", asset.Id),
zap.Duration("timeout", scheduleConfig.ConnectTimeout),
zap.Error(err))
return
}
defer conn.Close()
// Verify gateway tunnel if using gateway
if asset.GatewayId != 0 {
t := tunneling.GetTunnelBySessionId(sid)
if t == nil {
logger.L().Debug("Gateway tunnel not found",
zap.String("session_id", sid),
zap.Int("asset_id", asset.Id),
zap.Int("gateway_id", asset.GatewayId))
return
}
select {
case err = <-t.Opened:
if err != nil {
logger.L().Debug("Gateway tunnel failed to open",
zap.String("session_id", sid),
zap.Int("asset_id", asset.Id),
zap.Int("gateway_id", asset.GatewayId),
zap.Error(err))
return
}
case <-time.After(scheduleConfig.ConnectTimeout):
logger.L().Debug("Gateway tunnel open timeout",
zap.String("session_id", sid),
zap.Int("asset_id", asset.Id),
zap.Int("gateway_id", asset.GatewayId),
zap.Duration("timeout", scheduleConfig.ConnectTimeout))
return
}
}
logger.L().Debug("Asset connectivity check successful",
zap.Int("asset_id", asset.Id),
zap.String("address", hostPort))
ok = true
return
}
// UpdateAssetConnectables is used by service/asset.go to update connectables
func UpdateAssetConnectables(ids ...int) error {
return UpdateConnectables(ids...)
}