diff --git a/app/api/routes_command.py b/app/api/routes_command.py index cf09667..a7835ce 100644 --- a/app/api/routes_command.py +++ b/app/api/routes_command.py @@ -2,15 +2,22 @@ from typing import Optional -from fastapi import APIRouter, Request +import asyncio + +from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel -from app.api.response import ok +from app.api.response import ok, fail from app.core.command_runner import CommandRunner router = APIRouter() +# 单进程:按 command 互斥执行 /commands/stream +# key: command +_command_locks: dict[str, asyncio.Lock] = {} +_command_active_run: dict[str, str] = {} + class CommandRequest(BaseModel): command: str @@ -48,8 +55,16 @@ async def execute_command(payload: CommandRequest): async def execute_command_stream(payload: CommandRequest, request: Request): runner = CommandRunner() + command_key = payload.command + lock = _command_locks.setdefault(command_key, asyncio.Lock()) + + # 不排队:如果同 command 正在执行,直接 409 + if lock.locked(): + return fail("Another command is still running", status_code=409) + async def gen(): run_id: Optional[str] = None + await lock.acquire() try: async for ev in runner.stream( payload.command, @@ -59,6 +74,7 @@ async def gen(): ): if ev.get("event") == "start": run_id = ev.get("run_id") + _command_active_run[command_key] = run_id if await request.is_disconnected(): if run_id: @@ -81,8 +97,14 @@ async def gen(): {"run_id": ev["run_id"], "exit_code": ev.get("exit_code"), "lines": ev.get("lines")}, ) finally: - if run_id: - await runner.cleanup(run_id) + try: + if run_id: + await runner.cleanup(run_id) + finally: + if _command_active_run.get(command_key) == run_id: + _command_active_run.pop(command_key, None) + if lock.locked(): + lock.release() return StreamingResponse( gen(), diff --git a/app/core/command_runner.py b/app/core/command_runner.py index 06c6255..4400598 100644 --- a/app/core/command_runner.py +++ b/app/core/command_runner.py @@ -5,6 +5,7 @@ import signal import sys import uuid +import shlex from dataclasses import dataclass from typing import Any, AsyncIterator, Dict, Optional @@ -32,6 +33,8 @@ class CommandRunner: def __init__(self) -> None: self._active: Dict[str, asyncio.subprocess.Process] = {} + # 存储真实业务进程PID(修复核心) + self._real_pids: Dict[str, int] = {} def decode_output(self, data: bytes) -> str: try: @@ -119,6 +122,27 @@ async def exec_json( except Exception as e: return CommandResult(exit_code=-1, stdout="", stderr="", error=f"Execution error: {e}") + async def _get_real_child_pid(self, shell_pid: int) -> Optional[int]: + """ + 核心修复:获取shell进程下的真实业务子进程PID + Linux/Unix专用,Windows直接返回shell PID + """ + if sys.platform == "win32": + return shell_pid + + try: + # 读取 /proc/{shell_pid}/task/{shell_pid}/children 获取直接子进程 + children_path = f"/proc/{shell_pid}/task/{shell_pid}/children" + if os.path.exists(children_path): + async with asyncio.Lock(): + with open(children_path, 'r') as f: + child_pids = f.read().strip().split() + if child_pids: + return int(child_pids[0]) + return None + except Exception: + return None + async def stream( self, command: str, @@ -163,7 +187,19 @@ async def stream( proc = await asyncio.create_subprocess_shell(command, **kwargs) self._active[run_id] = proc - yield {"event": "start", "run_id": run_id, "pid": proc.pid, "command": command} + # ===================== 修复核心:获取真实PID ===================== + real_pid = proc.pid + if not is_windows: + # 等待子进程创建(极短等待,不影响性能) + await asyncio.sleep(0.1) + child_pid = await self._get_real_child_pid(proc.pid) + if child_pid: + real_pid = child_pid + self._real_pids[run_id] = real_pid + # =============================================================== + + # 现在返回的pid就是真实业务进程PID,和ps命令完全一致 + yield {"event": "start", "run_id": run_id, "pid": real_pid, "command": command} line_count = 0 while True: @@ -259,15 +295,23 @@ async def kill(self, run_id: str) -> bool: return True async def cleanup(self, run_id: str) -> None: + # 清理真实PID缓存 + self._real_pids.pop(run_id, None) proc = self._active.pop(run_id, None) if proc is not None: await self._terminate_process(proc) def list_active(self) -> list[dict]: - return [ - {"run_id": rid, "pid": proc.pid, "returncode": proc.returncode} - for rid, proc in self._active.items() - ] + active_list = [] + for rid, proc in self._active.items(): + real_pid = self._real_pids.get(rid, proc.pid) + active_list.append({ + "run_id": rid, + "pid": real_pid, + "shell_pid": proc.pid, + "returncode": proc.returncode + }) + return active_list async def _terminate_process(self, proc: asyncio.subprocess.Process, timeout: float = 5.0) -> None: if proc is None or proc.returncode is not None: @@ -317,4 +361,4 @@ async def _terminate_process(self, proc: asyncio.subprocess.Process, timeout: fl try: await asyncio.wait_for(proc.wait(), timeout=timeout) except Exception: - pass + pass \ No newline at end of file diff --git a/app/core/request_id_middleware.py b/app/core/request_id_middleware.py index 85fd9e2..be9a68b 100644 --- a/app/core/request_id_middleware.py +++ b/app/core/request_id_middleware.py @@ -63,17 +63,38 @@ async def dispatch(self, request: Request, call_next): if self._create_instance_apikey and request.client and request.client.host not in {"127.0.0.1", "::1"}: auth = request.headers.get("Authorization") or "" expected = f"Bearer {self._create_instance_apikey}" + + # 没有传递 API Key + if not auth: + request_id_ctx.reset(token) + return JSONResponse( + status_code=401, + content={ + "error": { + "err_code": -10001, + "message": "Missing 302 Apikey", + "message_cn": "缺少 302 API 密钥", + "message_jp": "302 APIキーがありません", + "type": "api_error" + } + } + ) + + # 传递了,但密钥不正确 if auth != expected: request_id_ctx.reset(token) - return JSONResponse(status_code=401, content={ - "error": { - "err_code": -10001, - "message": "Missing 302 Apikey", - "message_cn": "缺少 302 API 密钥", - "message_jp": "302 APIキーがありません", - "type": "api_error" + return JSONResponse( + status_code=401, + content={ + "error": { + "err_code": -10002, + "message": "Invalid API Key, for details please view 302.AI", + "message_cn": "无效的API KEY,更多请访问 302.AI", + "message_jp": "無効なAPIキーです。詳細は 302.AI をご覧ください。", + "type": "api_error" + } } - }) + ) # Streaming endpoints: don't read body. is_streaming = False