refactor(backend): file service

This commit is contained in:
pycook
2025-06-08 20:38:02 +08:00
parent 070fccb5db
commit 6ba8d8056a
13 changed files with 2456 additions and 2201 deletions

View File

@@ -12,6 +12,7 @@ import (
"github.com/veops/oneterm/internal/api/router"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
fileservice "github.com/veops/oneterm/internal/service/file"
"github.com/veops/oneterm/pkg/config"
"github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
@@ -45,7 +46,7 @@ func initDB() {
func initServices() {
service.InitAuthorizationService()
service.InitFileService()
fileservice.InitFileService()
}

View File

@@ -21,7 +21,7 @@ import (
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/guacd"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
fileservice "github.com/veops/oneterm/internal/service/file"
gsession "github.com/veops/oneterm/internal/session"
myErrors "github.com/veops/oneterm/pkg/errors"
"github.com/veops/oneterm/pkg/logger"
@@ -51,7 +51,7 @@ const (
func (c *Controller) GetFileHistory(ctx *gin.Context) {
currentUser, _ := acl.GetSessionFromCtx(ctx)
db := service.DefaultFileService.BuildFileHistoryQuery(ctx)
db := fileservice.DefaultFileService.BuildFileHistoryQuery(ctx)
// Apply user permissions - non-admin users can only see their own history
if !acl.IsAdmin(currentUser) {
@@ -87,9 +87,9 @@ func (c *Controller) FileLS(ctx *gin.Context) {
}
// Use global file service
info, err := service.DefaultFileService.ReadDir(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"))
info, err := fileservice.DefaultFileService.ReadDir(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"))
if err != nil {
if service.IsPermissionError(err) {
if fileservice.IsPermissionError(err) {
ctx.AbortWithError(http.StatusForbidden, fmt.Errorf("permission denied"))
} else {
ctx.AbortWithError(http.StatusBadRequest, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}})
@@ -110,7 +110,7 @@ func (c *Controller) FileLS(ctx *gin.Context) {
List: lo.Map(info, func(f fs.FileInfo, _ int) any {
var target string
if f.Mode()&os.ModeSymlink != 0 {
cli, err := service.GetFileManager().GetFileClient(sess.Session.AssetId, sess.Session.AccountId)
cli, err := fileservice.GetFileManager().GetFileClient(sess.Session.AssetId, sess.Session.AccountId)
if err == nil {
linkPath := filepath.Join(ctx.Query("dir"), f.Name())
if linkTarget, err := cli.ReadLink(linkPath); err == nil {
@@ -118,7 +118,7 @@ func (c *Controller) FileLS(ctx *gin.Context) {
}
}
}
return &service.FileInfo{
return &fileservice.FileInfo{
Name: f.Name(),
IsDir: f.IsDir(),
Size: f.Size(),
@@ -157,13 +157,13 @@ func (c *Controller) FileMkdir(ctx *gin.Context) {
}
// Use global file service
if err := service.DefaultFileService.MkdirAll(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir")); err != nil {
if err := fileservice.DefaultFileService.MkdirAll(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir")); err != nil {
ctx.AbortWithError(http.StatusInternalServerError, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}})
return
}
// Record file history using unified method
if err := service.DefaultFileService.RecordFileHistory(ctx, "mkdir", ctx.Query("dir"), "", sess.Session.AssetId, sess.Session.AccountId); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistory(ctx, "mkdir", ctx.Query("dir"), "", sess.Session.AssetId, sess.Session.AccountId); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -209,7 +209,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
transferId = fmt.Sprintf("%d-%d-%d", sess.Session.AssetId, sess.Session.AccountId, time.Now().UnixNano())
}
service.CreateTransferProgress(transferId, "sftp")
fileservice.CreateTransferProgress(transferId, "sftp")
// Parse multipart form
if err := ctx.Request.ParseMultipartForm(MaxMemoryForParsing); err != nil {
@@ -243,7 +243,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
}
// Update transfer progress with file size
service.UpdateTransferProgress(transferId, fileSize, 0, "")
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "")
targetPath := filepath.Join(targetDir, filename)
@@ -290,10 +290,10 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
}
// Phase 2: Transfer to target machine using SFTP (synchronous)
service.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
if err := service.TransferToTarget(transferId, "", tempFilePath, targetPath, sess.Session.AssetId, sess.Session.AccountId); err != nil {
service.UpdateTransferProgress(transferId, 0, -1, "failed")
if err := fileservice.TransferToTarget(transferId, "", tempFilePath, targetPath, sess.Session.AssetId, sess.Session.AccountId); err != nil {
fileservice.UpdateTransferProgress(transferId, 0, -1, "failed")
os.Remove(tempFilePath)
ctx.JSON(http.StatusInternalServerError, HttpResponse{
Code: http.StatusInternalServerError,
@@ -303,13 +303,13 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
}
// Mark transfer as completed
service.UpdateTransferProgress(transferId, 0, -1, "completed")
fileservice.UpdateTransferProgress(transferId, 0, -1, "completed")
// Clean up temp file after successful transfer
os.Remove(tempFilePath)
// Record file history using unified method
if err := service.DefaultFileService.RecordFileHistory(ctx, "upload", targetDir, filename, sess.Session.AssetId, sess.Session.AccountId); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistory(ctx, "upload", targetDir, filename, sess.Session.AssetId, sess.Session.AccountId); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -329,7 +329,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
// Clean up progress record after a short delay
go func() {
time.Sleep(30 * time.Second) // Keep for 30 seconds for any delayed queries
service.CleanupTransferProgress(transferId, 0)
fileservice.CleanupTransferProgress(transferId, 0)
}()
}
@@ -378,9 +378,9 @@ func (c *Controller) FileDownload(ctx *gin.Context) {
return
}
reader, downloadFilename, fileSize, err := service.DefaultFileService.DownloadMultiple(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"), filenames)
reader, downloadFilename, fileSize, err := fileservice.DefaultFileService.DownloadMultiple(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"), filenames)
if err != nil {
if service.IsPermissionError(err) {
if fileservice.IsPermissionError(err) {
ctx.AbortWithError(http.StatusForbidden, &myErrors.ApiError{Code: myErrors.ErrNoPerm, Data: map[string]any{"err": err}})
} else {
ctx.AbortWithError(http.StatusInternalServerError, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}})
@@ -390,7 +390,7 @@ func (c *Controller) FileDownload(ctx *gin.Context) {
defer reader.Close()
// Record file operation history using unified method
if err := service.DefaultFileService.RecordFileHistory(ctx, "download", ctx.Query("dir"), strings.Join(filenames, ","), sess.Session.AssetId, sess.Session.AccountId); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistory(ctx, "download", ctx.Query("dir"), strings.Join(filenames, ","), sess.Session.AssetId, sess.Session.AccountId); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -448,7 +448,7 @@ func (c *Controller) RDPFileList(ctx *gin.Context) {
}
// Check if RDP drive is enabled
if !service.IsRDPDriveEnabled(tunnel) {
if !fileservice.IsRDPDriveEnabled(tunnel) {
ctx.JSON(http.StatusBadRequest, HttpResponse{
Code: http.StatusBadRequest,
Message: "RDP drive is not enabled for this session",
@@ -457,7 +457,7 @@ func (c *Controller) RDPFileList(ctx *gin.Context) {
}
// Send file list request through Guacamole protocol
files, err := service.RequestRDPFileList(tunnel, path)
files, err := fileservice.RequestRDPFileList(tunnel, path)
if err != nil {
logger.L().Error("Failed to get RDP file list", zap.Error(err))
ctx.JSON(http.StatusInternalServerError, HttpResponse{
@@ -501,7 +501,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
}
// Create progress record IMMEDIATELY when request starts
service.CreateTransferProgress(transferId, "rdp")
fileservice.CreateTransferProgress(transferId, "rdp")
tunnel, err := c.validateRDPAccess(ctx, sessionId)
if err != nil {
@@ -519,7 +519,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
return
}
if !service.IsRDPDriveEnabled(tunnel) {
if !fileservice.IsRDPDriveEnabled(tunnel) {
logger.L().Error("RDP drive is not enabled for session", zap.String("sessionId", sessionId))
ctx.JSON(http.StatusBadRequest, HttpResponse{
Code: http.StatusBadRequest,
@@ -528,7 +528,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
return
}
if !service.IsRDPUploadAllowed(tunnel) {
if !fileservice.IsRDPUploadAllowed(tunnel) {
logger.L().Error("RDP upload is disabled for session", zap.String("sessionId", sessionId))
ctx.JSON(http.StatusForbidden, HttpResponse{
Code: http.StatusForbidden,
@@ -653,7 +653,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
fullPath := filepath.Join(targetPath, filename)
// Update progress record with file size
service.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
// Open temp file for reading and upload synchronously
tempFile, err := os.Open(tempFilePath)
@@ -668,9 +668,9 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
defer tempFile.Close()
// Perform RDP upload synchronously
err = service.UploadRDPFileStreamWithID(tunnel, transferId, sessionId, fullPath, tempFile, fileSize)
err = fileservice.UploadRDPFileStreamWithID(tunnel, transferId, sessionId, fullPath, tempFile, fileSize)
if err != nil {
service.UpdateTransferProgress(transferId, 0, -1, "failed")
fileservice.UpdateTransferProgress(transferId, 0, -1, "failed")
os.Remove(tempFilePath)
ctx.JSON(http.StatusInternalServerError, HttpResponse{
Code: http.StatusInternalServerError,
@@ -683,7 +683,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
os.Remove(tempFilePath)
// Record file history using session-based method
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", fullPath); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", fullPath); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -708,7 +708,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
// Clean up progress record after a short delay
go func() {
time.Sleep(30 * time.Second) // Keep for 30 seconds for any delayed queries
service.CleanupTransferProgress(transferId, 0)
fileservice.CleanupTransferProgress(transferId, 0)
}()
}
@@ -742,7 +742,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
return
}
if !service.IsRDPDriveEnabled(tunnel) {
if !fileservice.IsRDPDriveEnabled(tunnel) {
ctx.JSON(http.StatusForbidden, HttpResponse{
Code: http.StatusForbidden,
Message: "Drive redirection not enabled",
@@ -750,7 +750,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
return
}
if !service.IsRDPDownloadAllowed(tunnel) {
if !fileservice.IsRDPDownloadAllowed(tunnel) {
ctx.JSON(http.StatusForbidden, HttpResponse{
Code: http.StatusForbidden,
Message: "File download not allowed",
@@ -803,7 +803,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
if len(filenames) == 1 {
// Single file download (memory-efficient streaming)
path := filepath.Join(dir, filenames[0])
reader, fileSize, err = service.DownloadRDPFile(tunnel, path)
reader, fileSize, err = fileservice.DownloadRDPFile(tunnel, path)
if err != nil {
ctx.JSON(http.StatusInternalServerError, HttpResponse{
Code: http.StatusInternalServerError,
@@ -815,7 +815,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
downloadFilename = filenames[0]
} else {
// Multiple files download as ZIP
reader, downloadFilename, fileSize, err = service.DownloadRDPMultiple(tunnel, dir, filenames)
reader, downloadFilename, fileSize, err = fileservice.DownloadRDPMultiple(tunnel, dir, filenames)
if err != nil {
ctx.JSON(http.StatusInternalServerError, HttpResponse{
Code: http.StatusInternalServerError,
@@ -827,7 +827,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
defer reader.Close()
// Record file operation history using session-based method
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(dir, strings.Join(filenames, ","))); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(dir, strings.Join(filenames, ","))); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -861,13 +861,13 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
// @Accept json
// @Produce json
// @Param session_id path string true "Session ID"
// @Param request body service.RDPMkdirRequest true "Directory creation request"
// @Param request body fileservice.RDPMkdirRequest true "Directory creation request"
// @Success 200 {object} HttpResponse
// @Router /rdp/sessions/{session_id}/files/mkdir [post]
func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
sessionId := ctx.Param("session_id")
var req service.RDPMkdirRequest
var req fileservice.RDPMkdirRequest
if err := ctx.ShouldBindJSON(&req); err != nil {
ctx.JSON(http.StatusBadRequest, HttpResponse{
Code: http.StatusBadRequest,
@@ -893,7 +893,7 @@ func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
}
// Check if upload is allowed (mkdir is considered an upload operation)
if !service.IsRDPUploadAllowed(tunnel) {
if !fileservice.IsRDPUploadAllowed(tunnel) {
ctx.JSON(http.StatusForbidden, HttpResponse{
Code: http.StatusForbidden,
Message: "Directory creation is disabled for this session",
@@ -902,7 +902,7 @@ func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
}
// Send mkdir request through Guacamole protocol
err := service.CreateRDPDirectory(tunnel, req.Path)
err := fileservice.CreateRDPDirectory(tunnel, req.Path)
if err != nil {
logger.L().Error("Failed to create directory in RDP session", zap.Error(err))
ctx.JSON(http.StatusInternalServerError, HttpResponse{
@@ -913,7 +913,7 @@ func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
}
// Record file operation history using session-based method
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", req.Path); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", req.Path); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -967,7 +967,7 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
}
// Check if session is active
if !service.DefaultFileService.IsSessionActive(sessionId) {
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found or inactive",
@@ -995,14 +995,14 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
}
// Use session-based file service
fileInfos, err := service.DefaultFileService.SessionLS(ctx, sessionId, dir)
fileInfos, err := fileservice.DefaultFileService.SessionLS(ctx, sessionId, dir)
if err != nil {
if errors.Is(err, service.ErrSessionNotFound) {
if errors.Is(err, fileservice.ErrSessionNotFound) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found",
})
} else if service.IsPermissionError(err) {
} else if fileservice.IsPermissionError(err) {
ctx.JSON(http.StatusForbidden, HttpResponse{
Code: http.StatusForbidden,
Message: "Permission denied",
@@ -1019,7 +1019,7 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
// Filter hidden files unless show_hidden is true
showHidden := cast.ToBool(ctx.Query("show_hidden"))
if !showHidden {
var filtered []service.FileInfo
var filtered []fileservice.FileInfo
for _, f := range fileInfos {
if !strings.HasPrefix(f.Name, ".") {
filtered = append(filtered, f)
@@ -1030,7 +1030,7 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
res := &ListData{
Count: int64(len(fileInfos)),
List: lo.Map(fileInfos, func(f service.FileInfo, _ int) any { return f }),
List: lo.Map(fileInfos, func(f fileservice.FileInfo, _ int) any { return f }),
}
ctx.JSON(http.StatusOK, NewHttpResponseWithData(res))
}
@@ -1055,7 +1055,7 @@ func (c *Controller) SftpFileMkdir(ctx *gin.Context) {
}
// Check if session is active
if !service.DefaultFileService.IsSessionActive(sessionId) {
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found or inactive",
@@ -1083,8 +1083,8 @@ func (c *Controller) SftpFileMkdir(ctx *gin.Context) {
}
// Use session-based file service
if err := service.DefaultFileService.SessionMkdir(ctx, sessionId, dir); err != nil {
if errors.Is(err, service.ErrSessionNotFound) {
if err := fileservice.DefaultFileService.SessionMkdir(ctx, sessionId, dir); err != nil {
if errors.Is(err, fileservice.ErrSessionNotFound) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found",
@@ -1099,7 +1099,7 @@ func (c *Controller) SftpFileMkdir(ctx *gin.Context) {
}
// Record history using session-based method
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", dir); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", dir); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -1113,7 +1113,7 @@ func (c *Controller) TransferProgressById(ctx *gin.Context) {
transferId := ctx.Param("transfer_id")
// First check unified progress tracking
progress, exists := service.GetTransferProgressById(transferId)
progress, exists := fileservice.GetTransferProgressById(transferId)
if exists {
// Calculate transfer progress
@@ -1136,7 +1136,7 @@ func (c *Controller) TransferProgressById(ctx *gin.Context) {
}
// Fallback: check RDP guacd transfer manager
rdpProgress, err := service.GetRDPTransferProgressById(transferId)
rdpProgress, err := fileservice.GetRDPTransferProgressById(transferId)
if err == nil {
ctx.JSON(http.StatusOK, HttpResponse{
Code: 0,
@@ -1181,7 +1181,7 @@ func (c *Controller) RDPFileTransferPrepare(ctx *gin.Context) {
}
// Create unified progress tracking entry
service.CreateTransferProgress(transferId, "rdp")
fileservice.CreateTransferProgress(transferId, "rdp")
ctx.JSON(http.StatusOK, HttpResponse{
Code: 0,
@@ -1221,10 +1221,10 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
transferId = fmt.Sprintf("%s-%d", sessionId, time.Now().UnixNano())
}
service.CreateTransferProgress(transferId, "sftp")
fileservice.CreateTransferProgress(transferId, "sftp")
// Validate session
if !service.DefaultFileService.IsSessionActive(sessionId) {
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found or inactive",
@@ -1280,7 +1280,7 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
}
// Update transfer progress with file size now that we have it
service.UpdateTransferProgress(transferId, fileSize, 0, "")
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "")
// Phase 1: Save file to server temp directory
tempDir := filepath.Join(os.TempDir(), "oneterm-uploads", sessionId)
@@ -1327,11 +1327,11 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
targetPath := filepath.Join(targetDir, filename)
// Phase 2: Transfer to target machine using SFTP (synchronous)
service.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
if err := service.TransferToTarget(transferId, sessionId, tempFilePath, targetPath, 0, 0); err != nil {
if err := fileservice.TransferToTarget(transferId, sessionId, tempFilePath, targetPath, 0, 0); err != nil {
// Mark transfer as failed and clean up
service.UpdateTransferProgress(transferId, 0, -1, "failed")
fileservice.UpdateTransferProgress(transferId, 0, -1, "failed")
os.Remove(tempFilePath)
ctx.JSON(http.StatusInternalServerError, HttpResponse{
Code: http.StatusInternalServerError,
@@ -1341,13 +1341,13 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
}
// Mark transfer as completed (success)
service.UpdateTransferProgress(transferId, 0, -1, "completed")
fileservice.UpdateTransferProgress(transferId, 0, -1, "completed")
// Clean up temp file after successful transfer
os.Remove(tempFilePath)
// Record file history using session-based method
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", filepath.Join(targetDir, filename)); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", filepath.Join(targetDir, filename)); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}
@@ -1367,7 +1367,7 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
// Clean up progress record after a short delay
go func() {
time.Sleep(30 * time.Second) // Keep for 30 seconds for any delayed queries
service.CleanupTransferProgress(transferId, 0)
fileservice.CleanupTransferProgress(transferId, 0)
}()
}
@@ -1383,7 +1383,7 @@ func (c *Controller) SftpFileDownload(ctx *gin.Context) {
sessionId := ctx.Param("session_id")
// Check if session is active
if !service.DefaultFileService.IsSessionActive(sessionId) {
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found or inactive",
@@ -1436,14 +1436,14 @@ func (c *Controller) SftpFileDownload(ctx *gin.Context) {
return
}
reader, downloadFilename, fileSize, err := service.DefaultFileService.SessionDownloadMultiple(ctx, sessionId, ctx.Query("dir"), filenames)
reader, downloadFilename, fileSize, err := fileservice.DefaultFileService.SessionDownloadMultiple(ctx, sessionId, ctx.Query("dir"), filenames)
if err != nil {
if errors.Is(err, service.ErrSessionNotFound) {
if errors.Is(err, fileservice.ErrSessionNotFound) {
ctx.JSON(http.StatusNotFound, HttpResponse{
Code: http.StatusNotFound,
Message: "Session not found",
})
} else if service.IsPermissionError(err) {
} else if fileservice.IsPermissionError(err) {
ctx.JSON(http.StatusForbidden, HttpResponse{
Code: http.StatusForbidden,
Message: "Permission denied",
@@ -1459,7 +1459,7 @@ func (c *Controller) SftpFileDownload(ctx *gin.Context) {
defer reader.Close()
// Record file operation history using session-based method
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(ctx.Query("dir"), strings.Join(filenames, ","))); err != nil {
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(ctx.Query("dir"), strings.Join(filenames, ","))); err != nil {
logger.L().Error("Failed to record file history", zap.Error(err))
}

View File

@@ -22,7 +22,9 @@ import (
"github.com/veops/oneterm/internal/connector/protocols"
"github.com/veops/oneterm/internal/connector/protocols/db"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/internal/service"
fileservice "github.com/veops/oneterm/internal/service/file"
gsession "github.com/veops/oneterm/internal/session"
myErrors "github.com/veops/oneterm/pkg/errors"
"github.com/veops/oneterm/pkg/logger"
@@ -163,7 +165,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
currentUser, _ := acl.GetSessionFromCtx(ctx)
assetId, accountId := cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))
asset, account, gateway, err := service.GetAAG(assetId, accountId)
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
if err != nil {
return
}
@@ -265,7 +267,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
// Only for SSH-based protocols that support SFTP
protocol := strings.Split(sess.Protocol, ":")[0]
if protocol == "ssh" {
if err := service.DefaultFileService.InitSessionFileClient(sess.SessionId, sess.AssetId, sess.AccountId); err != nil {
if err := fileservice.DefaultFileService.InitSessionFileClient(sess.SessionId, sess.AssetId, sess.AccountId); err != nil {
logger.L().Warn("Failed to initialize session file client",
zap.String("sessionId", sess.SessionId),
zap.Int("assetId", sess.AssetId),
@@ -296,7 +298,7 @@ func HandleTerm(sess *gsession.Session, ctx *gin.Context) (err error) {
// Clean up session-based file client (only for SSH-based protocols)
protocol := strings.Split(sess.Protocol, ":")[0]
if protocol == "ssh" {
service.DefaultFileService.CloseSessionFileClient(sess.SessionId)
fileservice.DefaultFileService.CloseSessionFileClient(sess.SessionId)
// Clear SSH client from session to ensure proper cleanup
sess.ClearSSHClient()
}

View File

@@ -12,7 +12,7 @@ import (
gossh "golang.org/x/crypto/ssh"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
"github.com/veops/oneterm/internal/repository"
gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/internal/tunneling"
"github.com/veops/oneterm/pkg/logger"
@@ -33,7 +33,7 @@ func ConnectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
return
}
auth, err := service.GetAuth(account)
auth, err := repository.GetAuth(account)
if err != nil {
return
}

View File

@@ -16,7 +16,7 @@ import (
myi18n "github.com/veops/oneterm/internal/i18n"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/service"
fileservice "github.com/veops/oneterm/internal/service/file"
gsession "github.com/veops/oneterm/internal/session"
myErrors "github.com/veops/oneterm/pkg/errors"
"github.com/veops/oneterm/pkg/logger"
@@ -227,7 +227,7 @@ func OfflineSession(ctx *gin.Context, sessionId string, closer string) {
defer gsession.GetOnlineSession().Delete(sessionId)
// Clean up session-based file client
service.DefaultFileService.CloseSessionFileClient(sessionId)
fileservice.DefaultFileService.CloseSessionFileClient(sessionId)
session := gsession.GetOnlineSessionById(sessionId)
if session == nil {

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"errors"
"fmt"
"strings"
"github.com/gin-gonic/gin"
@@ -12,6 +13,7 @@ import (
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"golang.org/x/crypto/ssh"
"gorm.io/gorm"
)
@@ -145,3 +147,27 @@ func HandleAccountIds(ctx context.Context, dbFind *gorm.DB, resIds []int) (db *g
return
}
// GetAuth creates SSH authentication method from account credentials
func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
switch account.AccountType {
case model.AUTHMETHOD_PASSWORD:
return ssh.Password(account.Password), nil
case model.AUTHMETHOD_PUBLICKEY:
if account.Phrase == "" {
pk, err := ssh.ParsePrivateKey([]byte(account.Pk))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
} else {
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
}
default:
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/pkg/config"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/utils"
"gorm.io/gorm"
)
@@ -279,3 +280,27 @@ func HandleAssetIds(ctx context.Context, dbFind *gorm.DB, resIds []int) (db *gor
return
}
// GetAAG retrieves Asset, Account, and Gateway by their IDs with decrypted credentials
func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) {
asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{}
if err = dbpkg.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil {
return
}
if err = dbpkg.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil {
return
}
account.Password = utils.DecryptAES(account.Password)
account.Pk = utils.DecryptAES(account.Pk)
account.Phrase = utils.DecryptAES(account.Phrase)
if asset.GatewayId != 0 {
if err = dbpkg.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
return
}
gateway.Password = utils.DecryptAES(gateway.Password)
gateway.Pk = utils.DecryptAES(gateway.Pk)
gateway.Phrase = utils.DecryptAES(gateway.Phrase)
}
return
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,666 @@
package file
import (
"archive/zip"
"context"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/sftp"
"go.uber.org/zap"
"gorm.io/gorm"
"github.com/veops/oneterm/internal/acl"
"github.com/veops/oneterm/internal/model"
gsession "github.com/veops/oneterm/internal/session"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
)
// Global file service instance
var DefaultFileService IFileService
// InitFileService initializes the global file service
func InitFileService() {
DefaultFileService = NewFileService(&FileRepository{
db: dbpkg.DB,
})
}
func init() {
// Legacy file manager cleanup
go func() {
tk := time.NewTicker(time.Minute)
for {
<-tk.C
func() {
GetFileManager().mtx.Lock()
defer GetFileManager().mtx.Unlock()
for k, v := range GetFileManager().lastTime {
if v.Before(time.Now().Add(time.Minute * 10)) {
delete(GetFileManager().sftps, k)
delete(GetFileManager().lastTime, k)
}
}
}()
}
}()
// Session-based file manager cleanup
go func() {
ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes
defer ticker.Stop()
for {
<-ticker.C
// Clean up sessions inactive for more than 30 minutes
GetSessionFileManager().CleanupInactiveSessions(30 * time.Minute)
}
}()
}
// NewFileService creates a new file service instance
func NewFileService(repo IFileRepository) IFileService {
return &FileService{
repo: repo,
}
}
// Legacy asset-based operations
func (s *FileService) ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.ReadDir(dir)
}
func (s *FileService) MkdirAll(ctx context.Context, assetId, accountId int, dir string) error {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return err
}
return cli.MkdirAll(dir)
}
func (s *FileService) Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
// Check if parent directory exists, create only if not exists
parentDir := filepath.Dir(path)
if parentDir != "" && parentDir != "." && parentDir != "/" {
if _, err := cli.Stat(parentDir); err != nil {
// Directory doesn't exist, create it
if err := cli.MkdirAll(parentDir); err != nil {
return nil, fmt.Errorf("failed to create parent directory %s: %w", parentDir, err)
}
}
}
return cli.Create(path)
}
func (s *FileService) Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Open(path)
}
func (s *FileService) Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Stat(path)
}
// DownloadMultiple handles downloading single file or multiple files/directories as ZIP
func (s *FileService) DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, "", 0, err
}
// Validate and sanitize all filenames for security
var sanitizedFilenames []string
for _, filename := range filenames {
sanitized, err := sanitizeFilename(filename)
if err != nil {
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
}
sanitizedFilenames = append(sanitizedFilenames, sanitized)
}
// If only one file, check if it's a regular file first
if len(sanitizedFilenames) == 1 {
fullPath := filepath.Join(dir, sanitizedFilenames[0])
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return nil, "", 0, err
}
// If it's a regular file, return directly
if !fileInfo.IsDir() {
reader, err := cli.Open(fullPath)
if err != nil {
return nil, "", 0, err
}
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
}
}
// Multiple files or contains directory, create ZIP
return s.createZipArchive(cli, dir, sanitizedFilenames)
}
// createZipArchive creates a ZIP archive containing the specified files/directories
func (s *FileService) createZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
// Generate ZIP filename
var zipName string
if len(filenames) == 1 {
zipName = filenames[0] + ".zip"
} else {
zipName = "download.zip"
}
// Use pipe for true streaming without memory buffering
pipeReader, pipeWriter := io.Pipe()
// Create ZIP in a separate goroutine
go func() {
defer pipeWriter.Close()
zipWriter := zip.NewWriter(pipeWriter)
defer zipWriter.Close()
// Add each file/directory to ZIP
for _, filename := range filenames {
fullPath := filepath.Join(baseDir, filename)
if err := s.addToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
logger.L().Error("Failed to add file to ZIP", zap.String("path", fullPath), zap.Error(err))
pipeWriter.CloseWithError(err)
return
}
}
}()
// Return pipe reader for streaming, size unknown (-1)
return pipeReader, zipName, -1, nil
}
// addToZip recursively adds files/directories to ZIP archive
func (s *FileService) addToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
if fileInfo.IsDir() {
// Add directory
return s.addDirToZip(cli, zipWriter, fullPath, relativePath)
} else {
// Add file
return s.addFileToZip(cli, zipWriter, fullPath, relativePath)
}
}
// addFileToZip adds a single file to ZIP archive
func (s *FileService) addFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Get file info first to preserve metadata
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
// Open remote file
file, err := cli.Open(fullPath)
if err != nil {
return err
}
defer file.Close()
// Create FileHeader with original file metadata
header := &zip.FileHeader{
Name: relativePath,
Method: zip.Deflate,
Modified: fileInfo.ModTime(), // Preserve original modification time
}
// Set file mode
header.SetMode(fileInfo.Mode())
// Create file in ZIP with header
zipFile, err := zipWriter.CreateHeader(header)
if err != nil {
return err
}
// Copy file content
_, err = io.Copy(zipFile, file)
return err
}
// addDirToZip recursively adds a directory to ZIP archive
func (s *FileService) addDirToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Get directory info to preserve metadata
dirInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
// Read directory contents
entries, err := cli.ReadDir(fullPath)
if err != nil {
return err
}
// If directory is empty, create directory entry with preserved timestamp
if len(entries) == 0 {
header := &zip.FileHeader{
Name: relativePath + "/",
Method: zip.Store, // Directories are not compressed
Modified: dirInfo.ModTime(), // Preserve original modification time
}
header.SetMode(dirInfo.Mode())
_, err := zipWriter.CreateHeader(header)
return err
}
// Recursively add each entry in the directory
for _, entry := range entries {
entryFullPath := filepath.Join(fullPath, entry.Name())
entryRelativePath := filepath.Join(relativePath, entry.Name())
if err := s.addToZip(cli, zipWriter, fullPath, entryRelativePath, entryFullPath); err != nil {
return err
}
}
return nil
}
// Session-based operations
func (s *FileService) SessionLS(ctx context.Context, sessionId, dir string) ([]FileInfo, error) {
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return nil, err
}
entries, err := cli.ReadDir(dir)
if err != nil {
return nil, err
}
var fileInfos []FileInfo
for _, entry := range entries {
var target string
if entry.Mode()&fs.ModeSymlink != 0 {
linkPath := filepath.Join(dir, entry.Name())
if linkTarget, err := cli.ReadLink(linkPath); err == nil {
target = linkTarget
}
}
fileInfos = append(fileInfos, FileInfo{
Name: entry.Name(),
IsDir: entry.IsDir(),
Size: entry.Size(),
Mode: entry.Mode().String(),
IsLink: entry.Mode()&fs.ModeSymlink != 0,
Target: target,
ModTime: entry.ModTime().Format(time.RFC3339),
})
}
return fileInfos, nil
}
func (s *FileService) SessionMkdir(ctx context.Context, sessionId, dir string) error {
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return err
}
return cli.MkdirAll(dir)
}
func (s *FileService) SessionUpload(ctx context.Context, sessionId, targetPath string, file io.Reader, filename string, size int64) error {
return s.SessionUploadWithID(ctx, "", sessionId, targetPath, file, filename, size)
}
func (s *FileService) SessionUploadWithID(ctx context.Context, transferID, sessionId, targetPath string, file io.Reader, filename string, size int64) error {
client, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return err
}
var transfer *SessionFileTransfer
transferManager := GetSessionTransferManager()
if transferID != "" {
// Use existing transfer or create new one
transfer = transferManager.GetTransfer(transferID)
if transfer == nil {
transfer = transferManager.CreateTransferWithID(transferID, sessionId, filename, size, true)
if transfer == nil {
return fmt.Errorf("transfer ID already exists: %s", transferID)
}
// Set initial status
transfer.UpdateProgress(0)
}
} else {
// Create new transfer with auto-generated ID
transfer = transferManager.CreateTransfer(sessionId, filename, size, true)
transfer.UpdateProgress(0)
}
parentDir := filepath.Dir(targetPath)
if parentDir != "" && parentDir != "." && parentDir != "/" {
if _, err := client.Stat(parentDir); err != nil {
if err := client.MkdirAll(parentDir); err != nil {
transfer.SetError(fmt.Errorf("failed to create parent directory %s: %w", parentDir, err))
return fmt.Errorf("failed to create parent directory %s: %w", parentDir, err)
}
}
}
remoteFile, err := client.Create(targetPath)
if err != nil {
transfer.SetError(fmt.Errorf("failed to create remote file: %w", err))
return fmt.Errorf("failed to create remote file: %w", err)
}
defer remoteFile.Close()
progressWriter := NewProgressWriter(remoteFile, transfer)
buffer := make([]byte, 32*1024)
_, err = io.CopyBuffer(progressWriter, file, buffer)
if err != nil {
transfer.SetError(fmt.Errorf("failed to upload file: %w", err))
return fmt.Errorf("failed to upload file: %w", err)
}
transfer.UpdateProgress(size)
return nil
}
func (s *FileService) SessionDownload(ctx context.Context, sessionId, filePath string) (io.ReadCloser, int64, error) {
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return nil, 0, err
}
// Get file info
info, err := cli.Stat(filePath)
if err != nil {
return nil, 0, err
}
// Open file
file, err := cli.Open(filePath)
if err != nil {
return nil, 0, err
}
return file, info.Size(), nil
}
func (s *FileService) SessionDownloadMultiple(ctx context.Context, sessionId, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
client, err := GetSessionFileManager().GetSessionSFTP(sessionId)
if err != nil {
return nil, "", 0, err
}
// Validate and sanitize filenames
var sanitizedFilenames []string
for _, filename := range filenames {
sanitized, err := sanitizeFilename(filename)
if err != nil {
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
}
sanitizedFilenames = append(sanitizedFilenames, sanitized)
}
// If only one file and it's not a directory, return directly
if len(sanitizedFilenames) == 1 {
fullPath := filepath.Join(dir, sanitizedFilenames[0])
fileInfo, err := client.Stat(fullPath)
if err != nil {
return nil, "", 0, err
}
if !fileInfo.IsDir() {
reader, err := client.Open(fullPath)
if err != nil {
return nil, "", 0, err
}
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
}
}
// Multiple files or contains directory, create ZIP
return s.createZipArchive(client, dir, sanitizedFilenames)
}
// Session lifecycle management
func (s *FileService) InitSessionFileClient(sessionId string, assetId, accountId int) error {
return GetSessionFileManager().InitSessionSFTP(sessionId, assetId, accountId)
}
func (s *FileService) CloseSessionFileClient(sessionId string) {
GetSessionFileManager().CloseSessionSFTP(sessionId)
}
func (s *FileService) IsSessionActive(sessionId string) bool {
return GetSessionFileManager().IsSessionActive(sessionId)
}
// Progress tracking
func (s *FileService) GetSessionTransferProgress(ctx context.Context, sessionId string) ([]*SessionFileTransferProgress, error) {
return GetSessionTransferManager().GetSessionProgress(sessionId), nil
}
func (s *FileService) GetTransferProgress(ctx context.Context, transferId string) (*SessionFileTransferProgress, error) {
transfer := GetSessionTransferManager().GetTransfer(transferId)
if transfer == nil {
return nil, fmt.Errorf("transfer not found")
}
return transfer.GetProgress(), nil
}
// File history operations
func (s *FileService) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
if s.repo == nil {
return fmt.Errorf("repository not initialized")
}
return s.repo.AddFileHistory(ctx, history)
}
func (s *FileService) BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB {
db := dbpkg.DB.Model(&model.FileHistory{})
db = dbpkg.FilterSearch(ctx, db, "dir", "filename")
// Apply exact match filters
db = dbpkg.FilterEqual(ctx, db, "status", "uid", "asset_id", "account_id", "action")
// Apply client IP filter
if clientIp := ctx.Query("client_ip"); clientIp != "" {
db = db.Where("client_ip = ?", clientIp)
}
// Apply date range filters
if start := ctx.Query("start"); start != "" {
db = db.Where("created_at >= ?", start)
}
if end := ctx.Query("end"); end != "" {
db = db.Where("created_at <= ?", end)
}
return db
}
func (s *FileService) RecordFileHistory(ctx context.Context, operation, dir, filename string, assetId, accountId int, sessionId ...string) error {
// Extract user information from context
currentUser, err := acl.GetSessionFromCtx(ctx)
if err != nil || currentUser == nil {
// If no user context, still record the operation but with empty user info
logger.L().Warn("No user context found when recording file history", zap.String("operation", operation))
}
var uid int
var userName string
var clientIP string
if currentUser != nil {
uid = currentUser.GetUid()
userName = currentUser.GetUserName()
}
// Get client IP from gin context
if ginCtx, ok := ctx.(*gin.Context); ok {
clientIP = ginCtx.ClientIP()
}
history := &model.FileHistory{
Uid: uid,
UserName: userName,
AssetId: assetId,
AccountId: accountId,
ClientIp: clientIP,
Action: s.GetActionCode(operation),
Dir: dir,
Filename: filename,
}
if err := s.AddFileHistory(ctx, history); err != nil {
// Log error details including sessionId if provided for debugging
sessionIdStr := ""
if len(sessionId) > 0 {
sessionIdStr = sessionId[0]
}
logger.L().Error("Failed to record file history",
zap.Error(err),
zap.String("operation", operation),
zap.String("sessionId", sessionIdStr),
zap.Any("history", history))
return err
}
return nil
}
func (s *FileService) RecordFileHistoryBySession(ctx context.Context, sessionId, operation, path string) error {
// Get session info to extract asset and account information
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
logger.L().Warn("Cannot record file history: session not found", zap.String("sessionId", sessionId))
return fmt.Errorf("session not found: %s", sessionId)
}
// Extract directory and filename from path
dir := filepath.Dir(path)
filename := filepath.Base(path)
return s.RecordFileHistory(ctx, operation, dir, filename, onlineSession.AssetId, onlineSession.AccountId, sessionId)
}
// Utility methods
func (s *FileService) GetActionCode(operation string) int {
switch operation {
case "upload":
return model.FILE_ACTION_UPLOAD
case "download":
return model.FILE_ACTION_DOWNLOAD
case "mkdir":
return model.FILE_ACTION_MKDIR
default:
return model.FILE_ACTION_LS
}
}
func (s *FileService) ValidateAndNormalizePath(basePath, userPath string) (string, error) {
// Clean the user-provided path
cleanPath := filepath.Clean(userPath)
// Ensure it doesn't contain directory traversal attempts
if strings.Contains(cleanPath, "..") {
return "", fmt.Errorf("path contains directory traversal: %s", userPath)
}
// Join with base path and clean again
fullPath := filepath.Join(basePath, cleanPath)
// Ensure the resulting path is still within the base directory
relPath, err := filepath.Rel(basePath, fullPath)
if err != nil {
return "", fmt.Errorf("invalid path: %s", userPath)
}
if strings.HasPrefix(relPath, "..") {
return "", fmt.Errorf("path outside base directory: %s", userPath)
}
return fullPath, nil
}
func (s *FileService) GetRDPDrivePath(assetId int) (string, error) {
var drivePath string
// Priority 1: Get from environment variable (highest priority)
drivePath = os.Getenv("ONETERM_RDP_DRIVE_PATH")
// Priority 2: Use default path based on OS
if drivePath == "" {
if runtime.GOOS == "windows" {
drivePath = filepath.Join("C:", "temp", "oneterm", "rdp")
} else {
drivePath = filepath.Join("/tmp", "oneterm", "rdp")
}
}
// Create asset-specific subdirectory
fullDrivePath := filepath.Join(drivePath, fmt.Sprintf("asset_%d", assetId))
// Ensure directory exists
if err := os.MkdirAll(fullDrivePath, 0755); err != nil {
return "", fmt.Errorf("failed to create RDP drive directory %s: %w", fullDrivePath, err)
}
// Clear macOS extended attributes that might interfere with Docker volume mounting
if runtime.GOOS == "darwin" {
// Clear attributes for the directory and all its contents
exec.Command("find", fullDrivePath, "-exec", "xattr", "-c", "{}", ";").Run()
}
return fullDrivePath, nil
}
// Simple FileRepository implementation
type FileRepository struct {
db *gorm.DB
}
func (r *FileRepository) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
return r.db.Create(history).Error
}
func (r *FileRepository) BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB {
return r.db.Model(&model.FileHistory{})
}

View File

@@ -0,0 +1,465 @@
package file
import (
"archive/zip"
"bytes"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/veops/oneterm/internal/guacd"
gsession "github.com/veops/oneterm/internal/session"
)
// RDP file operation functions
// NewRDPProgressWriter creates a new RDP progress writer
func NewRDPProgressWriter(writer io.Writer, transfer *guacd.FileTransfer, transferId string) *RDPProgressWriter {
return &RDPProgressWriter{
writer: writer,
transfer: transfer,
transferId: transferId,
written: 0,
}
}
func (pw *RDPProgressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
if err != nil {
return n, err
}
pw.written += int64(n)
// Update unified progress tracking
UpdateTransferProgress(pw.transferId, 0, pw.written, "transferring")
return n, nil
}
// IsRDPDriveEnabled checks if RDP drive is enabled
func IsRDPDriveEnabled(tunnel *guacd.Tunnel) bool {
if tunnel == nil || tunnel.Config == nil {
return false
}
driveEnabled := tunnel.Config.Parameters["enable-drive"] == "true"
return driveEnabled
}
// IsRDPUploadAllowed checks if RDP upload is allowed
func IsRDPUploadAllowed(tunnel *guacd.Tunnel) bool {
if tunnel == nil || tunnel.Config == nil {
return false
}
return tunnel.Config.Parameters["disable-upload"] != "true"
}
// IsRDPDownloadAllowed checks if RDP download is allowed
func IsRDPDownloadAllowed(tunnel *guacd.Tunnel) bool {
if tunnel == nil || tunnel.Config == nil {
return false
}
return tunnel.Config.Parameters["disable-download"] != "true"
}
// RequestRDPFileList gets file list for RDP session
func RequestRDPFileList(tunnel *guacd.Tunnel, path string) ([]RDPFileInfo, error) {
// Implementation placeholder - this would need to be implemented based on Guacamole protocol
return RequestRDPFileListViaDirect(tunnel, path)
}
func RequestRDPFileListViaDirect(tunnel *guacd.Tunnel, path string) ([]RDPFileInfo, error) {
// Get session to extract asset ID
sessionId := tunnel.SessionId
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return nil, fmt.Errorf("session not found: %s", sessionId)
}
// Get drive path with proper fallback handling
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
if err != nil {
return nil, fmt.Errorf("failed to get drive path: %w", err)
}
// Validate and construct full filesystem path
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
if err != nil {
return nil, fmt.Errorf("invalid path: %w", err)
}
// Read directory contents
entries, err := os.ReadDir(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to read directory: %w", err)
}
var files []RDPFileInfo
for _, entry := range entries {
info, err := entry.Info()
if err != nil {
continue // Skip entries with errors
}
files = append(files, RDPFileInfo{
Name: entry.Name(),
Size: info.Size(),
IsDir: entry.IsDir(),
ModTime: info.ModTime().Format(time.RFC3339),
})
}
return files, nil
}
// DownloadRDPFile downloads a single file from RDP session
func DownloadRDPFile(tunnel *guacd.Tunnel, path string) (io.ReadCloser, int64, error) {
// Get session to extract asset ID
sessionId := tunnel.SessionId
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return nil, 0, fmt.Errorf("session not found: %s", sessionId)
}
// Get drive path with proper fallback handling
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
if err != nil {
return nil, 0, fmt.Errorf("failed to get drive path: %w", err)
}
// Validate and construct full filesystem path
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
if err != nil {
return nil, 0, fmt.Errorf("invalid path: %w", err)
}
// Check if path exists and is a file
info, err := os.Stat(fullPath)
if err != nil {
return nil, 0, fmt.Errorf("file not found: %w", err)
}
if info.IsDir() {
return nil, 0, fmt.Errorf("path is a directory, not a file")
}
// Open file for streaming (memory-efficient)
file, err := os.Open(fullPath)
if err != nil {
return nil, 0, fmt.Errorf("failed to open file: %w", err)
}
return file, info.Size(), nil
}
// DownloadRDPMultiple downloads multiple files from RDP session as ZIP
func DownloadRDPMultiple(tunnel *guacd.Tunnel, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
var sanitizedFilenames []string
for _, filename := range filenames {
if filename == "" || strings.Contains(filename, "..") || strings.Contains(filename, "/") {
return nil, "", 0, fmt.Errorf("invalid filename: %s", filename)
}
sanitizedFilenames = append(sanitizedFilenames, filename)
}
if len(sanitizedFilenames) == 1 {
fullPath := filepath.Join(dir, sanitizedFilenames[0])
// Check if it's a directory or file
sessionId := tunnel.SessionId
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return nil, "", 0, fmt.Errorf("session not found: %s", sessionId)
}
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
if err != nil {
return nil, "", 0, fmt.Errorf("failed to get drive path: %w", err)
}
realPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, fullPath)
if err != nil {
return nil, "", 0, fmt.Errorf("invalid path: %w", err)
}
info, err := os.Stat(realPath)
if err != nil {
return nil, "", 0, fmt.Errorf("file not found: %w", err)
}
if info.IsDir() {
// For directory, create a zip with directory contents
return CreateRDPZip(tunnel, dir, sanitizedFilenames)
} else {
// For single file, download directly
reader, fileSize, err := DownloadRDPFile(tunnel, fullPath)
if err != nil {
return nil, "", 0, err
}
return reader, sanitizedFilenames[0], fileSize, nil
}
}
// Multiple files/directories - always create zip
return CreateRDPZip(tunnel, dir, sanitizedFilenames)
}
// CreateRDPZip creates a ZIP archive of multiple RDP files
func CreateRDPZip(tunnel *guacd.Tunnel, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
var buf bytes.Buffer
zipWriter := zip.NewWriter(&buf)
for _, filename := range filenames {
fullPath := filepath.Join(dir, filename)
err := AddToRDPZip(tunnel, zipWriter, fullPath, filename)
if err != nil {
zipWriter.Close()
return nil, "", 0, fmt.Errorf("failed to add %s to zip: %w", filename, err)
}
}
if err := zipWriter.Close(); err != nil {
return nil, "", 0, fmt.Errorf("failed to close zip: %w", err)
}
downloadFilename := fmt.Sprintf("rdp_files_%s.zip", time.Now().Format("20060102_150405"))
reader := io.NopCloser(bytes.NewReader(buf.Bytes()))
return reader, downloadFilename, int64(buf.Len()), nil
}
// AddToRDPZip adds a file or directory to the ZIP archive
func AddToRDPZip(tunnel *guacd.Tunnel, zipWriter *zip.Writer, fullPath, zipPath string) error {
// Get session to extract asset ID
sessionId := tunnel.SessionId
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return fmt.Errorf("session not found: %s", sessionId)
}
// Get drive path with proper fallback handling
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
if err != nil {
return fmt.Errorf("failed to get drive path: %w", err)
}
// Validate and construct full filesystem path
realPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, fullPath)
if err != nil {
return fmt.Errorf("invalid path: %w", err)
}
// Check if path exists
info, err := os.Stat(realPath)
if err != nil {
return fmt.Errorf("file not found: %w", err)
}
if info.IsDir() {
// Add directory entries recursively
entries, err := os.ReadDir(realPath)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
}
// Create directory entry in zip if not empty
if len(entries) == 0 {
// Create empty directory entry
dirHeader := &zip.FileHeader{
Name: zipPath + "/",
Method: zip.Store,
}
_, err := zipWriter.CreateHeader(dirHeader)
if err != nil {
return fmt.Errorf("failed to create directory entry: %w", err)
}
} else {
// Add all files in directory
for _, entry := range entries {
entryPath := filepath.Join(fullPath, entry.Name())
entryZipPath := zipPath + "/" + entry.Name()
err := AddToRDPZip(tunnel, zipWriter, entryPath, entryZipPath)
if err != nil {
return err
}
}
}
} else {
// Add file to zip
file, err := os.Open(realPath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
writer, err := zipWriter.Create(zipPath)
if err != nil {
return fmt.Errorf("failed to create zip entry: %w", err)
}
// Stream file content to zip (memory-efficient)
if _, err := io.Copy(writer, file); err != nil {
return fmt.Errorf("failed to write file to zip: %w", err)
}
}
return nil
}
// CreateRDPDirectory creates a directory in RDP session
func CreateRDPDirectory(tunnel *guacd.Tunnel, path string) error {
// Get session to extract asset ID
sessionId := tunnel.SessionId
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return fmt.Errorf("session not found: %s", sessionId)
}
// Get drive path with proper fallback handling
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
if err != nil {
return fmt.Errorf("failed to get drive path: %w", err)
}
// Validate and construct full filesystem path
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
if err != nil {
return fmt.Errorf("invalid path: %w", err)
}
// Create directory with proper permissions
if err := os.MkdirAll(fullPath, 0755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Send refresh notification to RDP session
NotifyRDPDirectoryRefresh(sessionId)
return nil
}
// UploadRDPFileStreamWithID uploads file to RDP session with progress tracking
func UploadRDPFileStreamWithID(tunnel *guacd.Tunnel, transferID, sessionId, path string, reader io.Reader, totalSize int64) error {
// Get session to extract asset ID
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return fmt.Errorf("session not found: %s", sessionId)
}
// Get drive path with proper fallback handling
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
if err != nil {
return fmt.Errorf("failed to get drive path: %w", err)
}
// Create transfer tracker
var transfer *guacd.FileTransfer
if transferID != "" {
transfer, err = guacd.DefaultFileTransferManager.CreateUploadWithID(transferID, sessionId, filepath.Base(path), drivePath)
} else {
transfer, err = guacd.DefaultFileTransferManager.CreateUpload(sessionId, filepath.Base(path), drivePath)
}
if err != nil {
return fmt.Errorf("failed to create transfer tracker: %w", err)
}
// Note: Don't remove transfer immediately - let it be cleaned up later so progress can be queried
transfer.SetSize(totalSize)
// Validate and construct full filesystem path
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
if err != nil {
return fmt.Errorf("invalid path: %w", err)
}
destDir := filepath.Dir(fullPath)
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
destFile, err := os.Create(fullPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer destFile.Close()
progressWriter := NewRDPProgressWriter(destFile, transfer, transferID)
buffer := make([]byte, 32*1024)
written, err := io.CopyBuffer(progressWriter, reader, buffer)
if err != nil {
os.Remove(fullPath)
return fmt.Errorf("failed to write file: %w", err)
}
if totalSize > 0 && written != totalSize {
os.Remove(fullPath)
// Mark as failed in unified tracking
UpdateTransferProgress(transferID, 0, -1, "failed")
return fmt.Errorf("file size mismatch: expected %d, wrote %d", totalSize, written)
}
// CRITICAL: Explicitly close and sync file before marking as completed
// This ensures the file is fully written to disk and visible to mounted containers
if err := destFile.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
// Clear macOS extended attributes that might interfere with Docker volume mounting
if runtime.GOOS == "darwin" {
// Clear attributes for the file and parent directory
exec.Command("xattr", "-c", fullPath).Run()
exec.Command("xattr", "-c", filepath.Dir(fullPath)).Run()
}
// Mark as completed in unified tracking
UpdateTransferProgress(transferID, 0, written, "completed")
// Send refresh notification to frontend with a slight delay
// This ensures the file system operations are fully completed
go func() {
time.Sleep(500 * time.Millisecond)
NotifyRDPDirectoryRefresh(sessionId)
}()
return nil
}
// GetRDPTransferProgressById gets RDP transfer progress by ID
func GetRDPTransferProgressById(transferId string) (interface{}, error) {
progress, err := guacd.DefaultFileTransferManager.GetTransferProgress(transferId)
if err != nil {
return nil, err
}
return progress, nil
}
// NotifyRDPDirectoryRefresh sends F5 key to refresh Windows Explorer
func NotifyRDPDirectoryRefresh(sessionId string) {
// Get the active session and tunnel
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return
}
tunnel := onlineSession.GuacdTunnel
if tunnel == nil {
return
}
// Send F5 key to refresh Windows Explorer
// F5 key code: 65474
f5DownInstruction := guacd.NewInstruction("key", "65474", "1")
if _, err := tunnel.WriteInstruction(f5DownInstruction); err != nil {
return
}
time.Sleep(100 * time.Millisecond)
f5UpInstruction := guacd.NewInstruction("key", "65474", "0")
tunnel.WriteInstruction(f5UpInstruction)
}

View File

@@ -0,0 +1,475 @@
package file
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/pkg/sftp"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"github.com/veops/oneterm/internal/repository"
gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/internal/tunneling"
"github.com/veops/oneterm/pkg/logger"
)
// =============================================================================
// SFTP Operations - Managers defined in parent file service
// =============================================================================
// =============================================================================
// SFTP Upload/Download Operations with Progress Tracking
// =============================================================================
// TransferToTarget handles transfer routing (session-based or asset-based)
func TransferToTarget(transferId, sessionIdOrCustom, tempFilePath, targetPath string, assetId, accountId int) error {
// For session-based transfers, try to reuse existing SFTP connection first
if assetId == 0 && accountId == 0 && sessionIdOrCustom != "" {
return SessionBasedTransfer(transferId, sessionIdOrCustom, tempFilePath, targetPath)
}
// For asset/account-based transfers, fall back to creating new connection
return AssetBasedTransfer(transferId, tempFilePath, targetPath, assetId, accountId)
}
// SessionBasedTransfer uses existing session SFTP connection for optimal performance
func SessionBasedTransfer(transferId, sessionId, tempFilePath, targetPath string) error {
// Try to get existing SFTP client from session manager
sessionFM := GetSessionFileManager()
sftpClient, err := sessionFM.GetSessionSFTP(sessionId)
if err != nil {
// If no existing connection, create one
onlineSession := gsession.GetOnlineSessionById(sessionId)
if onlineSession == nil {
return fmt.Errorf("session %s not found", sessionId)
}
// Initialize SFTP connection for this session
if initErr := sessionFM.InitSessionSFTP(sessionId, onlineSession.AssetId, onlineSession.AccountId); initErr != nil {
return fmt.Errorf("failed to initialize SFTP for session %s: %w", sessionId, initErr)
}
// Get the newly created SFTP client
sftpClient, err = sessionFM.GetSessionSFTP(sessionId)
if err != nil {
return fmt.Errorf("failed to get SFTP client for session %s: %w", sessionId, err)
}
}
// Use existing SFTP client for transfer (no need to close it as it's managed by SessionFileManager)
return SftpUploadWithExistingClient(sftpClient, transferId, tempFilePath, targetPath)
}
// AssetBasedTransfer creates new connection for asset/account-based transfers (legacy)
func AssetBasedTransfer(transferId, tempFilePath, targetPath string, assetId, accountId int) error {
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
if err != nil {
return fmt.Errorf("failed to get asset/account info: %w", err)
}
sessionId := fmt.Sprintf("upload_%d_%d_%d", assetId, accountId, time.Now().UnixNano())
// Get SSH connection details
ip, port, err := tunneling.Proxy(false, sessionId, "ssh", asset, gateway)
if err != nil {
return fmt.Errorf("failed to setup tunnel: %w", err)
}
auth, err := repository.GetAuth(account)
if err != nil {
return fmt.Errorf("failed to get auth: %w", err)
}
// Create SSH client with maximum performance optimizations for SFTP
sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
User: account.Account,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 30 * time.Second,
// Ultra-high performance optimizations - fastest algorithms first
Config: ssh.Config{
Ciphers: []string{
"aes128-ctr", // Fastest for most CPUs with AES-NI
"aes128-gcm@openssh.com", // Hardware accelerated AEAD cipher
"chacha20-poly1305@openssh.com", // Fast on ARM/systems without AES-NI
"aes256-ctr", // Fallback high-performance option
},
MACs: []string{
"hmac-sha2-256-etm@openssh.com", // Encrypt-then-MAC (fastest + most secure)
"hmac-sha2-256", // Standard high-performance MAC
},
KeyExchanges: []string{
"curve25519-sha256@libssh.org", // Modern elliptic curve (fastest)
"curve25519-sha256", // Equivalent modern KEX
"ecdh-sha2-nistp256", // Fast NIST curve fallback
},
},
// Optimize connection algorithms for speed
HostKeyAlgorithms: []string{
"rsa-sha2-256", // Fast RSA with SHA-2
"rsa-sha2-512", // Alternative fast RSA
"ssh-ed25519", // Modern EdDSA (very fast verification)
},
})
if err != nil {
return fmt.Errorf("failed to connect SSH: %w", err)
}
defer sshClient.Close()
// Use optimized SFTP to transfer file
return SftpUploadWithProgress(sshClient, transferId, tempFilePath, targetPath)
}
// SftpUploadWithProgress uploads file using optimized SFTP protocol with accurate progress tracking
func SftpUploadWithProgress(client *ssh.Client, transferId, localPath, remotePath string) error {
// Create SFTP client with maximum performance settings
sftpClient, err := sftp.NewClient(client,
sftp.MaxPacket(1024*32), // 32KB packets - maximum safe size for most servers
sftp.MaxConcurrentRequestsPerFile(64), // High concurrency for maximum throughput
sftp.UseConcurrentReads(true), // Enable concurrent reads for better performance
sftp.UseConcurrentWrites(true), // Enable concurrent writes for better performance
)
if err != nil {
logger.L().Error("Failed to create SFTP client",
zap.String("transferId", transferId),
zap.Error(err))
return fmt.Errorf("failed to create SFTP client: %w", err)
}
defer sftpClient.Close()
// Open local file
localFile, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("failed to open local file: %w", err)
}
defer localFile.Close()
// Get file info
fileInfo, err := localFile.Stat()
if err != nil {
return fmt.Errorf("failed to get file info: %w", err)
}
// Create parent directory on remote if needed
parentDir := filepath.Dir(remotePath)
if parentDir != "" && parentDir != "." && parentDir != "/" {
if err := sftpClient.MkdirAll(parentDir); err != nil {
logger.L().Warn("Failed to create parent directory via SFTP", zap.Error(err))
}
}
// Create remote file
remoteFile, err := sftpClient.Create(remotePath)
if err != nil {
return fmt.Errorf("failed to create remote file: %w", err)
}
defer remoteFile.Close()
// Create progress tracking writer with SFTP-specific optimizations
progressWriter := NewFileProgressWriter(remoteFile, transferId)
// Transfer file content with ultra-high performance buffer for SFTP
// Use 2MB buffer to minimize round trips and maximize throughput
buffer := make([]byte, 2*1024*1024) // 2MB buffer for ultra-high SFTP performance
// Manual optimized copy loop to avoid io.CopyBuffer overhead
var transferred int64
for {
n, readErr := localFile.Read(buffer)
if n > 0 {
written, writeErr := progressWriter.Write(buffer[:n])
transferred += int64(written)
if writeErr != nil {
err = writeErr
break
}
}
if readErr != nil {
if readErr == io.EOF {
break // Normal end of file
}
err = readErr
break
}
}
if err != nil {
logger.L().Error("SFTP file transfer failed during copy",
zap.String("transferId", transferId),
zap.Int64("transferred", transferred),
zap.Int64("fileSize", fileInfo.Size()),
zap.Error(err))
return fmt.Errorf("failed to transfer file content via SFTP: %w", err)
}
// Force final progress update
UpdateTransferProgress(transferId, 0, transferred, "")
return nil
}
// SftpUploadWithExistingClient uploads file using existing SFTP client with accurate progress tracking
func SftpUploadWithExistingClient(client *sftp.Client, transferId, localPath, remotePath string) error {
// Open local file
localFile, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("failed to open local file: %w", err)
}
defer localFile.Close()
// Get file info
fileInfo, err := localFile.Stat()
if err != nil {
return fmt.Errorf("failed to get file info: %w", err)
}
// Create parent directory on remote if needed
parentDir := filepath.Dir(remotePath)
if parentDir != "" && parentDir != "." && parentDir != "/" {
if err := client.MkdirAll(parentDir); err != nil {
logger.L().Warn("Failed to create parent directory via SFTP", zap.Error(err))
}
}
// Create remote file
remoteFile, err := client.Create(remotePath)
if err != nil {
return fmt.Errorf("failed to create remote file: %w", err)
}
defer remoteFile.Close()
// Create progress tracking writer
progressWriter := NewFileProgressWriter(remoteFile, transferId)
// Transfer file content with ultra-high performance buffer for SFTP
// Use 2MB buffer to minimize round trips and maximize throughput
buffer := make([]byte, 2*1024*1024) // 2MB buffer for ultra-high SFTP performance
var transferred int64
for {
n, readErr := localFile.Read(buffer)
if n > 0 {
written, writeErr := progressWriter.Write(buffer[:n])
transferred += int64(written)
if writeErr != nil {
err = writeErr
break
}
}
if readErr != nil {
if readErr == io.EOF {
break // Normal end of file
}
err = readErr
break
}
}
if err != nil {
logger.L().Error("SFTP file transfer failed",
zap.String("transferId", transferId),
zap.Int64("transferred", transferred),
zap.Int64("fileSize", fileInfo.Size()),
zap.Error(err))
return fmt.Errorf("failed to transfer file: %w", err)
}
// Force final progress update
UpdateTransferProgress(transferId, 0, transferred, "")
logger.L().Info("SFTP file transfer completed",
zap.String("transferId", transferId),
zap.String("remotePath", remotePath),
zap.Int64("size", transferred))
return nil
}
// =============================================================================
// SFTP Download Operations with ZIP Support
// =============================================================================
// SftpDownloadMultiple downloads multiple files as ZIP or single file
func SftpDownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, "", 0, fmt.Errorf("failed to get SFTP client: %w", err)
}
if len(filenames) == 1 {
// Single file download
fullPath := filepath.Join(dir, filenames[0])
file, err := cli.Open(fullPath)
if err != nil {
return nil, "", 0, fmt.Errorf("failed to open file %s: %w", fullPath, err)
}
// Get file size
info, err := cli.Stat(fullPath)
if err != nil {
file.Close()
return nil, "", 0, fmt.Errorf("failed to get file info: %w", err)
}
return file, filenames[0], info.Size(), nil
}
// Multiple files - create ZIP
return createSftpZipArchive(cli, dir, filenames)
}
// createSftpZipArchive creates a ZIP archive of multiple SFTP files
func createSftpZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
// Create a buffer to write the ZIP archive
var buffer bytes.Buffer
zipWriter := zip.NewWriter(&buffer)
for _, filename := range filenames {
fullPath := filepath.Join(baseDir, filename)
if err := addSftpFileToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
zipWriter.Close()
return nil, "", 0, fmt.Errorf("failed to add %s to ZIP: %w", filename, err)
}
}
if err := zipWriter.Close(); err != nil {
return nil, "", 0, fmt.Errorf("failed to close ZIP writer: %w", err)
}
// Generate ZIP filename
var zipFilename string
if len(filenames) == 1 {
zipFilename = strings.TrimSuffix(filenames[0], filepath.Ext(filenames[0])) + ".zip"
} else {
zipFilename = fmt.Sprintf("sftp_files_%d_items.zip", len(filenames))
}
reader := bytes.NewReader(buffer.Bytes())
return io.NopCloser(reader), zipFilename, int64(buffer.Len()), nil
}
// addSftpFileToZip adds a file or directory to the ZIP archive
func addSftpFileToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
// Get file info
info, err := cli.Stat(fullPath)
if err != nil {
return fmt.Errorf("failed to stat %s: %w", fullPath, err)
}
if info.IsDir() {
// Handle directory
return addSftpDirToZip(cli, zipWriter, baseDir, relativePath, fullPath)
}
// Handle regular file
return addSftpRegularFileToZip(cli, zipWriter, fullPath, relativePath)
}
// addSftpRegularFileToZip adds a regular file to ZIP
func addSftpRegularFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Open remote file
file, err := cli.Open(fullPath)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", fullPath, err)
}
defer file.Close()
// Create ZIP entry
header := &zip.FileHeader{
Name: relativePath,
Method: zip.Deflate,
}
writer, err := zipWriter.CreateHeader(header)
if err != nil {
return fmt.Errorf("failed to create ZIP entry: %w", err)
}
// Copy file content to ZIP
_, err = io.Copy(writer, file)
return err
}
// addSftpDirToZip adds a directory to ZIP recursively
func addSftpDirToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
// Read directory contents
entries, err := cli.ReadDir(fullPath)
if err != nil {
return fmt.Errorf("failed to read directory %s: %w", fullPath, err)
}
// Add directory entry to ZIP
if relativePath != "" && relativePath != "." {
header := &zip.FileHeader{
Name: relativePath + "/",
}
_, err = zipWriter.CreateHeader(header)
if err != nil {
return fmt.Errorf("failed to create directory entry: %w", err)
}
}
// Add directory contents recursively
for _, entry := range entries {
entryRelPath := filepath.Join(relativePath, entry.Name())
entryFullPath := filepath.Join(fullPath, entry.Name())
if err := addSftpFileToZip(cli, zipWriter, baseDir, entryRelPath, entryFullPath); err != nil {
return err
}
}
return nil
}
// =============================================================================
// SFTP Progress Writers
// =============================================================================
// SftpProgressWriter tracks SFTP transfer progress
type SftpProgressWriter struct {
writer io.Writer
transferId string
written int64
lastUpdate time.Time
updateBytes int64 // Bytes written since last progress update
updateTicker int64 // Simple counter to reduce time.Now() calls
}
// NewSftpProgressWriter creates a new SFTP progress writer
func NewSftpProgressWriter(writer io.Writer, transferId string) *SftpProgressWriter {
return &SftpProgressWriter{
writer: writer,
transferId: transferId,
lastUpdate: time.Now(),
}
}
func (pw *SftpProgressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
if err != nil {
return n, err
}
pw.written += int64(n)
pw.updateBytes += int64(n)
pw.updateTicker++
// Update progress every 64KB bytes OR every 1000 write calls (reduces time.Now() overhead)
if pw.updateBytes >= 65536 || pw.updateTicker >= 1000 {
now := time.Now()
// Only update if enough time has passed (reduce lock contention)
if pw.updateBytes >= 65536 || now.Sub(pw.lastUpdate) >= 50*time.Millisecond {
UpdateTransferProgress(pw.transferId, 0, pw.written, "")
pw.lastUpdate = now
pw.updateBytes = 0
pw.updateTicker = 0
}
}
return n, nil
}

View File

@@ -0,0 +1,722 @@
package file
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"io/fs"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/pkg/sftp"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"gorm.io/gorm"
"github.com/veops/oneterm/internal/guacd"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
gsession "github.com/veops/oneterm/internal/session"
"github.com/veops/oneterm/internal/tunneling"
"github.com/veops/oneterm/pkg/logger"
)
var (
// Legacy asset-based file manager
fm = &FileManager{
sftps: map[string]*sftp.Client{},
lastTime: map[string]time.Time{},
mtx: sync.Mutex{},
}
// New session-based file manager
sessionFM = &SessionFileManager{
sessionSFTP: make(map[string]*sftp.Client),
sessionSSH: make(map[string]*ssh.Client),
lastActive: make(map[string]time.Time),
}
// Session file transfer manager
sessionTransferManager = &SessionFileTransferManager{
transfers: make(map[string]*SessionFileTransfer),
}
// Progress tracking state
fileTransfers = make(map[string]*FileTransferProgress)
transferMutex sync.RWMutex
)
// Session-based file operation errors
var (
ErrSessionNotFound = errors.New("session not found")
ErrSessionClosed = errors.New("session has been closed")
ErrSessionInactive = errors.New("session is inactive")
ErrSFTPNotAvailable = errors.New("SFTP not available for this session")
)
// Global getter functions
func GetFileManager() *FileManager {
return fm
}
func GetSessionFileManager() *SessionFileManager {
return sessionFM
}
func GetSessionTransferManager() *SessionFileTransferManager {
return sessionTransferManager
}
// FileInfo represents file information
type FileInfo struct {
Name string `json:"name"`
IsDir bool `json:"is_dir"`
Size int64 `json:"size"`
Mode string `json:"mode"`
IsLink bool `json:"is_link"`
Target string `json:"target"`
ModTime string `json:"mod_time"`
}
// FileManager manages SFTP connections (legacy asset-based)
type FileManager struct {
sftps map[string]*sftp.Client
lastTime map[string]time.Time
mtx sync.Mutex
}
func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client, err error) {
fm.mtx.Lock()
defer fm.mtx.Unlock()
key := fmt.Sprintf("%d-%d", assetId, accountId)
defer func() {
fm.lastTime[key] = time.Now()
}()
cli, ok := fm.sftps[key]
if ok {
return
}
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
if err != nil {
return
}
ip, port, err := tunneling.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
if err != nil {
return
}
auth, err := repository.GetAuth(account)
if err != nil {
return
}
sshCli, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
User: account.Account,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Second,
})
if err != nil {
return
}
// Create optimized SFTP client
cli, err = sftp.NewClient(sshCli,
sftp.MaxPacket(32768), // 32KB packets for maximum compatibility
sftp.MaxConcurrentRequestsPerFile(16), // Increase concurrent requests
sftp.UseConcurrentWrites(true), // Enable concurrent writes
sftp.UseConcurrentReads(true), // Enable concurrent reads
sftp.UseFstat(false), // Disable fstat for better compatibility
)
if err != nil {
sshCli.Close()
return
}
fm.sftps[key] = cli
return
}
// SessionFileManager manages SFTP connections per session
type SessionFileManager struct {
sessionSFTP map[string]*sftp.Client // sessionId -> SFTP client
sessionSSH map[string]*ssh.Client // sessionId -> SSH client
lastActive map[string]time.Time // sessionId -> last active time
mutex sync.RWMutex
}
func (sfm *SessionFileManager) InitSessionSFTP(sessionId string, assetId, accountId int) error {
sfm.mutex.Lock()
defer sfm.mutex.Unlock()
// Check if already exists
if _, exists := sfm.sessionSFTP[sessionId]; exists {
sfm.lastActive[sessionId] = time.Now()
return nil
}
// CRITICAL OPTIMIZATION: Try to reuse existing SSH connection from terminal session
onlineSession := gsession.GetOnlineSessionById(sessionId)
var sshClient *ssh.Client
var shouldCloseClient = false
if onlineSession != nil && onlineSession.HasSSHClient() {
sshClient = onlineSession.GetSSHClient()
logger.L().Info("REUSING existing SSH connection from terminal session",
zap.String("sessionId", sessionId))
} else {
// Fallback: Create new SSH connection if no existing connection found
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
if err != nil {
return err
}
// Use sessionId as proxy identifier for connection reuse
ip, port, err := tunneling.Proxy(false, sessionId, "sftp,ssh", asset, gateway)
if err != nil {
return err
}
auth, err := repository.GetAuth(account)
if err != nil {
return err
}
sshClient, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
User: account.Account,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
})
if err != nil {
return fmt.Errorf("failed to connect SSH for session %s: %w", sessionId, err)
}
shouldCloseClient = true // We created it, so we manage its lifecycle
logger.L().Info("Created new SSH connection for file transfer",
zap.String("sessionId", sessionId))
}
// Create SFTP client with optimized settings for better performance
sftpClient, err := sftp.NewClient(sshClient,
sftp.MaxPacket(32768), // 32KB packets for maximum compatibility
sftp.MaxConcurrentRequestsPerFile(16), // Increase concurrent requests per file (default is 3)
sftp.UseConcurrentWrites(true), // Enable concurrent writes
sftp.UseConcurrentReads(true), // Enable concurrent reads
sftp.UseFstat(false), // Disable fstat for better compatibility
)
if err != nil {
if shouldCloseClient {
sshClient.Close()
}
return fmt.Errorf("failed to create SFTP client for session %s: %w", sessionId, err)
}
// Store clients (only store SSH client if we created it)
sfm.sessionSFTP[sessionId] = sftpClient
if shouldCloseClient {
sfm.sessionSSH[sessionId] = sshClient
}
sfm.lastActive[sessionId] = time.Now()
logger.L().Info("SFTP connection initialized for session",
zap.String("sessionId", sessionId),
zap.Bool("reusedConnection", !shouldCloseClient))
return nil
}
func (sfm *SessionFileManager) GetSessionSFTP(sessionId string) (*sftp.Client, error) {
sfm.mutex.RLock()
defer sfm.mutex.RUnlock()
client, exists := sfm.sessionSFTP[sessionId]
if !exists {
return nil, ErrSessionNotFound
}
// Update last active time
sfm.lastActive[sessionId] = time.Now()
return client, nil
}
func (sfm *SessionFileManager) CloseSessionSFTP(sessionId string) {
sfm.mutex.Lock()
defer sfm.mutex.Unlock()
if sftpClient, exists := sfm.sessionSFTP[sessionId]; exists {
sftpClient.Close()
delete(sfm.sessionSFTP, sessionId)
}
// Only close SSH client if we created it (not reused from terminal session)
if sshClient, exists := sfm.sessionSSH[sessionId]; exists {
sshClient.Close()
delete(sfm.sessionSSH, sessionId)
logger.L().Info("SFTP SSH connection closed for session", zap.String("sessionId", sessionId))
} else {
logger.L().Info("SFTP connection closed for session (SSH connection reused)", zap.String("sessionId", sessionId))
}
delete(sfm.lastActive, sessionId)
}
func (sfm *SessionFileManager) IsSessionActive(sessionId string) bool {
sfm.mutex.RLock()
defer sfm.mutex.RUnlock()
_, exists := sfm.sessionSFTP[sessionId]
return exists
}
func (sfm *SessionFileManager) CleanupInactiveSessions(timeout time.Duration) {
sfm.mutex.Lock()
defer sfm.mutex.Unlock()
cutoff := time.Now().Add(-timeout)
for sessionId, lastActive := range sfm.lastActive {
if lastActive.Before(cutoff) {
// Close and remove inactive session
if sftpClient, exists := sfm.sessionSFTP[sessionId]; exists {
sftpClient.Close()
delete(sfm.sessionSFTP, sessionId)
}
if sshClient, exists := sfm.sessionSSH[sessionId]; exists {
sshClient.Close()
delete(sfm.sessionSSH, sessionId)
}
delete(sfm.lastActive, sessionId)
logger.L().Info("Cleaned up inactive SFTP session",
zap.String("sessionId", sessionId),
zap.Duration("inactiveFor", time.Since(lastActive)))
}
}
}
func (sfm *SessionFileManager) GetActiveSessionCount() int {
sfm.mutex.RLock()
defer sfm.mutex.RUnlock()
return len(sfm.sessionSFTP)
}
// IFileService defines the file service interface
type IFileService interface {
// Legacy asset-based operations (for backward compatibility)
ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error)
MkdirAll(ctx context.Context, assetId, accountId int, dir string) error
Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error)
Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error)
Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error)
DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error)
// Session-based operations (NEW - high performance)
SessionLS(ctx context.Context, sessionId, dir string) ([]FileInfo, error)
SessionMkdir(ctx context.Context, sessionId, dir string) error
SessionUpload(ctx context.Context, sessionId, targetPath string, file io.Reader, filename string, size int64) error
SessionUploadWithID(ctx context.Context, transferID, sessionId, targetPath string, file io.Reader, filename string, size int64) error
SessionDownload(ctx context.Context, sessionId, filePath string) (io.ReadCloser, int64, error)
SessionDownloadMultiple(ctx context.Context, sessionId, dir string, filenames []string) (io.ReadCloser, string, int64, error)
// Session lifecycle management
InitSessionFileClient(sessionId string, assetId, accountId int) error
CloseSessionFileClient(sessionId string)
IsSessionActive(sessionId string) bool
// Progress tracking for session transfers
GetSessionTransferProgress(ctx context.Context, sessionId string) ([]*SessionFileTransferProgress, error)
GetTransferProgress(ctx context.Context, transferId string) (*SessionFileTransferProgress, error)
// File history and other operations
AddFileHistory(ctx context.Context, history *model.FileHistory) error
BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB
RecordFileHistory(ctx context.Context, operation, dir, filename string, assetId, accountId int, sessionId ...string) error
RecordFileHistoryBySession(ctx context.Context, sessionId, operation, path string) error
// Utility methods
GetActionCode(operation string) int
ValidateAndNormalizePath(basePath, userPath string) (string, error)
GetRDPDrivePath(assetId int) (string, error)
}
// FileService implements IFileService
type FileService struct {
repo IFileRepository
}
// RDP File related structures
type RDPFileInfo struct {
Name string `json:"name"`
Size int64 `json:"size"`
IsDir bool `json:"is_dir"`
ModTime string `json:"mod_time"`
}
type RDPMkdirRequest struct {
Path string `json:"path" binding:"required"`
}
type RDPProgressWriter struct {
writer io.Writer
transfer *guacd.FileTransfer
transferId string
written int64
}
// Session transfer related structures
type SessionFileTransfer struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
Filename string `json:"filename"`
Size int64 `json:"size"`
Offset int64 `json:"offset"`
Status string `json:"status"` // "pending", "uploading", "completed", "failed"
IsUpload bool `json:"is_upload"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
Error string `json:"error,omitempty"`
mutex sync.Mutex
}
type SessionFileTransferProgress struct {
ID string `json:"id"`
SessionID string `json:"session_id"`
Filename string `json:"filename"`
Size int64 `json:"size"`
Offset int64 `json:"offset"`
Percentage float64 `json:"percentage"`
Status string `json:"status"`
IsUpload bool `json:"is_upload"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
Error string `json:"error,omitempty"`
Speed int64 `json:"speed"` // bytes per second
ETA int64 `json:"eta"` // estimated time to completion in seconds
}
type SessionFileTransferManager struct {
transfers map[string]*SessionFileTransfer
mutex sync.RWMutex
}
func (m *SessionFileTransferManager) CreateTransfer(sessionID, filename string, size int64, isUpload bool) *SessionFileTransfer {
return m.CreateTransferWithID(generateTransferID(), sessionID, filename, size, isUpload)
}
func (m *SessionFileTransferManager) CreateTransferWithID(transferID, sessionID, filename string, size int64, isUpload bool) *SessionFileTransfer {
m.mutex.Lock()
defer m.mutex.Unlock()
transfer := &SessionFileTransfer{
ID: transferID,
SessionID: sessionID,
Filename: filename,
Size: size,
Offset: 0,
Status: "pending",
IsUpload: isUpload,
Created: time.Now(),
Updated: time.Now(),
}
m.transfers[transferID] = transfer
return transfer
}
func (m *SessionFileTransferManager) GetTransfer(id string) *SessionFileTransfer {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.transfers[id]
}
func (m *SessionFileTransferManager) GetTransfersBySession(sessionID string) []*SessionFileTransfer {
m.mutex.RLock()
defer m.mutex.RUnlock()
var transfers []*SessionFileTransfer
for _, transfer := range m.transfers {
if transfer.SessionID == sessionID {
transfers = append(transfers, transfer)
}
}
return transfers
}
func (m *SessionFileTransferManager) GetSessionProgress(sessionID string) []*SessionFileTransferProgress {
transfers := m.GetTransfersBySession(sessionID)
var progressList []*SessionFileTransferProgress
for _, transfer := range transfers {
progressList = append(progressList, transfer.GetProgress())
}
return progressList
}
func (m *SessionFileTransferManager) RemoveTransfer(id string) {
m.mutex.Lock()
defer m.mutex.Unlock()
delete(m.transfers, id)
}
func (m *SessionFileTransferManager) CleanupCompletedTransfers(maxAge time.Duration) {
m.mutex.Lock()
defer m.mutex.Unlock()
cutoff := time.Now().Add(-maxAge)
for id, transfer := range m.transfers {
if (transfer.Status == "completed" || transfer.Status == "failed") && transfer.Updated.Before(cutoff) {
delete(m.transfers, id)
}
}
}
// Progress tracking structures
type FileTransferProgress struct {
TotalSize int64
TransferredSize int64
Status string // "transferring", "completed", "failed"
Type string // "sftp", "rdp"
}
type ProgressWriter struct {
writer io.Writer
transfer *SessionFileTransfer
written int64
}
func NewProgressWriter(writer io.Writer, transfer *SessionFileTransfer) *ProgressWriter {
return &ProgressWriter{
writer: writer,
transfer: transfer,
written: 0,
}
}
func (pw *ProgressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
if n > 0 {
pw.written += int64(n)
pw.transfer.UpdateProgress(pw.written)
}
if err != nil {
pw.transfer.SetError(err)
}
return n, err
}
type FileProgressWriter struct {
writer io.Writer
transferId string
written int64
lastUpdate time.Time
updateBytes int64 // Bytes written since last progress update
updateTicker int64 // Simple counter to reduce time.Now() calls
}
func NewFileProgressWriter(writer io.Writer, transferId string) *FileProgressWriter {
return &FileProgressWriter{
writer: writer,
transferId: transferId,
written: 0,
lastUpdate: time.Now(),
}
}
func (pw *FileProgressWriter) Write(p []byte) (int, error) {
n, err := pw.writer.Write(p)
if n > 0 {
pw.written += int64(n)
pw.updateBytes += int64(n)
pw.updateTicker++
// Update progress every 64KB or every 100 writes to reduce overhead
if pw.updateBytes >= 65536 || pw.updateTicker%100 == 0 {
UpdateTransferProgress(pw.transferId, 0, pw.written, "transferring")
pw.updateBytes = 0
pw.lastUpdate = time.Now()
}
}
if err != nil {
UpdateTransferProgress(pw.transferId, 0, -1, "failed")
}
return n, err
}
// Repository interface
type IFileRepository interface {
AddFileHistory(ctx context.Context, history *model.FileHistory) error
BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB
}
// SessionFileTransfer methods
func (t *SessionFileTransfer) UpdateProgress(offset int64) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.Offset = offset
t.Updated = time.Now()
if t.Status == "pending" {
t.Status = "uploading"
}
if t.Size > 0 && t.Offset >= t.Size {
t.Status = "completed"
}
}
func (t *SessionFileTransfer) SetError(err error) {
t.mutex.Lock()
defer t.mutex.Unlock()
t.Error = err.Error()
t.Status = "failed"
t.Updated = time.Now()
}
// GetProgress returns the current progress information
func (t *SessionFileTransfer) GetProgress() *SessionFileTransferProgress {
t.mutex.Lock()
defer t.mutex.Unlock()
var percentage float64
if t.Size > 0 {
percentage = float64(t.Offset) / float64(t.Size) * 100
}
// Calculate speed and ETA
var speed int64
var eta int64
if !t.Created.Equal(t.Updated) && t.Offset > 0 {
duration := t.Updated.Sub(t.Created).Seconds()
speed = int64(float64(t.Offset) / duration)
if speed > 0 && t.Size > t.Offset {
eta = (t.Size - t.Offset) / speed
}
}
return &SessionFileTransferProgress{
ID: t.ID,
SessionID: t.SessionID,
Filename: t.Filename,
Size: t.Size,
Offset: t.Offset,
Percentage: percentage,
Status: t.Status,
IsUpload: t.IsUpload,
Created: t.Created,
Updated: t.Updated,
Error: t.Error,
Speed: speed,
ETA: eta,
}
}
// Progress tracking functions
func CreateTransferProgress(transferId, transferType string) {
transferMutex.Lock()
fileTransfers[transferId] = &FileTransferProgress{
TotalSize: 0,
TransferredSize: 0,
Status: "transferring",
Type: transferType,
}
transferMutex.Unlock()
}
func UpdateTransferProgress(transferId string, totalSize, transferredSize int64, status string) {
transferMutex.Lock()
if progress, exists := fileTransfers[transferId]; exists {
if totalSize > 0 {
progress.TotalSize = totalSize
}
if transferredSize >= 0 {
progress.TransferredSize = transferredSize
}
if status != "" {
progress.Status = status
}
}
transferMutex.Unlock()
}
func CleanupTransferProgress(transferId string, delay time.Duration) {
go func() {
time.Sleep(delay)
transferMutex.Lock()
delete(fileTransfers, transferId)
transferMutex.Unlock()
}()
}
func GetTransferProgressById(transferId string) (*FileTransferProgress, bool) {
transferMutex.RLock()
progress, exists := fileTransfers[transferId]
transferMutex.RUnlock()
return progress, exists
}
// Utility functions
func generateTransferID() string {
bytes := make([]byte, 16)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
func sanitizeFilename(filename string) (string, error) {
// Remove any directory separators
cleaned := filepath.Base(filename)
// Check for dangerous patterns
if strings.Contains(cleaned, "..") || cleaned == "." || cleaned == "" {
return "", fmt.Errorf("invalid filename")
}
// Additional security checks
if strings.HasPrefix(cleaned, ".") && len(cleaned) > 1 {
// Allow hidden files but validate they're not dangerous
}
return cleaned, nil
}
func IsPermissionError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
permissionKeywords := []string{
"permission denied",
"access denied",
"unauthorized",
"forbidden",
"not authorized",
"insufficient privileges",
"operation not permitted",
"sftp: permission denied",
"ssh: permission denied",
}
for _, keyword := range permissionKeywords {
if strings.Contains(errStr, keyword) {
return true
}
}
return false
}

View File

@@ -1,57 +0,0 @@
package service
import (
"fmt"
"golang.org/x/crypto/ssh"
"github.com/veops/oneterm/internal/model"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/utils"
)
func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) {
asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{}
if err = dbpkg.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil {
return
}
if err = dbpkg.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil {
return
}
account.Password = utils.DecryptAES(account.Password)
account.Pk = utils.DecryptAES(account.Pk)
account.Phrase = utils.DecryptAES(account.Phrase)
if asset.GatewayId != 0 {
if err = dbpkg.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
return
}
gateway.Password = utils.DecryptAES(gateway.Password)
gateway.Pk = utils.DecryptAES(gateway.Pk)
gateway.Phrase = utils.DecryptAES(gateway.Phrase)
}
return
}
func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
switch account.AccountType {
case model.AUTHMETHOD_PASSWORD:
return ssh.Password(account.Password), nil
case model.AUTHMETHOD_PUBLICKEY:
if account.Phrase == "" {
pk, err := ssh.ParsePrivateKey([]byte(account.Pk))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
} else {
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
if err != nil {
return nil, err
}
return ssh.PublicKeys(pk), nil
}
default:
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
}
}