refactor(backend): asset and account

This commit is contained in:
pycook
2025-05-05 11:15:42 +08:00
parent bf3bd1dc40
commit b12040b3c4
15 changed files with 825 additions and 341 deletions

View File

@@ -1,84 +1,53 @@
package controller
import (
"errors"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"golang.org/x/crypto/ssh"
"gorm.io/gorm"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/utils"
)
var (
accountService = service.NewAccountService()
accountPreHooks = []preHook[*model.Account]{
// Validate public key
func(ctx *gin.Context, data *model.Account) {
if data.AccountType == model.AUTHMETHOD_PUBLICKEY {
if data.Phrase == "" {
_, err := ssh.ParsePrivateKey([]byte(data.Pk))
if err != nil {
if err := accountService.ValidatePublicKey(data); err != nil {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrWrongPvk, Data: nil})
return
}
} else {
_, err := ssh.ParsePrivateKeyWithPassphrase([]byte(data.Pk), []byte(data.Phrase))
if err != nil {
ctx.AbortWithError(http.StatusBadRequest, &ApiError{Code: ErrWrongPvk, Data: nil})
return
}
}
}
},
// Encrypt sensitive data
func(ctx *gin.Context, data *model.Account) {
data.Password = utils.EncryptAES(data.Password)
data.Pk = utils.EncryptAES(data.Pk)
data.Phrase = utils.EncryptAES(data.Phrase)
accountService.EncryptSensitiveData(data)
},
}
accountPostHooks = []postHook[*model.Account]{
// Attach asset count
func(ctx *gin.Context, data []*model.Account) {
acs := make([]*model.AccountCount, 0)
if err := dbpkg.DB.
Model(&model.Authorization{}).
Select("account_id AS id, COUNT(*) as count").
Group("account_id").
Where("account_id IN ?", lo.Map(data, func(d *model.Account, _ int) int { return d.Id })).
Find(&acs).
Error; err != nil {
if err := accountService.AttachAssetCount(ctx, data); err != nil {
return
}
m := lo.SliceToMap(acs, func(ac *model.AccountCount) (int, int64) { return ac.Id, ac.Count })
for _, d := range data {
d.AssetCount = m[d.Id]
}
},
// Decrypt sensitive data
func(ctx *gin.Context, data []*model.Account) {
for _, d := range data {
d.Password = utils.DecryptAES(d.Password)
d.Pk = utils.DecryptAES(d.Pk)
d.Phrase = utils.DecryptAES(d.Phrase)
}
accountService.DecryptSensitiveData(data)
},
}
accountDcs = []deleteCheck{
// Check dependencies
func(ctx *gin.Context, id int) {
assetName := ""
err := dbpkg.DB.
Model(model.DefaultAsset).
Select("name").
Where("id = (?)", dbpkg.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)).
First(&assetName).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
assetName, err := accountService.CheckAssetDependencies(ctx, id)
if err == nil && assetName == "" {
return
}
code := lo.Ternary(err == nil, http.StatusBadRequest, http.StatusInternalServerError)
@@ -136,45 +105,37 @@ func (c *Controller) GetAccounts(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
info := cast.ToBool(ctx.Query("info"))
db := dbpkg.DB.Model(&model.Account{})
db = filterEqual(ctx, db, "id", "type")
db = filterLike(ctx, db, "name")
db = filterSearch(ctx, db, "name", "account")
if q, ok := ctx.GetQuery("ids"); ok {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","), func(s string, _ int) int { return cast.ToInt(s) }))
}
// Build base query using service layer
db := accountService.BuildQuery(ctx)
// Apply select fields for info mode
if info {
db = db.Select("id", "name", "account")
// Apply authorization filter if needed
if !acl.IsAdmin(currentUser) {
ids, err := GetAccountIdsByAuthorization(ctx)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
db = db.Where("id IN ?", ids)
}
}
db = db.Order("name")
doGet(ctx, !info, db, config.RESOURCE_ACCOUNT, accountPostHooks...)
}
func GetAccountIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
assetIds, err := GetAssetIdsByAuthorization(ctx)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
ss := make([]model.Slice[string], 0)
if err = dbpkg.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
ids = lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
_, _, accountIds := getIdsByAuthorizationIds(ctx)
ids = lo.Uniq(append(ids, accountIds...))
return
// Filter accounts by asset IDs
db = accountService.FilterByAssetIds(db, assetIds)
}
}
doGet(ctx, !info, db, config.RESOURCE_ACCOUNT, accountPostHooks...)
}
// GetAccountIdsByAuthorization gets account IDs by authorization
func GetAccountIdsByAuthorization(ctx *gin.Context) ([]int, error) {
assetIds, err := GetAssetIdsByAuthorization(ctx)
if err != nil {
return nil, err
}
_, _, authorizationIds := getIdsByAuthorizationIds(ctx)
return accountService.GetAccountIdsByAuthorization(ctx, assetIds, authorizationIds)
}

View File

