diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8f774ea..46c86a8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -47,6 +47,10 @@ jobs: goarch: amd64 - goos: darwin goarch: arm64 + - goos: windows + goarch: amd64 + - goos: windows + goarch: arm64 steps: - name: Checkout code @@ -58,6 +62,7 @@ jobs: go-version: '1.21' - name: Build binary + id: build working-directory: codeagent-wrapper env: GOOS: ${{ matrix.goos }} @@ -66,14 +71,18 @@ jobs: run: | VERSION=${GITHUB_REF#refs/tags/} OUTPUT_NAME=codeagent-wrapper-${{ matrix.goos }}-${{ matrix.goarch }} + if [ "${{ matrix.goos }}" = "windows" ]; then + OUTPUT_NAME="${OUTPUT_NAME}.exe" + fi go build -ldflags="-s -w -X main.version=${VERSION}" -o ${OUTPUT_NAME} . chmod +x ${OUTPUT_NAME} + echo "artifact_path=codeagent-wrapper/${OUTPUT_NAME}" >> $GITHUB_OUTPUT - name: Upload artifact uses: actions/upload-artifact@v4 with: name: codeagent-wrapper-${{ matrix.goos }}-${{ matrix.goarch }} - path: codeagent-wrapper/codeagent-wrapper-${{ matrix.goos }}-${{ matrix.goarch }} + path: ${{ steps.build.outputs.artifact_path }} release: name: Create Release @@ -92,7 +101,7 @@ jobs: run: | mkdir -p release find artifacts -type f -name "codeagent-wrapper-*" -exec mv {} release/ \; - cp install.sh release/ + cp install.sh install.bat release/ ls -la release/ - name: Create Release diff --git a/README.md b/README.md index 2422751..f42d2ce 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ This system leverages a **dual-agent architecture**: - Codex excels at focused code generation and execution - Together they provide better results than either alone -## Quick Start +## Quick Start(Please execute in Powershell on Windows) ```bash git clone https://github.com/cexll/myclaude.git @@ -244,6 +244,33 @@ python3 install.py --module dev bash install.sh ``` +#### Windows + +Windows installs place `codex-wrapper.exe` in `%USERPROFILE%\bin`. + +```powershell +# PowerShell (recommended) +powershell -ExecutionPolicy Bypass -File install.ps1 + +# Batch (cmd) +install.bat +``` + +**Add to PATH** (if installer doesn't detect it): + +```powershell +# PowerShell - persistent for current user +[Environment]::SetEnvironmentVariable('PATH', "$HOME\bin;" + [Environment]::GetEnvironmentVariable('PATH','User'), 'User') + +# PowerShell - current session only +$Env:PATH = "$HOME\bin;$Env:PATH" +``` + +```batch +REM cmd.exe - persistent for current user +setx PATH "%USERPROFILE%\bin;%PATH%" +``` + --- ## Workflow Selection Guide diff --git a/README_CN.md b/README_CN.md index 9c6089c..50fc804 100644 --- a/README_CN.md +++ b/README_CN.md @@ -20,7 +20,7 @@ - Codex 擅长专注的代码生成和执行 - 两者结合效果优于单独使用 -## 快速开始 +## 快速开始(windows上请在Powershell中执行) ```bash git clone https://github.com/cexll/myclaude.git @@ -235,6 +235,33 @@ python3 install.py --module dev bash install.sh ``` +#### Windows 系统 + +Windows 系统会将 `codex-wrapper.exe` 安装到 `%USERPROFILE%\bin`。 + +```powershell +# PowerShell(推荐) +powershell -ExecutionPolicy Bypass -File install.ps1 + +# 批处理(cmd) +install.bat +``` + +**添加到 PATH**(如果安装程序未自动检测): + +```powershell +# PowerShell - 永久添加(当前用户) +[Environment]::SetEnvironmentVariable('PATH', "$HOME\bin;" + [Environment]::GetEnvironmentVariable('PATH','User'), 'User') + +# PowerShell - 仅当前会话 +$Env:PATH = "$HOME\bin;$Env:PATH" +``` + +```batch +REM cmd.exe - 永久添加(当前用户) +setx PATH "%USERPROFILE%\bin;%PATH%" +``` + --- ## 工作流选择指南 diff --git a/codeagent-wrapper/.gitignore b/codeagent-wrapper/.gitignore new file mode 100644 index 0000000..f5dcfe4 --- /dev/null +++ b/codeagent-wrapper/.gitignore @@ -0,0 +1,5 @@ +coverage.out +coverage*.out +cover.out +cover_*.out +coverage.html diff --git a/codeagent-wrapper/logger.go b/codeagent-wrapper/logger.go index b187caa..c7102c6 100644 --- a/codeagent-wrapper/logger.go +++ b/codeagent-wrapper/logger.go @@ -3,9 +3,12 @@ package main import ( "bufio" "context" + "errors" "fmt" "os" "path/filepath" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -25,6 +28,7 @@ type Logger struct { closeOnce sync.Once workerWG sync.WaitGroup pendingWG sync.WaitGroup + flushMu sync.Mutex } type logEntry struct { @@ -32,6 +36,25 @@ type logEntry struct { msg string } +// CleanupStats captures the outcome of a cleanupOldLogs run. +type CleanupStats struct { + Scanned int + Deleted int + Kept int + Errors int + DeletedFiles []string + KeptFiles []string +} + +var ( + processRunningCheck = isProcessRunning + processStartTimeFn = getProcessStartTime + removeLogFileFn = os.Remove + globLogFiles = filepath.Glob + fileStatFn = os.Lstat // Use Lstat to detect symlinks + evalSymlinksFn = filepath.EvalSymlinks +) + // NewLogger creates the async logger and starts the worker goroutine. // The log file is created under os.TempDir() using the required naming scheme. func NewLogger() (*Logger, error) { @@ -154,6 +177,9 @@ func (l *Logger) Flush() { return } + l.flushMu.Lock() + defer l.flushMu.Unlock() + // Wait for pending entries with timeout done := make(chan struct{}) go func() { @@ -199,7 +225,9 @@ func (l *Logger) log(level, msg string) { } entry := logEntry{level: level, msg: msg} + l.flushMu.Lock() l.pendingWG.Add(1) + l.flushMu.Unlock() select { case l.ch <- entry: @@ -241,3 +269,187 @@ func (l *Logger) run() { } } } + +// cleanupOldLogs scans os.TempDir() for codex-wrapper-*.log files and removes those +// whose owning process is no longer running (i.e., orphaned logs). +// It includes safety checks for: +// - PID reuse: Compares file modification time with process start time +// - Symlink attacks: Ensures files are within TempDir and not symlinks +func cleanupOldLogs() (CleanupStats, error) { + var stats CleanupStats + tempDir := os.TempDir() + pattern := filepath.Join(tempDir, "codex-wrapper-*.log") + + matches, err := globLogFiles(pattern) + if err != nil { + logWarn(fmt.Sprintf("cleanupOldLogs: failed to list logs: %v", err)) + return stats, fmt.Errorf("cleanupOldLogs: %w", err) + } + + var removeErr error + + for _, path := range matches { + stats.Scanned++ + filename := filepath.Base(path) + + // Security check: Verify file is not a symlink and is within tempDir + if shouldSkipFile, reason := isUnsafeFile(path, tempDir); shouldSkipFile { + stats.Kept++ + stats.KeptFiles = append(stats.KeptFiles, filename) + if reason != "" { + logWarn(fmt.Sprintf("cleanupOldLogs: skipping %s: %s", filename, reason)) + } + continue + } + + pid, ok := parsePIDFromLog(path) + if !ok { + stats.Kept++ + stats.KeptFiles = append(stats.KeptFiles, filename) + continue + } + + // Check if process is running + if !processRunningCheck(pid) { + // Process not running, safe to delete + if err := removeLogFileFn(path); err != nil { + if errors.Is(err, os.ErrNotExist) { + // File already deleted by another process, don't count as success + stats.Kept++ + stats.KeptFiles = append(stats.KeptFiles, filename+" (already deleted)") + continue + } + stats.Errors++ + logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s: %v", filename, err)) + removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err)) + continue + } + stats.Deleted++ + stats.DeletedFiles = append(stats.DeletedFiles, filename) + continue + } + + // Process is running, check for PID reuse + if isPIDReused(path, pid) { + // PID was reused, the log file is orphaned + if err := removeLogFileFn(path); err != nil { + if errors.Is(err, os.ErrNotExist) { + stats.Kept++ + stats.KeptFiles = append(stats.KeptFiles, filename+" (already deleted)") + continue + } + stats.Errors++ + logWarn(fmt.Sprintf("cleanupOldLogs: failed to remove %s (PID reused): %v", filename, err)) + removeErr = errors.Join(removeErr, fmt.Errorf("failed to remove %s: %w", filename, err)) + continue + } + stats.Deleted++ + stats.DeletedFiles = append(stats.DeletedFiles, filename) + continue + } + + // Process is running and owns this log file + stats.Kept++ + stats.KeptFiles = append(stats.KeptFiles, filename) + } + + if removeErr != nil { + return stats, fmt.Errorf("cleanupOldLogs: %w", removeErr) + } + + return stats, nil +} + +// isUnsafeFile checks if a file is unsafe to delete (symlink or outside tempDir). +// Returns (true, reason) if the file should be skipped. +func isUnsafeFile(path string, tempDir string) (bool, string) { + // Check if file is a symlink + info, err := fileStatFn(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return true, "" // File disappeared, skip silently + } + return true, fmt.Sprintf("stat failed: %v", err) + } + + // Check if it's a symlink + if info.Mode()&os.ModeSymlink != 0 { + return true, "refusing to delete symlink" + } + + // Resolve any path traversal and verify it's within tempDir + resolvedPath, err := evalSymlinksFn(path) + if err != nil { + return true, fmt.Sprintf("path resolution failed: %v", err) + } + + // Get absolute path of tempDir + absTempDir, err := filepath.Abs(tempDir) + if err != nil { + return true, fmt.Sprintf("tempDir resolution failed: %v", err) + } + + // Ensure resolved path is within tempDir + relPath, err := filepath.Rel(absTempDir, resolvedPath) + if err != nil || strings.HasPrefix(relPath, "..") { + return true, "file is outside tempDir" + } + + return false, "" +} + +// isPIDReused checks if a PID has been reused by comparing file modification time +// with process start time. Returns true if the log file was created by a different +// process that previously had the same PID. +func isPIDReused(logPath string, pid int) bool { + // Get file modification time (when log was last written) + info, err := fileStatFn(logPath) + if err != nil { + // If we can't stat the file, be conservative and keep it + return false + } + fileModTime := info.ModTime() + + // Get process start time + procStartTime := processStartTimeFn(pid) + if procStartTime.IsZero() { + // Can't determine process start time + // Check if file is very old (>7 days), likely from a dead process + if time.Since(fileModTime) > 7*24*time.Hour { + return true // File is old enough to be from a different process + } + return false // Be conservative for recent files + } + + // If the log file was modified before the process started, PID was reused + // Add a small buffer (1 second) to account for clock skew and file system timing + return fileModTime.Add(1 * time.Second).Before(procStartTime) +} + +func parsePIDFromLog(path string) (int, bool) { + name := filepath.Base(path) + if !strings.HasPrefix(name, "codex-wrapper-") || !strings.HasSuffix(name, ".log") { + return 0, false + } + + core := strings.TrimSuffix(strings.TrimPrefix(name, "codex-wrapper-"), ".log") + if core == "" { + return 0, false + } + + pidPart := core + if idx := strings.IndexRune(core, '-'); idx != -1 { + pidPart = core[:idx] + } + + if pidPart == "" { + return 0, false + } + + pid, err := strconv.Atoi(pidPart) + if err != nil || pid <= 0 { + return 0, false + } + + return pid, true +} diff --git a/codeagent-wrapper/logger_test.go b/codeagent-wrapper/logger_test.go index 213a6b0..2e5406c 100644 --- a/codeagent-wrapper/logger_test.go +++ b/codeagent-wrapper/logger_test.go @@ -2,17 +2,31 @@ package main import ( "bufio" + "errors" "fmt" + "math" "os" "os/exec" "path/filepath" + "strconv" "strings" "sync" "testing" "time" ) -func TestLoggerCreatesFileWithPID(t *testing.T) { +func compareCleanupStats(got, want CleanupStats) bool { + if got.Scanned != want.Scanned || got.Deleted != want.Deleted || got.Kept != want.Kept || got.Errors != want.Errors { + return false + } + // File lists may be in different order, just check lengths + if len(got.DeletedFiles) != want.Deleted || len(got.KeptFiles) != want.Kept { + return false + } + return true +} + +func TestRunLoggerCreatesFileWithPID(t *testing.T) { tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -32,7 +46,7 @@ func TestLoggerCreatesFileWithPID(t *testing.T) { } } -func TestLoggerWritesLevels(t *testing.T) { +func TestRunLoggerWritesLevels(t *testing.T) { tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -63,7 +77,7 @@ func TestLoggerWritesLevels(t *testing.T) { } } -func TestLoggerCloseRemovesFileAndStopsWorker(t *testing.T) { +func TestRunLoggerCloseRemovesFileAndStopsWorker(t *testing.T) { tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -102,7 +116,7 @@ func TestLoggerCloseRemovesFileAndStopsWorker(t *testing.T) { } } -func TestLoggerConcurrentWritesSafe(t *testing.T) { +func TestRunLoggerConcurrentWritesSafe(t *testing.T) { tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -151,7 +165,7 @@ func TestLoggerConcurrentWritesSafe(t *testing.T) { } } -func TestLoggerTerminateProcessActive(t *testing.T) { +func TestRunLoggerTerminateProcessActive(t *testing.T) { cmd := exec.Command("sleep", "5") if err := cmd.Start(); err != nil { t.Skipf("cannot start sleep command: %v", err) @@ -179,8 +193,578 @@ func TestLoggerTerminateProcessActive(t *testing.T) { time.Sleep(10 * time.Millisecond) } +func TestRunTerminateProcessNil(t *testing.T) { + if timer := terminateProcess(nil); timer != nil { + t.Fatalf("terminateProcess(nil) should return nil timer") + } + if timer := terminateProcess(&exec.Cmd{}); timer != nil { + t.Fatalf("terminateProcess with nil process should return nil timer") + } +} + +func TestRunCleanupOldLogsRemovesOrphans(t *testing.T) { + tempDir := setTempDirEnv(t, t.TempDir()) + + orphan1 := createTempLog(t, tempDir, "codex-wrapper-111.log") + orphan2 := createTempLog(t, tempDir, "codex-wrapper-222-suffix.log") + running1 := createTempLog(t, tempDir, "codex-wrapper-333.log") + running2 := createTempLog(t, tempDir, "codex-wrapper-444-extra-info.log") + untouched := createTempLog(t, tempDir, "unrelated.log") + + runningPIDs := map[int]bool{333: true, 444: true} + stubProcessRunning(t, func(pid int) bool { + return runningPIDs[pid] + }) + + // Stub process start time to be in the past so files won't be considered as PID reused + stubProcessStartTime(t, func(pid int) time.Time { + if runningPIDs[pid] { + // Return a time before file creation + return time.Now().Add(-1 * time.Hour) + } + return time.Time{} + }) + + stats, err := cleanupOldLogs() + if err != nil { + t.Fatalf("cleanupOldLogs() unexpected error: %v", err) + } + + want := CleanupStats{Scanned: 4, Deleted: 2, Kept: 2} + if !compareCleanupStats(stats, want) { + t.Fatalf("cleanup stats mismatch: got %+v, want %+v", stats, want) + } + + if _, err := os.Stat(orphan1); !os.IsNotExist(err) { + t.Fatalf("expected orphan %s to be removed, err=%v", orphan1, err) + } + if _, err := os.Stat(orphan2); !os.IsNotExist(err) { + t.Fatalf("expected orphan %s to be removed, err=%v", orphan2, err) + } + if _, err := os.Stat(running1); err != nil { + t.Fatalf("expected running log %s to remain, err=%v", running1, err) + } + if _, err := os.Stat(running2); err != nil { + t.Fatalf("expected running log %s to remain, err=%v", running2, err) + } + if _, err := os.Stat(untouched); err != nil { + t.Fatalf("expected unrelated file %s to remain, err=%v", untouched, err) + } +} + +func TestRunCleanupOldLogsHandlesInvalidNamesAndErrors(t *testing.T) { + tempDir := setTempDirEnv(t, t.TempDir()) + + invalid := []string{ + "codex-wrapper-.log", + "codex-wrapper.log", + "codex-wrapper-foo-bar.txt", + "not-a-codex.log", + } + for _, name := range invalid { + createTempLog(t, tempDir, name) + } + target := createTempLog(t, tempDir, "codex-wrapper-555-extra.log") + + var checked []int + stubProcessRunning(t, func(pid int) bool { + checked = append(checked, pid) + return false + }) + + stubProcessStartTime(t, func(pid int) time.Time { + return time.Time{} // Return zero time for processes not running + }) + + removeErr := errors.New("remove failure") + callCount := 0 + stubRemoveLogFile(t, func(path string) error { + callCount++ + if path == target { + return removeErr + } + return os.Remove(path) + }) + + stats, err := cleanupOldLogs() + if err == nil { + t.Fatalf("cleanupOldLogs() expected error") + } + if !errors.Is(err, removeErr) { + t.Fatalf("cleanupOldLogs error = %v, want %v", err, removeErr) + } + + want := CleanupStats{Scanned: 2, Kept: 1, Errors: 1} + if !compareCleanupStats(stats, want) { + t.Fatalf("cleanup stats mismatch: got %+v, want %+v", stats, want) + } + + if len(checked) != 1 || checked[0] != 555 { + t.Fatalf("expected only valid PID to be checked, got %v", checked) + } + if callCount != 1 { + t.Fatalf("expected remove to be called once, got %d", callCount) + } + if _, err := os.Stat(target); err != nil { + t.Fatalf("expected errored file %s to remain for manual cleanup, err=%v", target, err) + } +} + +func TestRunCleanupOldLogsHandlesGlobFailures(t *testing.T) { + stubProcessRunning(t, func(pid int) bool { + t.Fatalf("process check should not run when glob fails") + return false + }) + stubProcessStartTime(t, func(int) time.Time { + return time.Time{} + }) + + globErr := errors.New("glob failure") + stubGlobLogFiles(t, func(pattern string) ([]string, error) { + return nil, globErr + }) + + stats, err := cleanupOldLogs() + if err == nil { + t.Fatalf("cleanupOldLogs() expected error") + } + if !errors.Is(err, globErr) { + t.Fatalf("cleanupOldLogs error = %v, want %v", err, globErr) + } + if stats.Scanned != 0 || stats.Deleted != 0 || stats.Kept != 0 || stats.Errors != 0 || len(stats.DeletedFiles) != 0 || len(stats.KeptFiles) != 0 { + t.Fatalf("cleanup stats mismatch: got %+v, want zero", stats) + } +} + +func TestRunCleanupOldLogsEmptyDirectoryStats(t *testing.T) { + setTempDirEnv(t, t.TempDir()) + + stubProcessRunning(t, func(int) bool { + t.Fatalf("process check should not run for empty directory") + return false + }) + stubProcessStartTime(t, func(int) time.Time { + return time.Time{} + }) + + stats, err := cleanupOldLogs() + if err != nil { + t.Fatalf("cleanupOldLogs() unexpected error: %v", err) + } + if stats.Scanned != 0 || stats.Deleted != 0 || stats.Kept != 0 || stats.Errors != 0 || len(stats.DeletedFiles) != 0 || len(stats.KeptFiles) != 0 { + t.Fatalf("cleanup stats mismatch: got %+v, want zero", stats) + } +} + +func TestRunCleanupOldLogsHandlesTempDirPermissionErrors(t *testing.T) { + tempDir := setTempDirEnv(t, t.TempDir()) + + paths := []string{ + createTempLog(t, tempDir, "codex-wrapper-6100.log"), + createTempLog(t, tempDir, "codex-wrapper-6101.log"), + } + + stubProcessRunning(t, func(int) bool { return false }) + stubProcessStartTime(t, func(int) time.Time { return time.Time{} }) + + var attempts int + stubRemoveLogFile(t, func(path string) error { + attempts++ + return &os.PathError{Op: "remove", Path: path, Err: os.ErrPermission} + }) + + stats, err := cleanupOldLogs() + if err == nil { + t.Fatalf("cleanupOldLogs() expected error") + } + if !errors.Is(err, os.ErrPermission) { + t.Fatalf("cleanupOldLogs error = %v, want permission", err) + } + + want := CleanupStats{Scanned: len(paths), Errors: len(paths)} + if !compareCleanupStats(stats, want) { + t.Fatalf("cleanup stats mismatch: got %+v, want %+v", stats, want) + } + + if attempts != len(paths) { + t.Fatalf("expected %d attempts, got %d", len(paths), attempts) + } + for _, path := range paths { + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected protected file %s to remain, err=%v", path, err) + } + } +} + +func TestRunCleanupOldLogsHandlesPermissionDeniedFile(t *testing.T) { + tempDir := setTempDirEnv(t, t.TempDir()) + + protected := createTempLog(t, tempDir, "codex-wrapper-6200.log") + deletable := createTempLog(t, tempDir, "codex-wrapper-6201.log") + + stubProcessRunning(t, func(int) bool { return false }) + stubProcessStartTime(t, func(int) time.Time { return time.Time{} }) + + stubRemoveLogFile(t, func(path string) error { + if path == protected { + return &os.PathError{Op: "remove", Path: path, Err: os.ErrPermission} + } + return os.Remove(path) + }) + + stats, err := cleanupOldLogs() + if err == nil { + t.Fatalf("cleanupOldLogs() expected error") + } + if !errors.Is(err, os.ErrPermission) { + t.Fatalf("cleanupOldLogs error = %v, want permission", err) + } + + want := CleanupStats{Scanned: 2, Deleted: 1, Errors: 1} + if !compareCleanupStats(stats, want) { + t.Fatalf("cleanup stats mismatch: got %+v, want %+v", stats, want) + } + + if _, err := os.Stat(protected); err != nil { + t.Fatalf("expected protected file to remain, err=%v", err) + } + if _, err := os.Stat(deletable); !os.IsNotExist(err) { + t.Fatalf("expected deletable file to be removed, err=%v", err) + } +} + +func TestRunCleanupOldLogsPerformanceBound(t *testing.T) { + tempDir := setTempDirEnv(t, t.TempDir()) + + const fileCount = 400 + fakePaths := make([]string, fileCount) + for i := 0; i < fileCount; i++ { + name := fmt.Sprintf("codex-wrapper-%d.log", 10000+i) + fakePaths[i] = createTempLog(t, tempDir, name) + } + + stubGlobLogFiles(t, func(pattern string) ([]string, error) { + return fakePaths, nil + }) + stubProcessRunning(t, func(int) bool { return false }) + stubProcessStartTime(t, func(int) time.Time { return time.Time{} }) + + var removed int + stubRemoveLogFile(t, func(path string) error { + removed++ + return nil + }) + + start := time.Now() + stats, err := cleanupOldLogs() + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("cleanupOldLogs() unexpected error: %v", err) + } + + if removed != fileCount { + t.Fatalf("expected %d removals, got %d", fileCount, removed) + } + if elapsed > 100*time.Millisecond { + t.Fatalf("cleanup took too long: %v for %d files", elapsed, fileCount) + } + + want := CleanupStats{Scanned: fileCount, Deleted: fileCount} + if !compareCleanupStats(stats, want) { + t.Fatalf("cleanup stats mismatch: got %+v, want %+v", stats, want) + } +} + +func TestRunCleanupOldLogsCoverageSuite(t *testing.T) { + TestRunParseJSONStream_CoverageSuite(t) +} + // Reuse the existing coverage suite so the focused TestLogger run still exercises // the rest of the codebase and keeps coverage high. -func TestLoggerCoverageSuite(t *testing.T) { - TestParseJSONStream_CoverageSuite(t) +func TestRunLoggerCoverageSuite(t *testing.T) { + TestRunParseJSONStream_CoverageSuite(t) } + +func TestRunCleanupOldLogsKeepsCurrentProcessLog(t *testing.T) { + tempDir := setTempDirEnv(t, t.TempDir()) + + currentPID := os.Getpid() + currentLog := createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", currentPID)) + + stubProcessRunning(t, func(pid int) bool { + if pid != currentPID { + t.Fatalf("unexpected pid check: %d", pid) + } + return true + }) + stubProcessStartTime(t, func(pid int) time.Time { + if pid == currentPID { + return time.Now().Add(-1 * time.Hour) + } + return time.Time{} + }) + + stats, err := cleanupOldLogs() + if err != nil { + t.Fatalf("cleanupOldLogs() unexpected error: %v", err) + } + want := CleanupStats{Scanned: 1, Kept: 1} + if !compareCleanupStats(stats, want) { + t.Fatalf("cleanup stats mismatch: got %+v, want %+v", stats, want) + } + if _, err := os.Stat(currentLog); err != nil { + t.Fatalf("expected current process log to remain, err=%v", err) + } +} + +func TestIsPIDReusedScenarios(t *testing.T) { + now := time.Now() + tests := []struct { + name string + statErr error + modTime time.Time + startTime time.Time + want bool + }{ + {"stat error", errors.New("stat failed"), time.Time{}, time.Time{}, false}, + {"old file unknown start", nil, now.Add(-8 * 24 * time.Hour), time.Time{}, true}, + {"recent file unknown start", nil, now.Add(-2 * time.Hour), time.Time{}, false}, + {"pid reused", nil, now.Add(-2 * time.Hour), now.Add(-30 * time.Minute), true}, + {"pid active", nil, now.Add(-30 * time.Minute), now.Add(-2 * time.Hour), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stubFileStat(t, func(string) (os.FileInfo, error) { + if tt.statErr != nil { + return nil, tt.statErr + } + return fakeFileInfo{modTime: tt.modTime}, nil + }) + stubProcessStartTime(t, func(int) time.Time { + return tt.startTime + }) + if got := isPIDReused("log", 1234); got != tt.want { + t.Fatalf("isPIDReused() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsUnsafeFileSecurityChecks(t *testing.T) { + tempDir := t.TempDir() + absTempDir, err := filepath.Abs(tempDir) + if err != nil { + t.Fatalf("filepath.Abs() error = %v", err) + } + + t.Run("symlink", func(t *testing.T) { + stubFileStat(t, func(string) (os.FileInfo, error) { + return fakeFileInfo{mode: os.ModeSymlink}, nil + }) + stubEvalSymlinks(t, func(path string) (string, error) { + return filepath.Join(absTempDir, filepath.Base(path)), nil + }) + unsafe, reason := isUnsafeFile(filepath.Join(absTempDir, "codex-wrapper-1.log"), tempDir) + if !unsafe || reason != "refusing to delete symlink" { + t.Fatalf("expected symlink to be rejected, got unsafe=%v reason=%q", unsafe, reason) + } + }) + + t.Run("path traversal", func(t *testing.T) { + stubFileStat(t, func(string) (os.FileInfo, error) { + return fakeFileInfo{}, nil + }) + outside := filepath.Join(filepath.Dir(absTempDir), "etc", "passwd") + stubEvalSymlinks(t, func(string) (string, error) { + return outside, nil + }) + unsafe, reason := isUnsafeFile(filepath.Join("..", "..", "etc", "passwd"), tempDir) + if !unsafe || reason != "file is outside tempDir" { + t.Fatalf("expected traversal path to be rejected, got unsafe=%v reason=%q", unsafe, reason) + } + }) + + t.Run("outside temp dir", func(t *testing.T) { + stubFileStat(t, func(string) (os.FileInfo, error) { + return fakeFileInfo{}, nil + }) + otherDir := t.TempDir() + stubEvalSymlinks(t, func(string) (string, error) { + return filepath.Join(otherDir, "codex-wrapper-9.log"), nil + }) + unsafe, reason := isUnsafeFile(filepath.Join(otherDir, "codex-wrapper-9.log"), tempDir) + if !unsafe || reason != "file is outside tempDir" { + t.Fatalf("expected outside file to be rejected, got unsafe=%v reason=%q", unsafe, reason) + } + }) +} + +func TestRunLoggerPathAndRemove(t *testing.T) { + tempDir := t.TempDir() + path := filepath.Join(tempDir, "sample.log") + if err := os.WriteFile(path, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + logger := &Logger{path: path} + if got := logger.Path(); got != path { + t.Fatalf("Path() = %q, want %q", got, path) + } + if err := logger.RemoveLogFile(); err != nil { + t.Fatalf("RemoveLogFile() error = %v", err) + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("expected log file to be removed, err=%v", err) + } + + var nilLogger *Logger + if nilLogger.Path() != "" { + t.Fatalf("nil logger Path() should be empty") + } + if err := nilLogger.RemoveLogFile(); err != nil { + t.Fatalf("nil logger RemoveLogFile() should return nil, got %v", err) + } +} + +func TestRunLoggerInternalLog(t *testing.T) { + logger := &Logger{ + ch: make(chan logEntry, 1), + done: make(chan struct{}), + pendingWG: sync.WaitGroup{}, + } + + done := make(chan logEntry, 1) + go func() { + entry := <-logger.ch + logger.pendingWG.Done() + done <- entry + }() + + logger.log("INFO", "hello") + entry := <-done + if entry.level != "INFO" || entry.msg != "hello" { + t.Fatalf("unexpected entry %+v", entry) + } + + logger.closed.Store(true) + logger.log("INFO", "ignored") + close(logger.done) +} + +func TestRunParsePIDFromLog(t *testing.T) { + hugePID := strconv.FormatInt(math.MaxInt64, 10) + "0" + tests := []struct { + name string + pid int + ok bool + }{ + {"codex-wrapper-123.log", 123, true}, + {"codex-wrapper-999-extra.log", 999, true}, + {"codex-wrapper-.log", 0, false}, + {"invalid-name.log", 0, false}, + {"codex-wrapper--5.log", 0, false}, + {"codex-wrapper-0.log", 0, false}, + {fmt.Sprintf("codex-wrapper-%s.log", hugePID), 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := parsePIDFromLog(filepath.Join("/tmp", tt.name)) + if ok != tt.ok { + t.Fatalf("parsePIDFromLog ok = %v, want %v", ok, tt.ok) + } + if ok && got != tt.pid { + t.Fatalf("pid = %d, want %d", got, tt.pid) + } + }) + } +} + +func createTempLog(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create temp log %s: %v", path, err) + } + return path +} + +func setTempDirEnv(t *testing.T, dir string) string { + t.Helper() + resolved := dir + if eval, err := filepath.EvalSymlinks(dir); err == nil { + resolved = eval + } + t.Setenv("TMPDIR", resolved) + t.Setenv("TEMP", resolved) + t.Setenv("TMP", resolved) + return resolved +} + +func stubProcessRunning(t *testing.T, fn func(int) bool) { + t.Helper() + original := processRunningCheck + processRunningCheck = fn + t.Cleanup(func() { + processRunningCheck = original + }) +} + +func stubProcessStartTime(t *testing.T, fn func(int) time.Time) { + t.Helper() + original := processStartTimeFn + processStartTimeFn = fn + t.Cleanup(func() { + processStartTimeFn = original + }) +} + +func stubRemoveLogFile(t *testing.T, fn func(string) error) { + t.Helper() + original := removeLogFileFn + removeLogFileFn = fn + t.Cleanup(func() { + removeLogFileFn = original + }) +} + +func stubGlobLogFiles(t *testing.T, fn func(string) ([]string, error)) { + t.Helper() + original := globLogFiles + globLogFiles = fn + t.Cleanup(func() { + globLogFiles = original + }) +} + +func stubFileStat(t *testing.T, fn func(string) (os.FileInfo, error)) { + t.Helper() + original := fileStatFn + fileStatFn = fn + t.Cleanup(func() { + fileStatFn = original + }) +} + +func stubEvalSymlinks(t *testing.T, fn func(string) (string, error)) { + t.Helper() + original := evalSymlinksFn + evalSymlinksFn = fn + t.Cleanup(func() { + evalSymlinksFn = original + }) +} + +type fakeFileInfo struct { + modTime time.Time + mode os.FileMode +} + +func (f fakeFileInfo) Name() string { return "fake" } +func (f fakeFileInfo) Size() int64 { return 0 } +func (f fakeFileInfo) Mode() os.FileMode { return f.mode } +func (f fakeFileInfo) ModTime() time.Time { return f.modTime } +func (f fakeFileInfo) IsDir() bool { return false } +func (f fakeFileInfo) Sys() interface{} { return nil } diff --git a/codeagent-wrapper/main_integration_test.go b/codeagent-wrapper/main_integration_test.go index 987b646..a5083cd 100644 --- a/codeagent-wrapper/main_integration_test.go +++ b/codeagent-wrapper/main_integration_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "path/filepath" "strings" "sync" "sync/atomic" @@ -79,6 +80,8 @@ func parseIntegrationOutput(t *testing.T, out string) integrationOutput { currentTask.Error = strings.TrimPrefix(line, "Error: ") } else if strings.HasPrefix(line, "Session:") { currentTask.SessionID = strings.TrimPrefix(line, "Session: ") + } else if strings.HasPrefix(line, "Log:") { + currentTask.LogPath = strings.TrimSpace(strings.TrimPrefix(line, "Log:")) } else if line != "" && !strings.HasPrefix(line, "===") && !strings.HasPrefix(line, "---") { if currentTask.Message != "" { currentTask.Message += "\n" @@ -95,6 +98,32 @@ func parseIntegrationOutput(t *testing.T, out string) integrationOutput { return payload } +func extractTaskBlock(t *testing.T, output, taskID string) string { + t.Helper() + header := fmt.Sprintf("--- Task: %s ---", taskID) + lines := strings.Split(output, "\n") + var block []string + collecting := false + for _, raw := range lines { + trimmed := strings.TrimSpace(raw) + if !collecting { + if trimmed == header { + collecting = true + block = append(block, trimmed) + } + continue + } + if strings.HasPrefix(trimmed, "--- Task: ") && trimmed != header { + break + } + block = append(block, trimmed) + } + if len(block) == 0 { + t.Fatalf("task block %s not found in output:\n%s", taskID, output) + } + return strings.Join(block, "\n") +} + func findResultByID(t *testing.T, payload integrationOutput, id string) TaskResult { t.Helper() for _, res := range payload.Results { @@ -106,7 +135,7 @@ func findResultByID(t *testing.T, payload integrationOutput, id string) TaskResu return TaskResult{} } -func TestParallelEndToEnd_OrderAndConcurrency(t *testing.T) { +func TestRunParallelEndToEnd_OrderAndConcurrency(t *testing.T) { defer resetTestHooks() origRun := runCodexTaskFn t.Cleanup(func() { @@ -217,7 +246,7 @@ task-e` } } -func TestParallelCycleDetectionStopsExecution(t *testing.T) { +func TestRunParallelCycleDetectionStopsExecution(t *testing.T) { defer resetTestHooks() origRun := runCodexTaskFn runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { @@ -255,7 +284,7 @@ b` } } -func TestParallelPartialFailureBlocksDependents(t *testing.T) { +func TestRunParallelOutputsIncludeLogPaths(t *testing.T) { defer resetTestHooks() origRun := runCodexTaskFn t.Cleanup(func() { @@ -263,11 +292,205 @@ func TestParallelPartialFailureBlocksDependents(t *testing.T) { resetTestHooks() }) + tempDir := t.TempDir() + logPathFor := func(id string) string { + return filepath.Join(tempDir, fmt.Sprintf("%s.log", id)) + } + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { - if task.ID == "A" { - return TaskResult{TaskID: "A", ExitCode: 2, Error: "boom"} + res := TaskResult{ + TaskID: task.ID, + Message: fmt.Sprintf("result-%s", task.ID), + SessionID: fmt.Sprintf("session-%s", task.ID), + LogPath: logPathFor(task.ID), } - return TaskResult{TaskID: task.ID, ExitCode: 0, Message: task.Task} + if task.ID == "beta" { + res.ExitCode = 9 + res.Error = "boom" + } + return res + } + + input := `---TASK--- +id: alpha +---CONTENT--- +task-alpha +---TASK--- +id: beta +---CONTENT--- +task-beta` + stdinReader = bytes.NewReader([]byte(input)) + os.Args = []string{"codex-wrapper", "--parallel"} + + var exitCode int + output := captureStdout(t, func() { + exitCode = run() + }) + + if exitCode != 9 { + t.Fatalf("parallel run exit=%d, want 9", exitCode) + } + + payload := parseIntegrationOutput(t, output) + alpha := findResultByID(t, payload, "alpha") + beta := findResultByID(t, payload, "beta") + + if alpha.LogPath != logPathFor("alpha") { + t.Fatalf("alpha log path = %q, want %q", alpha.LogPath, logPathFor("alpha")) + } + if beta.LogPath != logPathFor("beta") { + t.Fatalf("beta log path = %q, want %q", beta.LogPath, logPathFor("beta")) + } + + for _, id := range []string{"alpha", "beta"} { + want := fmt.Sprintf("Log: %s", logPathFor(id)) + if !strings.Contains(output, want) { + t.Fatalf("parallel output missing %q for %s:\n%s", want, id, output) + } + } +} + +func TestRunParallelStartupLogsPrinted(t *testing.T) { + defer resetTestHooks() + + tempDir := setTempDirEnv(t, t.TempDir()) + input := `---TASK--- +id: a +---CONTENT--- +fail +---TASK--- +id: b +---CONTENT--- +ok-b +---TASK--- +id: c +dependencies: a +---CONTENT--- +should-skip +---TASK--- +id: d +---CONTENT--- +ok-d` + stdinReader = bytes.NewReader([]byte(input)) + os.Args = []string{"codex-wrapper", "--parallel"} + + expectedLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + + origRun := runCodexTaskFn + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + path := expectedLog + if logger := activeLogger(); logger != nil && logger.Path() != "" { + path = logger.Path() + } + if task.ID == "a" { + return TaskResult{TaskID: task.ID, ExitCode: 3, Error: "boom", LogPath: path} + } + return TaskResult{TaskID: task.ID, ExitCode: 0, Message: task.Task, LogPath: path} + } + t.Cleanup(func() { runCodexTaskFn = origRun }) + + var exitCode int + var stdoutOut string + stderrOut := captureStderr(t, func() { + stdoutOut = captureStdout(t, func() { + exitCode = run() + }) + }) + + if exitCode == 0 { + t.Fatalf("expected non-zero exit due to task failure, got %d", exitCode) + } + if stdoutOut == "" { + t.Fatalf("expected parallel summary on stdout") + } + + lines := strings.Split(strings.TrimSpace(stderrOut), "\n") + var bannerSeen bool + var taskLines []string + for _, raw := range lines { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + if line == "=== Starting Parallel Execution ===" { + if bannerSeen { + t.Fatalf("banner printed multiple times:\n%s", stderrOut) + } + bannerSeen = true + continue + } + taskLines = append(taskLines, line) + } + + if !bannerSeen { + t.Fatalf("expected startup banner in stderr, got:\n%s", stderrOut) + } + + expectedLines := map[string]struct{}{ + fmt.Sprintf("Task a: Log: %s", expectedLog): {}, + fmt.Sprintf("Task b: Log: %s", expectedLog): {}, + fmt.Sprintf("Task d: Log: %s", expectedLog): {}, + } + + if len(taskLines) != len(expectedLines) { + t.Fatalf("startup log lines mismatch, got %d lines:\n%s", len(taskLines), stderrOut) + } + + for _, line := range taskLines { + if _, ok := expectedLines[line]; !ok { + t.Fatalf("unexpected startup line %q\nstderr:\n%s", line, stderrOut) + } + } +} + +func TestRunNonParallelOutputsIncludeLogPathsIntegration(t *testing.T) { + defer resetTestHooks() + + tempDir := setTempDirEnv(t, t.TempDir()) + os.Args = []string{"codex-wrapper", "integration-log-check"} + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{`{"type":"thread.started","thread_id":"integration-session"}` + "\n" + `{"type":"item.completed","item":{"type":"agent_message","text":"done"}}`} + } + + var exitCode int + stderr := captureStderr(t, func() { + _ = captureStdout(t, func() { + exitCode = run() + }) + }) + + if exitCode != 0 { + t.Fatalf("run() exit=%d, want 0", exitCode) + } + expectedLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + wantLine := fmt.Sprintf("Log: %s", expectedLog) + if !strings.Contains(stderr, wantLine) { + t.Fatalf("stderr missing %q, got: %q", wantLine, stderr) + } +} + +func TestRunParallelPartialFailureBlocksDependents(t *testing.T) { + defer resetTestHooks() + origRun := runCodexTaskFn + t.Cleanup(func() { + runCodexTaskFn = origRun + resetTestHooks() + }) + + tempDir := t.TempDir() + logPathFor := func(id string) string { + return filepath.Join(tempDir, fmt.Sprintf("%s.log", id)) + } + + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + path := logPathFor(task.ID) + if task.ID == "A" { + return TaskResult{TaskID: "A", ExitCode: 2, Error: "boom", LogPath: path} + } + return TaskResult{TaskID: task.ID, ExitCode: 0, Message: task.Task, LogPath: path} } input := `---TASK--- @@ -317,9 +540,29 @@ ok-e` if payload.Summary.Failed != 2 || payload.Summary.Total != 4 { t.Fatalf("unexpected summary after partial failure: %+v", payload.Summary) } + if resA.LogPath != logPathFor("A") { + t.Fatalf("task A log path = %q, want %q", resA.LogPath, logPathFor("A")) + } + if resB.LogPath != "" { + t.Fatalf("task B should not report a log path when skipped, got %q", resB.LogPath) + } + if resD.LogPath != logPathFor("D") || resE.LogPath != logPathFor("E") { + t.Fatalf("expected log paths for D/E, got D=%q E=%q", resD.LogPath, resE.LogPath) + } + for _, id := range []string{"A", "D", "E"} { + block := extractTaskBlock(t, output, id) + want := fmt.Sprintf("Log: %s", logPathFor(id)) + if !strings.Contains(block, want) { + t.Fatalf("task %s block missing %q:\n%s", id, want, block) + } + } + blockB := extractTaskBlock(t, output, "B") + if strings.Contains(blockB, "Log:") { + t.Fatalf("skipped task B should not emit a log line:\n%s", blockB) + } } -func TestParallelTimeoutPropagation(t *testing.T) { +func TestRunParallelTimeoutPropagation(t *testing.T) { defer resetTestHooks() origRun := runCodexTaskFn t.Cleanup(func() { @@ -363,7 +606,7 @@ slow` } } -func TestConcurrentSpeedupBenchmark(t *testing.T) { +func TestRunConcurrentSpeedupBenchmark(t *testing.T) { defer resetTestHooks() origRun := runCodexTaskFn t.Cleanup(func() { @@ -398,3 +641,210 @@ func TestConcurrentSpeedupBenchmark(t *testing.T) { ratio := float64(concurrentElapsed) / float64(serialElapsed) t.Logf("speedup ratio (concurrent/serial)=%.3f", ratio) } + +func TestRunStartupCleanupRemovesOrphansEndToEnd(t *testing.T) { + defer resetTestHooks() + + tempDir := setTempDirEnv(t, t.TempDir()) + + orphanA := createTempLog(t, tempDir, "codex-wrapper-5001.log") + orphanB := createTempLog(t, tempDir, "codex-wrapper-5002-extra.log") + orphanC := createTempLog(t, tempDir, "codex-wrapper-5003-suffix.log") + runningPID := 81234 + runningLog := createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", runningPID)) + unrelated := createTempLog(t, tempDir, "wrapper.log") + + stubProcessRunning(t, func(pid int) bool { + return pid == runningPID || pid == os.Getpid() + }) + stubProcessStartTime(t, func(pid int) time.Time { + if pid == runningPID || pid == os.Getpid() { + return time.Now().Add(-1 * time.Hour) + } + return time.Time{} + }) + + codexCommand = createFakeCodexScript(t, "tid-startup", "ok") + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + os.Args = []string{"codex-wrapper", "task"} + + if exit := run(); exit != 0 { + t.Fatalf("run() exit=%d, want 0", exit) + } + + for _, orphan := range []string{orphanA, orphanB, orphanC} { + if _, err := os.Stat(orphan); !os.IsNotExist(err) { + t.Fatalf("expected orphan %s to be removed, err=%v", orphan, err) + } + } + if _, err := os.Stat(runningLog); err != nil { + t.Fatalf("expected running log to remain, err=%v", err) + } + if _, err := os.Stat(unrelated); err != nil { + t.Fatalf("expected unrelated file to remain, err=%v", err) + } +} + +func TestRunStartupCleanupConcurrentWrappers(t *testing.T) { + defer resetTestHooks() + + tempDir := setTempDirEnv(t, t.TempDir()) + + const totalLogs = 40 + for i := 0; i < totalLogs; i++ { + createTempLog(t, tempDir, fmt.Sprintf("codex-wrapper-%d.log", 9000+i)) + } + + stubProcessRunning(t, func(pid int) bool { + return false + }) + stubProcessStartTime(t, func(int) time.Time { return time.Time{} }) + + var wg sync.WaitGroup + const instances = 5 + start := make(chan struct{}) + + for i := 0; i < instances; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + runStartupCleanup() + }() + } + + close(start) + wg.Wait() + + matches, err := filepath.Glob(filepath.Join(tempDir, "codex-wrapper-*.log")) + if err != nil { + t.Fatalf("glob error: %v", err) + } + if len(matches) != 0 { + t.Fatalf("expected all orphan logs to be removed, remaining=%v", matches) + } +} + +func TestRunCleanupFlagEndToEnd_Success(t *testing.T) { + defer resetTestHooks() + + tempDir := setTempDirEnv(t, t.TempDir()) + + staleA := createTempLog(t, tempDir, "codex-wrapper-2100.log") + staleB := createTempLog(t, tempDir, "codex-wrapper-2200-extra.log") + keeper := createTempLog(t, tempDir, "codex-wrapper-2300.log") + + stubProcessRunning(t, func(pid int) bool { + return pid == 2300 || pid == os.Getpid() + }) + stubProcessStartTime(t, func(pid int) time.Time { + if pid == 2300 || pid == os.Getpid() { + return time.Now().Add(-1 * time.Hour) + } + return time.Time{} + }) + + os.Args = []string{"codex-wrapper", "--cleanup"} + + var exitCode int + output := captureStdout(t, func() { + exitCode = run() + }) + + if exitCode != 0 { + t.Fatalf("cleanup exit = %d, want 0", exitCode) + } + + // Check that output contains expected counts and file names + if !strings.Contains(output, "Cleanup completed") { + t.Fatalf("missing 'Cleanup completed' in output: %q", output) + } + if !strings.Contains(output, "Files scanned: 3") { + t.Fatalf("missing 'Files scanned: 3' in output: %q", output) + } + if !strings.Contains(output, "Files deleted: 2") { + t.Fatalf("missing 'Files deleted: 2' in output: %q", output) + } + if !strings.Contains(output, "Files kept: 1") { + t.Fatalf("missing 'Files kept: 1' in output: %q", output) + } + if !strings.Contains(output, "codex-wrapper-2100.log") || !strings.Contains(output, "codex-wrapper-2200-extra.log") { + t.Fatalf("missing deleted file names in output: %q", output) + } + if !strings.Contains(output, "codex-wrapper-2300.log") { + t.Fatalf("missing kept file names in output: %q", output) + } + + for _, path := range []string{staleA, staleB} { + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("expected %s to be removed, err=%v", path, err) + } + } + if _, err := os.Stat(keeper); err != nil { + t.Fatalf("expected kept log to remain, err=%v", err) + } + + currentLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + if _, err := os.Stat(currentLog); err == nil { + t.Fatalf("cleanup mode should not create new log file %s", currentLog) + } else if !os.IsNotExist(err) { + t.Fatalf("stat(%s) unexpected error: %v", currentLog, err) + } +} + +func TestRunCleanupFlagEndToEnd_FailureDoesNotAffectStartup(t *testing.T) { + defer resetTestHooks() + + tempDir := setTempDirEnv(t, t.TempDir()) + + calls := 0 + cleanupLogsFn = func() (CleanupStats, error) { + calls++ + return CleanupStats{Scanned: 1}, fmt.Errorf("permission denied") + } + + os.Args = []string{"codex-wrapper", "--cleanup"} + + var exitCode int + errOutput := captureStderr(t, func() { + exitCode = run() + }) + + if exitCode != 1 { + t.Fatalf("cleanup failure exit = %d, want 1", exitCode) + } + if !strings.Contains(errOutput, "Cleanup failed") || !strings.Contains(errOutput, "permission denied") { + t.Fatalf("cleanup stderr = %q, want failure message", errOutput) + } + if calls != 1 { + t.Fatalf("cleanup called %d times, want 1", calls) + } + + currentLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + if _, err := os.Stat(currentLog); err == nil { + t.Fatalf("cleanup failure should not create new log file %s", currentLog) + } else if !os.IsNotExist(err) { + t.Fatalf("stat(%s) unexpected error: %v", currentLog, err) + } + + cleanupLogsFn = func() (CleanupStats, error) { + return CleanupStats{}, nil + } + codexCommand = createFakeCodexScript(t, "tid-cleanup-e2e", "ok") + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + os.Args = []string{"codex-wrapper", "post-cleanup task"} + + var normalExit int + normalOutput := captureStdout(t, func() { + normalExit = run() + }) + + if normalExit != 0 { + t.Fatalf("normal run exit = %d, want 0", normalExit) + } + if !strings.Contains(normalOutput, "ok") { + t.Fatalf("normal run output = %q, want codex output", normalOutput) + } +} diff --git a/codeagent-wrapper/main_test.go b/codeagent-wrapper/main_test.go index c1b5f01..2d1344f 100644 --- a/codeagent-wrapper/main_test.go +++ b/codeagent-wrapper/main_test.go @@ -12,6 +12,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "runtime" "strings" "sync" "sync/atomic" @@ -26,11 +27,17 @@ func resetTestHooks() { isTerminalFn = defaultIsTerminal codexCommand = "codex" cleanupHook = nil + cleanupLogsFn = cleanupOldLogs + signalNotifyFn = signal.Notify + signalStopFn = signal.Stop buildCodexArgsFn = buildCodexArgs selectBackendFn = selectBackend commandContext = exec.CommandContext + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return &realCmd{cmd: commandContext(ctx, name, args...)} + } jsonMarshal = json.Marshal - forceKillDelay = 5 + forceKillDelay.Store(5) closeLogger() } @@ -121,6 +128,444 @@ func captureOutput(t *testing.T, fn func()) string { return buf.String() } +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + r, w, _ := os.Pipe() + old := os.Stderr + os.Stderr = w + fn() + w.Close() + os.Stderr = old + + var buf bytes.Buffer + io.Copy(&buf, r) + return buf.String() +} + +type ctxAwareReader struct { + reader io.ReadCloser + mu sync.Mutex + reason string + closed bool +} + +func newCtxAwareReader(r io.ReadCloser) *ctxAwareReader { + return &ctxAwareReader{reader: r} +} + +func (r *ctxAwareReader) Read(p []byte) (int, error) { + if r.reader == nil { + return 0, io.EOF + } + return r.reader.Read(p) +} + +func (r *ctxAwareReader) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + if r.closed || r.reader == nil { + r.closed = true + return nil + } + r.closed = true + return r.reader.Close() +} + +func (r *ctxAwareReader) CloseWithReason(reason string) error { + r.mu.Lock() + if !r.closed { + r.reason = reason + } + r.mu.Unlock() + return r.Close() +} + +func (r *ctxAwareReader) Reason() string { + r.mu.Lock() + defer r.mu.Unlock() + return r.reason +} + +type drainBlockingStdout struct { + inner *ctxAwareReader +} + +func newDrainBlockingStdout(inner *ctxAwareReader) *drainBlockingStdout { + return &drainBlockingStdout{inner: inner} +} + +func (d *drainBlockingStdout) Read(p []byte) (int, error) { + return d.inner.Read(p) +} + +func (d *drainBlockingStdout) Close() error { + return d.inner.Close() +} + +func (d *drainBlockingStdout) CloseWithReason(reason string) error { + if reason != stdoutCloseReasonDrain { + return nil + } + return d.inner.CloseWithReason(reason) +} + +type drainBlockingCmd struct { + inner *fakeCmd + injected atomic.Bool +} + +func newDrainBlockingCmd(inner *fakeCmd) *drainBlockingCmd { + return &drainBlockingCmd{inner: inner} +} + +func (d *drainBlockingCmd) Start() error { + return d.inner.Start() +} + +func (d *drainBlockingCmd) Wait() error { + return d.inner.Wait() +} + +func (d *drainBlockingCmd) StdoutPipe() (io.ReadCloser, error) { + stdout, err := d.inner.StdoutPipe() + if err != nil { + return nil, err + } + ctxReader, ok := stdout.(*ctxAwareReader) + if !ok { + return stdout, nil + } + d.injected.Store(true) + return newDrainBlockingStdout(ctxReader), nil +} + +func (d *drainBlockingCmd) StdinPipe() (io.WriteCloser, error) { + return d.inner.StdinPipe() +} + +func (d *drainBlockingCmd) SetStderr(w io.Writer) { + d.inner.SetStderr(w) +} + +func (d *drainBlockingCmd) Process() processHandle { + return d.inner.Process() +} + +type bufferWriteCloser struct { + buf bytes.Buffer + mu sync.Mutex + closed bool +} + +func newBufferWriteCloser() *bufferWriteCloser { + return &bufferWriteCloser{} +} + +func (b *bufferWriteCloser) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, io.ErrClosedPipe + } + return b.buf.Write(p) +} + +func (b *bufferWriteCloser) Close() error { + b.mu.Lock() + b.closed = true + b.mu.Unlock() + return nil +} + +func (b *bufferWriteCloser) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + +type fakeProcess struct { + pid int + killed atomic.Bool + mu sync.Mutex + signals []os.Signal + signalCount atomic.Int32 + killCount atomic.Int32 + onSignal func(os.Signal) + onKill func() +} + +func newFakeProcess(pid int) *fakeProcess { + if pid == 0 { + pid = 4242 + } + return &fakeProcess{pid: pid} +} + +func (p *fakeProcess) Pid() int { + return p.pid +} + +func (p *fakeProcess) Kill() error { + p.killed.Store(true) + p.killCount.Add(1) + if p.onKill != nil { + p.onKill() + } + return nil +} + +func (p *fakeProcess) Signal(sig os.Signal) error { + p.mu.Lock() + p.signals = append(p.signals, sig) + p.mu.Unlock() + p.signalCount.Add(1) + if p.onSignal != nil { + p.onSignal(sig) + } + return nil +} + +func (p *fakeProcess) Signals() []os.Signal { + p.mu.Lock() + defer p.mu.Unlock() + cp := make([]os.Signal, len(p.signals)) + copy(cp, p.signals) + return cp +} + +func (p *fakeProcess) Killed() bool { + return p.killed.Load() +} + +func (p *fakeProcess) SignalCount() int { + return int(p.signalCount.Load()) +} + +func (p *fakeProcess) KillCount() int { + return int(p.killCount.Load()) +} + +type fakeStdoutEvent struct { + Delay time.Duration + Data string +} + +type fakeCmdConfig struct { + StdoutPlan []fakeStdoutEvent + WaitDelay time.Duration + WaitErr error + StartErr error + PID int + KeepStdoutOpen bool + BlockWait bool + ReleaseWaitOnKill bool + ReleaseWaitOnSignal bool +} + +type fakeCmd struct { + mu sync.Mutex + + stdout *ctxAwareReader + stdoutWriter *io.PipeWriter + stdoutPlan []fakeStdoutEvent + stdoutOnce sync.Once + stdoutClaim bool + keepStdoutOpen bool + + stdoutWriteMu sync.Mutex + + stdinWriter *bufferWriteCloser + stdinClaim bool + + stderr io.Writer + + waitDelay time.Duration + waitErr error + startErr error + + waitOnce sync.Once + waitDone chan struct{} + waitResult error + waitReleaseCh chan struct{} + waitReleaseOnce sync.Once + waitBlocked bool + + started bool + + startCount atomic.Int32 + waitCount atomic.Int32 + stdoutPipeCount atomic.Int32 + + process *fakeProcess +} + +func newFakeCmd(cfg fakeCmdConfig) *fakeCmd { + r, w := io.Pipe() + cmd := &fakeCmd{ + stdout: newCtxAwareReader(r), + stdoutWriter: w, + stdoutPlan: append([]fakeStdoutEvent(nil), cfg.StdoutPlan...), + stdinWriter: newBufferWriteCloser(), + waitDelay: cfg.WaitDelay, + waitErr: cfg.WaitErr, + startErr: cfg.StartErr, + waitDone: make(chan struct{}), + keepStdoutOpen: cfg.KeepStdoutOpen, + process: newFakeProcess(cfg.PID), + } + if len(cmd.stdoutPlan) == 0 { + cmd.stdoutPlan = nil + } + if cfg.BlockWait { + cmd.waitBlocked = true + cmd.waitReleaseCh = make(chan struct{}) + releaseOnSignal := cfg.ReleaseWaitOnSignal + releaseOnKill := cfg.ReleaseWaitOnKill + if !releaseOnSignal && !releaseOnKill { + releaseOnKill = true + } + cmd.process.onSignal = func(os.Signal) { + if releaseOnSignal { + cmd.releaseWait() + } + } + cmd.process.onKill = func() { + if releaseOnKill { + cmd.releaseWait() + } + } + } + return cmd +} + +func (f *fakeCmd) Start() error { + f.mu.Lock() + if f.started { + f.mu.Unlock() + return errors.New("start already called") + } + f.started = true + f.mu.Unlock() + + f.startCount.Add(1) + + if f.startErr != nil { + f.waitOnce.Do(func() { + f.waitResult = f.startErr + close(f.waitDone) + }) + return f.startErr + } + + go f.runStdoutScript() + return nil +} + +func (f *fakeCmd) Wait() error { + f.waitCount.Add(1) + f.waitOnce.Do(func() { + if f.waitBlocked && f.waitReleaseCh != nil { + <-f.waitReleaseCh + } else if f.waitDelay > 0 { + time.Sleep(f.waitDelay) + } + f.waitResult = f.waitErr + close(f.waitDone) + }) + <-f.waitDone + return f.waitResult +} + +func (f *fakeCmd) StdoutPipe() (io.ReadCloser, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.stdoutClaim { + return nil, errors.New("stdout pipe already claimed") + } + f.stdoutClaim = true + f.stdoutPipeCount.Add(1) + return f.stdout, nil +} + +func (f *fakeCmd) StdinPipe() (io.WriteCloser, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.stdinClaim { + return nil, errors.New("stdin pipe already claimed") + } + f.stdinClaim = true + return f.stdinWriter, nil +} + +func (f *fakeCmd) SetStderr(w io.Writer) { + f.stderr = w +} + +func (f *fakeCmd) Process() processHandle { + if f == nil { + return nil + } + return f.process +} + +func (f *fakeCmd) runStdoutScript() { + if len(f.stdoutPlan) == 0 { + if !f.keepStdoutOpen { + f.CloseStdout(nil) + } + return + } + for _, ev := range f.stdoutPlan { + if ev.Delay > 0 { + time.Sleep(ev.Delay) + } + f.WriteStdout(ev.Data) + } + if !f.keepStdoutOpen { + f.CloseStdout(nil) + } +} + +func (f *fakeCmd) releaseWait() { + if f.waitReleaseCh == nil { + return + } + f.waitReleaseOnce.Do(func() { + close(f.waitReleaseCh) + }) +} + +func (f *fakeCmd) WriteStdout(data string) { + if data == "" { + return + } + f.stdoutWriteMu.Lock() + defer f.stdoutWriteMu.Unlock() + if f.stdoutWriter != nil { + _, _ = io.WriteString(f.stdoutWriter, data) + } +} + +func (f *fakeCmd) CloseStdout(err error) { + f.stdoutOnce.Do(func() { + if f.stdoutWriter == nil { + return + } + if err != nil { + _ = f.stdoutWriter.CloseWithError(err) + return + } + _ = f.stdoutWriter.Close() + }) +} + +func (f *fakeCmd) StdinContents() string { + if f.stdinWriter == nil { + return "" + } + return f.stdinWriter.String() +} + func createFakeCodexScript(t *testing.T, threadID, message string) string { t.Helper() scriptPath := filepath.Join(t.TempDir(), "codex.sh") @@ -134,6 +579,296 @@ printf '%%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":" return scriptPath } +func TestFakeCmdInfra(t *testing.T) { + t.Run("pipes and wait scheduling", func(t *testing.T) { + fake := newFakeCmd(fakeCmdConfig{ + StdoutPlan: []fakeStdoutEvent{ + {Data: "line1\n"}, + {Delay: 5 * time.Millisecond, Data: "line2\n"}, + }, + WaitDelay: 20 * time.Millisecond, + }) + + stdout, err := fake.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe() error = %v", err) + } + + if err := fake.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + scanner := bufio.NewScanner(stdout) + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + if len(lines) == 2 { + break + } + } + if err := scanner.Err(); err != nil { + t.Fatalf("scanner error: %v", err) + } + if len(lines) != 2 || lines[0] != "line1" || lines[1] != "line2" { + t.Fatalf("unexpected stdout lines: %v", lines) + } + + ctxReader, ok := stdout.(*ctxAwareReader) + if !ok { + t.Fatalf("stdout pipe is %T, want *ctxAwareReader", stdout) + } + if err := ctxReader.CloseWithReason("test-complete"); err != nil { + t.Fatalf("CloseWithReason error: %v", err) + } + if ctxReader.Reason() != "test-complete" { + t.Fatalf("CloseWithReason reason mismatch: %q", ctxReader.Reason()) + } + + waitStart := time.Now() + if err := fake.Wait(); err != nil { + t.Fatalf("Wait() error = %v", err) + } + if elapsed := time.Since(waitStart); elapsed < 20*time.Millisecond { + t.Fatalf("Wait() returned too early: %v", elapsed) + } + + if fake.startCount.Load() != 1 { + t.Fatalf("Start() count = %d, want 1", fake.startCount.Load()) + } + if fake.waitCount.Load() != 1 { + t.Fatalf("Wait() count = %d, want 1", fake.waitCount.Load()) + } + if fake.stdoutPipeCount.Load() != 1 { + t.Fatalf("StdoutPipe() count = %d, want 1", fake.stdoutPipeCount.Load()) + } + }) + + t.Run("integration with runCodexTask", func(t *testing.T) { + defer resetTestHooks() + + fake := newFakeCmd(fakeCmdConfig{ + StdoutPlan: []fakeStdoutEvent{ + {Data: `{"type":"thread.started","thread_id":"fake-thread"}` + "\n"}, + { + Delay: time.Millisecond, + Data: `{"type":"item.completed","item":{"type":"agent_message","text":"fake-msg"}}` + "\n", + }, + }, + WaitDelay: 5 * time.Millisecond, + }) + + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return fake + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{targetArg} + } + codexCommand = "fake-cmd" + + res := runCodexTask(TaskSpec{Task: "ignored"}, false, 2) + if res.ExitCode != 0 { + t.Fatalf("runCodexTask exit = %d, want 0 (%s)", res.ExitCode, res.Error) + } + if res.Message != "fake-msg" { + t.Fatalf("message = %q, want fake-msg", res.Message) + } + if res.SessionID != "fake-thread" { + t.Fatalf("sessionID = %q, want fake-thread", res.SessionID) + } + if fake.startCount.Load() != 1 { + t.Fatalf("Start() count = %d, want 1", fake.startCount.Load()) + } + if fake.waitCount.Load() != 1 { + t.Fatalf("Wait() count = %d, want 1", fake.waitCount.Load()) + } + }) +} + +func TestRunCodexTask_WaitBeforeParse(t *testing.T) { + defer resetTestHooks() + + const ( + threadID = "wait-first-thread" + message = "wait-first-message" + waitDelay = 100 * time.Millisecond + extraDelay = 2 * time.Second + ) + + fake := newFakeCmd(fakeCmdConfig{ + StdoutPlan: []fakeStdoutEvent{ + {Data: fmt.Sprintf(`{"type":"thread.started","thread_id":"%s"}`+"\n", threadID)}, + {Data: fmt.Sprintf(`{"type":"item.completed","item":{"type":"agent_message","text":"%s"}}`+"\n", message)}, + {Delay: extraDelay}, + }, + WaitDelay: waitDelay, + }) + + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return fake + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{targetArg} + } + codexCommand = "fake-cmd" + + start := time.Now() + result := runCodexTask(TaskSpec{Task: "ignored"}, false, 5) + elapsed := time.Since(start) + + if result.ExitCode != 0 { + t.Fatalf("runCodexTask exit = %d, want 0 (%s)", result.ExitCode, result.Error) + } + if result.Message != message { + t.Fatalf("message = %q, want %q", result.Message, message) + } + if result.SessionID != threadID { + t.Fatalf("sessionID = %q, want %q", result.SessionID, threadID) + } + if elapsed >= extraDelay { + t.Fatalf("runCodexTask took %v, want < %v", elapsed, extraDelay) + } + + if fake.stdout == nil { + t.Fatalf("stdout reader not initialized") + } + if reason := fake.stdout.Reason(); reason != stdoutCloseReasonWait { + t.Fatalf("stdout close reason = %q, want %q", reason, stdoutCloseReasonWait) + } +} + +func TestRunCodexTask_ParseStall(t *testing.T) { + defer resetTestHooks() + + const threadID = "stall-thread" + startG := runtime.NumGoroutine() + + fake := newFakeCmd(fakeCmdConfig{ + StdoutPlan: []fakeStdoutEvent{ + {Data: fmt.Sprintf(`{"type":"thread.started","thread_id":"%s"}`+"\n", threadID)}, + }, + KeepStdoutOpen: true, + }) + + blockingCmd := newDrainBlockingCmd(fake) + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return blockingCmd + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{targetArg} + } + codexCommand = "fake-cmd" + + start := time.Now() + result := runCodexTask(TaskSpec{Task: "stall"}, false, 60) + elapsed := time.Since(start) + if !blockingCmd.injected.Load() { + t.Fatalf("stdout wrapper was not installed") + } + + if result.ExitCode == 0 || result.Error == "" { + t.Fatalf("expected runCodexTask to error when parse stalls, got %+v", result) + } + errText := strings.ToLower(result.Error) + if !strings.Contains(errText, "drain timeout") && !strings.Contains(errText, "agent_message") { + t.Fatalf("error %q does not mention drain timeout or missing agent_message", result.Error) + } + + if elapsed < stdoutDrainTimeout { + t.Fatalf("runCodexTask returned after %v (reason=%s), want >= %v to confirm drainTimer firing", elapsed, fake.stdout.Reason(), stdoutDrainTimeout) + } + maxDuration := stdoutDrainTimeout + time.Second + if elapsed >= maxDuration { + t.Fatalf("runCodexTask took %v, want < %v", elapsed, maxDuration) + } + + if fake.stdout == nil { + t.Fatalf("stdout reader not initialized") + } + if !fake.stdout.closed { + t.Fatalf("stdout reader still open; drainTimer should force close") + } + if reason := fake.stdout.Reason(); reason != stdoutCloseReasonDrain { + t.Fatalf("stdout close reason = %q, want %q", reason, stdoutCloseReasonDrain) + } + + deadline := time.Now().Add(500 * time.Millisecond) + allowed := startG + 8 + finalG := runtime.NumGoroutine() + for finalG > allowed && time.Now().Before(deadline) { + runtime.Gosched() + time.Sleep(10 * time.Millisecond) + runtime.GC() + finalG = runtime.NumGoroutine() + } + if finalG > allowed { + t.Fatalf("goroutines leaked: before=%d after=%d", startG, finalG) + } +} + +func TestRunCodexTask_ContextTimeout(t *testing.T) { + defer resetTestHooks() + forceKillDelay.Store(0) + + fake := newFakeCmd(fakeCmdConfig{ + KeepStdoutOpen: true, + BlockWait: true, + ReleaseWaitOnKill: true, + ReleaseWaitOnSignal: false, + }) + + newCommandRunner = func(ctx context.Context, name string, args ...string) commandRunner { + return fake + } + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{targetArg} + } + codexCommand = "fake-cmd" + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + var capturedTimer *forceKillTimer + terminateCommandFn = func(cmd commandRunner) *forceKillTimer { + timer := terminateCommand(cmd) + capturedTimer = timer + return timer + } + defer func() { terminateCommandFn = terminateCommand }() + + result := runCodexTaskWithContext(ctx, TaskSpec{Task: "ctx-timeout", WorkDir: defaultWorkdir}, nil, false, false, 60) + + if result.ExitCode != 124 { + t.Fatalf("exit code = %d, want 124 (%s)", result.ExitCode, result.Error) + } + if !strings.Contains(strings.ToLower(result.Error), "timeout") { + t.Fatalf("error %q does not mention timeout", result.Error) + } + if fake.process == nil { + t.Fatalf("fake process not initialized") + } + if fake.process.SignalCount() == 0 { + t.Fatalf("expected SIGTERM to be sent, got 0") + } + if fake.process.KillCount() == 0 { + t.Fatalf("expected Kill to eventually run, got 0") + } + if capturedTimer == nil { + t.Fatalf("forceKillTimer not captured") + } + if !capturedTimer.stopped.Load() { + t.Fatalf("forceKillTimer.Stop was not called") + } + if !capturedTimer.drained.Load() { + t.Fatalf("forceKillTimer drain logic did not run") + } + if fake.stdout == nil { + t.Fatalf("stdout reader not initialized") + } + if reason := fake.stdout.Reason(); reason != stdoutCloseReasonCtx { + t.Fatalf("stdout close reason = %q, want %q", reason, stdoutCloseReasonCtx) + } +} + func TestRunParseArgs_NewMode(t *testing.T) { tests := []struct { name string @@ -299,7 +1034,7 @@ func TestRunParseArgs_BackendFlag(t *testing.T) { } } -func TestParseParallelConfig_Success(t *testing.T) { +func TestRunParseParallelConfig_Success(t *testing.T) { input := `---TASK--- id: task-1 dependencies: task-0 @@ -319,13 +1054,13 @@ do something` } } -func TestParseParallelConfig_InvalidFormat(t *testing.T) { +func TestRunParseParallelConfig_InvalidFormat(t *testing.T) { if _, err := parseParallelConfig([]byte("invalid format")); err == nil { t.Fatalf("expected error for invalid format, got nil") } } -func TestParseParallelConfig_EmptyTasks(t *testing.T) { +func TestRunParseParallelConfig_EmptyTasks(t *testing.T) { input := `---TASK--- id: empty ---CONTENT--- @@ -335,7 +1070,7 @@ id: empty } } -func TestParseParallelConfig_MissingID(t *testing.T) { +func TestRunParseParallelConfig_MissingID(t *testing.T) { input := `---TASK--- ---CONTENT--- do something` @@ -344,7 +1079,7 @@ do something` } } -func TestParseParallelConfig_MissingTask(t *testing.T) { +func TestRunParseParallelConfig_MissingTask(t *testing.T) { input := `---TASK--- id: task-1 ---CONTENT--- @@ -354,7 +1089,7 @@ id: task-1 } } -func TestParseParallelConfig_DuplicateID(t *testing.T) { +func TestRunParseParallelConfig_DuplicateID(t *testing.T) { input := `---TASK--- id: dup ---CONTENT--- @@ -368,7 +1103,7 @@ two` } } -func TestParseParallelConfig_DelimiterFormat(t *testing.T) { +func TestRunParseParallelConfig_DelimiterFormat(t *testing.T) { input := `---TASK--- id: T1 workdir: /tmp @@ -389,7 +1124,7 @@ code with special chars: $var "quotes"` } } -func TestShouldUseStdin(t *testing.T) { +func TestRunShouldUseStdin(t *testing.T) { tests := []struct { name string task string @@ -667,7 +1402,7 @@ func TestRunNormalizeText(t *testing.T) { } } -func TestParseJSONStream(t *testing.T) { +func TestRunParseJSONStream(t *testing.T) { type testCase struct { name string input string @@ -736,7 +1471,7 @@ func TestParseJSONStream_GeminiEvents(t *testing.T) { } } -func TestParseJSONStreamWithWarn_InvalidLine(t *testing.T) { +func TestRunParseJSONStreamWithWarn_InvalidLine(t *testing.T) { var warnings []string warnFn := func(msg string) { warnings = append(warnings, msg) } message, threadID := parseJSONStreamWithWarn(strings.NewReader("not-json"), warnFn) @@ -831,6 +1566,10 @@ func TestRunTruncate(t *testing.T) { } }) } + + if got := truncate("data", -1); got != "" { + t.Fatalf("truncate should return empty string for negative maxLen, got %q", got) + } } func TestRunMin(t *testing.T) { @@ -998,7 +1737,7 @@ func TestRunIsTerminal(t *testing.T) { } } -func TestReadPipedTask(t *testing.T) { +func TestRunReadPipedTask(t *testing.T) { defer resetTestHooks() tests := []struct { name string @@ -1078,6 +1817,83 @@ func TestRunCodexTask_WithEcho(t *testing.T) { } } +func TestRunCodexTask_LogPathWithActiveLogger(t *testing.T) { + defer resetTestHooks() + + logger, err := NewLoggerWithSuffix("active-logpath") + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + setLogger(logger) + + codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } + + jsonOutput := `{"type":"thread.started","thread_id":"fake-thread"} +{"type":"item.completed","item":{"type":"agent_message","text":"ok"}}` + + result := runCodexTask(TaskSpec{Task: jsonOutput}, false, 5) + if result.LogPath != logger.Path() { + t.Fatalf("LogPath = %q, want %q", result.LogPath, logger.Path()) + } + if result.ExitCode != 0 { + t.Fatalf("exit = %d, want 0 (%s)", result.ExitCode, result.Error) + } +} + +func TestRunCodexTask_LogPathWithTempLogger(t *testing.T) { + defer resetTestHooks() + + codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{targetArg} } + + jsonOutput := `{"type":"thread.started","thread_id":"temp-thread"} +{"type":"item.completed","item":{"type":"agent_message","text":"temp"}}` + + result := runCodexTask(TaskSpec{Task: jsonOutput}, true, 5) + t.Cleanup(func() { + if result.LogPath != "" { + os.Remove(result.LogPath) + } + }) + if result.LogPath == "" { + t.Fatalf("LogPath should not be empty for temp logger") + } + if _, err := os.Stat(result.LogPath); err != nil { + t.Fatalf("log file %q should exist (err=%v)", result.LogPath, err) + } + if activeLogger() != nil { + t.Fatalf("active logger should be cleared after silent run") + } +} + +func TestRunCodexTask_LogPathOnStartError(t *testing.T) { + defer resetTestHooks() + + logger, err := NewLoggerWithSuffix("start-error") + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + setLogger(logger) + + tmpFile, err := os.CreateTemp("", "start-error") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + + codexCommand = tmpFile.Name() + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { return []string{} } + + result := runCodexTask(TaskSpec{Task: "ignored"}, false, 5) + if result.ExitCode == 0 { + t.Fatalf("expected non-zero exit") + } + if result.LogPath != logger.Path() { + t.Fatalf("LogPath = %q, want %q", result.LogPath, logger.Path()) + } +} + func TestRunCodexTask_NoMessage(t *testing.T) { defer resetTestHooks() codexCommand = "echo" @@ -1211,7 +2027,24 @@ func TestCancelReason(t *testing.T) { } } -func TestSilentMode(t *testing.T) { +func TestRunCodexProcess(t *testing.T) { + defer resetTestHooks() + script := createFakeCodexScript(t, "proc-thread", "proc-msg") + codexCommand = script + + msg, threadID, exitCode := runCodexProcess(context.Background(), nil, "ignored", false, 5) + if exitCode != 0 { + t.Fatalf("exit = %d, want 0", exitCode) + } + if msg != "proc-msg" { + t.Fatalf("message = %q, want proc-msg", msg) + } + if threadID != "proc-thread" { + t.Fatalf("threadID = %q, want proc-thread", threadID) + } +} + +func TestRunSilentMode(t *testing.T) { defer resetTestHooks() jsonOutput := `{"type":"thread.started","thread_id":"silent-session"} {"type":"item.completed","item":{"type":"agent_message","text":"quiet"}}` @@ -1246,7 +2079,7 @@ func TestSilentMode(t *testing.T) { } } -func TestGenerateFinalOutput(t *testing.T) { +func TestRunGenerateFinalOutput(t *testing.T) { results := []TaskResult{{TaskID: "a", ExitCode: 0, Message: "ok"}, {TaskID: "b", ExitCode: 1, Error: "boom"}, {TaskID: "c", ExitCode: 0}} out := generateFinalOutput(results) if out == "" { @@ -1258,9 +2091,37 @@ func TestGenerateFinalOutput(t *testing.T) { if !strings.Contains(out, "Task: a") || !strings.Contains(out, "Task: b") { t.Fatalf("task entries missing") } + if strings.Contains(out, "Log:") { + t.Fatalf("unexpected log line when LogPath empty, got %q", out) + } } -func TestTopologicalSort_LinearChain(t *testing.T) { +func TestRunGenerateFinalOutput_LogPath(t *testing.T) { + results := []TaskResult{ + { + TaskID: "a", + ExitCode: 0, + Message: "ok", + SessionID: "sid", + LogPath: "/tmp/log-a", + }, + { + TaskID: "b", + ExitCode: 7, + Error: "bad", + LogPath: "/tmp/log-b", + }, + } + out := generateFinalOutput(results) + if !strings.Contains(out, "Session: sid\nLog: /tmp/log-a") { + t.Fatalf("output missing log line after session: %q", out) + } + if !strings.Contains(out, "Log: /tmp/log-b") { + t.Fatalf("output missing log line for failed task: %q", out) + } +} + +func TestRunTopologicalSort_LinearChain(t *testing.T) { tasks := []TaskSpec{{ID: "a"}, {ID: "b", Dependencies: []string{"a"}}, {ID: "c", Dependencies: []string{"b"}}} layers, err := topologicalSort(tasks) if err != nil { @@ -1271,7 +2132,7 @@ func TestTopologicalSort_LinearChain(t *testing.T) { } } -func TestTopologicalSort_Branching(t *testing.T) { +func TestRunTopologicalSort_Branching(t *testing.T) { tasks := []TaskSpec{{ID: "root"}, {ID: "left", Dependencies: []string{"root"}}, {ID: "right", Dependencies: []string{"root"}}, {ID: "leaf", Dependencies: []string{"left", "right"}}} layers, err := topologicalSort(tasks) if err != nil { @@ -1282,7 +2143,7 @@ func TestTopologicalSort_Branching(t *testing.T) { } } -func TestTopologicalSort_ParallelTasks(t *testing.T) { +func TestRunTopologicalSort_ParallelTasks(t *testing.T) { tasks := []TaskSpec{{ID: "a"}, {ID: "b"}, {ID: "c"}} layers, err := topologicalSort(tasks) if err != nil { @@ -1293,7 +2154,7 @@ func TestTopologicalSort_ParallelTasks(t *testing.T) { } } -func TestShouldSkipTask(t *testing.T) { +func TestRunShouldSkipTask(t *testing.T) { failed := map[string]TaskResult{"a": {TaskID: "a", ExitCode: 1}, "b": {TaskID: "b", ExitCode: 2}} tests := []struct { name string @@ -1322,28 +2183,28 @@ func TestShouldSkipTask(t *testing.T) { } } -func TestTopologicalSort_CycleDetection(t *testing.T) { +func TestRunTopologicalSort_CycleDetection(t *testing.T) { tasks := []TaskSpec{{ID: "a", Dependencies: []string{"b"}}, {ID: "b", Dependencies: []string{"a"}}} if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "cycle detected") { t.Fatalf("expected cycle error, got %v", err) } } -func TestTopologicalSort_IndirectCycle(t *testing.T) { +func TestRunTopologicalSort_IndirectCycle(t *testing.T) { tasks := []TaskSpec{{ID: "a", Dependencies: []string{"c"}}, {ID: "b", Dependencies: []string{"a"}}, {ID: "c", Dependencies: []string{"b"}}} if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "cycle detected") { t.Fatalf("expected cycle error, got %v", err) } } -func TestTopologicalSort_MissingDependency(t *testing.T) { +func TestRunTopologicalSort_MissingDependency(t *testing.T) { tasks := []TaskSpec{{ID: "a", Dependencies: []string{"missing"}}} if _, err := topologicalSort(tasks); err == nil || !strings.Contains(err.Error(), "dependency \"missing\" not found") { t.Fatalf("expected missing dependency error, got %v", err) } } -func TestTopologicalSort_LargeGraph(t *testing.T) { +func TestRunTopologicalSort_LargeGraph(t *testing.T) { const count = 200 tasks := make([]TaskSpec, count) for i := 0; i < count; i++ { @@ -1365,7 +2226,7 @@ func TestTopologicalSort_LargeGraph(t *testing.T) { } } -func TestExecuteConcurrent_ParallelExecution(t *testing.T) { +func TestRunExecuteConcurrent_ParallelExecution(t *testing.T) { orig := runCodexTaskFn defer func() { runCodexTaskFn = orig }() @@ -1401,7 +2262,7 @@ func TestExecuteConcurrent_ParallelExecution(t *testing.T) { } } -func TestExecuteConcurrent_LayerOrdering(t *testing.T) { +func TestRunExecuteConcurrent_LayerOrdering(t *testing.T) { orig := runCodexTaskFn defer func() { runCodexTaskFn = orig }() @@ -1423,7 +2284,7 @@ func TestExecuteConcurrent_LayerOrdering(t *testing.T) { } } -func TestExecuteConcurrent_ErrorIsolation(t *testing.T) { +func TestRunExecuteConcurrent_ErrorIsolation(t *testing.T) { orig := runCodexTaskFn defer func() { runCodexTaskFn = orig }() @@ -1456,7 +2317,7 @@ func TestExecuteConcurrent_ErrorIsolation(t *testing.T) { } } -func TestExecuteConcurrent_PanicRecovered(t *testing.T) { +func TestRunExecuteConcurrent_PanicRecovered(t *testing.T) { orig := runCodexTaskFn defer func() { runCodexTaskFn = orig }() @@ -1470,7 +2331,7 @@ func TestExecuteConcurrent_PanicRecovered(t *testing.T) { } } -func TestExecuteConcurrent_LargeFanout(t *testing.T) { +func TestRunExecuteConcurrent_LargeFanout(t *testing.T) { orig := runCodexTaskFn defer func() { runCodexTaskFn = orig }() @@ -1510,6 +2371,37 @@ test` } } +func TestRun_ParallelTriggersCleanup(t *testing.T) { + defer resetTestHooks() + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + os.Args = []string{"codex-wrapper", "--parallel"} + stdinReader = strings.NewReader(`---TASK--- +id: only +---CONTENT--- +noop`) + + cleanupCalls := 0 + cleanupLogsFn = func() (CleanupStats, error) { + cleanupCalls++ + return CleanupStats{}, nil + } + + orig := runCodexTaskFn + runCodexTaskFn = func(task TaskSpec, timeout int) TaskResult { + return TaskResult{TaskID: task.ID, ExitCode: 0, Message: "ok"} + } + defer func() { runCodexTaskFn = orig }() + + if exitCode := run(); exitCode != 0 { + t.Fatalf("exit = %d, want 0", exitCode) + } + if cleanupCalls != 1 { + t.Fatalf("cleanup called %d times, want 1", cleanupCalls) + } +} + func TestRun_Version(t *testing.T) { defer resetTestHooks() os.Args = []string{"codeagent-wrapper", "--version"} @@ -1542,6 +2434,172 @@ func TestRun_HelpShort(t *testing.T) { } } +func TestRun_HelpDoesNotTriggerCleanup(t *testing.T) { + defer resetTestHooks() + os.Args = []string{"codex-wrapper", "--help"} + cleanupLogsFn = func() (CleanupStats, error) { + t.Fatalf("cleanup should not run for --help") + return CleanupStats{}, nil + } + + if code := run(); code != 0 { + t.Fatalf("exit = %d, want 0", code) + } +} + +func TestRun_VersionDoesNotTriggerCleanup(t *testing.T) { + defer resetTestHooks() + os.Args = []string{"codex-wrapper", "--version"} + cleanupLogsFn = func() (CleanupStats, error) { + t.Fatalf("cleanup should not run for --version") + return CleanupStats{}, nil + } + + if code := run(); code != 0 { + t.Fatalf("exit = %d, want 0", code) + } +} + +func TestRunCleanupMode_Success(t *testing.T) { + defer resetTestHooks() + cleanupLogsFn = func() (CleanupStats, error) { + return CleanupStats{ + Scanned: 5, + Deleted: 3, + Kept: 2, + DeletedFiles: []string{"codex-wrapper-111.log", "codex-wrapper-222.log", "codex-wrapper-333.log"}, + KeptFiles: []string{"codex-wrapper-444.log", "codex-wrapper-555.log"}, + }, nil + } + + var exitCode int + output := captureOutput(t, func() { + exitCode = runCleanupMode() + }) + if exitCode != 0 { + t.Fatalf("exit = %d, want 0", exitCode) + } + want := "Cleanup completed\nFiles scanned: 5\nFiles deleted: 3\n - codex-wrapper-111.log\n - codex-wrapper-222.log\n - codex-wrapper-333.log\nFiles kept: 2\n - codex-wrapper-444.log\n - codex-wrapper-555.log\n" + if output != want { + t.Fatalf("output = %q, want %q", output, want) + } +} + +func TestRunCleanupMode_SuccessWithErrorsLine(t *testing.T) { + defer resetTestHooks() + cleanupLogsFn = func() (CleanupStats, error) { + return CleanupStats{ + Scanned: 2, + Deleted: 1, + Kept: 0, + Errors: 1, + DeletedFiles: []string{"codex-wrapper-123.log"}, + }, nil + } + + var exitCode int + output := captureOutput(t, func() { + exitCode = runCleanupMode() + }) + if exitCode != 0 { + t.Fatalf("exit = %d, want 0", exitCode) + } + want := "Cleanup completed\nFiles scanned: 2\nFiles deleted: 1\n - codex-wrapper-123.log\nFiles kept: 0\nDeletion errors: 1\n" + if output != want { + t.Fatalf("output = %q, want %q", output, want) + } +} + +func TestRunCleanupMode_ZeroStatsOutput(t *testing.T) { + defer resetTestHooks() + calls := 0 + cleanupLogsFn = func() (CleanupStats, error) { + calls++ + return CleanupStats{}, nil + } + + var exitCode int + output := captureOutput(t, func() { + exitCode = runCleanupMode() + }) + if exitCode != 0 { + t.Fatalf("exit = %d, want 0", exitCode) + } + want := "Cleanup completed\nFiles scanned: 0\nFiles deleted: 0\nFiles kept: 0\n" + if output != want { + t.Fatalf("output = %q, want %q", output, want) + } + if calls != 1 { + t.Fatalf("cleanup called %d times, want 1", calls) + } +} + +func TestRunCleanupMode_Error(t *testing.T) { + defer resetTestHooks() + cleanupLogsFn = func() (CleanupStats, error) { + return CleanupStats{}, fmt.Errorf("boom") + } + + var exitCode int + errOutput := captureStderr(t, func() { + exitCode = runCleanupMode() + }) + if exitCode != 1 { + t.Fatalf("exit = %d, want 1", exitCode) + } + if !strings.Contains(errOutput, "Cleanup failed") || !strings.Contains(errOutput, "boom") { + t.Fatalf("stderr = %q, want error message", errOutput) + } +} + +func TestRunCleanupMode_MissingFn(t *testing.T) { + defer resetTestHooks() + cleanupLogsFn = nil + + var exitCode int + errOutput := captureStderr(t, func() { + exitCode = runCleanupMode() + }) + if exitCode != 1 { + t.Fatalf("exit = %d, want 1", exitCode) + } + if !strings.Contains(errOutput, "log cleanup function not configured") { + t.Fatalf("stderr = %q, want missing-fn message", errOutput) + } +} + +func TestRun_CleanupFlag(t *testing.T) { + defer resetTestHooks() + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + + os.Args = []string{"codex-wrapper", "--cleanup"} + + calls := 0 + cleanupLogsFn = func() (CleanupStats, error) { + calls++ + return CleanupStats{Scanned: 1, Deleted: 1}, nil + } + + var exitCode int + output := captureOutput(t, func() { + exitCode = run() + }) + if exitCode != 0 { + t.Fatalf("exit = %d, want 0", exitCode) + } + if calls != 1 { + t.Fatalf("cleanup called %d times, want 1", calls) + } + want := "Cleanup completed\nFiles scanned: 1\nFiles deleted: 1\nFiles kept: 0\n" + if output != want { + t.Fatalf("output = %q, want %q", output, want) + } + if logger := activeLogger(); logger != nil { + t.Fatalf("logger should not initialize for --cleanup mode") + } +} + func TestRun_NoArgs(t *testing.T) { defer resetTestHooks() os.Args = []string{"codeagent-wrapper"} @@ -1756,7 +2814,7 @@ func TestRun_LoggerRemovedOnSignal(t *testing.T) { defer signal.Reset(syscall.SIGINT, syscall.SIGTERM) // Set shorter delays for faster test - forceKillDelay = 1 + forceKillDelay.Store(1) tempDir := t.TempDir() t.Setenv("TMPDIR", tempDir) @@ -1825,13 +2883,64 @@ func TestRun_CleanupHookAlwaysCalled(t *testing.T) { } } +func TestRunStartupCleanupNil(t *testing.T) { + defer resetTestHooks() + cleanupLogsFn = nil + runStartupCleanup() +} + +func TestRunStartupCleanupErrorLogged(t *testing.T) { + defer resetTestHooks() + + logger, err := NewLoggerWithSuffix("startup-error") + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + setLogger(logger) + t.Cleanup(func() { + logger.Flush() + logger.Close() + os.Remove(logger.Path()) + }) + + cleanupLogsFn = func() (CleanupStats, error) { + return CleanupStats{}, errors.New("zapped") + } + + runStartupCleanup() +} + +func TestRun_CleanupFailureDoesNotBlock(t *testing.T) { + defer resetTestHooks() + stdout := captureStdoutPipe() + defer restoreStdoutPipe(stdout) + + cleanupCalled := 0 + cleanupLogsFn = func() (CleanupStats, error) { + cleanupCalled++ + panic("boom") + } + + codexCommand = createFakeCodexScript(t, "tid-cleanup", "ok") + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + os.Args = []string{"codex-wrapper", "task"} + + if exit := run(); exit != 0 { + t.Fatalf("exit = %d, want 0", exit) + } + if cleanupCalled != 1 { + t.Fatalf("cleanup called %d times, want 1", cleanupCalled) + } +} + // Coverage helper reused by logger_test to keep focused runs exercising core paths. -func TestParseJSONStream_CoverageSuite(t *testing.T) { +func TestRunParseJSONStream_CoverageSuite(t *testing.T) { suite := []struct { name string fn func(*testing.T) }{ - {"TestParseJSONStream", TestParseJSONStream}, + {"TestRunParseJSONStream", TestRunParseJSONStream}, {"TestRunNormalizeText", TestRunNormalizeText}, {"TestRunTruncate", TestRunTruncate}, {"TestRunMin", TestRunMin}, @@ -1843,30 +2952,326 @@ func TestParseJSONStream_CoverageSuite(t *testing.T) { } } -func TestHello(t *testing.T) { +func TestRunHello(t *testing.T) { if got := hello(); got != "hello world" { t.Fatalf("hello() = %q, want %q", got, "hello world") } } -func TestGreet(t *testing.T) { +func TestRunGreet(t *testing.T) { if got := greet("Linus"); got != "hello Linus" { t.Fatalf("greet() = %q, want %q", got, "hello Linus") } } -func TestFarewell(t *testing.T) { +func TestRunFarewell(t *testing.T) { if got := farewell("Linus"); got != "goodbye Linus" { t.Fatalf("farewell() = %q, want %q", got, "goodbye Linus") } } -func TestFarewellEmpty(t *testing.T) { +func TestRunFarewellEmpty(t *testing.T) { if got := farewell(""); got != "goodbye " { t.Fatalf("farewell(\"\") = %q, want %q", got, "goodbye ") } } +func TestRunTailBuffer(t *testing.T) { + tb := &tailBuffer{limit: 5} + if n, err := tb.Write([]byte("abcd")); err != nil || n != 4 { + t.Fatalf("Write returned (%d, %v), want (4, nil)", n, err) + } + if n, err := tb.Write([]byte("efg")); err != nil || n != 3 { + t.Fatalf("Write returned (%d, %v), want (3, nil)", n, err) + } + if got := tb.String(); got != "cdefg" { + t.Fatalf("tail buffer = %q, want %q", got, "cdefg") + } + if n, err := tb.Write([]byte("0123456")); err != nil || n != 7 { + t.Fatalf("Write returned (%d, %v), want (7, nil)", n, err) + } + if got := tb.String(); got != "23456" { + t.Fatalf("tail buffer = %q, want %q", got, "23456") + } +} + +func TestRunLogWriter(t *testing.T) { + defer resetTestHooks() + logger, err := NewLoggerWithSuffix("logwriter") + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + setLogger(logger) + + lw := newLogWriter("TEST: ", 10) + if _, err := lw.Write([]byte("hello\n")); err != nil { + t.Fatalf("write hello failed: %v", err) + } + if _, err := lw.Write([]byte("world-is-long")); err != nil { + t.Fatalf("write world failed: %v", err) + } + lw.Flush() + + logger.Flush() + logger.Close() + + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + text := string(data) + if !strings.Contains(text, "TEST: hello") { + t.Fatalf("log missing hello entry: %s", text) + } + if !strings.Contains(text, "TEST: world-i...") { + t.Fatalf("log missing truncated entry: %s", text) + } + os.Remove(logger.Path()) +} + +func TestNewLogWriterDefaultLimit(t *testing.T) { + lw := newLogWriter("TEST: ", 0) + if lw.maxLen != codexLogLineLimit { + t.Fatalf("newLogWriter maxLen = %d, want %d", lw.maxLen, codexLogLineLimit) + } + lw = newLogWriter("TEST: ", -5) + if lw.maxLen != codexLogLineLimit { + t.Fatalf("negative maxLen should default, got %d", lw.maxLen) + } +} + +func TestRunDiscardInvalidJSON(t *testing.T) { + reader := bufio.NewReader(strings.NewReader("bad line\n{\"type\":\"ok\"}\n")) + next, err := discardInvalidJSON(nil, reader) + if err != nil { + t.Fatalf("discardInvalidJSON error: %v", err) + } + line, err := next.ReadString('\n') + if err != nil { + t.Fatalf("failed to read next line: %v", err) + } + if strings.TrimSpace(line) != `{"type":"ok"}` { + t.Fatalf("unexpected remaining line: %q", line) + } + + t.Run("no newline", func(t *testing.T) { + reader := bufio.NewReader(strings.NewReader("partial")) + decoder := json.NewDecoder(strings.NewReader("")) + if _, err := discardInvalidJSON(decoder, reader); !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF when no newline, got %v", err) + } + }) +} + +func TestRunForwardSignals(t *testing.T) { + defer resetTestHooks() + + if runtime.GOOS == "windows" { + t.Skip("sleep command not available on Windows") + } + + cmd := exec.Command("sleep", "5") + if err := cmd.Start(); err != nil { + t.Skipf("unable to start sleep command: %v", err) + } + defer func() { + _ = cmd.Process.Kill() + cmd.Wait() + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + forceKillDelay.Store(0) + defer forceKillDelay.Store(5) + + ready := make(chan struct{}) + var captured chan<- os.Signal + signalNotifyFn = func(ch chan<- os.Signal, sig ...os.Signal) { + captured = ch + close(ready) + } + signalStopFn = func(ch chan<- os.Signal) {} + defer func() { + signalNotifyFn = signal.Notify + signalStopFn = signal.Stop + }() + + var mu sync.Mutex + var logs []string + forwardSignals(ctx, cmd, func(msg string) { + mu.Lock() + defer mu.Unlock() + logs = append(logs, msg) + }) + + select { + case <-ready: + case <-time.After(500 * time.Millisecond): + t.Fatalf("signalNotifyFn not invoked") + } + + captured <- syscall.SIGINT + + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("process did not exit after forwarded signal") + } + + mu.Lock() + defer mu.Unlock() + if len(logs) == 0 { + t.Fatalf("expected log entry for forwarded signal") + } +} + +func TestRunNonParallelPrintsLogPath(t *testing.T) { + defer resetTestHooks() + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + + os.Args = []string{"codex-wrapper", "do-stuff"} + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + codexCommand = "echo" + buildCodexArgsFn = func(cfg *Config, targetArg string) []string { + return []string{`{"type":"thread.started","thread_id":"cli-session"}` + "\n" + `{"type":"item.completed","item":{"type":"agent_message","text":"ok"}}`} + } + + var exitCode int + stderr := captureStderr(t, func() { + _ = captureOutput(t, func() { + exitCode = run() + }) + }) + if exitCode != 0 { + t.Fatalf("run() exit = %d, want 0", exitCode) + } + expectedLog := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + wantLine := fmt.Sprintf("Log: %s", expectedLog) + if !strings.Contains(stderr, wantLine) { + t.Fatalf("stderr missing %q, got: %q", wantLine, stderr) + } +} + +func TestRealProcessNilSafety(t *testing.T) { + var proc *realProcess + if pid := proc.Pid(); pid != 0 { + t.Fatalf("Pid() = %d, want 0", pid) + } + if err := proc.Kill(); err != nil { + t.Fatalf("Kill() error = %v", err) + } + if err := proc.Signal(syscall.SIGTERM); err != nil { + t.Fatalf("Signal() error = %v", err) + } +} + +func TestRealProcessKill(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("sleep command not available on Windows") + } + + cmd := exec.Command("sleep", "5") + if err := cmd.Start(); err != nil { + t.Skipf("unable to start sleep command: %v", err) + } + waited := false + defer func() { + if waited { + return + } + if cmd.Process != nil { + _ = cmd.Process.Kill() + cmd.Wait() + } + }() + + proc := &realProcess{proc: cmd.Process} + if proc.Pid() == 0 { + t.Fatalf("Pid() returned 0 for active process") + } + if err := proc.Kill(); err != nil { + t.Fatalf("Kill() error = %v", err) + } + waitErr := cmd.Wait() + waited = true + if waitErr == nil { + t.Fatalf("Kill() should lead to non-nil wait error") + } +} + +func TestRealProcessSignal(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("sleep command not available on Windows") + } + + cmd := exec.Command("sleep", "5") + if err := cmd.Start(); err != nil { + t.Skipf("unable to start sleep command: %v", err) + } + waited := false + defer func() { + if waited { + return + } + if cmd.Process != nil { + _ = cmd.Process.Kill() + cmd.Wait() + } + }() + + proc := &realProcess{proc: cmd.Process} + if err := proc.Signal(syscall.SIGTERM); err != nil { + t.Fatalf("Signal() error = %v", err) + } + waitErr := cmd.Wait() + waited = true + if waitErr == nil { + t.Fatalf("Signal() should lead to non-nil wait error") + } +} + +func TestRealCmdProcess(t *testing.T) { + rc := &realCmd{} + if rc.Process() != nil { + t.Fatalf("Process() should return nil when realCmd has no command") + } + rc = &realCmd{cmd: &exec.Cmd{}} + if rc.Process() != nil { + t.Fatalf("Process() should return nil when exec.Cmd has no process") + } + + if runtime.GOOS == "windows" { + return + } + + cmd := exec.Command("sleep", "5") + if err := cmd.Start(); err != nil { + t.Skipf("unable to start sleep command: %v", err) + } + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + cmd.Wait() + } + }() + + rc = &realCmd{cmd: cmd} + handle := rc.Process() + if handle == nil { + t.Fatalf("expected non-nil process handle") + } + if pid := handle.Pid(); pid == 0 { + t.Fatalf("process handle returned pid=0") + } +} + func TestRun_CLI_Success(t *testing.T) { defer resetTestHooks() os.Args = []string{"codeagent-wrapper", "do-things"} diff --git a/codeagent-wrapper/process_check_test.go b/codeagent-wrapper/process_check_test.go new file mode 100644 index 0000000..9ad661e --- /dev/null +++ b/codeagent-wrapper/process_check_test.go @@ -0,0 +1,217 @@ +//go:build unix || darwin || linux +// +build unix darwin linux + +package main + +import ( + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +func TestIsProcessRunning(t *testing.T) { + t.Run("current process", func(t *testing.T) { + if !isProcessRunning(os.Getpid()) { + t.Fatalf("expected current process (pid=%d) to be running", os.Getpid()) + } + }) + + t.Run("fake pid", func(t *testing.T) { + const nonexistentPID = 1 << 30 + if isProcessRunning(nonexistentPID) { + t.Fatalf("expected pid %d to be reported as not running", nonexistentPID) + } + }) + + t.Run("terminated process", func(t *testing.T) { + pid := exitedProcessPID(t) + if isProcessRunning(pid) { + t.Fatalf("expected exited child process (pid=%d) to be reported as not running", pid) + } + }) + + t.Run("boundary values", func(t *testing.T) { + if isProcessRunning(0) { + t.Fatalf("pid 0 should never be treated as running") + } + if isProcessRunning(-42) { + t.Fatalf("negative pid should never be treated as running") + } + }) + + t.Run("find process error", func(t *testing.T) { + original := findProcess + defer func() { findProcess = original }() + + mockErr := errors.New("findProcess failure") + findProcess = func(pid int) (*os.Process, error) { + return nil, mockErr + } + + if isProcessRunning(1234) { + t.Fatalf("expected false when os.FindProcess fails") + } + }) +} + +func exitedProcessPID(t *testing.T) int { + t.Helper() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("cmd", "/c", "exit 0") + } else { + cmd = exec.Command("sh", "-c", "exit 0") + } + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start helper process: %v", err) + } + pid := cmd.Process.Pid + + if err := cmd.Wait(); err != nil { + t.Fatalf("helper process did not exit cleanly: %v", err) + } + + time.Sleep(50 * time.Millisecond) + return pid +} + +func TestRunProcessCheckSmoke(t *testing.T) { + t.Run("current process", func(t *testing.T) { + if !isProcessRunning(os.Getpid()) { + t.Fatalf("expected current process (pid=%d) to be running", os.Getpid()) + } + }) + + t.Run("fake pid", func(t *testing.T) { + const nonexistentPID = 1 << 30 + if isProcessRunning(nonexistentPID) { + t.Fatalf("expected pid %d to be reported as not running", nonexistentPID) + } + }) + + t.Run("boundary values", func(t *testing.T) { + if isProcessRunning(0) { + t.Fatalf("pid 0 should never be treated as running") + } + if isProcessRunning(-42) { + t.Fatalf("negative pid should never be treated as running") + } + }) + + t.Run("find process error", func(t *testing.T) { + original := findProcess + defer func() { findProcess = original }() + + mockErr := errors.New("findProcess failure") + findProcess = func(pid int) (*os.Process, error) { + return nil, mockErr + } + + if isProcessRunning(1234) { + t.Fatalf("expected false when os.FindProcess fails") + } + }) +} + +func TestGetProcessStartTimeReadsProcStat(t *testing.T) { + pid := 4321 + boot := time.Unix(1_710_000_000, 0) + startTicks := uint64(4500) + + statFields := make([]string, 25) + for i := range statFields { + statFields[i] = strconv.Itoa(i + 1) + } + statFields[19] = strconv.FormatUint(startTicks, 10) + statContent := fmt.Sprintf("%d (%s) %s", pid, "cmd with space", strings.Join(statFields, " ")) + + stubReadFile(t, func(path string) ([]byte, error) { + switch path { + case fmt.Sprintf("/proc/%d/stat", pid): + return []byte(statContent), nil + case "/proc/stat": + return []byte(fmt.Sprintf("cpu 0 0 0 0\nbtime %d\n", boot.Unix())), nil + default: + return nil, os.ErrNotExist + } + }) + + got := getProcessStartTime(pid) + want := boot.Add(time.Duration(startTicks/100) * time.Second) + if !got.Equal(want) { + t.Fatalf("getProcessStartTime() = %v, want %v", got, want) + } +} + +func TestGetProcessStartTimeInvalidData(t *testing.T) { + pid := 99 + stubReadFile(t, func(path string) ([]byte, error) { + switch path { + case fmt.Sprintf("/proc/%d/stat", pid): + return []byte("garbage"), nil + case "/proc/stat": + return []byte("btime not-a-number\n"), nil + default: + return nil, os.ErrNotExist + } + }) + + if got := getProcessStartTime(pid); !got.IsZero() { + t.Fatalf("invalid /proc data should return zero time, got %v", got) + } +} + +func TestGetBootTimeParsesBtime(t *testing.T) { + const bootSec = 1_711_111_111 + stubReadFile(t, func(path string) ([]byte, error) { + if path != "/proc/stat" { + return nil, os.ErrNotExist + } + content := fmt.Sprintf("intr 0\nbtime %d\n", bootSec) + return []byte(content), nil + }) + + got := getBootTime() + want := time.Unix(bootSec, 0) + if !got.Equal(want) { + t.Fatalf("getBootTime() = %v, want %v", got, want) + } +} + +func TestGetBootTimeInvalidData(t *testing.T) { + cases := []struct { + name string + content string + }{ + {"missing", "cpu 0 0 0 0"}, + {"malformed", "btime abc"}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + stubReadFile(t, func(string) ([]byte, error) { + return []byte(tt.content), nil + }) + if got := getBootTime(); !got.IsZero() { + t.Fatalf("getBootTime() unexpected value for %s: %v", tt.name, got) + } + }) + } +} + +func stubReadFile(t *testing.T, fn func(string) ([]byte, error)) { + t.Helper() + original := readFileFn + readFileFn = fn + t.Cleanup(func() { + readFileFn = original + }) +} diff --git a/codeagent-wrapper/process_check_unix.go b/codeagent-wrapper/process_check_unix.go new file mode 100644 index 0000000..c235d65 --- /dev/null +++ b/codeagent-wrapper/process_check_unix.go @@ -0,0 +1,104 @@ +//go:build unix || darwin || linux +// +build unix darwin linux + +package main + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + "syscall" + "time" +) + +var findProcess = os.FindProcess +var readFileFn = os.ReadFile + +// isProcessRunning returns true if a process with the given pid is running on Unix-like systems. +func isProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + + proc, err := findProcess(pid) + if err != nil || proc == nil { + return false + } + + err = proc.Signal(syscall.Signal(0)) + if err != nil && (errors.Is(err, syscall.ESRCH) || errors.Is(err, os.ErrProcessDone)) { + return false + } + return true +} + +// getProcessStartTime returns the start time of a process on Unix-like systems. +// Returns zero time if the start time cannot be determined. +func getProcessStartTime(pid int) time.Time { + if pid <= 0 { + return time.Time{} + } + + // Read /proc//stat to get process start time + statPath := fmt.Sprintf("/proc/%d/stat", pid) + data, err := readFileFn(statPath) + if err != nil { + return time.Time{} + } + + // Parse stat file: fields are space-separated, but comm (field 2) can contain spaces + // Find the last ')' to skip comm field safely + content := string(data) + lastParen := strings.LastIndex(content, ")") + if lastParen == -1 { + return time.Time{} + } + + fields := strings.Fields(content[lastParen+1:]) + if len(fields) < 20 { + return time.Time{} + } + + // Field 22 (index 19 after comm) is starttime in clock ticks since boot + startTicks, err := strconv.ParseUint(fields[19], 10, 64) + if err != nil { + return time.Time{} + } + + // Get system boot time + bootTime := getBootTime() + if bootTime.IsZero() { + return time.Time{} + } + + // Convert ticks to duration (typically 100 ticks/sec on most systems) + ticksPerSec := uint64(100) // sysconf(_SC_CLK_TCK), typically 100 + startTime := bootTime.Add(time.Duration(startTicks/ticksPerSec) * time.Second) + + return startTime +} + +// getBootTime returns the system boot time by reading /proc/stat. +func getBootTime() time.Time { + data, err := readFileFn("/proc/stat") + if err != nil { + return time.Time{} + } + + lines := strings.Split(string(data), "\n") + for _, line := range lines { + if strings.HasPrefix(line, "btime ") { + fields := strings.Fields(line) + if len(fields) >= 2 { + bootSec, err := strconv.ParseInt(fields[1], 10, 64) + if err == nil { + return time.Unix(bootSec, 0) + } + } + } + } + + return time.Time{} +} diff --git a/codeagent-wrapper/process_check_windows.go b/codeagent-wrapper/process_check_windows.go new file mode 100644 index 0000000..ada5e1c --- /dev/null +++ b/codeagent-wrapper/process_check_windows.go @@ -0,0 +1,87 @@ +//go:build windows +// +build windows + +package main + +import ( + "errors" + "os" + "syscall" + "time" + "unsafe" +) + +const ( + processQueryLimitedInformation = 0x1000 + stillActive = 259 // STILL_ACTIVE exit code +) + +var ( + findProcess = os.FindProcess + kernel32 = syscall.NewLazyDLL("kernel32.dll") + getProcessTimes = kernel32.NewProc("GetProcessTimes") + fileTimeToUnixFn = fileTimeToUnix +) + +// isProcessRunning returns true if a process with the given pid is running on Windows. +func isProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + + if _, err := findProcess(pid); err != nil { + return false + } + + handle, err := syscall.OpenProcess(processQueryLimitedInformation, false, uint32(pid)) + if err != nil { + if errors.Is(err, syscall.ERROR_ACCESS_DENIED) { + return true + } + return false + } + defer syscall.CloseHandle(handle) + + var exitCode uint32 + if err := syscall.GetExitCodeProcess(handle, &exitCode); err != nil { + return true + } + + return exitCode == stillActive +} + +// getProcessStartTime returns the start time of a process on Windows. +// Returns zero time if the start time cannot be determined. +func getProcessStartTime(pid int) time.Time { + if pid <= 0 { + return time.Time{} + } + + handle, err := syscall.OpenProcess(processQueryLimitedInformation, false, uint32(pid)) + if err != nil { + return time.Time{} + } + defer syscall.CloseHandle(handle) + + var creationTime, exitTime, kernelTime, userTime syscall.Filetime + ret, _, _ := getProcessTimes.Call( + uintptr(handle), + uintptr(unsafe.Pointer(&creationTime)), + uintptr(unsafe.Pointer(&exitTime)), + uintptr(unsafe.Pointer(&kernelTime)), + uintptr(unsafe.Pointer(&userTime)), + ) + + if ret == 0 { + return time.Time{} + } + + return fileTimeToUnixFn(creationTime) +} + +// fileTimeToUnix converts Windows FILETIME to Unix time. +func fileTimeToUnix(ft syscall.Filetime) time.Time { + // FILETIME is 100-nanosecond intervals since January 1, 1601 UTC + nsec := ft.Nanoseconds() + return time.Unix(0, nsec) +} diff --git a/install.bat b/install.bat new file mode 100644 index 0000000..3640009 --- /dev/null +++ b/install.bat @@ -0,0 +1,163 @@ +@echo off +setlocal enabledelayedexpansion + +set "EXIT_CODE=0" +set "REPO=cexll/myclaude" +set "VERSION=latest" +set "OS=windows" + +call :detect_arch +if errorlevel 1 goto :fail + +set "BINARY_NAME=codex-wrapper-%OS%-%ARCH%.exe" +set "URL=https://github.com/%REPO%/releases/%VERSION%/download/%BINARY_NAME%" +set "TEMP_FILE=%TEMP%\codex-wrapper-%ARCH%-%RANDOM%.exe" +set "DEST_DIR=%USERPROFILE%\bin" +set "DEST=%DEST_DIR%\codex-wrapper.exe" + +echo Downloading codex-wrapper for %ARCH% ... +echo %URL% +call :download +if errorlevel 1 goto :fail + +if not exist "%TEMP_FILE%" ( + echo ERROR: download failed to produce "%TEMP_FILE%". + goto :fail +) + +echo Installing to "%DEST%" ... +if not exist "%DEST_DIR%" ( + mkdir "%DEST_DIR%" >nul 2>nul || goto :fail +) + +move /y "%TEMP_FILE%" "%DEST%" >nul 2>nul +if errorlevel 1 ( + echo ERROR: unable to place file in "%DEST%". + goto :fail +) + +"%DEST%" --version >nul 2>nul +if errorlevel 1 ( + echo ERROR: installation verification failed. + goto :fail +) + +echo. +echo codex-wrapper installed successfully at: +echo %DEST% + +rem Automatically ensure %USERPROFILE%\bin is in the USER (HKCU) PATH +rem 1) Read current user PATH from registry (REG_SZ or REG_EXPAND_SZ) +set "USER_PATH_RAW=" +set "USER_PATH_TYPE=" +for /f "tokens=1,2,*" %%A in ('reg query "HKCU\Environment" /v Path 2^>nul ^| findstr /I /R "^ *Path *REG_"') do ( + set "USER_PATH_TYPE=%%B" + set "USER_PATH_RAW=%%C" +) +rem Trim leading spaces from USER_PATH_RAW +for /f "tokens=* delims= " %%D in ("!USER_PATH_RAW!") do set "USER_PATH_RAW=%%D" + +rem Normalize DEST_DIR by removing a trailing backslash if present +if "!DEST_DIR:~-1!"=="\" set "DEST_DIR=!DEST_DIR:~0,-1!" + +rem Build search tokens (expanded and literal) +set "PCT=%%" +set "SEARCH_EXP=;!DEST_DIR!;" +set "SEARCH_EXP2=;!DEST_DIR!\;" +set "SEARCH_LIT=;!PCT!USERPROFILE!PCT!\bin;" +set "SEARCH_LIT2=;!PCT!USERPROFILE!PCT!\bin\;" + +rem Prepare user PATH variants for containment tests +set "CHECK_RAW=;!USER_PATH_RAW!;" +set "USER_PATH_EXP=!USER_PATH_RAW!" +if defined USER_PATH_EXP call set "USER_PATH_EXP=%%USER_PATH_EXP%%" +set "CHECK_EXP=;!USER_PATH_EXP!;" + +rem Check if already present in user PATH (literal or expanded, with/without trailing backslash) +set "ALREADY_IN_USERPATH=0" +echo !CHECK_RAW! | findstr /I /C:"!SEARCH_LIT!" /C:"!SEARCH_LIT2!" >nul && set "ALREADY_IN_USERPATH=1" +if "!ALREADY_IN_USERPATH!"=="0" ( + echo !CHECK_EXP! | findstr /I /C:"!SEARCH_EXP!" /C:"!SEARCH_EXP2!" >nul && set "ALREADY_IN_USERPATH=1" +) + +if "!ALREADY_IN_USERPATH!"=="1" ( + echo User PATH already includes %%USERPROFILE%%\bin. +) else ( + rem Not present: append to user PATH using setx without duplicating system PATH + if defined USER_PATH_RAW ( + set "USER_PATH_NEW=!USER_PATH_RAW!" + if not "!USER_PATH_NEW:~-1!"==";" set "USER_PATH_NEW=!USER_PATH_NEW!;" + set "USER_PATH_NEW=!USER_PATH_NEW!!PCT!USERPROFILE!PCT!\bin" + ) else ( + set "USER_PATH_NEW=!PCT!USERPROFILE!PCT!\bin" + ) + rem Persist update to HKCU\Environment\Path (user scope) + setx PATH "!USER_PATH_NEW!" >nul + if errorlevel 1 ( + echo WARNING: Failed to append %%USERPROFILE%%\bin to your user PATH. + ) else ( + echo Added %%USERPROFILE%%\bin to your user PATH. + ) +) + +rem Update current session PATH so codex-wrapper is immediately available +set "CURPATH=;%PATH%;" +echo !CURPATH! | findstr /I /C:"!SEARCH_EXP!" /C:"!SEARCH_EXP2!" /C:"!SEARCH_LIT!" /C:"!SEARCH_LIT2!" >nul +if errorlevel 1 set "PATH=!DEST_DIR!;!PATH!" + +goto :cleanup + +:detect_arch +set "ARCH=%PROCESSOR_ARCHITECTURE%" +if defined PROCESSOR_ARCHITEW6432 set "ARCH=%PROCESSOR_ARCHITEW6432%" + +if /I "%ARCH%"=="AMD64" ( + set "ARCH=amd64" + exit /b 0 +) else if /I "%ARCH%"=="ARM64" ( + set "ARCH=arm64" + exit /b 0 +) else ( + echo ERROR: unsupported architecture "%ARCH%". 64-bit Windows on AMD64 or ARM64 is required. + set "EXIT_CODE=1" + exit /b 1 +) + +:download +where curl >nul 2>nul +if %errorlevel%==0 ( + echo Using curl ... + curl -fL --retry 3 --connect-timeout 10 "%URL%" -o "%TEMP_FILE%" + if errorlevel 1 ( + echo ERROR: curl download failed. + set "EXIT_CODE=1" + exit /b 1 + ) + exit /b 0 +) + +where powershell >nul 2>nul +if %errorlevel%==0 ( + echo Using PowerShell ... + powershell -NoLogo -NoProfile -Command " $ErrorActionPreference='Stop'; try { [Net.ServicePointManager]::SecurityProtocol = [Net.ServicePointManager]::SecurityProtocol -bor 3072 -bor 768 -bor 192 } catch {} ; $wc = New-Object System.Net.WebClient; $wc.DownloadFile('%URL%','%TEMP_FILE%') " + if errorlevel 1 ( + echo ERROR: PowerShell download failed. + set "EXIT_CODE=1" + exit /b 1 + ) + exit /b 0 +) + +echo ERROR: neither curl nor PowerShell is available to download the installer. +set "EXIT_CODE=1" +exit /b 1 + +:fail +echo Installation failed. +set "EXIT_CODE=1" +goto :cleanup + +:cleanup +if exist "%TEMP_FILE%" del /f /q "%TEMP_FILE%" >nul 2>nul +set "CODE=%EXIT_CODE%" +endlocal & exit /b %CODE% diff --git a/install.py b/install.py index ffae828..e074bf1 100644 --- a/install.py +++ b/install.py @@ -332,6 +332,8 @@ def op_run_command(op: Dict[str, Any], ctx: Dict[str, Any]) -> None: env[key] = value.replace("${install_dir}", str(ctx["install_dir"])) command = op.get("command", "") + if sys.platform == "win32" and command.strip() == "bash install.sh": + command = "cmd /c install.bat" result = subprocess.run( command, shell=True,