Files
oneterm/backend/internal/repository/account.go
2025-06-08 20:38:02 +08:00

174 lines
5.3 KiB
Go

package repository
import (
"context"
"errors"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"golang.org/x/crypto/ssh"
"gorm.io/gorm"
)
// AccountRepository interface for account data access
type AccountRepository interface {
AttachAssetCount(ctx context.Context, accounts []*model.Account) error
CheckAssetDependencies(ctx context.Context, id int) (string, error)
BuildQuery(ctx *gin.Context) *gorm.DB
FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB
GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error)
}
// accountRepository implements AccountRepository
type accountRepository struct{}
// NewAccountRepository creates a new account repository
func NewAccountRepository() AccountRepository {
return &accountRepository{}
}
// BuildQuery builds the base query for accounts with filters
func (r *accountRepository) BuildQuery(ctx *gin.Context) *gorm.DB {
db := dbpkg.DB.Model(&model.Account{})
// Apply filters
db = dbpkg.FilterEqual(ctx, db, "id", "type")
db = dbpkg.FilterLike(ctx, db, "name")
db = dbpkg.FilterSearch(ctx, db, "name", "account")
// Handle IDs parameter
if q, ok := ctx.GetQuery("ids"); ok {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","),
func(s string, _ int) int { return cast.ToInt(s) }))
}
// Sort by name
db = db.Order("name")
return db
}
// FilterByAssetIds filters accounts by related asset IDs
func (r *accountRepository) FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB {
if len(assetIds) == 0 {
return db.Where("0 = 1") // Return empty result if no asset IDs
}
// Query account IDs associated with specified assets
subQuery := dbpkg.DB.Model(&model.Authorization{}).
Select("account_id").
Where("asset_id IN ?", assetIds).
Group("account_id")
return db.Where("id IN (?)", subQuery)
}
// AttachAssetCount attaches asset count to accounts
func (r *accountRepository) AttachAssetCount(ctx context.Context, accounts []*model.Account) error {
acs := make([]*model.AccountCount, 0)
if err := dbpkg.DB.
Model(&model.Authorization{}).
Select("account_id AS id, COUNT(*) as count").
Group("account_id").
Where("account_id IN ?", lo.Map(accounts, func(d *model.Account, _ int) int { return d.Id })).
Find(&acs).
Error; err != nil {
return err
}
m := lo.SliceToMap(acs, func(ac *model.AccountCount) (int, int64) { return ac.Id, ac.Count })
for _, d := range accounts {
d.AssetCount = m[d.Id]
}
return nil
}
// CheckAssetDependencies checks if account has dependent assets
func (r *accountRepository) CheckAssetDependencies(ctx context.Context, id int) (string, error) {
var assetName string
err := dbpkg.DB.
Model(model.DefaultAsset).
Select("name").
Where("id = (?)", dbpkg.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)).
First(&assetName).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil
}
if err != nil {
return "", err
}
return assetName, errors.New("account has dependent assets")
}
// GetAccountIdsByAuthorization gets account IDs by authorization and asset IDs
func (r *accountRepository) GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error) {
// Get account IDs from assets' authorization lists
ss := make([]model.Slice[string], 0)
if err := dbpkg.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
return nil, err
}
// Process account IDs obtained from assets
accountIds := lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
// Merge with account IDs from authorizations
return lo.Uniq(append(accountIds, authorizationIds...)), nil
}
// HandleAccountIds filters account queries based on resource IDs
func HandleAccountIds(ctx context.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
assetResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
if err != nil {
return
}
t, _ := HandleAssetIds(ctx, dbpkg.DB.Model(model.DefaultAsset), assetResIds)
ss := make([]model.Slice[string], 0)
if err = t.Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
return
}
ids := lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("id IN ?", ids)
db = dbFind.Where(d)
return
}
// GetAuth creates SSH authentication method from account credentials
func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
switch account.AccountType {
case model.AUTHMETHOD_PASSWORD:
return ssh.Password(account.Password), nil
case model.AUTHMETHOD_PUBLICKEY:
if account.Phrase == "" {
pk, err := ssh.ParsePrivateKey([]byte(account.Pk))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
} else {
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
}
default:
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
}
}