@@ -1,22 +1,16 @@
package controller
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"go.uber.org/zap"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/internal/schedule"
"github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
)
@@ -28,16 +22,36 @@ const (
)
var (
assetService = service.NewAssetService()
assetPreHooks = []preHook[*model.Asset]{
// Preprocess asset data
func(ctx *gin.Context, data *model.Asset) {
data.Ip = strings.TrimSpace(data.Ip)
data.Protocols = lo.Map(data.Protocols, func(s string, _ int) string { return strings.TrimSpace(s) })
if data.Authorization == nil {
data.Authorization = make(model.Map[int, model.Slice[int]])
}
assetService.PreprocessAssetData(data)
},
}
assetPostHooks = []postHook[*model.Asset]{
// Attach node chain
func(ctx *gin.Context, data []*model.Asset) {
if err := assetService.AttachNodeChain(ctx, data); err != nil {
logger.L().Error("attach node chain failed", zap.Error(err))
return
}
},
// Apply authorization filters
func(ctx *gin.Context, data []*model.Asset) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
if acl.IsAdmin(currentUser) {
return
}
authorizationIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
nodeIds, _ := ctx.Value(kNodeIds).([]int)
accountIds, _ := ctx.Value(kAccountIds).([]int)
assetService.ApplyAuthorizationFilters(ctx, data, authorizationIds, nodeIds, accountIds)
},
}
assetPostHooks = []postHook[*model.Asset]{assetPostHookCount, assetPostHookAuth}
)
// CreateAsset godoc
@@ -50,7 +64,7 @@ func (c *Controller) CreateAsset(ctx *gin.Context) {
asset := &model.Asset{}
doCreate(ctx, true, asset, config.RESOURCE_ASSET, assetPreHooks...)
schedule.UpdateConnectables(asset.Id)
assetService.UpdateConnectables(asset.Id)
}
// DeleteAsset godoc
@@ -71,8 +85,8 @@ func (c *Controller) DeleteAsset(ctx *gin.Context) {
// @Success 200 {object} HttpResponse
// @Router /asset/:id [put]
func (c *Controller) UpdateAsset(ctx *gin.Context) {
doUpdate(ctx, true, &model.Asset{}, config.RESOURCE_ASSET)
schedule.UpdateConnectables(cast.ToInt(ctx.Param("id")))
doUpdate(ctx, true, &model.Asset{}, config.RESOURCE_ASSET, assetPreHooks...)
assetService.UpdateConnectables(cast.ToInt(ctx.Param("id")))
}
// GetAssets godoc
@@ -93,25 +107,28 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
info := cast.ToBool(ctx.Query("info"))
db := dbpkg.DB.Model(model.DefaultAsset)
db = filterEqual(ctx, db, "id")
db = filterLike(ctx, db, "name", "ip")
db = filterSearch(ctx, db, "name", "ip")
if q, ok := ctx.GetQuery("ids"); ok {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","), func(s string, _ int) int { return cast.ToInt(s) }))
}
if q, ok := ctx.GetQuery("parent_id"); ok {
parentIds, err := handleParentId(ctx, cast.ToInt(q))
// Build base query using service layer
db, err := assetService.BuildQuery(ctx)
if err != nil {
logger.L().Error("parent id found failed", zap.Error(err))
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
// Apply parent_id filter if needed
if q, ok := ctx.GetQuery("parent_id"); ok {
db, err = assetService.FilterByParentId(db, cast.ToInt(q))
if err != nil {
logger.L().Error("parent id filtering failed", zap.Error(err))
ctx.AbortWithError(http.StatusInternalServerError, &ApiError{Code: ErrInternal, Data: map[string]any{"err": err}})
return
}
db = db.Where("parent_id IN ?", parentIds)
}
// Apply info mode settings
if info {
db = db.Select("id", "parent_id", "name", "ip", "protocols", "connectable", "authorization")
// Apply authorization filter if needed
if !acl.IsAdmin(currentUser) {
ids, err := GetAssetIdsByAuthorization(ctx)
if err != nil {
@@ -122,149 +139,16 @@ func (c *Controller) GetAssets(ctx *gin.Context) {
}
}
db = db.Order("name")
doGet(ctx, !info, db, config.RESOURCE_ASSET, assetPostHooks...)
}
func assetPostHookCount(ctx *gin.Context, data []*model.Asset) {
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return
}
g := make(map[int][]model.Pair[int, string])
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], model.Pair[int, string]{First: n.Id, Second: n.Name})
}
m := make(map[int]string)
var dfs func(int, string)
dfs = func(x int, s string) {
m[x] = s
for _, node := range g[x] {
dfs(node.First, fmt.Sprintf("%s/%s", s, node.Second))
}
}
dfs(0, "")
for _, d := range data {
d.NodeChain = m[d.ParentId]
}
}
func assetPostHookAuth(ctx *gin.Context, data []*model.Asset) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
if acl.IsAdmin(currentUser) {
return
}
info := cast.ToBool(ctx.Query("info"))
noInfoIds := make([]int, 0)
if !info {
t := dbpkg.DB.Model(model.DefaultAsset)
assetResIds, _ := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
t, _ = handleAssetIds(ctx, t, assetResIds)
t.Pluck("id", &noInfoIds)
}
authorizationIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
nodeIds, _, accountIds := getIdsByAuthorizationIds(ctx)
nodeIds, _ = handleSelfChild(ctx, nodeIds...)
for _, a := range data {
if lo.Contains(nodeIds, a.ParentId) || lo.Contains(noInfoIds, a.Id) {
continue
}
if lo.ContainsBy(authorizationIds, func(item *model.AuthorizationIds) bool {
return item.AssetId == a.Id && item.NodeId == 0 && item.AccountId == 0
}) {
continue
}
ids := lo.Map(lo.Filter(authorizationIds, func(item *model.AuthorizationIds, _ int) bool {
return item.AssetId == a.Id && item.AccountId != 0 && item.NodeId == 0
}),
func(item *model.AuthorizationIds, _ int) int { return item.AccountId })
for k := range a.Authorization {
if !lo.Contains(ids, k) && !lo.Contains(accountIds, k) {
delete(a.Authorization, k)
}
}
}
}
func handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int)
dfs = func(x int) {
pids = append(pids, x)
for _, y := range g[x] {
dfs(y)
}
}
dfs(parentId)
return
}
func GetAssetIdsByAuthorization(ctx *gin.Context) (ids []int, err error) {
// GetAssetIdsByAuthorization gets asset IDs by authorization
func GetAssetIdsByAuthorization(ctx *gin.Context) ([]int, error) {
authIds, err := getAuthorizationIds(ctx)
if err != nil {
return
return nil, err
}
ctx.Set(kAuthorizationIds, authIds)
nodeIds, ids, accountIds := getIdsByAuthorizationIds(ctx)
tmp, err := handleSelfChild(ctx, nodeIds...)
if err != nil {
return
}
nodeIds = append(nodeIds, tmp...)
ctx.Set(kNodeIds, nodeIds)
ctx.Set(kAccountIds, accountIds)
tmp, err = getAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
if err != nil {
return
}
ids = lo.Uniq(append(ids, tmp...))
return
}
func getIdsByAuthorizationIds(ctx *gin.Context) (nodeIds, assetIds, accountIds []int) {
authIds, _ := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
info := cast.ToBool(ctx.Query("info"))
for _, a := range authIds {
if a.NodeId != 0 && a.AssetId == 0 && a.AccountId == 0 {
nodeIds = append(nodeIds, a.NodeId)
}
if a.AssetId != 0 && a.NodeId == 0 && (info || a.AccountId == 0) {
assetIds = append(assetIds, a.AssetId)
}
if a.AccountId != 0 && a.AssetId == 0 && (info || a.NodeId == 0) {
accountIds = append(accountIds, a.AccountId)
}
}
return
}
func getAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) {
assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
if err != nil {
return
}
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
return lo.Contains(nodeIds, a.ParentId) || len(lo.Intersect(lo.Keys(a.Authorization), accountIds)) > 0
})
assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
return
_, assetIds, _, err := assetService.GetAssetIdsByAuthorization(ctx, authIds)
return assetIds, err
}

