diff --git a/internal/app/agentmode_test.go b/internal/app/agentmode_test.go index 21f46d2..99b4f4b 100644 --- a/internal/app/agentmode_test.go +++ b/internal/app/agentmode_test.go @@ -187,6 +187,86 @@ func TestRun_JSONConfigFailuresDoNotConnect(t *testing.T) { } } +func TestEmitCommandJSONContracts(t *testing.T) { + tests := []struct { + name string + result sshclient.ExecResult + errKind string + execErr error + wantExitCode float64 + wantSuccess bool + wantErrorKind string + wantErrorText string + }{ + { + name: "remote non-zero exit is structured without sshx error", + result: sshclient.ExecResult{ExitCode: 7, Stdout: "partial\n", Stderr: "warn\n"}, + wantExitCode: 7, + wantSuccess: false, + }, + { + name: "timeout is structured as sshx-level error", + result: sshclient.ExecResult{ExitCode: -1, Stdout: "before timeout\n"}, + errKind: classifyError(sshclient.ErrCommandTimeout), + execErr: sshclient.ErrCommandTimeout, + wantExitCode: -1, + wantSuccess: false, + wantErrorKind: "timeout", + wantErrorText: "command execution timed out", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + raw := captureStdout(t, func() { + emitCommandJSON( + &sshclient.Config{Host: "127.0.0.1", Port: "22", User: "tester", Command: "uptime"}, + sshclient.AuthMethodPassword, + tt.result, + 1500*time.Millisecond, + tt.errKind, + tt.execErr, + ) + }) + var got map[string]any + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("invalid JSON output %q: %v", string(raw), err) + } + if got["success"] != tt.wantSuccess { + t.Errorf("expected success=%v, got %v", tt.wantSuccess, got["success"]) + } + if got["exit_code"] != tt.wantExitCode { + t.Errorf("expected exit_code=%v, got %v", tt.wantExitCode, got["exit_code"]) + } + if got["auth_method"] != string(sshclient.AuthMethodPassword) { + t.Errorf("expected password auth method, got %v", got["auth_method"]) + } + if got["duration_ms"] != float64(1500) { + t.Errorf("expected duration_ms=1500, got %v", got["duration_ms"]) + } + if tt.wantErrorKind == "" { + if _, exists := got["error_kind"]; exists { + t.Errorf("did not expect error_kind, got %v", got["error_kind"]) + } + if _, exists := got["error"]; exists { + t.Errorf("did not expect error, got %v", got["error"]) + } + return + } + if got["error_kind"] != tt.wantErrorKind { + t.Errorf("expected error_kind=%q, got %v", tt.wantErrorKind, got["error_kind"]) + } + errText, ok := got["error"].(string) + if !ok { + t.Fatalf("expected error string, got %T", got["error"]) + } + if !strings.Contains(errText, tt.wantErrorText) { + t.Errorf("expected error containing %q, got %q", tt.wantErrorText, errText) + } + }) + } +} + func TestRun_DryRunJSONDoesNotConnect(t *testing.T) { t.Setenv("HOME", t.TempDir()) result := runDryRunJSON(t, []string{"sshx", "-h=192.0.2.1", "--dry-run", "--json", "uptime"}) @@ -390,6 +470,23 @@ func runDryRunJSON(t *testing.T, args []string) map[string]any { func runReportedJSON(t *testing.T, args []string) map[string]any { t.Helper() + var runErr error + raw := captureStdout(t, func() { + runErr = Run(args) + }) + if !errors.Is(runErr, ErrReported) { + t.Fatalf("expected ErrReported, got %v, output=%s", runErr, string(raw)) + } + var result map[string]any + if jErr := json.Unmarshal(raw, &result); jErr != nil { + t.Fatalf("invalid JSON output %q: %v", string(raw), jErr) + } + return result +} + +func captureStdout(t *testing.T, fn func()) []byte { + t.Helper() + old := os.Stdout r, w, err := os.Pipe() if err != nil { @@ -397,7 +494,7 @@ func runReportedJSON(t *testing.T, args []string) map[string]any { } os.Stdout = w - runErr := Run(args) + fn() if closeErr := w.Close(); closeErr != nil { t.Logf("failed to close pipe writer: %v", closeErr) @@ -407,13 +504,5 @@ func runReportedJSON(t *testing.T, args []string) map[string]any { if _, copyErr := io.Copy(&buf, r); copyErr != nil { t.Logf("failed to copy pipe output: %v", copyErr) } - - if !errors.Is(runErr, ErrReported) { - t.Fatalf("expected ErrReported, got %v, output=%s", runErr, buf.String()) - } - var result map[string]any - if jErr := json.Unmarshal(buf.Bytes(), &result); jErr != nil { - t.Fatalf("invalid JSON output %q: %v", buf.String(), jErr) - } - return result + return buf.Bytes() } diff --git a/internal/sshclient/runcommand_test.go b/internal/sshclient/runcommand_test.go index ef8c1d0..e9c4e67 100644 --- a/internal/sshclient/runcommand_test.go +++ b/internal/sshclient/runcommand_test.go @@ -1,6 +1,7 @@ package sshclient import ( + "bufio" "crypto/ed25519" "crypto/rand" "encoding/binary" @@ -116,6 +117,18 @@ func runFakeCommand(ch ssh.Channel, command string) uint32 { writeAll(ch, "to-out\n") writeAll(ch.Stderr(), "to-err\n") return 0 + case "sudo -S -p '' whoami": + stdin, err := bufio.NewReader(ch).ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + writeAll(ch.Stderr(), "failed to read sudo stdin\n") + return 24 + } + if stdin != "sudo-fixture\n" { + writeAll(ch.Stderr(), "unexpected sudo stdin\n") + return 25 + } + writeAll(ch, "sudo-ok\n") + return 0 case "sleep": time.Sleep(5 * time.Second) return 0 @@ -204,6 +217,24 @@ func TestRunCommandTimeout(t *testing.T) { assert.Less(t, elapsed, 4*time.Second, "timeout should fire well before the command finishes") } +func TestRunCommandSudoFeedsPasswordOnStdin(t *testing.T) { + host, port := startTestSSHServer(t) + client := dialTestClient(t, host, port) + sudoPassword := "sudo-fixture" // #nosec G101 -- fake test password used only for stdin contract coverage. + client.config.Command = "sudo whoami" + client.config.Password = sudoPassword + client.config.Timeout = 2 * time.Second + + res, err := client.RunCommand(true) + + require.NoError(t, err) + assert.Equal(t, 0, res.ExitCode) + assert.Equal(t, "sudo-ok\n", res.Stdout) + assert.Empty(t, res.Stderr) + assert.NotContains(t, res.Stdout, sudoPassword) + assert.NotContains(t, res.Stderr, sudoPassword) +} + func TestCappedBufferTruncates(t *testing.T) { buf := newCappedBuffer(8) n, err := buf.Write([]byte("hello"))