mirror of
				https://github.com/veops/oneterm.git
				synced 2025-11-01 03:12:39 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			413 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			413 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package service
 | |
| 
 | |
| import (
 | |
| 	"archive/zip"
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"io/fs"
 | |
| 	"path/filepath"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/google/uuid"
 | |
| 	"github.com/pkg/sftp"
 | |
| 	"github.com/veops/oneterm/internal/model"
 | |
| 	"github.com/veops/oneterm/internal/repository"
 | |
| 	"github.com/veops/oneterm/internal/tunneling"
 | |
| 	dbpkg "github.com/veops/oneterm/pkg/db"
 | |
| 	"github.com/veops/oneterm/pkg/logger"
 | |
| 	"go.uber.org/zap"
 | |
| 	"golang.org/x/crypto/ssh"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	fm = &FileManager{
 | |
| 		sftps:    map[string]*sftp.Client{},
 | |
| 		lastTime: map[string]time.Time{},
 | |
| 		mtx:      sync.Mutex{},
 | |
| 	}
 | |
| 
 | |
| 	// Global file service instance
 | |
| 	DefaultFileService IFileService
 | |
| )
 | |
| 
 | |
| // InitFileService initializes the global file service
 | |
| func InitFileService() {
 | |
| 	repo := repository.NewFileRepository(dbpkg.DB)
 | |
| 	DefaultFileService = NewFileService(repo)
 | |
| }
 | |
| 
 | |
| func init() {
 | |
| 	go func() {
 | |
| 		tk := time.NewTicker(time.Minute)
 | |
| 		for {
 | |
| 			<-tk.C
 | |
| 			func() {
 | |
| 				fm.mtx.Lock()
 | |
| 				defer fm.mtx.Unlock()
 | |
| 				for k, v := range fm.lastTime {
 | |
| 					if v.Before(time.Now().Add(time.Minute * 10)) {
 | |
| 						delete(fm.sftps, k)
 | |
| 						delete(fm.lastTime, k)
 | |
| 					}
 | |
| 				}
 | |
| 			}()
 | |
| 		}
 | |
| 	}()
 | |
| }
 | |
| 
 | |
| type FileManager struct {
 | |
| 	sftps    map[string]*sftp.Client
 | |
| 	lastTime map[string]time.Time
 | |
| 	mtx      sync.Mutex
 | |
| }
 | |
| 
 | |
| type FileInfo struct {
 | |
| 	Name    string `json:"name"`
 | |
| 	IsDir   bool   `json:"is_dir"`
 | |
| 	Size    int64  `json:"size"`
 | |
| 	Mode    string `json:"mode"`
 | |
| 	IsLink  bool   `json:"is_link"`
 | |
| 	Target  string `json:"target"`
 | |
| 	ModTime string `json:"mod_time"`
 | |
| }
 | |
| 
 | |
| func GetFileManager() *FileManager {
 | |
| 	return fm
 | |
| }
 | |
| 
 | |
| func (fm *FileManager) GetFileClient(assetId, accountId int) (cli *sftp.Client, err error) {
 | |
| 	fm.mtx.Lock()
 | |
| 	defer fm.mtx.Unlock()
 | |
| 
 | |
| 	key := fmt.Sprintf("%d-%d", assetId, accountId)
 | |
| 	defer func() {
 | |
| 		fm.lastTime[key] = time.Now()
 | |
| 	}()
 | |
| 
 | |
| 	cli, ok := fm.sftps[key]
 | |
| 	if ok {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	asset, account, gateway, err := GetAAG(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	ip, port, err := tunneling.Proxy(false, uuid.New().String(), "sftp,ssh", asset, gateway)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	auth, err := GetAuth(account)
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	sshCli, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &ssh.ClientConfig{
 | |
| 		User:            account.Account,
 | |
| 		Auth:            []ssh.AuthMethod{auth},
 | |
| 		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
 | |
| 		Timeout:         time.Second,
 | |
| 	})
 | |
| 	if err != nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	cli, err = sftp.NewClient(sshCli)
 | |
| 	fm.sftps[key] = cli
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // File service interface
 | |
| type IFileService interface {
 | |
| 	ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error)
 | |
| 	MkdirAll(ctx context.Context, assetId, accountId int, dir string) error
 | |
| 	Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error)
 | |
| 	Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error)
 | |
