Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions app/api/routes_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@

from typing import Optional

from fastapi import APIRouter, Request
import asyncio

from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from app.api.response import ok
from app.api.response import ok, fail
from app.core.command_runner import CommandRunner

router = APIRouter()

# 单进程:按 command 互斥执行 /commands/stream
# key: command
_command_locks: dict[str, asyncio.Lock] = {}
_command_active_run: dict[str, str] = {}


class CommandRequest(BaseModel):
command: str
Expand Down Expand Up @@ -48,8 +55,16 @@ async def execute_command(payload: CommandRequest):
async def execute_command_stream(payload: CommandRequest, request: Request):
runner = CommandRunner()

command_key = payload.command
lock = _command_locks.setdefault(command_key, asyncio.Lock())

# 不排队:如果同 command 正在执行,直接 409
if lock.locked():
return fail("Another command is still running", status_code=409)

async def gen():
run_id: Optional[str] = None
await lock.acquire()
try:
async for ev in runner.stream(
payload.command,
Expand All @@ -59,6 +74,7 @@ async def gen():
):
if ev.get("event") == "start":
run_id = ev.get("run_id")
_command_active_run[command_key] = run_id

if await request.is_disconnected():
if run_id:
Expand All @@ -81,8 +97,14 @@ async def gen():
{"run_id": ev["run_id"], "exit_code": ev.get("exit_code"), "lines": ev.get("lines")},
)
finally:
if run_id:
await runner.cleanup(run_id)
try:
if run_id:
await runner.cleanup(run_id)
finally:
if _command_active_run.get(command_key) == run_id:
_command_active_run.pop(command_key, None)
if lock.locked():
lock.release()

return StreamingResponse(
gen(),
Expand Down
56 changes: 50 additions & 6 deletions app/core/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import signal
import sys
import uuid
import shlex
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Optional

Expand Down Expand Up @@ -32,6 +33,8 @@ class CommandRunner:

def __init__(self) -> None:
self._active: Dict[str, asyncio.subprocess.Process] = {}
# 存储真实业务进程PID(修复核心)
self._real_pids: Dict[str, int] = {}

def decode_output(self, data: bytes) -> str:
try:
Expand Down Expand Up @@ -119,6 +122,27 @@ async def exec_json(
except Exception as e:
return CommandResult(exit_code=-1, stdout="", stderr="", error=f"Execution error: {e}")

async def _get_real_child_pid(self, shell_pid: int) -> Optional[int]:
"""
核心修复:获取shell进程下的真实业务子进程PID
Linux/Unix专用,Windows直接返回shell PID
"""
if sys.platform == "win32":
return shell_pid

try:
# 读取 /proc/{shell_pid}/task/{shell_pid}/children 获取直接子进程
children_path = f"/proc/{shell_pid}/task/{shell_pid}/children"
if os.path.exists(children_path):
async with asyncio.Lock():
with open(children_path, 'r') as f:
child_pids = f.read().strip().split()
if child_pids:
return int(child_pids[0])
return None
except Exception:
return None

async def stream(
self,
command: str,
Expand Down Expand Up @@ -163,7 +187,19 @@ async def stream(
proc = await asyncio.create_subprocess_shell(command, **kwargs)
self._active[run_id] = proc

yield {"event": "start", "run_id": run_id, "pid": proc.pid, "command": command}
# ===================== 修复核心:获取真实PID =====================
real_pid = proc.pid
if not is_windows:
# 等待子进程创建(极短等待,不影响性能)
await asyncio.sleep(0.1)
child_pid = await self._get_real_child_pid(proc.pid)
if child_pid:
real_pid = child_pid
self._real_pids[run_id] = real_pid
# ===============================================================

# 现在返回的pid就是真实业务进程PID,和ps命令完全一致
yield {"event": "start", "run_id": run_id, "pid": real_pid, "command": command}

line_count = 0
while True:
Expand Down Expand Up @@ -259,15 +295,23 @@ async def kill(self, run_id: str) -> bool:
return True

async def cleanup(self, run_id: str) -> None:
# 清理真实PID缓存
self._real_pids.pop(run_id, None)
proc = self._active.pop(run_id, None)
if proc is not None:
await self._terminate_process(proc)

def list_active(self) -> list[dict]:
return [
{"run_id": rid, "pid": proc.pid, "returncode": proc.returncode}
for rid, proc in self._active.items()
]
active_list = []
for rid, proc in self._active.items():
real_pid = self._real_pids.get(rid, proc.pid)
active_list.append({
"run_id": rid,
"pid": real_pid,
"shell_pid": proc.pid,
"returncode": proc.returncode
})
return active_list

async def _terminate_process(self, proc: asyncio.subprocess.Process, timeout: float = 5.0) -> None:
if proc is None or proc.returncode is not None:
Expand Down Expand Up @@ -317,4 +361,4 @@ async def _terminate_process(self, proc: asyncio.subprocess.Process, timeout: fl
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except Exception:
pass
pass
37 changes: 29 additions & 8 deletions app/core/request_id_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,38 @@ async def dispatch(self, request: Request, call_next):
if self._create_instance_apikey and request.client and request.client.host not in {"127.0.0.1", "::1"}:
auth = request.headers.get("Authorization") or ""
expected = f"Bearer {self._create_instance_apikey}"

# 没有传递 API Key
if not auth:
request_id_ctx.reset(token)
return JSONResponse(
status_code=401,
content={
"error": {
"err_code": -10001,
"message": "Missing 302 Apikey",
"message_cn": "缺少 302 API 密钥",
"message_jp": "302 APIキーがありません",
"type": "api_error"
}
}
)

# 传递了,但密钥不正确
if auth != expected:
request_id_ctx.reset(token)
return JSONResponse(status_code=401, content={
"error": {
"err_code": -10001,
"message": "Missing 302 Apikey",
"message_cn": "缺少 302 API 密钥",
"message_jp": "302 APIキーがありません",
"type": "api_error"
return JSONResponse(
status_code=401,
content={
"error": {
"err_code": -10002,
"message": "Invalid API Key, for details please view 302.AI",
"message_cn": "无效的API KEY,更多请访问 302.AI",
"message_jp": "無効なAPIキーです。詳細は 302.AI をご覧ください。",
"type": "api_error"
}
}
})
)

# Streaming endpoints: don't read body.
is_streaming = False
Expand Down
Loading