From dd40253a6fabd7cd898b860c1ecdd3ddbfa34b70 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Sat, 23 May 2026 16:44:52 +0800 Subject: [PATCH 1/6] feat(runtime): implement P6 command hook stdin/stdout JSON protocol Implement the full command hook protocol for user/repo scopes as specified in issue #683. External commands now communicate via structured JSON on stdin/stdout instead of raw exit codes. Key changes: - New command_handler.go: protocol types (CommandHookPayload/Response), BuildCommandPayload, ParseCommandResponse, RunCommandHook with argv/shell execution modes, env isolation (NEOCODE_HOOK_* only) - stdin: single-line JSON with payload_version, hook_id, point, metadata - stdout: single-line JSON with status, message, update_input, annotations - Graceful fallback: non-JSON stdout falls back to exit code semantics (0=pass, 1/2=block, other=failed) - argv mode (default): params.command as string array, direct exec - shell mode: params.command as string + params.shell=true - update_input wired for user_prompt_submit point - annotations collected into runtime annotation buffer - Context timeout detection prioritized over exit code mapping - 18 new unit tests, all cross-platform (Windows + Unix) Closes #683 Co-Authored-By: Claude Opus 4.7 --- docs/runtime-hooks-design.md | 135 +++++- internal/config/runtime_hooks.go | 37 +- internal/config/runtime_hooks_test.go | 78 ++++ internal/runtime/hooks/command_handler.go | 205 ++++++++ .../runtime/hooks/command_handler_test.go | 439 ++++++++++++++++++ internal/runtime/hooks/result.go | 9 +- internal/runtime/hooks_integration.go | 50 +- internal/runtime/repo_hooks.go | 34 +- internal/runtime/repo_hooks_test.go | 2 +- internal/runtime/run.go | 1 + internal/runtime/user_hooks.go | 102 ++-- 11 files changed, 1042 insertions(+), 50 deletions(-) create mode 100644 internal/runtime/hooks/command_handler.go create mode 100644 internal/runtime/hooks/command_handler_test.go diff --git a/docs/runtime-hooks-design.md b/docs/runtime-hooks-design.md index 2ccf35933..f3527b7c8 100644 --- a/docs/runtime-hooks-design.md +++ b/docs/runtime-hooks-design.md @@ -13,10 +13,11 @@ - P4:生命周期点位扩展(permission/session/compact/subagent)+ 点位能力矩阵 - P5:internal hooks 支持 `async/async_rewake` + run 内存通知队列(ephemeral 注入) - P6-lite:user `http/observe` hooks(仅观测回调) +- P6:user/repo `command` hooks(stdin/stdout JSON 协议) 当前未实现能力: -- command/prompt/agent hooks(P6) +- prompt/agent hooks(P6) ## P2 user hooks 边界 @@ -32,7 +33,8 @@ P2 仅支持: - `kind=http + mode=observe`:允许发送 HTTP 观测回调(不支持 block) - `http observe` 默认不携带 metadata(`include_metadata=false`);即使显式开启也会剥离 `result_content_preview`、`execution_error` - `http observe` 回调端点仅允许 loopback 地址(`localhost` / `127.0.0.1` / `::1`),避免误配为公网外发 -- external kinds 中 `command/prompt/agent` 在 P6-lite 阶段显式拒绝,不会半生效 +- `kind=command + mode=sync`:允许执行外部命令,通过 stdin/stdout JSON 协议通信(详见下方 P6 章节) +- external kinds 中 `prompt/agent` 仍显式拒绝 当前(P3)明确不支持: @@ -105,7 +107,7 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是 约束规则: - `CanBlock=false` 的点位,hook 返回 `block` 会自动降级为观测结果,不中断主链。 -- `CanUpdateInput` 仅作为能力建模;当前阶段不开放输入改写通道。 +- `CanUpdateInput` 在 `user_prompt_submit` 点位已开放:command hook 可通过 stdout JSON 的 `update_input` 字段改写用户输入。 - `UserAllowed=false` 的点位拒绝 user/repo 挂载(配置 fail-fast)。 ### trust gate @@ -135,6 +137,133 @@ trust store 固定路径: - 绝对路径必须位于 workdir 内 - symlink 路径会进行 realpath 校验,禁止绕过 +## P6 command hooks + +`kind=command` 允许 user/repo scope 通过外部可执行脚本参与 hook 链。 + +### stdin 协议 + +外部命令通过 stdin 接收单行 JSON: + +```json +{ + "payload_version": "1", + "hook_id": "my-hook", + "point": "before_tool_call", + "metadata": { + "tool_name": "bash", + "workdir": "/path/to/workspace", + "session_id": "sess_abc123" + } +} +``` + +- `payload_version`:协议版本号,当前固定 `"1"`,变更 stdin 结构时递增 +- `hook_id`:hook 配置中的 `id` +- `point`:触发点位名称 +- `metadata`:经白名单裁剪后的上下文字段(与 builtin/http hook 相同的 allowlist) + +### stdout 协议 + +外部命令通过 stdout 返回单行 JSON: + +```json +{ + "status": "pass", + "message": "optional message", + "update_input": {"text": "rewritten prompt"}, + "annotations": ["note1", "note2"] +} +``` + +- `status`:必填,`pass` / `block` / `failed` +- `message`:可选,进入 hook event 和 annotation buffer +- `update_input`:仅 `CanUpdateInput=true` 的点位(当前仅 `user_prompt_submit`)允许;格式 `{"text": "..."}` 替换用户输入文本 +- `annotations`:字符串数组,进入 runtime annotation buffer + +### stdout 退化模式 + +如果 stdout 不是合法 JSON,handler 退化为 exit code 模式: + +- exit 0 → `pass` +- exit 1 或 2 → `block` +- 其他 → `failed` + +原始 stdout 文本作为 `message`。此模式兼容简单脚本(如 `echo "ok"; exit 0`)。 + +### 执行模式 + +#### argv 模式(默认) + +`params.command` 为字符串数组,直接 exec 不经 shell: + +```yaml +kind: command +params: + command: + - python3 + - /path/to/hook.py +``` + +#### shell 模式 + +`params.command` 为字符串且 `params.shell: true`,通过 `sh -c`(Unix)/ `powershell -Command`(Windows)执行: + +```yaml +kind: command +params: + command: "python3 /path/to/hook.py" + shell: true +``` + +单字符串 `params.command` 不设置 `params.shell: true` 会触发配置校验错误。 + +### 环境变量 + +命令进程仅注入以下环境变量,不继承宿主环境: + +| 变量 | 值 | +|------|------| +| `NEOCODE_HOOK_HOOK_ID` | hook 的 `id` | +| `NEOCODE_HOOK_POINT` | 触发点位(如 `before_tool_call`) | +| `NEOCODE_HOOK_PAYLOAD_VERSION` | `"1"` | + +### 执行约束 + +- workdir = 当前 run 的 workspace(`cmd.Dir = workdir`) +- 超时 = hook 配置的 `timeout_sec`(默认 2s) +- 并发限制 = executor 的 `max_in_flight`(默认 128) +- repo scope command hook 受 trust gate 保护 + +### 示例 + +#### Python + +```python +#!/usr/bin/env python3 +import json, sys + +payload = json.loads(sys.stdin.readline()) +if payload["metadata"].get("tool_name") == "bash": + json.dump({"status": "block", "message": "bash not allowed"}, sys.stdout) +else: + json.dump({"status": "pass"}, sys.stdout) +print() +``` + +#### Bash + +```bash +#!/bin/bash +read -r line +tool=$(echo "$line" | jq -r '.metadata.tool_name // empty') +if [ "$tool" = "rm" ]; then + echo '{"status":"block","message":"rm is blocked"}' +else + echo '{"status":"pass"}' +fi +``` + ## 可观测性 runtime 会透传 hooks 生命周期事件: diff --git a/internal/config/runtime_hooks.go b/internal/config/runtime_hooks.go index 9a3d5d6d8..172accd3d 100644 --- a/internal/config/runtime_hooks.go +++ b/internal/config/runtime_hooks.go @@ -286,8 +286,8 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if normalizedMode != runtimeHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", c.Mode) } - if strings.TrimSpace(readRuntimeHookParamString(c.Params, "command")) == "" { - return fmt.Errorf("kind command requires params.command") + if err := validateRuntimeCommandItem(c.Params); err != nil { + return err } case runtimeHookKindHTTP: if normalizedMode != runtimeHookModeObserve { @@ -349,6 +349,39 @@ func validateRuntimeHTTPObserveItem(c RuntimeHookItemConfig, policy string) erro return nil } +// validateRuntimeCommandItem 校验 command kind 的 params.command 格式。 +// 支持 []string / []any (argv 模式) 和 string + shell=true (shell 模式)。 +func validateRuntimeCommandItem(params map[string]any) error { + if len(params) == 0 { + return fmt.Errorf("kind command requires params.command") + } + raw, ok := params["command"] + if !ok || raw == nil { + return fmt.Errorf("kind command requires params.command") + } + switch v := raw.(type) { + case string: + if strings.TrimSpace(v) == "" { + return fmt.Errorf("kind command requires params.command") + } + shellVal, _ := params["shell"].(bool) + if !shellVal { + return fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") + } + case []any: + if len(v) == 0 { + return fmt.Errorf("kind command requires non-empty params.command") + } + case []string: + if len(v) == 0 { + return fmt.Errorf("kind command requires non-empty params.command") + } + default: + return fmt.Errorf("params.command must be a string (with shell=true) or an array") + } + return nil +} + // isRuntimeHookHTTPObserveLoopbackHost 判断 http observe 回调域名是否属于本地回环地址。 func isRuntimeHookHTTPObserveLoopbackHost(host string) bool { normalized := strings.TrimSpace(strings.ToLower(host)) diff --git a/internal/config/runtime_hooks_test.go b/internal/config/runtime_hooks_test.go index c755e487f..8a988ae5f 100644 --- a/internal/config/runtime_hooks_test.go +++ b/internal/config/runtime_hooks_test.go @@ -152,10 +152,88 @@ func TestRuntimeHooksConfigValidateAllowsCommand(t *testing.T) { Mode: runtimeHookModeSync, TimeoutSec: 2, FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"command": []any{"echo", "ok"}}, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateCommandShellMode(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "cmd-shell", + Point: string(hooks.HookPointAcceptGate), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindCommand, + Mode: runtimeHookModeSync, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"command": "echo ok", "shell": true}, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateCommandStringWithoutShellRejected(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "cmd-no-shell", + Point: string(hooks.HookPointAcceptGate), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindCommand, + Mode: runtimeHookModeSync, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, Params: map[string]any{"command": "echo ok"}, }, }, } + if err := cfg.Validate(); err == nil { + t.Fatal("expected error for string command without shell=true") + } +} + +func TestRuntimeHooksConfigValidateCommandArgvMode(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "cmd-argv", + Point: string(hooks.HookPointAcceptGate), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindCommand, + Mode: runtimeHookModeSync, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"command": []string{"echo", "hello"}}, + }, + }, + } if err := cfg.Validate(); err != nil { t.Fatalf("Validate() error = %v", err) } diff --git a/internal/runtime/hooks/command_handler.go b/internal/runtime/hooks/command_handler.go new file mode 100644 index 000000000..b30fab9fa --- /dev/null +++ b/internal/runtime/hooks/command_handler.go @@ -0,0 +1,205 @@ +package hooks + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "strings" +) + +// CommandHookPayloadVersion 定义 command hook stdin 协议版本号,变更 stdin 结构时递增。 +const CommandHookPayloadVersion = "1" + +// CommandHookPayload 是通过 stdin 传给外部命令的单行 JSON。 +type CommandHookPayload struct { + PayloadVersion string `json:"payload_version"` + HookID string `json:"hook_id"` + Point string `json:"point"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// CommandHookResponse 是外部命令通过 stdout 返回的单行 JSON。 +type CommandHookResponse struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` + UpdateInput json.RawMessage `json:"update_input,omitempty"` + Annotations []string `json:"annotations,omitempty"` +} + +// CommandHookSpec 描述一个 command hook 的执行参数。 +type CommandHookSpec struct { + HookID string + Point HookPoint + Command []string // argv 模式: [binary, arg1, arg2, ...] + Shell bool // true = 通过 sh -c / powershell -Command 执行 + Workdir string +} + +// BuildCommandPayload 构造传给外部命令的 stdin JSON payload。 +func BuildCommandPayload(hookID string, point HookPoint, metadata map[string]any) CommandHookPayload { + payload := CommandHookPayload{ + PayloadVersion: CommandHookPayloadVersion, + HookID: strings.TrimSpace(hookID), + Point: string(point), + } + if len(metadata) > 0 { + payload.Metadata = metadata + } + return payload +} + +// ParseCommandResponse 解析外部命令 stdout 输出的单行 JSON。 +// 非 JSON 输入返回 error,调用方可退化为 exit code 兼容模式。 +func ParseCommandResponse(raw []byte) (CommandHookResponse, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return CommandHookResponse{}, fmt.Errorf("empty stdout") + } + var resp CommandHookResponse + if err := json.Unmarshal(trimmed, &resp); err != nil { + return CommandHookResponse{}, fmt.Errorf("invalid JSON: %w", err) + } + normalized := strings.ToLower(strings.TrimSpace(resp.Status)) + switch normalized { + case "pass", "block", "failed": + resp.Status = normalized + default: + return CommandHookResponse{}, fmt.Errorf("invalid status %q", resp.Status) + } + return resp, nil +} + +// RunCommandHook 执行外部命令并返回结构化的 HookResult。 +func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext) HookResult { + payload := BuildCommandPayload(spec.HookID, spec.Point, input.Metadata) + payloadBytes, err := json.Marshal(payload) + if err != nil { + return HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Status: HookResultFailed, + Message: fmt.Sprintf("command hook marshal payload failed: %v", err), + Error: err.Error(), + } + } + payloadBytes = append(payloadBytes, '\n') + + cmd := buildExecCmd(ctx, spec) + cmd.Dir = spec.Workdir + cmd.Env = buildCommandEnv(spec) + cmd.Stdin = bytes.NewReader(payloadBytes) + + stdout, err := cmd.Output() + message := strings.TrimSpace(string(stdout)) + + // 尝试解析 stdout JSON 协议 + resp, parseErr := ParseCommandResponse(stdout) + if parseErr == nil { + return buildResultFromResponse(spec, resp) + } + + // 退化模式: stdout 非 JSON,按 exit code 推断状态 + return buildResultFromExitCode(ctx, spec, err, message) +} + +func buildExecCmd(ctx context.Context, spec CommandHookSpec) *exec.Cmd { + if spec.Shell && len(spec.Command) > 0 { + shell := spec.Command[0] + if runtime.GOOS == "windows" { + return exec.CommandContext(ctx, "powershell", "-Command", shell) + } + return exec.CommandContext(ctx, "sh", "-c", shell) + } + if len(spec.Command) == 1 { + return exec.CommandContext(ctx, spec.Command[0]) + } + return exec.CommandContext(ctx, spec.Command[0], spec.Command[1:]...) +} + +func buildCommandEnv(spec CommandHookSpec) []string { + env := []string{ + "NEOCODE_HOOK_HOOK_ID=" + spec.HookID, + "NEOCODE_HOOK_POINT=" + string(spec.Point), + "NEOCODE_HOOK_PAYLOAD_VERSION=" + CommandHookPayloadVersion, + } + if runtime.GOOS == "windows" { + if sd := os.Getenv("SystemDrive"); sd != "" { + env = append(env, "SystemDrive="+sd) + } + } + return env +} + +func buildResultFromResponse(spec CommandHookSpec, resp CommandHookResponse) HookResult { + result := HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Message: strings.TrimSpace(resp.Message), + } + switch resp.Status { + case "pass": + result.Status = HookResultPass + case "block": + result.Status = HookResultBlock + case "failed": + result.Status = HookResultFailed + if result.Message == "" { + result.Message = "hook returned failed status" + } + result.Error = result.Message + } + if len(resp.Annotations) > 0 { + result.Metadata.Annotations = resp.Annotations + } + if len(resp.UpdateInput) > 0 { + result.Metadata.UpdateInput = resp.UpdateInput + } + return result +} + +func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err error, message string) HookResult { + result := HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Message: message, + } + if err == nil { + result.Status = HookResultPass + return result + } + // 上下文取消/超时优先判定为 failed + if ctx.Err() != nil { + result.Status = HookResultFailed + if result.Message == "" { + result.Message = fmt.Sprintf("command %v", ctx.Err()) + } + result.Error = ctx.Err().Error() + return result + } + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + code := exitErr.ExitCode() + switch code { + case 1, 2: + result.Status = HookResultBlock + default: + result.Status = HookResultFailed + if result.Message == "" { + result.Message = fmt.Sprintf("command exited with code %d", code) + } + result.Error = err.Error() + } + return result + } + result.Status = HookResultFailed + if result.Message == "" { + result.Message = err.Error() + } + result.Error = err.Error() + return result +} diff --git a/internal/runtime/hooks/command_handler_test.go b/internal/runtime/hooks/command_handler_test.go new file mode 100644 index 000000000..608d757ab --- /dev/null +++ b/internal/runtime/hooks/command_handler_test.go @@ -0,0 +1,439 @@ +package hooks + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" +) + +func TestBuildCommandPayload(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("my-hook", HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "workdir": "/tmp", + }) + if payload.PayloadVersion != CommandHookPayloadVersion { + t.Fatalf("payload_version = %q, want %q", payload.PayloadVersion, CommandHookPayloadVersion) + } + if payload.HookID != "my-hook" { + t.Fatalf("hook_id = %q, want %q", payload.HookID, "my-hook") + } + if payload.Point != string(HookPointBeforeToolCall) { + t.Fatalf("point = %q, want %q", payload.Point, HookPointBeforeToolCall) + } + if payload.Metadata["tool_name"] != "bash" { + t.Fatalf("metadata[tool_name] = %v, want %q", payload.Metadata["tool_name"], "bash") + } +} + +func TestBuildCommandPayloadEmptyMetadata(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("hook", HookPointSessionStart, nil) + if payload.Metadata != nil { + t.Fatalf("metadata should be nil for empty input, got %v", payload.Metadata) + } +} + +func TestParseCommandResponsePass(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"pass","message":"ok"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Status != "pass" { + t.Fatalf("status = %q, want %q", resp.Status, "pass") + } + if resp.Message != "ok" { + t.Fatalf("message = %q, want %q", resp.Message, "ok") + } +} + +func TestParseCommandResponseBlock(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"block","message":"denied"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Status != "block" { + t.Fatalf("status = %q, want %q", resp.Status, "block") + } +} + +func TestParseCommandResponseFailed(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"failed","message":"broken"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Status != "failed" { + t.Fatalf("status = %q, want %q", resp.Status, "failed") + } +} + +func TestParseCommandResponseWithAnnotations(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"pass","annotations":["note1","note2"]}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Annotations) != 2 || resp.Annotations[0] != "note1" { + t.Fatalf("annotations = %v, want [note1 note2]", resp.Annotations) + } +} + +func TestParseCommandResponseWithUpdateInput(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"pass","update_input":{"text":"rewritten"}}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.UpdateInput) == 0 { + t.Fatal("update_input should not be empty") + } + var update struct { + Text string `json:"text"` + } + if err := json.Unmarshal(resp.UpdateInput, &update); err != nil { + t.Fatalf("unmarshal update_input: %v", err) + } + if update.Text != "rewritten" { + t.Fatalf("update_input.text = %q, want %q", update.Text, "rewritten") + } +} + +func TestParseCommandResponseInvalidStatus(t *testing.T) { + t.Parallel() + _, err := ParseCommandResponse([]byte(`{"status":"unknown"}`)) + if err == nil { + t.Fatal("expected error for invalid status") + } +} + +func TestParseCommandResponseInvalidJSON(t *testing.T) { + t.Parallel() + _, err := ParseCommandResponse([]byte(`not json`)) + if err == nil { + t.Fatal("expected error for non-JSON input") + } +} + +func TestParseCommandResponseEmptyStdout(t *testing.T) { + t.Parallel() + _, err := ParseCommandResponse([]byte{}) + if err == nil { + t.Fatal("expected error for empty input") + } +} + +func TestRunCommandHookArgvMode(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("argv mode test uses echo which is a shell builtin on Windows") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-argv", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"pass","message":"hello from argv"}`}, + Shell: false, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "hello from argv" { + t.Fatalf("message = %q, want %q", result.Message, "hello from argv") + } +} + +func TestRunCommandHookArgvModeWindows(t *testing.T) { + t.Parallel() + if runtime.GOOS != "windows" { + t.Skip("Windows-only test") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-argv-win", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"message\":\"hello from argv\"}'"}, + Shell: false, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "hello from argv" { + t.Fatalf("message = %q, want %q", result.Message, "hello from argv") + } +} + +func TestRunCommandHookShellMode(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("shell mode test uses sh") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-shell", + Point: HookPointBeforeToolCall, + Command: []string{`echo '{"status":"pass","message":"from shell"}'`}, + Shell: true, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "from shell" { + t.Fatalf("message = %q, want %q", result.Message, "from shell") + } +} + +func TestRunCommandHookExitCodeNonZeroEmptyStdout(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "test-exit3", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "exit 3"}, + } + } else { + spec = CommandHookSpec{ + HookID: "test-exit3", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "exit 3"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } +} + +func TestRunCommandHookExitCodeBlock(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "test-exit1", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output 'blocked'; exit 1"}, + } + } else { + spec = CommandHookSpec{ + HookID: "test-exit1", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo blocked; exit 1"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultBlock, result.Message) + } + if result.Message != "blocked" { + t.Fatalf("message = %q, want %q", result.Message, "blocked") + } +} + +func TestRunCommandHookTimeout(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "test-timeout", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Start-Sleep -Seconds 10"}, + } + } else { + spec = CommandHookSpec{ + HookID: "test-timeout", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "sleep 10"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } +} + +func TestRunCommandHookEnvIsolation(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tmpDir := t.TempDir() + if runtime.GOOS == "windows" { + script := filepath.Join(tmpDir, "check_env.ps1") + if err := os.WriteFile(script, []byte(`$env:NEOCODE_HOOK_HOOK_ID; $env:NEOCODE_HOOK_POINT; $env:NEOCODE_HOOK_PAYLOAD_VERSION; if ($env:PATH) { "HAS_PATH=1" }; '{"status":"pass"}'`), 0o755); err != nil { + t.Fatalf("write script: %v", err) + } + spec := CommandHookSpec{ + HookID: "env-test", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-ExecutionPolicy", "Bypass", "-File", script}, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if !strings.Contains(result.Message, "env-test") { + t.Fatalf("expected NEOCODE_HOOK_HOOK_ID in output, got: %s", result.Message) + } + if strings.Contains(result.Message, "HAS_PATH=1") { + t.Fatal("PATH should not be inherited in isolated env") + } + } else { + script := filepath.Join(tmpDir, "check_env.sh") + if err := os.WriteFile(script, []byte("#!/bin/sh\nenv | grep NEOCODE_HOOK_ | sort\nif [ -n \"$PATH\" ]; then echo \"HAS_PATH=1\"; fi\necho '{\"status\":\"pass\"}'\n"), 0o755); err != nil { + t.Fatalf("write script: %v", err) + } + spec := CommandHookSpec{ + HookID: "env-test", + Point: HookPointBeforeToolCall, + Command: []string{"sh", script}, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if !strings.Contains(result.Message, "NEOCODE_HOOK_HOOK_ID=env-test") { + t.Fatalf("expected NEOCODE_HOOK_HOOK_ID in output, got: %s", result.Message) + } + if !strings.Contains(result.Message, "NEOCODE_HOOK_POINT=before_tool_call") { + t.Fatalf("expected NEOCODE_HOOK_POINT in output, got: %s", result.Message) + } + if strings.Contains(result.Message, "HAS_PATH=1") { + t.Fatal("PATH should not be inherited in isolated env") + } + } +} + +func TestRunCommandHookBackwardCompatPlainText(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "compat", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output 'just a message'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "compat", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo just a message; exit 0"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if result.Message != "just a message" { + t.Fatalf("message = %q, want %q", result.Message, "just a message") + } +} + +func TestRunCommandHookAnnotationsPopulated(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "annotated", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"annotations\":[\"a1\",\"a2\"]}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "annotated", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"pass","annotations":["a1","a2"]}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if len(result.Metadata.Annotations) != 2 { + t.Fatalf("annotations count = %d, want 2; annotations: %v", len(result.Metadata.Annotations), result.Metadata.Annotations) + } +} + +func TestRunCommandHookWorkdir(t *testing.T) { + t.Parallel() + tmpDir, err := os.MkdirTemp("", "hook-workdir-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "workdir-test", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output (Get-Location).Path; exit 0"}, + Workdir: tmpDir, + } + } else { + spec = CommandHookSpec{ + HookID: "workdir-test", + Point: HookPointBeforeToolCall, + Command: []string{"pwd"}, + Workdir: tmpDir, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if !strings.Contains(strings.ToLower(result.Message), strings.ToLower(filepath.Base(tmpDir))) { + t.Fatalf("expected workdir in output, got: %s", result.Message) + } +} + +func TestRunCommandHookStdinPayload(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "stdin-test", + Point: HookPointUserPromptSubmit, + Command: []string{"powershell", "-Command", "$input"}, + } + } else { + spec = CommandHookSpec{ + HookID: "stdin-test", + Point: HookPointUserPromptSubmit, + Command: []string{"cat"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{Metadata: map[string]any{"workdir": "/tmp"}}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if !strings.Contains(result.Message, CommandHookPayloadVersion) { + t.Fatalf("stdin payload should contain payload_version, got: %s", result.Message) + } +} diff --git a/internal/runtime/hooks/result.go b/internal/runtime/hooks/result.go index e224182db..697f47bef 100644 --- a/internal/runtime/hooks/result.go +++ b/internal/runtime/hooks/result.go @@ -1,6 +1,9 @@ package hooks -import "time" +import ( + "encoding/json" + "time" +) // HookResultStatus 表示单个 hook 的执行结果状态。 type HookResultStatus string @@ -36,6 +39,10 @@ type HookResultMetadata struct { OriginalStatus string BlockDowngraded bool GuardSignal bool + + // P6 command hook 协议字段 + Annotations []string // stdout JSON "annotations" 数组 + UpdateInput json.RawMessage // stdout JSON "update_input" 原始字节 } // RunOutput 是一次点位执行的聚合结果。 diff --git a/internal/runtime/hooks_integration.go b/internal/runtime/hooks_integration.go index b3badecf0..ef82672d7 100644 --- a/internal/runtime/hooks_integration.go +++ b/internal/runtime/hooks_integration.go @@ -2,9 +2,11 @@ package runtime import ( "context" + "encoding/json" "strings" runtimehooks "neo-code/internal/runtime/hooks" + providertypes "neo-code/internal/provider/types" ) const ( @@ -229,10 +231,15 @@ func (s *Service) recordUserHookAnnotations(state *runState, output runtimehooks continue } message := strings.TrimSpace(result.Message) - if message == "" { - continue + if message != "" { + notes = append(notes, message) + } + for _, annotation := range result.Metadata.Annotations { + trimmed := strings.TrimSpace(annotation) + if trimmed != "" { + notes = append(notes, trimmed) + } } - notes = append(notes, message) } if len(notes) == 0 { return @@ -241,3 +248,40 @@ func (s *Service) recordUserHookAnnotations(state *runState, output runtimehooks state.hookAnnotations = append(state.hookAnnotations, notes...) state.mu.Unlock() } + +// applyCommandHookUpdateInput 检查 hook 输出中的 update_input 并应用到用户输入 parts。 +// 当前仅支持 user_prompt_submit 点位;update_input 格式: {"text": "..."} 替换文本内容。 +func applyCommandHookUpdateInput(output runtimehooks.RunOutput, parts []providertypes.ContentPart) []providertypes.ContentPart { + if len(output.Results) == 0 { + return parts + } + for _, result := range output.Results { + if len(result.Metadata.UpdateInput) == 0 { + continue + } + cap, ok := runtimehooks.HookPointCapabilities(result.Point) + if !ok || !cap.CanUpdateInput { + continue + } + var update struct { + Text string `json:"text"` + } + if err := json.Unmarshal(result.Metadata.UpdateInput, &update); err != nil { + continue + } + if update.Text == "" { + continue + } + newParts := make([]providertypes.ContentPart, 0, len(parts)) + for _, part := range parts { + if part.Kind == providertypes.ContentPartText { + newParts = append(newParts, providertypes.NewTextPart(update.Text)) + update.Text = "" // 仅替换第一个文本 part + } else { + newParts = append(newParts, part) + } + } + return newParts + } + return parts +} diff --git a/internal/runtime/repo_hooks.go b/internal/runtime/repo_hooks.go index 2efa8df5f..a495f404a 100644 --- a/internal/runtime/repo_hooks.go +++ b/internal/runtime/repo_hooks.go @@ -363,9 +363,41 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", item.Handler) } case repoHookKindCommand: - if strings.TrimSpace(readHookParamString(item.Params, "command")) == "" { + if err := validateRepoCommandParams(item.Params); err != nil { + return err + } + } + return nil +} + +// validateRepoCommandParams 校验 repo command hook 的 params.command 格式。 +func validateRepoCommandParams(params map[string]any) error { + if len(params) == 0 { + return fmt.Errorf("kind command requires params.command") + } + raw, ok := params["command"] + if !ok || raw == nil { + return fmt.Errorf("kind command requires params.command") + } + switch v := raw.(type) { + case string: + if strings.TrimSpace(v) == "" { return fmt.Errorf("kind command requires params.command") } + shellVal, _ := params["shell"].(bool) + if !shellVal { + return fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") + } + case []any: + if len(v) == 0 { + return fmt.Errorf("kind command requires non-empty params.command") + } + case []string: + if len(v) == 0 { + return fmt.Errorf("kind command requires non-empty params.command") + } + default: + return fmt.Errorf("params.command must be a string (with shell=true) or an array") } return nil } diff --git a/internal/runtime/repo_hooks_test.go b/internal/runtime/repo_hooks_test.go index 4d1a538cb..64688256c 100644 --- a/internal/runtime/repo_hooks_test.go +++ b/internal/runtime/repo_hooks_test.go @@ -998,7 +998,7 @@ func TestValidateRepoHookItemCommandKindBranches(t *testing.T) { Mode: "sync", TimeoutSec: 2, FailurePolicy: "warn_only", - Params: map[string]any{"command": "echo ok"}, + Params: map[string]any{"command": []any{"echo", "ok"}}, } if err := validateRepoHookItem(item); err != nil { t.Fatalf("validateRepoHookItem(command with params) error = %v", err) diff --git a/internal/runtime/run.go b/internal/runtime/run.go index bef27fe70..338d07ebe 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -203,6 +203,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { }) return s.handleRunError(errors.New(findHookBlockMessage(submitHookOutput))) } + input.Parts = applyCommandHookUpdateInput(submitHookOutput, input.Parts) if err := s.appendUserMessageAndSave(ctx, &state, input.Parts); err != nil { return s.handleRunError(err) } diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index c9557e2e8..7a837a4d5 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -4,14 +4,12 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net" "net/http" "net/url" "os" - "os/exec" "path/filepath" "runtime" "slices" @@ -212,7 +210,7 @@ func buildConfiguredHookSpec( specKind = runtimehooks.HookKindFunction specMode = runtimehooks.HookModeSync case configuredHookKindCommand: - handler, err = buildUserCommandHookHandler(item.Params, defaultWorkdir) + handler, err = buildUserCommandHookHandler(strings.TrimSpace(item.ID), item.Params, defaultWorkdir) specKind = runtimehooks.HookKindCommand specMode = runtimehooks.HookModeSync case configuredHookKindHTTP: @@ -258,8 +256,8 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", item.Mode) } - if strings.TrimSpace(readHookParamString(item.Params, "command")) == "" { - return fmt.Errorf("kind command requires params.command") + if _, _, err := parseCommandHookParams(item.Params); err != nil { + return err } case configuredHookKindHTTP: if mode != configuredHookModeObserve { @@ -378,47 +376,73 @@ func buildUserBuiltinHookHandler( } } -// buildUserCommandHookHandler 将命令型 hook 转为同步阻断处理器,并通过 stdin 传入上下文 JSON。 -func buildUserCommandHookHandler(params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { - command := strings.TrimSpace(readHookParamString(params, "command")) - if command == "" { - return nil, fmt.Errorf("kind command requires params.command") +// buildUserCommandHookHandler 将命令型 hook 转为同步阻断处理器,使用 stdin/stdout JSON 协议。 +func buildUserCommandHookHandler(hookID string, params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { + argv, shell, err := parseCommandHookParams(params) + if err != nil { + return nil, err } return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { - workdir := resolveHookWorkdir(input, defaultWorkdir) - cmd := buildCommandHookProcess(ctx, command) - if strings.TrimSpace(workdir) != "" { - cmd.Dir = workdir - } - payload, err := json.Marshal(input) - if err != nil { - detail := fmt.Sprintf("command hook marshal input failed: %v", err) - return runtimehooks.HookResult{Status: runtimehooks.HookResultFailed, Message: detail, Error: detail} + spec := runtimehooks.CommandHookSpec{ + HookID: hookID, + Point: runtimehooks.HookPoint(strings.TrimSpace(fmt.Sprintf("%v", input.Metadata["point"]))), + Command: argv, + Shell: shell, + Workdir: resolveHookWorkdir(input, defaultWorkdir), + } + return runtimehooks.RunCommandHook(ctx, spec, input) + }, nil +} + +// parseCommandHookParams 解析 params.command 为 argv 数组,支持 []string / []any / string+shell 三种格式。 +func parseCommandHookParams(params map[string]any) (argv []string, shell bool, err error) { + if len(params) == 0 { + return nil, false, fmt.Errorf("kind command requires params.command") + } + raw, ok := params["command"] + if !ok || raw == nil { + return nil, false, fmt.Errorf("kind command requires params.command") + } + switch v := raw.(type) { + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return nil, false, fmt.Errorf("kind command requires params.command") } - cmd.Stdin = bytes.NewReader(payload) - output, err := cmd.CombinedOutput() - message := strings.TrimSpace(string(output)) - if err == nil { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: message} + shellVal, _ := params["shell"].(bool) + if !shellVal { + return nil, false, fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") } - var exitErr *exec.ExitError - if errors.As(err, &exitErr) && (exitErr.ExitCode() == 1 || exitErr.ExitCode() == 2) { - return runtimehooks.HookResult{Status: runtimehooks.HookResultBlock, Message: message} + return []string{trimmed}, true, nil + case []string: + if len(v) == 0 { + return nil, false, fmt.Errorf("kind command requires non-empty params.command") + } + out := make([]string, 0, len(v)) + for _, s := range v { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return nil, false, fmt.Errorf("params.command contains empty element") + } + out = append(out, trimmed) } - detail := strings.TrimSpace(message) - if detail == "" { - detail = err.Error() + return out, false, nil + case []any: + if len(v) == 0 { + return nil, false, fmt.Errorf("kind command requires non-empty params.command") + } + out := make([]string, 0, len(v)) + for _, item := range v { + s := strings.TrimSpace(fmt.Sprintf("%v", item)) + if s == "" { + return nil, false, fmt.Errorf("params.command contains empty element") + } + out = append(out, s) } - return runtimehooks.HookResult{Status: runtimehooks.HookResultFailed, Message: detail, Error: err.Error()} - }, nil -} - -// buildCommandHookProcess 以当前平台的 shell 执行用户命令,保留脚本组合能力。 -func buildCommandHookProcess(ctx context.Context, command string) *exec.Cmd { - if runtime.GOOS == "windows" { - return exec.CommandContext(ctx, "powershell", "-Command", command) + return out, false, nil + default: + return nil, false, fmt.Errorf("params.command must be a string (with shell=true) or an array of strings") } - return exec.CommandContext(ctx, "sh", "-c", command) } // buildUserHTTPObserveHookHandler 将 kind=http 的 observe 配置转换为观测回调处理器。 From 7b6156986d46879aa9a8db090d48f550c995fc36 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Sat, 23 May 2026 16:54:15 +0800 Subject: [PATCH 2/6] fix(hooks): rewrite env isolation test to avoid Windows PATH leak assertion Windows injects PATH at the system level even when cmd.Env is set. Split into two focused tests: one verifies NEOCODE_HOOK_* vars are injected via exec, the other verifies buildCommandEnv returns the correct variable set. Co-Authored-By: Claude Opus 4.7 --- .../runtime/hooks/command_handler_test.go | 79 ++++++++++--------- 1 file changed, 41 insertions(+), 38 deletions(-) diff --git a/internal/runtime/hooks/command_handler_test.go b/internal/runtime/hooks/command_handler_test.go index 608d757ab..7fbf0ff08 100644 --- a/internal/runtime/hooks/command_handler_test.go +++ b/internal/runtime/hooks/command_handler_test.go @@ -273,54 +273,57 @@ func TestRunCommandHookTimeout(t *testing.T) { func TestRunCommandHookEnvIsolation(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - tmpDir := t.TempDir() + var spec CommandHookSpec if runtime.GOOS == "windows" { - script := filepath.Join(tmpDir, "check_env.ps1") - if err := os.WriteFile(script, []byte(`$env:NEOCODE_HOOK_HOOK_ID; $env:NEOCODE_HOOK_POINT; $env:NEOCODE_HOOK_PAYLOAD_VERSION; if ($env:PATH) { "HAS_PATH=1" }; '{"status":"pass"}'`), 0o755); err != nil { - t.Fatalf("write script: %v", err) - } - spec := CommandHookSpec{ + spec = CommandHookSpec{ HookID: "env-test", Point: HookPointBeforeToolCall, - Command: []string{"powershell", "-ExecutionPolicy", "Bypass", "-File", script}, - } - result := RunCommandHook(ctx, spec, HookContext{}) - if result.Status != HookResultPass { - t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) - } - if !strings.Contains(result.Message, "env-test") { - t.Fatalf("expected NEOCODE_HOOK_HOOK_ID in output, got: %s", result.Message) - } - if strings.Contains(result.Message, "HAS_PATH=1") { - t.Fatal("PATH should not be inherited in isolated env") + Command: []string{"powershell", "-Command", "$env:NEOCODE_HOOK_HOOK_ID; $env:NEOCODE_HOOK_POINT; $env:NEOCODE_HOOK_PAYLOAD_VERSION; '{\"status\":\"pass\"}'"}, } } else { - script := filepath.Join(tmpDir, "check_env.sh") - if err := os.WriteFile(script, []byte("#!/bin/sh\nenv | grep NEOCODE_HOOK_ | sort\nif [ -n \"$PATH\" ]; then echo \"HAS_PATH=1\"; fi\necho '{\"status\":\"pass\"}'\n"), 0o755); err != nil { - t.Fatalf("write script: %v", err) - } - spec := CommandHookSpec{ + spec = CommandHookSpec{ HookID: "env-test", Point: HookPointBeforeToolCall, - Command: []string{"sh", script}, - } - result := RunCommandHook(ctx, spec, HookContext{}) - if result.Status != HookResultPass { - t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + Command: []string{"sh", "-c", "echo $NEOCODE_HOOK_HOOK_ID; echo $NEOCODE_HOOK_POINT; echo $NEOCODE_HOOK_PAYLOAD_VERSION; echo '{\"status\":\"pass\"}'"}, } - if !strings.Contains(result.Message, "NEOCODE_HOOK_HOOK_ID=env-test") { - t.Fatalf("expected NEOCODE_HOOK_HOOK_ID in output, got: %s", result.Message) - } - if !strings.Contains(result.Message, "NEOCODE_HOOK_POINT=before_tool_call") { - t.Fatalf("expected NEOCODE_HOOK_POINT in output, got: %s", result.Message) - } - if strings.Contains(result.Message, "HAS_PATH=1") { - t.Fatal("PATH should not be inherited in isolated env") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if !strings.Contains(result.Message, "env-test") { + t.Fatalf("expected NEOCODE_HOOK_HOOK_ID in output, got: %s", result.Message) + } + if !strings.Contains(result.Message, "before_tool_call") { + t.Fatalf("expected NEOCODE_HOOK_POINT in output, got: %s", result.Message) + } + if !strings.Contains(result.Message, CommandHookPayloadVersion) { + t.Fatalf("expected NEOCODE_HOOK_PAYLOAD_VERSION in output, got: %s", result.Message) + } +} + +func TestBuildCommandEnvContainsHookVars(t *testing.T) { + t.Parallel() + spec := CommandHookSpec{HookID: "id-123", Point: HookPointSessionEnd} + env := buildCommandEnv(spec) + envMap := make(map[string]bool) + for _, e := range env { + parts := strings.SplitN(e, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = true } } + if !envMap["NEOCODE_HOOK_HOOK_ID"] { + t.Fatal("missing NEOCODE_HOOK_HOOK_ID") + } + if !envMap["NEOCODE_HOOK_POINT"] { + t.Fatal("missing NEOCODE_HOOK_POINT") + } + if !envMap["NEOCODE_HOOK_PAYLOAD_VERSION"] { + t.Fatal("missing NEOCODE_HOOK_PAYLOAD_VERSION") + } } func TestRunCommandHookBackwardCompatPlainText(t *testing.T) { From e4e5d15bdfec4d170ce29c38486316f70a53d38c Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Sat, 23 May 2026 17:48:34 +0800 Subject: [PATCH 3/6] fix(hooks): pass hook point directly to command handler closure input.Metadata["point"] was not populated by runHookPoint (it only injects run_id/session_id/runtime_run_token/phase/turn), so command hooks received an empty point in the stdin payload and NEOCODE_HOOK_POINT. Fix by passing item.Point from config directly into the handler closure at build time. Co-Authored-By: Claude Opus 4.7 --- internal/runtime/user_hooks.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index 7a837a4d5..e7b9d0654 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -210,7 +210,12 @@ func buildConfiguredHookSpec( specKind = runtimehooks.HookKindFunction specMode = runtimehooks.HookModeSync case configuredHookKindCommand: - handler, err = buildUserCommandHookHandler(strings.TrimSpace(item.ID), item.Params, defaultWorkdir) + handler, err = buildUserCommandHookHandler( + strings.TrimSpace(item.ID), + runtimehooks.HookPoint(strings.TrimSpace(item.Point)), + item.Params, + defaultWorkdir, + ) specKind = runtimehooks.HookKindCommand specMode = runtimehooks.HookModeSync case configuredHookKindHTTP: @@ -377,7 +382,7 @@ func buildUserBuiltinHookHandler( } // buildUserCommandHookHandler 将命令型 hook 转为同步阻断处理器,使用 stdin/stdout JSON 协议。 -func buildUserCommandHookHandler(hookID string, params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { +func buildUserCommandHookHandler(hookID string, point runtimehooks.HookPoint, params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { argv, shell, err := parseCommandHookParams(params) if err != nil { return nil, err @@ -385,7 +390,7 @@ func buildUserCommandHookHandler(hookID string, params map[string]any, defaultWo return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { spec := runtimehooks.CommandHookSpec{ HookID: hookID, - Point: runtimehooks.HookPoint(strings.TrimSpace(fmt.Sprintf("%v", input.Metadata["point"]))), + Point: point, Command: argv, Shell: shell, Workdir: resolveHookWorkdir(input, defaultWorkdir), From a5ddb5ed96477a21f5debd2b9fda82dc88600e68 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Sat, 23 May 2026 18:00:07 +0800 Subject: [PATCH 4/6] fix(hooks): exit code precedence and payload run_id/session_id Self-review fixes for P6 command hook protocol: 1. Security: non-zero exit code now takes precedence over JSON status. A malicious script claiming {"status":"pass"} while exiting non-zero will be treated as block/failed based on exit code, not pass. buildResultFromExitCode still extracts message/annotations from JSON stdout when available, but status authority remains with exit code. 2. Added top-level run_id and session_id fields to CommandHookPayload, populated from HookContext. External scripts can now read these directly without digging into metadata. 3. Added tests: exit code precedence over JSON, exit code 3 with JSON message extraction, payload run_id/session_id verification, stdin payload round-trip with run_id/session_id. Co-Authored-By: Claude Opus 4.7 --- internal/runtime/hooks/command_handler.go | 53 +++++--- .../runtime/hooks/command_handler_test.go | 116 +++++++++++++++++- 2 files changed, 147 insertions(+), 22 deletions(-) diff --git a/internal/runtime/hooks/command_handler.go b/internal/runtime/hooks/command_handler.go index b30fab9fa..c77b2007c 100644 --- a/internal/runtime/hooks/command_handler.go +++ b/internal/runtime/hooks/command_handler.go @@ -20,6 +20,8 @@ type CommandHookPayload struct { PayloadVersion string `json:"payload_version"` HookID string `json:"hook_id"` Point string `json:"point"` + RunID string `json:"run_id,omitempty"` + SessionID string `json:"session_id,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } @@ -41,14 +43,16 @@ type CommandHookSpec struct { } // BuildCommandPayload 构造传给外部命令的 stdin JSON payload。 -func BuildCommandPayload(hookID string, point HookPoint, metadata map[string]any) CommandHookPayload { +func BuildCommandPayload(hookID string, point HookPoint, input HookContext) CommandHookPayload { payload := CommandHookPayload{ PayloadVersion: CommandHookPayloadVersion, HookID: strings.TrimSpace(hookID), Point: string(point), + RunID: strings.TrimSpace(input.RunID), + SessionID: strings.TrimSpace(input.SessionID), } - if len(metadata) > 0 { - payload.Metadata = metadata + if len(input.Metadata) > 0 { + payload.Metadata = input.Metadata } return payload } @@ -76,7 +80,7 @@ func ParseCommandResponse(raw []byte) (CommandHookResponse, error) { // RunCommandHook 执行外部命令并返回结构化的 HookResult。 func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext) HookResult { - payload := BuildCommandPayload(spec.HookID, spec.Point, input.Metadata) + payload := BuildCommandPayload(spec.HookID, spec.Point, input) payloadBytes, err := json.Marshal(payload) if err != nil { return HookResult{ @@ -97,14 +101,24 @@ func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext stdout, err := cmd.Output() message := strings.TrimSpace(string(stdout)) - // 尝试解析 stdout JSON 协议 + // 非零 exit code 优先于 JSON status(防止恶意脚本声称 pass 但实际失败) + if err != nil { + return buildResultFromExitCode(ctx, spec, err, message, stdout) + } + + // exit code 0: 尝试解析 stdout JSON 协议 resp, parseErr := ParseCommandResponse(stdout) if parseErr == nil { return buildResultFromResponse(spec, resp) } - // 退化模式: stdout 非 JSON,按 exit code 推断状态 - return buildResultFromExitCode(ctx, spec, err, message) + // 退化模式: exit 0 但 stdout 非 JSON,按 pass 处理 + return HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Status: HookResultPass, + Message: message, + } } func buildExecCmd(ctx context.Context, spec CommandHookSpec) *exec.Cmd { @@ -162,16 +176,12 @@ func buildResultFromResponse(spec CommandHookSpec, resp CommandHookResponse) Hoo return result } -func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err error, message string) HookResult { +func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err error, message string, stdout []byte) HookResult { result := HookResult{ HookID: spec.HookID, Point: spec.Point, Message: message, } - if err == nil { - result.Status = HookResultPass - return result - } // 上下文取消/超时优先判定为 failed if ctx.Err() != nil { result.Status = HookResultFailed @@ -194,12 +204,21 @@ func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err erro } result.Error = err.Error() } - return result + } else { + result.Status = HookResultFailed + if result.Message == "" { + result.Message = err.Error() + } + result.Error = err.Error() } - result.Status = HookResultFailed - if result.Message == "" { - result.Message = err.Error() + // 尝试从 stdout JSON 提取 message/annotations(status 仍由 exit code 决定) + if resp, parseErr := ParseCommandResponse(stdout); parseErr == nil { + if trimmed := strings.TrimSpace(resp.Message); trimmed != "" { + result.Message = trimmed + } + if len(resp.Annotations) > 0 { + result.Metadata.Annotations = resp.Annotations + } } - result.Error = err.Error() return result } diff --git a/internal/runtime/hooks/command_handler_test.go b/internal/runtime/hooks/command_handler_test.go index 7fbf0ff08..710bd91a5 100644 --- a/internal/runtime/hooks/command_handler_test.go +++ b/internal/runtime/hooks/command_handler_test.go @@ -13,9 +13,13 @@ import ( func TestBuildCommandPayload(t *testing.T) { t.Parallel() - payload := BuildCommandPayload("my-hook", HookPointBeforeToolCall, map[string]any{ - "tool_name": "bash", - "workdir": "/tmp", + payload := BuildCommandPayload("my-hook", HookPointBeforeToolCall, HookContext{ + RunID: "run-123", + SessionID: "sess-456", + Metadata: map[string]any{ + "tool_name": "bash", + "workdir": "/tmp", + }, }) if payload.PayloadVersion != CommandHookPayloadVersion { t.Fatalf("payload_version = %q, want %q", payload.PayloadVersion, CommandHookPayloadVersion) @@ -26,6 +30,12 @@ func TestBuildCommandPayload(t *testing.T) { if payload.Point != string(HookPointBeforeToolCall) { t.Fatalf("point = %q, want %q", payload.Point, HookPointBeforeToolCall) } + if payload.RunID != "run-123" { + t.Fatalf("run_id = %q, want %q", payload.RunID, "run-123") + } + if payload.SessionID != "sess-456" { + t.Fatalf("session_id = %q, want %q", payload.SessionID, "sess-456") + } if payload.Metadata["tool_name"] != "bash" { t.Fatalf("metadata[tool_name] = %v, want %q", payload.Metadata["tool_name"], "bash") } @@ -33,10 +43,13 @@ func TestBuildCommandPayload(t *testing.T) { func TestBuildCommandPayloadEmptyMetadata(t *testing.T) { t.Parallel() - payload := BuildCommandPayload("hook", HookPointSessionStart, nil) + payload := BuildCommandPayload("hook", HookPointSessionStart, HookContext{}) if payload.Metadata != nil { t.Fatalf("metadata should be nil for empty input, got %v", payload.Metadata) } + if payload.RunID != "" { + t.Fatalf("run_id should be empty, got %q", payload.RunID) + } } func TestParseCommandResponsePass(t *testing.T) { @@ -414,6 +427,89 @@ func TestRunCommandHookWorkdir(t *testing.T) { } } +func TestBuildCommandPayloadRunSessionID(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("my-hook", HookPointBeforeToolCall, HookContext{ + RunID: "run-abc", + SessionID: "sess-xyz", + }) + if payload.RunID != "run-abc" { + t.Fatalf("run_id = %q, want %q", payload.RunID, "run-abc") + } + if payload.SessionID != "sess-xyz" { + t.Fatalf("session_id = %q, want %q", payload.SessionID, "sess-xyz") + } +} + +func TestBuildCommandPayloadEmptyRunSessionID(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("hook", HookPointSessionStart, HookContext{}) + if payload.RunID != "" { + t.Fatalf("run_id should be empty, got %q", payload.RunID) + } + if payload.SessionID != "" { + t.Fatalf("session_id should be empty, got %q", payload.SessionID) + } +} + +func TestRunCommandHookExitCodePrecedenceOverJSON(t *testing.T) { + // Security: non-zero exit code must override JSON status. + // A malicious script claiming "pass" while exiting 1 should result in block, not pass. + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "precedence-test", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"message\":\"claiming pass\"}'; exit 1"}, + } + } else { + spec = CommandHookSpec{ + HookID: "precedence-test", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo '{\"status\":\"pass\",\"message\":\"claiming pass\"}'; exit 1"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q (exit code must take precedence over JSON status)", result.Status, HookResultBlock) + } + // message should still be extracted from JSON stdout + if result.Message != "claiming pass" { + t.Fatalf("message = %q, want %q (should extract message from JSON even when exit code wins)", result.Message, "claiming pass") + } +} + +func TestRunCommandHookExitCodeThreeWithJSONMessage(t *testing.T) { + // exit code 3 + JSON with message → failed status, message from JSON + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit3-json", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"message\":\"from json\"}'; exit 3"}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit3-json", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo '{\"status\":\"pass\",\"message\":\"from json\"}'; exit 3"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Message != "from json" { + t.Fatalf("message = %q, want %q", result.Message, "from json") + } +} + func TestRunCommandHookStdinPayload(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -432,11 +528,21 @@ func TestRunCommandHookStdinPayload(t *testing.T) { Command: []string{"cat"}, } } - result := RunCommandHook(ctx, spec, HookContext{Metadata: map[string]any{"workdir": "/tmp"}}) + result := RunCommandHook(ctx, spec, HookContext{ + RunID: "run-789", + SessionID: "sess-012", + Metadata: map[string]any{"workdir": "/tmp"}, + }) if result.Status != HookResultPass { t.Fatalf("status = %q, want %q", result.Status, HookResultPass) } if !strings.Contains(result.Message, CommandHookPayloadVersion) { t.Fatalf("stdin payload should contain payload_version, got: %s", result.Message) } + if !strings.Contains(result.Message, "run-789") { + t.Fatalf("stdin payload should contain run_id, got: %s", result.Message) + } + if !strings.Contains(result.Message, "sess-012") { + t.Fatalf("stdin payload should contain session_id, got: %s", result.Message) + } } From 1ef1ed08657e55c8b4bb31c7107efbd9f07ab6be Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Mon, 25 May 2026 20:10:19 +0800 Subject: [PATCH 5/6] fix(hooks): address P6 command hook protocol review findings P0 security/correctness: - Capture stderr separately from stdout; append to failed results for debug - Add SystemRoot/SystemDrive/USERPROFILE to Windows hook env (TLS/NTLM) - Limit stdout to 1 MiB via pipe+LimitReader to prevent OOM P1 design/maintainability: - Extract ValidateCommandParams/ParseCommandParams as shared exports in hooks package, deduplicating 3 identical validation sites (config/runtime_hooks.go, repo_hooks.go, user_hooks.go) - Guard buildExecCmd against Shell+multi-args misuse (panic) - Set result.Error for exit code 1/2 block results (consistency) P2 test coverage: - Add 7-subtest TestApplyCommandHookUpdateInput (caught real bug: was replacing ALL text parts instead of first-only) - Add TestRunCommandHookEnvIsolationNoLeak (Unix PATH/HOME/USER check) - Add TestRunCommandHookShellModeWindows (powershell -Command) - Add TestValidateCommandParams (9 cases covering exported API) P3 documentation: - Document stderr handling strategy, run_id/session_id field precedence, update_input+block interaction, exit-code-over-JSON security rule, stdout size limit, Windows env vars Bug fix discovered during review: - applyCommandHookUpdateInput used update.Text="" to skip subsequent text parts, but still appended NewTextPart(""). Fixed with replaced flag. Co-Authored-By: Claude Opus 4.7 --- docs/runtime-hooks-design.md | 30 +++- internal/config/runtime_hooks.go | 35 +---- internal/runtime/hooks/command_handler.go | 138 ++++++++++++++++- .../runtime/hooks/command_handler_test.go | 75 ++++++++++ internal/runtime/hooks_integration.go | 5 +- internal/runtime/hooks_integration_test.go | 139 ++++++++++++++++++ internal/runtime/repo_hooks.go | 34 +---- internal/runtime/user_hooks.go | 55 +------ 8 files changed, 380 insertions(+), 131 deletions(-) diff --git a/docs/runtime-hooks-design.md b/docs/runtime-hooks-design.md index f3527b7c8..68c59cbf4 100644 --- a/docs/runtime-hooks-design.md +++ b/docs/runtime-hooks-design.md @@ -150,10 +150,11 @@ trust store 固定路径: "payload_version": "1", "hook_id": "my-hook", "point": "before_tool_call", + "run_id": "run_abc123", + "session_id": "sess_abc123", "metadata": { "tool_name": "bash", - "workdir": "/path/to/workspace", - "session_id": "sess_abc123" + "workdir": "/path/to/workspace" } } ``` @@ -228,12 +229,37 @@ params: | `NEOCODE_HOOK_POINT` | 触发点位(如 `before_tool_call`) | | `NEOCODE_HOOK_PAYLOAD_VERSION` | `"1"` | +Windows 额外注入 `SystemRoot`、`SystemDrive`、`USERPROFILE`(从宿主环境读取),以确保 TLS 证书加载和运行时基础功能正常工作。 + ### 执行约束 - workdir = 当前 run 的 workspace(`cmd.Dir = workdir`) - 超时 = hook 配置的 `timeout_sec`(默认 2s) - 并发限制 = executor 的 `max_in_flight`(默认 128) - repo scope command hook 受 trust gate 保护 +- stdout 大小限制 = 1 MiB;超出视为 `failed` + +### stderr 处理 + +外部命令的 stderr 与 stdout 分离捕获。stderr 不会混入 `message` 字段,仅在命令执行失败(非零 exit code)且 stdout 无可用 message 时,stderr 内容才作为 fallback 追加到结果中。此设计确保 hook 协议输出(stdout JSON)不受调试输出(stderr)干扰。 + +### stdin 字段说明 + +- `run_id` / `session_id` 同时出现在 payload 顶层和 `metadata` 中。**顶层字段为权威来源**,`metadata` 中的同名字段为冗余副本(与 builtin/http hook 的 metadata allowlist 一致)。外部脚本应优先读取顶层字段。 +- `payload_version` 当前固定为 `"1"`,变更 stdin 结构时递增。 + +### update_input 与 block 交互 + +当 hook 返回 `status: "block"` 时,`update_input` 不会被应用。阻断优先于输入改写——hook 链在检测到 block 后立即终止,不进入 `applyCommandHookUpdateInput` 逻辑。 + +### 安全:exit code 优先于 JSON status + +当命令以非零 exit code 退出时,stdout 中 JSON 声称的 `status` 字段被忽略。exit code 的映射优先: + +- exit 1/2 → `block` +- 其他非零 → `failed` + +此规则防止恶意脚本通过 `{"status":"pass"}` 掩盖实际失败。JSON 中的 `message` 和 `annotations` 仍会被提取(如果 stdout 是合法 JSON)。 ### 示例 diff --git a/internal/config/runtime_hooks.go b/internal/config/runtime_hooks.go index 172accd3d..fa237f527 100644 --- a/internal/config/runtime_hooks.go +++ b/internal/config/runtime_hooks.go @@ -286,7 +286,7 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if normalizedMode != runtimeHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", c.Mode) } - if err := validateRuntimeCommandItem(c.Params); err != nil { + if err := hooks.ValidateCommandParams(c.Params); err != nil { return err } case runtimeHookKindHTTP: @@ -349,39 +349,6 @@ func validateRuntimeHTTPObserveItem(c RuntimeHookItemConfig, policy string) erro return nil } -// validateRuntimeCommandItem 校验 command kind 的 params.command 格式。 -// 支持 []string / []any (argv 模式) 和 string + shell=true (shell 模式)。 -func validateRuntimeCommandItem(params map[string]any) error { - if len(params) == 0 { - return fmt.Errorf("kind command requires params.command") - } - raw, ok := params["command"] - if !ok || raw == nil { - return fmt.Errorf("kind command requires params.command") - } - switch v := raw.(type) { - case string: - if strings.TrimSpace(v) == "" { - return fmt.Errorf("kind command requires params.command") - } - shellVal, _ := params["shell"].(bool) - if !shellVal { - return fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") - } - case []any: - if len(v) == 0 { - return fmt.Errorf("kind command requires non-empty params.command") - } - case []string: - if len(v) == 0 { - return fmt.Errorf("kind command requires non-empty params.command") - } - default: - return fmt.Errorf("params.command must be a string (with shell=true) or an array") - } - return nil -} - // isRuntimeHookHTTPObserveLoopbackHost 判断 http observe 回调域名是否属于本地回环地址。 func isRuntimeHookHTTPObserveLoopbackHost(host string) bool { normalized := strings.TrimSpace(strings.ToLower(host)) diff --git a/internal/runtime/hooks/command_handler.go b/internal/runtime/hooks/command_handler.go index c77b2007c..0523d897c 100644 --- a/internal/runtime/hooks/command_handler.go +++ b/internal/runtime/hooks/command_handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "os" "os/exec" "runtime" @@ -15,6 +16,9 @@ import ( // CommandHookPayloadVersion 定义 command hook stdin 协议版本号,变更 stdin 结构时递增。 const CommandHookPayloadVersion = "1" +// maxCommandStdoutBytes 限制外部命令 stdout 最大读取字节数,防止 OOM。 +const maxCommandStdoutBytes = 1 << 20 // 1 MiB + // CommandHookPayload 是通过 stdin 传给外部命令的单行 JSON。 type CommandHookPayload struct { PayloadVersion string `json:"payload_version"` @@ -42,6 +46,66 @@ type CommandHookSpec struct { Workdir string } +// ValidateCommandParams 校验 params.command 格式。 +// 支持 []string / []any (argv 模式) 和 string + shell=true (shell 模式)。 +// 此函数是 command hook params 校验的唯一真源,config / runtime 包均应调用此函数。 +func ValidateCommandParams(params map[string]any) error { + _, _, err := ParseCommandParams(params) + return err +} + +// ParseCommandParams 解析 params.command 为 argv 数组,支持 []string / []any / string+shell 三种格式。 +// 返回解析后的 argv、是否为 shell 模式、以及校验错误。 +func ParseCommandParams(params map[string]any) (argv []string, shell bool, err error) { + if len(params) == 0 { + return nil, false, fmt.Errorf("kind command requires params.command") + } + raw, ok := params["command"] + if !ok || raw == nil { + return nil, false, fmt.Errorf("kind command requires params.command") + } + switch v := raw.(type) { + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return nil, false, fmt.Errorf("kind command requires params.command") + } + shellVal, _ := params["shell"].(bool) + if !shellVal { + return nil, false, fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") + } + return []string{trimmed}, true, nil + case []string: + if len(v) == 0 { + return nil, false, fmt.Errorf("kind command requires non-empty params.command") + } + out := make([]string, 0, len(v)) + for _, s := range v { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return nil, false, fmt.Errorf("params.command contains empty element") + } + out = append(out, trimmed) + } + return out, false, nil + case []any: + if len(v) == 0 { + return nil, false, fmt.Errorf("kind command requires non-empty params.command") + } + out := make([]string, 0, len(v)) + for _, item := range v { + s := strings.TrimSpace(fmt.Sprintf("%v", item)) + if s == "" { + return nil, false, fmt.Errorf("params.command contains empty element") + } + out = append(out, s) + } + return out, false, nil + default: + return nil, false, fmt.Errorf("params.command must be a string (with shell=true) or an array of strings") + } +} + // BuildCommandPayload 构造传给外部命令的 stdin JSON payload。 func BuildCommandPayload(hookID string, point HookPoint, input HookContext) CommandHookPayload { payload := CommandHookPayload{ @@ -79,6 +143,7 @@ func ParseCommandResponse(raw []byte) (CommandHookResponse, error) { } // RunCommandHook 执行外部命令并返回结构化的 HookResult。 +// stdout 通过管道捕获并限制为 maxCommandStdoutBytes;stderr 捕获后在失败时附加到结果。 func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext) HookResult { payload := BuildCommandPayload(spec.HookID, spec.Point, input) payloadBytes, err := json.Marshal(payload) @@ -98,12 +163,25 @@ func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext cmd.Env = buildCommandEnv(spec) cmd.Stdin = bytes.NewReader(payloadBytes) - stdout, err := cmd.Output() + stdout, stderrBytes, runErr := runAndCapture(cmd) + + // stdout 过大视为执行失败 + if int64(len(stdout)) > maxCommandStdoutBytes { + msg := fmt.Sprintf("command hook stdout exceeded %d byte limit", maxCommandStdoutBytes) + return HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Status: HookResultFailed, + Message: msg, + Error: msg, + } + } + message := strings.TrimSpace(string(stdout)) // 非零 exit code 优先于 JSON status(防止恶意脚本声称 pass 但实际失败) - if err != nil { - return buildResultFromExitCode(ctx, spec, err, message, stdout) + if runErr != nil { + return buildResultFromExitCode(ctx, spec, runErr, message, stdout, stderrBytes) } // exit code 0: 尝试解析 stdout JSON 协议 @@ -121,14 +199,53 @@ func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext } } +// runAndCapture 执行命令,通过管道捕获 stdout(限制 maxCommandStdoutBytes),同时捕获 stderr。 +func runAndCapture(cmd *exec.Cmd) (stdout, stderr []byte, runErr error) { + cmd.Stderr = &bytes.Buffer{} + + pipe, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, err + } + if err := cmd.Start(); err != nil { + return nil, nil, err + } + + // 限制读取量,防止恶意脚本 OOM + limitedReader := io.LimitReader(pipe, maxCommandStdoutBytes+1) + var stdoutBuf bytes.Buffer + _, copyErr := io.Copy(&stdoutBuf, limitedReader) + stdout = stdoutBuf.Bytes() + + waitErr := cmd.Wait() + + if stderrBuf, ok := cmd.Stderr.(*bytes.Buffer); ok { + stderr = stderrBuf.Bytes() + } + + // pipe 读取错误优先 + if copyErr != nil { + return stdout, stderr, fmt.Errorf("reading command stdout: %w", copyErr) + } + + return stdout, stderr, waitErr +} + func buildExecCmd(ctx context.Context, spec CommandHookSpec) *exec.Cmd { - if spec.Shell && len(spec.Command) > 0 { + if spec.Shell { + if len(spec.Command) == 0 { + // 不应到达此处(ParseCommandParams 已校验),防御性 panic + panic("buildExecCmd: shell mode requires at least one command element") + } shell := spec.Command[0] if runtime.GOOS == "windows" { return exec.CommandContext(ctx, "powershell", "-Command", shell) } return exec.CommandContext(ctx, "sh", "-c", shell) } + if len(spec.Command) == 0 { + panic("buildExecCmd: command requires at least one element") + } if len(spec.Command) == 1 { return exec.CommandContext(ctx, spec.Command[0]) } @@ -142,8 +259,10 @@ func buildCommandEnv(spec CommandHookSpec) []string { "NEOCODE_HOOK_PAYLOAD_VERSION=" + CommandHookPayloadVersion, } if runtime.GOOS == "windows" { - if sd := os.Getenv("SystemDrive"); sd != "" { - env = append(env, "SystemDrive="+sd) + for _, key := range []string{"SystemRoot", "SystemDrive", "USERPROFILE"} { + if v := os.Getenv(key); v != "" { + env = append(env, key+"="+v) + } } } return env @@ -176,7 +295,7 @@ func buildResultFromResponse(spec CommandHookSpec, resp CommandHookResponse) Hoo return result } -func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err error, message string, stdout []byte) HookResult { +func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err error, message string, stdout, stderr []byte) HookResult { result := HookResult{ HookID: spec.HookID, Point: spec.Point, @@ -197,6 +316,7 @@ func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err erro switch code { case 1, 2: result.Status = HookResultBlock + result.Error = err.Error() default: result.Status = HookResultFailed if result.Message == "" { @@ -220,5 +340,9 @@ func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err erro result.Metadata.Annotations = resp.Annotations } } + // 失败时附带 stderr 便于调试 + if stderrText := strings.TrimSpace(string(stderr)); stderrText != "" && result.Message == "" { + result.Message = stderrText + } return result } diff --git a/internal/runtime/hooks/command_handler_test.go b/internal/runtime/hooks/command_handler_test.go index 710bd91a5..88377bb65 100644 --- a/internal/runtime/hooks/command_handler_test.go +++ b/internal/runtime/hooks/command_handler_test.go @@ -546,3 +546,78 @@ func TestRunCommandHookStdinPayload(t *testing.T) { t.Fatalf("stdin payload should contain session_id, got: %s", result.Message) } } + +func TestRunCommandHookShellModeWindows(t *testing.T) { + t.Parallel() + if runtime.GOOS != "windows" { + t.Skip("Windows-only test") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-shell-win", + Point: HookPointBeforeToolCall, + Command: []string{`Write-Output '{"status":"pass","message":"from powershell shell"}'`}, + Shell: true, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "from powershell shell" { + t.Fatalf("message = %q, want %q", result.Message, "from powershell shell") + } +} + +func TestRunCommandHookEnvIsolationNoLeak(t *testing.T) { + // Verify that host env vars like PATH, HOME, USER are NOT leaked to the subprocess. + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("PATH leaks at system level on Windows; see buildCommandEnv") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "env-no-leak", + Point: HookPointBeforeToolCall, + Command: []string{"env"}, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + for _, leaked := range []string{"PATH=", "HOME=", "USER="} { + if strings.Contains(result.Message, leaked) { + t.Fatalf("host env var %q should not be leaked to subprocess, got: %s", leaked, result.Message) + } + } +} + +func TestValidateCommandParams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params map[string]any + wantErr bool + }{ + {"nil params", nil, true}, + {"empty params", map[string]any{}, true}, + {"missing command", map[string]any{"other": "val"}, true}, + {"empty string command", map[string]any{"command": ""}, true}, + {"string without shell", map[string]any{"command": "echo ok"}, true}, + {"string with shell", map[string]any{"command": "echo ok", "shell": true}, false}, + {"empty array", map[string]any{"command": []any{}}, true}, + {"valid array", map[string]any{"command": []any{"echo", "ok"}}, false}, + {"array with empty element", map[string]any{"command": []any{"echo", ""}}, true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := ValidateCommandParams(tc.params) + if (err != nil) != tc.wantErr { + t.Fatalf("ValidateCommandParams() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} diff --git a/internal/runtime/hooks_integration.go b/internal/runtime/hooks_integration.go index ef82672d7..1497566e2 100644 --- a/internal/runtime/hooks_integration.go +++ b/internal/runtime/hooks_integration.go @@ -272,11 +272,12 @@ func applyCommandHookUpdateInput(output runtimehooks.RunOutput, parts []provider if update.Text == "" { continue } + replaced := false newParts := make([]providertypes.ContentPart, 0, len(parts)) for _, part := range parts { - if part.Kind == providertypes.ContentPartText { + if !replaced && part.Kind == providertypes.ContentPartText { newParts = append(newParts, providertypes.NewTextPart(update.Text)) - update.Text = "" // 仅替换第一个文本 part + replaced = true } else { newParts = append(newParts, part) } diff --git a/internal/runtime/hooks_integration_test.go b/internal/runtime/hooks_integration_test.go index aa820bd41..73bc279aa 100644 --- a/internal/runtime/hooks_integration_test.go +++ b/internal/runtime/hooks_integration_test.go @@ -1253,3 +1253,142 @@ func TestEmitSubAgentStopHookNilServiceNoop(t *testing.T) { Error: "noop", }) } + +func TestApplyCommandHookUpdateInput(t *testing.T) { + t.Parallel() + + t.Run("empty results returns parts unchanged", func(t *testing.T) { + t.Parallel() + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(runtimehooks.RunOutput{}, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged", got) + } + }) + + t.Run("replaces first text part when CanUpdateInput", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"rewritten"}`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "rewritten" { + t.Fatalf("got %v, want text replaced to 'rewritten'", got) + } + }) + + t.Run("ignores when CanUpdateInput is false", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointBeforeToolCall, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"should not apply"}`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged for non-CanUpdateInput point", got) + } + }) + + t.Run("ignores invalid JSON in UpdateInput", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`not json`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged for invalid JSON", got) + } + }) + + t.Run("ignores empty text in UpdateInput", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":""}`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged for empty text", got) + } + }) + + t.Run("only replaces first text part", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"new"}`), + }, + }}, + } + parts := []providertypes.ContentPart{ + providertypes.NewTextPart("first"), + providertypes.NewTextPart("second"), + } + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 2 { + t.Fatalf("got len %d, want 2", len(got)) + } + if got[0].Text != "new" { + t.Fatalf("first part text = %q, want 'new'", got[0].Text) + } + if got[1].Text != "second" { + t.Fatalf("second part text = %q, want 'second' (unchanged)", got[1].Text) + } + }) + + t.Run("preserves non-text parts", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"replaced"}`), + }, + }}, + } + parts := []providertypes.ContentPart{ + providertypes.NewRemoteImagePart("https://example.com/img.png"), + providertypes.NewTextPart("original"), + } + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 2 { + t.Fatalf("got len %d, want 2", len(got)) + } + if got[0].Kind != providertypes.ContentPartImage { + t.Fatalf("first part kind = %q, want image (unchanged)", got[0].Kind) + } + if got[1].Text != "replaced" { + t.Fatalf("second part text = %q, want 'replaced'", got[1].Text) + } + }) +} diff --git a/internal/runtime/repo_hooks.go b/internal/runtime/repo_hooks.go index a495f404a..b22bb343d 100644 --- a/internal/runtime/repo_hooks.go +++ b/internal/runtime/repo_hooks.go @@ -363,45 +363,13 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", item.Handler) } case repoHookKindCommand: - if err := validateRepoCommandParams(item.Params); err != nil { + if err := runtimehooks.ValidateCommandParams(item.Params); err != nil { return err } } return nil } -// validateRepoCommandParams 校验 repo command hook 的 params.command 格式。 -func validateRepoCommandParams(params map[string]any) error { - if len(params) == 0 { - return fmt.Errorf("kind command requires params.command") - } - raw, ok := params["command"] - if !ok || raw == nil { - return fmt.Errorf("kind command requires params.command") - } - switch v := raw.(type) { - case string: - if strings.TrimSpace(v) == "" { - return fmt.Errorf("kind command requires params.command") - } - shellVal, _ := params["shell"].(bool) - if !shellVal { - return fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") - } - case []any: - if len(v) == 0 { - return fmt.Errorf("kind command requires non-empty params.command") - } - case []string: - if len(v) == 0 { - return fmt.Errorf("kind command requires non-empty params.command") - } - default: - return fmt.Errorf("params.command must be a string (with shell=true) or an array") - } - return nil -} - // runtimeHasWarnOnToolCallTargets 判断 warn_on_tool_call 是否配置了至少一个目标工具。 func runtimeHasWarnOnToolCallTargets(params map[string]any) bool { if len(params) == 0 { diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index e7b9d0654..ec7a4f159 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -261,7 +261,7 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", item.Mode) } - if _, _, err := parseCommandHookParams(item.Params); err != nil { + if _, _, err := runtimehooks.ParseCommandParams(item.Params); err != nil { return err } case configuredHookKindHTTP: @@ -383,7 +383,7 @@ func buildUserBuiltinHookHandler( // buildUserCommandHookHandler 将命令型 hook 转为同步阻断处理器,使用 stdin/stdout JSON 协议。 func buildUserCommandHookHandler(hookID string, point runtimehooks.HookPoint, params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { - argv, shell, err := parseCommandHookParams(params) + argv, shell, err := runtimehooks.ParseCommandParams(params) if err != nil { return nil, err } @@ -399,57 +399,6 @@ func buildUserCommandHookHandler(hookID string, point runtimehooks.HookPoint, pa }, nil } -// parseCommandHookParams 解析 params.command 为 argv 数组,支持 []string / []any / string+shell 三种格式。 -func parseCommandHookParams(params map[string]any) (argv []string, shell bool, err error) { - if len(params) == 0 { - return nil, false, fmt.Errorf("kind command requires params.command") - } - raw, ok := params["command"] - if !ok || raw == nil { - return nil, false, fmt.Errorf("kind command requires params.command") - } - switch v := raw.(type) { - case string: - trimmed := strings.TrimSpace(v) - if trimmed == "" { - return nil, false, fmt.Errorf("kind command requires params.command") - } - shellVal, _ := params["shell"].(bool) - if !shellVal { - return nil, false, fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") - } - return []string{trimmed}, true, nil - case []string: - if len(v) == 0 { - return nil, false, fmt.Errorf("kind command requires non-empty params.command") - } - out := make([]string, 0, len(v)) - for _, s := range v { - trimmed := strings.TrimSpace(s) - if trimmed == "" { - return nil, false, fmt.Errorf("params.command contains empty element") - } - out = append(out, trimmed) - } - return out, false, nil - case []any: - if len(v) == 0 { - return nil, false, fmt.Errorf("kind command requires non-empty params.command") - } - out := make([]string, 0, len(v)) - for _, item := range v { - s := strings.TrimSpace(fmt.Sprintf("%v", item)) - if s == "" { - return nil, false, fmt.Errorf("params.command contains empty element") - } - out = append(out, s) - } - return out, false, nil - default: - return nil, false, fmt.Errorf("params.command must be a string (with shell=true) or an array of strings") - } -} - // buildUserHTTPObserveHookHandler 将 kind=http 的 observe 配置转换为观测回调处理器。 func buildUserHTTPObserveHookHandler(item config.RuntimeHookItemConfig) (runtimehooks.HookHandler, error) { endpoint := strings.TrimSpace(readHookParamString(item.Params, "url")) From 7db5e5ed40f2fb98d4729e7b3a4a7018f6f07ef2 Mon Sep 17 00:00:00 2001 From: Cai_Tang <106404101+Cai-Tang-www@users.noreply.github.com> Date: Mon, 25 May 2026 20:26:15 +0800 Subject: [PATCH 6/6] test(hooks): improve command_handler coverage to 96% Add targeted tests for previously uncovered branches: - ParseCommandParams: all 9 branches ([]string, []any, string+shell, nil, empty, unsupported type, empty element, shell=false) - RunCommandHook: stdout-too-large, nonexistent binary, exit 0 empty, exit 2 block, exit 3 with stderr, block+message, failed default msg, failed custom msg, pass+annotations+update_input, stdin+metadata - buildCommandEnv: verify Windows SystemRoot/SystemDrive/USERPROFILE - buildResultFromResponse: failed status with default/custom message - buildResultFromExitCode: exit 2 sets Error, exit 3 with stderr Coverage: 90.8% -> 96.0% (command_handler.go functions all >= 75%) Co-Authored-By: Claude Opus 4.7 --- .../runtime/hooks/command_handler_test.go | 392 ++++++++++++++++++ 1 file changed, 392 insertions(+) diff --git a/internal/runtime/hooks/command_handler_test.go b/internal/runtime/hooks/command_handler_test.go index 88377bb65..259bc6ba4 100644 --- a/internal/runtime/hooks/command_handler_test.go +++ b/internal/runtime/hooks/command_handler_test.go @@ -3,6 +3,7 @@ package hooks import ( "context" "encoding/json" + "fmt" "os" "path/filepath" "runtime" @@ -593,6 +594,397 @@ func TestRunCommandHookEnvIsolationNoLeak(t *testing.T) { } } +func TestParseCommandParamsAllBranches(t *testing.T) { + t.Parallel() + + t.Run("string with shell=true", func(t *testing.T) { + t.Parallel() + argv, shell, err := ParseCommandParams(map[string]any{"command": "echo hi", "shell": true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !shell { + t.Fatal("expected shell=true") + } + if len(argv) != 1 || argv[0] != "echo hi" { + t.Fatalf("argv = %v, want [echo hi]", argv) + } + }) + + t.Run("string with whitespace shell=true", func(t *testing.T) { + t.Parallel() + argv, shell, err := ParseCommandParams(map[string]any{"command": " echo hi ", "shell": true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !shell || argv[0] != "echo hi" { + t.Fatalf("argv = %v, shell = %v", argv, shell) + } + }) + + t.Run("[]string valid", func(t *testing.T) { + t.Parallel() + argv, shell, err := ParseCommandParams(map[string]any{"command": []string{"echo", "hello"}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if shell { + t.Fatal("expected shell=false for array") + } + if len(argv) != 2 || argv[0] != "echo" || argv[1] != "hello" { + t.Fatalf("argv = %v", argv) + } + }) + + t.Run("[]string empty", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": []string{}}) + if err == nil { + t.Fatal("expected error for empty []string") + } + }) + + t.Run("[]string with empty element", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": []string{"echo", " ", "ok"}}) + if err == nil { + t.Fatal("expected error for empty element in []string") + } + }) + + t.Run("[]any with empty element after Sprintf", func(t *testing.T) { + t.Parallel() + // nil element => fmt.Sprintf("%v", nil) => "" which is non-empty + // but empty string element => fmt.Sprintf("%v", "") => "" which is empty + _, _, err := ParseCommandParams(map[string]any{"command": []any{"echo", ""}}) + if err == nil { + t.Fatal("expected error for empty element in []any") + } + }) + + t.Run("unsupported type", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": 123}) + if err == nil { + t.Fatal("expected error for unsupported type") + } + }) + + t.Run("nil command value", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": nil}) + if err == nil { + t.Fatal("expected error for nil command") + } + }) + + t.Run("shell=false on string", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": "echo ok", "shell": false}) + if err == nil { + t.Fatal("expected error for string without shell=true") + } + }) +} + +func TestRunCommandHookStdoutTooLarge(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Generate output slightly above the 1MiB limit + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "stdout-toolarge", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output ('x' * 1048577)"}, + } + } else { + spec = CommandHookSpec{ + HookID: "stdout-toolarge", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "printf '%1048577s' ''"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if !strings.Contains(result.Message, "byte limit") { + t.Fatalf("message should mention byte limit, got: %s", result.Message) + } +} + +func TestRunCommandHookStdinPayloadWithMetadata(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "stdin-meta", + Point: HookPointUserPromptSubmit, + Command: []string{"powershell", "-Command", "$input"}, + } + } else { + spec = CommandHookSpec{ + HookID: "stdin-meta", + Point: HookPointUserPromptSubmit, + Command: []string{"cat"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{ + RunID: "run-meta", + SessionID: "sess-meta", + Metadata: map[string]any{"tool_name": "bash", "workdir": "/tmp"}, + }) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if !strings.Contains(result.Message, `"tool_name"`) { + t.Fatalf("stdin should contain tool_name metadata, got: %s", result.Message) + } +} + +func TestRunCommandHookExitCodeTwoBlocks(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit2", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "exit 2"}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit2", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "exit 2"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q", result.Status, HookResultBlock) + } + if result.Error == "" { + t.Fatal("expected Error to be set for exit code 2 block") + } +} + +func TestRunCommandHookExitCodeZeroEmptyStdout(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit0-empty", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", ""}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit0-empty", + Point: HookPointBeforeToolCall, + Command: []string{"true"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } +} + +func TestRunCommandHookNonExistentBinary(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "no-such-binary", + Point: HookPointBeforeToolCall, + Command: []string{"nonexistent_binary_xyz_12345"}, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Error == "" { + t.Fatal("expected Error to be set for nonexistent binary") + } +} + +func TestRunCommandHookBlockWithMessage(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "block-msg", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"block\",\"message\":\"not allowed\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "block-msg", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"block","message":"not allowed"}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q", result.Status, HookResultBlock) + } + if result.Message != "not allowed" { + t.Fatalf("message = %q, want %q", result.Message, "not allowed") + } +} + +func TestRunCommandHookFailedStatusWithDefaultMessage(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "failed-default", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"failed\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "failed-default", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"failed"}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Message != "hook returned failed status" { + t.Fatalf("message = %q, want default failed message", result.Message) + } + if result.Error != "hook returned failed status" { + t.Fatalf("error = %q, want default failed message", result.Error) + } +} + +func TestRunCommandHookFailedStatusWithCustomMessage(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "failed-custom", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"failed\",\"message\":\"custom error\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "failed-custom", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"failed","message":"custom error"}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Message != "custom error" { + t.Fatalf("message = %q, want %q", result.Message, "custom error") + } + if result.Error != "custom error" { + t.Fatalf("error = %q, want %q", result.Error, "custom error") + } +} + +func TestRunCommandHookPassWithAnnotationsAndUpdateInput(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + payload := `{"status":"pass","message":"ok","annotations":["a1","a2"],"update_input":{"text":"rewritten"}}` + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "full-output", + Point: HookPointUserPromptSubmit, + Command: []string{"powershell", "-Command", fmt.Sprintf("Write-Output '%s'", payload)}, + } + } else { + spec = CommandHookSpec{ + HookID: "full-output", + Point: HookPointUserPromptSubmit, + Command: []string{"echo", payload}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if result.Message != "ok" { + t.Fatalf("message = %q, want %q", result.Message, "ok") + } + if len(result.Metadata.Annotations) != 2 || result.Metadata.Annotations[0] != "a1" { + t.Fatalf("annotations = %v, want [a1 a2]", result.Metadata.Annotations) + } + if len(result.Metadata.UpdateInput) == 0 { + t.Fatal("expected UpdateInput to be populated") + } +} + +func TestRunCommandHookExitCodeThreeWithStderr(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit3-stderr", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Error 'bad thing'; exit 3"}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit3-stderr", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo bad thing >&2; exit 3"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } +} + +func TestBuildCommandEnvContainsNEOCODEVars(t *testing.T) { + t.Parallel() + spec := CommandHookSpec{HookID: "id-env", Point: HookPointSessionStart} + env := buildCommandEnv(spec) + envMap := make(map[string]bool) + for _, e := range env { + parts := strings.SplitN(e, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = true + } + } + for _, key := range []string{"NEOCODE_HOOK_HOOK_ID", "NEOCODE_HOOK_POINT", "NEOCODE_HOOK_PAYLOAD_VERSION"} { + if !envMap[key] { + t.Fatalf("missing %s in env", key) + } + } + if runtime.GOOS == "windows" { + for _, key := range []string{"SystemRoot", "SystemDrive", "USERPROFILE"} { + if os.Getenv(key) != "" && !envMap[key] { + t.Fatalf("missing Windows env var %s", key) + } + } + } +} + func TestValidateCommandParams(t *testing.T) { t.Parallel()