Files
oneterm/backend/internal/service/share.go

199 lines
5.8 KiB
Go

package service
import (
"context"
"fmt"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/spf13/cast"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
)
// ShareService handles share business logic
type ShareService struct {
repo repository.ShareRepository
}
// NewShareService creates a new share service
func NewShareService() *ShareService {
return &ShareService{
repo: repository.NewShareRepository(),
}
}
// HasPermission checks if the user has permission on the share
func (s *ShareService) HasPermission(ctx context.Context, share *model.Share, action string) bool {
currentUser, _ := acl.GetSessionFromCtx(ctx)
if acl.IsAdmin(currentUser) {
return true
}
_, assetIds, accountIds, err := getNodeAssetAccoutIdsByAction(ctx, action)
if err != nil {
return false
}
return lo.Contains(assetIds, share.AssetId) || lo.Contains(accountIds, share.AccountId)
}
// CreateShares creates new shares
func (s *ShareService) CreateShares(ctx context.Context, shares []*model.Share) ([]string, error) {
// Add UUID and return the UUIDs
uuids := lo.Map(shares, func(share *model.Share, _ int) string {
share.Uuid = uuid.New().String()
return share.Uuid
})
err := s.repo.CreateShares(ctx, shares)
if err != nil {
return nil, err
}
return uuids, nil
}
// BuildQuery constructs a query for shares
func (s *ShareService) BuildQuery(ctx *gin.Context, isAdmin bool) (*gorm.DB, error) {
var assetIds, accountIds []int
// If not admin, get accessible assets and accounts
if !isAdmin {
var err error
_, assetIds, accountIds, err = getNodeAssetAccoutIdsByAction(ctx, acl.GRANT)
if err != nil {
return nil, err
}
}
return s.repo.BuildQuery(ctx, assetIds, accountIds)
}
// GetShareByID retrieves a share by ID
func (s *ShareService) GetShareByID(ctx context.Context, id int) (*model.Share, error) {
return s.repo.GetShareByID(ctx, id)
}
// GetShareByUUID retrieves a share by UUID
func (s *ShareService) GetShareByUUID(ctx context.Context, uuid string) (*model.Share, error) {
return s.repo.GetShareByUUID(ctx, uuid)
}
// ValidateShareForConnection validates a share for connection
func (s *ShareService) ValidateShareForConnection(ctx context.Context, uuid string) (*model.Share, error) {
share, err := s.repo.GetShareByUUID(ctx, uuid)
if err != nil {
return nil, err
}
now := time.Now()
if now.Before(share.Start) || now.After(share.End) {
return nil, fmt.Errorf("share expired or not started")
}
if share.NoLimit {
return share, nil
}
rowsAffected, err := s.repo.DecrementShareTimes(ctx, share.Uuid)
if err != nil {
return nil, err
}
if rowsAffected != 1 {
return nil, fmt.Errorf("no times left")
}
return share, nil
}
// SetupConnectionParams sets up connection parameters from share
func (s *ShareService) SetupConnectionParams(ctx *gin.Context, share *model.Share) {
// Filter out and add new parameters
ctx.Params = lo.Filter(ctx.Params, func(p gin.Param, _ int) bool {
return !lo.Contains([]string{"account_id", "asset_id", "protocol"}, p.Key)
})
ctx.Params = append(ctx.Params, gin.Param{Key: "account_id", Value: cast.ToString(share.AccountId)})
ctx.Params = append(ctx.Params, gin.Param{Key: "asset_id", Value: cast.ToString(share.AssetId)})
ctx.Params = append(ctx.Params, gin.Param{Key: "protocol", Value: cast.ToString(share.Protocol)})
// Set share context values
ctx.Set("shareId", share.Id)
ctx.Set("session", &acl.Session{})
ctx.Set("shareEnd", share.End)
}
// Helper functions that should be moved to appropriate services
func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds, assetIds, accountIds []int, err error) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
eg := &errgroup.Group{}
ch := make(chan bool)
eg.Go(func() (err error) {
defer close(ch)
res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), "node")
if err != nil {
return
}
res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
if err != nil {
return
}
nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(resIds, n.ResourceId) })
nodeIds = lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
nodeIds, err = repository.HandleSelfChild(ctx, nodeIds...)
return
})
eg.Go(func() (err error) {
res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), "asset")
if err != nil {
return
}
res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
<-ch
assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
if err != nil {
return
}
assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
return lo.Contains(resIds, a.ResourceId) || lo.Contains(nodeIds, a.ParentId)
})
assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
return
})
eg.Go(func() (err error) {
res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), "account")
if err != nil {
return
}
res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
accounts, err := repository.GetAllFromCacheDb(ctx, model.DefaultAccount)
if err != nil {
return
}
accounts = lo.Filter(accounts, func(a *model.Account, _ int) bool { return lo.Contains(resIds, a.ResourceId) })
accountIds = lo.Map(accounts, func(a *model.Account, _ int) int { return a.Id })
return
})
err = eg.Wait()
return
}