Files
oneterm/backend/internal/repository/session.go

231 lines
6.6 KiB
Go

package repository
import (
"context"
"time"
"github.com/gin-gonic/gin"
"github.com/veops/oneterm/internal/model"
gsession "github.com/veops/oneterm/internal/session"
dbpkg "github.com/veops/oneterm/pkg/db"
"gorm.io/gorm"
)
// SessionRepository defines the interface for session repository
type SessionRepository interface {
GetSession(ctx context.Context, sessionId string) (*model.Session, error)
BuildQuery(ctx *gin.Context, isAdmin bool, uid int) (*gorm.DB, error)
BuildCmdQuery(ctx *gin.Context, sessionId string) *gorm.DB
GetSessionOptionAssets(ctx context.Context) ([]*model.SessionOptionAsset, error)
GetSessionOptionClientIps(ctx context.Context) ([]string, error)
CreateSessionCmd(ctx context.Context, cmd *model.SessionCmd) error
GetSessionCmdCounts(ctx context.Context, sessionIds []string) (map[string]int64, error)
GetOnlineSessionByID(ctx context.Context, sessionID string) (*gsession.Session, error)
GetSshParserCommands(ctx context.Context, cmdIDs []int) ([]*model.Command, error)
// GetRecentSessionsByUser retrieves recent sessions deduplicated by asset_id and account_id combination
GetRecentSessionsByUser(ctx context.Context, uid int, limit int) ([]*model.Session, error)
}
type sessionRepository struct{}
// NewSessionRepository creates a new session repository
func NewSessionRepository() SessionRepository {
return &sessionRepository{}
}
// GetSession retrieves a session by session ID
func (r *sessionRepository) GetSession(ctx context.Context, sessionId string) (*model.Session, error) {
session := &model.Session{}
if err := dbpkg.DB.Where("session_id = ?", sessionId).First(session).Error; err != nil {
return nil, err
}
return session, nil
}
// BuildQuery constructs a query for sessions with filters
func (r *sessionRepository) BuildQuery(ctx *gin.Context, isAdmin bool, uid int) (*gorm.DB, error) {
db := dbpkg.DB.Model(model.DefaultSession)
// Apply user filter if not admin
if !isAdmin {
db = db.Where("uid = ?", uid)
}
// Apply text search
if q, ok := ctx.GetQuery("search"); ok && q != "" {
db = db.Where("user_name LIKE ? OR asset_info LIKE ? OR gateway_info LIKE ? OR account_info LIKE ?",
"%"+q+"%", "%"+q+"%", "%"+q+"%", "%"+q+"%")
}
// Apply date range filters
if start, ok := ctx.GetQuery("start"); ok {
t, err := time.Parse(time.RFC3339, start)
if err != nil {
return nil, err
}
db = db.Where("created_at >= ?", t)
}
if end, ok := ctx.GetQuery("end"); ok {
t, err := time.Parse(time.RFC3339, end)
if err != nil {
return nil, err
}
db = db.Where("created_at <= ?", t)
}
// Apply exact match filters
for _, field := range []string{"status", "uid", "asset_id", "client_ip"} {
if q, ok := ctx.GetQuery(field); ok && q != "" {
db = db.Where(field+" = ?", q)
}
}
return db, nil
}
// BuildCmdQuery constructs a query for session commands
func (r *sessionRepository) BuildCmdQuery(ctx *gin.Context, sessionId string) *gorm.DB {
db := dbpkg.DB.Model(&model.SessionCmd{})
db = db.Where("session_id = ?", sessionId)
// Apply text search
if q, ok := ctx.GetQuery("search"); ok && q != "" {
db = db.Where("cmd LIKE ? OR result LIKE ?", "%"+q+"%", "%"+q+"%")
}
return db
}
// GetSessionOptionAssets retrieves session option assets
func (r *sessionRepository) GetSessionOptionAssets(ctx context.Context) ([]*model.SessionOptionAsset, error) {
opts := make([]*model.SessionOptionAsset, 0)
if err := dbpkg.DB.
Model(model.DefaultAsset).
Select("id, name").
Find(&opts).
Error; err != nil {
return nil, err
}
return opts, nil
}
// GetSessionOptionClientIps retrieves distinct client IPs
func (r *sessionRepository) GetSessionOptionClientIps(ctx context.Context) ([]string, error) {
opts := make([]string, 0)
if err := dbpkg.DB.
Model(model.DefaultSession).
Distinct("client_ip").
Find(&opts).
Error; err != nil {
return nil, err
}
return opts, nil
}
// CreateSessionCmd creates a new session command
func (r *sessionRepository) CreateSessionCmd(ctx context.Context, cmd *model.SessionCmd) error {
return dbpkg.DB.Create(cmd).Error
}
// GetSessionCmdCounts retrieves command counts for sessions
func (r *sessionRepository) GetSessionCmdCounts(ctx context.Context, sessionIds []string) (map[string]int64, error) {
if len(sessionIds) <= 0 {
return map[string]int64{}, nil
}
post := make([]*model.CmdCount, 0)
if err := dbpkg.DB.
Model(&model.SessionCmd{}).
Select("session_id, COUNT(*) AS count").
Where("session_id IN ?", sessionIds).
Group("session_id").
Find(&post).
Error; err != nil {
return nil, err
}
// Convert to map
result := make(map[string]int64)
for _, p := range post {
result[p.SessionId] = p.Count
}
return result, nil
}
// GetOnlineSessionByID retrieves an online session by ID
func (r *sessionRepository) GetOnlineSessionByID(ctx context.Context, sessionID string) (*gsession.Session, error) {
session := &gsession.Session{}
err := dbpkg.DB.
Model(session).
Where("session_id = ?", sessionID).
Where("status = ?", model.SESSIONSTATUS_ONLINE).
First(session).
Error
return session, err
}
// GetSshParserCommands retrieves SSH parser commands by IDs
func (r *sessionRepository) GetSshParserCommands(ctx context.Context, cmdIDs []int) ([]*model.Command, error) {
var commands []*model.Command
err := dbpkg.DB.
Where("id IN ? AND enable=?", cmdIDs, true).
Find(&commands).
Error
return commands, err
}
// GetRecentSessionsByUser retrieves recent sessions for a user, deduplicated by asset_id and account_id
func (r *sessionRepository) GetRecentSessionsByUser(ctx context.Context, uid int, limit int) ([]*model.Session, error) {
var sessions []*model.Session
// First, get the MAX session ID for each asset+account combination
// This approach avoids LIMIT in subquery which MySQL doesn't support
type MaxSession struct {
AssetId int
AccountId int
MaxId int
}
var maxSessions []MaxSession
err := dbpkg.DB.Model(&model.Session{}).
Select("asset_id, account_id, MAX(id) as max_id").
Where("uid = ?", uid).
Where("asset_id > 0").
Where("account_id > 0").
Where("protocol NOT LIKE ?", "rdp%").
Where("protocol NOT LIKE ?", "vnc%").
Group("asset_id, account_id").
Find(&maxSessions).Error
if err != nil {
return nil, err
}
// Extract the IDs
var sessionIds []int
for _, ms := range maxSessions {
sessionIds = append(sessionIds, ms.MaxId)
}
if len(sessionIds) == 0 {
return sessions, nil
}
// Get the full session records for those IDs
err = dbpkg.DB.Model(&model.Session{}).
Where("id IN ?", sessionIds).
Where("protocol NOT LIKE ?", "rdp%").
Where("protocol NOT LIKE ?", "vnc%").
Order("created_at DESC").
Limit(limit).
Find(&sessions).Error
if err != nil {
return nil, err
}
return sessions, nil
}