fix(backend): close monitor session

This commit is contained in:
pycook
2025-05-16 21:04:11 +08:00
parent 2fb76ecda6
commit 2412fc0b14
7 changed files with 498 additions and 21 deletions

View File

@@ -22,6 +22,7 @@ func SetupRouter(r *gin.Engine) {
c := controller.Controller{}
v1 := r.Group("/api/oneterm/v1", middleware.Error2RespMiddleware(), middleware.AuthMiddleware())
v1AuthAbandoned := r.Group("/api/oneterm/v1", middleware.Error2RespMiddleware())
{
account := v1.Group("account")
{
@@ -108,9 +109,12 @@ func SetupRouter(r *gin.Engine) {
config := v1.Group("config")
{
config.GET("", c.GetConfig)
config.POST("", c.PostConfig)
}
config2 := v1AuthAbandoned.Group("config")
{
config2.GET("", c.GetConfig)
}
history := v1.Group("history")
{

View File

@@ -126,9 +126,8 @@ func ConnectMonitor(ctx *gin.Context) {
}
})
if err = g.Wait(); err != nil {
logger.L().Error("monitor failed", zap.Error(err))
}
g.Wait()
logger.L().Info("monitor exit", zap.String("sessionId", sess.SessionId))
}
func ConnectClose(ctx *gin.Context) {

View File

@@ -111,6 +111,9 @@ func HandleError(ctx *gin.Context, sess *gsession.Session, err error, ws *websoc
if sess != nil && sess.IsGuacd() && ws != nil {
ws.WriteMessage(websocket.TextMessage, NewInstruction("error", lo.Ternary(ok, (ae).MessageBase64(ctx), err.Error()), cast.ToString(myErrors.ErrAdminClose)).Bytes())
} else if sess != nil {
if ctx.Query("is_monitor") == "true" {
return
}
WriteErrMsg(sess, lo.Ternary(ok, ae.MessageWithCtx(ctx), err.Error()))
}
}

View File

@@ -3,12 +3,15 @@ package guacd
import (
"bufio"
"fmt"
"io"
"net"
"os"
"strings"
"time"
"github.com/samber/lo"
"github.com/spf13/cast"
"go.uber.org/zap"
"github.com/veops/oneterm/internal/model"
"github.com/veops/oneterm/internal/tunneling"
@@ -23,6 +26,15 @@ const (
IGNORE_CERT = "true"
)
// File transfer parameters
const (
DRIVE_ENABLE = "enable-drive"
DRIVE_PATH = "drive-path"
DRIVE_CREATE_PATH = "create-drive-path"
DRIVE_DISABLE_UPLOAD = "disable-upload"
DRIVE_DISABLE_DOWNLOAD = "disable-download"
)
type Configuration struct {
Protocol string
Parameters map[string]string
@@ -35,13 +47,15 @@ func NewConfiguration() (config *Configuration) {
}
type Tunnel struct {
SessionId string
ConnectionId string
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
Config *Configuration
gw *tunneling.GatewayTunnel
SessionId string
ConnectionId string
conn net.Conn
reader *bufio.Reader
writer *bufio.Writer
Config *Configuration
gw *tunneling.GatewayTunnel
transferManager *FileTransferManager
drivePath string
}
func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, asset *model.Asset, account *model.Account, gateway *model.Gateway) (t *Tunnel, err error) {
@@ -61,10 +75,11 @@ func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, a
protocol, port := ss[0], ss[1]
cfg := model.GlobalConfig.Load()
t = &Tunnel{
conn: conn,
reader: bufio.NewReader(conn),
writer: bufio.NewWriter(conn),
ConnectionId: connectionId,
conn: conn,
reader: bufio.NewReader(conn),
writer: bufio.NewWriter(conn),
ConnectionId: connectionId,
transferManager: DefaultFileTransferManager,
Config: &Configuration{
Protocol: protocol,
Parameters: lo.TernaryF(
@@ -85,6 +100,12 @@ func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, a
"password": account.Password,
"disable-copy": cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), !cfg.RdpConfig.Copy, !cfg.VncConfig.Copy)),
"disable-paste": cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), !cfg.RdpConfig.Paste, !cfg.VncConfig.Paste)),
// Set file transfer related parameters from config
DRIVE_ENABLE: cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), cfg.RdpConfig.EnableDrive, false)),
DRIVE_PATH: cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), cfg.RdpConfig.DrivePath, "")),
DRIVE_CREATE_PATH: cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), cfg.RdpConfig.CreateDrivePath, false)),
DRIVE_DISABLE_UPLOAD: cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), cfg.RdpConfig.DisableUpload, false)),
DRIVE_DISABLE_DOWNLOAD: cast.ToString(lo.Ternary(strings.Contains(protocol, "rdp"), cfg.RdpConfig.DisableDownload, false)),
}
}, func() map[string]string {
return map[string]string{
@@ -109,6 +130,21 @@ func NewTunnel(connectionId, sessionId string, w, h, dpi int, protocol string, a
t.Config.Parameters["port"] = cast.ToString(t.gw.LocalPort)
}
// If RDP protocol and file transfer is enabled
if strings.Contains(protocol, "rdp") && t.Config.Parameters[DRIVE_ENABLE] == "true" {
// Get drive path
t.drivePath = t.Config.Parameters[DRIVE_PATH]
// Create drive path if needed
if t.Config.Parameters[DRIVE_CREATE_PATH] == "true" && t.drivePath != "" {
if err := os.MkdirAll(t.drivePath, 0755); err != nil {
logger.L().Error("Failed to create RDP drive path", zap.Error(err))
// Don't terminate the connection, just disable file transfer
t.drivePath = ""
}
}
}
err = t.handshake()
return
@@ -196,15 +232,28 @@ func (t *Tunnel) Read() (p []byte, err error) {
return
}
func (t *Tunnel) ReadInstruction() (instruction *Instruction, err error) {
func (t *Tunnel) ReadInstruction() (*Instruction, error) {
data, err := t.Read()
if err != nil {
return
return nil, err
}
instruction = (&Instruction{}).Parse(string(data))
instruction := (&Instruction{}).Parse(string(data))
return
// Check if this is a file transfer instruction
if isFileInstruction(instruction.Opcode) {
return t.HandleFileInstruction(instruction)
}
return instruction, nil
}
// isFileInstruction checks if the instruction is related to file transfer
func isFileInstruction(opcode string) bool {
return opcode == INSTRUCTION_FILE_UPLOAD ||
opcode == INSTRUCTION_FILE_DOWNLOAD ||
opcode == INSTRUCTION_FILE_DATA ||
opcode == INSTRUCTION_FILE_COMPLETE
}
func (t *Tunnel) assert(opcode string) (instruction *Instruction, err error) {
@@ -228,3 +277,115 @@ func (t *Tunnel) Disconnect() {
logger.L().Debug("client disconnect")
t.WriteInstruction(NewInstruction("disconnect"))
}
// HandleFileUpload handles file upload request
func (t *Tunnel) HandleFileUpload(filename string, size int64) (string, error) {
if t.drivePath == "" || t.Config.Parameters[DRIVE_DISABLE_UPLOAD] == "true" {
return "", fmt.Errorf("file upload is disabled")
}
transfer, err := t.transferManager.CreateUpload(filename, t.drivePath)
if err != nil {
return "", err
}
return transfer.ID, nil
}
// HandleFileDownload handles file download request
func (t *Tunnel) HandleFileDownload(filename string) (string, int64, error) {
if t.drivePath == "" || t.Config.Parameters[DRIVE_DISABLE_DOWNLOAD] == "true" {
return "", 0, fmt.Errorf("file download is disabled")
}
transfer, err := t.transferManager.CreateDownload(filename, t.drivePath)
if err != nil {
return "", 0, err
}
return transfer.ID, transfer.Size, nil
}
// WriteFileData writes data to an upload file
func (t *Tunnel) WriteFileData(transferId string, data []byte) (int, error) {
transfer := t.transferManager.GetTransfer(transferId)
if transfer == nil {
return 0, fmt.Errorf("transfer not found: %s", transferId)
}
return transfer.Write(data)
}
// ReadFileData reads data from a download file
func (t *Tunnel) ReadFileData(transferId string, buffer []byte) (int, error) {
transfer := t.transferManager.GetTransfer(transferId)
if transfer == nil {
return 0, fmt.Errorf("transfer not found: %s", transferId)
}
return transfer.Read(buffer)
}
// CloseFileTransfer closes a file transfer
func (t *Tunnel) CloseFileTransfer(transferId string) error {
transfer := t.transferManager.GetTransfer(transferId)
if transfer == nil {
return fmt.Errorf("transfer not found: %s", transferId)
}
err := transfer.Close()
t.transferManager.RemoveTransfer(transferId)
return err
}
// SendDownloadData reads data from a file and sends to client
func (t *Tunnel) SendDownloadData(transferId string) error {
transfer := t.transferManager.GetTransfer(transferId)
if transfer == nil {
return fmt.Errorf("transfer not found: %s", transferId)
}
if transfer.IsUpload {
return fmt.Errorf("cannot download from upload transfer")
}
// Use 4KB buffer for file data
buffer := make([]byte, 4096)
for !transfer.Completed {
n, err := transfer.Read(buffer)
if err != nil && err != io.EOF {
return err
}
if n > 0 {
// Send file data to client
dataInstr := NewInstruction(INSTRUCTION_FILE_DATA, transferId, string(buffer[:n]))
if _, err := t.WriteInstruction(dataInstr); err != nil {
return err
}
// Read ACK from client
ack, err := t.ReadInstruction()
if err != nil {
return err
}
if ack.Opcode != INSTRUCTION_FILE_ACK {
return fmt.Errorf("expected ACK instruction, got: %s", ack.Opcode)
}
}
if err == io.EOF || transfer.Completed {
break
}
}
// Send complete instruction
completeInstr := NewInstruction(INSTRUCTION_FILE_COMPLETE, transferId)
if _, err := t.WriteInstruction(completeInstr); err != nil {
return err
}
return t.CloseFileTransfer(transferId)
}

View File

@@ -0,0 +1,110 @@
package guacd
import (
"fmt"
"strconv"
)
// File transfer instruction constants
const (
INSTRUCTION_FILE_UPLOAD = "file-upload"
INSTRUCTION_FILE_DOWNLOAD = "file-download"
INSTRUCTION_FILE_DATA = "file-data"
INSTRUCTION_FILE_ACK = "file-ack"
INSTRUCTION_FILE_COMPLETE = "file-complete"
INSTRUCTION_FILE_ERROR = "file-error"
)
// RDP file transfer related parameters
const (
RDP_ENABLE_DRIVE = "enable-drive"
RDP_DRIVE_PATH = "drive-path"
RDP_DRIVE_NAME = "drive-name"
RDP_DISABLE_DOWNLOAD = "disable-download"
RDP_DISABLE_UPLOAD = "disable-upload"
RDP_CREATE_DRIVE_PATH = "create-drive-path"
)
// HandleFileInstruction processes file transfer related instructions
func (t *Tunnel) HandleFileInstruction(instruction *Instruction) (*Instruction, error) {
switch instruction.Opcode {
case INSTRUCTION_FILE_UPLOAD:
if len(instruction.Args) < 2 {
return NewInstruction(INSTRUCTION_FILE_ERROR, "Invalid upload request"), nil
}
filename := instruction.Args[0]
size, err := strconv.ParseInt(instruction.Args[1], 10, 64)
if err != nil {
return NewInstruction(INSTRUCTION_FILE_ERROR, "Invalid file size"), nil
}
transferId, err := t.HandleFileUpload(filename, size)
if err != nil {
return NewInstruction(INSTRUCTION_FILE_ERROR, err.Error()), nil
}
return NewInstruction(INSTRUCTION_FILE_ACK, transferId), nil
case INSTRUCTION_FILE_DOWNLOAD:
if len(instruction.Args) < 1 {
return NewInstruction(INSTRUCTION_FILE_ERROR, "Invalid download request"), nil
}
filename := instruction.Args[0]
transferId, size, err := t.HandleFileDownload(filename)
if err != nil {
return NewInstruction(INSTRUCTION_FILE_ERROR, err.Error()), nil
}
// Send acknowledgement with transfer ID and file size
ackInstr := NewInstruction(INSTRUCTION_FILE_ACK, transferId, strconv.FormatInt(size, 10))
if _, err := t.WriteInstruction(ackInstr); err != nil {
return NewInstruction(INSTRUCTION_FILE_ERROR, fmt.Sprintf("Failed to send ACK: %s", err.Error())), nil
}
// Start file download process in a new goroutine
go func() {
if err := t.SendDownloadData(transferId); err != nil {
// Log error, but we can't send error instruction here as it would interfere with protocol
fmt.Printf("Download failed: %s\n", err.Error())
}
}()
// Return nil to avoid sending another response
return nil, nil
case INSTRUCTION_FILE_DATA:
if len(instruction.Args) < 2 {
return NewInstruction(INSTRUCTION_FILE_ERROR, "Invalid data request"), nil
}
transferId := instruction.Args[0]
data := []byte(instruction.Args[1])
// If this is an upload, write the data
n, err := t.WriteFileData(transferId, data)
if err != nil {
return NewInstruction(INSTRUCTION_FILE_ERROR, fmt.Sprintf("Write error: %s", err.Error())), nil
}
return NewInstruction(INSTRUCTION_FILE_ACK, transferId, strconv.Itoa(n)), nil
case INSTRUCTION_FILE_COMPLETE:
if len(instruction.Args) < 1 {
return NewInstruction(INSTRUCTION_FILE_ERROR, "Invalid complete request"), nil
}
transferId := instruction.Args[0]
err := t.CloseFileTransfer(transferId)
if err != nil {
return NewInstruction(INSTRUCTION_FILE_ERROR, fmt.Sprintf("Failed to complete transfer: %s", err.Error())), nil
}
return NewInstruction(INSTRUCTION_FILE_ACK, transferId, "complete"), nil
default:
return nil, fmt.Errorf("Unknown file instruction: %s", instruction.Opcode)
}
}

View File

@@ -0,0 +1,195 @@
package guacd
import (
"fmt"
"io"
"os"
"path/filepath"
"sync"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/veops/oneterm/pkg/logger"
)
// FileTransferManager manages RDP file transfers
type FileTransferManager struct {
transfers map[string]*FileTransfer
mutex sync.Mutex
}
// FileTransfer represents a single file transfer
type FileTransfer struct {
ID string
Filename string
Path string
Size int64
Offset int64
Created time.Time
Completed bool
IsUpload bool
file *os.File
mutex sync.Mutex
}
// Global file transfer manager instance
var (
DefaultFileTransferManager = NewFileTransferManager()
)
// NewFileTransferManager creates a new file transfer manager
func NewFileTransferManager() *FileTransferManager {
return &FileTransferManager{
transfers: make(map[string]*FileTransfer),
}
}
// CreateUpload creates an upload file transfer
func (m *FileTransferManager) CreateUpload(filename, drivePath string) (*FileTransfer, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
id := uuid.New().String()
fullPath := filepath.Join(drivePath, filename)
// Ensure directory exists
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create directory: %w", err)
}
// Create file
file, err := os.Create(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to create file: %w", err)
}
transfer := &FileTransfer{
ID: id,
Filename: filename,
Path: fullPath,
Created: time.Now(),
IsUpload: true,
file: file,
}
m.transfers[id] = transfer
logger.L().Debug("Created file upload", zap.String("id", id), zap.String("filename", filename))
return transfer, nil
}
// CreateDownload creates a download file transfer
func (m *FileTransferManager) CreateDownload(filename, drivePath string) (*FileTransfer, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
id := uuid.New().String()
fullPath := filepath.Join(drivePath, filename)
// Open file
file, err := os.Open(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to open file: %w", err)
}
// Get file info
stat, err := file.Stat()
if err != nil {
file.Close()
return nil, fmt.Errorf("failed to get file info: %w", err)
}
transfer := &FileTransfer{
ID: id,
Filename: filename,
Path: fullPath,
Size: stat.Size(),
Created: time.Now(),
IsUpload: false,
file: file,
}
m.transfers[id] = transfer
logger.L().Debug("Created file download", zap.String("id", id), zap.String("filename", filename))
return transfer, nil
}
// GetTransfer gets a transfer by ID
func (m *FileTransferManager) GetTransfer(id string) *FileTransfer {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.transfers[id]
}
// RemoveTransfer removes a transfer by ID
func (m *FileTransferManager) RemoveTransfer(id string) {
m.mutex.Lock()
defer m.mutex.Unlock()
if transfer, exists := m.transfers[id]; exists {
if transfer.file != nil {
transfer.file.Close()
}
delete(m.transfers, id)
}
}
// Write writes data to an upload file
func (t *FileTransfer) Write(data []byte) (int, error) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.Completed {
return 0, fmt.Errorf("transfer already completed")
}
if !t.IsUpload {
return 0, fmt.Errorf("cannot write to download transfer")
}
n, err := t.file.Write(data)
if err != nil {
return n, err
}
t.Offset += int64(n)
return n, nil
}
// Read reads data from a download file
func (t *FileTransfer) Read(p []byte) (int, error) {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.IsUpload {
return 0, fmt.Errorf("cannot read from upload transfer")
}
n, err := t.file.Read(p)
if err != nil {
if err == io.EOF {
t.Completed = true
}
return n, err
}
t.Offset += int64(n)
if t.Offset >= t.Size {
t.Completed = true
}
return n, nil
}
// Close closes the file transfer
func (t *FileTransfer) Close() error {
t.mutex.Lock()
defer t.mutex.Unlock()
t.Completed = true
if t.file != nil {
return t.file.Close()
}
return nil
}

View File

@@ -16,8 +16,13 @@ type SshConfig struct {
Paste bool `json:"paste" gorm:"column:paste"`
}
type RdpConfig struct {
Copy bool `json:"copy" gorm:"column:copy"`
Paste bool `json:"paste" gorm:"column:paste"`
Copy bool `json:"copy" gorm:"column:copy"`
Paste bool `json:"paste" gorm:"column:paste"`
EnableDrive bool `json:"enable_drive" gorm:"column:enable_drive"`
DrivePath string `json:"drive_path" gorm:"column:drive_path"`
CreateDrivePath bool `json:"create_drive_path" gorm:"column:create_drive_path"`
DisableUpload bool `json:"disable_upload" gorm:"column:disable_upload"`
DisableDownload bool `json:"disable_download" gorm:"column:disable_download"`
}
type VncConfig struct {
Copy bool `json:"copy" gorm:"column:copy"`