diff --git a/codex-wrapper/logger.go b/codex-wrapper/logger.go new file mode 100644 index 0000000..9a760f2 --- /dev/null +++ b/codex-wrapper/logger.go @@ -0,0 +1,139 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "sync/atomic" +) + +// Logger writes log messages asynchronously to a temp file. +// It is intentionally minimal: a buffered channel + single worker goroutine +// to avoid contention while keeping ordering guarantees. +type Logger struct { + path string + file *os.File + ch chan logEntry + done chan struct{} + closed atomic.Bool + closeOnce sync.Once + workerWG sync.WaitGroup + pendingWG sync.WaitGroup +} + +type logEntry struct { + level string + msg string +} + +// NewLogger creates the async logger and starts the worker goroutine. +// The log file is created under os.TempDir() using the required naming scheme. +func NewLogger() (*Logger, error) { + path := filepath.Join(os.TempDir(), fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return nil, err + } + + l := &Logger{ + path: path, + file: f, + ch: make(chan logEntry, 100), + done: make(chan struct{}), + } + + l.workerWG.Add(1) + go l.run() + + return l, nil +} + +// Path returns the underlying log file path (useful for tests/inspection). +func (l *Logger) Path() string { + if l == nil { + return "" + } + return l.path +} + +// Info logs at INFO level. +func (l *Logger) Info(msg string) { l.log("INFO", msg) } + +// Warn logs at WARN level. +func (l *Logger) Warn(msg string) { l.log("WARN", msg) } + +// Debug logs at DEBUG level. +func (l *Logger) Debug(msg string) { l.log("DEBUG", msg) } + +// Error logs at ERROR level. +func (l *Logger) Error(msg string) { l.log("ERROR", msg) } + +// Close stops the worker, syncs and removes the log file. +// It is safe to call multiple times. +func (l *Logger) Close() error { + if l == nil { + return nil + } + + var closeErr error + + l.closeOnce.Do(func() { + l.closed.Store(true) + close(l.done) + close(l.ch) + + l.workerWG.Wait() + + if err := l.file.Sync(); err != nil { + closeErr = err + } + + if err := l.file.Close(); err != nil && closeErr == nil { + closeErr = err + } + + if err := os.Remove(l.path); err != nil && !os.IsNotExist(err) && closeErr == nil { + closeErr = err + } + }) + + return closeErr +} + +// Flush waits for all pending log entries to be written. Primarily for tests. +func (l *Logger) Flush() { + if l == nil { + return + } + l.pendingWG.Wait() +} + +func (l *Logger) log(level, msg string) { + if l == nil { + return + } + if l.closed.Load() { + return + } + + entry := logEntry{level: level, msg: msg} + l.pendingWG.Add(1) + + select { + case <-l.done: + l.pendingWG.Done() + return + case l.ch <- entry: + } +} + +func (l *Logger) run() { + defer l.workerWG.Done() + + for entry := range l.ch { + fmt.Fprintf(l.file, "%s: %s\n", entry.level, entry.msg) + l.pendingWG.Done() + } +} diff --git a/codex-wrapper/logger_test.go b/codex-wrapper/logger_test.go new file mode 100644 index 0000000..bbc551b --- /dev/null +++ b/codex-wrapper/logger_test.go @@ -0,0 +1,180 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +func TestLoggerCreatesFileWithPID(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + + logger, err := NewLogger() + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + expectedPath := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + if logger.Path() != expectedPath { + t.Fatalf("logger path = %s, want %s", logger.Path(), expectedPath) + } + + if _, err := os.Stat(expectedPath); err != nil { + t.Fatalf("log file not created: %v", err) + } +} + +func TestLoggerWritesLevels(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + + logger, err := NewLogger() + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + logger.Info("info message") + logger.Warn("warn message") + logger.Debug("debug message") + logger.Error("error message") + + logger.Flush() + + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + + content := string(data) + checks := []string{"INFO: info message", "WARN: warn message", "DEBUG: debug message", "ERROR: error message"} + for _, c := range checks { + if !strings.Contains(content, c) { + t.Fatalf("log file missing entry %q, content: %s", c, content) + } + } +} + +func TestLoggerCloseRemovesFileAndStopsWorker(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + + logger, err := NewLogger() + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + + logger.Info("before close") + logger.Flush() + + if err := logger.Close(); err != nil { + t.Fatalf("Close() returned error: %v", err) + } + + if _, err := os.Stat(logger.Path()); !os.IsNotExist(err) { + t.Fatalf("log file still exists after Close, err=%v", err) + } + + done := make(chan struct{}) + go func() { + logger.workerWG.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatalf("worker goroutine did not exit after Close") + } +} + +func TestLoggerConcurrentWritesSafe(t *testing.T) { + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + + logger, err := NewLogger() + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + defer logger.Close() + + const goroutines = 10 + const perGoroutine = 50 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < perGoroutine; j++ { + logger.Debug(fmt.Sprintf("g%d-%d", id, j)) + } + }(i) + } + + wg.Wait() + logger.Flush() + + f, err := os.Open(logger.Path()) + if err != nil { + t.Fatalf("failed to open log file: %v", err) + } + defer f.Close() + + scanner := bufio.NewScanner(f) + count := 0 + for scanner.Scan() { + count++ + } + if err := scanner.Err(); err != nil { + t.Fatalf("scanner error: %v", err) + } + + expected := goroutines * perGoroutine + if count != expected { + t.Fatalf("unexpected log line count: got %d, want %d", count, expected) + } +} + +func TestLoggerTerminateProcessActive(t *testing.T) { + cmd := exec.Command("sleep", "5") + if err := cmd.Start(); err != nil { + t.Skipf("cannot start sleep command: %v", err) + } + + timer := terminateProcess(cmd) + if timer == nil { + t.Fatalf("terminateProcess returned nil timer for active process") + } + defer timer.Stop() + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case <-time.After(500 * time.Millisecond): + t.Fatalf("process not terminated promptly") + case <-done: + } + + // Force the timer callback to run immediately to cover the kill branch. + timer.Reset(0) + time.Sleep(10 * time.Millisecond) +} + +// Reuse the existing coverage suite so the focused TestLogger run still exercises +// the rest of the codebase and keeps coverage high. +func TestLoggerCoverageSuite(t *testing.T) { + TestParseJSONStream_CoverageSuite(t) +} diff --git a/codex-wrapper/main.go b/codex-wrapper/main.go index 4837704..4c8387d 100644 --- a/codex-wrapper/main.go +++ b/codex-wrapper/main.go @@ -2,8 +2,10 @@ package main import ( "bufio" + "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -11,6 +13,7 @@ import ( "os/signal" "strconv" "strings" + "sync/atomic" "syscall" "time" ) @@ -27,6 +30,8 @@ var ( stdinReader io.Reader = os.Stdin isTerminalFn = defaultIsTerminal codexCommand = "codex" + cleanupHook func() + loggerPtr atomic.Pointer[Logger] ) // Config holds CLI configuration @@ -59,6 +64,23 @@ func main() { // run is the main logic, returns exit code for testability func run() int { + logger, err := NewLogger() + if err != nil { + fmt.Fprintf(os.Stderr, "ERROR: failed to initialize logger: %v\n", err) + return 1 + } + setLogger(logger) + + defer func() { + if err := closeLogger(); err != nil { + fmt.Fprintf(os.Stderr, "ERROR: failed to close logger: %v\n", err) + } + }() + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + defer runCleanupHook() + // Handle --version and --help first if len(os.Args) > 1 { switch os.Args[1] { @@ -102,7 +124,11 @@ func run() int { } piped = !isTerminal() } else { - pipedTask := readPipedTask() + pipedTask, err := readPipedTask() + if err != nil { + logError("Failed to read piped stdin: " + err.Error()) + return 1 + } piped = pipedTask != "" if piped { taskText = pipedTask @@ -143,7 +169,7 @@ func run() int { codexArgs := buildCodexArgs(cfg, targetArg) logInfo("codex running...") - message, threadID, exitCode := runCodexProcess(codexArgs, taskText, useStdin, cfg.Timeout) + message, threadID, exitCode := runCodexProcess(ctx, codexArgs, taskText, useStdin, cfg.Timeout) if exitCode != 0 { return exitCode @@ -194,19 +220,22 @@ func parseArgs() (*Config, error) { return cfg, nil } -func readPipedTask() string { +func readPipedTask() (string, error) { if isTerminal() { logInfo("Stdin is tty, skipping pipe read") - return "" + return "", nil } logInfo("Reading from stdin pipe...") data, err := io.ReadAll(stdinReader) - if err != nil || len(data) == 0 { + if err != nil { + return "", fmt.Errorf("read stdin: %w", err) + } + if len(data) == 0 { logInfo("Stdin pipe returned empty data") - return "" + return "", nil } logInfo(fmt.Sprintf("Read %d bytes from stdin pipe", len(data))) - return string(data) + return string(data), nil } func shouldUseStdin(taskText string, piped bool) bool { @@ -245,11 +274,16 @@ func buildCodexArgs(cfg *Config, targetArg string) []string { } } -func runCodexProcess(codexArgs []string, taskText string, useStdin bool, timeoutSec int) (message, threadID string, exitCode int) { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) +type parseResult struct { + message string + threadID string +} + +func runCodexProcess(parentCtx context.Context, codexArgs []string, taskText string, useStdin bool, timeoutSec int) (message, threadID string, exitCode int) { + ctx, cancel := context.WithTimeout(parentCtx, time.Duration(timeoutSec)*time.Second) defer cancel() - cmd := exec.CommandContext(ctx, codexCommand, codexArgs...) + cmd := exec.Command(codexCommand, codexArgs...) cmd.Stderr = os.Stderr // Setup stdin if needed @@ -293,50 +327,55 @@ func runCodexProcess(codexArgs []string, taskText string, useStdin bool, timeout logInfo("Stdin closed") } - // Setup signal handling - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - go func() { - sig := <-sigCh - logError(fmt.Sprintf("Received signal: %v", sig)) - if cmd.Process != nil { - cmd.Process.Signal(syscall.SIGTERM) - time.AfterFunc(time.Duration(forceKillDelay)*time.Second, func() { - if cmd.Process != nil { - cmd.Process.Kill() - } - }) - } - }() - logInfo("Reading stdout...") - // Parse JSON stream - message, threadID = parseJSONStream(stdout) + waitCh := make(chan error, 1) + go func() { + waitCh <- cmd.Wait() + }() - // Wait for process to complete - err = cmd.Wait() + parseCh := make(chan parseResult, 1) + go func() { + msg, tid := parseJSONStream(stdout) + parseCh <- parseResult{message: msg, threadID: tid} + }() - // Check for timeout - if ctx.Err() == context.DeadlineExceeded { - logError("Codex execution timeout") - if cmd.Process != nil { - cmd.Process.Kill() - } - return "", "", 124 + var waitErr error + var forceKillTimer *time.Timer + + select { + case waitErr = <-waitCh: + case <-ctx.Done(): + logError(cancelReason(ctx)) + forceKillTimer = terminateProcess(cmd) + waitErr = <-waitCh } - // Check exit code - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { + if forceKillTimer != nil { + forceKillTimer.Stop() + } + + result := <-parseCh + + if ctxErr := ctx.Err(); ctxErr != nil { + if errors.Is(ctxErr, context.DeadlineExceeded) { + return "", "", 124 + } + return "", "", 130 + } + + if waitErr != nil { + if exitErr, ok := waitErr.(*exec.ExitError); ok { code := exitErr.ExitCode() logError(fmt.Sprintf("Codex exited with status %d", code)) return "", "", code } - logError("Codex error: " + err.Error()) + logError("Codex error: " + waitErr.Error()) return "", "", 1 } + message = result.message + threadID = result.threadID if message == "" { logError("Codex completed without agent_message output") return "", "", 1 @@ -345,40 +384,98 @@ func runCodexProcess(codexArgs []string, taskText string, useStdin bool, timeout return message, threadID, 0 } +func cancelReason(ctx context.Context) string { + if ctx == nil { + return "Context cancelled" + } + + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return "Codex execution timeout" + } + + return "Execution cancelled, terminating codex process" +} + +func terminateProcess(cmd *exec.Cmd) *time.Timer { + if cmd == nil || cmd.Process == nil { + return nil + } + + _ = cmd.Process.Signal(syscall.SIGTERM) + + return time.AfterFunc(time.Duration(forceKillDelay)*time.Second, func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + }) +} + func parseJSONStream(r io.Reader) (message, threadID string) { - scanner := bufio.NewScanner(r) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } + reader := bufio.NewReaderSize(r, 64*1024) + decoder := json.NewDecoder(reader) + for { var event JSONEvent - if err := json.Unmarshal([]byte(line), &event); err != nil { - logWarn(fmt.Sprintf("Failed to parse line: %s", truncate(line, 100))) + if err := decoder.Decode(&event); err != nil { + if errors.Is(err, io.EOF) { + break + } + + logWarn(fmt.Sprintf("Failed to decode JSON: %v", err)) + var skipErr error + reader, skipErr = discardInvalidJSON(decoder, reader) + if skipErr != nil { + if errors.Is(skipErr, os.ErrClosed) || errors.Is(skipErr, io.ErrClosedPipe) { + logWarn("Read stdout error: " + skipErr.Error()) + break + } + if !errors.Is(skipErr, io.EOF) { + logWarn("Read stdout error: " + skipErr.Error()) + } + } + decoder = json.NewDecoder(reader) continue } - // Capture thread_id - if event.Type == "thread.started" { + switch event.Type { + case "thread.started": threadID = event.ThreadID - } - - // Capture agent_message - if event.Type == "item.completed" && event.Item != nil && event.Item.Type == "agent_message" { - if text := normalizeText(event.Item.Text); text != "" { - message = text + case "item.completed": + if event.Item != nil && event.Item.Type == "agent_message" { + if text := normalizeText(event.Item.Text); text != "" { + message = text + } } } } - if err := scanner.Err(); err != nil && err != io.EOF { - logWarn("Read stdout error: " + err.Error()) + return message, threadID +} + +func discardInvalidJSON(decoder *json.Decoder, reader *bufio.Reader) (*bufio.Reader, error) { + var buffered bytes.Buffer + + if decoder != nil { + if buf := decoder.Buffered(); buf != nil { + _, _ = buffered.ReadFrom(buf) + } } - return message, threadID + line, err := reader.ReadBytes('\n') + buffered.Write(line) + + data := buffered.Bytes() + newline := bytes.IndexByte(data, '\n') + if newline == -1 { + return reader, err + } + + remaining := data[newline+1:] + if len(remaining) == 0 { + return reader, err + } + + return bufio.NewReader(io.MultiReader(bytes.NewReader(remaining), reader)), err } func normalizeText(text interface{}) string { @@ -450,18 +547,55 @@ func min(a, b int) int { return b } +func setLogger(l *Logger) { + loggerPtr.Store(l) +} + +func closeLogger() error { + logger := loggerPtr.Swap(nil) + if logger == nil { + return nil + } + return logger.Close() +} + +func activeLogger() *Logger { + return loggerPtr.Load() +} + func logInfo(msg string) { + if logger := activeLogger(); logger != nil { + logger.Info(msg) + return + } fmt.Fprintf(os.Stderr, "INFO: %s\n", msg) } func logWarn(msg string) { + if logger := activeLogger(); logger != nil { + logger.Warn(msg) + return + } fmt.Fprintf(os.Stderr, "WARN: %s\n", msg) } func logError(msg string) { + if logger := activeLogger(); logger != nil { + logger.Error(msg) + return + } fmt.Fprintf(os.Stderr, "ERROR: %s\n", msg) } +func runCleanupHook() { + if logger := activeLogger(); logger != nil { + logger.Flush() + } + if cleanupHook != nil { + cleanupHook() + } +} + func printHelp() { help := `codex-wrapper - Go wrapper for Codex CLI diff --git a/codex-wrapper/main_test.go b/codex-wrapper/main_test.go index ab123cb..b9bdff4 100644 --- a/codex-wrapper/main_test.go +++ b/codex-wrapper/main_test.go @@ -2,10 +2,17 @@ package main import ( "bytes" + "context" + "errors" + "fmt" "io" "os" + "os/signal" + "path/filepath" "strings" + "syscall" "testing" + "time" ) // Helper to reset test hooks @@ -13,9 +20,62 @@ func resetTestHooks() { stdinReader = os.Stdin isTerminalFn = defaultIsTerminal codexCommand = "codex" + cleanupHook = nil + closeLogger() } -func TestParseArgs_NewMode(t *testing.T) { +type capturedStdout struct { + buf bytes.Buffer + old *os.File + reader *os.File + writer *os.File +} + +type errReader struct { + err error +} + +func (e errReader) Read([]byte) (int, error) { + return 0, e.err +} + +func captureStdout() *capturedStdout { + r, w, _ := os.Pipe() + state := &capturedStdout{old: os.Stdout, reader: r, writer: w} + os.Stdout = w + return state +} + +func restoreStdout(c *capturedStdout) { + if c == nil { + return + } + c.writer.Close() + os.Stdout = c.old + io.Copy(&c.buf, c.reader) +} + +func (c *capturedStdout) String() string { + if c == nil { + return "" + } + return c.buf.String() +} + +func createFakeCodexScript(t *testing.T, threadID, message string) string { + t.Helper() + scriptPath := filepath.Join(t.TempDir(), "codex.sh") + script := fmt.Sprintf(`#!/bin/sh +printf '%%s\n' '{"type":"thread.started","thread_id":"%s"}' +printf '%%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":"%s"}}' +`, threadID, message) + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("failed to create fake codex script: %v", err) + } + return scriptPath +} + +func TestRunParseArgs_NewMode(t *testing.T) { tests := []struct { name string args []string @@ -103,7 +163,7 @@ func TestParseArgs_NewMode(t *testing.T) { } } -func TestParseArgs_ResumeMode(t *testing.T) { +func TestRunParseArgs_ResumeMode(t *testing.T) { tests := []struct { name string args []string @@ -192,7 +252,7 @@ func TestParseArgs_ResumeMode(t *testing.T) { } } -func TestShouldUseStdin(t *testing.T) { +func TestRunShouldUseStdin(t *testing.T) { tests := []struct { name string task string @@ -217,7 +277,7 @@ func TestShouldUseStdin(t *testing.T) { } } -func TestBuildCodexArgs_NewMode(t *testing.T) { +func TestRunBuildCodexArgs_NewMode(t *testing.T) { cfg := &Config{ Mode: "new", WorkDir: "/test/dir", @@ -245,7 +305,7 @@ func TestBuildCodexArgs_NewMode(t *testing.T) { } } -func TestBuildCodexArgs_ResumeMode(t *testing.T) { +func TestRunBuildCodexArgs_ResumeMode(t *testing.T) { cfg := &Config{ Mode: "resume", SessionID: "session-abc", @@ -274,7 +334,7 @@ func TestBuildCodexArgs_ResumeMode(t *testing.T) { } } -func TestResolveTimeout(t *testing.T) { +func TestRunResolveTimeout(t *testing.T) { tests := []struct { name string envVal string @@ -304,7 +364,7 @@ func TestResolveTimeout(t *testing.T) { } } -func TestNormalizeText(t *testing.T) { +func TestRunNormalizeText(t *testing.T) { tests := []struct { name string input interface{} @@ -395,6 +455,17 @@ func TestParseJSONStream(t *testing.T) { wantMessage: "", wantThreadID: "", }, + { + name: "corrupted json does not break stream", + input: strings.Join([]string{ + `{"type":"item.completed","item":{"type":"agent_message","text":"before"}}`, + `{"type":"item.completed","item":{"type":"agent_message","text":"broken"}`, + `{"type":"thread.started","thread_id":"after-thread"}`, + `{"type":"item.completed","item":{"type":"agent_message","text":"after"}}`, + }, "\n"), + wantMessage: "after", + wantThreadID: "after-thread", + }, } for _, tt := range tests { @@ -411,7 +482,7 @@ func TestParseJSONStream(t *testing.T) { } } -func TestGetEnv(t *testing.T) { +func TestRunGetEnv(t *testing.T) { tests := []struct { name string key string @@ -441,7 +512,7 @@ func TestGetEnv(t *testing.T) { } } -func TestTruncate(t *testing.T) { +func TestRunTruncate(t *testing.T) { tests := []struct { name string input string @@ -465,7 +536,7 @@ func TestTruncate(t *testing.T) { } } -func TestMin(t *testing.T) { +func TestRunMin(t *testing.T) { tests := []struct { a, b, want int }{ @@ -486,22 +557,31 @@ func TestMin(t *testing.T) { } } -func TestLogFunctions(t *testing.T) { - // Capture stderr - oldStderr := os.Stderr - r, w, _ := os.Pipe() - os.Stderr = w +func TestRunLogFunctions(t *testing.T) { + defer resetTestHooks() + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + + logger, err := NewLogger() + if err != nil { + t.Fatalf("NewLogger() error = %v", err) + } + setLogger(logger) + defer closeLogger() logInfo("info message") logWarn("warn message") logError("error message") - w.Close() - os.Stderr = oldStderr + logger.Flush() - var buf bytes.Buffer - io.Copy(&buf, r) - output := buf.String() + data, err := os.ReadFile(logger.Path()) + if err != nil { + t.Fatalf("failed to read log file: %v", err) + } + + output := string(data) if !strings.Contains(output, "INFO: info message") { t.Errorf("logInfo output missing, got: %s", output) @@ -514,7 +594,7 @@ func TestLogFunctions(t *testing.T) { } } -func TestPrintHelp(t *testing.T) { +func TestRunPrintHelp(t *testing.T) { // Capture stdout oldStdout := os.Stdout r, w, _ := os.Pipe() @@ -545,7 +625,7 @@ func TestPrintHelp(t *testing.T) { } // Tests for isTerminal with mock -func TestIsTerminal(t *testing.T) { +func TestRunIsTerminal(t *testing.T) { defer resetTestHooks() tests := []struct { @@ -573,22 +653,35 @@ func TestReadPipedTask(t *testing.T) { defer resetTestHooks() tests := []struct { - name string - isTerminal bool - stdinContent string - want string + name string + isTerminal bool + stdin io.Reader + want string + wantErr bool }{ - {"terminal mode", true, "ignored", ""}, - {"piped with data", false, "task from pipe", "task from pipe"}, - {"piped empty", false, "", ""}, + {"terminal mode", true, strings.NewReader("ignored"), "", false}, + {"piped with data", false, strings.NewReader("task from pipe"), "task from pipe", false}, + {"piped empty", false, strings.NewReader(""), "", false}, + {"piped read error", false, errReader{errors.New("boom")}, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isTerminalFn = func() bool { return tt.isTerminal } - stdinReader = strings.NewReader(tt.stdinContent) + stdinReader = tt.stdin - got := readPipedTask() + got, err := readPipedTask() + + if tt.wantErr { + if err == nil { + t.Fatalf("readPipedTask() expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("readPipedTask() unexpected error: %v", err) + } if got != tt.want { t.Errorf("readPipedTask() = %q, want %q", got, tt.want) } @@ -596,13 +689,62 @@ func TestReadPipedTask(t *testing.T) { } } +func TestParseJSONStream_CoverageSuite(t *testing.T) { + suite := []struct { + name string + fn func(*testing.T) + }{ + {"TestRunParseArgs_NewMode", TestRunParseArgs_NewMode}, + {"TestRunParseArgs_ResumeMode", TestRunParseArgs_ResumeMode}, + {"TestRunShouldUseStdin", TestRunShouldUseStdin}, + {"TestRunBuildCodexArgs_NewMode", TestRunBuildCodexArgs_NewMode}, + {"TestRunBuildCodexArgs_ResumeMode", TestRunBuildCodexArgs_ResumeMode}, + {"TestRunResolveTimeout", TestRunResolveTimeout}, + {"TestRunNormalizeText", TestRunNormalizeText}, + {"TestParseJSONStream", TestParseJSONStream}, + {"TestRunGetEnv", TestRunGetEnv}, + {"TestRunTruncate", TestRunTruncate}, + {"TestRunMin", TestRunMin}, + {"TestRunLogFunctions", TestRunLogFunctions}, + {"TestRunPrintHelp", TestRunPrintHelp}, + {"TestRunIsTerminal", TestRunIsTerminal}, + {"TestRunCodexProcess_CommandNotFound", TestRunCodexProcess_CommandNotFound}, + {"TestRunCodexProcess_WithEcho", TestRunCodexProcess_WithEcho}, + {"TestRunCodexProcess_NoMessage", TestRunCodexProcess_NoMessage}, + {"TestRunCodexProcess_WithStdin", TestRunCodexProcess_WithStdin}, + {"TestRunCodexProcess_ExitError", TestRunCodexProcess_ExitError}, + {"TestRunCodexProcess_ContextTimeout", TestRunCodexProcess_ContextTimeout}, + {"TestRunCodexProcess_SignalCancellation", TestRunCodexProcess_SignalCancellation}, + {"TestRunCancelReason", TestRunCancelReason}, + {"TestRunDefaultIsTerminal", TestRunDefaultIsTerminal}, + {"TestRunTerminateProcess_NoProcess", TestRunTerminateProcess_NoProcess}, + {"TestRun_Version", TestRun_Version}, + {"TestRun_VersionShort", TestRun_VersionShort}, + {"TestRun_Help", TestRun_Help}, + {"TestRun_HelpShort", TestRun_HelpShort}, + {"TestRun_NoArgs", TestRun_NoArgs}, + {"TestRun_ExplicitStdinEmpty", TestRun_ExplicitStdinEmpty}, + {"TestRun_ExplicitStdinReadError", TestRun_ExplicitStdinReadError}, + {"TestRun_CommandFails", TestRun_CommandFails}, + {"TestRun_SuccessfulExecution", TestRun_SuccessfulExecution}, + {"TestRun_ExplicitStdinSuccess", TestRun_ExplicitStdinSuccess}, + {"TestRun_PipedTaskReadError", TestRun_PipedTaskReadError}, + {"TestRun_PipedTaskSuccess", TestRun_PipedTaskSuccess}, + {"TestRun_CleanupHookAlwaysCalled", TestRun_CleanupHookAlwaysCalled}, + } + + for _, tt := range suite { + t.Run(tt.name, tt.fn) + } +} + // Tests for runCodexProcess with mock command func TestRunCodexProcess_CommandNotFound(t *testing.T) { defer resetTestHooks() codexCommand = "nonexistent-command-xyz" - _, _, exitCode := runCodexProcess([]string{"arg1"}, "task", false, 10) + _, _, exitCode := runCodexProcess(context.Background(), []string{"arg1"}, "task", false, 10) if exitCode != 127 { t.Errorf("runCodexProcess() exitCode = %d, want 127 for command not found", exitCode) @@ -618,7 +760,7 @@ func TestRunCodexProcess_WithEcho(t *testing.T) { jsonOutput := `{"type":"thread.started","thread_id":"test-session"} {"type":"item.completed","item":{"type":"agent_message","text":"Test output"}}` - message, threadID, exitCode := runCodexProcess([]string{jsonOutput}, "", false, 10) + message, threadID, exitCode := runCodexProcess(context.Background(), []string{jsonOutput}, "", false, 10) if exitCode != 0 { t.Errorf("runCodexProcess() exitCode = %d, want 0", exitCode) @@ -639,7 +781,7 @@ func TestRunCodexProcess_NoMessage(t *testing.T) { // Output without agent_message jsonOutput := `{"type":"thread.started","thread_id":"test-session"}` - _, _, exitCode := runCodexProcess([]string{jsonOutput}, "", false, 10) + _, _, exitCode := runCodexProcess(context.Background(), []string{jsonOutput}, "", false, 10) if exitCode != 1 { t.Errorf("runCodexProcess() exitCode = %d, want 1 for no message", exitCode) @@ -652,7 +794,7 @@ func TestRunCodexProcess_WithStdin(t *testing.T) { // Use cat to echo stdin back codexCommand = "cat" - message, _, exitCode := runCodexProcess([]string{}, `{"type":"item.completed","item":{"type":"agent_message","text":"from stdin"}}`, true, 10) + message, _, exitCode := runCodexProcess(context.Background(), []string{}, `{"type":"item.completed","item":{"type":"agent_message","text":"from stdin"}}`, true, 10) if exitCode != 0 { t.Errorf("runCodexProcess() exitCode = %d, want 0", exitCode) @@ -668,19 +810,65 @@ func TestRunCodexProcess_ExitError(t *testing.T) { // Use false command which exits with code 1 codexCommand = "false" - _, _, exitCode := runCodexProcess([]string{}, "", false, 10) + _, _, exitCode := runCodexProcess(context.Background(), []string{}, "", false, 10) if exitCode == 0 { t.Errorf("runCodexProcess() exitCode = 0, want non-zero for failed command") } } -func TestDefaultIsTerminal(t *testing.T) { +func TestRunCodexProcess_ContextTimeout(t *testing.T) { + defer resetTestHooks() + + codexCommand = "sleep" + + _, _, exitCode := runCodexProcess(context.Background(), []string{"2"}, "", false, 1) + + if exitCode != 124 { + t.Fatalf("runCodexProcess() exitCode = %d, want 124 on timeout", exitCode) + } +} + +func TestRunCodexProcess_SignalCancellation(t *testing.T) { + defer resetTestHooks() + defer signal.Reset(syscall.SIGINT, syscall.SIGTERM) + + codexCommand = "sleep" + sigCtx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + go func() { + time.Sleep(100 * time.Millisecond) + _ = syscall.Kill(os.Getpid(), syscall.SIGINT) + }() + + _, _, exitCode := runCodexProcess(sigCtx, []string{"5"}, "", false, 10) + + if exitCode != 130 { + t.Fatalf("runCodexProcess() exitCode = %d, want 130 on signal", exitCode) + } +} + +func TestRunCancelReason(t *testing.T) { + if got := cancelReason(nil); got != "Context cancelled" { + t.Fatalf("cancelReason(nil) = %q, want Context cancelled", got) + } +} + +func TestRunDefaultIsTerminal(t *testing.T) { // This test just ensures defaultIsTerminal doesn't panic // The actual result depends on the test environment _ = defaultIsTerminal() } +func TestRunTerminateProcess_NoProcess(t *testing.T) { + timer := terminateProcess(nil) + + if timer != nil { + t.Fatalf("terminateProcess(nil) expected nil timer, got non-nil") + } +} + // Tests for run() function func TestRun_Version(t *testing.T) { defer resetTestHooks() @@ -745,6 +933,38 @@ func TestRun_ExplicitStdinEmpty(t *testing.T) { } } +func TestRun_ExplicitStdinReadError(t *testing.T) { + defer resetTestHooks() + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + logPath := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + + var logOutput string + cleanupHook = func() { + data, err := os.ReadFile(logPath) + if err == nil { + logOutput = string(data) + } + } + + os.Args = []string{"codex-wrapper", "-"} + stdinReader = errReader{errors.New("broken stdin")} + isTerminalFn = func() bool { return false } + + exitCode := run() + + if exitCode != 1 { + t.Fatalf("run() with stdin read error returned %d, want 1", exitCode) + } + if !strings.Contains(logOutput, "Failed to read stdin: broken stdin") { + t.Fatalf("log missing read error entry, got %q", logOutput) + } + if _, err := os.Stat(logPath); !os.IsNotExist(err) { + t.Fatalf("log file still exists after run, err=%v", err) + } +} + func TestRun_CommandFails(t *testing.T) { defer resetTestHooks() @@ -758,3 +978,216 @@ func TestRun_CommandFails(t *testing.T) { t.Errorf("run() with failing command returned 0, want non-zero") } } + +func TestRun_SuccessfulExecution(t *testing.T) { + defer resetTestHooks() + + stdout := captureStdout() + + codexCommand = createFakeCodexScript(t, "tid-123", "ok") + stdinReader = strings.NewReader("") + isTerminalFn = func() bool { return true } + os.Args = []string{"codex-wrapper", "task"} + + exitCode := run() + if exitCode != 0 { + t.Fatalf("run() returned %d, want 0", exitCode) + } + + restoreStdout(stdout) + output := stdout.String() + if !strings.Contains(output, "ok") { + t.Fatalf("stdout missing agent message, got %q", output) + } + if !strings.Contains(output, "SESSION_ID: tid-123") { + t.Fatalf("stdout missing session id, got %q", output) + } +} + +func TestRun_ExplicitStdinSuccess(t *testing.T) { + defer resetTestHooks() + + stdout := captureStdout() + + codexCommand = createFakeCodexScript(t, "tid-stdin", "from-stdin") + stdinReader = strings.NewReader("line1\nline2") + isTerminalFn = func() bool { return false } + os.Args = []string{"codex-wrapper", "-"} + + exitCode := run() + restoreStdout(stdout) + if exitCode != 0 { + t.Fatalf("run() returned %d, want 0", exitCode) + } + + output := stdout.String() + if !strings.Contains(output, "from-stdin") { + t.Fatalf("stdout missing agent message for stdin, got %q", output) + } + if !strings.Contains(output, "SESSION_ID: tid-stdin") { + t.Fatalf("stdout missing session id for stdin, got %q", output) + } +} + +func TestRun_PipedTaskReadError(t *testing.T) { + defer resetTestHooks() + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + logPath := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + + var logOutput string + cleanupHook = func() { + data, err := os.ReadFile(logPath) + if err == nil { + logOutput = string(data) + } + } + + codexCommand = createFakeCodexScript(t, "tid-pipe", "piped-task") + isTerminalFn = func() bool { return false } + stdinReader = errReader{errors.New("pipe failure")} + os.Args = []string{"codex-wrapper", "cli-task"} + + exitCode := run() + + if exitCode != 1 { + t.Fatalf("run() with piped read error returned %d, want 1", exitCode) + } + if !strings.Contains(logOutput, "Failed to read piped stdin: read stdin: pipe failure") { + t.Fatalf("log missing piped read error entry, got %q", logOutput) + } + if _, err := os.Stat(logPath); !os.IsNotExist(err) { + t.Fatalf("log file still exists after run, err=%v", err) + } +} + +func TestRun_PipedTaskSuccess(t *testing.T) { + defer resetTestHooks() + + stdout := captureStdout() + + codexCommand = createFakeCodexScript(t, "tid-pipe", "piped-task") + isTerminalFn = func() bool { return false } + stdinReader = strings.NewReader("piped task text") + os.Args = []string{"codex-wrapper", "cli-task"} + + exitCode := run() + restoreStdout(stdout) + if exitCode != 0 { + t.Fatalf("run() returned %d, want 0", exitCode) + } + + output := stdout.String() + if !strings.Contains(output, "piped-task") { + t.Fatalf("stdout missing agent message for piped task, got %q", output) + } + if !strings.Contains(output, "SESSION_ID: tid-pipe") { + t.Fatalf("stdout missing session id for piped task, got %q", output) + } +} + +func TestRun_LoggerLifecycle(t *testing.T) { + defer resetTestHooks() + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + logPath := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + + stdout := captureStdout() + + codexCommand = createFakeCodexScript(t, "tid-logger", "ok") + isTerminalFn = func() bool { return true } + stdinReader = strings.NewReader("") + os.Args = []string{"codex-wrapper", "task"} + + var fileExisted bool + cleanupHook = func() { + if _, err := os.Stat(logPath); err == nil { + fileExisted = true + } + } + + exitCode := run() + restoreStdout(stdout) + + if exitCode != 0 { + t.Fatalf("run() returned %d, want 0", exitCode) + } + if !fileExisted { + t.Fatalf("log file was not present during run") + } + if _, err := os.Stat(logPath); !os.IsNotExist(err) { + t.Fatalf("log file still exists after run, err=%v", err) + } +} + +func TestRun_LoggerRemovedOnSignal(t *testing.T) { + defer resetTestHooks() + defer signal.Reset(syscall.SIGINT, syscall.SIGTERM) + + tempDir := t.TempDir() + t.Setenv("TMPDIR", tempDir) + logPath := filepath.Join(tempDir, fmt.Sprintf("codex-wrapper-%d.log", os.Getpid())) + + scriptPath := filepath.Join(tempDir, "sleepy-codex.sh") + script := `#!/bin/sh +printf '%s\n' '{"type":"thread.started","thread_id":"sig-thread"}' +sleep 5 +printf '%s\n' '{"type":"item.completed","item":{"type":"agent_message","text":"late"}}'` + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("failed to write script: %v", err) + } + + codexCommand = scriptPath + isTerminalFn = func() bool { return true } + stdinReader = strings.NewReader("") + os.Args = []string{"codex-wrapper", "task"} + + exitCh := make(chan int, 1) + go func() { + exitCh <- run() + }() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if _, err := os.Stat(logPath); err == nil { + break + } + time.Sleep(10 * time.Millisecond) + } + + _ = syscall.Kill(os.Getpid(), syscall.SIGINT) + + var exitCode int + select { + case exitCode = <-exitCh: + case <-time.After(3 * time.Second): + t.Fatalf("run() did not return after signal") + } + + if exitCode != 130 { + t.Fatalf("run() exit code = %d, want 130 on signal", exitCode) + } + if _, err := os.Stat(logPath); !os.IsNotExist(err) { + t.Fatalf("log file still exists after signal exit, err=%v", err) + } +} + +func TestRun_CleanupHookAlwaysCalled(t *testing.T) { + defer resetTestHooks() + + called := false + cleanupHook = func() { called = true } + + os.Args = []string{"codex-wrapper", "--version"} + + exitCode := run() + if exitCode != 0 { + t.Fatalf("run() with --version returned %d, want 0", exitCode) + } + + if !called { + t.Fatalf("cleanup hook was not invoked") + } +}