Files
oneterm/backend/internal/repository/asset.go
2025-07-16 18:11:04 +08:00

267 lines
8.1 KiB
Go

package repository
import (
"context"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/utils"
"gorm.io/gorm"
)
const (
kFmtAssetIds = "assetIds-%d"
kAuthorizationIds = "authorizationIds"
kNodeIds = "nodeIds"
kAccountIds = "accountIds"
)
// AssetRepository interface for asset data access
type AssetRepository interface {
GetById(ctx context.Context, id int) (*model.Asset, error)
AttachNodeChain(ctx context.Context, assets []*model.Asset) error
BuildQuery(ctx *gin.Context) (*gorm.DB, error)
FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error)
GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error)
GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int)
GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error)
}
// assetRepository implements AssetRepository
type assetRepository struct{}
// NewAssetRepository creates a new asset repository
func NewAssetRepository() AssetRepository {
return &assetRepository{}
}
// GetById retrieves an asset by its ID
func (r *assetRepository) GetById(ctx context.Context, id int) (*model.Asset, error) {
asset := &model.Asset{}
err := dbpkg.DB.Where("id = ?", id).First(asset).Error
return asset, err
}
// BuildQuery builds the base query for assets with filters
func (r *assetRepository) BuildQuery(ctx *gin.Context) (*gorm.DB, error) {
db := dbpkg.DB.Model(model.DefaultAsset)
// Apply filters
db = dbpkg.FilterEqual(ctx, db, "id")
db = dbpkg.FilterLike(ctx, db, "name", "ip")
db = dbpkg.FilterSearch(ctx, db, "name", "ip")
// Handle IDs parameter
if q, ok := ctx.GetQuery("ids"); ok {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","),
func(s string, _ int) int { return cast.ToInt(s) }))
}
// Sort by name
db = db.Order("name")
return db, nil
}
// FilterByParentId filters assets by parent ID and its children
func (r *assetRepository) FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error) {
parentIds, err := r.handleParentId(context.Background(), parentId)
if err != nil {
return nil, err
}
return db.Where("parent_id IN ?", parentIds), nil
}
// AttachNodeChain attaches node chain to assets
func (r *assetRepository) AttachNodeChain(ctx context.Context, assets []*model.Asset) error {
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return err
}
g := make(map[int][]model.Pair[int, string])
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], model.Pair[int, string]{First: n.Id, Second: n.Name})
}
m := make(map[int]string)
var dfs func(int, string)
dfs = func(x int, s string) {
m[x] = s
for _, node := range g[x] {
dfs(node.First, fmt.Sprintf("%s/%s", s, node.Second))
}
}
dfs(0, "")
for _, d := range assets {
d.NodeChain = m[d.ParentId]
}
return nil
}
// GetAssetIdsByAuthorization gets asset IDs by authorization
func (r *assetRepository) GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error) {
ctx.Set(kAuthorizationIds, authorizationIds)
nodeIds, assetIds, accountIds := r.GetIdsByAuthorizationIds(ctx, authorizationIds)
tmp, err := HandleSelfChild(ctx, nodeIds...)
if err != nil {
return nil, nil, nil, err
}
nodeIds = append(nodeIds, tmp...)
ctx.Set(kNodeIds, nodeIds)
ctx.Set(kAccountIds, accountIds)
assetIdsFromNode, err := r.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
if err != nil {
return nil, nil, nil, err
}
allAssetIds := lo.Uniq(append(assetIds, assetIdsFromNode...))
return nodeIds, allAssetIds, accountIds, nil
}
// GetIdsByAuthorizationIds extracts node IDs, asset IDs, and account IDs from authorization IDs
func (r *assetRepository) GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) (nodeIds, assetIds, accountIds []int) {
info := cast.ToBool(ctx.Query("info"))
for _, a := range authorizationIds {
if a.NodeId != 0 && a.AssetId == 0 && a.AccountId == 0 {
nodeIds = append(nodeIds, a.NodeId)
}
if a.AssetId != 0 && a.NodeId == 0 && (info || a.AccountId == 0) {
assetIds = append(assetIds, a.AssetId)
}
if a.AccountId != 0 && a.AssetId == 0 && (info || a.NodeId == 0) {
accountIds = append(accountIds, a.AccountId)
}
}
return
}
// GetAssetIdsByNodeAccount gets asset IDs by node IDs and account IDs
func (r *assetRepository) GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) {
assets, err := GetAllFromCacheDb(ctx, model.DefaultAsset)
if err != nil {
return nil, err
}
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
return lo.Contains(nodeIds, a.ParentId) || len(lo.Intersect(lo.Keys(a.Authorization), accountIds)) > 0
})
assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
return assetIds, nil
}
// handleParentId builds a slice of parent IDs including the given parent ID and all its children
func (r *assetRepository) handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return nil, err
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int)
dfs = func(x int) {
pids = append(pids, x)
for _, y := range g[x] {
dfs(y)
}
}
dfs(parentId)
return pids, nil
}
// handleAssetIds filters assets by resource IDs and node IDs
func (r *assetRepository) handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return nil, err
}
nodeResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
if err != nil {
return nil, err
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
nodeIds := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
if nodeIds, err = HandleSelfChild(ctx, nodeIds...); err != nil {
return nil, err
}
d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("parent_id IN?", nodeIds)
db = dbFind.Where(d)
return db, nil
}
// HandleAssetIds filters asset queries based on resource IDs
func HandleAssetIds(ctx context.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return
}
nodeResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
if err != nil {
return
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
nodeIds := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
if nodeIds, err = HandleSelfChild(ctx, nodeIds...); err != nil {
return
}
d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("parent_id IN?", nodeIds)
db = dbFind.Where(d)
return
}
// GetAAG retrieves Asset, Account, and Gateway by their IDs with decrypted credentials
func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) {
asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{}
if err = dbpkg.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil {
return
}
if err = dbpkg.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil {
return
}
account.Password = utils.DecryptAES(account.Password)
account.Pk = utils.DecryptAES(account.Pk)
account.Phrase = utils.DecryptAES(account.Phrase)
if asset.GatewayId != 0 {
if err = dbpkg.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
return
}
gateway.Password = utils.DecryptAES(gateway.Password)
gateway.Pk = utils.DecryptAES(gateway.Pk)
gateway.Phrase = utils.DecryptAES(gateway.Phrase)
}
return
}