mirror of
https://github.com/veops/oneterm.git
synced 2025-10-31 02:46:29 +08:00
refactor(backend): file service
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/veops/oneterm/internal/api/router"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
fileservice "github.com/veops/oneterm/internal/service/file"
|
||||
"github.com/veops/oneterm/pkg/config"
|
||||
"github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
@@ -45,7 +46,7 @@ func initDB() {
|
||||
func initServices() {
|
||||
service.InitAuthorizationService()
|
||||
|
||||
service.InitFileService()
|
||||
fileservice.InitFileService()
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"github.com/veops/oneterm/internal/acl"
|
||||
"github.com/veops/oneterm/internal/guacd"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
fileservice "github.com/veops/oneterm/internal/service/file"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
myErrors "github.com/veops/oneterm/pkg/errors"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
@@ -51,7 +51,7 @@ const (
|
||||
func (c *Controller) GetFileHistory(ctx *gin.Context) {
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
|
||||
db := service.DefaultFileService.BuildFileHistoryQuery(ctx)
|
||||
db := fileservice.DefaultFileService.BuildFileHistoryQuery(ctx)
|
||||
|
||||
// Apply user permissions - non-admin users can only see their own history
|
||||
if !acl.IsAdmin(currentUser) {
|
||||
@@ -87,9 +87,9 @@ func (c *Controller) FileLS(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Use global file service
|
||||
info, err := service.DefaultFileService.ReadDir(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"))
|
||||
info, err := fileservice.DefaultFileService.ReadDir(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"))
|
||||
if err != nil {
|
||||
if service.IsPermissionError(err) {
|
||||
if fileservice.IsPermissionError(err) {
|
||||
ctx.AbortWithError(http.StatusForbidden, fmt.Errorf("permission denied"))
|
||||
} else {
|
||||
ctx.AbortWithError(http.StatusBadRequest, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}})
|
||||
@@ -110,7 +110,7 @@ func (c *Controller) FileLS(ctx *gin.Context) {
|
||||
List: lo.Map(info, func(f fs.FileInfo, _ int) any {
|
||||
var target string
|
||||
if f.Mode()&os.ModeSymlink != 0 {
|
||||
cli, err := service.GetFileManager().GetFileClient(sess.Session.AssetId, sess.Session.AccountId)
|
||||
cli, err := fileservice.GetFileManager().GetFileClient(sess.Session.AssetId, sess.Session.AccountId)
|
||||
if err == nil {
|
||||
linkPath := filepath.Join(ctx.Query("dir"), f.Name())
|
||||
if linkTarget, err := cli.ReadLink(linkPath); err == nil {
|
||||
@@ -118,7 +118,7 @@ func (c *Controller) FileLS(ctx *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return &service.FileInfo{
|
||||
return &fileservice.FileInfo{
|
||||
Name: f.Name(),
|
||||
IsDir: f.IsDir(),
|
||||
Size: f.Size(),
|
||||
@@ -157,13 +157,13 @@ func (c *Controller) FileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Use global file service
|
||||
if err := service.DefaultFileService.MkdirAll(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir")); err != nil {
|
||||
if err := fileservice.DefaultFileService.MkdirAll(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir")); err != nil {
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}})
|
||||
return
|
||||
}
|
||||
|
||||
// Record file history using unified method
|
||||
if err := service.DefaultFileService.RecordFileHistory(ctx, "mkdir", ctx.Query("dir"), "", sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistory(ctx, "mkdir", ctx.Query("dir"), "", sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -209,7 +209,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
|
||||
transferId = fmt.Sprintf("%d-%d-%d", sess.Session.AssetId, sess.Session.AccountId, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
service.CreateTransferProgress(transferId, "sftp")
|
||||
fileservice.CreateTransferProgress(transferId, "sftp")
|
||||
|
||||
// Parse multipart form
|
||||
if err := ctx.Request.ParseMultipartForm(MaxMemoryForParsing); err != nil {
|
||||
@@ -243,7 +243,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Update transfer progress with file size
|
||||
service.UpdateTransferProgress(transferId, fileSize, 0, "")
|
||||
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "")
|
||||
|
||||
targetPath := filepath.Join(targetDir, filename)
|
||||
|
||||
@@ -290,10 +290,10 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Phase 2: Transfer to target machine using SFTP (synchronous)
|
||||
service.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
|
||||
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
|
||||
|
||||
if err := service.TransferToTarget(transferId, "", tempFilePath, targetPath, sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
service.UpdateTransferProgress(transferId, 0, -1, "failed")
|
||||
if err := fileservice.TransferToTarget(transferId, "", tempFilePath, targetPath, sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
fileservice.UpdateTransferProgress(transferId, 0, -1, "failed")
|
||||
os.Remove(tempFilePath)
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
@@ -303,13 +303,13 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Mark transfer as completed
|
||||
service.UpdateTransferProgress(transferId, 0, -1, "completed")
|
||||
fileservice.UpdateTransferProgress(transferId, 0, -1, "completed")
|
||||
|
||||
// Clean up temp file after successful transfer
|
||||
os.Remove(tempFilePath)
|
||||
|
||||
// Record file history using unified method
|
||||
if err := service.DefaultFileService.RecordFileHistory(ctx, "upload", targetDir, filename, sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistory(ctx, "upload", targetDir, filename, sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -329,7 +329,7 @@ func (c *Controller) FileUpload(ctx *gin.Context) {
|
||||
// Clean up progress record after a short delay
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second) // Keep for 30 seconds for any delayed queries
|
||||
service.CleanupTransferProgress(transferId, 0)
|
||||
fileservice.CleanupTransferProgress(transferId, 0)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -378,9 +378,9 @@ func (c *Controller) FileDownload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
reader, downloadFilename, fileSize, err := service.DefaultFileService.DownloadMultiple(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"), filenames)
|
||||
reader, downloadFilename, fileSize, err := fileservice.DefaultFileService.DownloadMultiple(ctx, sess.Session.AssetId, sess.Session.AccountId, ctx.Query("dir"), filenames)
|
||||
if err != nil {
|
||||
if service.IsPermissionError(err) {
|
||||
if fileservice.IsPermissionError(err) {
|
||||
ctx.AbortWithError(http.StatusForbidden, &myErrors.ApiError{Code: myErrors.ErrNoPerm, Data: map[string]any{"err": err}})
|
||||
} else {
|
||||
ctx.AbortWithError(http.StatusInternalServerError, &myErrors.ApiError{Code: myErrors.ErrInvalidArgument, Data: map[string]any{"err": err}})
|
||||
@@ -390,7 +390,7 @@ func (c *Controller) FileDownload(ctx *gin.Context) {
|
||||
defer reader.Close()
|
||||
|
||||
// Record file operation history using unified method
|
||||
if err := service.DefaultFileService.RecordFileHistory(ctx, "download", ctx.Query("dir"), strings.Join(filenames, ","), sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistory(ctx, "download", ctx.Query("dir"), strings.Join(filenames, ","), sess.Session.AssetId, sess.Session.AccountId); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -448,7 +448,7 @@ func (c *Controller) RDPFileList(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Check if RDP drive is enabled
|
||||
if !service.IsRDPDriveEnabled(tunnel) {
|
||||
if !fileservice.IsRDPDriveEnabled(tunnel) {
|
||||
ctx.JSON(http.StatusBadRequest, HttpResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
Message: "RDP drive is not enabled for this session",
|
||||
@@ -457,7 +457,7 @@ func (c *Controller) RDPFileList(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Send file list request through Guacamole protocol
|
||||
files, err := service.RequestRDPFileList(tunnel, path)
|
||||
files, err := fileservice.RequestRDPFileList(tunnel, path)
|
||||
if err != nil {
|
||||
logger.L().Error("Failed to get RDP file list", zap.Error(err))
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
@@ -501,7 +501,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Create progress record IMMEDIATELY when request starts
|
||||
service.CreateTransferProgress(transferId, "rdp")
|
||||
fileservice.CreateTransferProgress(transferId, "rdp")
|
||||
|
||||
tunnel, err := c.validateRDPAccess(ctx, sessionId)
|
||||
if err != nil {
|
||||
@@ -519,7 +519,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !service.IsRDPDriveEnabled(tunnel) {
|
||||
if !fileservice.IsRDPDriveEnabled(tunnel) {
|
||||
logger.L().Error("RDP drive is not enabled for session", zap.String("sessionId", sessionId))
|
||||
ctx.JSON(http.StatusBadRequest, HttpResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
@@ -528,7 +528,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !service.IsRDPUploadAllowed(tunnel) {
|
||||
if !fileservice.IsRDPUploadAllowed(tunnel) {
|
||||
logger.L().Error("RDP upload is disabled for session", zap.String("sessionId", sessionId))
|
||||
ctx.JSON(http.StatusForbidden, HttpResponse{
|
||||
Code: http.StatusForbidden,
|
||||
@@ -653,7 +653,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
fullPath := filepath.Join(targetPath, filename)
|
||||
|
||||
// Update progress record with file size
|
||||
service.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
|
||||
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
|
||||
|
||||
// Open temp file for reading and upload synchronously
|
||||
tempFile, err := os.Open(tempFilePath)
|
||||
@@ -668,9 +668,9 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
defer tempFile.Close()
|
||||
|
||||
// Perform RDP upload synchronously
|
||||
err = service.UploadRDPFileStreamWithID(tunnel, transferId, sessionId, fullPath, tempFile, fileSize)
|
||||
err = fileservice.UploadRDPFileStreamWithID(tunnel, transferId, sessionId, fullPath, tempFile, fileSize)
|
||||
if err != nil {
|
||||
service.UpdateTransferProgress(transferId, 0, -1, "failed")
|
||||
fileservice.UpdateTransferProgress(transferId, 0, -1, "failed")
|
||||
os.Remove(tempFilePath)
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
@@ -683,7 +683,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
os.Remove(tempFilePath)
|
||||
|
||||
// Record file history using session-based method
|
||||
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", fullPath); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", fullPath); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -708,7 +708,7 @@ func (c *Controller) RDPFileUpload(ctx *gin.Context) {
|
||||
// Clean up progress record after a short delay
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second) // Keep for 30 seconds for any delayed queries
|
||||
service.CleanupTransferProgress(transferId, 0)
|
||||
fileservice.CleanupTransferProgress(transferId, 0)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -742,7 +742,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !service.IsRDPDriveEnabled(tunnel) {
|
||||
if !fileservice.IsRDPDriveEnabled(tunnel) {
|
||||
ctx.JSON(http.StatusForbidden, HttpResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "Drive redirection not enabled",
|
||||
@@ -750,7 +750,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !service.IsRDPDownloadAllowed(tunnel) {
|
||||
if !fileservice.IsRDPDownloadAllowed(tunnel) {
|
||||
ctx.JSON(http.StatusForbidden, HttpResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "File download not allowed",
|
||||
@@ -803,7 +803,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
|
||||
if len(filenames) == 1 {
|
||||
// Single file download (memory-efficient streaming)
|
||||
path := filepath.Join(dir, filenames[0])
|
||||
reader, fileSize, err = service.DownloadRDPFile(tunnel, path)
|
||||
reader, fileSize, err = fileservice.DownloadRDPFile(tunnel, path)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
@@ -815,7 +815,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
|
||||
downloadFilename = filenames[0]
|
||||
} else {
|
||||
// Multiple files download as ZIP
|
||||
reader, downloadFilename, fileSize, err = service.DownloadRDPMultiple(tunnel, dir, filenames)
|
||||
reader, downloadFilename, fileSize, err = fileservice.DownloadRDPMultiple(tunnel, dir, filenames)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
@@ -827,7 +827,7 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
|
||||
defer reader.Close()
|
||||
|
||||
// Record file operation history using session-based method
|
||||
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(dir, strings.Join(filenames, ","))); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(dir, strings.Join(filenames, ","))); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -861,13 +861,13 @@ func (c *Controller) RDPFileDownload(ctx *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param session_id path string true "Session ID"
|
||||
// @Param request body service.RDPMkdirRequest true "Directory creation request"
|
||||
// @Param request body fileservice.RDPMkdirRequest true "Directory creation request"
|
||||
// @Success 200 {object} HttpResponse
|
||||
// @Router /rdp/sessions/{session_id}/files/mkdir [post]
|
||||
func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
|
||||
sessionId := ctx.Param("session_id")
|
||||
|
||||
var req service.RDPMkdirRequest
|
||||
var req fileservice.RDPMkdirRequest
|
||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, HttpResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
@@ -893,7 +893,7 @@ func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Check if upload is allowed (mkdir is considered an upload operation)
|
||||
if !service.IsRDPUploadAllowed(tunnel) {
|
||||
if !fileservice.IsRDPUploadAllowed(tunnel) {
|
||||
ctx.JSON(http.StatusForbidden, HttpResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "Directory creation is disabled for this session",
|
||||
@@ -902,7 +902,7 @@ func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Send mkdir request through Guacamole protocol
|
||||
err := service.CreateRDPDirectory(tunnel, req.Path)
|
||||
err := fileservice.CreateRDPDirectory(tunnel, req.Path)
|
||||
if err != nil {
|
||||
logger.L().Error("Failed to create directory in RDP session", zap.Error(err))
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
@@ -913,7 +913,7 @@ func (c *Controller) RDPFileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Record file operation history using session-based method
|
||||
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", req.Path); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", req.Path); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -967,7 +967,7 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Check if session is active
|
||||
if !service.DefaultFileService.IsSessionActive(sessionId) {
|
||||
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found or inactive",
|
||||
@@ -995,14 +995,14 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Use session-based file service
|
||||
fileInfos, err := service.DefaultFileService.SessionLS(ctx, sessionId, dir)
|
||||
fileInfos, err := fileservice.DefaultFileService.SessionLS(ctx, sessionId, dir)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrSessionNotFound) {
|
||||
if errors.Is(err, fileservice.ErrSessionNotFound) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found",
|
||||
})
|
||||
} else if service.IsPermissionError(err) {
|
||||
} else if fileservice.IsPermissionError(err) {
|
||||
ctx.JSON(http.StatusForbidden, HttpResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "Permission denied",
|
||||
@@ -1019,7 +1019,7 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
|
||||
// Filter hidden files unless show_hidden is true
|
||||
showHidden := cast.ToBool(ctx.Query("show_hidden"))
|
||||
if !showHidden {
|
||||
var filtered []service.FileInfo
|
||||
var filtered []fileservice.FileInfo
|
||||
for _, f := range fileInfos {
|
||||
if !strings.HasPrefix(f.Name, ".") {
|
||||
filtered = append(filtered, f)
|
||||
@@ -1030,7 +1030,7 @@ func (c *Controller) SftpFileLS(ctx *gin.Context) {
|
||||
|
||||
res := &ListData{
|
||||
Count: int64(len(fileInfos)),
|
||||
List: lo.Map(fileInfos, func(f service.FileInfo, _ int) any { return f }),
|
||||
List: lo.Map(fileInfos, func(f fileservice.FileInfo, _ int) any { return f }),
|
||||
}
|
||||
ctx.JSON(http.StatusOK, NewHttpResponseWithData(res))
|
||||
}
|
||||
@@ -1055,7 +1055,7 @@ func (c *Controller) SftpFileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Check if session is active
|
||||
if !service.DefaultFileService.IsSessionActive(sessionId) {
|
||||
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found or inactive",
|
||||
@@ -1083,8 +1083,8 @@ func (c *Controller) SftpFileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Use session-based file service
|
||||
if err := service.DefaultFileService.SessionMkdir(ctx, sessionId, dir); err != nil {
|
||||
if errors.Is(err, service.ErrSessionNotFound) {
|
||||
if err := fileservice.DefaultFileService.SessionMkdir(ctx, sessionId, dir); err != nil {
|
||||
if errors.Is(err, fileservice.ErrSessionNotFound) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found",
|
||||
@@ -1099,7 +1099,7 @@ func (c *Controller) SftpFileMkdir(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Record history using session-based method
|
||||
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", dir); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "mkdir", dir); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -1113,7 +1113,7 @@ func (c *Controller) TransferProgressById(ctx *gin.Context) {
|
||||
transferId := ctx.Param("transfer_id")
|
||||
|
||||
// First check unified progress tracking
|
||||
progress, exists := service.GetTransferProgressById(transferId)
|
||||
progress, exists := fileservice.GetTransferProgressById(transferId)
|
||||
|
||||
if exists {
|
||||
// Calculate transfer progress
|
||||
@@ -1136,7 +1136,7 @@ func (c *Controller) TransferProgressById(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Fallback: check RDP guacd transfer manager
|
||||
rdpProgress, err := service.GetRDPTransferProgressById(transferId)
|
||||
rdpProgress, err := fileservice.GetRDPTransferProgressById(transferId)
|
||||
if err == nil {
|
||||
ctx.JSON(http.StatusOK, HttpResponse{
|
||||
Code: 0,
|
||||
@@ -1181,7 +1181,7 @@ func (c *Controller) RDPFileTransferPrepare(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Create unified progress tracking entry
|
||||
service.CreateTransferProgress(transferId, "rdp")
|
||||
fileservice.CreateTransferProgress(transferId, "rdp")
|
||||
|
||||
ctx.JSON(http.StatusOK, HttpResponse{
|
||||
Code: 0,
|
||||
@@ -1221,10 +1221,10 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
|
||||
transferId = fmt.Sprintf("%s-%d", sessionId, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
service.CreateTransferProgress(transferId, "sftp")
|
||||
fileservice.CreateTransferProgress(transferId, "sftp")
|
||||
|
||||
// Validate session
|
||||
if !service.DefaultFileService.IsSessionActive(sessionId) {
|
||||
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found or inactive",
|
||||
@@ -1280,7 +1280,7 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Update transfer progress with file size now that we have it
|
||||
service.UpdateTransferProgress(transferId, fileSize, 0, "")
|
||||
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "")
|
||||
|
||||
// Phase 1: Save file to server temp directory
|
||||
tempDir := filepath.Join(os.TempDir(), "oneterm-uploads", sessionId)
|
||||
@@ -1327,11 +1327,11 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
|
||||
targetPath := filepath.Join(targetDir, filename)
|
||||
|
||||
// Phase 2: Transfer to target machine using SFTP (synchronous)
|
||||
service.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
|
||||
fileservice.UpdateTransferProgress(transferId, fileSize, 0, "transferring")
|
||||
|
||||
if err := service.TransferToTarget(transferId, sessionId, tempFilePath, targetPath, 0, 0); err != nil {
|
||||
if err := fileservice.TransferToTarget(transferId, sessionId, tempFilePath, targetPath, 0, 0); err != nil {
|
||||
// Mark transfer as failed and clean up
|
||||
service.UpdateTransferProgress(transferId, 0, -1, "failed")
|
||||
fileservice.UpdateTransferProgress(transferId, 0, -1, "failed")
|
||||
os.Remove(tempFilePath)
|
||||
ctx.JSON(http.StatusInternalServerError, HttpResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
@@ -1341,13 +1341,13 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// Mark transfer as completed (success)
|
||||
service.UpdateTransferProgress(transferId, 0, -1, "completed")
|
||||
fileservice.UpdateTransferProgress(transferId, 0, -1, "completed")
|
||||
|
||||
// Clean up temp file after successful transfer
|
||||
os.Remove(tempFilePath)
|
||||
|
||||
// Record file history using session-based method
|
||||
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", filepath.Join(targetDir, filename)); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "upload", filepath.Join(targetDir, filename)); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -1367,7 +1367,7 @@ func (c *Controller) SftpFileUpload(ctx *gin.Context) {
|
||||
// Clean up progress record after a short delay
|
||||
go func() {
|
||||
time.Sleep(30 * time.Second) // Keep for 30 seconds for any delayed queries
|
||||
service.CleanupTransferProgress(transferId, 0)
|
||||
fileservice.CleanupTransferProgress(transferId, 0)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1383,7 +1383,7 @@ func (c *Controller) SftpFileDownload(ctx *gin.Context) {
|
||||
sessionId := ctx.Param("session_id")
|
||||
|
||||
// Check if session is active
|
||||
if !service.DefaultFileService.IsSessionActive(sessionId) {
|
||||
if !fileservice.DefaultFileService.IsSessionActive(sessionId) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found or inactive",
|
||||
@@ -1436,14 +1436,14 @@ func (c *Controller) SftpFileDownload(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
reader, downloadFilename, fileSize, err := service.DefaultFileService.SessionDownloadMultiple(ctx, sessionId, ctx.Query("dir"), filenames)
|
||||
reader, downloadFilename, fileSize, err := fileservice.DefaultFileService.SessionDownloadMultiple(ctx, sessionId, ctx.Query("dir"), filenames)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrSessionNotFound) {
|
||||
if errors.Is(err, fileservice.ErrSessionNotFound) {
|
||||
ctx.JSON(http.StatusNotFound, HttpResponse{
|
||||
Code: http.StatusNotFound,
|
||||
Message: "Session not found",
|
||||
})
|
||||
} else if service.IsPermissionError(err) {
|
||||
} else if fileservice.IsPermissionError(err) {
|
||||
ctx.JSON(http.StatusForbidden, HttpResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "Permission denied",
|
||||
@@ -1459,7 +1459,7 @@ func (c *Controller) SftpFileDownload(ctx *gin.Context) {
|
||||
defer reader.Close()
|
||||
|
||||
// Record file operation history using session-based method
|
||||
if err := service.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(ctx.Query("dir"), strings.Join(filenames, ","))); err != nil {
|
||||
if err := fileservice.DefaultFileService.RecordFileHistoryBySession(ctx, sessionId, "download", filepath.Join(ctx.Query("dir"), strings.Join(filenames, ","))); err != nil {
|
||||
logger.L().Error("Failed to record file history", zap.Error(err))
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,9 @@ import (
|
||||
"github.com/veops/oneterm/internal/connector/protocols"
|
||||
"github.com/veops/oneterm/internal/connector/protocols/db"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
fileservice "github.com/veops/oneterm/internal/service/file"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
myErrors "github.com/veops/oneterm/pkg/errors"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
@@ -163,7 +165,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
|
||||
currentUser, _ := acl.GetSessionFromCtx(ctx)
|
||||
|
||||
assetId, accountId := cast.ToInt(ctx.Param("asset_id")), cast.ToInt(ctx.Param("account_id"))
|
||||
asset, account, gateway, err := service.GetAAG(assetId, accountId)
|
||||
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -265,7 +267,7 @@ func DoConnect(ctx *gin.Context, ws *websocket.Conn) (sess *gsession.Session, er
|
||||
// Only for SSH-based protocols that support SFTP
|
||||
protocol := strings.Split(sess.Protocol, ":")[0]
|
||||
if protocol == "ssh" {
|
||||
if err := service.DefaultFileService.InitSessionFileClient(sess.SessionId, sess.AssetId, sess.AccountId); err != nil {
|
||||
if err := fileservice.DefaultFileService.InitSessionFileClient(sess.SessionId, sess.AssetId, sess.AccountId); err != nil {
|
||||
logger.L().Warn("Failed to initialize session file client",
|
||||
zap.String("sessionId", sess.SessionId),
|
||||
zap.Int("assetId", sess.AssetId),
|
||||
@@ -296,7 +298,7 @@ func HandleTerm(sess *gsession.Session, ctx *gin.Context) (err error) {
|
||||
// Clean up session-based file client (only for SSH-based protocols)
|
||||
protocol := strings.Split(sess.Protocol, ":")[0]
|
||||
if protocol == "ssh" {
|
||||
service.DefaultFileService.CloseSessionFileClient(sess.SessionId)
|
||||
fileservice.DefaultFileService.CloseSessionFileClient(sess.SessionId)
|
||||
// Clear SSH client from session to ensure proper cleanup
|
||||
sess.ClearSSHClient()
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
@@ -33,7 +33,7 @@ func ConnectSsh(ctx *gin.Context, sess *gsession.Session, asset *model.Asset, ac
|
||||
return
|
||||
}
|
||||
|
||||
auth, err := service.GetAuth(account)
|
||||
auth, err := repository.GetAuth(account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
myi18n "github.com/veops/oneterm/internal/i18n"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/service"
|
||||
fileservice "github.com/veops/oneterm/internal/service/file"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
myErrors "github.com/veops/oneterm/pkg/errors"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
@@ -227,7 +227,7 @@ func OfflineSession(ctx *gin.Context, sessionId string, closer string) {
|
||||
defer gsession.GetOnlineSession().Delete(sessionId)
|
||||
|
||||
// Clean up session-based file client
|
||||
service.DefaultFileService.CloseSessionFileClient(sessionId)
|
||||
fileservice.DefaultFileService.CloseSessionFileClient(sessionId)
|
||||
|
||||
session := gsession.GetOnlineSessionById(sessionId)
|
||||
if session == nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/pkg/config"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -145,3 +147,27 @@ func HandleAccountIds(ctx context.Context, dbFind *gorm.DB, resIds []int) (db *g
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetAuth creates SSH authentication method from account credentials
|
||||
func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
|
||||
switch account.AccountType {
|
||||
case model.AUTHMETHOD_PASSWORD:
|
||||
return ssh.Password(account.Password), nil
|
||||
case model.AUTHMETHOD_PUBLICKEY:
|
||||
if account.Phrase == "" {
|
||||
pk, err := ssh.ParsePrivateKey([]byte(account.Pk))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(pk), nil
|
||||
} else {
|
||||
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(pk), nil
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/pkg/config"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -279,3 +280,27 @@ func HandleAssetIds(ctx context.Context, dbFind *gorm.DB, resIds []int) (db *gor
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetAAG retrieves Asset, Account, and Gateway by their IDs with decrypted credentials
|
||||
func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) {
|
||||
asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{}
|
||||
if err = dbpkg.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil {
|
||||
return
|
||||
}
|
||||
if err = dbpkg.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil {
|
||||
return
|
||||
}
|
||||
account.Password = utils.DecryptAES(account.Password)
|
||||
account.Pk = utils.DecryptAES(account.Pk)
|
||||
account.Phrase = utils.DecryptAES(account.Phrase)
|
||||
if asset.GatewayId != 0 {
|
||||
if err = dbpkg.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
|
||||
return
|
||||
}
|
||||
gateway.Password = utils.DecryptAES(gateway.Password)
|
||||
gateway.Pk = utils.DecryptAES(gateway.Pk)
|
||||
gateway.Phrase = utils.DecryptAES(gateway.Phrase)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
666
backend/internal/service/file/file.go
Normal file
666
backend/internal/service/file/file.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/sftp"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/veops/oneterm/internal/acl"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
)
|
||||
|
||||
// Global file service instance
|
||||
var DefaultFileService IFileService
|
||||
|
||||
// InitFileService initializes the global file service
|
||||
func InitFileService() {
|
||||
DefaultFileService = NewFileService(&FileRepository{
|
||||
db: dbpkg.DB,
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Legacy file manager cleanup
|
||||
go func() {
|
||||
tk := time.NewTicker(time.Minute)
|
||||
for {
|
||||
<-tk.C
|
||||
func() {
|
||||
GetFileManager().mtx.Lock()
|
||||
defer GetFileManager().mtx.Unlock()
|
||||
for k, v := range GetFileManager().lastTime {
|
||||
if v.Before(time.Now().Add(time.Minute * 10)) {
|
||||
delete(GetFileManager().sftps, k)
|
||||
delete(GetFileManager().lastTime, k)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// Session-based file manager cleanup
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
// Clean up sessions inactive for more than 30 minutes
|
||||
GetSessionFileManager().CleanupInactiveSessions(30 * time.Minute)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// NewFileService creates a new file service instance
|
||||
func NewFileService(repo IFileRepository) IFileService {
|
||||
return &FileService{
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy asset-based operations
|
||||
func (s *FileService) ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error) {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cli.ReadDir(dir)
|
||||
}
|
||||
|
||||
func (s *FileService) MkdirAll(ctx context.Context, assetId, accountId int, dir string) error {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cli.MkdirAll(dir)
|
||||
}
|
||||
|
||||
func (s *FileService) Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error) {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if parent directory exists, create only if not exists
|
||||
parentDir := filepath.Dir(path)
|
||||
if parentDir != "" && parentDir != "." && parentDir != "/" {
|
||||
if _, err := cli.Stat(parentDir); err != nil {
|
||||
// Directory doesn't exist, create it
|
||||
if err := cli.MkdirAll(parentDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to create parent directory %s: %w", parentDir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cli.Create(path)
|
||||
}
|
||||
|
||||
func (s *FileService) Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error) {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cli.Open(path)
|
||||
}
|
||||
|
||||
func (s *FileService) Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error) {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cli.Stat(path)
|
||||
}
|
||||
|
||||
// DownloadMultiple handles downloading single file or multiple files/directories as ZIP
|
||||
func (s *FileService) DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
// Validate and sanitize all filenames for security
|
||||
var sanitizedFilenames []string
|
||||
for _, filename := range filenames {
|
||||
sanitized, err := sanitizeFilename(filename)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
|
||||
}
|
||||
sanitizedFilenames = append(sanitizedFilenames, sanitized)
|
||||
}
|
||||
|
||||
// If only one file, check if it's a regular file first
|
||||
if len(sanitizedFilenames) == 1 {
|
||||
fullPath := filepath.Join(dir, sanitizedFilenames[0])
|
||||
fileInfo, err := cli.Stat(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
// If it's a regular file, return directly
|
||||
if !fileInfo.IsDir() {
|
||||
reader, err := cli.Open(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Multiple files or contains directory, create ZIP
|
||||
return s.createZipArchive(cli, dir, sanitizedFilenames)
|
||||
}
|
||||
|
||||
// createZipArchive creates a ZIP archive containing the specified files/directories
|
||||
func (s *FileService) createZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
// Generate ZIP filename
|
||||
var zipName string
|
||||
if len(filenames) == 1 {
|
||||
zipName = filenames[0] + ".zip"
|
||||
} else {
|
||||
zipName = "download.zip"
|
||||
}
|
||||
|
||||
// Use pipe for true streaming without memory buffering
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
|
||||
// Create ZIP in a separate goroutine
|
||||
go func() {
|
||||
defer pipeWriter.Close()
|
||||
|
||||
zipWriter := zip.NewWriter(pipeWriter)
|
||||
defer zipWriter.Close()
|
||||
|
||||
// Add each file/directory to ZIP
|
||||
for _, filename := range filenames {
|
||||
fullPath := filepath.Join(baseDir, filename)
|
||||
|
||||
if err := s.addToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
|
||||
logger.L().Error("Failed to add file to ZIP", zap.String("path", fullPath), zap.Error(err))
|
||||
pipeWriter.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Return pipe reader for streaming, size unknown (-1)
|
||||
return pipeReader, zipName, -1, nil
|
||||
}
|
||||
|
||||
// addToZip recursively adds files/directories to ZIP archive
|
||||
func (s *FileService) addToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
|
||||
fileInfo, err := cli.Stat(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if fileInfo.IsDir() {
|
||||
// Add directory
|
||||
return s.addDirToZip(cli, zipWriter, fullPath, relativePath)
|
||||
} else {
|
||||
// Add file
|
||||
return s.addFileToZip(cli, zipWriter, fullPath, relativePath)
|
||||
}
|
||||
}
|
||||
|
||||
// addFileToZip adds a single file to ZIP archive
|
||||
func (s *FileService) addFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
|
||||
// Get file info first to preserve metadata
|
||||
fileInfo, err := cli.Stat(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open remote file
|
||||
file, err := cli.Open(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create FileHeader with original file metadata
|
||||
header := &zip.FileHeader{
|
||||
Name: relativePath,
|
||||
Method: zip.Deflate,
|
||||
Modified: fileInfo.ModTime(), // Preserve original modification time
|
||||
}
|
||||
|
||||
// Set file mode
|
||||
header.SetMode(fileInfo.Mode())
|
||||
|
||||
// Create file in ZIP with header
|
||||
zipFile, err := zipWriter.CreateHeader(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy file content
|
||||
_, err = io.Copy(zipFile, file)
|
||||
return err
|
||||
}
|
||||
|
||||
// addDirToZip recursively adds a directory to ZIP archive
|
||||
func (s *FileService) addDirToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
|
||||
// Get directory info to preserve metadata
|
||||
dirInfo, err := cli.Stat(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read directory contents
|
||||
entries, err := cli.ReadDir(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If directory is empty, create directory entry with preserved timestamp
|
||||
if len(entries) == 0 {
|
||||
header := &zip.FileHeader{
|
||||
Name: relativePath + "/",
|
||||
Method: zip.Store, // Directories are not compressed
|
||||
Modified: dirInfo.ModTime(), // Preserve original modification time
|
||||
}
|
||||
header.SetMode(dirInfo.Mode())
|
||||
|
||||
_, err := zipWriter.CreateHeader(header)
|
||||
return err
|
||||
}
|
||||
|
||||
// Recursively add each entry in the directory
|
||||
for _, entry := range entries {
|
||||
entryFullPath := filepath.Join(fullPath, entry.Name())
|
||||
entryRelativePath := filepath.Join(relativePath, entry.Name())
|
||||
|
||||
if err := s.addToZip(cli, zipWriter, fullPath, entryRelativePath, entryFullPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Session-based operations
|
||||
func (s *FileService) SessionLS(ctx context.Context, sessionId, dir string) ([]FileInfo, error) {
|
||||
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entries, err := cli.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fileInfos []FileInfo
|
||||
for _, entry := range entries {
|
||||
var target string
|
||||
if entry.Mode()&fs.ModeSymlink != 0 {
|
||||
linkPath := filepath.Join(dir, entry.Name())
|
||||
if linkTarget, err := cli.ReadLink(linkPath); err == nil {
|
||||
target = linkTarget
|
||||
}
|
||||
}
|
||||
|
||||
fileInfos = append(fileInfos, FileInfo{
|
||||
Name: entry.Name(),
|
||||
IsDir: entry.IsDir(),
|
||||
Size: entry.Size(),
|
||||
Mode: entry.Mode().String(),
|
||||
IsLink: entry.Mode()&fs.ModeSymlink != 0,
|
||||
Target: target,
|
||||
ModTime: entry.ModTime().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
return fileInfos, nil
|
||||
}
|
||||
|
||||
func (s *FileService) SessionMkdir(ctx context.Context, sessionId, dir string) error {
|
||||
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cli.MkdirAll(dir)
|
||||
}
|
||||
|
||||
func (s *FileService) SessionUpload(ctx context.Context, sessionId, targetPath string, file io.Reader, filename string, size int64) error {
|
||||
return s.SessionUploadWithID(ctx, "", sessionId, targetPath, file, filename, size)
|
||||
}
|
||||
|
||||
func (s *FileService) SessionUploadWithID(ctx context.Context, transferID, sessionId, targetPath string, file io.Reader, filename string, size int64) error {
|
||||
client, err := GetSessionFileManager().GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var transfer *SessionFileTransfer
|
||||
transferManager := GetSessionTransferManager()
|
||||
|
||||
if transferID != "" {
|
||||
// Use existing transfer or create new one
|
||||
transfer = transferManager.GetTransfer(transferID)
|
||||
if transfer == nil {
|
||||
transfer = transferManager.CreateTransferWithID(transferID, sessionId, filename, size, true)
|
||||
if transfer == nil {
|
||||
return fmt.Errorf("transfer ID already exists: %s", transferID)
|
||||
}
|
||||
// Set initial status
|
||||
transfer.UpdateProgress(0)
|
||||
}
|
||||
} else {
|
||||
// Create new transfer with auto-generated ID
|
||||
transfer = transferManager.CreateTransfer(sessionId, filename, size, true)
|
||||
transfer.UpdateProgress(0)
|
||||
}
|
||||
|
||||
parentDir := filepath.Dir(targetPath)
|
||||
if parentDir != "" && parentDir != "." && parentDir != "/" {
|
||||
if _, err := client.Stat(parentDir); err != nil {
|
||||
if err := client.MkdirAll(parentDir); err != nil {
|
||||
transfer.SetError(fmt.Errorf("failed to create parent directory %s: %w", parentDir, err))
|
||||
return fmt.Errorf("failed to create parent directory %s: %w", parentDir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
remoteFile, err := client.Create(targetPath)
|
||||
if err != nil {
|
||||
transfer.SetError(fmt.Errorf("failed to create remote file: %w", err))
|
||||
return fmt.Errorf("failed to create remote file: %w", err)
|
||||
}
|
||||
defer remoteFile.Close()
|
||||
|
||||
progressWriter := NewProgressWriter(remoteFile, transfer)
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
_, err = io.CopyBuffer(progressWriter, file, buffer)
|
||||
if err != nil {
|
||||
transfer.SetError(fmt.Errorf("failed to upload file: %w", err))
|
||||
return fmt.Errorf("failed to upload file: %w", err)
|
||||
}
|
||||
|
||||
transfer.UpdateProgress(size)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileService) SessionDownload(ctx context.Context, sessionId, filePath string) (io.ReadCloser, int64, error) {
|
||||
cli, err := GetSessionFileManager().GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Get file info
|
||||
info, err := cli.Stat(filePath)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Open file
|
||||
file, err := cli.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return file, info.Size(), nil
|
||||
}
|
||||
|
||||
func (s *FileService) SessionDownloadMultiple(ctx context.Context, sessionId, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
client, err := GetSessionFileManager().GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
// Validate and sanitize filenames
|
||||
var sanitizedFilenames []string
|
||||
for _, filename := range filenames {
|
||||
sanitized, err := sanitizeFilename(filename)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
|
||||
}
|
||||
sanitizedFilenames = append(sanitizedFilenames, sanitized)
|
||||
}
|
||||
|
||||
// If only one file and it's not a directory, return directly
|
||||
if len(sanitizedFilenames) == 1 {
|
||||
fullPath := filepath.Join(dir, sanitizedFilenames[0])
|
||||
fileInfo, err := client.Stat(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
if !fileInfo.IsDir() {
|
||||
reader, err := client.Open(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Multiple files or contains directory, create ZIP
|
||||
return s.createZipArchive(client, dir, sanitizedFilenames)
|
||||
}
|
||||
|
||||
// Session lifecycle management
|
||||
func (s *FileService) InitSessionFileClient(sessionId string, assetId, accountId int) error {
|
||||
return GetSessionFileManager().InitSessionSFTP(sessionId, assetId, accountId)
|
||||
}
|
||||
|
||||
func (s *FileService) CloseSessionFileClient(sessionId string) {
|
||||
GetSessionFileManager().CloseSessionSFTP(sessionId)
|
||||
}
|
||||
|
||||
func (s *FileService) IsSessionActive(sessionId string) bool {
|
||||
return GetSessionFileManager().IsSessionActive(sessionId)
|
||||
}
|
||||
|
||||
// Progress tracking
|
||||
func (s *FileService) GetSessionTransferProgress(ctx context.Context, sessionId string) ([]*SessionFileTransferProgress, error) {
|
||||
return GetSessionTransferManager().GetSessionProgress(sessionId), nil
|
||||
}
|
||||
|
||||
func (s *FileService) GetTransferProgress(ctx context.Context, transferId string) (*SessionFileTransferProgress, error) {
|
||||
transfer := GetSessionTransferManager().GetTransfer(transferId)
|
||||
if transfer == nil {
|
||||
return nil, fmt.Errorf("transfer not found")
|
||||
}
|
||||
return transfer.GetProgress(), nil
|
||||
}
|
||||
|
||||
// File history operations
|
||||
func (s *FileService) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
|
||||
if s.repo == nil {
|
||||
return fmt.Errorf("repository not initialized")
|
||||
}
|
||||
return s.repo.AddFileHistory(ctx, history)
|
||||
}
|
||||
|
||||
func (s *FileService) BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB {
|
||||
db := dbpkg.DB.Model(&model.FileHistory{})
|
||||
|
||||
db = dbpkg.FilterSearch(ctx, db, "dir", "filename")
|
||||
|
||||
// Apply exact match filters
|
||||
db = dbpkg.FilterEqual(ctx, db, "status", "uid", "asset_id", "account_id", "action")
|
||||
|
||||
// Apply client IP filter
|
||||
if clientIp := ctx.Query("client_ip"); clientIp != "" {
|
||||
db = db.Where("client_ip = ?", clientIp)
|
||||
}
|
||||
|
||||
// Apply date range filters
|
||||
if start := ctx.Query("start"); start != "" {
|
||||
db = db.Where("created_at >= ?", start)
|
||||
}
|
||||
|
||||
if end := ctx.Query("end"); end != "" {
|
||||
db = db.Where("created_at <= ?", end)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (s *FileService) RecordFileHistory(ctx context.Context, operation, dir, filename string, assetId, accountId int, sessionId ...string) error {
|
||||
// Extract user information from context
|
||||
currentUser, err := acl.GetSessionFromCtx(ctx)
|
||||
if err != nil || currentUser == nil {
|
||||
// If no user context, still record the operation but with empty user info
|
||||
logger.L().Warn("No user context found when recording file history", zap.String("operation", operation))
|
||||
}
|
||||
|
||||
var uid int
|
||||
var userName string
|
||||
var clientIP string
|
||||
|
||||
if currentUser != nil {
|
||||
uid = currentUser.GetUid()
|
||||
userName = currentUser.GetUserName()
|
||||
}
|
||||
|
||||
// Get client IP from gin context
|
||||
if ginCtx, ok := ctx.(*gin.Context); ok {
|
||||
clientIP = ginCtx.ClientIP()
|
||||
}
|
||||
|
||||
history := &model.FileHistory{
|
||||
Uid: uid,
|
||||
UserName: userName,
|
||||
AssetId: assetId,
|
||||
AccountId: accountId,
|
||||
ClientIp: clientIP,
|
||||
Action: s.GetActionCode(operation),
|
||||
Dir: dir,
|
||||
Filename: filename,
|
||||
}
|
||||
|
||||
if err := s.AddFileHistory(ctx, history); err != nil {
|
||||
// Log error details including sessionId if provided for debugging
|
||||
sessionIdStr := ""
|
||||
if len(sessionId) > 0 {
|
||||
sessionIdStr = sessionId[0]
|
||||
}
|
||||
logger.L().Error("Failed to record file history",
|
||||
zap.Error(err),
|
||||
zap.String("operation", operation),
|
||||
zap.String("sessionId", sessionIdStr),
|
||||
zap.Any("history", history))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileService) RecordFileHistoryBySession(ctx context.Context, sessionId, operation, path string) error {
|
||||
// Get session info to extract asset and account information
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
logger.L().Warn("Cannot record file history: session not found", zap.String("sessionId", sessionId))
|
||||
return fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
// Extract directory and filename from path
|
||||
dir := filepath.Dir(path)
|
||||
filename := filepath.Base(path)
|
||||
|
||||
return s.RecordFileHistory(ctx, operation, dir, filename, onlineSession.AssetId, onlineSession.AccountId, sessionId)
|
||||
}
|
||||
|
||||
// Utility methods
|
||||
func (s *FileService) GetActionCode(operation string) int {
|
||||
switch operation {
|
||||
case "upload":
|
||||
return model.FILE_ACTION_UPLOAD
|
||||
case "download":
|
||||
return model.FILE_ACTION_DOWNLOAD
|
||||
case "mkdir":
|
||||
return model.FILE_ACTION_MKDIR
|
||||
default:
|
||||
return model.FILE_ACTION_LS
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FileService) ValidateAndNormalizePath(basePath, userPath string) (string, error) {
|
||||
// Clean the user-provided path
|
||||
cleanPath := filepath.Clean(userPath)
|
||||
|
||||
// Ensure it doesn't contain directory traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return "", fmt.Errorf("path contains directory traversal: %s", userPath)
|
||||
}
|
||||
|
||||
// Join with base path and clean again
|
||||
fullPath := filepath.Join(basePath, cleanPath)
|
||||
|
||||
// Ensure the resulting path is still within the base directory
|
||||
relPath, err := filepath.Rel(basePath, fullPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid path: %s", userPath)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(relPath, "..") {
|
||||
return "", fmt.Errorf("path outside base directory: %s", userPath)
|
||||
}
|
||||
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
func (s *FileService) GetRDPDrivePath(assetId int) (string, error) {
|
||||
var drivePath string
|
||||
|
||||
// Priority 1: Get from environment variable (highest priority)
|
||||
drivePath = os.Getenv("ONETERM_RDP_DRIVE_PATH")
|
||||
|
||||
// Priority 2: Use default path based on OS
|
||||
if drivePath == "" {
|
||||
if runtime.GOOS == "windows" {
|
||||
drivePath = filepath.Join("C:", "temp", "oneterm", "rdp")
|
||||
} else {
|
||||
drivePath = filepath.Join("/tmp", "oneterm", "rdp")
|
||||
}
|
||||
}
|
||||
|
||||
// Create asset-specific subdirectory
|
||||
fullDrivePath := filepath.Join(drivePath, fmt.Sprintf("asset_%d", assetId))
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(fullDrivePath, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create RDP drive directory %s: %w", fullDrivePath, err)
|
||||
}
|
||||
|
||||
// Clear macOS extended attributes that might interfere with Docker volume mounting
|
||||
if runtime.GOOS == "darwin" {
|
||||
// Clear attributes for the directory and all its contents
|
||||
exec.Command("find", fullDrivePath, "-exec", "xattr", "-c", "{}", ";").Run()
|
||||
}
|
||||
|
||||
return fullDrivePath, nil
|
||||
}
|
||||
|
||||
// Simple FileRepository implementation
|
||||
type FileRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (r *FileRepository) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
|
||||
return r.db.Create(history).Error
|
||||
}
|
||||
|
||||
func (r *FileRepository) BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB {
|
||||
return r.db.Model(&model.FileHistory{})
|
||||
}
|
||||
465
backend/internal/service/file/rdp.go
Normal file
465
backend/internal/service/file/rdp.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/veops/oneterm/internal/guacd"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
)
|
||||
|
||||
// RDP file operation functions
|
||||
|
||||
// NewRDPProgressWriter creates a new RDP progress writer
|
||||
func NewRDPProgressWriter(writer io.Writer, transfer *guacd.FileTransfer, transferId string) *RDPProgressWriter {
|
||||
return &RDPProgressWriter{
|
||||
writer: writer,
|
||||
transfer: transfer,
|
||||
transferId: transferId,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (pw *RDPProgressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.writer.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
pw.written += int64(n)
|
||||
|
||||
// Update unified progress tracking
|
||||
UpdateTransferProgress(pw.transferId, 0, pw.written, "transferring")
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// IsRDPDriveEnabled checks if RDP drive is enabled
|
||||
func IsRDPDriveEnabled(tunnel *guacd.Tunnel) bool {
|
||||
if tunnel == nil || tunnel.Config == nil {
|
||||
return false
|
||||
}
|
||||
driveEnabled := tunnel.Config.Parameters["enable-drive"] == "true"
|
||||
return driveEnabled
|
||||
}
|
||||
|
||||
// IsRDPUploadAllowed checks if RDP upload is allowed
|
||||
func IsRDPUploadAllowed(tunnel *guacd.Tunnel) bool {
|
||||
if tunnel == nil || tunnel.Config == nil {
|
||||
return false
|
||||
}
|
||||
return tunnel.Config.Parameters["disable-upload"] != "true"
|
||||
}
|
||||
|
||||
// IsRDPDownloadAllowed checks if RDP download is allowed
|
||||
func IsRDPDownloadAllowed(tunnel *guacd.Tunnel) bool {
|
||||
if tunnel == nil || tunnel.Config == nil {
|
||||
return false
|
||||
}
|
||||
return tunnel.Config.Parameters["disable-download"] != "true"
|
||||
}
|
||||
|
||||
// RequestRDPFileList gets file list for RDP session
|
||||
func RequestRDPFileList(tunnel *guacd.Tunnel, path string) ([]RDPFileInfo, error) {
|
||||
// Implementation placeholder - this would need to be implemented based on Guacamole protocol
|
||||
return RequestRDPFileListViaDirect(tunnel, path)
|
||||
}
|
||||
|
||||
func RequestRDPFileListViaDirect(tunnel *guacd.Tunnel, path string) ([]RDPFileInfo, error) {
|
||||
// Get session to extract asset ID
|
||||
sessionId := tunnel.SessionId
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return nil, fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
// Get drive path with proper fallback handling
|
||||
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get drive path: %w", err)
|
||||
}
|
||||
|
||||
// Validate and construct full filesystem path
|
||||
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
|
||||
// Read directory contents
|
||||
entries, err := os.ReadDir(fullPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
var files []RDPFileInfo
|
||||
for _, entry := range entries {
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue // Skip entries with errors
|
||||
}
|
||||
|
||||
files = append(files, RDPFileInfo{
|
||||
Name: entry.Name(),
|
||||
Size: info.Size(),
|
||||
IsDir: entry.IsDir(),
|
||||
ModTime: info.ModTime().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// DownloadRDPFile downloads a single file from RDP session
|
||||
func DownloadRDPFile(tunnel *guacd.Tunnel, path string) (io.ReadCloser, int64, error) {
|
||||
// Get session to extract asset ID
|
||||
sessionId := tunnel.SessionId
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return nil, 0, fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
// Get drive path with proper fallback handling
|
||||
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to get drive path: %w", err)
|
||||
}
|
||||
|
||||
// Validate and construct full filesystem path
|
||||
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
|
||||
// Check if path exists and is a file
|
||||
info, err := os.Stat(fullPath)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("file not found: %w", err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil, 0, fmt.Errorf("path is a directory, not a file")
|
||||
}
|
||||
|
||||
// Open file for streaming (memory-efficient)
|
||||
file, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
|
||||
return file, info.Size(), nil
|
||||
}
|
||||
|
||||
// DownloadRDPMultiple downloads multiple files from RDP session as ZIP
|
||||
func DownloadRDPMultiple(tunnel *guacd.Tunnel, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
var sanitizedFilenames []string
|
||||
for _, filename := range filenames {
|
||||
if filename == "" || strings.Contains(filename, "..") || strings.Contains(filename, "/") {
|
||||
return nil, "", 0, fmt.Errorf("invalid filename: %s", filename)
|
||||
}
|
||||
sanitizedFilenames = append(sanitizedFilenames, filename)
|
||||
}
|
||||
|
||||
if len(sanitizedFilenames) == 1 {
|
||||
fullPath := filepath.Join(dir, sanitizedFilenames[0])
|
||||
|
||||
// Check if it's a directory or file
|
||||
sessionId := tunnel.SessionId
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return nil, "", 0, fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("failed to get drive path: %w", err)
|
||||
}
|
||||
|
||||
realPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(realPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("file not found: %w", err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// For directory, create a zip with directory contents
|
||||
return CreateRDPZip(tunnel, dir, sanitizedFilenames)
|
||||
} else {
|
||||
// For single file, download directly
|
||||
reader, fileSize, err := DownloadRDPFile(tunnel, fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
return reader, sanitizedFilenames[0], fileSize, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Multiple files/directories - always create zip
|
||||
return CreateRDPZip(tunnel, dir, sanitizedFilenames)
|
||||
}
|
||||
|
||||
// CreateRDPZip creates a ZIP archive of multiple RDP files
|
||||
func CreateRDPZip(tunnel *guacd.Tunnel, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
var buf bytes.Buffer
|
||||
zipWriter := zip.NewWriter(&buf)
|
||||
|
||||
for _, filename := range filenames {
|
||||
fullPath := filepath.Join(dir, filename)
|
||||
err := AddToRDPZip(tunnel, zipWriter, fullPath, filename)
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, "", 0, fmt.Errorf("failed to add %s to zip: %w", filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := zipWriter.Close(); err != nil {
|
||||
return nil, "", 0, fmt.Errorf("failed to close zip: %w", err)
|
||||
}
|
||||
|
||||
downloadFilename := fmt.Sprintf("rdp_files_%s.zip", time.Now().Format("20060102_150405"))
|
||||
reader := io.NopCloser(bytes.NewReader(buf.Bytes()))
|
||||
return reader, downloadFilename, int64(buf.Len()), nil
|
||||
}
|
||||
|
||||
// AddToRDPZip adds a file or directory to the ZIP archive
|
||||
func AddToRDPZip(tunnel *guacd.Tunnel, zipWriter *zip.Writer, fullPath, zipPath string) error {
|
||||
// Get session to extract asset ID
|
||||
sessionId := tunnel.SessionId
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
// Get drive path with proper fallback handling
|
||||
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get drive path: %w", err)
|
||||
}
|
||||
|
||||
// Validate and construct full filesystem path
|
||||
realPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
|
||||
// Check if path exists
|
||||
info, err := os.Stat(realPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("file not found: %w", err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Add directory entries recursively
|
||||
entries, err := os.ReadDir(realPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory: %w", err)
|
||||
}
|
||||
|
||||
// Create directory entry in zip if not empty
|
||||
if len(entries) == 0 {
|
||||
// Create empty directory entry
|
||||
dirHeader := &zip.FileHeader{
|
||||
Name: zipPath + "/",
|
||||
Method: zip.Store,
|
||||
}
|
||||
_, err := zipWriter.CreateHeader(dirHeader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory entry: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Add all files in directory
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join(fullPath, entry.Name())
|
||||
entryZipPath := zipPath + "/" + entry.Name()
|
||||
err := AddToRDPZip(tunnel, zipWriter, entryPath, entryZipPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add file to zip
|
||||
file, err := os.Open(realPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
writer, err := zipWriter.Create(zipPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create zip entry: %w", err)
|
||||
}
|
||||
|
||||
// Stream file content to zip (memory-efficient)
|
||||
if _, err := io.Copy(writer, file); err != nil {
|
||||
return fmt.Errorf("failed to write file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateRDPDirectory creates a directory in RDP session
|
||||
func CreateRDPDirectory(tunnel *guacd.Tunnel, path string) error {
|
||||
// Get session to extract asset ID
|
||||
sessionId := tunnel.SessionId
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
// Get drive path with proper fallback handling
|
||||
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get drive path: %w", err)
|
||||
}
|
||||
|
||||
// Validate and construct full filesystem path
|
||||
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
|
||||
// Create directory with proper permissions
|
||||
if err := os.MkdirAll(fullPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Send refresh notification to RDP session
|
||||
NotifyRDPDirectoryRefresh(sessionId)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UploadRDPFileStreamWithID uploads file to RDP session with progress tracking
|
||||
func UploadRDPFileStreamWithID(tunnel *guacd.Tunnel, transferID, sessionId, path string, reader io.Reader, totalSize int64) error {
|
||||
// Get session to extract asset ID
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return fmt.Errorf("session not found: %s", sessionId)
|
||||
}
|
||||
|
||||
// Get drive path with proper fallback handling
|
||||
drivePath, err := DefaultFileService.GetRDPDrivePath(onlineSession.AssetId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get drive path: %w", err)
|
||||
}
|
||||
|
||||
// Create transfer tracker
|
||||
var transfer *guacd.FileTransfer
|
||||
if transferID != "" {
|
||||
transfer, err = guacd.DefaultFileTransferManager.CreateUploadWithID(transferID, sessionId, filepath.Base(path), drivePath)
|
||||
} else {
|
||||
transfer, err = guacd.DefaultFileTransferManager.CreateUpload(sessionId, filepath.Base(path), drivePath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create transfer tracker: %w", err)
|
||||
}
|
||||
// Note: Don't remove transfer immediately - let it be cleaned up later so progress can be queried
|
||||
|
||||
transfer.SetSize(totalSize)
|
||||
|
||||
// Validate and construct full filesystem path
|
||||
fullPath, err := DefaultFileService.ValidateAndNormalizePath(drivePath, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid path: %w", err)
|
||||
}
|
||||
|
||||
destDir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(destDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||
}
|
||||
|
||||
destFile, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
progressWriter := NewRDPProgressWriter(destFile, transfer, transferID)
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
written, err := io.CopyBuffer(progressWriter, reader, buffer)
|
||||
if err != nil {
|
||||
os.Remove(fullPath)
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
if totalSize > 0 && written != totalSize {
|
||||
os.Remove(fullPath)
|
||||
// Mark as failed in unified tracking
|
||||
UpdateTransferProgress(transferID, 0, -1, "failed")
|
||||
return fmt.Errorf("file size mismatch: expected %d, wrote %d", totalSize, written)
|
||||
}
|
||||
|
||||
// CRITICAL: Explicitly close and sync file before marking as completed
|
||||
// This ensures the file is fully written to disk and visible to mounted containers
|
||||
if err := destFile.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close file: %w", err)
|
||||
}
|
||||
|
||||
// Clear macOS extended attributes that might interfere with Docker volume mounting
|
||||
if runtime.GOOS == "darwin" {
|
||||
// Clear attributes for the file and parent directory
|
||||
exec.Command("xattr", "-c", fullPath).Run()
|
||||
exec.Command("xattr", "-c", filepath.Dir(fullPath)).Run()
|
||||
}
|
||||
|
||||
// Mark as completed in unified tracking
|
||||
UpdateTransferProgress(transferID, 0, written, "completed")
|
||||
|
||||
// Send refresh notification to frontend with a slight delay
|
||||
// This ensures the file system operations are fully completed
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
NotifyRDPDirectoryRefresh(sessionId)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRDPTransferProgressById gets RDP transfer progress by ID
|
||||
func GetRDPTransferProgressById(transferId string) (interface{}, error) {
|
||||
progress, err := guacd.DefaultFileTransferManager.GetTransferProgress(transferId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return progress, nil
|
||||
}
|
||||
|
||||
// NotifyRDPDirectoryRefresh sends F5 key to refresh Windows Explorer
|
||||
func NotifyRDPDirectoryRefresh(sessionId string) {
|
||||
// Get the active session and tunnel
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return
|
||||
}
|
||||
|
||||
tunnel := onlineSession.GuacdTunnel
|
||||
if tunnel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send F5 key to refresh Windows Explorer
|
||||
// F5 key code: 65474
|
||||
f5DownInstruction := guacd.NewInstruction("key", "65474", "1")
|
||||
if _, err := tunnel.WriteInstruction(f5DownInstruction); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
f5UpInstruction := guacd.NewInstruction("key", "65474", "0")
|
||||
tunnel.WriteInstruction(f5UpInstruction)
|
||||
}
|
||||
475
backend/internal/service/file/sftp.go
Normal file
475
backend/internal/service/file/sftp.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// SFTP Operations - Managers defined in parent file service
|
||||
// =============================================================================
|
||||
|
||||
// =============================================================================
|
||||
// SFTP Upload/Download Operations with Progress Tracking
|
||||
// =============================================================================
|
||||
|
||||
// TransferToTarget handles transfer routing (session-based or asset-based)
|
||||
func TransferToTarget(transferId, sessionIdOrCustom, tempFilePath, targetPath string, assetId, accountId int) error {
|
||||
// For session-based transfers, try to reuse existing SFTP connection first
|
||||
if assetId == 0 && accountId == 0 && sessionIdOrCustom != "" {
|
||||
return SessionBasedTransfer(transferId, sessionIdOrCustom, tempFilePath, targetPath)
|
||||
}
|
||||
|
||||
// For asset/account-based transfers, fall back to creating new connection
|
||||
return AssetBasedTransfer(transferId, tempFilePath, targetPath, assetId, accountId)
|
||||
}
|
||||
|
||||
// SessionBasedTransfer uses existing session SFTP connection for optimal performance
|
||||
func SessionBasedTransfer(transferId, sessionId, tempFilePath, targetPath string) error {
|
||||
// Try to get existing SFTP client from session manager
|
||||
sessionFM := GetSessionFileManager()
|
||||
sftpClient, err := sessionFM.GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
// If no existing connection, create one
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
if onlineSession == nil {
|
||||
return fmt.Errorf("session %s not found", sessionId)
|
||||
}
|
||||
|
||||
// Initialize SFTP connection for this session
|
||||
if initErr := sessionFM.InitSessionSFTP(sessionId, onlineSession.AssetId, onlineSession.AccountId); initErr != nil {
|
||||
return fmt.Errorf("failed to initialize SFTP for session %s: %w", sessionId, initErr)
|
||||
}
|
||||
|
||||
// Get the newly created SFTP client
|
||||
sftpClient, err = sessionFM.GetSessionSFTP(sessionId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get SFTP client for session %s: %w", sessionId, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Use existing SFTP client for transfer (no need to close it as it's managed by SessionFileManager)
|
||||
return SftpUploadWithExistingClient(sftpClient, transferId, tempFilePath, targetPath)
|
||||
}
|
||||
|
||||
// AssetBasedTransfer creates new connection for asset/account-based transfers (legacy)
|
||||
func AssetBasedTransfer(transferId, tempFilePath, targetPath string, assetId, accountId int) error {
|
||||
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get asset/account info: %w", err)
|
||||
}
|
||||
sessionId := fmt.Sprintf("upload_%d_%d_%d", assetId, accountId, time.Now().UnixNano())
|
||||
|
||||
// Get SSH connection details
|
||||
ip, port, err := tunneling.Proxy(false, sessionId, "ssh", asset, gateway)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup tunnel: %w", err)
|
||||
}
|
||||
|
||||
auth, err := repository.GetAuth(account)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get auth: %w", err)
|
||||
}
|
||||
|
||||
// Create SSH client with maximum performance optimizations for SFTP
|
||||
sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
|
||||
User: account.Account,
|
||||
Auth: []ssh.AuthMethod{auth},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 30 * time.Second,
|
||||
// Ultra-high performance optimizations - fastest algorithms first
|
||||
Config: ssh.Config{
|
||||
Ciphers: []string{
|
||||
"aes128-ctr", // Fastest for most CPUs with AES-NI
|
||||
"aes128-gcm@openssh.com", // Hardware accelerated AEAD cipher
|
||||
"chacha20-poly1305@openssh.com", // Fast on ARM/systems without AES-NI
|
||||
"aes256-ctr", // Fallback high-performance option
|
||||
},
|
||||
MACs: []string{
|
||||
"hmac-sha2-256-etm@openssh.com", // Encrypt-then-MAC (fastest + most secure)
|
||||
"hmac-sha2-256", // Standard high-performance MAC
|
||||
},
|
||||
KeyExchanges: []string{
|
||||
"curve25519-sha256@libssh.org", // Modern elliptic curve (fastest)
|
||||
"curve25519-sha256", // Equivalent modern KEX
|
||||
"ecdh-sha2-nistp256", // Fast NIST curve fallback
|
||||
},
|
||||
},
|
||||
// Optimize connection algorithms for speed
|
||||
HostKeyAlgorithms: []string{
|
||||
"rsa-sha2-256", // Fast RSA with SHA-2
|
||||
"rsa-sha2-512", // Alternative fast RSA
|
||||
"ssh-ed25519", // Modern EdDSA (very fast verification)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect SSH: %w", err)
|
||||
}
|
||||
defer sshClient.Close()
|
||||
|
||||
// Use optimized SFTP to transfer file
|
||||
return SftpUploadWithProgress(sshClient, transferId, tempFilePath, targetPath)
|
||||
}
|
||||
|
||||
// SftpUploadWithProgress uploads file using optimized SFTP protocol with accurate progress tracking
|
||||
func SftpUploadWithProgress(client *ssh.Client, transferId, localPath, remotePath string) error {
|
||||
// Create SFTP client with maximum performance settings
|
||||
sftpClient, err := sftp.NewClient(client,
|
||||
sftp.MaxPacket(1024*32), // 32KB packets - maximum safe size for most servers
|
||||
sftp.MaxConcurrentRequestsPerFile(64), // High concurrency for maximum throughput
|
||||
sftp.UseConcurrentReads(true), // Enable concurrent reads for better performance
|
||||
sftp.UseConcurrentWrites(true), // Enable concurrent writes for better performance
|
||||
)
|
||||
if err != nil {
|
||||
logger.L().Error("Failed to create SFTP client",
|
||||
zap.String("transferId", transferId),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("failed to create SFTP client: %w", err)
|
||||
}
|
||||
defer sftpClient.Close()
|
||||
|
||||
// Open local file
|
||||
localFile, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open local file: %w", err)
|
||||
}
|
||||
defer localFile.Close()
|
||||
|
||||
// Get file info
|
||||
fileInfo, err := localFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory on remote if needed
|
||||
parentDir := filepath.Dir(remotePath)
|
||||
if parentDir != "" && parentDir != "." && parentDir != "/" {
|
||||
if err := sftpClient.MkdirAll(parentDir); err != nil {
|
||||
logger.L().Warn("Failed to create parent directory via SFTP", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Create remote file
|
||||
remoteFile, err := sftpClient.Create(remotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create remote file: %w", err)
|
||||
}
|
||||
defer remoteFile.Close()
|
||||
|
||||
// Create progress tracking writer with SFTP-specific optimizations
|
||||
progressWriter := NewFileProgressWriter(remoteFile, transferId)
|
||||
|
||||
// Transfer file content with ultra-high performance buffer for SFTP
|
||||
// Use 2MB buffer to minimize round trips and maximize throughput
|
||||
buffer := make([]byte, 2*1024*1024) // 2MB buffer for ultra-high SFTP performance
|
||||
|
||||
// Manual optimized copy loop to avoid io.CopyBuffer overhead
|
||||
var transferred int64
|
||||
for {
|
||||
n, readErr := localFile.Read(buffer)
|
||||
if n > 0 {
|
||||
written, writeErr := progressWriter.Write(buffer[:n])
|
||||
transferred += int64(written)
|
||||
if writeErr != nil {
|
||||
err = writeErr
|
||||
break
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
break // Normal end of file
|
||||
}
|
||||
err = readErr
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.L().Error("SFTP file transfer failed during copy",
|
||||
zap.String("transferId", transferId),
|
||||
zap.Int64("transferred", transferred),
|
||||
zap.Int64("fileSize", fileInfo.Size()),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("failed to transfer file content via SFTP: %w", err)
|
||||
}
|
||||
|
||||
// Force final progress update
|
||||
UpdateTransferProgress(transferId, 0, transferred, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// SftpUploadWithExistingClient uploads file using existing SFTP client with accurate progress tracking
|
||||
func SftpUploadWithExistingClient(client *sftp.Client, transferId, localPath, remotePath string) error {
|
||||
// Open local file
|
||||
localFile, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open local file: %w", err)
|
||||
}
|
||||
defer localFile.Close()
|
||||
|
||||
// Get file info
|
||||
fileInfo, err := localFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
// Create parent directory on remote if needed
|
||||
parentDir := filepath.Dir(remotePath)
|
||||
if parentDir != "" && parentDir != "." && parentDir != "/" {
|
||||
if err := client.MkdirAll(parentDir); err != nil {
|
||||
logger.L().Warn("Failed to create parent directory via SFTP", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Create remote file
|
||||
remoteFile, err := client.Create(remotePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create remote file: %w", err)
|
||||
}
|
||||
defer remoteFile.Close()
|
||||
|
||||
// Create progress tracking writer
|
||||
progressWriter := NewFileProgressWriter(remoteFile, transferId)
|
||||
|
||||
// Transfer file content with ultra-high performance buffer for SFTP
|
||||
// Use 2MB buffer to minimize round trips and maximize throughput
|
||||
buffer := make([]byte, 2*1024*1024) // 2MB buffer for ultra-high SFTP performance
|
||||
var transferred int64
|
||||
|
||||
for {
|
||||
n, readErr := localFile.Read(buffer)
|
||||
if n > 0 {
|
||||
written, writeErr := progressWriter.Write(buffer[:n])
|
||||
transferred += int64(written)
|
||||
if writeErr != nil {
|
||||
err = writeErr
|
||||
break
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
break // Normal end of file
|
||||
}
|
||||
err = readErr
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.L().Error("SFTP file transfer failed",
|
||||
zap.String("transferId", transferId),
|
||||
zap.Int64("transferred", transferred),
|
||||
zap.Int64("fileSize", fileInfo.Size()),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("failed to transfer file: %w", err)
|
||||
}
|
||||
|
||||
// Force final progress update
|
||||
UpdateTransferProgress(transferId, 0, transferred, "")
|
||||
logger.L().Info("SFTP file transfer completed",
|
||||
zap.String("transferId", transferId),
|
||||
zap.String("remotePath", remotePath),
|
||||
zap.Int64("size", transferred))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SFTP Download Operations with ZIP Support
|
||||
// =============================================================================
|
||||
|
||||
// SftpDownloadMultiple downloads multiple files as ZIP or single file
|
||||
func SftpDownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
cli, err := GetFileManager().GetFileClient(assetId, accountId)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("failed to get SFTP client: %w", err)
|
||||
}
|
||||
|
||||
if len(filenames) == 1 {
|
||||
// Single file download
|
||||
fullPath := filepath.Join(dir, filenames[0])
|
||||
file, err := cli.Open(fullPath)
|
||||
if err != nil {
|
||||
return nil, "", 0, fmt.Errorf("failed to open file %s: %w", fullPath, err)
|
||||
}
|
||||
|
||||
// Get file size
|
||||
info, err := cli.Stat(fullPath)
|
||||
if err != nil {
|
||||
file.Close()
|
||||
return nil, "", 0, fmt.Errorf("failed to get file info: %w", err)
|
||||
}
|
||||
|
||||
return file, filenames[0], info.Size(), nil
|
||||
}
|
||||
|
||||
// Multiple files - create ZIP
|
||||
return createSftpZipArchive(cli, dir, filenames)
|
||||
}
|
||||
|
||||
// createSftpZipArchive creates a ZIP archive of multiple SFTP files
|
||||
func createSftpZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
|
||||
// Create a buffer to write the ZIP archive
|
||||
var buffer bytes.Buffer
|
||||
zipWriter := zip.NewWriter(&buffer)
|
||||
|
||||
for _, filename := range filenames {
|
||||
fullPath := filepath.Join(baseDir, filename)
|
||||
if err := addSftpFileToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, "", 0, fmt.Errorf("failed to add %s to ZIP: %w", filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := zipWriter.Close(); err != nil {
|
||||
return nil, "", 0, fmt.Errorf("failed to close ZIP writer: %w", err)
|
||||
}
|
||||
|
||||
// Generate ZIP filename
|
||||
var zipFilename string
|
||||
if len(filenames) == 1 {
|
||||
zipFilename = strings.TrimSuffix(filenames[0], filepath.Ext(filenames[0])) + ".zip"
|
||||
} else {
|
||||
zipFilename = fmt.Sprintf("sftp_files_%d_items.zip", len(filenames))
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(buffer.Bytes())
|
||||
return io.NopCloser(reader), zipFilename, int64(buffer.Len()), nil
|
||||
}
|
||||
|
||||
// addSftpFileToZip adds a file or directory to the ZIP archive
|
||||
func addSftpFileToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
|
||||
// Get file info
|
||||
info, err := cli.Stat(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat %s: %w", fullPath, err)
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
// Handle directory
|
||||
return addSftpDirToZip(cli, zipWriter, baseDir, relativePath, fullPath)
|
||||
}
|
||||
|
||||
// Handle regular file
|
||||
return addSftpRegularFileToZip(cli, zipWriter, fullPath, relativePath)
|
||||
}
|
||||
|
||||
// addSftpRegularFileToZip adds a regular file to ZIP
|
||||
func addSftpRegularFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
|
||||
// Open remote file
|
||||
file, err := cli.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file %s: %w", fullPath, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create ZIP entry
|
||||
header := &zip.FileHeader{
|
||||
Name: relativePath,
|
||||
Method: zip.Deflate,
|
||||
}
|
||||
|
||||
writer, err := zipWriter.CreateHeader(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ZIP entry: %w", err)
|
||||
}
|
||||
|
||||
// Copy file content to ZIP
|
||||
_, err = io.Copy(writer, file)
|
||||
return err
|
||||
}
|
||||
|
||||
// addSftpDirToZip adds a directory to ZIP recursively
|
||||
func addSftpDirToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
|
||||
// Read directory contents
|
||||
entries, err := cli.ReadDir(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read directory %s: %w", fullPath, err)
|
||||
}
|
||||
|
||||
// Add directory entry to ZIP
|
||||
if relativePath != "" && relativePath != "." {
|
||||
header := &zip.FileHeader{
|
||||
Name: relativePath + "/",
|
||||
}
|
||||
_, err = zipWriter.CreateHeader(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory entry: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add directory contents recursively
|
||||
for _, entry := range entries {
|
||||
entryRelPath := filepath.Join(relativePath, entry.Name())
|
||||
entryFullPath := filepath.Join(fullPath, entry.Name())
|
||||
|
||||
if err := addSftpFileToZip(cli, zipWriter, baseDir, entryRelPath, entryFullPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SFTP Progress Writers
|
||||
// =============================================================================
|
||||
|
||||
// SftpProgressWriter tracks SFTP transfer progress
|
||||
type SftpProgressWriter struct {
|
||||
writer io.Writer
|
||||
transferId string
|
||||
written int64
|
||||
lastUpdate time.Time
|
||||
updateBytes int64 // Bytes written since last progress update
|
||||
updateTicker int64 // Simple counter to reduce time.Now() calls
|
||||
}
|
||||
|
||||
// NewSftpProgressWriter creates a new SFTP progress writer
|
||||
func NewSftpProgressWriter(writer io.Writer, transferId string) *SftpProgressWriter {
|
||||
return &SftpProgressWriter{
|
||||
writer: writer,
|
||||
transferId: transferId,
|
||||
lastUpdate: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (pw *SftpProgressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.writer.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
pw.written += int64(n)
|
||||
pw.updateBytes += int64(n)
|
||||
pw.updateTicker++
|
||||
|
||||
// Update progress every 64KB bytes OR every 1000 write calls (reduces time.Now() overhead)
|
||||
if pw.updateBytes >= 65536 || pw.updateTicker >= 1000 {
|
||||
now := time.Now()
|
||||
// Only update if enough time has passed (reduce lock contention)
|
||||
if pw.updateBytes >= 65536 || now.Sub(pw.lastUpdate) >= 50*time.Millisecond {
|
||||
UpdateTransferProgress(pw.transferId, 0, pw.written, "")
|
||||
pw.lastUpdate = now
|
||||
pw.updateBytes = 0
|
||||
pw.updateTicker = 0
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
722
backend/internal/service/file/types.go
Normal file
722
backend/internal/service/file/types.go
Normal file
@@ -0,0 +1,722 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/sftp"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/veops/oneterm/internal/guacd"
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
"github.com/veops/oneterm/internal/repository"
|
||||
gsession "github.com/veops/oneterm/internal/session"
|
||||
"github.com/veops/oneterm/internal/tunneling"
|
||||
"github.com/veops/oneterm/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// Legacy asset-based file manager
|
||||
fm = &FileManager{
|
||||
sftps: map[string]*sftp.Client{},
|
||||
lastTime: map[string]time.Time{},
|
||||
mtx: sync.Mutex{},
|
||||
}
|
||||
|
||||
// New session-based file manager
|
||||
sessionFM = &SessionFileManager{
|
||||
sessionSFTP: make(map[string]*sftp.Client),
|
||||
sessionSSH: make(map[string]*ssh.Client),
|
||||
lastActive: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// Session file transfer manager
|
||||
sessionTransferManager = &SessionFileTransferManager{
|
||||
transfers: make(map[string]*SessionFileTransfer),
|
||||
}
|
||||
|
||||
// Progress tracking state
|
||||
fileTransfers = make(map[string]*FileTransferProgress)
|
||||
transferMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// Session-based file operation errors
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrSessionClosed = errors.New("session has been closed")
|
||||
ErrSessionInactive = errors.New("session is inactive")
|
||||
ErrSFTPNotAvailable = errors.New("SFTP not available for this session")
|
||||
)
|
||||
|
||||
// Global getter functions
|
||||
func GetFileManager() *FileManager {
|
||||
return fm
|
||||
}
|
||||
|
||||
func GetSessionFileManager() *SessionFileManager {
|
||||
return sessionFM
|
||||
}
|
||||
|
||||
func GetSessionTransferManager() *SessionFileTransferManager {
|
||||
return sessionTransferManager
|
||||
}
|
||||
|
||||
// FileInfo represents file information
|
||||
type FileInfo struct {
|
||||
Name string `json:"name"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
Size int64 `json:"size"`
|
||||
Mode string `json:"mode"`
|
||||
IsLink bool `json:"is_link"`
|
||||
Target string `json:"target"`
|
||||
ModTime string `json:"mod_time"`
|
||||
}
|
||||
|
||||
// FileManager manages SFTP connections (legacy asset-based)
|
||||
type FileManager struct {
|
||||
sftps map[string]*sftp.Client
|
||||
lastTime map[string]time.Time
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client, err error) {
|
||||
fm.mtx.Lock()
|
||||
defer fm.mtx.Unlock()
|
||||
|
||||
key := fmt.Sprintf("%d-%d", assetId, accountId)
|
||||
defer func() {
|
||||
fm.lastTime[key] = time.Now()
|
||||
}()
|
||||
|
||||
cli, ok := fm.sftps[key]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ip, port, err := tunneling.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
auth, err := repository.GetAuth(account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sshCli, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
|
||||
User: account.Account,
|
||||
Auth: []ssh.AuthMethod{auth},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Create optimized SFTP client
|
||||
cli, err = sftp.NewClient(sshCli,
|
||||
sftp.MaxPacket(32768), // 32KB packets for maximum compatibility
|
||||
sftp.MaxConcurrentRequestsPerFile(16), // Increase concurrent requests
|
||||
sftp.UseConcurrentWrites(true), // Enable concurrent writes
|
||||
sftp.UseConcurrentReads(true), // Enable concurrent reads
|
||||
sftp.UseFstat(false), // Disable fstat for better compatibility
|
||||
)
|
||||
if err != nil {
|
||||
sshCli.Close()
|
||||
return
|
||||
}
|
||||
fm.sftps[key] = cli
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SessionFileManager manages SFTP connections per session
|
||||
type SessionFileManager struct {
|
||||
sessionSFTP map[string]*sftp.Client // sessionId -> SFTP client
|
||||
sessionSSH map[string]*ssh.Client // sessionId -> SSH client
|
||||
lastActive map[string]time.Time // sessionId -> last active time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (sfm *SessionFileManager) InitSessionSFTP(sessionId string, assetId, accountId int) error {
|
||||
sfm.mutex.Lock()
|
||||
defer sfm.mutex.Unlock()
|
||||
|
||||
// Check if already exists
|
||||
if _, exists := sfm.sessionSFTP[sessionId]; exists {
|
||||
sfm.lastActive[sessionId] = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CRITICAL OPTIMIZATION: Try to reuse existing SSH connection from terminal session
|
||||
onlineSession := gsession.GetOnlineSessionById(sessionId)
|
||||
var sshClient *ssh.Client
|
||||
var shouldCloseClient = false
|
||||
|
||||
if onlineSession != nil && onlineSession.HasSSHClient() {
|
||||
sshClient = onlineSession.GetSSHClient()
|
||||
logger.L().Info("REUSING existing SSH connection from terminal session",
|
||||
zap.String("sessionId", sessionId))
|
||||
} else {
|
||||
// Fallback: Create new SSH connection if no existing connection found
|
||||
asset, account, gateway, err := repository.GetAAG(assetId, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use sessionId as proxy identifier for connection reuse
|
||||
ip, port, err := tunneling.Proxy(false, sessionId, "sftp,ssh", asset, gateway)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
auth, err := repository.GetAuth(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshClient, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
|
||||
User: account.Account,
|
||||
Auth: []ssh.AuthMethod{auth},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect SSH for session %s: %w", sessionId, err)
|
||||
}
|
||||
shouldCloseClient = true // We created it, so we manage its lifecycle
|
||||
logger.L().Info("Created new SSH connection for file transfer",
|
||||
zap.String("sessionId", sessionId))
|
||||
}
|
||||
|
||||
// Create SFTP client with optimized settings for better performance
|
||||
sftpClient, err := sftp.NewClient(sshClient,
|
||||
sftp.MaxPacket(32768), // 32KB packets for maximum compatibility
|
||||
sftp.MaxConcurrentRequestsPerFile(16), // Increase concurrent requests per file (default is 3)
|
||||
sftp.UseConcurrentWrites(true), // Enable concurrent writes
|
||||
sftp.UseConcurrentReads(true), // Enable concurrent reads
|
||||
sftp.UseFstat(false), // Disable fstat for better compatibility
|
||||
)
|
||||
if err != nil {
|
||||
if shouldCloseClient {
|
||||
sshClient.Close()
|
||||
}
|
||||
return fmt.Errorf("failed to create SFTP client for session %s: %w", sessionId, err)
|
||||
}
|
||||
|
||||
// Store clients (only store SSH client if we created it)
|
||||
sfm.sessionSFTP[sessionId] = sftpClient
|
||||
if shouldCloseClient {
|
||||
sfm.sessionSSH[sessionId] = sshClient
|
||||
}
|
||||
sfm.lastActive[sessionId] = time.Now()
|
||||
|
||||
logger.L().Info("SFTP connection initialized for session",
|
||||
zap.String("sessionId", sessionId),
|
||||
zap.Bool("reusedConnection", !shouldCloseClient))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sfm *SessionFileManager) GetSessionSFTP(sessionId string) (*sftp.Client, error) {
|
||||
sfm.mutex.RLock()
|
||||
defer sfm.mutex.RUnlock()
|
||||
|
||||
client, exists := sfm.sessionSFTP[sessionId]
|
||||
if !exists {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
// Update last active time
|
||||
sfm.lastActive[sessionId] = time.Now()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (sfm *SessionFileManager) CloseSessionSFTP(sessionId string) {
|
||||
sfm.mutex.Lock()
|
||||
defer sfm.mutex.Unlock()
|
||||
|
||||
if sftpClient, exists := sfm.sessionSFTP[sessionId]; exists {
|
||||
sftpClient.Close()
|
||||
delete(sfm.sessionSFTP, sessionId)
|
||||
}
|
||||
|
||||
// Only close SSH client if we created it (not reused from terminal session)
|
||||
if sshClient, exists := sfm.sessionSSH[sessionId]; exists {
|
||||
sshClient.Close()
|
||||
delete(sfm.sessionSSH, sessionId)
|
||||
logger.L().Info("SFTP SSH connection closed for session", zap.String("sessionId", sessionId))
|
||||
} else {
|
||||
logger.L().Info("SFTP connection closed for session (SSH connection reused)", zap.String("sessionId", sessionId))
|
||||
}
|
||||
|
||||
delete(sfm.lastActive, sessionId)
|
||||
}
|
||||
|
||||
func (sfm *SessionFileManager) IsSessionActive(sessionId string) bool {
|
||||
sfm.mutex.RLock()
|
||||
defer sfm.mutex.RUnlock()
|
||||
|
||||
_, exists := sfm.sessionSFTP[sessionId]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (sfm *SessionFileManager) CleanupInactiveSessions(timeout time.Duration) {
|
||||
sfm.mutex.Lock()
|
||||
defer sfm.mutex.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-timeout)
|
||||
for sessionId, lastActive := range sfm.lastActive {
|
||||
if lastActive.Before(cutoff) {
|
||||
// Close and remove inactive session
|
||||
if sftpClient, exists := sfm.sessionSFTP[sessionId]; exists {
|
||||
sftpClient.Close()
|
||||
delete(sfm.sessionSFTP, sessionId)
|
||||
}
|
||||
|
||||
if sshClient, exists := sfm.sessionSSH[sessionId]; exists {
|
||||
sshClient.Close()
|
||||
delete(sfm.sessionSSH, sessionId)
|
||||
}
|
||||
|
||||
delete(sfm.lastActive, sessionId)
|
||||
logger.L().Info("Cleaned up inactive SFTP session",
|
||||
zap.String("sessionId", sessionId),
|
||||
zap.Duration("inactiveFor", time.Since(lastActive)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sfm *SessionFileManager) GetActiveSessionCount() int {
|
||||
sfm.mutex.RLock()
|
||||
defer sfm.mutex.RUnlock()
|
||||
return len(sfm.sessionSFTP)
|
||||
}
|
||||
|
||||
// IFileService defines the file service interface
|
||||
type IFileService interface {
|
||||
// Legacy asset-based operations (for backward compatibility)
|
||||
ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error)
|
||||
MkdirAll(ctx context.Context, assetId, accountId int, dir string) error
|
||||
Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error)
|
||||
Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error)
|
||||
Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error)
|
||||
DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error)
|
||||
|
||||
// Session-based operations (NEW - high performance)
|
||||
SessionLS(ctx context.Context, sessionId, dir string) ([]FileInfo, error)
|
||||
SessionMkdir(ctx context.Context, sessionId, dir string) error
|
||||
SessionUpload(ctx context.Context, sessionId, targetPath string, file io.Reader, filename string, size int64) error
|
||||
SessionUploadWithID(ctx context.Context, transferID, sessionId, targetPath string, file io.Reader, filename string, size int64) error
|
||||
SessionDownload(ctx context.Context, sessionId, filePath string) (io.ReadCloser, int64, error)
|
||||
SessionDownloadMultiple(ctx context.Context, sessionId, dir string, filenames []string) (io.ReadCloser, string, int64, error)
|
||||
|
||||
// Session lifecycle management
|
||||
InitSessionFileClient(sessionId string, assetId, accountId int) error
|
||||
CloseSessionFileClient(sessionId string)
|
||||
IsSessionActive(sessionId string) bool
|
||||
|
||||
// Progress tracking for session transfers
|
||||
GetSessionTransferProgress(ctx context.Context, sessionId string) ([]*SessionFileTransferProgress, error)
|
||||
GetTransferProgress(ctx context.Context, transferId string) (*SessionFileTransferProgress, error)
|
||||
|
||||
// File history and other operations
|
||||
AddFileHistory(ctx context.Context, history *model.FileHistory) error
|
||||
BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB
|
||||
RecordFileHistory(ctx context.Context, operation, dir, filename string, assetId, accountId int, sessionId ...string) error
|
||||
RecordFileHistoryBySession(ctx context.Context, sessionId, operation, path string) error
|
||||
|
||||
// Utility methods
|
||||
GetActionCode(operation string) int
|
||||
ValidateAndNormalizePath(basePath, userPath string) (string, error)
|
||||
GetRDPDrivePath(assetId int) (string, error)
|
||||
}
|
||||
|
||||
// FileService implements IFileService
|
||||
type FileService struct {
|
||||
repo IFileRepository
|
||||
}
|
||||
|
||||
// RDP File related structures
|
||||
type RDPFileInfo struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
ModTime string `json:"mod_time"`
|
||||
}
|
||||
|
||||
type RDPMkdirRequest struct {
|
||||
Path string `json:"path" binding:"required"`
|
||||
}
|
||||
|
||||
type RDPProgressWriter struct {
|
||||
writer io.Writer
|
||||
transfer *guacd.FileTransfer
|
||||
transferId string
|
||||
written int64
|
||||
}
|
||||
|
||||
// Session transfer related structures
|
||||
type SessionFileTransfer struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Filename string `json:"filename"`
|
||||
Size int64 `json:"size"`
|
||||
Offset int64 `json:"offset"`
|
||||
Status string `json:"status"` // "pending", "uploading", "completed", "failed"
|
||||
IsUpload bool `json:"is_upload"`
|
||||
Created time.Time `json:"created"`
|
||||
Updated time.Time `json:"updated"`
|
||||
Error string `json:"error,omitempty"`
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type SessionFileTransferProgress struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Filename string `json:"filename"`
|
||||
Size int64 `json:"size"`
|
||||
Offset int64 `json:"offset"`
|
||||
Percentage float64 `json:"percentage"`
|
||||
Status string `json:"status"`
|
||||
IsUpload bool `json:"is_upload"`
|
||||
Created time.Time `json:"created"`
|
||||
Updated time.Time `json:"updated"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Speed int64 `json:"speed"` // bytes per second
|
||||
ETA int64 `json:"eta"` // estimated time to completion in seconds
|
||||
}
|
||||
|
||||
type SessionFileTransferManager struct {
|
||||
transfers map[string]*SessionFileTransfer
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) CreateTransfer(sessionID, filename string, size int64, isUpload bool) *SessionFileTransfer {
|
||||
return m.CreateTransferWithID(generateTransferID(), sessionID, filename, size, isUpload)
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) CreateTransferWithID(transferID, sessionID, filename string, size int64, isUpload bool) *SessionFileTransfer {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
transfer := &SessionFileTransfer{
|
||||
ID: transferID,
|
||||
SessionID: sessionID,
|
||||
Filename: filename,
|
||||
Size: size,
|
||||
Offset: 0,
|
||||
Status: "pending",
|
||||
IsUpload: isUpload,
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
|
||||
m.transfers[transferID] = transfer
|
||||
return transfer
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) GetTransfer(id string) *SessionFileTransfer {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.transfers[id]
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) GetTransfersBySession(sessionID string) []*SessionFileTransfer {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
var transfers []*SessionFileTransfer
|
||||
for _, transfer := range m.transfers {
|
||||
if transfer.SessionID == sessionID {
|
||||
transfers = append(transfers, transfer)
|
||||
}
|
||||
}
|
||||
return transfers
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) GetSessionProgress(sessionID string) []*SessionFileTransferProgress {
|
||||
transfers := m.GetTransfersBySession(sessionID)
|
||||
var progressList []*SessionFileTransferProgress
|
||||
for _, transfer := range transfers {
|
||||
progressList = append(progressList, transfer.GetProgress())
|
||||
}
|
||||
return progressList
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) RemoveTransfer(id string) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
delete(m.transfers, id)
|
||||
}
|
||||
|
||||
func (m *SessionFileTransferManager) CleanupCompletedTransfers(maxAge time.Duration) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
for id, transfer := range m.transfers {
|
||||
if (transfer.Status == "completed" || transfer.Status == "failed") && transfer.Updated.Before(cutoff) {
|
||||
delete(m.transfers, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Progress tracking structures
|
||||
type FileTransferProgress struct {
|
||||
TotalSize int64
|
||||
TransferredSize int64
|
||||
Status string // "transferring", "completed", "failed"
|
||||
Type string // "sftp", "rdp"
|
||||
}
|
||||
|
||||
type ProgressWriter struct {
|
||||
writer io.Writer
|
||||
transfer *SessionFileTransfer
|
||||
written int64
|
||||
}
|
||||
|
||||
func NewProgressWriter(writer io.Writer, transfer *SessionFileTransfer) *ProgressWriter {
|
||||
return &ProgressWriter{
|
||||
writer: writer,
|
||||
transfer: transfer,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (pw *ProgressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.writer.Write(p)
|
||||
if n > 0 {
|
||||
pw.written += int64(n)
|
||||
pw.transfer.UpdateProgress(pw.written)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
pw.transfer.SetError(err)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
type FileProgressWriter struct {
|
||||
writer io.Writer
|
||||
transferId string
|
||||
written int64
|
||||
lastUpdate time.Time
|
||||
updateBytes int64 // Bytes written since last progress update
|
||||
updateTicker int64 // Simple counter to reduce time.Now() calls
|
||||
}
|
||||
|
||||
func NewFileProgressWriter(writer io.Writer, transferId string) *FileProgressWriter {
|
||||
return &FileProgressWriter{
|
||||
writer: writer,
|
||||
transferId: transferId,
|
||||
written: 0,
|
||||
lastUpdate: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (pw *FileProgressWriter) Write(p []byte) (int, error) {
|
||||
n, err := pw.writer.Write(p)
|
||||
if n > 0 {
|
||||
pw.written += int64(n)
|
||||
pw.updateBytes += int64(n)
|
||||
pw.updateTicker++
|
||||
|
||||
// Update progress every 64KB or every 100 writes to reduce overhead
|
||||
if pw.updateBytes >= 65536 || pw.updateTicker%100 == 0 {
|
||||
UpdateTransferProgress(pw.transferId, 0, pw.written, "transferring")
|
||||
pw.updateBytes = 0
|
||||
pw.lastUpdate = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
UpdateTransferProgress(pw.transferId, 0, -1, "failed")
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Repository interface
|
||||
type IFileRepository interface {
|
||||
AddFileHistory(ctx context.Context, history *model.FileHistory) error
|
||||
BuildFileHistoryQuery(ctx *gin.Context) *gorm.DB
|
||||
}
|
||||
|
||||
// SessionFileTransfer methods
|
||||
func (t *SessionFileTransfer) UpdateProgress(offset int64) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
t.Offset = offset
|
||||
t.Updated = time.Now()
|
||||
|
||||
if t.Status == "pending" {
|
||||
t.Status = "uploading"
|
||||
}
|
||||
|
||||
if t.Size > 0 && t.Offset >= t.Size {
|
||||
t.Status = "completed"
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SessionFileTransfer) SetError(err error) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
t.Error = err.Error()
|
||||
t.Status = "failed"
|
||||
t.Updated = time.Now()
|
||||
}
|
||||
|
||||
// GetProgress returns the current progress information
|
||||
func (t *SessionFileTransfer) GetProgress() *SessionFileTransferProgress {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
var percentage float64
|
||||
if t.Size > 0 {
|
||||
percentage = float64(t.Offset) / float64(t.Size) * 100
|
||||
}
|
||||
|
||||
// Calculate speed and ETA
|
||||
var speed int64
|
||||
var eta int64
|
||||
if !t.Created.Equal(t.Updated) && t.Offset > 0 {
|
||||
duration := t.Updated.Sub(t.Created).Seconds()
|
||||
speed = int64(float64(t.Offset) / duration)
|
||||
|
||||
if speed > 0 && t.Size > t.Offset {
|
||||
eta = (t.Size - t.Offset) / speed
|
||||
}
|
||||
}
|
||||
|
||||
return &SessionFileTransferProgress{
|
||||
ID: t.ID,
|
||||
SessionID: t.SessionID,
|
||||
Filename: t.Filename,
|
||||
Size: t.Size,
|
||||
Offset: t.Offset,
|
||||
Percentage: percentage,
|
||||
Status: t.Status,
|
||||
IsUpload: t.IsUpload,
|
||||
Created: t.Created,
|
||||
Updated: t.Updated,
|
||||
Error: t.Error,
|
||||
Speed: speed,
|
||||
ETA: eta,
|
||||
}
|
||||
}
|
||||
|
||||
// Progress tracking functions
|
||||
func CreateTransferProgress(transferId, transferType string) {
|
||||
transferMutex.Lock()
|
||||
fileTransfers[transferId] = &FileTransferProgress{
|
||||
TotalSize: 0,
|
||||
TransferredSize: 0,
|
||||
Status: "transferring",
|
||||
Type: transferType,
|
||||
}
|
||||
transferMutex.Unlock()
|
||||
}
|
||||
|
||||
func UpdateTransferProgress(transferId string, totalSize, transferredSize int64, status string) {
|
||||
transferMutex.Lock()
|
||||
if progress, exists := fileTransfers[transferId]; exists {
|
||||
if totalSize > 0 {
|
||||
progress.TotalSize = totalSize
|
||||
}
|
||||
if transferredSize >= 0 {
|
||||
progress.TransferredSize = transferredSize
|
||||
}
|
||||
if status != "" {
|
||||
progress.Status = status
|
||||
}
|
||||
}
|
||||
transferMutex.Unlock()
|
||||
}
|
||||
|
||||
func CleanupTransferProgress(transferId string, delay time.Duration) {
|
||||
go func() {
|
||||
time.Sleep(delay)
|
||||
transferMutex.Lock()
|
||||
delete(fileTransfers, transferId)
|
||||
transferMutex.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
func GetTransferProgressById(transferId string) (*FileTransferProgress, bool) {
|
||||
transferMutex.RLock()
|
||||
progress, exists := fileTransfers[transferId]
|
||||
transferMutex.RUnlock()
|
||||
return progress, exists
|
||||
}
|
||||
|
||||
// Utility functions
|
||||
func generateTransferID() string {
|
||||
bytes := make([]byte, 16)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
func sanitizeFilename(filename string) (string, error) {
|
||||
// Remove any directory separators
|
||||
cleaned := filepath.Base(filename)
|
||||
|
||||
// Check for dangerous patterns
|
||||
if strings.Contains(cleaned, "..") || cleaned == "." || cleaned == "" {
|
||||
return "", fmt.Errorf("invalid filename")
|
||||
}
|
||||
|
||||
// Additional security checks
|
||||
if strings.HasPrefix(cleaned, ".") && len(cleaned) > 1 {
|
||||
// Allow hidden files but validate they're not dangerous
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func IsPermissionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
permissionKeywords := []string{
|
||||
"permission denied",
|
||||
"access denied",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"not authorized",
|
||||
"insufficient privileges",
|
||||
"operation not permitted",
|
||||
"sftp: permission denied",
|
||||
"ssh: permission denied",
|
||||
}
|
||||
|
||||
for _, keyword := range permissionKeywords {
|
||||
if strings.Contains(errStr, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/veops/oneterm/internal/model"
|
||||
dbpkg "github.com/veops/oneterm/pkg/db"
|
||||
"github.com/veops/oneterm/pkg/utils"
|
||||
)
|
||||
|
||||
func GetAAG(assetId int, accountId int) (asset *model.Asset, account *model.Account, gateway *model.Gateway, err error) {
|
||||
asset, account, gateway = &model.Asset{}, &model.Account{}, &model.Gateway{}
|
||||
if err = dbpkg.DB.Model(asset).Where("id = ?", assetId).First(asset).Error; err != nil {
|
||||
return
|
||||
}
|
||||
if err = dbpkg.DB.Model(account).Where("id = ?", accountId).First(account).Error; err != nil {
|
||||
return
|
||||
}
|
||||
account.Password = utils.DecryptAES(account.Password)
|
||||
account.Pk = utils.DecryptAES(account.Pk)
|
||||
account.Phrase = utils.DecryptAES(account.Phrase)
|
||||
if asset.GatewayId != 0 {
|
||||
if err = dbpkg.DB.Model(gateway).Where("id = ?", asset.GatewayId).First(gateway).Error; err != nil {
|
||||
return
|
||||
}
|
||||
gateway.Password = utils.DecryptAES(gateway.Password)
|
||||
gateway.Pk = utils.DecryptAES(gateway.Pk)
|
||||
gateway.Phrase = utils.DecryptAES(gateway.Phrase)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetAuth(account *model.Account) (ssh.AuthMethod, error) {
|
||||
switch account.AccountType {
|
||||
case model.AUTHMETHOD_PASSWORD:
|
||||
return ssh.Password(account.Password), nil
|
||||
case model.AUTHMETHOD_PUBLICKEY:
|
||||
if account.Phrase == "" {
|
||||
pk, err := ssh.ParsePrivateKey([]byte(account.Pk))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(pk), nil
|
||||
} else {
|
||||
pk, err := ssh.ParsePrivateKeyWithPassphrase([]byte(account.Pk), []byte(account.Phrase))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(pk), nil
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid authmethod %d", account.AccountType)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user