From 8e7610c77b5298b0c24571ba6090743d593969db Mon Sep 17 00:00:00 2001 From: jettwang Date: Sat, 20 Jun 2026 23:57:31 +0800 Subject: [PATCH] test: harden security-sensitive path coverage --- internal/app/agentmode_test.go | 102 ++++++++++++++++++++++-------- internal/sshclient/client_test.go | 39 ++++++++++++ 2 files changed, 115 insertions(+), 26 deletions(-) diff --git a/internal/app/agentmode_test.go b/internal/app/agentmode_test.go index 7146523..21f46d2 100644 --- a/internal/app/agentmode_test.go +++ b/internal/app/agentmode_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "os" + "strings" "testing" "time" @@ -130,35 +131,11 @@ func TestClassifyError(t *testing.T) { // even though the host is never reachable. func TestRun_BlockedCommandShortCircuits(t *testing.T) { t.Setenv("HOME", t.TempDir()) - old := os.Stdout - r, w, err := os.Pipe() - if err != nil { - t.Fatalf("failed to create pipe: %v", err) - } - os.Stdout = w - // 192.0.2.1 is RFC 5737 TEST-NET-1: if validation did not short-circuit, // the dial would block instead of returning instantly. - runErr := Run([]string{"sshx", "-h=192.0.2.1", "--json", "rm -rf /"}) - - if closeErr := w.Close(); closeErr != nil { - t.Logf("failed to close pipe writer: %v", closeErr) - } - os.Stdout = old - var buf bytes.Buffer - if _, copyErr := io.Copy(&buf, r); copyErr != nil { - t.Logf("failed to copy pipe output: %v", copyErr) - } - - if !errors.Is(runErr, ErrReported) { - t.Fatalf("expected ErrReported, got %v", runErr) - } - var result map[string]any - if jErr := json.Unmarshal(buf.Bytes(), &result); jErr != nil { - t.Fatalf("invalid JSON output %q: %v", buf.String(), jErr) - } + result := runReportedJSON(t, []string{"sshx", "-h=192.0.2.1", "--json", "rm -rf /"}) if result["error_kind"] != "blocked" { - t.Errorf("expected error_kind=blocked, got %v (full: %s)", result["error_kind"], buf.String()) + t.Errorf("expected error_kind=blocked, got %v", result["error_kind"]) } if code, ok := result["exit_code"].(float64); !ok || code != -1 { t.Errorf("expected exit_code=-1, got %v", result["exit_code"]) @@ -168,6 +145,48 @@ func TestRun_BlockedCommandShortCircuits(t *testing.T) { } } +func TestRun_JSONConfigFailuresDoNotConnect(t *testing.T) { + tests := []struct { + name string + args []string + wantErrorText string + }{ + { + name: "pty conflicts with json", + args: []string{"sshx", "-h=192.0.2.1", "--json", "--pty", "--no-audit", "uptime"}, + wantErrorText: "--pty cannot be combined with --json", + }, + { + name: "invalid timeout", + args: []string{"sshx", "-h=192.0.2.1", "--json", "--timeout=banana", "--no-audit", "uptime"}, + wantErrorText: "invalid --timeout value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + result := runReportedJSON(t, tt.args) + if result["error_kind"] != "config" { + t.Fatalf("expected error_kind=config, got %v", result["error_kind"]) + } + if result["success"] != false { + t.Errorf("expected success=false, got %v", result["success"]) + } + if code, ok := result["exit_code"].(float64); !ok || code != -1 { + t.Errorf("expected exit_code=-1, got %v", result["exit_code"]) + } + errText, ok := result["error"].(string) + if !ok { + t.Fatalf("expected error string, got %T", result["error"]) + } + if !strings.Contains(errText, tt.wantErrorText) { + t.Errorf("expected error containing %q, got %q", tt.wantErrorText, errText) + } + }) + } +} + func TestRun_DryRunJSONDoesNotConnect(t *testing.T) { t.Setenv("HOME", t.TempDir()) result := runDryRunJSON(t, []string{"sshx", "-h=192.0.2.1", "--dry-run", "--json", "uptime"}) @@ -367,3 +386,34 @@ func runDryRunJSON(t *testing.T, args []string) map[string]any { } return result } + +func runReportedJSON(t *testing.T, args []string) map[string]any { + t.Helper() + + old := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("failed to create pipe: %v", err) + } + os.Stdout = w + + runErr := Run(args) + + if closeErr := w.Close(); closeErr != nil { + t.Logf("failed to close pipe writer: %v", closeErr) + } + os.Stdout = old + var buf bytes.Buffer + if _, copyErr := io.Copy(&buf, r); copyErr != nil { + t.Logf("failed to copy pipe output: %v", copyErr) + } + + if !errors.Is(runErr, ErrReported) { + t.Fatalf("expected ErrReported, got %v, output=%s", runErr, buf.String()) + } + var result map[string]any + if jErr := json.Unmarshal(buf.Bytes(), &result); jErr != nil { + t.Fatalf("invalid JSON output %q: %v", buf.String(), jErr) + } + return result +} diff --git a/internal/sshclient/client_test.go b/internal/sshclient/client_test.go index ff7ca09..d1647be 100644 --- a/internal/sshclient/client_test.go +++ b/internal/sshclient/client_test.go @@ -326,6 +326,45 @@ func TestGetHostKeyCallbackStrictModeRejectsUnknownHost(t *testing.T) { assert.Equal(t, "", string(data)) } +func TestGetHostKeyCallbackRejectsChangedKnownHostKey(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + knownHostsPath := filepath.Join(home, ".ssh", "known_hosts") + hostWithPort := net.JoinHostPort("changed-host", "22") + remote := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22} + oldKey := generateTestPublicKey(t) + newKey := generateTestPublicKey(t) + + require.NoError(t, ensureKnownHostsFile(knownHostsPath)) + require.NoError(t, appendHostKey(knownHostsPath, normalizeHostPatterns(hostWithPort, remote), oldKey)) + callback, err := getHostKeyCallback(&Config{KnownHostsPath: knownHostsPath}) + require.NoError(t, err) + + err = callback(hostWithPort, remote, newKey) + require.Error(t, err) + assert.Contains(t, err.Error(), "HOST KEY VERIFICATION FAILED") + assert.Contains(t, err.Error(), "man-in-the-middle") +} + +func TestGetHostKeyCallbackInsecureFallbackRequiresExplicitOptIn(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + knownHostsPath := filepath.Join(home, ".ssh") + require.NoError(t, os.MkdirAll(knownHostsPath, 0o700)) + + _, err := getHostKeyCallback(&Config{KnownHostsPath: knownHostsPath}) + require.Error(t, err) + assert.Contains(t, err.Error(), "known_hosts path") + + callback, err := getHostKeyCallback(&Config{KnownHostsPath: knownHostsPath, AllowInsecureHostKey: true}) + require.NoError(t, err) + require.NotNil(t, callback) + + remote := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22} + key := generateTestPublicKey(t) + require.NoError(t, callback(net.JoinHostPort("insecure-host", "22"), remote, key)) +} + func generateTestPublicKey(t *testing.T) ssh.PublicKey { t.Helper() _, priv, err := ed25519.GenerateKey(rand.Reader)