mirror of
				https://github.com/veops/oneterm.git
				synced 2025-11-01 03:12:39 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			199 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			199 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package service
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/google/uuid"
 | |
| 	"github.com/samber/lo"
 | |
| 	"github.com/spf13/cast"
 | |
| 	"golang.org/x/sync/errgroup"
 | |
| 	"gorm.io/gorm"
 | |
| 
 | |
| 	"github.com/veops/oneterm/internal/acl"
 | |
| 	"github.com/veops/oneterm/internal/model"
 | |
| 	"github.com/veops/oneterm/internal/repository"
 | |
| )
 | |
| 
 | |
| // ShareService handles share business logic
 | |
| type ShareService struct {
 | |
| 	repo repository.ShareRepository
 | |
| }
 | |
| 
 | |
| // NewShareService creates a new share service
 | |
| func NewShareService() *ShareService {
 | |
| 	return &ShareService{
 | |
| 		repo: repository.NewShareRepository(),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // HasPermission checks if the user has permission on the share
 | |
| func (s *ShareService) HasPermission(ctx context.Context, share *model.Share, action string) bool {
 | |
| 	currentUser, _ := acl.GetSessionFromCtx(ctx)
 | |
| 
 | |
| 	if acl.IsAdmin(currentUser) {
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	_, assetIds, accountIds, err := getNodeAssetAccoutIdsByAction(ctx, action)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 
 | |
| 	return lo.Contains(assetIds, share.AssetId) || lo.Contains(accountIds, share.AccountId)
 | |
| }
 | |
| 
 | |
| // CreateShares creates new shares
 | |
| func (s *ShareService) CreateShares(ctx context.Context, shares []*model.Share) ([]string, error) {
 | |
| 	// Add UUID and return the UUIDs
 | |
| 	uuids := lo.Map(shares, func(share *model.Share, _ int) string {
 | |
| 		share.Uuid = uuid.New().String()
 | |
| 		return share.Uuid
 | |
| 	})
 | |
| 
 | |
| 	err := s.repo.CreateShares(ctx, shares)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return uuids, nil
 | |
| }
 | |
| 
 | |
| // BuildQuery constructs a query for shares
 | |
| func (s *ShareService) BuildQuery(ctx *gin.Context, isAdmin bool) (*gorm.DB, error) {
 | |
| 	var assetIds, accountIds []int
 | |
| 
 | |
| 	// If not admin, get accessible assets and accounts
 | |
| 	if !isAdmin {
 | |
| 		var err error
 | |
| 		_, assetIds, accountIds, err = getNodeAssetAccoutIdsByAction(ctx, acl.GRANT)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return s.repo.BuildQuery(ctx, assetIds, accountIds)
 | |
| }
 | |
| 
 | |
| // GetShareByID retrieves a share by ID
 | |
| func (s *ShareService) GetShareByID(ctx context.Context, id int) (*model.Share, error) {
 | |
| 	return s.repo.GetShareByID(ctx, id)
 | |
| }
 | |
| 
 | |
| // GetShareByUUID retrieves a share by UUID
 | |
| func (s *ShareService) GetShareByUUID(ctx context.Context, uuid string) (*model.Share, error) {
 | |
| 	return s.repo.GetShareByUUID(ctx, uuid)
 | |
| }
 | |
| 
 | |
| // ValidateShareForConnection validates a share for connection
 | |
| func (s *ShareService) ValidateShareForConnection(ctx context.Context, uuid string) (*model.Share, error) {
 | |
| 	share, err := s.repo.GetShareByUUID(ctx, uuid)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	now := time.Now()
 | |
| 	if now.Before(share.Start) || now.After(share.End) {
 | |
| 		return nil, fmt.Errorf("share expired or not started")
 | |
| 	}
 | |
| 
 | |
| 	if share.NoLimit {
 | |
| 		return share, nil
 | |
| 	}
 | |
| 
 | |
| 	rowsAffected, err := s.repo.DecrementShareTimes(ctx, share.Uuid)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	if rowsAffected != 1 {
 | |
| 		return nil, fmt.Errorf("no times left")
 | |
| 	}
 | |
| 
 | |
| 	return share, nil
 | |
| }
 | |
| 
 | |
| // SetupConnectionParams sets up connection parameters from share
 | |
| func (s *ShareService) SetupConnectionParams(ctx *gin.Context, share *model.Share) {
 | |
| 	// Filter out and add new parameters
 | |
| 	ctx.Params = lo.Filter(ctx.Params, func(p gin.Param, _ int) bool {
 | |
| 		return !lo.Contains([]string{"account_id", "asset_id", "protocol"}, p.Key)
 | |
| 	})
 | |
| 
 | |
| 	ctx.Params = append(ctx.Params, gin.Param{Key: "account_id", Value: cast.ToString(share.AccountId)})
 | |
| 	ctx.Params = append(ctx.Params, gin.Param{Key: "asset_id", Value: cast.ToString(share.AssetId)})
 | |
| 	ctx.Params = append(ctx.Params, gin.Param{Key: "protocol", Value: cast.ToString(share.Protocol)})
 | |
| 
 | |
| 	// Set share context values
 | |
| 	ctx.Set("shareId", share.Id)
 | |
| 	ctx.Set("session", &acl.Session{})
 | |
| 	ctx.Set("shareEnd", share.End)
 | |
| }
 | |
| 
 | |
| // Helper functions that should be moved to appropriate services
 | |
| func getNodeAssetAccoutIdsByAction(ctx context.Context, action string) (nodeIds, assetIds, accountIds []int, err error) {
 | |
| 	currentUser, _ := acl.GetSessionFromCtx(ctx)
 | |
| 
 | |
| 	eg := &errgroup.Group{}
 | |
| 	ch := make(chan bool)
 | |
| 
 | |
| 	eg.Go(func() (err error) {
 | |
| 		defer close(ch)
 | |
| 		res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), "node")
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
 | |
| 		resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
 | |
| 		nodes, err := repository.GetAllFromCacheDb(ctx, model.DefaultNode)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		nodes = lo.Filter(nodes, func(n *model.Node, _ int) bool { return lo.Contains(resIds, n.ResourceId) })
 | |
| 		nodeIds = lo.Map(nodes, func(n *model.Node, _ int) int { return n.Id })
 | |
| 		nodeIds, err = repository.HandleSelfChild(ctx, nodeIds...)
 | |
| 		return
 | |
| 	})
 | |
| 
 | |
| 	eg.Go(func() (err error) {
 | |
| 		res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), "asset")
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
 | |
| 		resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
 | |
| 		<-ch
 | |
| 		assets, err := repository.GetAllFromCacheDb(ctx, model.DefaultAsset)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		assets = lo.Filter(assets, func(a *model.Asset, _ int) bool {
 | |
| 			return lo.Contains(resIds, a.ResourceId) || lo.Contains(nodeIds, a.ParentId)
 | |
| 		})
 | |
| 		assetIds = lo.Map(assets, func(a *model.Asset, _ int) int { return a.Id })
 | |
| 		return
 | |
| 	})
 | |
| 
 | |
| 	eg.Go(func() (err error) {
 | |
| 		res, err := acl.GetRoleResources(ctx, currentUser.GetRid(), "account")
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		res = lo.Filter(res, func(r *acl.Resource, _ int) bool { return lo.Contains(r.Permissions, action) })
 | |
| 		resIds := lo.Map(res, func(r *acl.Resource, _ int) int { return r.ResourceId })
 | |
| 		accounts, err := repository.GetAllFromCacheDb(ctx, model.DefaultAccount)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		accounts = lo.Filter(accounts, func(a *model.Account, _ int) bool { return lo.Contains(resIds, a.ResourceId) })
 | |
| 		accountIds = lo.Map(accounts, func(a *model.Account, _ int) int { return a.Id })
 | |
| 		return
 | |
| 	})
 | |
| 
 | |
| 	err = eg.Wait()
 | |
| 
 | |
| 	return
 | |
| }
 | 