View File

@@ -159,7 +159,7 @@ func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds,
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(resIds, n.ResourceId) })
nodeIds = lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
nodeIds, err = handleSelfChild(ctx, nodeIds...)
nodeIds, err = repository.HandleSelfChild(ctx, nodeIds...)
return
})
@@ -403,7 +403,7 @@ func hasAuthorization(ctx *gin.Context, sess *gsession.Session) (ok bool) {
ctx.Set(kAuthorizationIds, authIds)
nodeIds, assetIds, accountIds := getIdsByAuthorizationIds(ctx)
tmp, err := handleSelfChild(ctx, nodeIds...)
tmp, err := repository.HandleSelfChild(ctx, nodeIds...)
if err != nil {
logger.L().Error("", zap.Error(err))
return
@@ -421,3 +421,15 @@ func hasAuthorization(ctx *gin.Context, sess *gsession.Session) (ok bool) {
return lo.Contains(ids, sess.AssetId)
}
func getIdsByAuthorizationIds(ctx *gin.Context) (nodeIds, assetIds, accountIds []int) {
authorizationIds, ok := ctx.Value(kAuthorizationIds).([]*model.AuthorizationIds)
if !ok || len(authorizationIds) == 0 {
return
}
return assetService.GetIdsByAuthorizationIds(ctx, authorizationIds)
}
func getAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error) {
return assetService.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/internal/tunneling"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
)
@@ -389,7 +390,7 @@ func connectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
}
}()
ip, port, err := service.Proxy(false, sess.SessionId, "ssh", asset, gateway)
ip, port, err := tunneling.Proxy(false, sess.SessionId, "ssh", asset, gateway)
if err != nil {
return
}
@@ -554,7 +555,7 @@ func connectOther(ctx *gin.Context, sess *gsession.Session, asset *model.Asset,
}()
protocol := strings.Split(sess.Protocol, ":")[0]
ip, port, err := service.Proxy(false, sess.SessionId, protocol, asset, gateway)
ip, port, err := tunneling.Proxy(false, sess.SessionId, protocol, asset, gateway)
if err != nil {
return
}

View File

