Files
oneterm/backend/internal/service/file.go
2025-05-28 21:40:40 +08:00

413 lines
12 KiB
Go

package service
import (
"archive/zip"
"context"
"fmt"
"io"
"io/fs"
"path/filepath"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/sftp"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/repository"
"github.com/veops/oneterm/internal/tunneling"
dbpkg "github.com/veops/oneterm/pkg/db"
"github.com/veops/oneterm/pkg/logger"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
var (
fm = &FileManager{
sftps: map[string]*sftp.Client{},
lastTime: map[string]time.Time{},
mtx: sync.Mutex{},
}
// Global file service instance
DefaultFileService IFileService
)
// InitFileService initializes the global file service
func InitFileService() {
repo := repository.NewFileRepository(dbpkg.DB)
DefaultFileService = NewFileService(repo)
}
func init() {
go func() {
tk := time.NewTicker(time.Minute)
for {
<-tk.C
func() {
fm.mtx.Lock()
defer fm.mtx.Unlock()
for k, v := range fm.lastTime {
if v.Before(time.Now().Add(time.Minute * 10)) {
delete(fm.sftps, k)
delete(fm.lastTime, k)
}
}
}()
}
}()
}
type FileManager struct {
sftps map[string]*sftp.Client
lastTime map[string]time.Time
mtx sync.Mutex
}
type FileInfo struct {
Name string `json:"name"`
IsDir bool `json:"is_dir"`
Size int64 `json:"size"`
Mode string `json:"mode"`
IsLink bool `json:"is_link"`
Target string `json:"target"`
ModTime string `json:"mod_time"`
}
func GetFileManager() *FileManager {
return fm
}
func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client, err error) {
fm.mtx.Lock()
defer fm.mtx.Unlock()
key := fmt.Sprintf("%d-%d", assetId, accountId)
defer func() {
fm.lastTime[key] = time.Now()
}()
cli, ok := fm.sftps[key]
if ok {
return
}
asset, account, gateway, err := GetAAG(assetId, accountId)
if err != nil {
return
}
ip, port, err := tunneling.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
if err != nil {
return
}
auth, err := GetAuth(account)
if err != nil {
return
}
sshCli, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
User: account.Account,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: time.Second,
})
if err != nil {
return
}
cli, err = sftp.NewClient(sshCli)
fm.sftps[key] = cli
return
}
// File service interface
type IFileService interface {
ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error)
MkdirAll(ctx context.Context, assetId, accountId int, dir string) error
Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error)
Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error)
Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error)
DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error)
AddFileHistory(ctx context.Context, history *model.FileHistory) error
GetFileHistory(ctx context.Context, filters map[string]interface{}) ([]*model.FileHistory, int64, error)
// RDP file transfer methods
RDPReadDir(ctx context.Context, sessionId, dir string) ([]fs.FileInfo, error)
RDPMkdirAll(ctx context.Context, sessionId, dir string) error
RDPUploadFile(ctx context.Context, sessionId, filename string, content []byte) error
RDPDownloadFile(ctx context.Context, sessionId, filename string) ([]byte, error)
}
// File service implementation
type FileService struct {
repo repository.IFileRepository
}
func NewFileService(repo repository.IFileRepository) IFileService {
return &FileService{
repo: repo,
}
}
// ReadDir gets directory listing
func (s *FileService) ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.ReadDir(dir)
}
// MkdirAll creates a directory
func (s *FileService) MkdirAll(ctx context.Context, assetId, accountId int, dir string) error {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return err
}
return cli.MkdirAll(dir)
}
// Create creates a file
func (s *FileService) Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Create(path)
}
// Open opens a file
func (s *FileService) Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Open(path)
}
// Stat gets file/directory information
func (s *FileService) Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, err
}
return cli.Stat(path)
}
// DownloadMultiple handles downloading single file or multiple files/directories as ZIP
func (s *FileService) DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
cli, err := GetFileManager().GetFileClient(assetId, accountId)
if err != nil {
return nil, "", 0, err
}
// Validate and sanitize all filenames for security
var sanitizedFilenames []string
for _, filename := range filenames {
sanitized, err := sanitizeFilename(filename)
if err != nil {
return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
}
sanitizedFilenames = append(sanitizedFilenames, sanitized)
}
// If only one file, check if it's a regular file first
if len(sanitizedFilenames) == 1 {
fullPath := filepath.Join(dir, sanitizedFilenames[0])
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return nil, "", 0, err
}
// If it's a regular file, return directly
if !fileInfo.IsDir() {
reader, err := cli.Open(fullPath)
if err != nil {
return nil, "", 0, err
}
return reader, sanitizedFilenames[0], fileInfo.Size(), nil
}
}
// Multiple files or contains directory, create ZIP
return s.createZipArchive(cli, dir, sanitizedFilenames)
}
// createZipArchive creates a ZIP archive containing the specified files/directories
func (s *FileService) createZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
// Generate ZIP filename
var zipName string
if len(filenames) == 1 {
zipName = filenames[0] + ".zip"
} else {
zipName = "download.zip"
}
// Use pipe for true streaming without memory buffering
pipeReader, pipeWriter := io.Pipe()
// Create ZIP in a separate goroutine
go func() {
defer pipeWriter.Close()
zipWriter := zip.NewWriter(pipeWriter)
defer zipWriter.Close()
// Add each file/directory to ZIP
for _, filename := range filenames {
fullPath := filepath.Join(baseDir, filename)
if err := s.addToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
logger.L().Error("Failed to add file to ZIP", zap.String("path", fullPath), zap.Error(err))
pipeWriter.CloseWithError(err)
return
}
}
}()
// Return pipe reader for streaming, size unknown (-1)
return pipeReader, zipName, -1, nil
}
// addToZip recursively adds files/directories to ZIP archive
func (s *FileService) addToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
fileInfo, err := cli.Stat(fullPath)
if err != nil {
return err
}
if fileInfo.IsDir() {
// Add directory
return s.addDirToZip(cli, zipWriter, fullPath, relativePath)
} else {
// Add file
return s.addFileToZip(cli, zipWriter, fullPath, relativePath)
}
}
// addFileToZip adds a single file to ZIP archive
func (s *FileService) addFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Open remote file
file, err := cli.Open(fullPath)
if err != nil {
return err
}
defer file.Close()
// Create file in ZIP
zipFile, err := zipWriter.Create(relativePath)
if err != nil {
return err
}
// Copy file content
_, err = io.Copy(zipFile, file)
return err
}
// addDirToZip recursively adds a directory to ZIP archive
func (s *FileService) addDirToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
// Read directory contents
entries, err := cli.ReadDir(fullPath)
if err != nil {
return err
}
// If directory is empty, create directory entry
if len(entries) == 0 {
_, err := zipWriter.Create(relativePath + "/")
return err
}
// Recursively add each entry in the directory
for _, entry := range entries {
entryFullPath := filepath.Join(fullPath, entry.Name())
entryRelativePath := filepath.Join(relativePath, entry.Name())
if err := s.addToZip(cli, zipWriter, fullPath, entryRelativePath, entryFullPath); err != nil {
return err
}
}
return nil
}
// AddFileHistory adds a file history record
func (s *FileService) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
return s.repo.AddFileHistory(ctx, history)
}
// GetFileHistory gets file history records
func (s *FileService) GetFileHistory(ctx context.Context, filters map[string]interface{}) ([]*model.FileHistory, int64, error) {
return s.repo.GetFileHistory(ctx, filters)
}
// RDP file transfer methods implementation
// RDPReadDir reads directory contents for RDP session
func (s *FileService) RDPReadDir(ctx context.Context, sessionId, dir string) ([]fs.FileInfo, error) {
// Get session tunnel to access file transfer manager
tunnel := tunneling.GetTunnelBySessionId(sessionId)
if tunnel == nil {
return nil, fmt.Errorf("session not found: %s", sessionId)
}
// For RDP sessions, we need to check if drive is enabled
// This would need to be implemented based on the actual session configuration
// For now, return an error indicating RDP file operations need to be handled differently
return nil, fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
}
// RDPMkdirAll creates directory for RDP session
func (s *FileService) RDPMkdirAll(ctx context.Context, sessionId, dir string) error {
tunnel := tunneling.GetTunnelBySessionId(sessionId)
if tunnel == nil {
return fmt.Errorf("session not found: %s", sessionId)
}
return fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
}
// RDPUploadFile uploads file for RDP session
func (s *FileService) RDPUploadFile(ctx context.Context, sessionId, filename string, content []byte) error {
tunnel := tunneling.GetTunnelBySessionId(sessionId)
if tunnel == nil {
return fmt.Errorf("session not found: %s", sessionId)
}
return fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
}
// RDPDownloadFile downloads file for RDP session
func (s *FileService) RDPDownloadFile(ctx context.Context, sessionId, filename string) ([]byte, error) {
tunnel := tunneling.GetTunnelBySessionId(sessionId)
if tunnel == nil {
return nil, fmt.Errorf("session not found: %s", sessionId)
}
return nil, fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
}
func sanitizeFilename(filename string) (string, error) {
// Remove any path traversal attempts
if strings.Contains(filename, "..") ||
strings.Contains(filename, "/") ||
strings.Contains(filename, "\\") {
return "", fmt.Errorf("invalid filename: path traversal detected")
}
// Remove null bytes and control characters
if strings.ContainsAny(filename, "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f") {
return "", fmt.Errorf("invalid filename: control characters detected")
}
return filename, nil
}