mirror of
https://github.com/veops/oneterm.git
synced 2025-10-05 07:16:57 +08:00
231 lines
6.6 KiB
Go
231 lines
6.6 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/veops/oneterm/internal/model"
|
|
gsession "github.com/veops/oneterm/internal/session"
|
|
dbpkg "github.com/veops/oneterm/pkg/db"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// SessionRepository defines the interface for session repository
|
|
type SessionRepository interface {
|
|
GetSession(ctx context.Context, sessionId string) (*model.Session, error)
|
|
BuildQuery(ctx *gin.Context, isAdmin bool, uid int) (*gorm.DB, error)
|
|
BuildCmdQuery(ctx *gin.Context, sessionId string) *gorm.DB
|
|
GetSessionOptionAssets(ctx context.Context) ([]*model.SessionOptionAsset, error)
|
|
GetSessionOptionClientIps(ctx context.Context) ([]string, error)
|
|
CreateSessionCmd(ctx context.Context, cmd *model.SessionCmd) error
|
|
GetSessionCmdCounts(ctx context.Context, sessionIds []string) (map[string]int64, error)
|
|
GetOnlineSessionByID(ctx context.Context, sessionID string) (*gsession.Session, error)
|
|
GetSshParserCommands(ctx context.Context, cmdIDs []int) ([]*model.Command, error)
|
|
// GetRecentSessionsByUser retrieves recent sessions deduplicated by asset_id and account_id combination
|
|
GetRecentSessionsByUser(ctx context.Context, uid int, limit int) ([]*model.Session, error)
|
|
}
|
|
|
|
type sessionRepository struct{}
|
|
|
|
// NewSessionRepository creates a new session repository
|
|
func NewSessionRepository() SessionRepository {
|
|
return &sessionRepository{}
|
|
}
|
|
|
|
// GetSession retrieves a session by session ID
|
|
func (r *sessionRepository) GetSession(ctx context.Context, sessionId string) (*model.Session, error) {
|
|
session := &model.Session{}
|
|
if err := dbpkg.DB.Where("session_id = ?", sessionId).First(session).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
// BuildQuery constructs a query for sessions with filters
|
|
func (r *sessionRepository) BuildQuery(ctx *gin.Context, isAdmin bool, uid int) (*gorm.DB, error) {
|
|
db := dbpkg.DB.Model(model.DefaultSession)
|
|
|
|
// Apply user filter if not admin
|
|
if !isAdmin {
|
|
db = db.Where("uid = ?", uid)
|
|
}
|
|
|
|
// Apply text search
|
|
if q, ok := ctx.GetQuery("search"); ok && q != "" {
|
|
db = db.Where("user_name LIKE ? OR asset_info LIKE ? OR gateway_info LIKE ? OR account_info LIKE ?",
|
|
"%"+q+"%", "%"+q+"%", "%"+q+"%", "%"+q+"%")
|
|
}
|
|
|
|
// Apply date range filters
|
|
if start, ok := ctx.GetQuery("start"); ok {
|
|
t, err := time.Parse(time.RFC3339, start)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
db = db.Where("created_at >= ?", t)
|
|
}
|
|
|
|
if end, ok := ctx.GetQuery("end"); ok {
|
|
t, err := time.Parse(time.RFC3339, end)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
db = db.Where("created_at <= ?", t)
|
|
}
|
|
|
|
// Apply exact match filters
|
|
for _, field := range []string{"status", "uid", "asset_id", "client_ip"} {
|
|
if q, ok := ctx.GetQuery(field); ok && q != "" {
|
|
db = db.Where(field+" = ?", q)
|
|
}
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// BuildCmdQuery constructs a query for session commands
|
|
func (r *sessionRepository) BuildCmdQuery(ctx *gin.Context, sessionId string) *gorm.DB {
|
|
db := dbpkg.DB.Model(&model.SessionCmd{})
|
|
db = db.Where("session_id = ?", sessionId)
|
|
|
|
// Apply text search
|
|
if q, ok := ctx.GetQuery("search"); ok && q != "" {
|
|
db = db.Where("cmd LIKE ? OR result LIKE ?", "%"+q+"%", "%"+q+"%")
|
|
}
|
|
|
|
return db
|
|
}
|
|
|
|
// GetSessionOptionAssets retrieves session option assets
|
|
func (r *sessionRepository) GetSessionOptionAssets(ctx context.Context) ([]*model.SessionOptionAsset, error) {
|
|
opts := make([]*model.SessionOptionAsset, 0)
|
|
if err := dbpkg.DB.
|
|
Model(model.DefaultAsset).
|
|
Select("id, name").
|
|
Find(&opts).
|
|
Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return opts, nil
|
|
}
|
|
|
|
// GetSessionOptionClientIps retrieves distinct client IPs
|
|
func (r *sessionRepository) GetSessionOptionClientIps(ctx context.Context) ([]string, error) {
|
|
opts := make([]string, 0)
|
|
if err := dbpkg.DB.
|
|
Model(model.DefaultSession).
|
|
Distinct("client_ip").
|
|
Find(&opts).
|
|
Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return opts, nil
|
|
}
|
|
|
|
// CreateSessionCmd creates a new session command
|
|
func (r *sessionRepository) CreateSessionCmd(ctx context.Context, cmd *model.SessionCmd) error {
|
|
return dbpkg.DB.Create(cmd).Error
|
|
}
|
|
|
|
// GetSessionCmdCounts retrieves command counts for sessions
|
|
func (r *sessionRepository) GetSessionCmdCounts(ctx context.Context, sessionIds []string) (map[string]int64, error) {
|
|
if len(sessionIds) <= 0 {
|
|
return map[string]int64{}, nil
|
|
}
|
|
|
|
post := make([]*model.CmdCount, 0)
|
|
if err := dbpkg.DB.
|
|
Model(&model.SessionCmd{}).
|
|
Select("session_id, COUNT(*) AS count").
|
|
Where("session_id IN ?", sessionIds).
|
|
Group("session_id").
|
|
Find(&post).
|
|
Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Convert to map
|
|
result := make(map[string]int64)
|
|
for _, p := range post {
|
|
result[p.SessionId] = p.Count
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// GetOnlineSessionByID retrieves an online session by ID
|
|
func (r *sessionRepository) GetOnlineSessionByID(ctx context.Context, sessionID string) (*gsession.Session, error) {
|
|
session := &gsession.Session{}
|
|
err := dbpkg.DB.
|
|
Model(session).
|
|
Where("session_id = ?", sessionID).
|
|
Where("status = ?", model.SESSIONSTATUS_ONLINE).
|
|
First(session).
|
|
Error
|
|
return session, err
|
|
}
|
|
|
|
// GetSshParserCommands retrieves SSH parser commands by IDs
|
|
func (r *sessionRepository) GetSshParserCommands(ctx context.Context, cmdIDs []int) ([]*model.Command, error) {
|
|
var commands []*model.Command
|
|
err := dbpkg.DB.
|
|
Where("id IN ? AND enable=?", cmdIDs, true).
|
|
Find(&commands).
|
|
Error
|
|
return commands, err
|
|
}
|
|
|
|
// GetRecentSessionsByUser retrieves recent sessions for a user, deduplicated by asset_id and account_id
|
|
func (r *sessionRepository) GetRecentSessionsByUser(ctx context.Context, uid int, limit int) ([]*model.Session, error) {
|
|
var sessions []*model.Session
|
|
|
|
// First, get the MAX session ID for each asset+account combination
|
|
// This approach avoids LIMIT in subquery which MySQL doesn't support
|
|
type MaxSession struct {
|
|
AssetId int
|
|
AccountId int
|
|
MaxId int
|
|
}
|
|
|
|
var maxSessions []MaxSession
|
|
err := dbpkg.DB.Model(&model.Session{}).
|
|
Select("asset_id, account_id, MAX(id) as max_id").
|
|
Where("uid = ?", uid).
|
|
Where("asset_id > 0").
|
|
Where("account_id > 0").
|
|
Where("protocol NOT LIKE ?", "rdp%").
|
|
Where("protocol NOT LIKE ?", "vnc%").
|
|
Group("asset_id, account_id").
|
|
Find(&maxSessions).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Extract the IDs
|
|
var sessionIds []int
|
|
for _, ms := range maxSessions {
|
|
sessionIds = append(sessionIds, ms.MaxId)
|
|
}
|
|
|
|
if len(sessionIds) == 0 {
|
|
return sessions, nil
|
|
}
|
|
|
|
// Get the full session records for those IDs
|
|
err = dbpkg.DB.Model(&model.Session{}).
|
|
Where("id IN ?", sessionIds).
|
|
Where("protocol NOT LIKE ?", "rdp%").
|
|
Where("protocol NOT LIKE ?", "vnc%").
|
|
Order("created_at DESC").
|
|
Limit(limit).
|
|
Find(&sessions).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return sessions, nil
|
|
}
|