@@ -501,9 +501,9 @@ func hasPerm[T model.Model](ctx context.Context, md T, resourceTypeName, action
pids := make([]int, 0)
switch t := any(md).(type) {
case *model.Asset:
pids, _ = handleSelfParent(ctx, t.ParentId)
pids, _ = repository.HandleSelfParent(ctx, t.ParentId)
case *model.Node:
pids, _ = handleSelfParent(ctx, t.Id)
pids, _ = repository.HandleSelfParent(ctx, t.Id)
}
if len(pids) > 0 {
@@ -619,7 +619,7 @@ func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(resIds, n.ResourceId) })
ids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
if ids, err = handleSelfChild(ctx, ids...); err != nil {
if ids, err = repository.HandleSelfChild(ctx, ids...); err != nil {
return
}
@@ -633,7 +633,7 @@ func handleNodeIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB
}
ids = append(ids, lo.Map(assets, func(a *model.AssetIdPid, _ int) int { return a.ParentId })...)
ids, err = handleSelfParent(ctx, ids...)
ids, err = repository.HandleSelfParent(ctx, ids...)
if err != nil {
return
}
@@ -656,7 +656,7 @@ func handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.D
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
nodeIds := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
if nodeIds, err = handleSelfChild(ctx, nodeIds...); err != nil {
if nodeIds, err = repository.HandleSelfChild(ctx, nodeIds...); err != nil {
return
}

View File

@@ -102,7 +102,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
}
if id, ok := ctx.GetQuery("self_parent"); ok {
ids, err := handleSelfParent(ctx, cast.ToInt(id))
ids, err := repository.HandleSelfParent(ctx, cast.ToInt(id))
if err != nil {
return
}
@@ -116,7 +116,7 @@ func (c *Controller) GetNodes(ctx *gin.Context) {
if err != nil {
return
}
if ids, err = handleSelfParent(ctx, ids...); err != nil {
if ids, err = repository.HandleSelfParent(ctx, ids...); err != nil {
return
}
db = db.Where("id IN ?", ids)
@@ -216,7 +216,7 @@ func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
assetIds, _ := GetAssetIdsByAuthorization(ctx)
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetIds, a.Id) })
pids := lo.Map(assets, func(a *model.Asset, _ int) int { return a.ParentId })
pids, _ = handleSelfParent(ctx, pids...)
pids, _ = repository.HandleSelfParent(ctx, pids...)
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(pids, n.Id) })
ps := lo.SliceToMap(nodes, func(a *model.Node) (int, bool) { return a.ParentId, true })
for _, n := range data {
@@ -229,14 +229,14 @@ func nodePostHookHasChild(ctx *gin.Context, data []*model.Node) {
assetResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool { return lo.Contains(assetResIds, a.ResourceId) })
pids = lo.Map(assets, func(n *model.Asset, _ int) int { return n.ParentId })
pids, _ = handleSelfParent(ctx, pids...)
pids, _ = repository.HandleSelfParent(ctx, pids...)
return
})
eg.Go(func() (err error) {
nodeResIds, err = acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
ns := lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
nids, _ = handleSelfChild(ctx, lo.Map(ns, func(n *model.Node, _ int) int { return n.Id })...)
nids, _ = handleSelfParent(ctx, nids...)
nids, _ = repository.HandleSelfChild(ctx, lo.Map(ns, func(n *model.Node, _ int) int { return n.Id })...)
nids, _ = repository.HandleSelfParent(ctx, nids...)
return
})
eg.Wait()
@@ -261,61 +261,6 @@ func nodeDelHook(ctx *gin.Context, id int) {
ctx.AbortWithError(http.StatusBadRequest, err)
}
func handleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int, bool)
dfs = func(x int, b bool) {
if b {
res = append(res, x)
}
for _, y := range g[x] {
dfs(y, b || lo.Contains(ids, x))
}
}
dfs(0, false)
res = lo.Uniq(append(res, ids...))
return
}
func handleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
t := make([]int, 0)
var dfs func(int)
dfs = func(x int) {
t = append(t, x)
if lo.Contains(ids, x) {
res = append(res, t...)
}
for _, y := range g[x] {
dfs(y)
}
t = t[:len(t)-1]
}
dfs(0)
res = lo.Uniq(append(res, ids...))
return
}
func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
@@ -323,7 +268,7 @@ func handleNoSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
}
allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
res, err = handleSelfChild(ctx, ids...)
res, err = repository.HandleSelfChild(ctx, ids...)
if err != nil {
return
}
@@ -339,7 +284,7 @@ func handleNoSelfParent(ctx context.Context, ids ...int) (res []int, err error)
}
allids := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
res, err = handleSelfParent(ctx, ids...)
res, err = repository.HandleSelfParent(ctx, ids...)
if err != nil {
return
}

View File

