Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions internal/app/agentmode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,86 @@ func TestRun_JSONConfigFailuresDoNotConnect(t *testing.T) {
}
}

func TestEmitCommandJSONContracts(t *testing.T) {
tests := []struct {
name string
result sshclient.ExecResult
errKind string
execErr error
wantExitCode float64
wantSuccess bool
wantErrorKind string
wantErrorText string
}{
{
name: "remote non-zero exit is structured without sshx error",
result: sshclient.ExecResult{ExitCode: 7, Stdout: "partial\n", Stderr: "warn\n"},
wantExitCode: 7,
wantSuccess: false,
},
{
name: "timeout is structured as sshx-level error",
result: sshclient.ExecResult{ExitCode: -1, Stdout: "before timeout\n"},
errKind: classifyError(sshclient.ErrCommandTimeout),
execErr: sshclient.ErrCommandTimeout,
wantExitCode: -1,
wantSuccess: false,
wantErrorKind: "timeout",
wantErrorText: "command execution timed out",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
raw := captureStdout(t, func() {
emitCommandJSON(
&sshclient.Config{Host: "127.0.0.1", Port: "22", User: "tester", Command: "uptime"},
sshclient.AuthMethodPassword,
tt.result,
1500*time.Millisecond,
tt.errKind,
tt.execErr,
)
})
var got map[string]any
if err := json.Unmarshal(raw, &got); err != nil {
t.Fatalf("invalid JSON output %q: %v", string(raw), err)
}
if got["success"] != tt.wantSuccess {
t.Errorf("expected success=%v, got %v", tt.wantSuccess, got["success"])
}
if got["exit_code"] != tt.wantExitCode {
t.Errorf("expected exit_code=%v, got %v", tt.wantExitCode, got["exit_code"])
}
if got["auth_method"] != string(sshclient.AuthMethodPassword) {
t.Errorf("expected password auth method, got %v", got["auth_method"])
}
if got["duration_ms"] != float64(1500) {
t.Errorf("expected duration_ms=1500, got %v", got["duration_ms"])
}
if tt.wantErrorKind == "" {
if _, exists := got["error_kind"]; exists {
t.Errorf("did not expect error_kind, got %v", got["error_kind"])
}
if _, exists := got["error"]; exists {
t.Errorf("did not expect error, got %v", got["error"])
}
return
}
if got["error_kind"] != tt.wantErrorKind {
t.Errorf("expected error_kind=%q, got %v", tt.wantErrorKind, got["error_kind"])
}
errText, ok := got["error"].(string)
if !ok {
t.Fatalf("expected error string, got %T", got["error"])
}
if !strings.Contains(errText, tt.wantErrorText) {
t.Errorf("expected error containing %q, got %q", tt.wantErrorText, errText)
}
})
}
}

func TestRun_DryRunJSONDoesNotConnect(t *testing.T) {
t.Setenv("HOME", t.TempDir())
result := runDryRunJSON(t, []string{"sshx", "-h=192.0.2.1", "--dry-run", "--json", "uptime"})
Expand Down Expand Up @@ -390,14 +470,31 @@ func runDryRunJSON(t *testing.T, args []string) map[string]any {
func runReportedJSON(t *testing.T, args []string) map[string]any {
t.Helper()

var runErr error
raw := captureStdout(t, func() {
runErr = Run(args)
})
if !errors.Is(runErr, ErrReported) {
t.Fatalf("expected ErrReported, got %v, output=%s", runErr, string(raw))
}
var result map[string]any
if jErr := json.Unmarshal(raw, &result); jErr != nil {
t.Fatalf("invalid JSON output %q: %v", string(raw), jErr)
}
return result
}

func captureStdout(t *testing.T, fn func()) []byte {
t.Helper()

old := os.Stdout
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("failed to create pipe: %v", err)
}
os.Stdout = w

runErr := Run(args)
fn()

if closeErr := w.Close(); closeErr != nil {
t.Logf("failed to close pipe writer: %v", closeErr)
Expand All @@ -407,13 +504,5 @@ func runReportedJSON(t *testing.T, args []string) map[string]any {
if _, copyErr := io.Copy(&buf, r); copyErr != nil {
t.Logf("failed to copy pipe output: %v", copyErr)
}

if !errors.Is(runErr, ErrReported) {
t.Fatalf("expected ErrReported, got %v, output=%s", runErr, buf.String())
}
var result map[string]any
if jErr := json.Unmarshal(buf.Bytes(), &result); jErr != nil {
t.Fatalf("invalid JSON output %q: %v", buf.String(), jErr)
}
return result
return buf.Bytes()
}
31 changes: 31 additions & 0 deletions internal/sshclient/runcommand_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sshclient

import (
"bufio"
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
Expand Down Expand Up @@ -116,6 +117,18 @@ func runFakeCommand(ch ssh.Channel, command string) uint32 {
writeAll(ch, "to-out\n")
writeAll(ch.Stderr(), "to-err\n")
return 0
case "sudo -S -p '' whoami":
stdin, err := bufio.NewReader(ch).ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
writeAll(ch.Stderr(), "failed to read sudo stdin\n")
return 24
}
if stdin != "sudo-fixture\n" {
writeAll(ch.Stderr(), "unexpected sudo stdin\n")
return 25
}
writeAll(ch, "sudo-ok\n")
return 0
case "sleep":
time.Sleep(5 * time.Second)
return 0
Expand Down Expand Up @@ -204,6 +217,24 @@ func TestRunCommandTimeout(t *testing.T) {
assert.Less(t, elapsed, 4*time.Second, "timeout should fire well before the command finishes")
}

func TestRunCommandSudoFeedsPasswordOnStdin(t *testing.T) {
host, port := startTestSSHServer(t)
client := dialTestClient(t, host, port)
sudoPassword := "sudo-fixture" // #nosec G101 -- fake test password used only for stdin contract coverage.
client.config.Command = "sudo whoami"
client.config.Password = sudoPassword
client.config.Timeout = 2 * time.Second

res, err := client.RunCommand(true)

require.NoError(t, err)
assert.Equal(t, 0, res.ExitCode)
assert.Equal(t, "sudo-ok\n", res.Stdout)
assert.Empty(t, res.Stderr)
assert.NotContains(t, res.Stdout, sudoPassword)
assert.NotContains(t, res.Stderr, sudoPassword)
}

func TestCappedBufferTruncates(t *testing.T) {
buf := newCappedBuffer(8)
n, err := buf.Write([]byte("hello"))
Expand Down
Loading