mirror of
https://github.com/veops/oneterm.git
synced 2025-11-02 03:42:36 +08:00
refactor(backend): asset and account
This commit is contained in:
@@ -1,84 +1,53 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/veops/oneterm/internal/acl"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
"github.com/veops/oneterm/pkg/config"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
accountService = service.NewAccountService()
|
||||
|
||||
accountPreHooks = []preHook[*model.Account]{
|
||||
// Validate public key
|
||||
func(ctx *gin.Context, data *model.Account) {
|
||||
if data.AccountType == model.AUTHMETHOD_PUBLICKEY {
|
||||
if data.Phrase == "" {
|
||||
_, err := ssh.ParsePrivateKey([]byte(data.Pk))
|
||||
if err != nil {
|
||||
if err := accountService.ValidatePublicKey(data); err != nil {
|
||||
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrWrongPvk, Data: nil})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, err := ssh.ParsePrivateKeyWithPassphrase([]byte(data.Pk), []byte(data.Phrase))
|
||||
if err != nil {
|
||||
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrWrongPvk, Data: nil})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
// Encrypt sensitive data
|
||||
func(ctx *gin.Context, data *model.Account) {
|
||||
data.Password = utils.EncryptAES(data.Password)
|
||||
data.Pk = utils.EncryptAES(data.Pk)
|
||||
data.Phrase = utils.EncryptAES(data.Phrase)
|
||||
accountService.EncryptSensitiveData(data)
|
||||
},
|
||||
}
|
||||
|
||||
accountPostHooks = []postHook[*model.Account]{
|
||||
// Attach asset count
|
||||
func(ctx *gin.Context, data []*model.Account) {
|
||||
acs := make([]*model.AccountCount, 0)
|
||||
if err := dbpkg.DB.
|
||||
Model(&model.Authorization{}).
|
||||
Select("account_id AS id, COUNT(*) as count").
|
||||
Group("account_id").
|
||||
Where("account_id IN ?", lo.Map(data, func(d *model.Account, _ int) int { return d.Id })).
|
||||
Find(&acs).
|
||||
Error; err != nil {
|
||||
if err := accountService.AttachAssetCount(ctx, data); err != nil {
|
||||
return
|
||||
}
|
||||
m := lo.SliceToMap(acs, func(ac *model.AccountCount) (int, int64) { return ac.Id, ac.Count })
|
||||
for _, d := range data {
|
||||
d.AssetCount = m[d.Id]
|
||||
}
|
||||
},
|
||||
// Decrypt sensitive data
|
||||
func(ctx *gin.Context, data []*model.Account) {
|
||||
for _, d := range data {
|
||||
d.Password = utils.DecryptAES(d.Password)
|
||||
d.Pk = utils.DecryptAES(d.Pk)
|
||||
d.Phrase = utils.DecryptAES(d.Phrase)
|
||||
}
|
||||
accountService.DecryptSensitiveData(data)
|
||||
},
|
||||
}
|
||||
|
||||
accountDcs = []deleteCheck{
|
||||
// Check dependencies
|
||||
func(ctx *gin.Context, id int) {
|
||||
assetName := ""
|
||||
err := dbpkg.DB.
|
||||
Model(model.DefaultAsset).
|
||||
Select("name").
|
||||
Where("id = (?)", dbpkg.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)).
|
||||
First(&assetName).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
assetName, err := accountService.CheckAssetDependencies(ctx, id)
|
||||
if err == nil && assetName == "" {
|
||||
return
|
||||
}
|
||||
code := lo.Ternary(err == nil, http.StatusBadRequest, http.StatusInternalServerError)
|
||||
@@ -136,45 +105,37 @@ func (c *Controller) GetAccounts(ctx *gin.Context) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
info := cast.ToBool(ctx.Query("info"))
|
||||
|
||||
db := dbpkg.DB.Model(&model.Account{})
|
||||
db = filterEqual(ctx, db, "id", "type")
|
||||
db = filterLike(ctx, db, "name")
|
||||
db = filterSearch(ctx, db, "name", "account")
|
||||
if q, ok := ctx.GetQuery("ids"); ok {
|
||||
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","), func(s string, _ int) int { return cast.ToInt(s) }))
|
||||
}
|
||||
// Build base query using service layer
|
||||
db := accountService.BuildQuery(ctx)
|
||||
|
||||
// Apply select fields for info mode
|
||||
if info {
|
||||
db = db.Select("id", "name", "account")
|
||||
|
||||
// Apply authorization filter if needed
|
||||
if !acl.IsAdmin(currentUser) {
|
||||
ids, err := GetAccountIdsByAuthorization(ctx)
|
||||
if err != nil {
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
||||
return
|
||||
}
|
||||
db = db.Where("id IN ?", ids)
|
||||
}
|
||||
}
|
||||
|
||||
db = db.Order("name")
|
||||
|
||||
doGet(ctx, !info, db, config.RESOURCE_ACCOUNT, accountPostHooks...)
|
||||
}
|
||||
|
||||
func GetAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
|
||||
assetIds, err := GetAssetIdsByAuthorization(ctx)
|
||||
if err != nil {
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
||||
return
|
||||
}
|
||||
ss := make([]model.Slice[string], 0)
|
||||
if err = dbpkg.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
||||
return
|
||||
}
|
||||
ids = lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
|
||||
_, _, accountIds := getIdsByAuthorizationIds(ctx)
|
||||
ids = lo.Uniq(append(ids, accountIds...))
|
||||
|
||||
return
|
||||
// Filter accounts by asset IDs
|
||||
db = accountService.FilterByAssetIds(db, assetIds)
|
||||
}
|
||||
}
|
||||
|
||||
doGet(ctx, !info, db, config.RESOURCE_ACCOUNT, accountPostHooks...)
|
||||
}
|
||||
|
||||
// GetAccountIdsByAuthorization gets account IDs by authorization
|
||||
func GetAccountIdsByAuthorization(ctx *gin.Context) ([]int, error) {
|
||||
assetIds, err := GetAssetIdsByAuthorization(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, authorizationIds := getIdsByAuthorizationIds(ctx)
|
||||
|
||||
return accountService.GetAccountIdsByAuthorization(ctx, assetIds, authorizationIds)
|
||||
}
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cast"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/veops/oneterm/internal/acl"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
"github.com/veops/oneterm/internal/schedule"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
"github.com/veops/oneterm/pkg/config"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -28,16 +22,36 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
assetService = service.NewAssetService()
|
||||
|
||||
assetPreHooks = []preHook[*model.Asset]{
|
||||
// Preprocess asset data
|
||||
func(ctx *gin.Context, data *model.Asset) {
|
||||
data.Ip = strings.TrimSpace(data.Ip)
|
||||
data.Protocols = lo.Map(data.Protocols, func(s string, _ int) string { return strings.TrimSpace(s) })
|
||||
if data.Authorization == nil {
|
||||
data.Authorization = make(model.Map[int, model.Slice[int]])
|
||||
}
|
||||
assetService.PreprocessAssetData(data)
|
||||
},
|
||||
}
|
||||
assetPostHooks = []postHook[*model.Asset]{
|
||||
// Attach node chain
|
||||
func(ctx *gin.Context, data []*model.Asset) {
|
||||
if err := assetService.AttachNodeChain(ctx, data); err != nil {
|
||||
logger.L().Error("attach node chain failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
},
|
||||
// Apply authorization filters
|
||||
func(ctx *gin.Context, data []*model.Asset) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
if acl.IsAdmin(currentUser) {
|
||||
return
|
||||
}
|
||||
|
||||
authorizationIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
|
||||
nodeIds, _ := ctx.Value(kNodeIds).([]int)
|
||||
accountIds, _ := ctx.Value(kAccountIds).([]int)
|
||||
|
||||
assetService.ApplyAuthorizationFilters(ctx, data, authorizationIds, nodeIds, accountIds)
|
||||
},
|
||||
}
|
||||
assetPostHooks = []postHook[*model.Asset]{assetPostHookCount, assetPostHookAuth}
|
||||
)
|
||||
|
||||
// CreateAsset godoc
|
||||
@@ -50,7 +64,7 @@ func (c *Controller) CreateAsset(ctx *gin.Context) {
|
||||
asset := &model.Asset{}
|
||||
doCreate(ctx, true, asset, config.RESOURCE_ASSET, assetPreHooks...)
|
||||
|
||||
schedule.UpdateConnectables(asset.Id)
|
||||
assetService.UpdateConnectables(asset.Id)
|
||||
}
|
||||
|
||||
// DeleteAsset godoc
|
||||
@@ -71,8 +85,8 @@ func (c *Controller) DeleteAsset(ctx *gin.Context) {
|
||||
// @Success 200 {object} HttpResponse
|
||||
// @Router /asset/:id [put]
|
||||
func (c *Controller) UpdateAsset(ctx *gin.Context) {
|
||||
doUpdate(ctx, true, &model.Asset{}, config.RESOURCE_ASSET)
|
||||
schedule.UpdateConnectables(cast.ToInt(ctx.Param("id")))
|
||||
doUpdate(ctx, true, &model.Asset{}, config.RESOURCE_ASSET, assetPreHooks...)
|
||||
assetService.UpdateConnectables(cast.ToInt(ctx.Param("id")))
|
||||
}
|
||||
|
||||
// GetAssets godoc
|
||||
@@ -93,25 +107,28 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
info := cast.ToBool(ctx.Query("info"))
|
||||
|
||||
db := dbpkg.DB.Model(model.DefaultAsset)
|
||||
db = filterEqual(ctx, db, "id")
|
||||
db = filterLike(ctx, db, "name", "ip")
|
||||
db = filterSearch(ctx, db, "name", "ip")
|
||||
if q, ok := ctx.GetQuery("ids"); ok {
|
||||
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","), func(s string, _ int) int { return cast.ToInt(s) }))
|
||||
}
|
||||
if q, ok := ctx.GetQuery("parent_id"); ok {
|
||||
parentIds, err := handleParentId(ctx, cast.ToInt(q))
|
||||
// Build base query using service layer
|
||||
db, err := assetService.BuildQuery(ctx)
|
||||
if err != nil {
|
||||
logger.L().Error("parent id found failed", zap.Error(err))
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
||||
return
|
||||
}
|
||||
|
||||
// Apply parent_id filter if needed
|
||||
if q, ok := ctx.GetQuery("parent_id"); ok {
|
||||
db, err = assetService.FilterByParentId(db, cast.ToInt(q))
|
||||
if err != nil {
|
||||
logger.L().Error("parent id filtering failed", zap.Error(err))
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
|
||||
return
|
||||
}
|
||||
db = db.Where("parent_id IN ?", parentIds)
|
||||
}
|
||||
|
||||
// Apply info mode settings
|
||||
if info {
|
||||
db = db.Select("id", "parent_id", "name", "ip", "protocols", "connectable", "authorization")
|
||||
|
||||
// Apply authorization filter if needed
|
||||
if !acl.IsAdmin(currentUser) {
|
||||
ids, err := GetAssetIdsByAuthorization(ctx)
|
||||
if err != nil {
|
||||
@@ -122,149 +139,16 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
db = db.Order("name")
|
||||
|
||||
doGet(ctx, !info, db, config.RESOURCE_ASSET, assetPostHooks...)
|
||||
}
|
||||
|
||||
func assetPostHookCount(ctx *gin.Context, data []*model.Asset) {
|
||||
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
g := make(map[int][]model.Pair[int, string])
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], model.Pair[int, string]{First: n.Id, Second: n.Name})
|
||||
}
|
||||
m := make(map[int]string)
|
||||
var dfs func(int, string)
|
||||
dfs = func(x int, s string) {
|
||||
m[x] = s
|
||||
for _, node := range g[x] {
|
||||
dfs(node.First, fmt.Sprintf("%s/%s", s, node.Second))
|
||||
}
|
||||
}
|
||||
dfs(0, "")
|
||||
|
||||
for _, d := range data {
|
||||
d.NodeChain = m[d.ParentId]
|
||||
}
|
||||
}
|
||||
|
||||
func assetPostHookAuth(ctx *gin.Context, data []*model.Asset) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
if acl.IsAdmin(currentUser) {
|
||||
return
|
||||
}
|
||||
info := cast.ToBool(ctx.Query("info"))
|
||||
noInfoIds := make([]int, 0)
|
||||
if !info {
|
||||
t := dbpkg.DB.Model(model.DefaultAsset)
|
||||
assetResIds, _ := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
|
||||
t, _ = handleAssetIds(ctx, t, assetResIds)
|
||||
t.Pluck("id", &noInfoIds)
|
||||
}
|
||||
|
||||
authorizationIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
|
||||
nodeIds, _, accountIds := getIdsByAuthorizationIds(ctx)
|
||||
nodeIds, _ = handleSelfChild(ctx, nodeIds...)
|
||||
|
||||
for _, a := range data {
|
||||
if lo.Contains(nodeIds, a.ParentId) || lo.Contains(noInfoIds, a.Id) {
|
||||
continue
|
||||
}
|
||||
if lo.ContainsBy(authorizationIds, func(item *model.AuthorizationIds) bool {
|
||||
return item.AssetId == a.Id && item.NodeId == 0 && item.AccountId == 0
|
||||
}) {
|
||||
continue
|
||||
}
|
||||
ids := lo.Map(lo.Filter(authorizationIds, func(item *model.AuthorizationIds, _ int) bool {
|
||||
return item.AssetId == a.Id && item.AccountId != 0 && item.NodeId == 0
|
||||
}),
|
||||
func(item *model.AuthorizationIds, _ int) int { return item.AccountId })
|
||||
|
||||
for k := range a.Authorization {
|
||||
if !lo.Contains(ids, k) && !lo.Contains(accountIds, k) {
|
||||
delete(a.Authorization, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
|
||||
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
g := make(map[int][]int)
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], n.Id)
|
||||
}
|
||||
var dfs func(int)
|
||||
dfs = func(x int) {
|
||||
pids = append(pids, x)
|
||||
for _, y := range g[x] {
|
||||
dfs(y)
|
||||
}
|
||||
}
|
||||
dfs(parentId)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetAssetIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
|
||||
// GetAssetIdsByAuthorization gets asset IDs by authorization
|
||||
func GetAssetIdsByAuthorization(ctx *gin.Context) ([]int, error) {
|
||||
authIds, err := getAuthorizationIds(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
ctx.Set(kAuthorizationIds, authIds)
|
||||
|
||||
nodeIds, ids, accountIds := getIdsByAuthorizationIds(ctx)
|
||||
|
||||
tmp, err := handleSelfChild(ctx, nodeIds...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nodeIds = append(nodeIds, tmp...)
|
||||
ctx.Set(kNodeIds, nodeIds)
|
||||
ctx.Set(kAccountIds, accountIds)
|
||||
tmp, err = getAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ids = lo.Uniq(append(ids, tmp...))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getIdsByAuthorizationIds(ctx *gin.Context) (nodeIds, assetIds, accountIds []int) {
|
||||
authIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
|
||||
info := cast.ToBool(ctx.Query("info"))
|
||||
for _, a := range authIds {
|
||||
if a.NodeId != 0 && a.AssetId == 0 && a.AccountId == 0 {
|
||||
nodeIds = append(nodeIds, a.NodeId)
|
||||
}
|
||||
if a.AssetId != 0 && a.NodeId == 0 && (info || a.AccountId == 0) {
|
||||
assetIds = append(assetIds, a.AssetId)
|
||||
}
|
||||
if a.AccountId != 0 && a.AssetId == 0 && (info || a.NodeId == 0) {
|
||||
accountIds = append(accountIds, a.AccountId)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) {
|
||||
assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
|
||||
return lo.Contains(nodeIds, a.ParentId) || len(lo.Intersect(lo.Keys(a.Authorization), accountIds)) > 0
|
||||
})
|
||||
assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
|
||||
|
||||
return
|
||||
_, assetIds, _, err := assetService.GetAssetIdsByAuthorization(ctx, authIds)
|
||||
return assetIds, err
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds,
|
||||
}
|
||||
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(resIds, n.ResourceId) })
|
||||
nodeIds = lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||
nodeIds, err = handleSelfChild(ctx, nodeIds...)
|
||||
nodeIds, err = repository.HandleSelfChild(ctx, nodeIds...)
|
||||
return
|
||||
})
|
||||
|
||||
@@ -403,7 +403,7 @@ func hasAuthorization(ctx *gin.Context, sess *gsession.Session) (ok bool) {
|
||||
ctx.Set(kAuthorizationIds, authIds)
|
||||
|
||||
nodeIds, assetIds, accountIds := getIdsByAuthorizationIds(ctx)
|
||||
tmp, err := handleSelfChild(ctx, nodeIds...)
|
||||
tmp, err := repository.HandleSelfChild(ctx, nodeIds...)
|
||||
if err != nil {
|
||||
logger.L().Error("", zap.Error(err))
|
||||
return
|
||||
@@ -421,3 +421,15 @@ func hasAuthorization(ctx *gin.Context, sess *gsession.Session) (ok bool) {
|
||||
|
||||
return lo.Contains(ids, sess.AssetId)
|
||||
}
|
||||
|
||||
func getIdsByAuthorizationIds(ctx *gin.Context) (nodeIds, assetIds, accountIds []int) {
|
||||
authorizationIds, ok := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
|
||||
if !ok || len(authorizationIds) == 0 {
|
||||
return
|
||||
}
|
||||
return assetService.GetIdsByAuthorizationIds(ctx, authorizationIds)
|
||||
}
|
||||
|
||||
func getAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error) {
|
||||
return assetService.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
)
|
||||
@@ -389,7 +390,7 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
|
||||
}
|
||||
}()
|
||||
|
||||
ip, port, err := service.Proxy(false, sess.SessionId, "ssh", asset, gateway)
|
||||
ip, port, err := tunneling.Proxy(false, sess.SessionId, "ssh", asset, gateway)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -554,7 +555,7 @@ func connectOther(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
|
||||
}()
|
||||
|
||||
protocol := strings.Split(sess.Protocol, ":")[0]
|
||||
ip, port, err := service.Proxy(false, sess.SessionId, protocol, asset, gateway)
|
||||
ip, port, err := tunneling.Proxy(false, sess.SessionId, protocol, asset, gateway)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -501,9 +501,9 @@ func hasPerm[T model.Model](ctx context.Context, md T, resourceTypeName, action
|
||||
pids := make([]int, 0)
|
||||
switch t := any(md).(type) {
|
||||
case *model.Asset:
|
||||
pids, _ = handleSelfParent(ctx, t.ParentId)
|
||||
pids, _ = repository.HandleSelfParent(ctx, t.ParentId)
|
||||
case *model.Node:
|
||||
pids, _ = handleSelfParent(ctx, t.Id)
|
||||
pids, _ = repository.HandleSelfParent(ctx, t.Id)
|
||||
}
|
||||
|
||||
if len(pids) > 0 {
|
||||
@@ -619,7 +619,7 @@ func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB
|
||||
}
|
||||
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(resIds, n.ResourceId) })
|
||||
ids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||
if ids, err = handleSelfChild(ctx, ids...); err != nil {
|
||||
if ids, err = repository.HandleSelfChild(ctx, ids...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -633,7 +633,7 @@ func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB
|
||||
}
|
||||
ids = append(ids, lo.Map(assets, func(a *model.AssetIdPid, _ int) int { return a.ParentId })...)
|
||||
|
||||
ids, err = handleSelfParent(ctx, ids...)
|
||||
ids, err = repository.HandleSelfParent(ctx, ids...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -656,7 +656,7 @@ func handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.D
|
||||
}
|
||||
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
|
||||
nodeIds := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||
if nodeIds, err = handleSelfChild(ctx, nodeIds...); err != nil {
|
||||
if nodeIds, err = repository.HandleSelfChild(ctx, nodeIds...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if id, ok := ctx.GetQuery("self_parent"); ok {
|
||||
ids, err := handleSelfParent(ctx, cast.ToInt(id))
|
||||
ids, err := repository.HandleSelfParent(ctx, cast.ToInt(id))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ids, err = handleSelfParent(ctx, ids...); err != nil {
|
||||
if ids, err = repository.HandleSelfParent(ctx, ids...); err != nil {
|
||||
return
|
||||
}
|
||||
db = db.Where("id IN ?", ids)
|
||||
@@ -216,7 +216,7 @@ func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
|
||||
assetIds, _ := GetAssetIdsByAuthorization(ctx)
|
||||
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetIds, a.Id) })
|
||||
pids := lo.Map(assets, func(a *model.Asset, _ int) int { return a.ParentId })
|
||||
pids, _ = handleSelfParent(ctx, pids...)
|
||||
pids, _ = repository.HandleSelfParent(ctx, pids...)
|
||||
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(pids, n.Id) })
|
||||
ps := lo.SliceToMap(nodes, func(a *model.Node) (int, bool) { return a.ParentId, true })
|
||||
for _, n := range data {
|
||||
@@ -229,14 +229,14 @@ func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
|
||||
assetResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
|
||||
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetResIds, a.ResourceId) })
|
||||
pids = lo.Map(assets, func(n *model.Asset, _ int) int { return n.ParentId })
|
||||
pids, _ = handleSelfParent(ctx, pids...)
|
||||
pids, _ = repository.HandleSelfParent(ctx, pids...)
|
||||
return
|
||||
})
|
||||
eg.Go(func() (err error) {
|
||||
nodeResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
|
||||
ns := lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
|
||||
nids, _ = handleSelfChild(ctx, lo.Map(ns, func(n *model.Node, _ int) int { return n.Id })...)
|
||||
nids, _ = handleSelfParent(ctx, nids...)
|
||||
nids, _ = repository.HandleSelfChild(ctx, lo.Map(ns, func(n *model.Node, _ int) int { return n.Id })...)
|
||||
nids, _ = repository.HandleSelfParent(ctx, nids...)
|
||||
return
|
||||
})
|
||||
eg.Wait()
|
||||
@@ -261,61 +261,6 @@ func nodeDelHook(ctx *gin.Context, id int) {
|
||||
ctx.AbortWithError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
||||
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
g := make(map[int][]int)
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], n.Id)
|
||||
}
|
||||
var dfs func(int, bool)
|
||||
dfs = func(x int, b bool) {
|
||||
if b {
|
||||
res = append(res, x)
|
||||
}
|
||||
for _, y := range g[x] {
|
||||
dfs(y, b || lo.Contains(ids, x))
|
||||
}
|
||||
}
|
||||
dfs(0, false)
|
||||
|
||||
res = lo.Uniq(append(res, ids...))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
|
||||
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
g := make(map[int][]int)
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], n.Id)
|
||||
}
|
||||
t := make([]int, 0)
|
||||
var dfs func(int)
|
||||
dfs = func(x int) {
|
||||
t = append(t, x)
|
||||
if lo.Contains(ids, x) {
|
||||
res = append(res, t...)
|
||||
}
|
||||
for _, y := range g[x] {
|
||||
dfs(y)
|
||||
}
|
||||
t = t[:len(t)-1]
|
||||
}
|
||||
dfs(0)
|
||||
|
||||
res = lo.Uniq(append(res, ids...))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
||||
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
@@ -323,7 +268,7 @@ func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
||||
}
|
||||
allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||
|
||||
res, err = handleSelfChild(ctx, ids...)
|
||||
res, err = repository.HandleSelfChild(ctx, ids...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -339,7 +284,7 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error)
|
||||
}
|
||||
allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||
|
||||
res, err = handleSelfParent(ctx, ids...)
|
||||
res, err = repository.HandleSelfParent(ctx, ids...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
165
backend/internal/repository/account.go
Normal file
165
backend/internal/repository/account.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AccountRepository interface for account data access
|
||||
type AccountRepository interface {
|
||||
AttachAssetCount(ctx context.Context, accounts []*model.Account) error
|
||||
CheckAssetDependencies(ctx context.Context, id int) (string, error)
|
||||
BuildQuery(ctx *gin.Context) *gorm.DB
|
||||
FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB
|
||||
GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error)
|
||||
}
|
||||
|
||||
// accountRepository implements AccountRepository
|
||||
type accountRepository struct{}
|
||||
|
||||
// NewAccountRepository creates a new account repository
|
||||
func NewAccountRepository() AccountRepository {
|
||||
return &accountRepository{}
|
||||
}
|
||||
|
||||
// BuildQuery builds the base query for accounts with filters
|
||||
func (r *accountRepository) BuildQuery(ctx *gin.Context) *gorm.DB {
|
||||
db := dbpkg.DB.Model(&model.Account{})
|
||||
|
||||
// Apply filters
|
||||
db = r.filterEqual(ctx, db, "id", "type")
|
||||
db = r.filterLike(ctx, db, "name")
|
||||
db = r.filterSearch(ctx, db, "name", "account")
|
||||
|
||||
// Handle IDs parameter
|
||||
if q, ok := ctx.GetQuery("ids"); ok {
|
||||
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","),
|
||||
func(s string, _ int) int { return cast.ToInt(s) }))
|
||||
}
|
||||
|
||||
// Sort by name
|
||||
db = db.Order("name")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// FilterByAssetIds filters accounts by related asset IDs
|
||||
func (r *accountRepository) FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB {
|
||||
if len(assetIds) == 0 {
|
||||
return db.Where("0 = 1") // Return empty result if no asset IDs
|
||||
}
|
||||
|
||||
// 查询与指定资产关联的账户ID
|
||||
subQuery := dbpkg.DB.Model(&model.Authorization{}).
|
||||
Select("account_id").
|
||||
Where("asset_id IN ?", assetIds).
|
||||
Group("account_id")
|
||||
|
||||
return db.Where("id IN (?)", subQuery)
|
||||
}
|
||||
|
||||
// AttachAssetCount attaches asset count to accounts
|
||||
func (r *accountRepository) AttachAssetCount(ctx context.Context, accounts []*model.Account) error {
|
||||
acs := make([]*model.AccountCount, 0)
|
||||
if err := dbpkg.DB.
|
||||
Model(&model.Authorization{}).
|
||||
Select("account_id AS id, COUNT(*) as count").
|
||||
Group("account_id").
|
||||
Where("account_id IN ?", lo.Map(accounts, func(d *model.Account, _ int) int { return d.Id })).
|
||||
Find(&acs).
|
||||
Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := lo.SliceToMap(acs, func(ac *model.AccountCount) (int, int64) { return ac.Id, ac.Count })
|
||||
for _, d := range accounts {
|
||||
d.AssetCount = m[d.Id]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAssetDependencies checks if account has dependent assets
|
||||
func (r *accountRepository) CheckAssetDependencies(ctx context.Context, id int) (string, error) {
|
||||
var assetName string
|
||||
err := dbpkg.DB.
|
||||
Model(model.DefaultAsset).
|
||||
Select("name").
|
||||
Where("id = (?)", dbpkg.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)).
|
||||
First(&assetName).
|
||||
Error
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return assetName, errors.New("account has dependent assets")
|
||||
}
|
||||
|
||||
// GetAccountIdsByAuthorization gets account IDs by authorization and asset IDs
|
||||
func (r *accountRepository) GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error) {
|
||||
// 从资产的授权列表中获取账户ID
|
||||
ss := make([]model.Slice[string], 0)
|
||||
if err := dbpkg.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理从资产中获取的账户IDs
|
||||
accountIds := lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
|
||||
|
||||
// 合并从授权中获取的账户IDs
|
||||
return lo.Uniq(append(accountIds, authorizationIds...)), nil
|
||||
}
|
||||
|
||||
// Filter helpers
|
||||
func (r *accountRepository) filterEqual(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
|
||||
for _, f := range fields {
|
||||
if q, ok := ctx.GetQuery(f); ok {
|
||||
db = db.Where(fmt.Sprintf("%s = ?", f), q)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (r *accountRepository) filterLike(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
|
||||
likes := false
|
||||
d := dbpkg.DB
|
||||
for _, f := range fields {
|
||||
if q, ok := ctx.GetQuery(f); ok && q != "" {
|
||||
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
|
||||
likes = true
|
||||
}
|
||||
}
|
||||
if !likes {
|
||||
return db
|
||||
}
|
||||
db = db.Where(d)
|
||||
return db
|
||||
}
|
||||
|
||||
func (r *accountRepository) filterSearch(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
|
||||
q, ok := ctx.GetQuery("search")
|
||||
if !ok || len(fields) <= 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
d := dbpkg.DB
|
||||
for _, f := range fields {
|
||||
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
|
||||
}
|
||||
|
||||
db = db.Where(d)
|
||||
return db
|
||||
}
|
||||
289
backend/internal/repository/asset.go
Normal file
289
backend/internal/repository/asset.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/veops/oneterm/internal/acl"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/pkg/config"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
kFmtAssetIds = "assetIds-%d"
|
||||
kAuthorizationIds = "authorizationIds"
|
||||
kNodeIds = "nodeIds"
|
||||
kAccountIds = "accountIds"
|
||||
)
|
||||
|
||||
// AssetRepository interface for asset data access
|
||||
type AssetRepository interface {
|
||||
AttachNodeChain(ctx context.Context, assets []*model.Asset) error
|
||||
ApplyAuthorizationFilters(ctx *gin.Context, assets []*model.Asset, authorizationIds []*model.AuthorizationIds, nodeIds, accountIds []int)
|
||||
BuildQuery(ctx *gin.Context) (*gorm.DB, error)
|
||||
FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error)
|
||||
GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error)
|
||||
GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int)
|
||||
GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error)
|
||||
}
|
||||
|
||||
// assetRepository implements AssetRepository
|
||||
type assetRepository struct{}
|
||||
|
||||
// NewAssetRepository creates a new asset repository
|
||||
func NewAssetRepository() AssetRepository {
|
||||
return &assetRepository{}
|
||||
}
|
||||
|
||||
// BuildQuery builds the base query for assets with filters
|
||||
func (r *assetRepository) BuildQuery(ctx *gin.Context) (*gorm.DB, error) {
|
||||
db := dbpkg.DB.Model(model.DefaultAsset)
|
||||
|
||||
// Apply filters
|
||||
db = r.filterEqual(ctx, db, "id")
|
||||
db = r.filterLike(ctx, db, "name", "ip")
|
||||
db = r.filterSearch(ctx, db, "name", "ip")
|
||||
|
||||
// Handle IDs parameter
|
||||
if q, ok := ctx.GetQuery("ids"); ok {
|
||||
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","),
|
||||
func(s string, _ int) int { return cast.ToInt(s) }))
|
||||
}
|
||||
|
||||
// Sort by name
|
||||
db = db.Order("name")
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// FilterByParentId filters assets by parent ID and its children
|
||||
func (r *assetRepository) FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error) {
|
||||
parentIds, err := r.handleParentId(context.Background(), parentId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.Where("parent_id IN ?", parentIds), nil
|
||||
}
|
||||
|
||||
// AttachNodeChain attaches node chain to assets
|
||||
func (r *assetRepository) AttachNodeChain(ctx context.Context, assets []*model.Asset) error {
|
||||
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g := make(map[int][]model.Pair[int, string])
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], model.Pair[int, string]{First: n.Id, Second: n.Name})
|
||||
}
|
||||
|
||||
m := make(map[int]string)
|
||||
var dfs func(int, string)
|
||||
dfs = func(x int, s string) {
|
||||
m[x] = s
|
||||
for _, node := range g[x] {
|
||||
dfs(node.First, fmt.Sprintf("%s/%s", s, node.Second))
|
||||
}
|
||||
}
|
||||
dfs(0, "")
|
||||
|
||||
for _, d := range assets {
|
||||
d.NodeChain = m[d.ParentId]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyAuthorizationFilters applies authorization filters to assets
|
||||
func (r *assetRepository) ApplyAuthorizationFilters(ctx *gin.Context, assets []*model.Asset, authorizationIds []*model.AuthorizationIds, nodeIds, accountIds []int) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
if acl.IsAdmin(currentUser) {
|
||||
return
|
||||
}
|
||||
|
||||
info := cast.ToBool(ctx.Query("info"))
|
||||
noInfoIds := make([]int, 0)
|
||||
|
||||
if !info {
|
||||
t := dbpkg.DB.Model(model.DefaultAsset)
|
||||
assetResIds, _ := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
|
||||
t, _ = r.handleAssetIds(ctx, t, assetResIds)
|
||||
t.Pluck("id", &noInfoIds)
|
||||
}
|
||||
|
||||
for _, a := range assets {
|
||||
if lo.Contains(nodeIds, a.ParentId) || lo.Contains(noInfoIds, a.Id) {
|
||||
continue
|
||||
}
|
||||
if lo.ContainsBy(authorizationIds, func(item *model.AuthorizationIds) bool {
|
||||
return item.AssetId == a.Id && item.NodeId == 0 && item.AccountId == 0
|
||||
}) {
|
||||
continue
|
||||
}
|
||||
ids := lo.Map(lo.Filter(authorizationIds, func(item *model.AuthorizationIds, _ int) bool {
|
||||
return item.AssetId == a.Id && item.AccountId != 0 && item.NodeId == 0
|
||||
}),
|
||||
func(item *model.AuthorizationIds, _ int) int { return item.AccountId })
|
||||
|
||||
for k := range a.Authorization {
|
||||
if !lo.Contains(ids, k) && !lo.Contains(accountIds, k) {
|
||||
delete(a.Authorization, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAssetIdsByAuthorization gets asset IDs by authorization
|
||||
func (r *assetRepository) GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error) {
|
||||
ctx.Set(kAuthorizationIds, authorizationIds)
|
||||
|
||||
nodeIds, assetIds, accountIds := r.GetIdsByAuthorizationIds(ctx, authorizationIds)
|
||||
|
||||
tmp, err := HandleSelfChild(ctx, nodeIds...)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
nodeIds = append(nodeIds, tmp...)
|
||||
ctx.Set(kNodeIds, nodeIds)
|
||||
ctx.Set(kAccountIds, accountIds)
|
||||
|
||||
assetIdsFromNode, err := r.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
allAssetIds := lo.Uniq(append(assetIds, assetIdsFromNode...))
|
||||
|
||||
return nodeIds, allAssetIds, accountIds, nil
|
||||
}
|
||||
|
||||
// GetIdsByAuthorizationIds extracts node IDs, asset IDs, and account IDs from authorization IDs
|
||||
func (r *assetRepository) GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) (nodeIds, assetIds, accountIds []int) {
|
||||
info := cast.ToBool(ctx.Query("info"))
|
||||
for _, a := range authorizationIds {
|
||||
if a.NodeId != 0 && a.AssetId == 0 && a.AccountId == 0 {
|
||||
nodeIds = append(nodeIds, a.NodeId)
|
||||
}
|
||||
if a.AssetId != 0 && a.NodeId == 0 && (info || a.AccountId == 0) {
|
||||
assetIds = append(assetIds, a.AssetId)
|
||||
}
|
||||
if a.AccountId != 0 && a.AssetId == 0 && (info || a.NodeId == 0) {
|
||||
accountIds = append(accountIds, a.AccountId)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetAssetIdsByNodeAccount gets asset IDs by node IDs and account IDs
|
||||
func (r *assetRepository) GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) {
|
||||
assets, err := GetAllFromCacheDb(ctx, model.DefaultAsset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
|
||||
return lo.Contains(nodeIds, a.ParentId) || len(lo.Intersect(lo.Keys(a.Authorization), accountIds)) > 0
|
||||
})
|
||||
|
||||
assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
|
||||
return assetIds, nil
|
||||
}
|
||||
|
||||
// handleParentId builds a slice of parent IDs including the given parent ID and all its children
|
||||
func (r *assetRepository) handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
|
||||
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g := make(map[int][]int)
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], n.Id)
|
||||
}
|
||||
|
||||
var dfs func(int)
|
||||
dfs = func(x int) {
|
||||
pids = append(pids, x)
|
||||
for _, y := range g[x] {
|
||||
dfs(y)
|
||||
}
|
||||
}
|
||||
dfs(parentId)
|
||||
|
||||
return pids, nil
|
||||
}
|
||||
|
||||
// handleAssetIds filters assets by resource IDs and node IDs
|
||||
func (r *assetRepository) handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
|
||||
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
|
||||
nodeIds := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
|
||||
|
||||
if nodeIds, err = HandleSelfChild(ctx, nodeIds...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("parent_id IN?", nodeIds)
|
||||
db = dbFind.Where(d)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Filter helpers
|
||||
func (r *assetRepository) filterEqual(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
|
||||
for _, f := range fields {
|
||||
if q, ok := ctx.GetQuery(f); ok {
|
||||
db = db.Where(fmt.Sprintf("%s = ?", f), q)
|
||||
}
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (r *assetRepository) filterLike(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
|
||||
likes := false
|
||||
d := dbpkg.DB
|
||||
for _, f := range fields {
|
||||
if q, ok := ctx.GetQuery(f); ok && q != "" {
|
||||
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
|
||||
likes = true
|
||||
}
|
||||
}
|
||||
if !likes {
|
||||
return db
|
||||
}
|
||||
db = db.Where(d)
|
||||
return db
|
||||
}
|
||||
|
||||
func (r *assetRepository) filterSearch(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
|
||||
q, ok := ctx.GetQuery("search")
|
||||
if !ok || len(fields) <= 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
d := dbpkg.DB
|
||||
for _, f := range fields {
|
||||
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
|
||||
}
|
||||
|
||||
db = db.Where(d)
|
||||
return db
|
||||
}
|
||||
67
backend/internal/repository/node.go
Normal file
67
backend/internal/repository/node.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
)
|
||||
|
||||
// HandleSelfChild gets IDs of nodes that are children of the specified node IDs
|
||||
func HandleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
|
||||
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g := make(map[int][]int)
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], n.Id)
|
||||
}
|
||||
|
||||
var dfs func(int, bool)
|
||||
dfs = func(x int, b bool) {
|
||||
if b {
|
||||
res = append(res, x)
|
||||
}
|
||||
for _, y := range g[x] {
|
||||
dfs(y, b || lo.Contains(ids, x))
|
||||
}
|
||||
}
|
||||
dfs(0, false)
|
||||
|
||||
res = lo.Uniq(append(res, ids...))
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// HandleSelfParent gets IDs of nodes that are parents of the specified node IDs
|
||||
func HandleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
|
||||
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g := make(map[int][]int)
|
||||
for _, n := range nodes {
|
||||
g[n.ParentId] = append(g[n.ParentId], n.Id)
|
||||
}
|
||||
|
||||
t := make([]int, 0)
|
||||
var dfs func(int)
|
||||
dfs = func(x int) {
|
||||
t = append(t, x)
|
||||
if lo.Contains(ids, x) {
|
||||
res = append(res, t...)
|
||||
}
|
||||
for _, y := range g[x] {
|
||||
dfs(y)
|
||||
}
|
||||
t = t[:len(t)-1]
|
||||
}
|
||||
dfs(0)
|
||||
|
||||
res = lo.Uniq(append(res, ids...))
|
||||
|
||||
return res, nil
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
@@ -81,7 +80,7 @@ func UpdateConnectables(ids ...int) (err error) {
|
||||
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
|
||||
sid = uuid.New().String()
|
||||
ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",")
|
||||
ip, port, err := service.Proxy(true, sid, ps, asset, gateway)
|
||||
ip, port, err := tunneling.Proxy(true, sid, ps, asset, gateway)
|
||||
if err != nil {
|
||||
logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err))
|
||||
return
|
||||
@@ -110,3 +109,8 @@ func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string,
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateAssetConnectables is used by service/asset.go to update connectables
|
||||
func UpdateAssetConnectables(ids ...int) error {
|
||||
return UpdateConnectables(ids...)
|
||||
}
|
||||
|
||||
80
backend/internal/service/account.go
Normal file
80
backend/internal/service/account.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
"github.com/veops/oneterm/pkg/utils"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AccountService handles account business logic
|
||||
type AccountService struct {
|
||||
repo repository.AccountRepository
|
||||
}
|
||||
|
||||
// NewAccountService creates a new account service
|
||||
func NewAccountService() *AccountService {
|
||||
return &AccountService{
|
||||
repo: repository.NewAccountRepository(),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePublicKey validates the given public key
|
||||
func (s *AccountService) ValidatePublicKey(account *model.Account) error {
|
||||
if account.AccountType != model.AUTHMETHOD_PUBLICKEY {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if account.Phrase == "" {
|
||||
_, err = ssh.ParsePrivateKey([]byte(account.Pk))
|
||||
} else {
|
||||
_, err = ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts sensitive account data
|
||||
func (s *AccountService) EncryptSensitiveData(account *model.Account) {
|
||||
account.Password = utils.EncryptAES(account.Password)
|
||||
account.Pk = utils.EncryptAES(account.Pk)
|
||||
account.Phrase = utils.EncryptAES(account.Phrase)
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts sensitive account data
|
||||
func (s *AccountService) DecryptSensitiveData(accounts []*model.Account) {
|
||||
for _, a := range accounts {
|
||||
a.Password = utils.DecryptAES(a.Password)
|
||||
a.Pk = utils.DecryptAES(a.Pk)
|
||||
a.Phrase = utils.DecryptAES(a.Phrase)
|
||||
}
|
||||
}
|
||||
|
||||
// AttachAssetCount attaches asset count to accounts
|
||||
func (s *AccountService) AttachAssetCount(ctx context.Context, accounts []*model.Account) error {
|
||||
return s.repo.AttachAssetCount(ctx, accounts)
|
||||
}
|
||||
|
||||
// CheckAssetDependencies checks if account has dependent assets
|
||||
func (s *AccountService) CheckAssetDependencies(ctx context.Context, id int) (string, error) {
|
||||
return s.repo.CheckAssetDependencies(ctx, id)
|
||||
}
|
||||
|
||||
// BuildQuery constructs account query with basic filters
|
||||
func (s *AccountService) BuildQuery(ctx *gin.Context) *gorm.DB {
|
||||
return s.repo.BuildQuery(ctx)
|
||||
}
|
||||
|
||||
// FilterByAssetIds filters accounts by related asset IDs
|
||||
func (s *AccountService) FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB {
|
||||
return s.repo.FilterByAssetIds(db, assetIds)
|
||||
}
|
||||
|
||||
// GetAccountIdsByAuthorization gets account IDs by authorization
|
||||
func (s *AccountService) GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error) {
|
||||
return s.repo.GetAccountIdsByAuthorization(ctx, assetIds, authorizationIds)
|
||||
}
|
||||
74
backend/internal/service/asset.go
Normal file
74
backend/internal/service/asset.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
"github.com/veops/oneterm/internal/schedule"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AssetService handles asset business logic
|
||||
type AssetService struct {
|
||||
repo repository.AssetRepository
|
||||
}
|
||||
|
||||
// NewAssetService creates a new asset service
|
||||
func NewAssetService() *AssetService {
|
||||
return &AssetService{
|
||||
repo: repository.NewAssetRepository(),
|
||||
}
|
||||
}
|
||||
|
||||
// PreprocessAssetData preprocesses asset data before saving
|
||||
func (s *AssetService) PreprocessAssetData(asset *model.Asset) {
|
||||
asset.Ip = strings.TrimSpace(asset.Ip)
|
||||
asset.Protocols = lo.Map(asset.Protocols, func(s string, _ int) string { return strings.TrimSpace(s) })
|
||||
if asset.Authorization == nil {
|
||||
asset.Authorization = make(model.Map[int, model.Slice[int]])
|
||||
}
|
||||
}
|
||||
|
||||
// AttachNodeChain attaches node chain to assets
|
||||
func (s *AssetService) AttachNodeChain(ctx context.Context, assets []*model.Asset) error {
|
||||
return s.repo.AttachNodeChain(ctx, assets)
|
||||
}
|
||||
|
||||
// ApplyAuthorizationFilters applies authorization filters to assets
|
||||
func (s *AssetService) ApplyAuthorizationFilters(ctx *gin.Context, assets []*model.Asset, authorizationIds []*model.AuthorizationIds, nodeIds, accountIds []int) {
|
||||
s.repo.ApplyAuthorizationFilters(ctx, assets, authorizationIds, nodeIds, accountIds)
|
||||
}
|
||||
|
||||
// BuildQuery constructs asset query with basic filters
|
||||
func (s *AssetService) BuildQuery(ctx *gin.Context) (*gorm.DB, error) {
|
||||
return s.repo.BuildQuery(ctx)
|
||||
}
|
||||
|
||||
// FilterByParentId filters assets by parent ID
|
||||
func (s *AssetService) FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error) {
|
||||
return s.repo.FilterByParentId(db, parentId)
|
||||
}
|
||||
|
||||
// GetAssetIdsByAuthorization gets asset IDs by authorization
|
||||
func (s *AssetService) GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error) {
|
||||
return s.repo.GetAssetIdsByAuthorization(ctx, authorizationIds)
|
||||
}
|
||||
|
||||
// GetIdsByAuthorizationIds extracts node IDs, asset IDs, and account IDs from authorization IDs
|
||||
func (s *AssetService) GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int) {
|
||||
return s.repo.GetIdsByAuthorizationIds(ctx, authorizationIds)
|
||||
}
|
||||
|
||||
// GetAssetIdsByNodeAccount gets asset IDs by node IDs and account IDs
|
||||
func (s *AssetService) GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error) {
|
||||
return s.repo.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
|
||||
}
|
||||
|
||||
// UpdateConnectables updates asset connectability status
|
||||
func (s *AssetService) UpdateConnectables(ids ...int) error {
|
||||
return schedule.UpdateAssetConnectables(ids...)
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,7 +75,7 @@ func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client,
|
||||
return
|
||||
}
|
||||
|
||||
ip, port, err := Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
|
||||
ip, port, err := tunneling.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,13 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/utils"
|
||||
)
|
||||
@@ -58,27 +55,3 @@ func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
|
||||
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
|
||||
}
|
||||
}
|
||||
|
||||
func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
|
||||
ip, port = asset.Ip, 0
|
||||
for _, tp := range strings.Split(protocol, ",") {
|
||||
for _, p := range asset.Protocols {
|
||||
if strings.HasPrefix(strings.ToLower(p), tp) {
|
||||
if port = cast.ToInt(strings.Split(p, ":")[1]); port != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if asset.GatewayId == 0 || gateway == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g, err := tunneling.OpenTunnel(isConnectable, sessionId, ip, port, gateway)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ip, port = g.LocalIp, g.LocalPort
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
@@ -231,3 +233,28 @@ func getAvailablePort() (int, error) {
|
||||
|
||||
return l.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
|
||||
// Proxy establishes a proxy connection to an asset through a gateway if necessary
|
||||
func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
|
||||
ip, port = asset.Ip, 0
|
||||
for _, tp := range strings.Split(protocol, ",") {
|
||||
for _, p := range asset.Protocols {
|
||||
if strings.HasPrefix(strings.ToLower(p), tp) {
|
||||
if port = cast.ToInt(strings.Split(p, ":")[1]); port != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if asset.GatewayId == 0 || gateway == nil {
|
||||
return
|
||||
}
|
||||
|
||||
g, err := OpenTunnel(isConnectable, sessionId, ip, port, gateway)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ip, port = g.LocalIp, g.LocalPort
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user