@@ -0,0 +1,165 @@
package repository
import (
"context"
"errors"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/veops/oneterm/internal/model"
dbpkg "github.com/veops/oneterm/pkg/db"
"gorm.io/gorm"
)
// AccountRepository interface for account data access
type AccountRepository interface {
AttachAssetCount(ctx context.Context, accounts []*model.Account) error
CheckAssetDependencies(ctx context.Context, id int) (string, error)
BuildQuery(ctx *gin.Context) *gorm.DB
FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB
GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error)
}
// accountRepository implements AccountRepository
type accountRepository struct{}
// NewAccountRepository creates a new account repository
func NewAccountRepository() AccountRepository {
return &accountRepository{}
}
// BuildQuery builds the base query for accounts with filters
func (r *accountRepository) BuildQuery(ctx *gin.Context) *gorm.DB {
db := dbpkg.DB.Model(&model.Account{})
// Apply filters
db = r.filterEqual(ctx, db, "id", "type")
db = r.filterLike(ctx, db, "name")
db = r.filterSearch(ctx, db, "name", "account")
// Handle IDs parameter
if q, ok := ctx.GetQuery("ids"); ok {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","),
func(s string, _ int) int { return cast.ToInt(s) }))
}
// Sort by name
db = db.Order("name")
return db
}
// FilterByAssetIds filters accounts by related asset IDs
func (r *accountRepository) FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB {
if len(assetIds) == 0 {
return db.Where("0 = 1") // Return empty result if no asset IDs
}
// 查询与指定资产关联的账户ID
subQuery := dbpkg.DB.Model(&model.Authorization{}).
Select("account_id").
Where("asset_id IN ?", assetIds).
Group("account_id")
return db.Where("id IN (?)", subQuery)
}
// AttachAssetCount attaches asset count to accounts
func (r *accountRepository) AttachAssetCount(ctx context.Context, accounts []*model.Account) error {
acs := make([]*model.AccountCount, 0)
if err := dbpkg.DB.
Model(&model.Authorization{}).
Select("account_id AS id, COUNT(*) as count").
Group("account_id").
Where("account_id IN ?", lo.Map(accounts, func(d *model.Account, _ int) int { return d.Id })).
Find(&acs).
Error; err != nil {
return err
}
m := lo.SliceToMap(acs, func(ac *model.AccountCount) (int, int64) { return ac.Id, ac.Count })
for _, d := range accounts {
d.AssetCount = m[d.Id]
}
return nil
}
// CheckAssetDependencies checks if account has dependent assets
func (r *accountRepository) CheckAssetDependencies(ctx context.Context, id int) (string, error) {
var assetName string
err := dbpkg.DB.
Model(model.DefaultAsset).
Select("name").
Where("id = (?)", dbpkg.DB.Model(&model.Authorization{}).Select("asset_id").Where("account_id = ?", id).Limit(1)).
First(&assetName).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil
}
if err != nil {
return "", err
}
return assetName, errors.New("account has dependent assets")
}
// GetAccountIdsByAuthorization gets account IDs by authorization and asset IDs
func (r *accountRepository) GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error) {
// 从资产的授权列表中获取账户ID
ss := make([]model.Slice[string], 0)
if err := dbpkg.DB.Model(model.DefaultAsset).Where("id IN ?", assetIds).Pluck("JSON_KEYS(authorization)", &ss).Error; err != nil {
return nil, err
}
// 处理从资产中获取的账户IDs
accountIds := lo.Uniq(lo.Map(lo.Flatten(ss), func(s string, _ int) int { return cast.ToInt(s) }))
// 合并从授权中获取的账户IDs
return lo.Uniq(append(accountIds, authorizationIds...)), nil
}
// Filter helpers
func (r *accountRepository) filterEqual(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
for _, f := range fields {
if q, ok := ctx.GetQuery(f); ok {
db = db.Where(fmt.Sprintf("%s = ?", f), q)
}
}
return db
}
func (r *accountRepository) filterLike(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
likes := false
d := dbpkg.DB
for _, f := range fields {
if q, ok := ctx.GetQuery(f); ok && q != "" {
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
likes = true
}
}
if !likes {
return db
}
db = db.Where(d)
return db
}
func (r *accountRepository) filterSearch(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
q, ok := ctx.GetQuery("search")
if !ok || len(fields) <= 0 {
return db
}
d := dbpkg.DB
for _, f := range fields {
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
}
db = db.Where(d)
return db
}

View File

