Files
oneterm/backend/internal/service/file/file.go

654 lines
18 KiB
Go

package file
import (
"archive/zip"
"context"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/sftp"
"go.uber.org/zap"
"gorm.io/gorm"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
gsession "github.com/veops/oneterm/internal/session"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
)
// Global file service instance
var DefaultFileService IFileService
// InitFileService initializes the global file service
func InitFileService() {
repo := repository.NewFileRepository(dbpkg.DB)
DefaultFileService = NewFileService(repo)
}
func init() {
// Legacy file manager cleanup
go func() {
tk := time.NewTicker(time.Minute)
for {
<-tk.C
func() {
GetFileManager().mtx.Lock()
defer GetFileManager().mtx.Unlock()
for k, v := range GetFileManager().lastTime {
if v.Before(time.Now().Add(time.Minute * 10)) {
delete(GetFileManager().sftps, k)
delete(GetFileManager().lastTime, k)
}
}
}()
}
}()
// Session-based file manager cleanup
go func() {
ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes
defer ticker.Stop()
for {
<-ticker.C
// Clean up sessions inactive for more than 30 minutes
GetSessionFileManager().CleanupInactiveSessions(30 * time.Minute)
}
}()
}
// NewFileService creates a new file service instance
func NewFileService(repo repository.IFileRepository) IFileService {
return &FileService{
repo: repo,
}
}
// Legacy asset-based operations
func (s *FileService) ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.ReadDir(dir)
}
func (s *FileService) MkdirAll(ctx context.Context, assetId, accountId int, dir string) error {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return err
}
return cli.MkdirAll(dir)
}
func (s *FileService) Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
// Check if parent directory exists, create only if not exists
parentDir := filepath.Dir(path)
if parentDir != "" && parentDir != "." && parentDir != "/" {
if _, err := cli.Stat(parentDir); err != nil {
// Directory doesn't exist, create it
if err := cli.MkdirAll(parentDir); err != nil {
return nil, fmt.Errorf("failed to create parent directory %s: %w", parentDir, err)
}
}
}
return cli.Create(path)
}
func (s *FileService) Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Open(path)
}
func (s *FileService) Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Stat(path)
}
// DownloadMultiple handles downloading single file or multiple files/directories as ZIP
func (s *FileService) DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, "", 0, err
}
// Validate and sanitize all filenames for security
var sanitizedFilenames []string
for _, filename := range filenames {
sanitized, err := sanitizeFilename(filename)
if err != nil {
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
}
sanitizedFilenames = append(sanitizedFilenames, sanitized)
}
// If only one file, check if it's a regular file first
if len(sanitizedFilenames) == 1 {
fullPath := filepath.Join(dir, sanitizedFilenames[0])
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return nil, "", 0, err
}
// If it's a regular file, return directly
if !fileInfo.IsDir() {
reader, err := cli.Open(fullPath)
if err != nil {
return nil, "", 0, err
}
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
}
}
// Multiple files or contains directory, create ZIP
return s.createZipArchive(cli, dir, sanitizedFilenames)
}
// createZipArchive creates a ZIP archive containing the specified files/directories
func (s *FileService) createZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
// Generate ZIP filename
var zipName string
if len(filenames) == 1 {
zipName = filenames[0] + ".zip"
} else {
zipName = "download.zip"
}
// Use pipe for true streaming without memory buffering
pipeReader, pipeWriter := io.Pipe()
// Create ZIP in a separate goroutine
go func() {
defer pipeWriter.Close()
zipWriter := zip.NewWriter(pipeWriter)
defer zipWriter.Close()
// Add each file/directory to ZIP
for _, filename := range filenames {
fullPath := filepath.Join(baseDir, filename)
if err := s.addToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
logger.L().Error("Failed to add file to ZIP", zap.String("path", fullPath), zap.Error(err))
pipeWriter.CloseWithError(err)
return
}
}
}()
// Return pipe reader for streaming, size unknown (-1)
return pipeReader, zipName, -1, nil
}
// addToZip recursively adds files/directories to ZIP archive
func (s *FileService) addToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
if fileInfo.IsDir() {
// Add directory
return s.addDirToZip(cli, zipWriter, fullPath, relativePath)
} else {
// Add file
return s.addFileToZip(cli, zipWriter, fullPath, relativePath)
}
}
// addFileToZip adds a single file to ZIP archive
func (s *FileService) addFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Get file info first to preserve metadata
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
// Open remote file
file, err := cli.Open(fullPath)
if err != nil {
return err
}
defer file.Close()
// Create FileHeader with original file metadata
header := &zip.FileHeader{
Name: relativePath,
Method: zip.Deflate,
Modified: fileInfo.ModTime(), // Preserve original modification time
}
// Set file mode
header.SetMode(fileInfo.Mode())
// Create file in ZIP with header
zipFile, err := zipWriter.CreateHeader(header)
if err != nil {
return err
}
// Copy file content
_, err = io.Copy(zipFile, file)
return err
}
// addDirToZip recursively adds a directory to ZIP archive
func (s *FileService) addDirToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Get directory info to preserve metadata
dirInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
// Read directory contents
entries, err := cli.ReadDir(fullPath)
if err != nil {
return err
}
// If directory is empty, create directory entry with preserved timestamp
if len(entries) == 0 {
header := &zip.FileHeader{
Name: relativePath + "/",
Method: zip.Store, // Directories are not compressed
Modified: dirInfo.ModTime(), // Preserve original modification time
}
header.SetMode(dirInfo.Mode())
_, err := zipWriter.CreateHeader(header)
return err
}
// Recursively add each entry in the directory
for _, entry := range entries {
entryFullPath := filepath.Join(fullPath, entry.Name())
entryRelativePath := filepath.Join(relativePath, entry.Name())
if err := s.addToZip(cli, zipWriter, fullPath, entryRelativePath, entryFullPath); err != nil {
return err
}
}
return nil
}
// Session-based operations
func (s *FileService) SessionLS(ctx context.Context, sessionId, dir string) ([]FileInfo, error) {
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return nil, err
}
entries, err := cli.ReadDir(dir)
if err != nil {
return nil, err
}
var fileInfos []FileInfo
for _, entry := range entries {
var target string
if entry.Mode()&fs.ModeSymlink != 0 {
linkPath := filepath.Join(dir, entry.Name())
if linkTarget, err := cli.ReadLink(linkPath); err == nil {
target = linkTarget
}
}
fileInfos = append(fileInfos, FileInfo{
Name: entry.Name(),
IsDir: entry.IsDir(),
Size: entry.Size(),
Mode: entry.Mode().String(),
IsLink: entry.Mode()&fs.ModeSymlink != 0,
Target: target,
ModTime: entry.ModTime().Format(time.RFC3339),
})
}
return fileInfos, nil
}
func (s *FileService) SessionMkdir(ctx context.Context, sessionId, dir string) error {
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return err
}
return cli.MkdirAll(dir)
}
func (s *FileService) SessionUpload(ctx context.Context, sessionId, targetPath string, file io.Reader, filename string, size int64) error {
return s.SessionUploadWithID(ctx, "", sessionId, targetPath, file, filename, size)
}
func (s *FileService) SessionUploadWithID(ctx context.Context, transferID, sessionId, targetPath string, file io.Reader, filename string, size int64) error {
client, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return err
}
var transfer *SessionFileTransfer
transferManager := GetSessionTransferManager()
if transferID != "" {
// Use existing transfer or create new one
transfer = transferManager.GetTransfer(transferID)
if transfer == nil {
transfer = transferManager.CreateTransferWithID(transferID, sessionId, filename, size, true)
if transfer == nil {
return fmt.Errorf("transfer ID already exists: %s", transferID)
}
// Set initial status
transfer.UpdateProgress(0)
}
} else {
// Create new transfer with auto-generated ID
transfer = transferManager.CreateTransfer(sessionId, filename, size, true)
transfer.UpdateProgress(0)
}
parentDir := filepath.Dir(targetPath)
if parentDir != "" && parentDir != "." && parentDir != "/" {
if _, err := client.Stat(parentDir); err != nil {
if err := client.MkdirAll(parentDir); err != nil {
transfer.SetError(fmt.Errorf("failed to create parent directory %s: %w", parentDir, err))
return fmt.Errorf("failed to create parent directory %s: %w", parentDir, err)
}
}
}
remoteFile, err := client.Create(targetPath)
if err != nil {
transfer.SetError(fmt.Errorf("failed to create remote file: %w", err))
return fmt.Errorf("failed to create remote file: %w", err)
}
defer remoteFile.Close()
progressWriter := NewProgressWriter(remoteFile, transfer)
buffer := make([]byte, 32*1024)
_, err = io.CopyBuffer(progressWriter, file, buffer)
if err != nil {
transfer.SetError(fmt.Errorf("failed to upload file: %w", err))
return fmt.Errorf("failed to upload file: %w", err)
}
transfer.UpdateProgress(size)
return nil
}
func (s *FileService) SessionDownload(ctx context.Context, sessionId, filePath string) (io.ReadCloser, int64, error) {
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return nil, 0, err
}
// Get file info
info, err := cli.Stat(filePath)
if err != nil {
return nil, 0, err
}
// Open file
file, err := cli.Open(filePath)
if err != nil {
return nil, 0, err
}
return file, info.Size(), nil
}
func (s *FileService) SessionDownloadMultiple(ctx context.Context, sessionId, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
client, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return nil, "", 0, err
}
// Validate and sanitize filenames
var sanitizedFilenames []string
for _, filename := range filenames {
sanitized, err := sanitizeFilename(filename)
if err != nil {
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
}
sanitizedFilenames = append(sanitizedFilenames, sanitized)
}
// If only one file and it's not a directory, return directly
if len(sanitizedFilenames) == 1 {
fullPath := filepath.Join(dir, sanitizedFilenames[0])
fileInfo, err := client.Stat(fullPath)
if err != nil {
return nil, "", 0, err
}
if !fileInfo.IsDir() {
reader, err := client.Open(fullPath)
if err != nil {
return nil, "", 0, err
}
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
}
}
// Multiple files or contains directory, create ZIP
return s.createZipArchive(client, dir, sanitizedFilenames)
}
// Session lifecycle management
func (s *FileService) InitSessionFileClient(sessionId string, assetId, accountId int) error {
return GetSessionFileManager().InitSessionSFTP(sessionId, assetId, accountId)
}
func (s *FileService) CloseSessionFileClient(sessionId string) {
GetSessionFileManager().CloseSessionSFTP(sessionId)
}
func (s *FileService) IsSessionActive(sessionId string) bool {
return GetSessionFileManager().IsSessionActive(sessionId)
}
// Progress tracking
func (s *FileService) GetSessionTransferProgress(ctx context.Context, sessionId string) ([]*SessionFileTransferProgress, error) {
return GetSessionTransferManager().GetSessionProgress(sessionId), nil
}
func (s *FileService) GetTransferProgress(ctx context.Context, transferId string) (*SessionFileTransferProgress, error) {
transfer := GetSessionTransferManager().GetTransfer(transferId)
if transfer == nil {
return nil, fmt.Errorf("transfer not found")
}
return transfer.GetProgress(), nil
}
// File history operations
func (s *FileService) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
if s.repo == nil {
return fmt.Errorf("repository not initialized")
}
return s.repo.AddFileHistory(ctx, history)
}
func (s *FileService) BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB {
db := dbpkg.DB.Model(&model.FileHistory{})
db = dbpkg.FilterSearch(ctx, db, "dir", "filename")
// Apply exact match filters
db = dbpkg.FilterEqual(ctx, db, "status", "uid", "asset_id", "account_id", "action")
// Apply client IP filter
if clientIp := ctx.Query("client_ip"); clientIp != "" {
db = db.Where("client_ip = ?", clientIp)
}
// Apply date range filters
if start := ctx.Query("start"); start != "" {
db = db.Where("created_at >= ?", start)
}
if end := ctx.Query("end"); end != "" {
db = db.Where("created_at <= ?", end)
}
return db
}
func (s *FileService) RecordFileHistory(ctx context.Context, operation, dir, filename string, assetId, accountId int, sessionId ...string) error {
// Extract user information from context
currentUser, err := acl.GetSessionFromCtx(ctx)
if err != nil || currentUser == nil {
// If no user context, still record the operation but with empty user info
logger.L().Warn("No user context found when recording file history", zap.String("operation", operation))
}
var uid int
var userName string
var clientIP string
if currentUser != nil {
uid = currentUser.GetUid()
userName = currentUser.GetUserName()
}
// Get client IP from gin context
if ginCtx, ok := ctx.(*gin.Context); ok {
clientIP = ginCtx.ClientIP()
}
history := &model.FileHistory{
Uid: uid,
UserName: userName,
AssetId: assetId,
AccountId: accountId,
ClientIp: clientIP,
Action: s.GetActionCode(operation),
Dir: dir,
Filename: filename,
}
if err := s.AddFileHistory(ctx, history); err != nil {
// Log error details including sessionId if provided for debugging
sessionIdStr := ""
if len(sessionId) > 0 {
sessionIdStr = sessionId[0]
}
logger.L().Error("Failed to record file history",
zap.Error(err),
zap.String("operation", operation),
zap.String("sessionId", sessionIdStr),
zap.Any("history", history))
return err
}
return nil
}
func (s *FileService) RecordFileHistoryBySession(ctx context.Context, sessionId, operation, path string) error {
// Get session info to extract asset and account information
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
logger.L().Warn("Cannot record file history: session not found", zap.String("sessionId", sessionId))
return fmt.Errorf("session not found: %s", sessionId)
}
// Extract directory and filename from path
dir := filepath.Dir(path)
filename := filepath.Base(path)
return s.RecordFileHistory(ctx, operation, dir, filename, onlineSession.AssetId, onlineSession.AccountId, sessionId)
}
// Utility methods
func (s *FileService) GetActionCode(operation string) int {
switch operation {
case "upload":
return model.FILE_ACTION_UPLOAD
case "download":
return model.FILE_ACTION_DOWNLOAD
case "mkdir":
return model.FILE_ACTION_MKDIR
default:
return model.FILE_ACTION_LS
}
}
func (s *FileService) ValidateAndNormalizePath(basePath, userPath string) (string, error) {
// Clean the user-provided path
cleanPath := filepath.Clean(userPath)
// Ensure it doesn't contain directory traversal attempts
if strings.Contains(cleanPath, "..") {
return "", fmt.Errorf("path contains directory traversal: %s", userPath)
}
// Join with base path and clean again
fullPath := filepath.Join(basePath, cleanPath)
// Ensure the resulting path is still within the base directory
relPath, err := filepath.Rel(basePath, fullPath)
if err != nil {
return "", fmt.Errorf("invalid path: %s", userPath)
}
if strings.HasPrefix(relPath, "..") {
return "", fmt.Errorf("path outside base directory: %s", userPath)
}
return fullPath, nil
}
func (s *FileService) GetRDPDrivePath(assetId int) (string, error) {
var drivePath string
// Priority 1: Get from environment variable (highest priority)
drivePath = os.Getenv("ONETERM_RDP_DRIVE_PATH")
// Priority 2: Use default path based on OS
if drivePath == "" {
if runtime.GOOS == "windows" {
drivePath = filepath.Join("C:", "temp", "oneterm", "rdp")
} else {
drivePath = filepath.Join("/tmp", "oneterm", "rdp")
}
}
// Create asset-specific subdirectory
fullDrivePath := filepath.Join(drivePath, fmt.Sprintf("asset_%d", assetId))
// Ensure directory exists
if err := os.MkdirAll(fullDrivePath, 0755); err != nil {
return "", fmt.Errorf("failed to create RDP drive directory %s: %w", fullDrivePath, err)
}
// Clear macOS extended attributes that might interfere with Docker volume mounting
if runtime.GOOS == "darwin" {
// Clear attributes for the directory and all its contents
exec.Command("find", fullDrivePath, "-exec", "xattr", "-c", "{}", ";").Run()
}
return fullDrivePath, nil
}