| 	Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error)
 | |
| 	DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error)
 | |
| 	AddFileHistory(ctx context.Context, history *model.FileHistory) error
 | |
| 	GetFileHistory(ctx context.Context, filters map[string]interface{}) ([]*model.FileHistory, int64, error)
 | |
| 
 | |
| 	// RDP file transfer methods
 | |
| 	RDPReadDir(ctx context.Context, sessionId, dir string) ([]fs.FileInfo, error)
 | |
| 	RDPMkdirAll(ctx context.Context, sessionId, dir string) error
 | |
| 	RDPUploadFile(ctx context.Context, sessionId, filename string, content []byte) error
 | |
| 	RDPDownloadFile(ctx context.Context, sessionId, filename string) ([]byte, error)
 | |
| }
 | |
| 
 | |
| // File service implementation
 | |
| type FileService struct {
 | |
| 	repo repository.IFileRepository
 | |
| }
 | |
| 
 | |
| func NewFileService(repo repository.IFileRepository) IFileService {
 | |
| 	return &FileService{
 | |
| 		repo: repo,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ReadDir gets directory listing
 | |
| func (s *FileService) ReadDir(ctx context.Context, assetId, accountId int, dir string) ([]fs.FileInfo, error) {
 | |
| 	cli, err := GetFileManager().GetFileClient(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return cli.ReadDir(dir)
 | |
| }
 | |
| 
 | |
| // MkdirAll creates a directory
 | |
| func (s *FileService) MkdirAll(ctx context.Context, assetId, accountId int, dir string) error {
 | |
| 	cli, err := GetFileManager().GetFileClient(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return cli.MkdirAll(dir)
 | |
| }
 | |
| 
 | |
| // Create creates a file
 | |
| func (s *FileService) Create(ctx context.Context, assetId, accountId int, path string) (io.WriteCloser, error) {
 | |
| 	cli, err := GetFileManager().GetFileClient(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return cli.Create(path)
 | |
| }
 | |
| 
 | |
| // Open opens a file
 | |
| func (s *FileService) Open(ctx context.Context, assetId, accountId int, path string) (io.ReadCloser, error) {
 | |
| 	cli, err := GetFileManager().GetFileClient(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return cli.Open(path)
 | |
| }
 | |
| 
 | |
| // Stat gets file/directory information
 | |
| func (s *FileService) Stat(ctx context.Context, assetId, accountId int, path string) (fs.FileInfo, error) {
 | |
| 	cli, err := GetFileManager().GetFileClient(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return cli.Stat(path)
 | |
| }
 | |
| 
 | |
| // DownloadMultiple handles downloading single file or multiple files/directories as ZIP
 | |
| func (s *FileService) DownloadMultiple(ctx context.Context, assetId, accountId int, dir string, filenames []string) (io.ReadCloser, string, int64, error) {
 | |
| 	cli, err := GetFileManager().GetFileClient(assetId, accountId)
 | |
| 	if err != nil {
 | |
| 		return nil, "", 0, err
 | |
| 	}
 | |
| 
 | |
| 	// Validate and sanitize all filenames for security
 | |
| 	var sanitizedFilenames []string
 | |
| 	for _, filename := range filenames {
 | |
| 		sanitized, err := sanitizeFilename(filename)
 | |
| 		if err != nil {
 | |
| 			return nil, "", 0, fmt.Errorf("invalid filename '%s': %v", filename, err)
 | |
| 		}
 | |
| 
 | |
| 		sanitizedFilenames = append(sanitizedFilenames, sanitized)
 | |
| 	}
 | |
| 
 | |
| 	// If only one file, check if it's a regular file first
 | |
| 	if len(sanitizedFilenames) == 1 {
 | |
| 		fullPath := filepath.Join(dir, sanitizedFilenames[0])
 | |
| 		fileInfo, err := cli.Stat(fullPath)
 | |
| 		if err != nil {
 | |
| 			return nil, "", 0, err
 | |
| 		}
 | |
| 
 | |
| 		// If it's a regular file, return directly
 | |
| 		if !fileInfo.IsDir() {
 | |
| 			reader, err := cli.Open(fullPath)
 | |
| 			if err != nil {
 | |
| 				return nil, "", 0, err
 | |
| 			}
 | |
| 			return reader, sanitizedFilenames[0], fileInfo.Size(), nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Multiple files or contains directory, create ZIP
 | |
| 	return s.createZipArchive(cli, dir, sanitizedFilenames)
 | |
| }
 | |
| 
 | |
| // createZipArchive creates a ZIP archive containing the specified files/directories
 | |
| func (s *FileService) createZipArchive(cli *sftp.Client, baseDir string, filenames []string) (io.ReadCloser, string, int64, error) {
 | |
| 	// Generate ZIP filename
 | |
| 	var zipName string
 | |
| 	if len(filenames) == 1 {
 | |
| 		zipName = filenames[0] + ".zip"
 | |
| 	} else {
 | |
| 		zipName = "download.zip"
 | |
| 	}
 | |
| 
 | |
| 	// Use pipe for true streaming without memory buffering
 | |
| 	pipeReader, pipeWriter := io.Pipe()
 | |
| 
 | |
| 	// Create ZIP in a separate goroutine
 | |
| 	go func() {
 | |
| 		defer pipeWriter.Close()
 | |
| 
 | |
| 		zipWriter := zip.NewWriter(pipeWriter)
 | |
| 		defer zipWriter.Close()
 | |
| 
 | |
| 		// Add each file/directory to ZIP
 | |
| 		for _, filename := range filenames {
 | |
| 			fullPath := filepath.Join(baseDir, filename)
 | |
| 
 | |
| 			if err := s.addToZip(cli, zipWriter, baseDir, filename, fullPath); err != nil {
 | |
| 				logger.L().Error("Failed to add file to ZIP", zap.String("path", fullPath), zap.Error(err))
 | |
| 				pipeWriter.CloseWithError(err)
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	// Return pipe reader for streaming, size unknown (-1)
 | |
| 	return pipeReader, zipName, -1, nil
 | |
| }
 | |
| 
 | |
| // addToZip recursively adds files/directories to ZIP archive
 | |
| func (s *FileService) addToZip(cli *sftp.Client, zipWriter *zip.Writer, baseDir, relativePath, fullPath string) error {
 | |
| 	fileInfo, err := cli.Stat(fullPath)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if fileInfo.IsDir() {
 | |
| 		// Add directory
 | |
| 		return s.addDirToZip(cli, zipWriter, fullPath, relativePath)
 | |
| 	} else {
 | |
| 		// Add file
 | |
| 		return s.addFileToZip(cli, zipWriter, fullPath, relativePath)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // addFileToZip adds a single file to ZIP archive
 | |
| func (s *FileService) addFileToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
 | |
| 	// Open remote file
 | |
| 	file, err := cli.Open(fullPath)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer file.Close()
 | |
| 
 | |
| 	// Create file in ZIP
 | |
| 	zipFile, err := zipWriter.Create(relativePath)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Copy file content
 | |
| 	_, err = io.Copy(zipFile, file)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // addDirToZip recursively adds a directory to ZIP archive
 | |
| func (s *FileService) addDirToZip(cli *sftp.Client, zipWriter *zip.Writer, fullPath, relativePath string) error {
 | |
| 	// Read directory contents
 | |
| 	entries, err := cli.ReadDir(fullPath)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// If directory is empty, create directory entry
 | |
| 	if len(entries) == 0 {
 | |
| 		_, err := zipWriter.Create(relativePath + "/")
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Recursively add each entry in the directory
 | |
| 	for _, entry := range entries {
 | |
| 		entryFullPath := filepath.Join(fullPath, entry.Name())
 | |
| 		entryRelativePath := filepath.Join(relativePath, entry.Name())
 | |
| 
 | |
| 		if err := s.addToZip(cli, zipWriter, fullPath, entryRelativePath, entryFullPath); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // AddFileHistory adds a file history record
 | |
| func (s *FileService) AddFileHistory(ctx context.Context, history *model.FileHistory) error {
 | |
| 	return s.repo.AddFileHistory(ctx, history)
 | |
| }
 | |
| 
 | |
| // GetFileHistory gets file history records
 | |
| func (s *FileService) GetFileHistory(ctx context.Context, filters map[string]interface{}) ([]*model.FileHistory, int64, error) {
 | |
| 	return s.repo.GetFileHistory(ctx, filters)
 | |
| }
 | |
| 
 | |
| // RDP file transfer methods implementation
 | |
| 
 | |
| // RDPReadDir reads directory contents for RDP session
 | |
| func (s *FileService) RDPReadDir(ctx context.Context, sessionId, dir string) ([]fs.FileInfo, error) {
 | |
| 	// Get session tunnel to access file transfer manager
 | |
| 	tunnel := tunneling.GetTunnelBySessionId(sessionId)
 | |
| 	if tunnel == nil {
 | |
| 		return nil, fmt.Errorf("session not found: %s", sessionId)
 | |
| 	}
 | |
| 
 | |
| 	// For RDP sessions, we need to check if drive is enabled
 | |
| 	// This would need to be implemented based on the actual session configuration
 | |
| 	// For now, return an error indicating RDP file operations need to be handled differently
 | |
| 	return nil, fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
 | |
| }
 | |
| 
 | |
| // RDPMkdirAll creates directory for RDP session
 | |
| func (s *FileService) RDPMkdirAll(ctx context.Context, sessionId, dir string) error {
 | |
| 	tunnel := tunneling.GetTunnelBySessionId(sessionId)
 | |
| 	if tunnel == nil {
 | |
| 		return fmt.Errorf("session not found: %s", sessionId)
 | |
| 	}
 | |
| 
 | |
| 	return fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
 | |
| }
 | |
| 
 | |
| // RDPUploadFile uploads file for RDP session
 | |
| func (s *FileService) RDPUploadFile(ctx context.Context, sessionId, filename string, content []byte) error {
 | |
| 	tunnel := tunneling.GetTunnelBySessionId(sessionId)
 | |
| 	if tunnel == nil {
 | |
| 		return fmt.Errorf("session not found: %s", sessionId)
 | |
| 	}
 | |
| 
 | |
| 	return fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
 | |
| }
 | |
| 
 | |
| // RDPDownloadFile downloads file for RDP session
 | |
| func (s *FileService) RDPDownloadFile(ctx context.Context, sessionId, filename string) ([]byte, error) {
 | |
| 	tunnel := tunneling.GetTunnelBySessionId(sessionId)
 | |
| 	if tunnel == nil {
 | |
| 		return nil, fmt.Errorf("session not found: %s", sessionId)
 | |
| 	}
 | |
| 
 | |
| 	return nil, fmt.Errorf("RDP file operations should be handled through Guacamole protocol")
 | |
| }
 | |
| 
 | |
| func sanitizeFilename(filename string) (string, error) {
 | |
| 	// Remove any path traversal attempts
 | |
| 	if strings.Contains(filename, "..") ||
 | |
| 		strings.Contains(filename, "/") ||
 | |
| 		strings.Contains(filename, "\\") {
 | |
| 		return "", fmt.Errorf("invalid filename: path traversal detected")
 | |
| 	}
 | |
| 
 | |
| 	// Remove null bytes and control characters
 | |
| 	if strings.ContainsAny(filename, "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f") {
 | |
| 		return "", fmt.Errorf("invalid filename: control characters detected")
 | |
| 	}
 | |
| 
 | |
| 	return filename, nil
 | |
| }
 | 