@@ -0,0 +1,289 @@
package repository
import (
"context"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"gorm.io/gorm"
)
const (
kFmtAssetIds = "assetIds-%d"
kAuthorizationIds = "authorizationIds"
kNodeIds = "nodeIds"
kAccountIds = "accountIds"
)
// AssetRepository interface for asset data access
type AssetRepository interface {
AttachNodeChain(ctx context.Context, assets []*model.Asset) error
ApplyAuthorizationFilters(ctx *gin.Context, assets []*model.Asset, authorizationIds []*model.AuthorizationIds, nodeIds, accountIds []int)
BuildQuery(ctx *gin.Context) (*gorm.DB, error)
FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error)
GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error)
GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int)
GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error)
}
// assetRepository implements AssetRepository
type assetRepository struct{}
// NewAssetRepository creates a new asset repository
func NewAssetRepository() AssetRepository {
return &assetRepository{}
}
// BuildQuery builds the base query for assets with filters
func (r *assetRepository) BuildQuery(ctx *gin.Context) (*gorm.DB, error) {
db := dbpkg.DB.Model(model.DefaultAsset)
// Apply filters
db = r.filterEqual(ctx, db, "id")
db = r.filterLike(ctx, db, "name", "ip")
db = r.filterSearch(ctx, db, "name", "ip")
// Handle IDs parameter
if q, ok := ctx.GetQuery("ids"); ok {
db = db.Where("id IN ?", lo.Map(strings.Split(q, ","),
func(s string, _ int) int { return cast.ToInt(s) }))
}
// Sort by name
db = db.Order("name")
return db, nil
}
// FilterByParentId filters assets by parent ID and its children
func (r *assetRepository) FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error) {
parentIds, err := r.handleParentId(context.Background(), parentId)
if err != nil {
return nil, err
}
return db.Where("parent_id IN ?", parentIds), nil
}
// AttachNodeChain attaches node chain to assets
func (r *assetRepository) AttachNodeChain(ctx context.Context, assets []*model.Asset) error {
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return err
}
g := make(map[int][]model.Pair[int, string])
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], model.Pair[int, string]{First: n.Id, Second: n.Name})
}
m := make(map[int]string)
var dfs func(int, string)
dfs = func(x int, s string) {
m[x] = s
for _, node := range g[x] {
dfs(node.First, fmt.Sprintf("%s/%s", s, node.Second))
}
}
dfs(0, "")
for _, d := range assets {
d.NodeChain = m[d.ParentId]
}
return nil
}
// ApplyAuthorizationFilters applies authorization filters to assets
func (r *assetRepository) ApplyAuthorizationFilters(ctx *gin.Context, assets []*model.Asset, authorizationIds []*model.AuthorizationIds, nodeIds, accountIds []int) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
if acl.IsAdmin(currentUser) {
return
}
info := cast.ToBool(ctx.Query("info"))
noInfoIds := make([]int, 0)
if !info {
t := dbpkg.DB.Model(model.DefaultAsset)
assetResIds, _ := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_ASSET)
t, _ = r.handleAssetIds(ctx, t, assetResIds)
t.Pluck("id", &noInfoIds)
}
for _, a := range assets {
if lo.Contains(nodeIds, a.ParentId) || lo.Contains(noInfoIds, a.Id) {
continue
}
if lo.ContainsBy(authorizationIds, func(item *model.AuthorizationIds) bool {
return item.AssetId == a.Id && item.NodeId == 0 && item.AccountId == 0
}) {
continue
}
ids := lo.Map(lo.Filter(authorizationIds, func(item *model.AuthorizationIds, _ int) bool {
return item.AssetId == a.Id && item.AccountId != 0 && item.NodeId == 0
}),
func(item *model.AuthorizationIds, _ int) int { return item.AccountId })
for k := range a.Authorization {
if !lo.Contains(ids, k) && !lo.Contains(accountIds, k) {
delete(a.Authorization, k)
}
}
}
}
// GetAssetIdsByAuthorization gets asset IDs by authorization
func (r *assetRepository) GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error) {
ctx.Set(kAuthorizationIds, authorizationIds)
nodeIds, assetIds, accountIds := r.GetIdsByAuthorizationIds(ctx, authorizationIds)
tmp, err := HandleSelfChild(ctx, nodeIds...)
if err != nil {
return nil, nil, nil, err
}
nodeIds = append(nodeIds, tmp...)
ctx.Set(kNodeIds, nodeIds)
ctx.Set(kAccountIds, accountIds)
assetIdsFromNode, err := r.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
if err != nil {
return nil, nil, nil, err
}
allAssetIds := lo.Uniq(append(assetIds, assetIdsFromNode...))
return nodeIds, allAssetIds, accountIds, nil
}
// GetIdsByAuthorizationIds extracts node IDs, asset IDs, and account IDs from authorization IDs
func (r *assetRepository) GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) (nodeIds, assetIds, accountIds []int) {
info := cast.ToBool(ctx.Query("info"))
for _, a := range authorizationIds {
if a.NodeId != 0 && a.AssetId == 0 && a.AccountId == 0 {
nodeIds = append(nodeIds, a.NodeId)
}
if a.AssetId != 0 && a.NodeId == 0 && (info || a.AccountId == 0) {
assetIds = append(assetIds, a.AssetId)
}
if a.AccountId != 0 && a.AssetId == 0 && (info || a.NodeId == 0) {
accountIds = append(accountIds, a.AccountId)
}
}
return
}
// GetAssetIdsByNodeAccount gets asset IDs by node IDs and account IDs
func (r *assetRepository) GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) (assetIds []int, err error) {
assets, err := GetAllFromCacheDb(ctx, model.DefaultAsset)
if err != nil {
return nil, err
}
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
return lo.Contains(nodeIds, a.ParentId) || len(lo.Intersect(lo.Keys(a.Authorization), accountIds)) > 0
})
assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
return assetIds, nil
}
// handleParentId builds a slice of parent IDs including the given parent ID and all its children
func (r *assetRepository) handleParentId(ctx context.Context, parentId int) (pids []int, err error) {
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return nil, err
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int)
dfs = func(x int) {
pids = append(pids, x)
for _, y := range g[x] {
dfs(y)
}
}
dfs(parentId)
return pids, nil
}
// handleAssetIds filters assets by resource IDs and node IDs
func (r *assetRepository) handleAssetIds(ctx *gin.Context, dbFind *gorm.DB, resIds []int) (db *gorm.DB, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return nil, err
}
nodeResIds, err := acl.GetRoleResourceIds(ctx, currentUser.GetRid(), config.RESOURCE_NODE)
if err != nil {
return nil, err
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(nodeResIds, n.ResourceId) })
nodeIds := lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
if nodeIds, err = HandleSelfChild(ctx, nodeIds...); err != nil {
return nil, err
}
d := dbpkg.DB.Where("resource_id IN ?", resIds).Or("parent_id IN?", nodeIds)
db = dbFind.Where(d)
return db, nil
}
// Filter helpers
func (r *assetRepository) filterEqual(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
for _, f := range fields {
if q, ok := ctx.GetQuery(f); ok {
db = db.Where(fmt.Sprintf("%s = ?", f), q)
}
}
return db
}
func (r *assetRepository) filterLike(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
likes := false
d := dbpkg.DB
for _, f := range fields {
if q, ok := ctx.GetQuery(f); ok && q != "" {
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
likes = true
}
}
if !likes {
return db
}
db = db.Where(d)
return db
}
func (r *assetRepository) filterSearch(ctx *gin.Context, db *gorm.DB, fields ...string) *gorm.DB {
q, ok := ctx.GetQuery("search")
if !ok || len(fields) <= 0 {
return db
}
d := dbpkg.DB
for _, f := range fields {
d = d.Or(fmt.Sprintf("%s LIKE ?", f), fmt.Sprintf("%%%s%%", q))
}
db = db.Where(d)
return db
}

View File

@@ -0,0 +1,67 @@
package repository
import (
"context"
"github.com/samber/lo"
"github.com/veops/oneterm/internal/model"
)
// HandleSelfChild gets IDs of nodes that are children of the specified node IDs
func HandleSelfChild(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return nil, err
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
var dfs func(int, bool)
dfs = func(x int, b bool) {
if b {
res = append(res, x)
}
for _, y := range g[x] {
dfs(y, b || lo.Contains(ids, x))
}
}
dfs(0, false)
res = lo.Uniq(append(res, ids...))
return res, nil
}
// HandleSelfParent gets IDs of nodes that are parents of the specified node IDs
func HandleSelfParent(ctx context.Context, ids ...int) (res []int, err error) {
nodes, err := GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return nil, err
}
g := make(map[int][]int)
for _, n := range nodes {
g[n.ParentId] = append(g[n.ParentId], n.Id)
}
t := make([]int, 0)
var dfs func(int)
dfs = func(x int) {
t = append(t, x)
if lo.Contains(ids, x) {
res = append(res, t...)
}
for _, y := range g[x] {
dfs(y)
}
t = t[:len(t)-1]
}
dfs(0)
res = lo.Uniq(append(res, ids...))
return res, nil
}

View File

@@ -11,7 +11,6 @@ import (
"go.uber.org/zap"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/internal/tunneling"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
@@ -81,7 +80,7 @@ func UpdateConnectables(ids ...int) (err error) {
func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string, ok bool) {
sid = uuid.New().String()
ps := strings.Join(lo.Map(asset.Protocols, func(p string, _ int) string { return strings.Split(p, ":")[0] }), ",")
ip, port, err := service.Proxy(true, sid, ps, asset, gateway)
ip, port, err := tunneling.Proxy(true, sid, ps, asset, gateway)
if err != nil {
logger.L().Debug("connectable proxy failed", zap.String("protocol", ps), zap.Error(err))
return
@@ -110,3 +109,8 @@ func updateConnectable(asset *model.Asset, gateway *model.Gateway) (sid string,
ok = true
return
}
// UpdateAssetConnectables is used by service/asset.go to update connectables
func UpdateAssetConnectables(ids ...int) error {
return UpdateConnectables(ids...)
}

View File

@@ -0,0 +1,80 @@
package service
import (
"context"
"github.com/gin-gonic/gin"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/pkg/utils"
"golang.org/x/crypto/ssh"
"gorm.io/gorm"
)
// AccountService handles account business logic
type AccountService struct {
repo repository.AccountRepository
}
// NewAccountService creates a new account service
func NewAccountService() *AccountService {
return &AccountService{
repo: repository.NewAccountRepository(),
}
}
// ValidatePublicKey validates the given public key
func (s *AccountService) ValidatePublicKey(account *model.Account) error {
if account.AccountType != model.AUTHMETHOD_PUBLICKEY {
return nil
}
var err error
if account.Phrase == "" {
_, err = ssh.ParsePrivateKey([]byte(account.Pk))
} else {
_, err = ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
}
return err
}
// EncryptSensitiveData encrypts sensitive account data
func (s *AccountService) EncryptSensitiveData(account *model.Account) {
account.Password = utils.EncryptAES(account.Password)
account.Pk = utils.EncryptAES(account.Pk)
account.Phrase = utils.EncryptAES(account.Phrase)
}
// DecryptSensitiveData decrypts sensitive account data
func (s *AccountService) DecryptSensitiveData(accounts []*model.Account) {
for _, a := range accounts {
a.Password = utils.DecryptAES(a.Password)
a.Pk = utils.DecryptAES(a.Pk)
a.Phrase = utils.DecryptAES(a.Phrase)
}
}
// AttachAssetCount attaches asset count to accounts
func (s *AccountService) AttachAssetCount(ctx context.Context, accounts []*model.Account) error {
return s.repo.AttachAssetCount(ctx, accounts)
}
// CheckAssetDependencies checks if account has dependent assets
func (s *AccountService) CheckAssetDependencies(ctx context.Context, id int) (string, error) {
return s.repo.CheckAssetDependencies(ctx, id)
}
// BuildQuery constructs account query with basic filters
func (s *AccountService) BuildQuery(ctx *gin.Context) *gorm.DB {
return s.repo.BuildQuery(ctx)
}
// FilterByAssetIds filters accounts by related asset IDs
func (s *AccountService) FilterByAssetIds(db *gorm.DB, assetIds []int) *gorm.DB {
return s.repo.FilterByAssetIds(db, assetIds)
}
// GetAccountIdsByAuthorization gets account IDs by authorization
func (s *AccountService) GetAccountIdsByAuthorization(ctx context.Context, assetIds []int, authorizationIds []int) ([]int, error) {
return s.repo.GetAccountIdsByAuthorization(ctx, assetIds, authorizationIds)
}

View File

@@ -0,0 +1,74 @@
package service
import (
"context"
"strings"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/internal/schedule"
"gorm.io/gorm"
)
// AssetService handles asset business logic
type AssetService struct {
repo repository.AssetRepository
}
// NewAssetService creates a new asset service
func NewAssetService() *AssetService {
return &AssetService{
repo: repository.NewAssetRepository(),
}
}
// PreprocessAssetData preprocesses asset data before saving
func (s *AssetService) PreprocessAssetData(asset *model.Asset) {
asset.Ip = strings.TrimSpace(asset.Ip)
asset.Protocols = lo.Map(asset.Protocols, func(s string, _ int) string { return strings.TrimSpace(s) })
if asset.Authorization == nil {
asset.Authorization = make(model.Map[int, model.Slice[int]])
}
}
// AttachNodeChain attaches node chain to assets
func (s *AssetService) AttachNodeChain(ctx context.Context, assets []*model.Asset) error {
return s.repo.AttachNodeChain(ctx, assets)
}
// ApplyAuthorizationFilters applies authorization filters to assets
func (s *AssetService) ApplyAuthorizationFilters(ctx *gin.Context, assets []*model.Asset, authorizationIds []*model.AuthorizationIds, nodeIds, accountIds []int) {
s.repo.ApplyAuthorizationFilters(ctx, assets, authorizationIds, nodeIds, accountIds)
}
// BuildQuery constructs asset query with basic filters
func (s *AssetService) BuildQuery(ctx *gin.Context) (*gorm.DB, error) {
return s.repo.BuildQuery(ctx)
}
// FilterByParentId filters assets by parent ID
func (s *AssetService) FilterByParentId(db *gorm.DB, parentId int) (*gorm.DB, error) {
return s.repo.FilterByParentId(db, parentId)
}
// GetAssetIdsByAuthorization gets asset IDs by authorization
func (s *AssetService) GetAssetIdsByAuthorization(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int, error) {
return s.repo.GetAssetIdsByAuthorization(ctx, authorizationIds)
}
// GetIdsByAuthorizationIds extracts node IDs, asset IDs, and account IDs from authorization IDs
func (s *AssetService) GetIdsByAuthorizationIds(ctx *gin.Context, authorizationIds []*model.AuthorizationIds) ([]int, []int, []int) {
return s.repo.GetIdsByAuthorizationIds(ctx, authorizationIds)
}
// GetAssetIdsByNodeAccount gets asset IDs by node IDs and account IDs
func (s *AssetService) GetAssetIdsByNodeAccount(ctx context.Context, nodeIds, accountIds []int) ([]int, error) {
return s.repo.GetAssetIdsByNodeAccount(ctx, nodeIds, accountIds)
}
// UpdateConnectables updates asset connectability status
func (s *AssetService) UpdateConnectables(ids ...int) error {
return schedule.UpdateAssetConnectables(ids...)
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/google/uuid"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"github.com/veops/oneterm/internal/tunneling"
)
var (
@@ -73,7 +75,7 @@ func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client,
return
}
ip, port, err := Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
ip, port, err := tunneling.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
if err != nil {
return
}

View File

@@ -2,13 +2,10 @@ package service
import (
"fmt"
"strings"
"github.com/spf13/cast"
"golang.org/x/crypto/ssh"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/tunneling"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/utils"
)
@@ -58,27 +55,3 @@ func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
}
}
func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
ip, port = asset.Ip, 0
for _, tp := range strings.Split(protocol, ",") {
for _, p := range asset.Protocols {
if strings.HasPrefix(strings.ToLower(p), tp) {
if port = cast.ToInt(strings.Split(p, ":")[1]); port != 0 {
break
}
}
}
}
if asset.GatewayId == 0 || gateway == nil {
return
}
g, err := tunneling.OpenTunnel(isConnectable, sessionId, ip, port, gateway)
if err != nil {
return
}
ip, port = g.LocalIp, g.LocalPort
return
}

View File

@@ -5,9 +5,11 @@ import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/spf13/cast"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
@@ -231,3 +233,28 @@ func getAvailablePort() (int, error) {
return l.Addr().(*net.TCPAddr).Port, nil
}
// Proxy establishes a proxy connection to an asset through a gateway if necessary
func Proxy(isConnectable bool, sessionId string, protocol string, asset *model.Asset, gateway *model.Gateway) (ip string, port int, err error) {
ip, port = asset.Ip, 0
for _, tp := range strings.Split(protocol, ",") {
for _, p := range asset.Protocols {
if strings.HasPrefix(strings.ToLower(p), tp) {
if port = cast.ToInt(strings.Split(p, ":")[1]); port != 0 {
break
}
}
}
}
if asset.GatewayId == 0 || gateway == nil {
return
}
g, err := OpenTunnel(isConnectable, sessionId, ip, port, gateway)
if err != nil {
return
}
ip, port = g.LocalIp, g.LocalPort
return
}