diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/_provider.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/_provider.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/_provider.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/_provider.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/_test.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/_test.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/_test.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/_test.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig.handler.py b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig.handler.py similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig.handler.py rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig.handler.py diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_assets.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_assets.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_assets.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_assets.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_device.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_device.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_device.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_device.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_helper.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_helper.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_helper.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_helper.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_login.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_login.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_login.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_login.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_monitoring.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_monitoring.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_monitoring.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_monitoring.yaml diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_strategy.yaml b/.flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_strategy.yaml similarity index 100% rename from .flocks/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_strategy.yaml rename to .flocks/flockshub/plugins/tools/device/onesig_v2_5_3_D20260321/onesig_strategy.yaml diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/_provider.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/_provider.yaml new file mode 100644 index 000000000..5be2b3316 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/_provider.yaml @@ -0,0 +1,62 @@ +name: "SecGate 3600" +vendor: qianxin +service_id: secgate3600_api +version: "3.6.6.0" +integration_type: device +description: > + Qi-Anxin SecGate 3600 firewall RESTful API integration for V3.6.6.0 + family builds. It supports login validation, system/resource status, + interface listing, security-policy queries, dashboard statistics, and + documented read-only REST calls. +description_cn: > + 奇安信网神 SecGate 3600 防火墙 V3.6.6.0 系列 RESTful API 接入。 + 支持登录校验、系统资源、接口列表、安全策略、首页统计以及文档内只读 + REST 调用。设备侧需要先开启 RESTful API 服务、配置 RESTful API 管理员 + 账号,并将 Flocks 服务所在机器加入 RESTful API 可信主机。 +auth: + type: custom + flow: login_cookie_token + login_path: /v1.0/login +credential_fields: + - key: base_url + label: 设备 API 地址 + storage: config + config_key: base_url + input_type: url + required: true + placeholder: "https://secgate3600.example.com:8443" + - key: username + label: 用户名 + storage: secret + config_key: username + secret_id: secgate3600_username + input_type: text + required: true + - key: password + label: 密码 + storage: secret + config_key: password + secret_id: secgate3600_password + input_type: password + required: true +defaults: + timeout: 30 + category: custom + product_version: "V3.6.6.0" + verify_ssl: false +notes: | + URL 规则来自《网神 SecGate3600 防火墙 RESTful API 使用指南 V1.1》: + - 登录:https://{IP地址}:{端口号}/v1.0/login + - 注销:https://{IP地址}:{端口号}/v1.0/out + - 数据:https://{IP地址}:{端口号}/v1.0/rest/ + + 设备侧配置要求: + 1. 接口管理方式勾选 HTTPS。 + 2. 创建角色为“RESTful API 管理员”的管理账号。 + 3. 在管理主机中添加 Flocks 服务所在机器 IP,服务选择 RESTful API。 + 4. 在本机设置中启用 RESTful API 服务并配置通信端口。 + + 登录成功后,设备返回 PHPSESSID cookie 和 token;后续 /v1.0/rest/ + 请求会携带 Cookie: PHPSESSID=...; token=...。 + base_url 可以填写设备根地址,例如 https://10.0.0.1:8443;如果误填了 + /v1.0、/v1.0/login 或 /v1.0/rest,handler 会自动归一化为根地址。 diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/_test.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/_test.yaml new file mode 100644 index 000000000..da7c473f0 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/_test.yaml @@ -0,0 +1,65 @@ +schema_version: 1 +provider: secgate3600_api + +connectivity: + tool: secgate3600_system + params: + action: check_login + +fixtures: + secgate3600_system: + - label: Check login + tags: [smoke, auth] + params: + action: check_login + assert: + success: true + - label: Get system resource + tags: [smoke, system] + params: + action: system_resource + assert: + success: true + secgate3600_dashboard: + - label: Get last-day notice count + tags: [smoke, dashboard] + params: + action: notice_num_day + assert: + success: true + secgate3600_network: + - label: List physical interfaces + tags: [smoke, network] + params: + action: interface_list + assert: + success: true + secgate3600_policy: + - label: List security policies + tags: [smoke, policy] + params: + action: security_policy_list + assert: + success: true + secgate3600_api_readonly: + - label: Show API catalog + tags: [smoke, api] + params: + action: api_catalog + assert: + success: true + - label: Raw readonly system resource + tags: [api] + params: + action: rest_call_readonly + module: dashboard + function: get_system_resource + assert: + success: true + secgate3600_api_mutation: + - label: Show API catalog before confirmed mutation + tags: [api, mutation] + params: + action: api_catalog + assert: + success: true diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/manifest.json b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/manifest.json new file mode 100644 index 000000000..939ce9ceb --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/manifest.json @@ -0,0 +1,63 @@ +{ + "schemaVersion": "hub.plugin.v1", + "id": "secgate3600_v3_6_6_0", + "type": "device", + "name": "SecGate 3600", + "description": "Qi-Anxin SecGate 3600 firewall RESTful API integration for V3.6.6.0 family builds.", + "descriptionCn": "奇安信网神 SecGate 3600 防火墙 V3.6.6.0 系列 RESTful API 接入。", + "version": "3.6.6.0", + "author": "Flocks Team", + "license": "MIT", + "category": "security", + "tags": [ + "firewall", + "device" + ], + "useCases": [ + "integration" + ], + "domains": [ + "security-ops" + ], + "capabilities": [ + "device-integration", + "rest-api" + ], + "trust": "official", + "source": { + "kind": "bundled", + "path": "plugins/tools/device/secgate3600_v3_6_6_0" + }, + "compatibility": { + "flocks": ">=0.8.0", + "os": [ + "darwin", + "linux", + "windows" + ] + }, + "dependencies": { + "skills": [], + "tools": [], + "python": [], + "external": [] + }, + "permissions": { + "tools": [], + "network": true, + "shell": false, + "filesystem": "none" + }, + "risk": { + "level": "low", + "reasons": [] + }, + "entrypoints": [ + "_provider.yaml", + "secgate3600.handler.py", + "secgate3600_api_catalog.json", + "secgate3600_api_readonly.yaml", + "secgate3600_api_mutation.yaml" + ], + "checksums": {} +} diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600.handler.py b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600.handler.py new file mode 100644 index 000000000..1a4ffc1dd --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600.handler.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +import asyncio +import json +import os +from pathlib import Path +from typing import Any, Callable + +import requests + +from flocks.config.config_writer import ConfigWriter +from flocks.security import get_secret_manager +from flocks.tool.registry import ToolContext, ToolResult + + +SERVICE_ID = "secgate3600_api" +STORAGE_KEY = "secgate3600_api_v3_6_6_0" +PRODUCT_VERSION = "V3.6.6.0" +DEFAULT_TIMEOUT = 30 +DEFAULT_VERIFY_SSL = False +CATALOG_FILE = Path(__file__).with_name("secgate3600_api_catalog.json") + + +class SecGateError(RuntimeError): + pass + + +class RuntimeConfig: + def __init__( + self, + *, + base_url: str, + username: str, + password: str, + verify_ssl: bool, + timeout: int, + ) -> None: + self.base_url = base_url + self.username = username + self.password = password + self.verify_ssl = verify_ssl + self.timeout = timeout + + +def _resolve_ref(value: Any) -> str: + if value is None: + return "" + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return get_secret_manager().get(value[len("{secret:") : -1]) or "" + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:") : -1], "") + return value + + +def _raw_service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + if not isinstance(raw, dict): + raw = ConfigWriter.get_api_service_raw(STORAGE_KEY) + return raw if isinstance(raw, dict) else {} + + +def _as_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + text = value.strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + return bool(value) + + +def _config_value(raw: dict[str, Any], *keys: str) -> Any: + for key in keys: + if raw.get(key) is not None: + return raw[key] + custom_settings = raw.get("custom_settings") + if isinstance(custom_settings, dict): + for key in keys: + if custom_settings.get(key) is not None: + return custom_settings[key] + return None + + +def _normalize_base_url(base_url: str) -> str: + text = base_url.strip().rstrip("/") + for suffix in ("/v1.0/rest", "/v1.0/login", "/v1.0/out", "/v1.0"): + if text.endswith(suffix): + text = text[: -len(suffix)] + break + return text.rstrip("/") + + +def resolve_config() -> RuntimeConfig: + raw = _raw_service_config() + base_url = _resolve_ref(_config_value(raw, "base_url", "baseUrl")) or os.getenv("SECGATE3600_BASE_URL", "") + if not base_url: + raise SecGateError("SecGate 3600 base_url is not configured") + + username = ( + _resolve_ref(_config_value(raw, "username")) + or get_secret_manager().get("secgate3600_username") + or get_secret_manager().get(f"{SERVICE_ID}_username") + or os.getenv("SECGATE3600_USERNAME", "") + ) + password = ( + _resolve_ref(_config_value(raw, "password")) + or get_secret_manager().get("secgate3600_password") + or get_secret_manager().get(f"{SERVICE_ID}_password") + or os.getenv("SECGATE3600_PASSWORD", "") + ) + if not username or not password: + raise SecGateError("SecGate 3600 username/password is not configured") + + verify_ssl = _as_bool( + _config_value(raw, "verify_ssl", "ssl_verify", "verifySsl") + if _config_value(raw, "verify_ssl", "ssl_verify", "verifySsl") is not None + else os.getenv("SECGATE3600_VERIFY_SSL"), + DEFAULT_VERIFY_SSL, + ) + try: + timeout = int(_config_value(raw, "timeout") or DEFAULT_TIMEOUT) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + + return RuntimeConfig( + base_url=_normalize_base_url(base_url), + username=username, + password=password, + verify_ssl=verify_ssl, + timeout=timeout, + ) + + +class SecGateClient: + def __init__(self, config: RuntimeConfig) -> None: + self.config = config + self.session = requests.Session() + self.token = "" + + def _url(self, path: str) -> str: + return f"{self.config.base_url}{path}" + + def login(self) -> dict[str, Any]: + response = self.session.post( + self._url("/v1.0/login"), + json={"username": self.config.username, "password": self.config.password}, + timeout=self.config.timeout, + verify=self.config.verify_ssl, + headers={"Content-type": "application/json"}, + ) + payload = _json_response(response) + if not _is_success(payload): + raise SecGateError(_error_message(payload, fallback=f"Login failed with HTTP {response.status_code}")) + result = payload.get("result") if isinstance(payload.get("result"), dict) else {} + token = str(result.get("token") or payload.get("token") or "").strip() + if not token: + raise SecGateError("Login succeeded but response did not include token") + self.token = token + self.session.cookies.set("token", token) + return { + "success": True, + "username": result.get("username") or self.config.username, + "base_url": self.config.base_url, + "token_present": True, + } + + def rest(self, *, module: str, function: str, body: dict[str, Any] | None = None, page_index: int = 1, page_size: int = 20) -> dict[str, Any]: + if not self.token: + self.login() + request_body = [ + { + "head": { + "module": module, + "function": function, + "page_index": page_index, + "page_size": page_size, + }, + "body": body or {}, + } + ] + response = self.session.post( + self._url("/v1.0/rest/"), + json=request_body, + timeout=self.config.timeout, + verify=self.config.verify_ssl, + headers={"Content-type": "application/json"}, + ) + payload = _json_response(response) + if not _is_success(payload): + raise SecGateError(_error_message(payload, fallback=f"REST call failed with HTTP {response.status_code}")) + return payload + + +def _json_response(response: requests.Response) -> dict[str, Any]: + try: + payload = response.json() + except ValueError as exc: + raise SecGateError(f"Invalid JSON response: HTTP {response.status_code}") from exc + if not isinstance(payload, dict): + raise SecGateError("Unexpected response shape: expected object") + return payload + + +def _is_success(payload: dict[str, Any]) -> bool: + if payload.get("success") is True: + return True + head = payload.get("head") + if isinstance(head, dict) and str(head.get("error_code")) in {"0", "success"}: + return True + if str(payload.get("error_code")) in {"0", "success"}: + return True + result = payload.get("result") + if isinstance(result, dict) and str(result.get("error_code")) in {"0", "success"}: + return True + return False + + +def _error_message(payload: dict[str, Any], *, fallback: str) -> str: + for container in (payload, payload.get("head"), payload.get("result")): + if isinstance(container, dict): + for key in ("error_string", "error_msg", "message", "error_code"): + value = container.get(key) + if value not in (None, ""): + return str(value) + return fallback + + +def _ok(data: Any, *, action: str) -> ToolResult: + return ToolResult(success=True, output=data, metadata={"source": "SecGate 3600", "version": PRODUCT_VERSION, "action": action}) + + +def get_client() -> SecGateClient: + return SecGateClient(resolve_config()) + + +def check_login(args: dict[str, Any]) -> ToolResult: + del args + return _ok(get_client().login(), action="check_login") + + +def _load_api_catalog() -> list[dict[str, str]]: + try: + data = json.loads(CATALOG_FILE.read_text(encoding="utf-8")) + except FileNotFoundError: + return [] + entries = data.get("entries") + if not isinstance(entries, list): + return [] + catalog: list[dict[str, str]] = [] + for entry in entries: + if not isinstance(entry, dict): + continue + module = str(entry.get("module") or "").strip() + function = str(entry.get("function") or "").strip() + kind = str(entry.get("kind") or "").strip() + if module and function and kind in {"readonly", "mutation"}: + catalog.append({"module": module, "function": function, "kind": kind}) + return catalog + + +def _catalog_pairs(kind: str) -> set[tuple[str, str]]: + return { + (entry["module"], entry["function"]) + for entry in _load_api_catalog() + if entry.get("kind") == kind + } + + +READONLY_ACTIONS: dict[str, dict[str, Any]] = { + "notice_num_day": {"module": "notice", "function": "get_notice_num_day", "body": {}, "page_size": 20}, + "threats_last_day": { + "module": "statistics", + "function": "get_threat_threats", + "body": {"data": {"time": "last-1-day"}}, + "page_size": 20, + }, + "focus_last_day": { + "module": "statistics", + "function": "get_focus_focus", + "body": {"data": {"time": "last-1-day"}}, + "page_size": 20, + }, + "cpu_usage": { + "module": "statistics", + "function": "get_cpu_usage", + "body": {"data": {"time": "last-1-hours", "group_by": "cpu"}}, + "page_size": 2000, + }, + "memory_usage": { + "module": "statistics", + "function": "get_memory_usage", + "body": {"data": {"time": "last-1-hours", "group_by": "memory"}}, + "page_size": 2000, + }, + "disk_usage": { + "module": "statistics", + "function": "get_disk_usage", + "body": {"data": {"time": "last-1-hours", "group_by": "disk"}}, + "page_size": 2000, + }, + "connection_monitor": { + "module": "statistics", + "function": "get_connection_monitor", + "body": {"data": {"time": "last-1-hours"}}, + "page_size": 2000, + }, + "system_info": {"module": "dashboard", "function": "get_system_info", "body": {}, "page_size": 20}, + "system_resource": {"module": "dashboard", "function": "get_system_resource", "body": {}, "page_size": 20}, + "interface_info": {"module": "dashboard", "function": "get_interface_info", "body": {}, "page_size": 20}, + "interface_list": { + "module": "inter_face", + "function": "show_all_interface_web", + "body": {"info": {"interface": {}, "filter": {"inf_type": "physical", "inf_desc": "", "inf_name": "", "inf_zone": ""}}}, + "page_size": 50, + }, + "security_policy_list": { + "module": "sec_policy", + "function": "get_sec_policy", + "body": {"sec_policy": [{"name": "", "is_detail": False}]}, + "page_size": 20, + }, +} + + +SYSTEM_ACTIONS = {"check_login", "system_info", "system_resource", "cpu_usage", "memory_usage", "disk_usage"} +DASHBOARD_ACTIONS = {"notice_num_day", "threats_last_day", "focus_last_day", "connection_monitor", "interface_info"} +NETWORK_ACTIONS = {"interface_list"} +POLICY_ACTIONS = {"security_policy_list"} + + +def call_readonly(action: str, args: dict[str, Any]) -> ToolResult: + spec = READONLY_ACTIONS[action] + body = args.get("body") if isinstance(args.get("body"), dict) else spec.get("body", {}) + page_index = int(args.get("page_index") or 1) + page_size = int(args.get("page_size") or spec.get("page_size") or 20) + return _ok( + get_client().rest( + module=spec["module"], + function=spec["function"], + body=body, + page_index=page_index, + page_size=page_size, + ), + action=action, + ) + + +def api_catalog(args: dict[str, Any]) -> ToolResult: + del args + catalog = _load_api_catalog() + readonly_entries = [entry for entry in catalog if entry.get("kind") == "readonly"] + mutation_entries = [entry for entry in catalog if entry.get("kind") == "mutation"] + return _ok( + { + "login": "/v1.0/login", + "rest": "/v1.0/rest/", + "catalog_counts": { + "total": len(catalog), + "readonly": len(readonly_entries), + "mutation": len(mutation_entries), + }, + "documented_api_catalog": catalog, + "readonly_actions": READONLY_ACTIONS, + "groups": { + "system": sorted(SYSTEM_ACTIONS), + "dashboard": sorted(DASHBOARD_ACTIONS), + "network": sorted(NETWORK_ACTIONS), + "policy": sorted(POLICY_ACTIONS), + }, + }, + action="api_catalog", + ) + + +def rest_call_readonly(args: dict[str, Any]) -> ToolResult: + module = str(args.get("module") or "").strip() + function = str(args.get("function") or "").strip() + if not module or not function: + raise SecGateError("module and function are required") + allowed = _catalog_pairs("readonly") | { + (spec["module"], spec["function"]) for spec in READONLY_ACTIONS.values() + } + if (module, function) not in allowed: + raise SecGateError("Only documented read-only module/function pairs in secgate3600_api_catalog are allowed") + body = args.get("body") if isinstance(args.get("body"), dict) else {} + page_index = int(args.get("page_index") or 1) + page_size = int(args.get("page_size") or 20) + return _ok(get_client().rest(module=module, function=function, body=body, page_index=page_index, page_size=page_size), action="rest_call_readonly") + + +def rest_call_mutation(args: dict[str, Any]) -> ToolResult: + module = str(args.get("module") or "").strip() + function = str(args.get("function") or "").strip() + if not module or not function: + raise SecGateError("module and function are required") + if (module, function) not in _catalog_pairs("mutation"): + raise SecGateError("Only documented mutation module/function pairs in secgate3600_api_catalog are allowed") + body = args.get("body") if isinstance(args.get("body"), dict) else {} + page_index = int(args.get("page_index") or 1) + page_size = int(args.get("page_size") or 20) + return _ok(get_client().rest(module=module, function=function, body=body, page_index=page_index, page_size=page_size), action="rest_call_mutation") + + +ACTION_HANDLERS: dict[str, Callable[[dict[str, Any]], ToolResult]] = { + "check_login": check_login, + "api_catalog": api_catalog, + "rest_call_readonly": rest_call_readonly, + "rest_call_mutation": rest_call_mutation, +} +for _action in READONLY_ACTIONS: + ACTION_HANDLERS[_action] = lambda args, action=_action: call_readonly(action, args) + + +async def _dispatch(ctx: ToolContext, allowed: set[str], action: str, **params: Any) -> ToolResult: + del ctx + if action == "test": + try: + return await asyncio.to_thread(check_login, params) + except SecGateError as exc: + return ToolResult(success=False, error=str(exc), metadata={"source": "SecGate 3600", "version": PRODUCT_VERSION, "action": action}) + except Exception as exc: + return ToolResult(success=False, error=f"Unexpected SecGate 3600 error: {exc}", metadata={"source": "SecGate 3600", "version": PRODUCT_VERSION, "action": action}) + if action not in allowed: + return ToolResult(success=False, error=f"Unsupported SecGate 3600 action: {action}. Available: {', '.join(sorted(allowed))}") + try: + return await asyncio.to_thread(ACTION_HANDLERS[action], params) + except SecGateError as exc: + return ToolResult(success=False, error=str(exc), metadata={"source": "SecGate 3600", "version": PRODUCT_VERSION, "action": action}) + except Exception as exc: + return ToolResult(success=False, error=f"Unexpected SecGate 3600 error: {exc}", metadata={"source": "SecGate 3600", "version": PRODUCT_VERSION, "action": action}) + + +async def system(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch(ctx, SYSTEM_ACTIONS, action, **params) + + +async def dashboard(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch(ctx, DASHBOARD_ACTIONS, action, **params) + + +async def network(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch(ctx, NETWORK_ACTIONS, action, **params) + + +async def policy(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch(ctx, POLICY_ACTIONS, action, **params) + + +async def api_readonly(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch(ctx, {"api_catalog", "rest_call_readonly", *READONLY_ACTIONS.keys()}, action, **params) + + +async def api_mutation(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch(ctx, {"api_catalog", "rest_call_mutation"}, action, **params) diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_catalog.json b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_catalog.json new file mode 100644 index 000000000..b6ec3d01a --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_catalog.json @@ -0,0 +1,2006 @@ +{ + "schema_version": 1, + "source": "SecGate3600 RESTful API guide V1.1", + "entries": [ + { + "module": "HA", + "function": "get_obj_lkdt_ref_list", + "kind": "readonly" + }, + { + "module": "aaa", + "function": "get_authsrv_config_list", + "kind": "readonly" + }, + { + "module": "aaa", + "function": "get_authsrv_list", + "kind": "readonly" + }, + { + "module": "aaa", + "function": "get_authsrv_member_group", + "kind": "readonly" + }, + { + "module": "aaa", + "function": "get_authsrv_member_role", + "kind": "readonly" + }, + { + "module": "aaa", + "function": "get_authsrv_member_user", + "kind": "readonly" + }, + { + "module": "addr_blacklist", + "function": "add_batch_blacklist_cfg", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "add_blacklist_ip", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "add_blacklist_ip_web", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "add_blacklist_mac_web", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "clear_batch_domain_blacklist_hitnum", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "clear_batch_domain_blacklist_hitnum_all", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "clear_batch_ip_blacklist_hitnum", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "clear_batch_ip_blacklist_hitnum_all", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "clear_blacklist_hitnum", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "clear_blacklist_hitnum_one", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "del_batch_blacklist_all_cfg", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "del_batch_blacklist_cfg", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "del_blacklist_by_id", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "export_batch_blacklist_cfg", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "get_batch_domain_blacklist_commit_status", + "kind": "readonly" + }, + { + "module": "addr_blacklist", + "function": "get_blacklist_config", + "kind": "readonly" + }, + { + "module": "addr_blacklist", + "function": "import_batch_blacklist_cfg", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "set_batch_blacklist_match_model", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "set_batch_domain_blacklist_commit", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "set_blacklist_match_model", + "kind": "mutation" + }, + { + "module": "addr_blacklist", + "function": "show_batch_blacklist_cfg", + "kind": "readonly" + }, + { + "module": "addr_blacklist", + "function": "show_batch_blacklist_match_model", + "kind": "readonly" + }, + { + "module": "addr_blacklist", + "function": "show_blacklist_match_model", + "kind": "readonly" + }, + { + "module": "addr_whitelist", + "function": "add_whitelist_ip", + "kind": "mutation" + }, + { + "module": "addr_whitelist", + "function": "clear_whitelist_hitnum", + "kind": "mutation" + }, + { + "module": "addr_whitelist", + "function": "clear_whitelist_hitnum_one", + "kind": "mutation" + }, + { + "module": "addr_whitelist", + "function": "del_whitelist_by_id", + "kind": "mutation" + }, + { + "module": "addr_whitelist", + "function": "get_whitelist_config", + "kind": "readonly" + }, + { + "module": "addr_whitelist", + "function": "get_whitelist_config_by_id", + "kind": "readonly" + }, + { + "module": "admin_host", + "function": "add_admin_host_addr", + "kind": "mutation" + }, + { + "module": "admin_host", + "function": "add_admin_mac", + "kind": "mutation" + }, + { + "module": "admin_host", + "function": "del_admin_host_addr", + "kind": "mutation" + }, + { + "module": "admin_host", + "function": "del_admin_mac", + "kind": "mutation" + }, + { + "module": "admin_host", + "function": "get_admin_host", + "kind": "readonly" + }, + { + "module": "admin_host", + "function": "get_admin_mac", + "kind": "readonly" + }, + { + "module": "admin_host", + "function": "get_admin_port", + "kind": "readonly" + }, + { + "module": "admin_host", + "function": "set_admin_port", + "kind": "mutation" + }, + { + "module": "admin_user", + "function": "add_admin_user", + "kind": "mutation" + }, + { + "module": "admin_user", + "function": "del_admin_user", + "kind": "mutation" + }, + { + "module": "admin_user", + "function": "get_admin_user", + "kind": "readonly" + }, + { + "module": "admin_user", + "function": "get_admin_user_detail", + "kind": "readonly" + }, + { + "module": "admin_user", + "function": "set_admin_user", + "kind": "mutation" + }, + { + "module": "alg", + "function": "get_alg_list", + "kind": "readonly" + }, + { + "module": "alg", + "function": "set_alg_conf", + "kind": "mutation" + }, + { + "module": "all", + "function": "get_ha_compare_cfg", + "kind": "readonly" + }, + { + "module": "app", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "area", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "asset", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "auto_upgrade", + "function": "set_auto_upgrade_conf", + "kind": "mutation" + }, + { + "module": "auto_upgrade", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "av", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "av", + "function": "set_av_add_profile", + "kind": "mutation" + }, + { + "module": "av", + "function": "set_av_custom_add", + "kind": "mutation" + }, + { + "module": "av", + "function": "set_av_custom_del", + "kind": "mutation" + }, + { + "module": "av", + "function": "set_av_del_profile", + "kind": "mutation" + }, + { + "module": "av", + "function": "set_av_edit_profile", + "kind": "mutation" + }, + { + "module": "av", + "function": "show_av_md5_custom", + "kind": "readonly" + }, + { + "module": "av", + "function": "show_av_profile", + "kind": "readonly" + }, + { + "module": "cloud_basic", + "function": "del_cloud_basic_intelligence_info", + "kind": "mutation" + }, + { + "module": "cloud_basic", + "function": "show_cld_cloud_basic_abnormal_host_statistics", + "kind": "readonly" + }, + { + "module": "cloud_basic", + "function": "show_cld_cloud_basic_intelligence_statistics", + "kind": "readonly" + }, + { + "module": "cloud_basic", + "function": "show_cloud_basic_intelligence_info", + "kind": "readonly" + }, + { + "module": "cloud_notice", + "function": "get_cloud_notice_msg_num", + "kind": "readonly" + }, + { + "module": "ctrl_agent", + "function": "get_controller_cfg", + "kind": "readonly" + }, + { + "module": "ctrl_agent", + "function": "get_controller_online", + "kind": "readonly" + }, + { + "module": "ctrl_agent", + "function": "set_controller_cfg", + "kind": "mutation" + }, + { + "module": "ctrl_agent", + "function": "set_controller_conf", + "kind": "mutation" + }, + { + "module": "dashboard", + "function": "get_interface_info", + "kind": "readonly" + }, + { + "module": "dashboard", + "function": "get_system_info", + "kind": "readonly" + }, + { + "module": "dashboard", + "function": "get_system_resource", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "add_decrypt_bypass_sni_domain", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "add_decrypt_policy", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "add_ssl_decrypt_profile", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "add_ssl_ins_profile", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "clear_decrypt_bypass_sni_domain_hitnum_all", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "clear_decrypt_bypass_sni_domain_hitnum_one", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "del_decrypt_bypass_sni_domain", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "del_decrypt_cert_signer", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "del_decrypt_policy", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "del_ssl_decrypt_profile", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "del_ssl_ins_profile", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "get_decrypt_cert_issue_addr_list", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "get_decrypt_certlearn_info", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "get_ssl_decrypt_profile", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "get_ssl_ins_cert", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "get_ssl_ins_profile", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "set_decrypt_bypass_sni_domain", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_decrypt_cert_issue", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_decrypt_cert_trustCA_default", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_decrypt_mirror_dport_offset", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_decrypt_policy", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_decrypt_policy_pri", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_ssl_decrypt_profile", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "set_ssl_ins_profile", + "kind": "mutation" + }, + { + "module": "decrypt", + "function": "show_decrypt_bypass_sni", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "show_decrypt_cert_issue", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "show_decrypt_cert_signer", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "show_decrypt_cert_trustCA", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "show_decrypt_mirror_config", + "kind": "readonly" + }, + { + "module": "decrypt", + "function": "show_decrypt_policy", + "kind": "readonly" + }, + { + "module": "dns", + "function": "add_dns_domain_blacklist", + "kind": "mutation" + }, + { + "module": "dns", + "function": "add_dns_domain_whitelist", + "kind": "mutation" + }, + { + "module": "dns", + "function": "add_ioc_honeypot_domain", + "kind": "mutation" + }, + { + "module": "dns", + "function": "clear_dns_domain_hit_num", + "kind": "mutation" + }, + { + "module": "dns", + "function": "del_dns_domain", + "kind": "mutation" + }, + { + "module": "dns", + "function": "del_ioc_honeypot_domain_one", + "kind": "mutation" + }, + { + "module": "dns", + "function": "get_dns_domain_list", + "kind": "readonly" + }, + { + "module": "dns", + "function": "set_dns_domain_blacklist", + "kind": "mutation" + }, + { + "module": "dns", + "function": "set_dns_domain_whitelist", + "kind": "mutation" + }, + { + "module": "dns", + "function": "set_ioc_honeypot", + "kind": "mutation" + }, + { + "module": "dns", + "function": "show_ioc_honeypot_domain", + "kind": "readonly" + }, + { + "module": "exploit", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "filetype", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "honeypot_policy", + "function": "add_honeypot_policy", + "kind": "mutation" + }, + { + "module": "honeypot_policy", + "function": "del_honeypot_policy", + "kind": "mutation" + }, + { + "module": "honeypot_policy", + "function": "del_honeypot_policy_hitnum", + "kind": "mutation" + }, + { + "module": "honeypot_policy", + "function": "get_honeypot_policy", + "kind": "readonly" + }, + { + "module": "honeypot_policy", + "function": "get_honeypot_policy_filter", + "kind": "readonly" + }, + { + "module": "honeypot_policy", + "function": "get_honeypot_policy_name", + "kind": "readonly" + }, + { + "module": "honeypot_policy", + "function": "set_honeypot_policy", + "kind": "mutation" + }, + { + "module": "honeypot_policy", + "function": "set_move_honeypot_policy_pri", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "add_bridge", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "add_inf_phy", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "add_vlan_conf", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "del_bridge", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "del_inf_phy", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "del_vlan_conf", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "get_bridge_conf_list", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "get_channel_inf_conf", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "get_interface_conf", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "get_interface_select_list", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "get_vlan_conf", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "get_vlan_conf_list", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "get_vlan_ref_modle", + "kind": "readonly" + }, + { + "module": "inter_face", + "function": "set_bridge_conf", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "set_if_config_web", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "set_vlan_conf", + "kind": "mutation" + }, + { + "module": "inter_face", + "function": "show_all_interface_web", + "kind": "readonly" + }, + { + "module": "ips", + "function": "add_ips_profile", + "kind": "mutation" + }, + { + "module": "ips", + "function": "add_ips_profile_scene", + "kind": "mutation" + }, + { + "module": "ips", + "function": "del_ips_profile", + "kind": "mutation" + }, + { + "module": "ips", + "function": "del_ips_profile_scene", + "kind": "mutation" + }, + { + "module": "ips", + "function": "get_app_ctl", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_app_proto_scene", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_defence_scene_class", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_defence_type_info", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_profile_detail", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_profile_info", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_profile_info_scene", + "kind": "readonly" + }, + { + "module": "ips", + "function": "get_scene_class_rules", + "kind": "readonly" + }, + { + "module": "ips", + "function": "set_app_ctl_proto", + "kind": "mutation" + }, + { + "module": "ips", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "ips", + "function": "set_ips_profile", + "kind": "mutation" + }, + { + "module": "ips", + "function": "set_ips_profile_scene", + "kind": "mutation" + }, + { + "module": "ips_sig", + "function": "add_ips_sig", + "kind": "mutation" + }, + { + "module": "ips_sig", + "function": "del_ips_sig", + "kind": "mutation" + }, + { + "module": "ips_sig", + "function": "get_ips_sig", + "kind": "readonly" + }, + { + "module": "ips_sig", + "function": "get_ips_sig_all", + "kind": "readonly" + }, + { + "module": "ips_sig", + "function": "set_ips_sig", + "kind": "mutation" + }, + { + "module": "ips_sig", + "function": "set_ips_sig_commit", + "kind": "mutation" + }, + { + "module": "ipv6", + "function": "get_nat64_conf", + "kind": "readonly" + }, + { + "module": "ipv6", + "function": "set_add_nat64_prefix", + "kind": "mutation" + }, + { + "module": "ipv6", + "function": "set_del_nat64_prefix", + "kind": "mutation" + }, + { + "module": "isp", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "log", + "function": "get_log_send_template_list", + "kind": "readonly" + }, + { + "module": "log", + "function": "get_log_syslog_config", + "kind": "readonly" + }, + { + "module": "log", + "function": "set_log_syslog_config", + "kind": "mutation" + }, + { + "module": "multicast_snooping", + "function": "set_vlan_igmp_snooping", + "kind": "mutation" + }, + { + "module": "nat", + "function": "add_dnat_conf", + "kind": "mutation" + }, + { + "module": "nat", + "function": "add_snat_conf", + "kind": "mutation" + }, + { + "module": "nat", + "function": "clear_dnat_hitnum", + "kind": "mutation" + }, + { + "module": "nat", + "function": "clear_snat_hitnum", + "kind": "mutation" + }, + { + "module": "nat", + "function": "del_dnat_conf", + "kind": "mutation" + }, + { + "module": "nat", + "function": "del_snat_conf", + "kind": "mutation" + }, + { + "module": "nat", + "function": "get_dnat_conf", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_dnat_list", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_dnat_list_by_filter", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_snat_conf", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_snat_list", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_snat_list_by_filter", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_snat_trans_ip_state", + "kind": "readonly" + }, + { + "module": "nat", + "function": "get_snat_trans_ip_state_by_filter", + "kind": "readonly" + }, + { + "module": "nat", + "function": "set_dnat_conf", + "kind": "mutation" + }, + { + "module": "nat", + "function": "set_dnat_move", + "kind": "mutation" + }, + { + "module": "nat", + "function": "set_dnat_status", + "kind": "mutation" + }, + { + "module": "nat", + "function": "set_snat_conf", + "kind": "mutation" + }, + { + "module": "nat", + "function": "set_snat_move", + "kind": "mutation" + }, + { + "module": "nat", + "function": "set_snat_status", + "kind": "mutation" + }, + { + "module": "notice", + "function": "get_notice_info", + "kind": "readonly" + }, + { + "module": "notice", + "function": "get_notice_num_day", + "kind": "readonly" + }, + { + "module": "ntp", + "function": "set_ntp_conf", + "kind": "mutation" + }, + { + "module": "obj_address", + "function": "add_obj_addr_conf", + "kind": "mutation" + }, + { + "module": "obj_address", + "function": "add_obj_addr_group", + "kind": "mutation" + }, + { + "module": "obj_address", + "function": "del_obj_addr_conf", + "kind": "mutation" + }, + { + "module": "obj_address", + "function": "del_obj_addr_group", + "kind": "mutation" + }, + { + "module": "obj_address", + "function": "get_obj_addr_group", + "kind": "readonly" + }, + { + "module": "obj_address", + "function": "get_obj_addr_list", + "kind": "readonly" + }, + { + "module": "obj_address", + "function": "set_obj_addr_conf", + "kind": "mutation" + }, + { + "module": "obj_address", + "function": "set_obj_addr_group", + "kind": "mutation" + }, + { + "module": "obj_app", + "function": "get_search_app_list_for_rule", + "kind": "readonly" + }, + { + "module": "obj_area", + "function": "add_obj_area_custom", + "kind": "mutation" + }, + { + "module": "obj_area", + "function": "del_obj_area_custom", + "kind": "mutation" + }, + { + "module": "obj_area", + "function": "export_area_custom", + "kind": "mutation" + }, + { + "module": "obj_area", + "function": "get_area_list", + "kind": "readonly" + }, + { + "module": "obj_area", + "function": "import_area_custom", + "kind": "mutation" + }, + { + "module": "obj_area", + "function": "set_obj_area_custom", + "kind": "mutation" + }, + { + "module": "obj_honeypot", + "function": "add_obj_honeypot", + "kind": "mutation" + }, + { + "module": "obj_honeypot", + "function": "del_obj_honeypot", + "kind": "mutation" + }, + { + "module": "obj_honeypot", + "function": "get_obj_honeypot", + "kind": "readonly" + }, + { + "module": "obj_honeypot", + "function": "set_obj_honeypot", + "kind": "mutation" + }, + { + "module": "obj_lkdt", + "function": "add_obj_lkdt", + "kind": "mutation" + }, + { + "module": "obj_lkdt", + "function": "del_obj_lkdt", + "kind": "mutation" + }, + { + "module": "obj_lkdt", + "function": "get_obj_lkdt", + "kind": "readonly" + }, + { + "module": "obj_lkdt", + "function": "get_obj_lkdt_list", + "kind": "readonly" + }, + { + "module": "obj_lkdt", + "function": "get_obj_lkdt_ref_list", + "kind": "readonly" + }, + { + "module": "obj_lkdt", + "function": "get_obj_lkdt_textvalue_list", + "kind": "readonly" + }, + { + "module": "obj_lkdt", + "function": "set_obj_lkdt", + "kind": "mutation" + }, + { + "module": "obj_sche", + "function": "add_schedule_conf", + "kind": "mutation" + }, + { + "module": "obj_sche", + "function": "del_schedule_conf", + "kind": "mutation" + }, + { + "module": "obj_sche", + "function": "get_obj_sche_list", + "kind": "readonly" + }, + { + "module": "obj_sche", + "function": "set_obj_schedule", + "kind": "mutation" + }, + { + "module": "obj_server", + "function": "add_server_conf", + "kind": "mutation" + }, + { + "module": "obj_server", + "function": "del_server", + "kind": "mutation" + }, + { + "module": "obj_server", + "function": "get_server_conf", + "kind": "readonly" + }, + { + "module": "obj_server", + "function": "get_server_list", + "kind": "readonly" + }, + { + "module": "obj_server", + "function": "get_server_search", + "kind": "readonly" + }, + { + "module": "obj_server", + "function": "set_server_conf", + "kind": "mutation" + }, + { + "module": "obj_service", + "function": "add_service_custom", + "kind": "mutation" + }, + { + "module": "obj_service", + "function": "add_service_grp", + "kind": "mutation" + }, + { + "module": "obj_service", + "function": "del_service_custom", + "kind": "mutation" + }, + { + "module": "obj_service", + "function": "del_service_grp", + "kind": "mutation" + }, + { + "module": "obj_service", + "function": "get_service_custom", + "kind": "readonly" + }, + { + "module": "obj_service", + "function": "get_service_grp", + "kind": "readonly" + }, + { + "module": "obj_service", + "function": "get_service_predef", + "kind": "readonly" + }, + { + "module": "obj_service", + "function": "set_service_custom", + "kind": "mutation" + }, + { + "module": "obj_service", + "function": "set_service_grp", + "kind": "mutation" + }, + { + "module": "pki", + "function": "get_pki_gwcert", + "kind": "readonly" + }, + { + "module": "pnf", + "function": "add_pnf_group", + "kind": "mutation" + }, + { + "module": "pnf", + "function": "add_pnf_object", + "kind": "mutation" + }, + { + "module": "pnf", + "function": "del_pnf_group", + "kind": "mutation" + }, + { + "module": "pnf", + "function": "del_pnf_object", + "kind": "mutation" + }, + { + "module": "pnf", + "function": "get_pnf_group_list", + "kind": "readonly" + }, + { + "module": "pnf", + "function": "get_pnf_list", + "kind": "readonly" + }, + { + "module": "pnf", + "function": "get_service_chain_monitor", + "kind": "readonly" + }, + { + "module": "pnf", + "function": "get_sfc_name_list_bypass", + "kind": "readonly" + }, + { + "module": "pnf", + "function": "set_pnf_group", + "kind": "mutation" + }, + { + "module": "pnf", + "function": "set_pnf_group_bypass", + "kind": "mutation" + }, + { + "module": "pnf", + "function": "set_pnf_object", + "kind": "mutation" + }, + { + "module": "profile_group", + "function": "add_profile_group", + "kind": "mutation" + }, + { + "module": "profile_group", + "function": "del_profile_group", + "kind": "mutation" + }, + { + "module": "profile_group", + "function": "get_profile_group", + "kind": "readonly" + }, + { + "module": "profile_group", + "function": "set_profile_group", + "kind": "mutation" + }, + { + "module": "qos", + "function": "add_qos_class_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "add_qos_line_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "add_qos_policy_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "add_qos_root_class_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "del_qos_class_tree_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "del_qos_line_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "del_qos_policy_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "get_qos_class_quota_use", + "kind": "readonly" + }, + { + "module": "qos", + "function": "get_qos_policy_list", + "kind": "readonly" + }, + { + "module": "qos", + "function": "set_qos_class", + "kind": "mutation" + }, + { + "module": "qos", + "function": "set_qos_line_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "set_qos_line_interface", + "kind": "mutation" + }, + { + "module": "qos", + "function": "set_qos_policy_conf", + "kind": "mutation" + }, + { + "module": "qos", + "function": "show_web_qos_class_one", + "kind": "readonly" + }, + { + "module": "qos", + "function": "show_web_qos_class_tree", + "kind": "readonly" + }, + { + "module": "qos", + "function": "show_web_qos_line_tree", + "kind": "readonly" + }, + { + "module": "route_policy", + "function": "add_route_policy", + "kind": "mutation" + }, + { + "module": "route_policy", + "function": "del_route_policy", + "kind": "mutation" + }, + { + "module": "route_policy", + "function": "del_route_policy_hitnum", + "kind": "mutation" + }, + { + "module": "route_policy", + "function": "get_route_policy_list", + "kind": "readonly" + }, + { + "module": "route_policy", + "function": "set_route_policy_move", + "kind": "mutation" + }, + { + "module": "route_static", + "function": "add_route_static", + "kind": "mutation" + }, + { + "module": "route_static", + "function": "del_route_static", + "kind": "mutation" + }, + { + "module": "route_static", + "function": "get_route_all_list", + "kind": "readonly" + }, + { + "module": "route_static", + "function": "get_route_single_static", + "kind": "readonly" + }, + { + "module": "route_static", + "function": "set_route_static", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "add_sec_policy", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "add_sec_policy_group", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "add_sec_policy_learn_task", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "del_sec_policy", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "del_sec_policy_group", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "del_sec_policy_hitnum", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "del_sec_policy_learn_task", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "get_sec_learn_policy_name", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_adjacent_merged_info", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_export", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_filter", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_group", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_group_filter", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_group_name", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_hitinfo_filter", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_include", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_name", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_part_include", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_part_useless", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_region_merged_info", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_task", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_sec_policy_useless", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "get_system_hostname", + "kind": "readonly" + }, + { + "module": "sec_policy", + "function": "set_move_sec_policy_group_member_pri", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_move_sec_policy_pri", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_group", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_group_state", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_import", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_learn_task", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_learn_task_state", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_state", + "kind": "mutation" + }, + { + "module": "sec_policy", + "function": "set_sec_policy_task_generate_policy", + "kind": "mutation" + }, + { + "module": "sfc", + "function": "add_sfc", + "kind": "mutation" + }, + { + "module": "sfc", + "function": "del_sfc", + "kind": "mutation" + }, + { + "module": "sfc", + "function": "set_sfc", + "kind": "mutation" + }, + { + "module": "sfc", + "function": "set_sfc_bypass", + "kind": "mutation" + }, + { + "module": "sfc", + "function": "show_sfc", + "kind": "readonly" + }, + { + "module": "sfc_policy", + "function": "add_sfc_policy", + "kind": "mutation" + }, + { + "module": "sfc_policy", + "function": "del_sfc_policy", + "kind": "mutation" + }, + { + "module": "sfc_policy", + "function": "del_sfc_policy_hitnum", + "kind": "mutation" + }, + { + "module": "sfc_policy", + "function": "get_sfc_policy", + "kind": "readonly" + }, + { + "module": "sfc_policy", + "function": "get_sfc_policy_filter", + "kind": "readonly" + }, + { + "module": "sfc_policy", + "function": "get_sfc_policy_name", + "kind": "readonly" + }, + { + "module": "sfc_policy", + "function": "set_move_sfc_policy_pri", + "kind": "mutation" + }, + { + "module": "sfc_policy", + "function": "set_sfc_policy", + "kind": "mutation" + }, + { + "module": "sfc_policy", + "function": "set_sfc_policy_state", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "del_ha_group_conf", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "get_ha_compare_cfg", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_global_config", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_group_id_list", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_group_list", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_track_bfd_list", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_track_if_list", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_track_ip_list", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "get_ha_track_pnf_list", + "kind": "readonly" + }, + { + "module": "sgrp", + "function": "set_ha_global_config", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_group_conf", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_group_demotion_force", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_group_trackbfd", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_group_trackif", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_group_trackip", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_group_trackpnf", + "kind": "mutation" + }, + { + "module": "sgrp", + "function": "set_ha_sync_cfg_force", + "kind": "mutation" + }, + { + "module": "smac", + "function": "get_smac_conf", + "kind": "readonly" + }, + { + "module": "smac", + "function": "set_smac_conf", + "kind": "mutation" + }, + { + "module": "snmpagent", + "function": "get_snmp_agent", + "kind": "readonly" + }, + { + "module": "snmpagent", + "function": "set_snmp_agent", + "kind": "mutation" + }, + { + "module": "snmpv3usr", + "function": "add_snmpv3usr", + "kind": "mutation" + }, + { + "module": "snmpv3usr", + "function": "del_snmpv3usr", + "kind": "mutation" + }, + { + "module": "snmpv3usr", + "function": "get_snmpv3usr", + "kind": "readonly" + }, + { + "module": "snmpv3usr", + "function": "set_snmpv3usr", + "kind": "mutation" + }, + { + "module": "statistics", + "function": "get_connection_monitor", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_cpu_usage", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_disk_usage", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_focus_focus", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_memory_usage", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_network_monitor", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_threat_monitor", + "kind": "readonly" + }, + { + "module": "statistics", + "function": "get_threat_threats", + "kind": "readonly" + }, + { + "module": "sysconfig", + "function": "export_sysconfig_ftp", + "kind": "mutation" + }, + { + "module": "sysconfig", + "function": "import_sysconfig_ftp", + "kind": "mutation" + }, + { + "module": "sysconfig", + "function": "set_save_sysconfig", + "kind": "mutation" + }, + { + "module": "syslog", + "function": "add_syslog_server_config", + "kind": "mutation" + }, + { + "module": "syslog", + "function": "del_syslog_server_config", + "kind": "mutation" + }, + { + "module": "syslog", + "function": "get_syslog_server_config_all", + "kind": "readonly" + }, + { + "module": "syslog", + "function": "get_syslog_server_config_one", + "kind": "readonly" + }, + { + "module": "syslog", + "function": "get_syslog_server_list", + "kind": "readonly" + }, + { + "module": "syslog", + "function": "set_syslog_server_config", + "kind": "mutation" + }, + { + "module": "system", + "function": "get_sec_policy", + "kind": "readonly" + }, + { + "module": "system", + "function": "get_system_hostname", + "kind": "readonly" + }, + { + "module": "system", + "function": "set_system_time", + "kind": "mutation" + }, + { + "module": "system_advanced", + "function": "get_sys_advanced_status", + "kind": "readonly" + }, + { + "module": "system_advanced", + "function": "set_sys_advanced_status", + "kind": "mutation" + }, + { + "module": "threat_disposal", + "function": "add_threat_disposal", + "kind": "mutation" + }, + { + "module": "threat_disposal", + "function": "del_threat_disposal", + "kind": "mutation" + }, + { + "module": "threat_disposal", + "function": "get_threat_disposal", + "kind": "readonly" + }, + { + "module": "threat_disposal", + "function": "get_threat_disposal_statitics", + "kind": "readonly" + }, + { + "module": "threat_disposal", + "function": "set_threat_disposal", + "kind": "mutation" + }, + { + "module": "trap", + "function": "add_trap_server", + "kind": "mutation" + }, + { + "module": "trap", + "function": "del_trap_server", + "kind": "mutation" + }, + { + "module": "trap", + "function": "get_trap_server", + "kind": "readonly" + }, + { + "module": "trap", + "function": "set_trap_server", + "kind": "mutation" + }, + { + "module": "upgrade", + "function": "get_notice_info", + "kind": "readonly" + }, + { + "module": "url", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "vsys", + "function": "get_vsys_info", + "kind": "readonly" + }, + { + "module": "vsys", + "function": "get_vsys_info_specify", + "kind": "readonly" + }, + { + "module": "vsys", + "function": "set_reset_vsys", + "kind": "mutation" + }, + { + "module": "vsys", + "function": "set_vsys", + "kind": "mutation" + }, + { + "module": "waf", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "weakpwd", + "function": "set_auto_upgrade_immediately", + "kind": "mutation" + }, + { + "module": "zone", + "function": "add_zone", + "kind": "mutation" + }, + { + "module": "zone", + "function": "del_zone", + "kind": "mutation" + }, + { + "module": "zone", + "function": "get_zone_list", + "kind": "readonly" + }, + { + "module": "zone", + "function": "set_zone", + "kind": "mutation" + } + ] +} diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_mutation.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_mutation.yaml new file mode 100644 index 000000000..2d4088f8a --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_mutation.yaml @@ -0,0 +1,37 @@ +name: secgate3600_api_mutation +description: SecGate 3600 documented mutation REST caller with confirmation. +description_cn: SecGate 3600 文档内变更类 REST 调用工具。使用 api_catalog 查看已收录 API,再用 rest_call_mutation 调用 kind=mutation 的 module/function。所有调用都需要确认。 +category: custom +enabled: true +requires_confirmation: true +provider: secgate3600_api +version: "3.6.6.0" +inputSchema: + type: object + properties: + action: + type: string + enum: + - api_catalog + - rest_call_mutation + module: + type: string + description: rest_call_mutation 使用的模块名。必须属于 api_catalog 中 kind=mutation 的条目。 + function: + type: string + description: rest_call_mutation 使用的方法名。必须属于 api_catalog 中 kind=mutation 的条目。 + page_index: + type: integer + default: 1 + page_size: + type: integer + default: 20 + body: + type: object + description: REST 请求 body。结构按 SecGate 3600 官方文档对应接口填写。 + required: + - action +handler: + type: script + script_file: secgate3600.handler.py + function: api_mutation diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_readonly.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_readonly.yaml new file mode 100644 index 000000000..fbd551f14 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_api_readonly.yaml @@ -0,0 +1,48 @@ +name: secgate3600_api_readonly +description: SecGate 3600 documented read-only REST caller. +description_cn: SecGate 3600 文档内只读 REST 调用工具。使用 api_catalog 查看已收录 API,再用 rest_call_readonly 调用任意 kind=readonly 的 module/function。 +category: custom +enabled: true +requires_confirmation: false +provider: secgate3600_api +version: "3.6.6.0" +inputSchema: + type: object + properties: + action: + type: string + enum: + - api_catalog + - rest_call_readonly + - notice_num_day + - threats_last_day + - focus_last_day + - cpu_usage + - memory_usage + - disk_usage + - connection_monitor + - system_info + - system_resource + - interface_info + - interface_list + - security_policy_list + - test + module: + type: string + description: rest_call_readonly 使用的模块名,例如 dashboard。必须属于 api_catalog 中 kind=readonly 的条目。 + function: + type: string + description: rest_call_readonly 使用的方法名,例如 get_system_resource。必须属于 api_catalog 中 kind=readonly 的条目。 + page_index: + type: integer + page_size: + type: integer + body: + type: object + description: REST 请求 body。 + required: + - action +handler: + type: script + script_file: secgate3600.handler.py + function: api_readonly diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_dashboard.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_dashboard.yaml new file mode 100644 index 000000000..da4376e05 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_dashboard.yaml @@ -0,0 +1,32 @@ +name: secgate3600_dashboard +description: SecGate 3600 dashboard statistics and monitor queries. +description_cn: SecGate 3600 首页统计和监控查询工具。 +category: custom +enabled: true +requires_confirmation: false +provider: secgate3600_api +version: "3.6.6.0" +inputSchema: + type: object + properties: + action: + type: string + enum: + - notice_num_day + - threats_last_day + - focus_last_day + - connection_monitor + - interface_info + - test + page_index: + type: integer + page_size: + type: integer + body: + type: object + required: + - action +handler: + type: script + script_file: secgate3600.handler.py + function: dashboard diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_network.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_network.yaml new file mode 100644 index 000000000..364f8c5e1 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_network.yaml @@ -0,0 +1,29 @@ +name: secgate3600_network +description: SecGate 3600 network interface read-only queries. +description_cn: SecGate 3600 网络接口只读查询工具。 +category: custom +enabled: true +requires_confirmation: false +provider: secgate3600_api +version: "3.6.6.0" +inputSchema: + type: object + properties: + action: + type: string + enum: + - interface_list + - test + page_index: + type: integer + page_size: + type: integer + body: + type: object + description: 可选覆盖请求 body,用于传入接口过滤条件。 + required: + - action +handler: + type: script + script_file: secgate3600.handler.py + function: network diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_policy.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_policy.yaml new file mode 100644 index 000000000..f0e7856c2 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_policy.yaml @@ -0,0 +1,29 @@ +name: secgate3600_policy +description: SecGate 3600 security policy read-only queries. +description_cn: SecGate 3600 安全策略只读查询工具。 +category: custom +enabled: true +requires_confirmation: false +provider: secgate3600_api +version: "3.6.6.0" +inputSchema: + type: object + properties: + action: + type: string + enum: + - security_policy_list + - test + page_index: + type: integer + page_size: + type: integer + body: + type: object + description: 可选覆盖请求 body,用于传入策略名称或 is_detail。 + required: + - action +handler: + type: script + script_file: secgate3600.handler.py + function: policy diff --git a/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_system.yaml b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_system.yaml new file mode 100644 index 000000000..82c6cfccf --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/secgate3600_v3_6_6_0/secgate3600_system.yaml @@ -0,0 +1,33 @@ +name: secgate3600_system +description: SecGate 3600 system status and login validation. +description_cn: SecGate 3600 系统状态和登录校验工具。 +category: custom +enabled: true +requires_confirmation: false +provider: secgate3600_api +version: "3.6.6.0" +inputSchema: + type: object + properties: + action: + type: string + enum: + - check_login + - system_info + - system_resource + - cpu_usage + - memory_usage + - disk_usage + - test + page_index: + type: integer + page_size: + type: integer + body: + type: object + required: + - action +handler: + type: script + script_file: secgate3600.handler.py + function: system diff --git a/.flocks/plugins/skills/device-integration-guide/SKILL.md b/.flocks/plugins/skills/device-integration-guide/SKILL.md new file mode 100644 index 000000000..010594ca3 --- /dev/null +++ b/.flocks/plugins/skills/device-integration-guide/SKILL.md @@ -0,0 +1,154 @@ +--- +name: device-integration-guide +description: 指导 Flocks 新建、添加和接入安全设备。Use when the user asks to create, add, onboard, or connect a new security device. +--- + +# Device Integration Guide + +用于处理 Flocks 设备接入相关对话。目标是把用户带到正确路径:设备创建、配置写入和敏感凭证走设备接入页面。 + +## 适用场景 + +当用户提到以下意图时使用本 skill: + +- 新建设备实例。 +- 添加安全设备到 Flocks。 +- 接入一个还没有出现在设备列表里的安全设备。 +- 用户想把一个没有现成模板的安全设备做成 Flocks 可用设备。 + +## 核心原则 + +- 先确认用户是在**新建设备实例**、**整理配置草稿**,还是**测试连通性**。 +- 不要要求用户在聊天里粘贴密码、Token、Cookie、API Key 等敏感凭证。 +- 不要在 skill 中优先引导通过工具写入设备配置;设备接入页面表单和 JSON 草稿是配置写入的主路径。 +- 每次修改后,优先用标准连通性测试验证结果。 +- 保持回答简短,给出当前动作、结果和下一步。 + +## 决策流程 + +1. 用户已经在设备列表里有目标设备,并提供了 `device_id`: + +- 如果是配置变更,回到设备接入页面表单或输出页面可回填 JSON 草稿。 +- 如果是测试或排障,先走 `device_manage(action="connectivity_test")`。 + +2. 用户说“添加设备”“接入设备”“新建设备”,但还没有设备实例: + +- 如果有已安装模板,引导用户在设备接入页面填写表单。 +- 如果没有合适模板,进入自定义设备接入路径。 +- 如果涉及密钥、密码、Token、Cookie 或浏览器登录态,只说明应该填到页面表单,不要在聊天中收集真实值。 + +3. 用户只描述产品、厂商、控制台地址或 API 文档: + +- 先判断已有模板是否可用。 +- 未安装模板需要先去 FlockHub 安装。 +- 没有合适模板时,按“自定义接入路由”选择 API、浏览器或 Workflow。 + +## 设备列表与目标确认 + +如果用户没有给出 `device_id`,先调用: + +```python +device_manage(action="list") +``` + +从返回结果里确认目标设备、机房、工具集和 `device_id`。 + +如果设备不存在,提醒用户前往「设备接入」页面添加设备;不要伪造 `device_id` 或直接调用业务工具。 + +## 新建设备与页面回填 + +用户在设备接入页面创建或配置设备时,目标是帮助页面得到清晰的表单信息。 + +需要收集的信息: + +- 设备名称。 +- 已安装模板的 `storage_key`。 +- Base URL、Host、端口、协议、租户或区域等非敏感字段。 +- SSL 证书验证偏好:`verify_ssl=true/false`。 +- 需要填写哪些敏感字段,但不收集真实值。 + +当信息足够,并且当前任务是在设备接入页面生成配置草稿时,在回复末尾输出 JSON 代码块供页面一键回填: + +```json +{"storage_key":"","device_name":"<设备名称>","fields":{"base_url":"https://example.local"},"verify_ssl":false} +``` + +JSON 草稿规则: + +- 只包含非敏感字段。 +- 不写真实 API Key、Secret、Token、Cookie、密码。 +- 敏感字段留空或省略,并提示用户稍后在页面表单中填写。 +- 如果没有合适模板,不要输出设备配置 JSON,先进入自定义接入路径。 + +## 自定义接入路由 + +没有合适已安装模板时,按用户描述选择路径: + +- 设备提供 API 文档或开放接口:选择「API 接入」。需要创建 device 插件时,先使用 `tool-builder`,目标是设备插件,不是普通 API 服务。 +- 设备主要通过 Web 控制台操作,没有开放 API:选择「浏览器接入」。需要捕获页面能力时,先使用 `web2cli`,生成可维护的设备能力。 +- 数据通过 Syslog、Kafka 或 Webhook 上报:选择「Workflow 接入」。不要创建 device 插件,引导用户走工作流发布/接入配置。 + +如果用户已经明确选择 API、浏览器或 Workflow,不要重复询问接入方式。只有无法判断时,才用一句话澄清。 + +## 配置草稿与表单更新 + +如果用户要写入或更新设备配置,应根据当前页面上下文整理表单草稿,让设备接入页面负责落库。 + +可以整理的非敏感字段包括: + +- `base_url` +- `host` +- `port` +- `scheme` +- `timeout` +- `tenant` +- `region` + +不要在 JSON 草稿或聊天中写入敏感字段: + +- `api_key` +- `secret` +- `password` +- `token` +- `cookie` +- `auth_state` + +如果用户的目标是补填密钥、修改密码、刷新 Token 或重新登录,只说明应该在设备接入页面对应字段中处理。 + +当需要给页面回填时,使用“新建设备与页面回填”中的 JSON 草稿格式。对于已有设备编辑,也只输出非敏感字段和 `verify_ssl`,并说明敏感字段在页面表单内填写。 + +## 连通性与冒烟验证 + +配置在设备接入页面保存后,除非用户明确不需要,继续调用: + +```python +device_manage(action="connectivity_test", device_id="") +``` + +连通性测试成功后,再选择少量只读、低风险的设备工具做基础冒烟验证。必须继续使用同一个 `device_id`。不要为了验证而执行写操作或高风险操作。 + +完成后汇报: + +- 目标设备和 `device_id`。 +- 页面中已整理或保存的字段名,不回显敏感值。 +- 标准连通性测试结果。 +- 只读冒烟验证结果。 + +## 失败排查顺序 + +连通性或冒烟失败时,按最小排查顺序给建议: + +1. 地址或端口是否正确,Base URL 是否包含协议。 +2. 设备侧网络、代理、防火墙或白名单是否允许 Flocks 访问。 +3. `verify_ssl` 是否与设备证书状态匹配。 +4. 页面里的凭证字段是否已填写且权限足够。 +5. 设备版本、模板版本或工具集是否匹配。 +6. 如果是浏览器接入,登录态是否过期,是否需要用户重新完成验证码、MFA 或人工确认。 + +## 不要做 + +- 不要在聊天中索要、保存或复述真实密钥。 +- 不要把自定义设备误做成普通 API 服务。 +- 不要对未安装模板输出可回填 JSON。 +- 不要跳过 `device_manage(action="connectivity_test")` 就声称设备已可用。 +- 不要把卡片状态建立在普通业务工具结果上;卡片状态以标准连通性测试写入结果为准。 diff --git a/.flocks/plugins/skills/web2cli/references/cli-in-device.md b/.flocks/plugins/skills/web2cli/references/cli-in-device.md deleted file mode 100644 index 56a12af07..000000000 --- a/.flocks/plugins/skills/web2cli/references/cli-in-device.md +++ /dev/null @@ -1,280 +0,0 @@ -# 生成后的 WebCLI 如何接入 Device 插件 - -> 本文说明:`web2cli` 已经抓到页面请求、并整理出可复用调用逻辑后,怎样把它沉淀成可在设备页识别、配置和调用的 device 插件。 - -## 结论 - -`cli-in-device.md` 不是 `cli-in-skill.md` 的替代物,而是安全设备场景下的进一步封装: - -- 所有 `web2cli` 结果都必须先完成 skill 集成 -- 如果目标是安全设备接入,再继续按本文档额外生成 device 插件 -- 最终交付关系是:`skill` 必选,`device 插件` 为安全设备场景下的额外交付 - -## 何时使用 - -在以下场景调用本文档: - -- 当前任务明确来自“设备接入”页面,目标是把某个安全设备或安全产品接入到设备管理体系 -- 最终产物需要出现在设备页,并允许用户填写实例配置、刷新模板、按 `device_id` 调用 -- 当前 WebCLI 抓到的能力属于安全设备能力,而不是单纯给 skill 复用的站点操作脚本 - -不优先使用本文档的场景: - -- 只是想保留一个可复用 CLI 供 agent 在 skill 中调用 -- 目标不是设备接入,而是某个通用网站的操作自动化、查询脚本或内部工具 -- 暂时只需要沉淀浏览器经验、CLI 参数和认证恢复流程,不需要设备页识别 - -如果当前任务来自“设备接入”页面,并且目标是安全设备接入,WebCLI 在完成 skill 集成后,还应当额外生成标准 device 插件: - -```text -$HOME/.flocks/plugins/tools/device// -├── _provider.yaml -├── .yaml -├── .handler.py -├── _cli.py # 可选,仅用于调试/回归 -└── _test.yaml # 可选,最小验证样例 -``` - -其中: - -- `_provider.yaml`:决定设备页是否能识别该模板,以及用户创建实例时需要填写哪些字段 -- `.yaml`:定义可调用工具、参数和 action -- `.handler.py`:设备运行时入口,负责读取配置、认证、发请求、清洗结果 -- `_cli.py`:只作为调试入口保留,不作为设备运行时主路径 - -认证默认规则: - -- 自定义 CLI / WebCLI 默认认证方式为 `cookie/auth-state`:优先复用浏览器保存的 `auth-state.json`,从中按请求域名/path/secure 规则选择 Cookie,并在需要时读取 localStorage -- 默认认证状态文件:`~/.flocks/browser//auth-state.json` -- 优先使用 `auth_state_path` 指向 `~/.flocks/browser//auth-state.json` -- 可以额外暴露可选 `username` / `password`,但它们只用于 cookie 失效后的认证恢复,不替代默认的 `auth_state_path` -- 不要生成或使用 `auth_state_json` / `Legacy Auth State JSON` 这类内联 JSON 字段;设备配置只保存 state 文件路径,不粘贴 state 文件内容 -- 只有在目标站点确实还依赖额外字段时,才补充 `cookie`、`csrf_token`、`access_token` 或特定认证头;这些字段是 `auth_state_path` 之外的补充,不替代默认的 cookie/auth-state -- 不要把 `cookie` 或 `token` 设计成和 `auth-state` 并列的多个默认入口;如果用户提供的是 state 文件路径,必须写入 `auth_state_path` - -## 命名约定 - -- 插件目录:`$HOME/.flocks/plugins/tools/device//` -- `plugin_id`:推荐使用稳定产品名加版本,例如 `_v1_0_0` -- `service_id`:推荐使用稳定能力标识,例如 `_device` -- handler 文件:`.handler.py` -- 可选 CLI 文件:`_cli.py` - -约定说明: - -- `` 用产品或系统的稳定标识,不用一次性任务名 -- 目录名可以带版本;`service_id` 要尽量稳定,避免和临时抓包任务绑定 -- Python 文件名统一用 `_` - -## 最小 `_provider.yaml` - -至少包含以下字段: - -```yaml -name: Acme Portal -vendor: acme_security -service_id: acme_portal_device -version: "1.0.0" -integration_type: device -description: > - Acme Portal WebCLI-backed device integration for alert listing and asset - detail queries. Configure Base URL and the required login state fields - separately in the credentials form. -description_cn: > - Acme Portal 的 WebCLI 设备接入模板,支持告警列表和资产详情查询。 - 请在设备配置中分别填写 Base URL 与所需登录态字段。 -credential_fields: - - key: base_url - label: Base URL - storage: config - config_key: base_url - input_type: url - required: true - - key: auth_state_path - label: Auth State Path - storage: config - config_key: auth_state_path - input_type: text - default: "~/.flocks/browser/acme-portal/auth-state.json" - - key: username - label: Username - storage: config - config_key: username - input_type: text - required: false - description: 仅在 cookie 失效后需要 Agent 辅助登录刷新 state 时填写 - - key: password - label: Password - storage: secret - config_key: password - secret_id: acme_portal_password - input_type: password - required: false - description: 仅在 cookie 失效后需要 Agent 辅助登录刷新 state 时填写 - - key: cookie - label: Cookie - storage: secret - config_key: cookie - secret_id: acme_portal_cookie - input_type: password - - key: csrf_token - label: CSRF Token - storage: secret - config_key: csrf_token - secret_id: acme_portal_csrf_token - input_type: password -defaults: - timeout: 30 - category: custom -notes: | - WebCLI 设备建议优先复用稳定隐藏接口,不建议把浏览器自动化作为默认运行时。 - 若返回 401/403、跳转登录页或 CSRF 失效,应先按认证失效处理。 -``` - -注意: - -- 必须包含 `integration_type: device` -- `description` 用英文,`description_cn` 用中文 -- 只把运行时真正需要用户填写的字段放进 `credential_fields` -- 不要把真实 cookie、token、密码、auth state JSON 写进插件文件 -- 默认先放 `auth_state_path`,并指向 `~/.flocks/browser//auth-state.json`;不要添加 `auth_state_json` / `Legacy Auth State JSON` -- 可以补充可选 `username` / `password`,但必须标注它们仅用于认证恢复或浏览器辅助登录,不得作为默认运行时认证入口 -- `cookie`、`csrf_token`、`access_token` 或特定认证头只有在实际站点需要时再补,并在 handler 中明确说明来源与刷新方式 - -## 最小工具 YAML - -MVP 阶段推荐一个分组工具 + 多个 action: - -```yaml -name: acme_portal_ops -description: > - Acme Portal grouped device tool. Use the action parameter to query alerts, - assets, and other WebCLI-backed operations. -description_cn: > - Acme Portal 分组设备工具。通过 action 参数调用告警、资产和其他 WebCLI 能力。 -category: custom -enabled: true -requires_confirmation: false -provider: acme_portal_device -inputSchema: - type: object - properties: - action: - type: string - enum: [list_alerts, get_asset_detail] - description: 统一业务动作名,不要暴露内部实现来源。 - alert_id: - type: string - description: 查询资产详情时可选使用的关联标识。 - required: [action] -handler: - type: script - script_file: acme_portal.handler.py - function: handle -``` - -规则: - -- `provider` 必须与 `_provider.yaml.service_id` 一致 -- 高风险写操作必须设置 `requires_confirmation: true` -- 对外 action 用统一业务语义,不要命名成 `webcli_get_alerts`、`api_get_alerts` - -## 最小 handler 结构 - -MVP 阶段优先单文件 handler,不强制拆 client 模块: - -```python -from __future__ import annotations - -from typing import Any - -from flocks.config.config_writer import ConfigWriter -from flocks.tool.registry import ToolContext, ToolResult - -SERVICE_ID = "acme_portal_device" - - -def _service_config() -> dict[str, Any]: - raw = ConfigWriter.get_api_service_raw(SERVICE_ID) - return raw if isinstance(raw, dict) else {} - - -async def handle(ctx: ToolContext, action: str, **params: Any) -> ToolResult: - cfg = _service_config() - if action == "list_alerts": - return ToolResult(success=True, output={"items": [], "source": "webcli_api"}) - if action == "get_asset_detail": - return ToolResult(success=True, output={"item": None, "source": "webcli_api"}) - return ToolResult(success=False, error=f"Unsupported action: {action}") -``` - -要求: - -- 通过 `ConfigWriter.get_api_service_raw(SERVICE_ID)` 读取当前设备实例配置 -- handler 内部负责认证头构造、分页、超时、重试和响应归一化 -- handler 默认只读取 `auth_state_path` 指向的 `auth-state.json`;如果文件缺失、不是合法 JSON,或没有匹配当前 Base URL 的 Cookie,应返回明确错误并提示重新登录/保存 state -- handler 不要 fallback 到内联 `auth_state_json`;这会把路径字符串、占位文本或过期内容误当 JSON 解析,导致设备测试报错不清晰 -- 如果模板提供了 `username` / `password`,handler 也不要在普通 tool 调用里静默自动登录;这些字段只用于后续由 Rex 进入浏览器认证恢复流程时辅助填表 -- CLI 可选保留,但不要让设备运行时通过 subprocess 调 CLI - -## 组合 API / WebCLI / 处理逻辑 - -同一设备可以混合多种能力来源,但对外仍然是统一 action: - -- `api`:正式 API,可直接调用 -- `webcli_api`:WebCLI 抓到的隐藏接口 -- `process`:本地字段归一化、过滤、聚合、补全 -- `composed`:先调一种来源,再补另一种来源,最后统一输出 - -推荐选择顺序: - -1. 正式 API 稳定可用时,优先正式 API -2. 正式 API 缺能力但 WebCLI 接口稳定时,用 `webcli_api` -3. 需要字段清洗、补全、排序、聚合时,在 handler 内增加 `process` -4. 需要多个来源补齐同一业务结果时,用 `composed` -5. 必须验证码、强动态页面或人工交互时,只记录为 browser fallback,不放进默认设备运行时 -6. 如果某个隐藏接口依赖 `Authorization`、`Tdp-Authentication`、CSRF 等临时头,只有在 handler 已实现可靠的恢复/刷新逻辑时才暴露为默认 action;否则保留在 CLI 或文档中,不放进设备默认动作 - -示例 action 映射: - -```yaml -list_alerts: webcli_api -get_asset_detail: composed -list_users: api -normalize_alert: process -``` - -这里的映射可以写进 handler 常量、注释、`notes` 或单独的设计文档,但不要把“来源类型”直接暴露给最终用户。 - -## 认证失败处理 - -出现以下情况时,优先按认证失效处理: - -- 返回 `401` 或 `403` -- 返回内容出现 `Unauthorized`、`login`、未登录、无权限 -- Cookie / CSRF / access token 明显过期 -- `auth_state_path` 已存在,但接口仍跳转登录页 - -处理原则: - -1. 不要无限重试 -2. 优先返回明确话术,提示 Rex 使用 `flocks browser` 和对应 skill 的认证失败处理去恢复登录态 -3. 如果设备已配置可选 `username` / `password`,Rex 可以在浏览器恢复流程中读取它们辅助登录;如遇验证码、MFA、短信码或人工确认,立即停下并让用户接管 -4. 登录成功后执行 `flocks browser state save ` 更新 cookie/state 文件 -5. 如仍失败,再提示用户重新登录或更新设备配置中的认证字段 -6. 如果保留了 CLI,可用 CLI 做一次最小验证 -7. 验证通过后,再让用户回到设备页点击“刷新设备模板” - -## `_test.yaml` 建议 - -如果该 WebCLI 设备已经有最小可验证动作,建议补一个 `_test.yaml`,至少覆盖: - -- 一个低风险读操作 -- 最小必填参数 -- 成功时的关键字段断言 - -这样后续更新 handler 或认证逻辑时更容易回归验证。 - -## 一句话原则 - -`web2cli` 生成的 CLI 是中间产物;只有在“安全设备接入”场景下,才把它整理成标准 device 插件,让设备页能识别、配置并调用。 diff --git a/.flocks/plugins/skills/web2cli/references/device-tool-requirements.md b/.flocks/plugins/skills/web2cli/references/device-tool-requirements.md index fd1f2658a..716e396e0 100644 --- a/.flocks/plugins/skills/web2cli/references/device-tool-requirements.md +++ b/.flocks/plugins/skills/web2cli/references/device-tool-requirements.md @@ -48,7 +48,8 @@ $HOME/.flocks/plugins/tools/device// - 自定义 CLI / WebCLI 默认认证方式为 `cookie/auth-state`:优先复用浏览器保存的 `auth-state.json`,从中按请求域名/path/secure 规则选择 Cookie,并在需要时读取 localStorage - 默认认证状态文件:`~/.flocks/browser//auth-state.json` - 优先使用 `auth_state_path` 指向 `~/.flocks/browser//auth-state.json` -- 可以额外暴露可选 `username` / `password`,但它们只用于 cookie 失效后的认证恢复,不替代默认的 `auth_state_path` +- 可以额外暴露可选 `username` / `password`,但它们只用于 cookie 失效后的认证恢复,不替代默认的 `auth_state_path`;两者都必须声明为 `storage: secret` +- 如果模板需要保存内联登录态,只能使用 `auth_state`,并且必须声明 `storage: secret` 与 `internal: true`;不要在配置表单里展示 Cookie、localStorage、token 明文 - 不要生成或使用 `auth_state_json` / `Legacy Auth State JSON` 这类内联 JSON 字段;设备配置只保存 state 文件路径,不粘贴 state 文件内容 - 只有在目标站点确实还依赖额外字段时,才补充 `cookie`、`csrf_token`、`access_token` 或特定认证头;这些字段是 `auth_state_path` 之外的补充,不替代默认的 cookie/auth-state - 不要把 `cookie` 或 `token` 设计成和 `auth-state` 并列的多个默认入口;如果用户提供的是 state 文件路径,必须写入 `auth_state_path` @@ -99,11 +100,11 @@ credential_fields: default: "~/.flocks/browser/acme-portal/auth-state.json" - key: username label: Username - storage: config + storage: secret config_key: username - input_type: text + input_type: password required: false - description: 仅在 cookie 失效后需要 Agent 辅助登录刷新 state 时填写 + description: 仅在 cookie 失效后需要 Agent 辅助登录刷新 state 时填写;不会明文写入数据库 - key: password label: Password storage: secret @@ -112,6 +113,14 @@ credential_fields: input_type: password required: false description: 仅在 cookie 失效后需要 Agent 辅助登录刷新 state 时填写 + - key: auth_state + label: Auth State + storage: secret + config_key: auth_state + input_type: password + required: false + internal: true + description: 内部登录态字段;不要在表单中展示 Cookie/localStorage/token 明文 - key: cookie label: Cookie storage: secret @@ -127,6 +136,14 @@ credential_fields: defaults: timeout: 30 category: custom +browser_auth: + login_url: "/login" + username_selector: "input[name=username]" + password_selector: "input[name=password]" + submit_selector: "button[type=submit]" + success_check: + url_not_contains: "/login" + selector: ".main-layout" notes: | WebCLI 设备建议优先复用稳定隐藏接口,不建议把浏览器自动化作为默认运行时。 若返回 401/403、跳转登录页或 CSRF 失效,应先按认证失效处理。 @@ -259,9 +276,9 @@ normalize_alert: process 处理原则: 1. 不要无限重试 -2. 优先返回明确话术,提示 Rex 使用 `flocks browser` 和对应 skill 的认证失败处理去恢复登录态 -3. 如果设备已配置可选 `username` / `password`,Rex 可以在浏览器恢复流程中读取它们辅助登录;如遇验证码、MFA、短信码或人工确认,立即停下并让用户接管 -4. 登录成功后执行 `flocks browser state save ` 更新 cookie/state 文件 +2. 优先调用 `flocks.browser.device_auth.ensure_browser_auth_state(...)` 尝试恢复登录态 +3. 如果设备已配置可选 `username` / `password`,该 helper 可以用它们辅助登录;如遇验证码、MFA、短信码或人工确认,立即停下并让用户接管 +4. 登录成功后执行 `flocks browser state save ` 或由 helper 调用 `save_state(...)` 更新 cookie/state 文件 5. 如仍失败,再提示用户重新登录或更新设备配置中的认证字段 6. 如果保留了 CLI,可用 CLI 做一次最小验证 7. 验证通过后,再让用户回到设备页点击“刷新设备模板” diff --git a/.flocks/plugins/skills/user-defined-page-builder/SKILL.md b/.flocks/plugins/skills/webui-page-builder/SKILL.md similarity index 68% rename from .flocks/plugins/skills/user-defined-page-builder/SKILL.md rename to .flocks/plugins/skills/webui-page-builder/SKILL.md index d780e404f..c2342dbf0 100644 --- a/.flocks/plugins/skills/user-defined-page-builder/SKILL.md +++ b/.flocks/plugins/skills/webui-page-builder/SKILL.md @@ -1,40 +1,40 @@ --- -name: user-defined-page-builder +name: webui-page-builder category: system -description: Guide users to create, develop, hide, or delete user-defined custom pages that appear in the WebUI left navigation under Home, with live preview and no restart required. Also guide development of page-scoped backend APIs through the User Defined Page Backend API Runtime when built-in APIs are insufficient. Trigger when the user asks to create, remove, or delete a custom page, user-defined page, dashboard, navigation tab, integrate custom APIs for a page, or sends messages such as "create a custom page", "delete custom page", "remove user-defined page", "创建自定义页面", "删除自定义页面", "用户自定义页面", "自定义页面", "左侧导航页面", "首页下面的页面", "页面数据来源", "自定义 API", or wants help understanding how custom pages work in Flocks. +description: Guide users to create, develop, hide, or delete WebUI page plugins that appear in the WebUI left navigation under Home, with live preview and no restart required. Also guide development of page-scoped backend APIs through the WebUI Page Backend API Runtime when built-in APIs are insufficient. Trigger when the user asks to create, remove, or delete a WebUI contract page, WebUI page, dashboard, navigation tab, integrate custom APIs for a page, or sends messages such as "create a WebUI contract page", "delete WebUI contract page", "remove WebUI page", "创建WebUI 契约页面", "删除WebUI 契约页面", "用户WebUI 契约页面", "WebUI 契约页面", "左侧导航页面", "首页下面的页面", "页面数据来源", "自定义 API", or wants help understanding how WebUI contract pages work in Flocks. --- -# User Defined Page Builder +# WebUI Page Builder -When the user wants to create **user-defined custom pages** (shown in the WebUI left navigation under **Home**), first explain the feature clearly, then guide them through creation and development. +When the user wants to create **WebUI page plugins** (shown in the WebUI left navigation under **Home**), first explain the feature clearly, then guide them through creation and development. ## Core Principles - **Language**: Detect the user's language from their messages or UI locale. Conduct the **entire conversation in the user's language** (Chinese or English). Do not switch languages mid-session. -- **Admin-required notice**: Creating, editing, hiding, deleting, importing, or exporting user-defined pages requires administrator privileges. Before starting any write workflow, remind the user that the operation must be performed by an admin. This skill does **not** verify the user's role; WebUI visibility and backend APIs enforce authorization. +- **Admin-required notice**: Creating, editing, hiding, deleting, importing, or exporting WebUI pages requires administrator privileges. Before starting any write workflow, remind the user that the operation must be performed by an admin. This skill does **not** verify the user's role; WebUI visibility and backend APIs enforce authorization. - **Explain before acting**: If the user only asks what the feature is, explain fully before creating anything. - **Confirm once**: Before creating, confirm `pageId` (lowercase English + hyphens), `title` (navigation label in the user's language), and optional `icon` (Lucide icon name). -- **User space only**: Read and write only under `~/.flocks/plugins/user_defined_pages/`. -- **Final location check**: After finishing any page development, verify that all user-defined page files are stored under `~/.flocks/plugins/user_defined_pages//`. They must **not** remain in the project code directories such as `webui/`, `flocks/`, `tests/`, or `docs/`. -- **SDK only**: Page code may import only `react` and `@flocks/user-defined-page-sdk` (`Card`, `api`, `useCurrentUser`). +- **WebUI plugin space only**: Read and write only under `~/.flocks/plugins/contracts/webui/`. +- **Final location check**: After finishing any page development, verify that all WebUI page files are stored under `~/.flocks/plugins/contracts/webui//`. They must **not** remain in the project code directories such as `webui/`, `flocks/`, `tests/`, or `docs/`. +- **SDK only**: Page code may import only `react` and `@flocks/webui-contract-sdk` (`Card`, `api`, `useCurrentUser`). - **Never write `dist/`**: Build artifacts are generated automatically. -- **Auth-aware**: All `/api/user-defined-pages/*` routes require authentication. Prefer **direct file writes** for Rex; use API Token only when calling HTTP from non-browser clients. Never embed tokens in page source. -- **Page-scoped backend**: When built-in `/api/*` is insufficient, use the User Defined Page Backend API Runtime design: page APIs live under the page directory and are exposed only at `/api/user-defined-pages//api/*`. +- **Auth-aware**: All `/api/contracts/webui/pages/*` routes require authentication. Prefer **direct file writes** for Rex; use API Token only when calling HTTP from non-browser clients. Never embed tokens in page source. +- **Page-scoped backend**: When built-in `/api/*` is insufficient, use the WebUI Page Backend API Runtime design: page APIs live under the page directory and are exposed only at `/api/contracts/webui/pages//api/*`. ## Authentication -Flocks protects **all HTTP API paths by default** (including `/api/user-defined-pages/*`). Only bootstrap, static assets, and a few public endpoints are exempt. Understand who needs what credential: +Flocks protects **all HTTP API paths by default** (including `/api/contracts/webui/pages/*`). Only bootstrap, static assets, and a few public endpoints are exempt. Understand who needs what credential: ### WebUI (browser) - User must be **logged in** (session cookie `flocks_session`). - The WebUI axios client sends cookies automatically (`withCredentials: true`). - Navigation, page host, bundle loading, and in-page `api` calls all reuse this session — **no extra token setup** for end users. -- If the user is not logged in or the session expired, user-defined pages and related APIs return **401**. +- If the user is not logged in or the session expired, WebUI pages and related APIs return **401**. ### Rex / Agent (recommended: file writes, no HTTP auth) -When creating or editing pages, first remind the user that the operation requires admin privileges, then **write files directly** under `~/.flocks/plugins/user_defined_pages//`: +When creating or editing pages, first remind the user that the operation requires admin privileges, then **write files directly** under `~/.flocks/plugins/contracts/webui//`: - No HTTP request → no API Token needed. - The file watcher detects changes, rebuilds, and publishes SSE events automatically. @@ -79,7 +79,7 @@ All `curl` examples in this skill use `Authorization: Bearer ` — substi API Token authenticates as a synthetic **admin service identity** (`api-token-service`). It is for automation, not for end-user page rendering. -### Custom page code (`@flocks/user-defined-page-sdk` `api`) +### WebUI contract page code (`@flocks/webui-contract-sdk` `api`) - The SDK `api` helper is the WebUI axios client — it sends the **logged-in user's session cookie**, not an API Token. - Page code may call other `/api/*` endpoints (alerts, sessions, etc.) while the user is logged in. @@ -89,25 +89,25 @@ API Token authenticates as a synthetic **admin service identity** (`api-token-se **Chinese example**: -> 自定义页面相关接口都需要登录鉴权。普通用户可以查看和使用已发布页面,但创建、修改、隐藏、删除、导入或导出页面需要管理员权限。我(Rex)在开始这类写操作前会提醒需要管理员操作,通常直接读写 `~/.flocks/plugins/user_defined_pages/` 目录,不经过 HTTP。若用脚本调管理 API,需在服务端配置 `server_api_token` 并在请求头携带 Bearer Token。 +> WebUI 契约页面相关接口都需要登录鉴权。普通用户可以查看和使用已发布页面,但创建、修改、隐藏、删除、导入或导出页面需要管理员权限。我(Rex)在开始这类写操作前会提醒需要管理员操作,通常直接读写 `~/.flocks/plugins/contracts/webui/` 目录,不经过 HTTP。若用脚本调管理 API,需在服务端配置 `server_api_token` 并在请求头携带 Bearer Token。 **English example**: -> User-defined page APIs require authentication. Regular users can view and use published pages, but creating, editing, hiding, deleting, importing, or exporting pages requires admin privileges. I (Rex) remind the user before starting these write operations and usually read/write `~/.flocks/plugins/user_defined_pages/` directly without HTTP. Non-browser management API clients must send a Bearer API Token from `server_api_token` in `~/.flocks/config/.secret.json`. +> WebUI page APIs require authentication. Regular users can view and use published pages, but creating, editing, hiding, deleting, importing, or exporting pages requires admin privileges. I (Rex) remind the user before starting these write operations and usually read/write `~/.flocks/plugins/contracts/webui/` directly without HTTP. Non-browser management API clients must send a Bearer API Token from `server_api_token` in `~/.flocks/config/.secret.json`. ## First Reply Must Cover Explain these points in the user's language: 1. **What it is**: Custom React pages under the Home section of the left navigation — for alert dashboards, asset views, duty screens, etc. -2. **Where files live**: `~/.flocks/plugins/user_defined_pages//` in the user space, **not** in the project code directory. -3. **How it appears**: After creation, a nav item shows under Home; route is `/user-defined-pages/`. +2. **Where files live**: `~/.flocks/plugins/contracts/webui//` in the user space, **not** in the project code directory. +3. **How it appears**: After creation, a nav item shows under Home; route is `/contracts/webui/`. 4. **How to develop**: Describe requirements in chat; you write `src/Page.tsx`; saving triggers auto-build; **no restart** required. 5. **Live updates**: Source changes rebuild automatically; open pages and navigation refresh via SSE. 6. **How to remove**: Tell the user both options below — hiding from nav (reversible) and permanently deleting the page directory. 7. **Authentication and authorization**: WebUI uses login session automatically. Regular users can use published pages. Creating/modifying pages requires admin privileges; Rex should remind the user before write operations but does not verify roles in this skill; scripts calling management APIs need `server_api_token` (see **Authentication** above). -8. **Data sources**: Built-in `/api/*` endpoints, page-scoped backend APIs (`/api/user-defined-pages//api/*`), or workflows (`/api/workflow/{id}/run`) (see **Backend Data & API Extension** below). -9. **Backup + restart/upgrade continuity**: Back up the full page directory and explain that Flocks scans/rebuilds pages from `~/.flocks/plugins/user_defined_pages/` after restart or upgrade. +8. **Data sources**: Built-in `/api/*` endpoints, page-scoped backend APIs (`/api/contracts/webui/pages//api/*`), or workflows (`/api/workflow/{id}/run`) (see **Backend Data & API Extension** below). +9. **Backup + restart/upgrade continuity**: Back up the full page directory and explain that Flocks scans/rebuilds pages from `~/.flocks/plugins/contracts/webui/` after restart or upgrade. Then ask whether the user already has a page idea. If they have an idea, remind them that creation requires admin privileges, then start creation. If they do not have an idea, offer 2–3 example scenarios. @@ -120,7 +120,7 @@ Then ask whether the user already has a page idea. If they have an idea, remind ## Directory Layout ```text -~/.flocks/plugins/user_defined_pages// +~/.flocks/plugins/contracts/webui// manifest.json src/index.tsx src/Page.tsx @@ -136,7 +136,7 @@ Then ask whether the user already has a page idea. If they have an idea, remind Requires a valid `server_api_token` (see **Authentication**). Rex should prefer Option B unless API is explicitly required. ```bash -curl -s -X POST http://127.0.0.1:8000/api/user-defined-pages \ +curl -s -X POST http://127.0.0.1:8000/api/contracts/webui/pages \ -H "Content-Type: application/json" \ -H "Authorization: Bearer " \ -d '{"id":"alert-dashboard","title":"Alert Dashboard","icon":"BarChart3","order":100}' @@ -146,7 +146,7 @@ Chinese example title: `"title":"告警看板"`. ### Option B — Write files directly (preferred for Rex) -Create under `~/.flocks/plugins/user_defined_pages//`: +Create under `~/.flocks/plugins/contracts/webui//`: **manifest.json** @@ -154,7 +154,7 @@ Create under `~/.flocks/plugins/user_defined_pages//`: { "id": "alert-dashboard", "title": "Alert Dashboard", - "route": "/user-defined-pages/alert-dashboard", + "route": "/contracts/webui/alert-dashboard", "icon": "BarChart3", "order": 100, "enabled": true, @@ -179,7 +179,7 @@ Use the manifest `title` for the card heading. Keep in-page status text in the u ```tsx import { useEffect, useState } from 'react'; -import { Card } from '@flocks/user-defined-page-sdk'; +import { Card } from '@flocks/webui-contract-sdk'; export default function Page() { const [ready, setReady] = useState(false); @@ -204,22 +204,22 @@ For Chinese pages, use Chinese copy inside the component, e.g. `{ready ? '页面 2. Identify data sources — built-in `/api/*`, page-scoped backend APIs, workflows, or external systems that need server-side proxying. 3. Edit `src/Page.tsx` based on requirements (add more files under `src/` if needed). 4. On save, the system rebuilds automatically. If build fails, read `dist/meta.json` → `error` and fix. -5. Manual rebuild: `POST /api/user-defined-pages//build` -6. Before wrapping up, run a final location check: every page file created for the user must be under `~/.flocks/plugins/user_defined_pages//`; do not leave page source, API handlers, assets, or drafts in the repository code directories. +5. Manual rebuild: `POST /api/contracts/webui/pages//build` +6. Before wrapping up, run a final location check: every page file created for the user must be under `~/.flocks/plugins/contracts/webui//`; do not leave page source, API handlers, assets, or drafts in the repository code directories. ## Backup and Restore Always provide this backup command in the first explanation and in the final wrap-up: ```bash -cp -a ~/.flocks/plugins/user_defined_pages/ ~/.flocks/workspace/outputs//-backup +cp -a ~/.flocks/plugins/contracts/webui/ ~/.flocks/workspace/outputs//-backup ``` -Restore by copying the backup directory back to `~/.flocks/plugins/user_defined_pages//`. After restart (or immediately if watcher is active), the page will be scanned and available again. +Restore by copying the backup directory back to `~/.flocks/plugins/contracts/webui//`. After restart (or immediately if watcher is active), the page will be scanned and available again. ## Backend Data & API Extension -When a custom page needs backend logic or external data that built-in APIs do not provide, use a **page-scoped backend API runtime**. +When a WebUI contract page needs backend logic or external data that built-in APIs do not provide, use a **page-scoped backend API runtime**. ### Design Principle @@ -228,7 +228,7 @@ Do **not** register arbitrary global FastAPI routes such as `/api/my-dashboard/s The page backend should be scoped to the page namespace: ```text -/api/user-defined-pages//api/{path:path} +/api/contracts/webui/pages//api/{path:path} ``` This keeps page APIs tied to page lifecycle, permissions, logs, hot reload, deletion, and future UI management. @@ -236,8 +236,8 @@ This keeps page APIs tied to page lifecycle, permissions, logs, hot reload, dele ### Architecture ```text -User Defined Page (src/Page.tsx) - └─ SDK api ──► /api/user-defined-pages//api/* (page-scoped backend) +WebUI Page (src/Page.tsx) + └─ SDK api ──► /api/contracts/webui/pages//api/* (page-scoped backend) ├─► /api/workflow/{id}/run (multi-step workflows) └─► /api/* (built-in Flocks APIs) ``` @@ -247,7 +247,7 @@ User Defined Page (src/Page.tsx) When a page needs backend code, add an `api/` directory inside that page: ```text -~/.flocks/plugins/user_defined_pages// +~/.flocks/plugins/contracts/webui// manifest.json src/Page.tsx api/ @@ -277,7 +277,7 @@ routes: Rules: -- `path` must start with `/` and is always scoped under `/api/user-defined-pages//api`. +- `path` must start with `/` and is always scoped under `/api/contracts/webui/pages//api`. - `handler` points to a callable in `api/handlers.py`. - Keep route count small and page-specific. - Prefer read-only `GET` for dashboards; use `POST` for actions. @@ -318,22 +318,24 @@ Implementation expectations for the runtime: The SDK `api` helper sends the logged-in user's session cookie: ```tsx -const res = await api.get('/api/user-defined-pages/alert-dashboard/api/stats'); +const res = await api.page.get('/stats'); ``` -If the SDK later provides a page helper, prefer: +For data access contracts, use the contract helper instead of hand-writing operation URLs: ```tsx -const res = await api.page.get('/stats'); +const res = await api + .contract('soc/alerts', 'soc.alerts.operations') + .operation('list', { params: { limit: 100 } }); ``` -Until `api.page` exists, use explicit `/api/user-defined-pages//api/*` paths. +This maps to `/api/contracts/webui/pages//access//operations/`. ### Other extension paths | Need | Approach | Call from page | |------|----------|----------------| -| Page-specific backend data | `api/routes.yaml` + `api/handlers.py` | `/api/user-defined-pages//api/*` | +| Page-specific backend data | `api/routes.yaml` + `api/handlers.py` | `/api/contracts/webui/pages//api/*` | | External REST API needed by one page | Page handler proxies it server-side | same | | Local compute / file transform for one page | Page handler | same | | Multi-step orchestration | Workflow under `~/.flocks/plugins/workflows//` | `POST /api/workflow//run` | @@ -344,10 +346,10 @@ Until `api.page` exists, use explicit `/api/user-defined-pages//api/*` p | Action | Method | |--------|--------| -| List page API routes | `GET /api/user-defined-pages//api` | -| Call page API | `GET/POST/... /api/user-defined-pages//api/` | -| Reload page API | `POST /api/user-defined-pages//api/reload` | -| Read page detail/build info | `GET /api/user-defined-pages/` | +| List page API routes | `GET /api/contracts/webui/pages//api` | +| Call page API | `GET/POST/... /api/contracts/webui/pages//api/` | +| Reload page API | `POST /api/contracts/webui/pages//api/reload` | +| Read page detail/build info | `GET /api/contracts/webui/pages/` | If these endpoints are not implemented yet, treat this section as the target design and implement the backend runtime before promising the feature as available. @@ -355,30 +357,30 @@ If these endpoints are not implemented yet, treat this section as the target des - Do not register user routes globally under `/api/custom/...`. - Page API code is trusted local plugin code, not sandboxed untrusted code. -- Page code may only import `react` and `@flocks/user-defined-page-sdk` — backend logic lives in `api/handlers.py`, not in page TSX. +- Page code may only import `react` and `@flocks/webui-contract-sdk` — backend logic lives in `api/handlers.py`, not in page TSX. - Page API routes are page-scoped; deleting the page should remove its backend routes as well. ### Explain to users (when page needs custom data) **Chinese example**: -> 如果内置 API 不够用,我们按页面专属后端 API 的设计来做:在 `~/.flocks/plugins/user_defined_pages//api/` 下定义 `routes.yaml` 和 `handlers.py`,后端统一暴露到 `/api/user-defined-pages//api/*`。密钥只在服务端读取,不会写进页面代码。 +> 如果内置 API 不够用,我们按页面专属后端 API 的设计来做:在 `~/.flocks/plugins/contracts/webui//api/` 下定义 `routes.yaml` 和 `handlers.py`,后端统一暴露到 `/api/contracts/webui/pages//api/*`。密钥只在服务端读取,不会写进页面代码。 **English example**: -> When built-in APIs are insufficient, use the page-scoped backend API design: define `api/routes.yaml` and `api/handlers.py` under `~/.flocks/plugins/user_defined_pages//`, then expose them through `/api/user-defined-pages//api/*`. Secrets stay server-side. +> When built-in APIs are insufficient, use the page-scoped backend API design: define `api/routes.yaml` and `api/handlers.py` under `~/.flocks/plugins/contracts/webui//`, then expose them through `/api/contracts/webui/pages//api/*`. Secrets stay server-side. ## Useful APIs | Action | Method | |--------|--------| -| List | `GET /api/user-defined-pages?enabledOnly=true` | -| Detail | `GET /api/user-defined-pages/` | -| Save source | `PUT /api/user-defined-pages/` with `{"sourcePath":"src/Page.tsx","sourceContent":"..."}` | -| Update manifest | `PUT /api/user-defined-pages/` with `{"manifest":{"title":"New Title","order":50}}` | -| Rebuild | `POST /api/user-defined-pages//build` | -| Hide from nav | `PUT /api/user-defined-pages/` with `{"manifest":{"enabled":false}}` | -| Delete permanently | Remove `~/.flocks/plugins/user_defined_pages//` (see below) | +| List | `GET /api/contracts/webui/pages?enabledOnly=true` | +| Detail | `GET /api/contracts/webui/pages/` | +| Save source | `PUT /api/contracts/webui/pages/` with `{"sourcePath":"src/Page.tsx","sourceContent":"..."}` | +| Update manifest | `PUT /api/contracts/webui/pages/` with `{"manifest":{"title":"New Title","order":50}}` | +| Rebuild | `POST /api/contracts/webui/pages//build` | +| Hide from nav | `PUT /api/contracts/webui/pages/` with `{"manifest":{"enabled":false}}` | +| Delete permanently | Remove `~/.flocks/plugins/contracts/webui//` (see below) | ## Remove or Delete a Page @@ -389,7 +391,7 @@ Always explain both approaches when the user asks how to delete, or proactively Update manifest so the page no longer appears in the left nav, but files are kept: ```bash -curl -s -X PUT http://127.0.0.1:8000/api/user-defined-pages/ \ +curl -s -X PUT http://127.0.0.1:8000/api/contracts/webui/pages/ \ -H "Content-Type: application/json" \ -H "Authorization: Bearer " \ -d '{"manifest":{"enabled":false}}' @@ -404,7 +406,7 @@ To restore later, set `"enabled": true` again. Remove the entire page directory under user space: ```bash -rm -rf ~/.flocks/plugins/user_defined_pages/ +rm -rf ~/.flocks/plugins/contracts/webui/ ``` Only do this **after confirming** with the user — this cannot be undone (unless they have backups). After deletion, the nav item disappears automatically; no restart required. @@ -413,13 +415,13 @@ Only do this **after confirming** with the user — this cannot be undone (unles > 如果不再需要这个页面,有两种方式: > 1. **从导航隐藏**:把 `manifest.json` 里的 `enabled` 设为 `false`,页面文件仍保留,以后可恢复; -> 2. **彻底删除**:删除目录 `~/.flocks/plugins/user_defined_pages//`,导航标签会消失且无法恢复,请先确认再操作。 +> 2. **彻底删除**:删除目录 `~/.flocks/plugins/contracts/webui//`,导航标签会消失且无法恢复,请先确认再操作。 **English phrasing example**: > To remove a page you have two options: > 1. **Hide from navigation** — set `"enabled": false` in the manifest (files kept, reversible); -> 2. **Delete permanently** — remove `~/.flocks/plugins/user_defined_pages//` (irreversible; confirm with the user first). +> 2. **Delete permanently** — remove `~/.flocks/plugins/contracts/webui//` (irreversible; confirm with the user first). When the user explicitly asks to delete a page, confirm which option they want before acting. @@ -443,14 +445,14 @@ After confirming `pageId` and `title`, create the page and report nav name, rout Iterate on `src/Page.tsx`: - Match WebUI Tailwind styling - Use `api` for built-in `/api/*` data when available -- If built-in APIs are insufficient, design or implement page-scoped APIs under `api/routes.yaml` + `api/handlers.py`, then call `/api/user-defined-pages//api/*` from the page +- If built-in APIs are insufficient, design or implement page-scoped APIs under `api/routes.yaml` + `api/handlers.py`, then call `/api/contracts/webui/pages//api/*` from the page - Tell the user to wait for hot reload after each save ### Step 4 — Wrap up -Before responding, perform a final location check and explicitly confirm that the page files are under `~/.flocks/plugins/user_defined_pages//`, not in the project code directory. +Before responding, perform a final location check and explicitly confirm that the page files are under `~/.flocks/plugins/contracts/webui//`, not in the project code directory. -Summarize page ID, nav label, route, data sources used (built-in API / page API / workflow), the verified storage directory, how to keep editing via chat, how to **hide** (`"enabled": false`) or **permanently delete** (remove `~/.flocks/plugins/user_defined_pages//`), and how to add another page or extend data with page-scoped backend APIs. +Summarize page ID, nav label, route, data sources used (built-in API / page API / workflow), the verified storage directory, how to keep editing via chat, how to **hide** (`"enabled": false`) or **permanently delete** (remove `~/.flocks/plugins/contracts/webui//`), and how to add another page or extend data with page-scoped backend APIs. Also include one concrete backup command and remind the user that restart/upgrade keeps pages because files are stored in user home and startup reconciliation rebuilds missing/old bundles. ## Rex-User Collaboration Loop @@ -476,12 +478,12 @@ For each iteration, Rex should: Do not declare "done" until all are true: - Nav item visible under Home with expected title/icon/order. -- Route works: `/user-defined-pages/`. +- Route works: `/contracts/webui/`. - Frontend behavior matches user requirements. -- If custom backend is used, page API routes work under `/api/user-defined-pages//api/*`. +- If custom backend is used, page API routes work under `/api/contracts/webui/pages//api/*`. - Backup command provided. - Hide/delete options provided. -- Final location check explicitly confirmed (`~/.flocks/plugins/user_defined_pages//` only). +- Final location check explicitly confirmed (`~/.flocks/plugins/contracts/webui//` only). ### If the user is unsure what to provide @@ -498,8 +500,8 @@ Rex should ask only the minimum needed in this order: | Symptom | Action | |---------|--------| | Nav item missing | Check `manifest.enabled`; wait for build | -| Blank / error page | Check `GET /api/user-defined-pages/` → `build.error` | -| Build failed | Fix TSX syntax; only import `react` and `@flocks/user-defined-page-sdk` | +| Blank / error page | Check `GET /api/contracts/webui/pages/` → `build.error` | +| Build failed | Fix TSX syntax; only import `react` and `@flocks/webui-contract-sdk` | | Changes not visible | Confirm file saved under `src/`; try `POST .../build` | | 401 Unauthorized (WebUI) | User not logged in or session expired — re-login | | 401 Unauthorized (curl/script) | Add `Authorization: Bearer `; run `flocks admin generate-api-token` if missing | @@ -507,12 +509,12 @@ Rex should ask only the minimum needed in this order: | Page API 404 | Confirm `api/routes.yaml` path and page ID; reload page API runtime | | Page API 500 | Check handler traceback / page API diagnostics; validate handler return shape | | External API from page fails | Do not call third-party URLs directly from page code — proxy through page-scoped backend | -| Need custom `/api/foo` route | Do not add global routes — use `/api/user-defined-pages//api/foo` | +| Need custom `/api/foo` route | Do not add global routes — use `/api/contracts/webui/pages//api/foo` | ## Do Not - Write pages into `webui/` or `flocks/` code directories -- Leave generated user page source, API handlers, assets, or drafts anywhere outside `~/.flocks/plugins/user_defined_pages//` +- Leave generated user page source, API handlers, assets, or drafts anywhere outside `~/.flocks/plugins/contracts/webui//` - Modify files under `dist/` - Import non-whitelisted npm packages into page code - Skip `pageId` format validation @@ -520,4 +522,4 @@ Rex should ask only the minimum needed in this order: - Ask users to paste API tokens into chat - Register global custom FastAPI routes or write backend logic into `src/Page.tsx` - Call third-party APIs directly from page code with embedded secrets -- Write page backend files outside `~/.flocks/plugins/user_defined_pages//api/` +- Write page backend files outside `~/.flocks/plugins/contracts/webui//api/` diff --git a/.flocks/plugins/skills/workflow-builder/SKILL.md b/.flocks/plugins/skills/workflow-builder/SKILL.md index dfe3947c1..ad3b62e09 100644 --- a/.flocks/plugins/skills/workflow-builder/SKILL.md +++ b/.flocks/plugins/skills/workflow-builder/SKILL.md @@ -237,13 +237,17 @@ flowchart TD **Edge 约束:** - JSON 中用 `"from"` 而非 `"from_"`;`from`/`to` 引用存在的 node id;`order` ≥ 0。 +- **新建 workflow 的每条 edge 必须包含非空 `mapping` 对象**。不要生成无 `mapping` 的 edge;不要只写 `const` 而省略 `mapping`。新建 workflow 默认启用 strict edge mapping,无 `mapping` 会在创建或运行时失败。 ### 3.2 映射规则 - `workflow.md` 每步对应一个节点,`id` 用 snake_case。 - md 中写的输出字段,必须在 `outputs[...]` 中体现。 - md 中 `Tool: xxx` 标记 → 对应节点 `description` 保留。 -- 下游节点如需 `tool.run(..., **inputs)`,用 `edge.mapping`/`edge.const` 规整输入到匹配工具参数形状。 +- 所有 edge 都必须写 `edge.mapping`,只映射下游节点实际需要的字段,避免全量 payload 传递。 +- 下游节点如需 `tool.run(..., **inputs)`,用 `edge.mapping`/`edge.const` 规整输入到匹配工具参数形状;其中 `edge.mapping` 仍然必须非空。 +- 若某条边只是控制流、下游不需要业务字段,也必须映射一个确定存在的小字段(如 `case_id`、`has_results`、`status`);如果没有合适字段,让上游节点写出 `outputs["_edge_context"] = True`,并在该边映射 `{ "_edge_context": "_edge_context" }`。 +- `branch`/`loop` 出边同样必须写 `mapping`。映射源可以来自该分支节点收到的输入 payload(例如 `search_text`、`case_id`、`has_results`),不要求来自 branch 节点自身输出。 - 详细 Mapping 指南见 [reference.md § Edge Mapping](references/reference.md#4-edge-mapping-详细指南)。 ### 3.3 分支/循环与 Join @@ -322,8 +326,9 @@ elif isinstance(obj, str): 1. 用 `json.load` 确认 JSON 格式正确。 2. 对每个 `type="python"` 节点的 `code` 执行 `compile(code, "", "exec")` 确认 Python 语法正确。 -3. 若格式或语法报错,修复后重新写入 `workflow.json` 并再次验证。 -4. 将阶段 1 收集的样例数据保存到 `POST /api/workflow/{id}/sample-inputs`,body 为 `{ "sampleInputs": <样例 JSON 对象> }`。 +3. 检查每条 `edges[]` 都有非空 `mapping` 对象;发现缺失时必须补齐映射并重写 `workflow.json`。 +4. 若格式、语法或 edge mapping 校验报错,修复后重新写入 `workflow.json` 并再次验证。 +5. 将阶段 1 收集的样例数据保存到 `POST /api/workflow/{id}/sample-inputs`,body 为 `{ "sampleInputs": <样例 JSON 对象> }`。 只有以上步骤全部通过后,才能进入第五阶段逐节点测试。 diff --git a/.flocks/plugins/skills/workflow-builder/references/reference.md b/.flocks/plugins/skills/workflow-builder/references/reference.md index 2fa7458aa..59dfa6ae5 100644 --- a/.flocks/plugins/skills/workflow-builder/references/reference.md +++ b/.flocks/plugins/skills/workflow-builder/references/reference.md @@ -69,8 +69,24 @@ ``` 出边示例: ```json -{ "from": "check_risk", "to": "handle_high_risk", "label": "High" }, -{ "from": "check_risk", "to": "handle_normal", "label": "" } +{ + "from": "check_risk", + "to": "handle_high_risk", + "label": "High", + "mapping": { + "case_id": "case_id", + "risk_level": "risk_level" + } +}, +{ + "from": "check_risk", + "to": "handle_normal", + "label": "", + "mapping": { + "case_id": "case_id", + "risk_level": "risk_level" + } +} ``` ### Join 节点 @@ -95,15 +111,23 @@ - **点路径支持**:`mapping: { "user_id": "data.user.id" }` - **根路径引用**:`mapping: { "full_data": "$" }` -### 何时写 mapping +### 新建 workflow 的强制规则 -- 下游只需上游 payload 的一部分字段 -- 字段需要重命名(上游 key 与下游期望 key 不一致) -- 下游要 `tool.run(..., **inputs)`,需把 inputs 规整到匹配工具参数形状 +- **每条 edge 都必须写非空 `mapping` 对象**。新建 workflow 默认启用 strict edge mapping;无 `mapping` 的 edge 会被创建/运行校验拒绝。 +- `const` 只能补充常量,不能替代 `mapping`。即使只需要常量参数,也要映射一个确定存在的小字段作为显式数据契约。 +- 默认只映射下游节点实际读取的字段,不要用 `$` 传递完整 payload,除非下游确实需要完整对象且该对象已经是上游刻意裁剪过的小对象。 +- 字段需要重命名时,用 `mapping` 把下游 key 映射到上游 payload 路径。 +- 下游要 `tool.run(..., **inputs)` 时,用 `edge.mapping`/`edge.const` 把 inputs 规整到工具参数形状。 +- 控制流边如果下游不需要业务字段,也必须映射一个确定存在的小字段(如 `case_id`、`has_results`、`status`);如果没有合适字段,让上游节点写出 `outputs["_edge_context"] = True`,并映射 `{ "_edge_context": "_edge_context" }`。 +- `branch`/`loop` 出边同样必须写 `mapping`。映射源可以来自分支节点收到的输入 payload,例如 `{ "search_text": "search_text" }`。 -### 何时不写 mapping +### 禁止生成的 edge + +```json +{ "from": "step_1", "to": "step_2" } +``` -- 下游可直接消费完整 payload(引擎浅合并 `payload = {**inputs, **outputs}`),且不会造成字段冲突 +上面这种无 `mapping` 的 edge 禁止出现在新建 workflow 中。 ### 避免脆弱映射 @@ -228,9 +252,30 @@ } ], "edges": [ - { "from": "step_1", "to": "check_results" }, - { "from": "check_results", "to": "summarize", "label": "true" }, - { "from": "check_results", "to": "fallback", "label": "false" } + { + "from": "step_1", + "to": "check_results", + "mapping": { + "search_text": "search_text", + "has_results": "has_results" + } + }, + { + "from": "check_results", + "to": "summarize", + "label": "true", + "mapping": { + "search_text": "search_text" + } + }, + { + "from": "check_results", + "to": "fallback", + "label": "false", + "mapping": { + "has_results": "has_results" + } + } ] } ``` @@ -298,7 +343,7 @@ else: ### 工具参数对齐最佳实践 1. 在 `workflow.md` 的输入中直接使用工具参数名 -2. 用 `edge.mapping` 完成上游字段到工具参数名的转换 +2. 每条 edge 都写非空 `edge.mapping`,完成上游字段到下游节点或工具参数名的转换 3. python 节点中 `result = tool.run_safe("xxx", **inputs)` 或按需取参数 4. 仅快速原型时使用 `logic` 节点 5. **默认用 `result["text"]` 取结果**,仅在明确需要结构化数据且已做类型检查时才用 `result["obj"]` diff --git a/.flocks/plugins/skills/workflow-builder/references/workflow_en.md b/.flocks/plugins/skills/workflow-builder/references/workflow_en.md index 18b9d666d..a9f1c1e15 100644 --- a/.flocks/plugins/skills/workflow-builder/references/workflow_en.md +++ b/.flocks/plugins/skills/workflow-builder/references/workflow_en.md @@ -183,6 +183,7 @@ Acceptance checklist: - [ ] Inputs are correctly recognized and parsed. - [ ] Each node has a clear responsibility and outputs fields downstream nodes can read. +- [ ] Every edge in `workflow.json` has a non-empty `mapping` and maps only fields the downstream node needs. - [ ] Branch, filtering, aggregation, or analysis logic matches expectations. - [ ] Output fields and file formats are clear. - [ ] `workflow.md` and `workflow.json` describe the same flow. diff --git a/.flocks/plugins/skills/workflow-builder/references/workflow_template/workflow.md b/.flocks/plugins/skills/workflow-builder/references/workflow_template/workflow.md index 5cb9fb87d..a024dd2a1 100644 --- a/.flocks/plugins/skills/workflow-builder/references/workflow_template/workflow.md +++ b/.flocks/plugins/skills/workflow-builder/references/workflow_template/workflow.md @@ -67,7 +67,8 @@ List thresholds, switches, timeouts, file paths, concurrency settings, and rollb Generation notes for Flocks: - Keep node IDs stable after users start configuring publish modes. -- When adding or renaming outputs, update downstream edges and the runtime contract. +- When adding or renaming outputs, update downstream edges, explicit edge mappings, and the runtime contract. +- Every generated edge in `workflow.json` must contain a non-empty `mapping` object. - Do not store plaintext secrets in this directory. ## 6. Data Flow And Field Contract @@ -106,12 +107,13 @@ Workflow configuration guidance lives in `guide.md`. - `workflow.md` describes intent, module boundaries, field contracts, and validation. - `workflow.json` describes executable nodes, edges, code, triggers, and metadata. - Regeneration should preserve node IDs unless the user explicitly requests a graph change. -- Deleting or renaming a node requires updating edges, mappings, samples, and tests. +- Deleting or renaming a node requires updating edges, explicit mappings, samples, and tests. ## 10. Validation Checklist - [ ] `workflow.md` and `workflow.json` describe the same flow. - [ ] A representative sample input runs successfully. +- [ ] Every edge in `workflow.json` has a non-empty `mapping`. - [ ] At least one edge or error case is documented. - [ ] Publish page only shows capabilities enabled by `config.json`. - [ ] No plaintext secrets are stored in the workflow directory. diff --git a/.flocks/plugins/skills/workflow-builder/references/workflow_zh.md b/.flocks/plugins/skills/workflow-builder/references/workflow_zh.md index 5adb56027..ce0d0c787 100644 --- a/.flocks/plugins/skills/workflow-builder/references/workflow_zh.md +++ b/.flocks/plugins/skills/workflow-builder/references/workflow_zh.md @@ -183,6 +183,7 @@ - [ ] 输入能被正确识别和解析。 - [ ] 每个节点的职责清晰且输出可被下游读取。 +- [ ] `workflow.json` 中每条 edge 都包含非空 `mapping`,且只映射下游需要的字段。 - [ ] 分支、过滤、聚合或分析逻辑符合预期。 - [ ] 输出字段和文件格式清晰。 - [ ] `workflow.md` 和 `workflow.json` 描述同一个流程。 diff --git a/.flocks/plugins/tools/device/ngtip_v5_1_5/_provider.yaml b/.flocks/plugins/tools/device/ngtip_v5_1_5/_provider.yaml index c75b81295..03b92cfa8 100644 --- a/.flocks/plugins/tools/device/ngtip_v5_1_5/_provider.yaml +++ b/.flocks/plugins/tools/device/ngtip_v5_1_5/_provider.yaml @@ -11,6 +11,7 @@ description: > description_cn: > NGTIP 威胁情报平台 API 服务,覆盖情报查询(端口 8090)和平台功能(管理写入)两类接口。 可分别配置情报查询 APIKEY、平台功能 APIKEY、平台功能 Base URL 和情报查询 Base URL。 +docs_url: "https://agentflocks.github.io/flocks-docs/md/modules/devices/ngtip-integration" auth: type: custom secret: ngtip_query_apikey diff --git a/.flocks/plugins/tools/device/onesandbox_v3/_provider.yaml b/.flocks/plugins/tools/device/onesandbox_v3/_provider.yaml new file mode 100644 index 000000000..f3123192a --- /dev/null +++ b/.flocks/plugins/tools/device/onesandbox_v3/_provider.yaml @@ -0,0 +1,45 @@ +name: onesandbox +vendor: threatbook +service_id: onesandbox_api +version: "3" +integration_type: device +description: > + OneSandbox v3 API integration for sandbox sample submission, file report + lookup and download, hash reputation, SafeSkill, and CheckURL operations. + Authentication uses an apikey query parameter. +description_cn: > + OneSandbox v3 API 接入模板,支持沙箱样本提交、文件报告查询/下载、 + Hash 信誉、SafeSkill 技能检测和 CheckURL URL 检测能力。认证方式为 + query 参数 apikey。 +auth: + type: api_key + key: apikey +credential_fields: + - key: base_url + label: Base URL + storage: config + config_key: base_url + input_type: url + required: true + placeholder: "https://sandbox.example.com" + - key: apikey + label: API Key + storage: secret + config_key: apikey + secret_id: onesandbox_apikey + input_type: password + required: true +defaults: + timeout: 120 + category: custom + product_version: "v3" +notes: | + OneSandbox API 手册说明 apikey 认证优先使用 query 参数,尤其文件上传等大表单 + 场景不要把 apikey 放在 form-data 中,避免表单解析失败导致 bad apikey。 + + Base URL 示例: + - https://sandbox.example.com + - http://10.0.0.10:8080 + + 文件下载类 action 会把返回的二进制文件保存到 + `~/.flocks/workspace/outputs//` 下,并在结果中返回 file_path。 diff --git a/.flocks/plugins/tools/device/onesandbox_v3/_test.yaml b/.flocks/plugins/tools/device/onesandbox_v3/_test.yaml new file mode 100644 index 000000000..af9e61a2b --- /dev/null +++ b/.flocks/plugins/tools/device/onesandbox_v3/_test.yaml @@ -0,0 +1,13 @@ +provider: onesandbox_api +tests: + onesandbox_query: + - label: "Readiness check" + input: + action: readyz + - label: "Get API version" + input: + action: api_version + - label: "Query file report" + input: + action: file_report + sha256: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" diff --git a/.flocks/plugins/tools/device/onesandbox_v3/onesandbox.handler.py b/.flocks/plugins/tools/device/onesandbox_v3/onesandbox.handler.py new file mode 100644 index 000000000..09665b5a4 --- /dev/null +++ b/.flocks/plugins/tools/device/onesandbox_v3/onesandbox.handler.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import datetime as dt +import os +from pathlib import Path +from typing import Any + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + +SERVICE_ID = "onesandbox_api" +DEFAULT_TIMEOUT = 120 + + +def _get_secret_manager(): + from flocks.security import get_secret_manager + + return get_secret_manager() + + +def _resolve_ref(value: Any) -> str | None: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:") : -1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:") : -1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _ensure_scheme(value: str) -> str: + if value and not value.startswith(("http://", "https://")): + return "https://" + value + return value + + +def _bool_value(value: Any, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +def _runtime_config() -> tuple[str, str, int, bool]: + raw = _service_config() + base_url = _ensure_scheme( + ( + _resolve_ref(raw.get("base_url")) + or _resolve_ref(raw.get("baseUrl")) + or os.getenv("ONESANDBOX_BASE_URL") + or "" + ).rstrip("/") + ) + apikey = ( + _resolve_ref(raw.get("apikey")) + or _resolve_ref(raw.get("api_key")) + or _resolve_ref(raw.get("apiKey")) + or _get_secret_manager().get("onesandbox_apikey") + or os.getenv("ONESANDBOX_APIKEY") + or "" + ) + timeout = raw.get("timeout", DEFAULT_TIMEOUT) + try: + timeout = int(timeout) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + verify_ssl = _bool_value(raw.get("verify_ssl"), False) + custom_settings = raw.get("custom_settings") + if isinstance(custom_settings, dict): + verify_ssl = _bool_value(custom_settings.get("verify_ssl"), verify_ssl) + return base_url, apikey, timeout, verify_ssl + + +def _output_dir() -> Path: + today = dt.datetime.now().strftime("%Y-%m-%d") + path = Path.home() / ".flocks" / "workspace" / "outputs" / today + path.mkdir(parents=True, exist_ok=True) + return path + + +def _safe_filename(value: str) -> str: + keep = [] + for char in value: + keep.append(char if char.isalnum() or char in {".", "-", "_"} else "_") + return "".join(keep).strip("._") or "onesandbox_download" + + +def _query_params(apikey: str, params: dict[str, Any], *keys: str) -> dict[str, Any]: + query: dict[str, Any] = {"apikey": apikey} + for key in keys: + value = params.get(key) + if value is not None and value != "": + query[key] = value + return query + + +def _body_from_params(params: dict[str, Any], *keys: str) -> dict[str, Any]: + body = params.get("body") + if isinstance(body, dict): + return body + return { + key: params[key] + for key in keys + if key in params and params[key] is not None + } + + +def _json_result(action: str, data: Any) -> ToolResult: + metadata = {"source": "OneSandbox", "action": action} + if isinstance(data, dict): + code = data.get("response_code") + if code is not None and code not in (0, "0"): + return ToolResult( + success=False, + error=f"OneSandbox API error (code={code}): {data.get('verbose_msg') or data}", + metadata=metadata, + ) + return ToolResult(success=True, output=data.get("data", data), metadata=metadata) + return ToolResult(success=True, output=data, metadata=metadata) + + +async def _json_request( + action: str, + method: str, + path: str, + query: dict[str, Any] | None = None, + body: dict[str, Any] | None = None, + auth_required: bool = True, +) -> ToolResult: + base_url, apikey, timeout, verify_ssl = _runtime_config() + if not base_url: + return ToolResult(success=False, error="OneSandbox base_url is not configured.") + if auth_required and not apikey: + return ToolResult(success=False, error="OneSandbox apikey is required.") + url = f"{base_url}{path}" + params = dict(query or {}) + if auth_required: + params.setdefault("apikey", apikey) + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + async with session.request( + method, + url, + params=params, + json=body if method not in {"GET", "DELETE"} else None, + headers={"Content-Type": "application/json"}, + ssl=verify_ssl, + ) as resp: + text = await resp.text() + if resp.status >= 400: + return ToolResult(success=False, error=f"HTTP {resp.status}: {text[:500]}") + if text.strip() == "OK": + return ToolResult(success=True, output={"status": "OK"}, metadata={"source": "OneSandbox", "action": action}) + try: + data = await resp.json(content_type=None) + except Exception: + return ToolResult(success=True, output=text, metadata={"source": "OneSandbox", "action": action}) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + return _json_result(action, data) + + +async def _download_request(action: str, params: dict[str, Any]) -> ToolResult: + base_url, apikey, timeout, verify_ssl = _runtime_config() + if not base_url: + return ToolResult(success=False, error="OneSandbox base_url is not configured.") + if not apikey: + return ToolResult(success=False, error="OneSandbox apikey is required.") + query = _query_params(apikey, params, "sha256", "md5", "sha1", "sandbox_type", "type") + if len(query) == 1: + return ToolResult(success=False, error="download_file_report requires at least one of sha256, md5, or sha1.") + report_type = str(params.get("type") or "report") + target_name = _safe_filename(str(params.get("output_filename") or f"onesandbox_{report_type}_{query.get('sha256') or query.get('md5') or query.get('sha1')}.bin")) + target = _output_dir() / target_name + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + async with session.get(f"{base_url}/v3/file/report", params=query, ssl=verify_ssl) as resp: + content = await resp.read() + if resp.status >= 400: + return ToolResult(success=False, error=f"HTTP {resp.status}: {content[:500].decode(errors='replace')}") + content_type = resp.headers.get("Content-Type", "") + if "json" in content_type.lower(): + try: + return _json_result(action, await resp.json(content_type=None)) + except Exception: + pass + target.write_bytes(content) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + return ToolResult( + success=True, + output={"file_path": str(target), "bytes": target.stat().st_size}, + metadata={"source": "OneSandbox", "action": action}, + ) + + +async def _upload_file(params: dict[str, Any]) -> ToolResult: + base_url, apikey, timeout, verify_ssl = _runtime_config() + if not base_url: + return ToolResult(success=False, error="OneSandbox base_url is not configured.") + if not apikey: + return ToolResult(success=False, error="OneSandbox apikey is required.") + file_path = params.get("file_path") + if not file_path: + return ToolResult(success=False, error="upload_file requires file_path.") + path = Path(str(file_path)).expanduser() + if not path.exists() or not path.is_file(): + return ToolResult(success=False, error=f"file_path does not exist: {path}") + + query = _query_params(apikey, params, "filename", "password", "run_time", "mode") + form = aiohttp.FormData() + form.add_field( + "file", + path.open("rb"), + filename=str(params.get("filename") or path.name), + content_type="application/octet-stream", + ) + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + async with session.post(f"{base_url}/v3/file/upload", params=query, data=form, ssl=verify_ssl) as resp: + text = await resp.text() + if resp.status >= 400: + return ToolResult(success=False, error=f"HTTP {resp.status}: {text[:500]}") + data = await resp.json(content_type=None) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + finally: + for field in getattr(form, "_fields", []): + value = field[2] + if hasattr(value, "close"): + value.close() + return _json_result("upload_file", data) + + +async def query(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + del ctx + _, apikey, _, _ = _runtime_config() + if action == "readyz": + return await _json_request(action, "GET", "/health/readyz", auth_required=False) + if action == "livez": + return await _json_request(action, "GET", "/health/livez", auth_required=False) + if action == "api_version": + return await _json_request(action, "GET", "/api/version", auth_required=False) + if action == "file_report": + return await _json_request(action, "GET", "/v3/file/report", _query_params(apikey, params, "sha256", "md5", "sha1", "query_fields")) + if action == "file_queue": + return await _json_request(action, "GET", "/v3/file/queue", _query_params(apikey, params)) + if action == "download_file_report": + return await _download_request(action, params) + if action == "hash_reputation": + return await _json_request(action, "GET", "/v3/hash/reputation", _query_params(apikey, params, "sha256", "md5", "sha1")) + if action == "sdk_policy_get": + return await _json_request(action, "GET", "/v3/sdk/policy", _query_params(apikey, params)) + if action == "safeskill_report": + return await _json_request(action, "GET", "/v3/safeskill/report", _query_params(apikey, params, "sha256", "md5", "sha1", "task_id")) + if action == "checkurl_result": + task_id = params.get("task_id") + if not task_id: + return ToolResult(success=False, error="checkurl_result requires task_id.") + return await _json_request(action, "GET", f"/v3/checkurl/result/{task_id}", _query_params(apikey, params)) + if action in {"checkurl_report", "checkurl_summary"}: + uuid = params.get("uuid") + if not uuid: + return ToolResult(success=False, error=f"{action} requires uuid.") + path = f"/v3/checkurl/{'summary' if action == 'checkurl_summary' else 'report'}/{uuid}" + return await _json_request(action, "GET", path, _query_params(apikey, params, "summary")) + return ToolResult(success=False, error=f"Unsupported OneSandbox query action: {action}") + + +async def ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + del ctx + _, apikey, _, _ = _runtime_config() + if action == "upload_file": + return await _upload_file(params) + if action == "delete_file_report": + return await _json_request(action, "DELETE", "/v3/file/delete", _query_params(apikey, params, "sha256", "md5", "sha1")) + if action == "upload_event": + body = _body_from_params(params, "events", "file_name", "file_size", "sha256", "md5", "sha1", "source", "event_time") + return await _json_request(action, "POST", "/v3/event/upload", _query_params(apikey, params), body) + if action == "hash_reputation_set": + body = _body_from_params(params, "hash", "white", "threat", "classify", "description") + return await _json_request(action, "POST", "/v3/hash/reputation", _query_params(apikey, params), body) + if action == "hash_reputation_delete": + return await _json_request(action, "DELETE", "/v3/hash/reputation", _query_params(apikey, params, "sha256", "md5", "sha1")) + if action == "safeskill_scan": + body = _body_from_params(params, "sha256", "md5", "sha1", "url", "file_name", "file_type") + return await _json_request(action, "POST", "/v3/safeskill/scan", _query_params(apikey, params), body) + if action == "checkurl_scan": + body = _body_from_params(params, "url") + return await _json_request(action, "POST", "/v3/checkurl/scan", _query_params(apikey, params), body) + return ToolResult(success=False, error=f"Unsupported OneSandbox ops action: {action}") diff --git a/.flocks/plugins/tools/device/onesandbox_v3/onesandbox_ops.yaml b/.flocks/plugins/tools/device/onesandbox_v3/onesandbox_ops.yaml new file mode 100644 index 000000000..d05149b02 --- /dev/null +++ b/.flocks/plugins/tools/device/onesandbox_v3/onesandbox_ops.yaml @@ -0,0 +1,77 @@ +name: onesandbox_ops +description: > + OneSandbox v3 mutation grouped tool for sample upload, report deletion, + event upload, hash reputation update/delete, SafeSkill scan, and CheckURL + scan operations. +description_cn: > + OneSandbox v3 写操作分组工具,支持样本上传、报告删除、事件上报、 + Hash 信誉新增/删除、SafeSkill 扫描和 CheckURL 扫描。 +category: custom +enabled: true +requires_confirmation: true +provider: onesandbox_api +inputSchema: + type: object + properties: + action: + type: string + enum: + - upload_file + - delete_file_report + - upload_event + - hash_reputation_set + - hash_reputation_delete + - safeskill_scan + - checkurl_scan + description: 写操作动作名。调用前请确认影响范围。 + file_path: + type: string + description: upload_file 本地样本文件路径。 + filename: + type: string + description: 上传文件名。 + password: + type: string + description: 加密压缩包密码。 + run_time: + type: integer + description: 沙箱运行时间,单位秒,通常 30-300。 + mode: + type: string + description: 上传模式:standard/rapid/custom。 + sha256: + type: string + description: 文件 SHA256。 + md5: + type: string + description: 文件 MD5。 + sha1: + type: string + description: 文件 SHA1。 + hash: + type: string + description: hash_reputation_set 的哈希值。 + white: + type: boolean + description: hash_reputation_set 是否白名单。 + threat: + type: string + description: hash_reputation_set 威胁类型。 + classify: + type: string + description: hash_reputation_set 分类。 + description: + type: string + description: hash_reputation_set 描述。 + url: + type: string + description: checkurl_scan 待检测 URL,或 SafeSkill 相关 URL。 + body: + type: object + description: upload_event 或扩展接口的原始请求体。 + required: + - action +handler: + type: script + script_file: onesandbox.handler.py + function: ops diff --git a/.flocks/plugins/tools/device/onesandbox_v3/onesandbox_query.yaml b/.flocks/plugins/tools/device/onesandbox_v3/onesandbox_query.yaml new file mode 100644 index 000000000..c7e6cfa1f --- /dev/null +++ b/.flocks/plugins/tools/device/onesandbox_v3/onesandbox_query.yaml @@ -0,0 +1,67 @@ +name: onesandbox_query +description: > + OneSandbox v3 read-only grouped tool for health checks, API version, file + reports, task queue, report download, hash reputation, SDK policy, + SafeSkill report, and CheckURL result/report queries. +description_cn: > + OneSandbox v3 只读分组工具,支持健康检查、API 版本、文件报告、任务队列、 + 报告下载、Hash 信誉、SDK 策略、SafeSkill 报告和 CheckURL 结果/报告查询。 +category: custom +enabled: true +requires_confirmation: false +provider: onesandbox_api +inputSchema: + type: object + properties: + action: + type: string + enum: + - readyz + - livez + - api_version + - file_report + - file_queue + - download_file_report + - hash_reputation + - sdk_policy_get + - safeskill_report + - checkurl_result + - checkurl_report + - checkurl_summary + description: 查询动作名。 + sha256: + type: string + description: 文件 SHA256。 + md5: + type: string + description: 文件 MD5。 + sha1: + type: string + description: 文件 SHA1。 + query_fields: + type: string + description: file_report 查询字段,如 summary/multiengines/static/sandbox/email。 + sandbox_type: + type: string + description: download_file_report 沙箱环境。 + type: + type: string + description: download_file_report 下载类型:sample/static/report/pcap/drop/buffer/trace。 + output_filename: + type: string + description: 下载文件保存名。 + task_id: + type: string + description: CheckURL 任务 ID。 + uuid: + type: string + description: CheckURL 报告 UUID。 + summary: + type: boolean + description: 查询 CheckURL report 时是否只返回摘要。 + required: + - action +handler: + type: script + script_file: onesandbox.handler.py + function: query diff --git a/.flocks/plugins/tools/device/onesec_v2_8_2/_provider.yaml b/.flocks/plugins/tools/device/onesec_v2_8_2/_provider.yaml index 991c341d1..4e1b654ff 100644 --- a/.flocks/plugins/tools/device/onesec_v2_8_2/_provider.yaml +++ b/.flocks/plugins/tools/device/onesec_v2_8_2/_provider.yaml @@ -10,6 +10,7 @@ description: > description_cn: > OneSEC 终端安全平台 API 服务。当前配置页需要分别填写 API Key、Secret 和 Base URL;其中 Base URL 可留空,默认使用 `https://console.onesec.net`。 +docs_url: "https://agentflocks.github.io/flocks-docs/md/modules/devices/onesec-integration" auth: type: custom secret: onesec_api_key diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3/_provider.yaml b/.flocks/plugins/tools/device/onesig_v2_5_3/_provider.yaml new file mode 100644 index 000000000..29dd25c7c --- /dev/null +++ b/.flocks/plugins/tools/device/onesig_v2_5_3/_provider.yaml @@ -0,0 +1,54 @@ +name: onesig +vendor: threatbook +service_id: onesig_api +version: "2.5.3" +integration_type: device +description: > + OneSIG strategy API integration using ApiKey + Secret HMAC-SHA1 request + signing. This template targets the third-party strategy API documented in + api-onesig-2.5.3, including device status, global whitelist, global + blacklist, banned whitelist, HTTP blacklist, protection policy, and asset + operations. +description_cn: > + OneSIG 策略 API 接入模板,使用 ApiKey + Secret 的 HMAC-SHA1 签名机制。 + 适用于 api-onesig-2.5.3 文档中的第三方策略 API,覆盖设备状态、全局白名单、 + 全局黑名单、封禁白名单、HTTP 防护、策略配置和资产相关接口。 +auth: + type: custom + secret: onesig_v2_5_3_secret +credential_fields: + - key: base_url + label: Base URL + storage: config + config_key: base_url + input_type: url + required: true + placeholder: "https://device.example.com" + - key: api_key + label: ApiKey + storage: secret + config_key: api_key + secret_id: onesig_v2_5_3_api_key + input_type: password + required: true + - key: secret + label: Secret + storage: secret + config_key: secret + secret_id: onesig_v2_5_3_secret + input_type: password + required: true +defaults: + timeout: 60 + category: custom + product_version: "2.5.3" +notes: | + 本模板对应 OneSIG 第三方策略 API:请求 URL 自动追加 + `apikey=×tamp=&sign=`。 + + 与已有 `onesig_v2_5_3_D20260321` 模板不同: + - 已有 onesig 模板走控制台 Cookie 登录 + `/v3/...` 接口; + - 本模板走第三方 ApiKey/Secret 签名 + `/api/v3/...` 接口。 + + 若只需要控制台完整能力,优先使用已有 onesig 模板;若客户只开放策略 API, + 使用本模板。 diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3/_test.yaml b/.flocks/plugins/tools/device/onesig_v2_5_3/_test.yaml new file mode 100644 index 000000000..bbce0aaf0 --- /dev/null +++ b/.flocks/plugins/tools/device/onesig_v2_5_3/_test.yaml @@ -0,0 +1,12 @@ +provider: onesig_api +tests: + onesig_strategy_api_query: + - label: "Query platform status" + input: + action: platform_status + - label: "List global whitelist" + input: + action: whitelist_list + body: + pageNo: 1 + pageSize: 20 diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api.handler.py b/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api.handler.py new file mode 100644 index 000000000..61afc6855 --- /dev/null +++ b/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api.handler.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import os +import time +from typing import Any + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + +SERVICE_ID = "onesig_api" +DEFAULT_TIMEOUT = 60 + + +def _get_secret_manager(): + from flocks.security import get_secret_manager + + return get_secret_manager() + + +def _resolve_ref(value: Any) -> str | None: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:") : -1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:") : -1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _ensure_scheme(value: str) -> str: + if value and not value.startswith(("http://", "https://")): + return "https://" + value + return value + + +def _bool_value(value: Any, default: bool = False) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +def _runtime_config() -> tuple[str, str, str, int, bool]: + raw = _service_config() + base_url = _ensure_scheme( + ( + _resolve_ref(raw.get("base_url")) + or _resolve_ref(raw.get("baseUrl")) + or os.getenv("ONESIG_V2_5_3_BASE_URL") + or "" + ).rstrip("/") + ) + api_key = ( + _resolve_ref(raw.get("api_key")) + or _resolve_ref(raw.get("apiKey")) + or _get_secret_manager().get("onesig_v2_5_3_api_key") + or os.getenv("ONESIG_V2_5_3_API_KEY") + or "" + ) + secret = ( + _resolve_ref(raw.get("secret")) + or _get_secret_manager().get("onesig_v2_5_3_secret") + or os.getenv("ONESIG_V2_5_3_SECRET") + or "" + ) + timeout = raw.get("timeout", DEFAULT_TIMEOUT) + try: + timeout = int(timeout) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + verify_ssl = _bool_value(raw.get("verify_ssl"), False) + custom_settings = raw.get("custom_settings") + if isinstance(custom_settings, dict): + verify_ssl = _bool_value(custom_settings.get("verify_ssl"), verify_ssl) + return base_url, api_key, secret, timeout, verify_ssl + + +def _signed_query(api_key: str, secret: str) -> dict[str, str]: + timestamp = str(int(time.time())) + sign_data = f"{api_key}{timestamp}".encode() + digest = hmac.new(secret.encode(), sign_data, hashlib.sha1).digest() + return { + "apikey": api_key, + "timestamp": timestamp, + "sign": base64.b64encode(digest).decode(), + } + + +def _json_result(action: str, data: Any) -> ToolResult: + metadata = {"source": "OneSIG Strategy API", "action": action} + if isinstance(data, dict): + code = data.get("response_code") + if code is not None and code not in (0, "0"): + return ToolResult( + success=False, + error=f"OneSIG Strategy API error (code={code}): {data.get('verbose_msg') or data}", + metadata=metadata, + ) + return ToolResult(success=True, output=data.get("data", data), metadata=metadata) + return ToolResult(success=True, output=data, metadata=metadata) + + +def _payload_from_params(params: dict[str, Any]) -> dict[str, Any]: + body = params.get("body") + if isinstance(body, dict): + return body + return { + key: value + for key, value in params.items() + if key not in {"action", "body"} and value is not None + } + + +async def _request( + action: str, + method: str, + path: str, + params: dict[str, Any], +) -> ToolResult: + base_url, api_key, secret, timeout, verify_ssl = _runtime_config() + if not base_url: + return ToolResult(success=False, error="OneSIG Strategy API base_url is not configured.") + if not api_key or not secret: + return ToolResult(success=False, error="OneSIG Strategy API ApiKey and Secret are required.") + + url = f"{base_url}{path}" + query = _signed_query(api_key, secret) + payload = _payload_from_params(params) + headers = {"Content-Type": "application/json"} + try: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + request = session.request( + method, + url, + params=query, + json=payload if method != "GET" else None, + headers=headers, + ssl=verify_ssl, + ) + async with request as resp: + text = await resp.text() + if resp.status >= 400: + return ToolResult(success=False, error=f"HTTP {resp.status}: {text[:500]}") + try: + data = await resp.json(content_type=None) + except Exception: + return ToolResult(success=True, output=text, metadata={"source": "OneSIG Strategy API", "action": action}) + except aiohttp.ClientError as exc: + return ToolResult(success=False, error=f"Request failed: {exc}") + return _json_result(action, data) + + +QUERY_ACTIONS: dict[str, tuple[str, str]] = { + "platform_status": ("POST", "/api/v3/device/platformStatus"), + "system_status": ("POST", "/api/v3/device/systemStatus"), + "network_status": ("POST", "/api/v3/device/networkStatus"), + "asset_group_list": ("GET", "/api/v3/asset/group"), + "asset_list": ("POST", "/api/v3/asset/list"), + "asset_type_list": ("GET", "/api/v3/asset/type"), + "protection_policy_list": ("POST", "/api/v3/protection/policy"), + "whitelist_list": ("POST", "/api/v3/globalWhitelist/list"), + "blacklist_list": ("POST", "/api/v3/globalBlacklist/list"), + "banned_whitelist_list": ("POST", "/api/v3/bannedWhitelist/list"), + "http_blacklist_list": ("POST", "/api/v3/httpBlacklist/list"), +} + +OPS_ACTIONS: dict[str, tuple[str, str]] = { + "asset_group_create": ("POST", "/api/v3/asset/group/create"), + "asset_group_update": ("POST", "/api/v3/asset/group/update"), + "asset_group_delete": ("POST", "/api/v3/asset/group/delete"), + "asset_create": ("POST", "/api/v3/asset/create"), + "asset_update": ("POST", "/api/v3/asset/update"), + "asset_delete": ("POST", "/api/v3/asset/delete"), + "protection_policy_update": ("POST", "/api/v3/protection/policy/update"), + "protection_policy_delete": ("POST", "/api/v3/protection/policy/delete"), + "whitelist_create": ("POST", "/api/v3/globalWhitelist/create"), + "whitelist_update": ("POST", "/api/v3/globalWhitelist/update"), + "whitelist_delete": ("POST", "/api/v3/globalWhitelist/delete"), + "whitelist_remove": ("POST", "/api/v3/globalWhitelist/remove"), + "blacklist_create": ("POST", "/api/v3/globalBlacklist/create"), + "blacklist_update": ("POST", "/api/v3/globalBlacklist/update"), + "blacklist_delete": ("POST", "/api/v3/globalBlacklist/delete"), + "blacklist_remove": ("POST", "/api/v3/globalBlacklist/remove"), + "banned_whitelist_create": ("POST", "/api/v3/bannedWhitelist/create"), + "banned_whitelist_update": ("POST", "/api/v3/bannedWhitelist/update"), + "banned_whitelist_delete": ("POST", "/api/v3/bannedWhitelist/delete"), + "http_blacklist_create": ("POST", "/api/v3/httpBlacklist/create"), + "http_blacklist_update": ("POST", "/api/v3/httpBlacklist/update"), + "http_blacklist_enable": ("POST", "/api/v3/httpBlacklist/enable"), + "http_blacklist_delete": ("POST", "/api/v3/httpBlacklist/delete"), +} + + +async def query(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + del ctx + spec = QUERY_ACTIONS.get(action) + if not spec: + return ToolResult(success=False, error=f"Unsupported OneSIG Strategy API query action: {action}") + return await _request(action, spec[0], spec[1], params) + + +async def ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + del ctx + spec = OPS_ACTIONS.get(action) + if not spec: + return ToolResult(success=False, error=f"Unsupported OneSIG Strategy API ops action: {action}") + return await _request(action, spec[0], spec[1], params) diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api_ops.yaml b/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api_ops.yaml new file mode 100644 index 000000000..2a15c5570 --- /dev/null +++ b/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api_ops.yaml @@ -0,0 +1,51 @@ +name: onesig_strategy_api_ops +description: > + OneSIG Strategy API mutation grouped tool. Uses ApiKey + Secret HMAC-SHA1 + signing and calls /api/v3 mutation endpoints for assets, policies, + whitelist, blacklist, banned whitelist, and HTTP blacklist operations. +description_cn: > + OneSIG 策略 API 写操作分组工具。使用 ApiKey + Secret HMAC-SHA1 签名, + 调用资产、策略、白名单、黑名单、封禁白名单和 HTTP 防护相关写接口。 +category: custom +enabled: true +requires_confirmation: true +provider: onesig_api +inputSchema: + type: object + properties: + action: + type: string + enum: + - asset_group_create + - asset_group_update + - asset_group_delete + - asset_create + - asset_update + - asset_delete + - protection_policy_update + - protection_policy_delete + - whitelist_create + - whitelist_update + - whitelist_delete + - whitelist_remove + - blacklist_create + - blacklist_update + - blacklist_delete + - blacklist_remove + - banned_whitelist_create + - banned_whitelist_update + - banned_whitelist_delete + - http_blacklist_create + - http_blacklist_update + - http_blacklist_enable + - http_blacklist_delete + description: 写操作动作名。调用前必须确认影响范围。 + body: + type: object + description: 请求体,按 api-onesig-2.5.3 文档传入对应字段。 + required: + - action +handler: + type: script + script_file: onesig_strategy_api.handler.py + function: ops diff --git a/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api_query.yaml b/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api_query.yaml new file mode 100644 index 000000000..120e12d3f --- /dev/null +++ b/.flocks/plugins/tools/device/onesig_v2_5_3/onesig_strategy_api_query.yaml @@ -0,0 +1,39 @@ +name: onesig_strategy_api_query +description: > + OneSIG Strategy API read-only grouped tool. Uses ApiKey + Secret HMAC-SHA1 + signing and calls /api/v3 endpoints for status, asset, policy, whitelist, + blacklist, banned whitelist, and HTTP blacklist queries. +description_cn: > + OneSIG 策略 API 只读分组工具。使用 ApiKey + Secret HMAC-SHA1 签名, + 调用 /api/v3 下的设备状态、资产、策略、白名单、黑名单、封禁白名单和 HTTP 防护查询接口。 +category: custom +enabled: true +requires_confirmation: false +provider: onesig_api +inputSchema: + type: object + properties: + action: + type: string + enum: + - platform_status + - system_status + - network_status + - asset_group_list + - asset_list + - asset_type_list + - protection_policy_list + - whitelist_list + - blacklist_list + - banned_whitelist_list + - http_blacklist_list + description: 只读动作名。 + body: + type: object + description: 请求体。列表类接口可传 pageNo、pageSize、search、direction 等文档字段。 + required: + - action +handler: + type: script + script_file: onesig_strategy_api.handler.py + function: query diff --git a/.flocks/plugins/tools/device/sangfor_sip_v92/_provider.yaml b/.flocks/plugins/tools/device/sangfor_sip_v92/_provider.yaml index f99ce891b..7b7806f8b 100644 --- a/.flocks/plugins/tools/device/sangfor_sip_v92/_provider.yaml +++ b/.flocks/plugins/tools/device/sangfor_sip_v92/_provider.yaml @@ -13,6 +13,7 @@ description_cn: > 支持拉取受监控IP、服务器资产、终端资产、安全事件、风险业务/终端 及脆弱性(弱密码、漏洞、明文传输)等数据。 使用 auth3(SHA1)签名认证,Token 有有效期,过期后自动重新认证。 +docs_url: "https://agentflocks.github.io/flocks-docs/md/modules/devices/sangfor-sip-integration" auth: type: custom secret: sangfor_sip_password diff --git a/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/_provider.yaml b/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/_provider.yaml index 938a8c1c1..07f03850d 100644 --- a/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/_provider.yaml +++ b/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/_provider.yaml @@ -5,6 +5,7 @@ version: "4.0.14.0.SP2" integration_type: device description: SkyEye monitoring platform API for dashboard views, alarm retrieval, and related downloads. description_cn: SkyEye 威胁监测平台接口,提供仪表板视图、告警检索、枚举值查询和相关下载能力。 +docs_url: "https://agentflocks.github.io/flocks-docs/md/modules/devices/skyeye-integration" auth: type: custom secret: skyeye_api_key diff --git a/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/skyeye.handler.py b/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/skyeye.handler.py index 2ebd5e969..e6f8f94d4 100644 --- a/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/skyeye.handler.py +++ b/.flocks/plugins/tools/device/skyeye_v4_0_14_0_SP2/skyeye.handler.py @@ -5,7 +5,7 @@ import re import time from typing import Any -from urllib.parse import urljoin +from urllib.parse import urljoin, urlsplit, urlunsplit import aiohttp @@ -32,6 +32,20 @@ def _get_custom_setting(raw_service: dict[str, Any], key: str, default: Any = No return custom_settings.get(key, default) +def _ensure_skyeye_base_path(base_url: str) -> str: + cleaned = base_url.strip().rstrip("/") + if not cleaned: + return "" + + parts = urlsplit(cleaned) + path = parts.path.rstrip("/") + if path.lower() == "/skyeye" or path.lower().endswith("/skyeye"): + return urlunsplit(parts._replace(path=path)) + + next_path = f"{path}/skyeye" if path else "/skyeye" + return urlunsplit(parts._replace(path=next_path)) + + def _resolve_login_key(raw_service: dict[str, Any]) -> str: secret_manager = security.get_secret_manager() api_key_ref = raw_service.get("apiKey") or _get_custom_setting(raw_service, "login_key") @@ -59,15 +73,15 @@ def _resolve_base_url(raw_service: dict[str, Any]) -> str: if base_url: resolved = security.resolve_value(base_url) if isinstance(resolved, str) and resolved.strip(): - return resolved.rstrip("/") + return _ensure_skyeye_base_path(resolved) secret_manager = security.get_secret_manager() host = secret_manager.get("skyeye_host") or security.resolve_value("{env:SKYEYE_HOST}") if isinstance(host, str) and host.strip(): host = host.strip().rstrip("/") if host.startswith("http://") or host.startswith("https://"): - return host - return f"https://{host}:443" + return _ensure_skyeye_base_path(host) + return _ensure_skyeye_base_path(f"https://{host}:443") return "" diff --git a/.flocks/plugins/tools/device/tdp_v3_3_10/_provider.yaml b/.flocks/plugins/tools/device/tdp_v3_3_10/_provider.yaml index 678de8e6e..8f7fc28e9 100644 --- a/.flocks/plugins/tools/device/tdp_v3_3_10/_provider.yaml +++ b/.flocks/plugins/tools/device/tdp_v3_3_10/_provider.yaml @@ -9,6 +9,7 @@ description: > description_cn: > TDP 监测平台接口,提供监控看板、威胁调查、资产风险、日志检索、MDR、告警调查 研判和系统运行状态查询能力。 +docs_url: "https://agentflocks.github.io/flocks-docs/md/modules/devices/tdp-integration" auth: type: custom secret: tdp_api_key diff --git a/.github/workflows/dispatch-autotest-windows-upgrade.yml b/.github/workflows/dispatch-autotest-windows-upgrade.yml new file mode 100644 index 000000000..b9c08c6ef --- /dev/null +++ b/.github/workflows/dispatch-autotest-windows-upgrade.yml @@ -0,0 +1,343 @@ +name: Dispatch autotest Windows upgrade + +on: + push: + branches: + - main + workflow_dispatch: + inputs: + release_tag: + description: "Optional Flocks version tag. Defaults to pyproject.toml version." + required: false + type: string + release_url: + description: "Optional source URL. Defaults to the release or commit URL." + required: false + type: string + rollback_version: + description: "Optional old version for flocks_autotest to prepare an upgradeable baseline." + required: false + type: string + run_force_fallback: + description: "Also ask flocks_autotest to run the force fallback reinstall case." + required: false + default: false + type: boolean + +permissions: + contents: read + +concurrency: + group: dispatch-autotest-windows-upgrade-${{ github.event_name }}-${{ github.ref_name || github.run_id }} + cancel-in-progress: false + +jobs: + dispatch: + name: Dispatch flocks_autotest Windows upgrade + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Resolve dispatch payload + id: payload + env: + EVENT_NAME: ${{ github.event_name }} + INPUT_RELEASE_TAG: ${{ inputs.release_tag }} + INPUT_RELEASE_URL: ${{ inputs.release_url }} + INPUT_ROLLBACK_VERSION: ${{ inputs.rollback_version }} + INPUT_RUN_FORCE_FALLBACK: ${{ inputs.run_force_fallback }} + GITHUB_SERVER_URL_VALUE: ${{ github.server_url }} + GITHUB_REPOSITORY_VALUE: ${{ github.repository }} + GITHUB_SHA_VALUE: ${{ github.sha }} + GITHUB_REF_NAME_VALUE: ${{ github.ref_name }} + GITHUB_RUN_ID_VALUE: ${{ github.run_id }} + run: | + set -euo pipefail + + pyproject_version="$(python3 - <<'PY' + import tomllib + with open("pyproject.toml", "rb") as f: + print(tomllib.load(f)["project"]["version"]) + PY + )" + + release_tag="${INPUT_RELEASE_TAG:-}" + release_url="${INPUT_RELEASE_URL:-}" + + if [ -z "$release_tag" ]; then + release_tag="$pyproject_version" + fi + if [ -z "$release_url" ]; then + release_url="${GITHUB_SERVER_URL_VALUE}/${GITHUB_REPOSITORY_VALUE}/commit/${GITHUB_SHA_VALUE}" + fi + + if [ -z "$release_tag" ]; then + echo "release_tag is empty" >&2 + exit 1 + fi + + { + echo "release_tag=$release_tag" + echo "release_url=$release_url" + echo "rollback_version=${INPUT_ROLLBACK_VERSION:-}" + echo "run_force_fallback=${INPUT_RUN_FORCE_FALLBACK:-false}" + } >> "$GITHUB_OUTPUT" + + { + echo "### flocks_autotest dispatch payload" + echo "" + echo "- Event: ${EVENT_NAME}" + echo "- Release tag: ${release_tag}" + echo "- Release URL: ${release_url}" + echo "- Source ref: ${GITHUB_REF_NAME_VALUE}" + echo "- Source sha: ${GITHUB_SHA_VALUE}" + echo "- Rollback version: ${INPUT_ROLLBACK_VERSION:-}" + echo "- Run force fallback: ${INPUT_RUN_FORCE_FALLBACK:-false}" + echo "- Expected autotest artifact prefix: flocks-windows-upgrade-${release_tag}-time" + echo "- Autotest workflow: https://github.com/AgentFlocks/flocks_autotest/actions/workflows/flocks-release-dispatch.yml" + } >> "$GITHUB_STEP_SUMMARY" + + - name: Create repository_dispatch payload + env: + RELEASE_TAG: ${{ steps.payload.outputs.release_tag }} + RELEASE_URL: ${{ steps.payload.outputs.release_url }} + ROLLBACK_VERSION: ${{ steps.payload.outputs.rollback_version }} + RUN_FORCE_FALLBACK: ${{ steps.payload.outputs.run_force_fallback }} + SOURCE_REPOSITORY: ${{ github.repository }} + SOURCE_RUN_ID: ${{ github.run_id }} + SOURCE_SHA: ${{ github.sha }} + SOURCE_REF: ${{ github.ref_name }} + run: | + set -euo pipefail + + python3 - <<'PY' > "$RUNNER_TEMP/flocks-autotest-dispatch.json" + import json + import os + + payload = { + "event_type": "flocks_main_updated", + "client_payload": { + "release_tag": os.environ["RELEASE_TAG"], + "release_url": os.environ["RELEASE_URL"], + "installer_asset_name": "", + "source_repository": os.environ["SOURCE_REPOSITORY"], + "source_run_id": os.environ["SOURCE_RUN_ID"], + "source_sha": os.environ["SOURCE_SHA"], + "source_ref": os.environ["SOURCE_REF"], + "rollback_version": os.environ.get("ROLLBACK_VERSION", ""), + "run_force_fallback": os.environ.get("RUN_FORCE_FALLBACK", "false"), + }, + } + print(json.dumps(payload, ensure_ascii=False)) + PY + + cat "$RUNNER_TEMP/flocks-autotest-dispatch.json" + + - name: Dispatch flocks_autotest + env: + GH_TOKEN: ${{ secrets.AUTOTEST_DISPATCH_TOKEN }} + run: | + set -euo pipefail + + if [ -z "${GH_TOKEN:-}" ]; then + echo "Missing AUTOTEST_DISPATCH_TOKEN secret" >&2 + exit 1 + fi + + curl --fail-with-body -L -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${GH_TOKEN}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + --data-binary "@${RUNNER_TEMP}/flocks-autotest-dispatch.json" \ + "https://api.github.com/repos/AgentFlocks/flocks_autotest/dispatches" + + - name: Wait for flocks_autotest result + id: autotest + timeout-minutes: 210 + env: + GH_TOKEN: ${{ secrets.AUTOTEST_DISPATCH_TOKEN }} + RELEASE_TAG: ${{ steps.payload.outputs.release_tag }} + SOURCE_RUN_ID: ${{ github.run_id }} + AUTOTEST_REPOSITORY: AgentFlocks/flocks_autotest + AUTOTEST_WORKFLOW_FILE: flocks-release-dispatch.yml + AUTOTEST_POLL_INTERVAL_SECONDS: "30" + AUTOTEST_WAIT_TIMEOUT_SECONDS: "10800" + run: | + set -euo pipefail + + python3 - <<'PY' + import json + import os + import sys + import time + import urllib.error + import urllib.parse + import urllib.request + + api_root = "https://api.github.com" + token = os.environ["GH_TOKEN"] + release_tag = os.environ["RELEASE_TAG"] + source_run_id = os.environ["SOURCE_RUN_ID"] + autotest_repo = os.environ["AUTOTEST_REPOSITORY"] + workflow_file = os.environ["AUTOTEST_WORKFLOW_FILE"] + poll_interval = int(os.environ["AUTOTEST_POLL_INTERVAL_SECONDS"]) + wait_timeout = int(os.environ["AUTOTEST_WAIT_TIMEOUT_SECONDS"]) + summary_path = os.environ["GITHUB_STEP_SUMMARY"] + output_path = os.environ["GITHUB_OUTPUT"] + runner_temp = os.environ["RUNNER_TEMP"] + + headers = { + "Accept": "application/vnd.github+json", + "Authorization": f"Bearer {token}", + "X-GitHub-Api-Version": "2022-11-28", + } + + def api_json(url): + request = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(request, timeout=30) as response: + return json.loads(response.read().decode("utf-8")) + + def download(url, destination): + request = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(request, timeout=300) as response: + with open(destination, "wb") as f: + f.write(response.read()) + + def write_summary(lines): + with open(summary_path, "a", encoding="utf-8") as f: + f.write("\n".join(lines)) + f.write("\n") + + def write_outputs(values): + with open(output_path, "a", encoding="utf-8") as f: + for key, value in values.items(): + f.write(f"{key}={value}\n") + + workflow_ref = urllib.parse.quote(workflow_file, safe="") + runs_url = ( + f"{api_root}/repos/{autotest_repo}/actions/workflows/" + f"{workflow_ref}/runs?event=repository_dispatch&per_page=30" + ) + artifact_prefix = f"flocks-windows-upgrade-{release_tag}-time" + deadline = time.time() + wait_timeout + matched_run = None + + while time.time() < deadline: + runs = api_json(runs_url).get("workflow_runs", []) + for run in runs: + title = run.get("display_title") or run.get("name") or "" + if f"#{source_run_id}" in title: + matched_run = run + break + if matched_run: + break + print( + f"Waiting for flocks_autotest workflow run with source_run_id={source_run_id}...", + flush=True, + ) + time.sleep(poll_interval) + + if not matched_run: + write_summary( + [ + "### flocks_autotest result", + "", + f"- Status: timed out waiting for workflow run", + f"- Source run id: {source_run_id}", + f"- Expected artifact prefix: {artifact_prefix}", + ] + ) + raise SystemExit("Timed out waiting for flocks_autotest workflow run") + + run_id = matched_run["id"] + run_api_url = matched_run["url"] + run_html_url = matched_run["html_url"] + print(f"Matched flocks_autotest run: {run_html_url}", flush=True) + + while time.time() < deadline: + matched_run = api_json(run_api_url) + status = matched_run.get("status") + conclusion = matched_run.get("conclusion") or "" + print( + f"flocks_autotest run {run_id}: status={status} conclusion={conclusion}", + flush=True, + ) + if status == "completed": + break + time.sleep(poll_interval) + + if matched_run.get("status") != "completed": + write_summary( + [ + "### flocks_autotest result", + "", + f"- Status: timed out waiting for completion", + f"- Run: {run_html_url}", + f"- Expected artifact prefix: {artifact_prefix}", + ] + ) + raise SystemExit("Timed out waiting for flocks_autotest completion") + + conclusion = matched_run.get("conclusion") or "unknown" + artifacts_url = f"{api_root}/repos/{autotest_repo}/actions/runs/{run_id}/artifacts?per_page=100" + artifacts = api_json(artifacts_url).get("artifacts", []) + matching_artifacts = [ + artifact + for artifact in artifacts + if not artifact.get("expired") and (artifact.get("name") or "").startswith(artifact_prefix) + ] + matching_artifacts.sort(key=lambda artifact: artifact.get("created_at") or "", reverse=True) + + artifact_name = "" + artifact_html_url = "" + artifact_path = "" + if matching_artifacts: + artifact = matching_artifacts[0] + artifact_name = artifact["name"] + artifact_html_url = ( + f"https://github.com/{autotest_repo}/actions/runs/{run_id}/artifacts/{artifact['id']}" + ) + artifact_path = os.path.join(runner_temp, f"{artifact_name}.zip") + download(artifact["archive_download_url"], artifact_path) + print(f"Downloaded flocks_autotest artifact: {artifact_path}", flush=True) + + write_outputs( + { + "run_id": str(run_id), + "run_url": run_html_url, + "conclusion": conclusion, + "artifact_name": artifact_name, + "artifact_url": artifact_html_url, + "artifact_path": artifact_path, + } + ) + + summary_lines = [ + "### flocks_autotest result", + "", + f"- Run: {run_html_url}", + f"- Conclusion: {conclusion}", + f"- Expected artifact prefix: {artifact_prefix}", + ] + if artifact_name: + summary_lines.append(f"- Artifact: [{artifact_name}]({artifact_html_url})") + summary_lines.append("- Artifact copy: uploaded to this flocks workflow run") + else: + summary_lines.append("- Artifact: not found") + write_summary(summary_lines) + + if conclusion != "success": + raise SystemExit(f"flocks_autotest concluded with {conclusion}") + if not artifact_name: + raise SystemExit("flocks_autotest artifact was not found") + PY + + - name: Upload flocks_autotest artifact copy + if: ${{ always() && steps.autotest.outputs.artifact_path != '' }} + uses: actions/upload-artifact@v4 + with: + name: ${{ steps.autotest.outputs.artifact_name }} + path: ${{ steps.autotest.outputs.artifact_path }} + if-no-files-found: error diff --git a/flocks/auth/context.py b/flocks/auth/context.py index 0e69f82c6..c95d7b6af 100644 --- a/flocks/auth/context.py +++ b/flocks/auth/context.py @@ -18,6 +18,8 @@ class AuthUser(BaseModel): role: str = Field(..., description="admin or member") status: str = Field("active", description="active or disabled") must_reset_password: bool = False + tenant_ids: tuple[str, ...] = Field(default_factory=tuple) + asset_groups: tuple[str, ...] = Field(default_factory=tuple) _current_auth_user: contextvars.ContextVar[Optional[AuthUser]] = contextvars.ContextVar( diff --git a/flocks/auth/service.py b/flocks/auth/service.py index ba74aef53..3c5d761c0 100644 --- a/flocks/auth/service.py +++ b/flocks/auth/service.py @@ -7,9 +7,10 @@ import base64 import hashlib import hmac +import json import secrets from datetime import UTC, datetime, timedelta -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import aiosqlite from pydantic import BaseModel, Field @@ -45,12 +46,44 @@ def _parse_iso(ts: str) -> datetime: return parsed.astimezone(UTC) +def _clean_scope_values(values: Iterable[str]) -> tuple[str, ...]: + cleaned: list[str] = [] + seen: set[str] = set() + for value in values: + if not isinstance(value, str): + continue + normalized = value.strip() + if not normalized or normalized in seen: + continue + cleaned.append(normalized) + seen.add(normalized) + return tuple(cleaned) + + +def _decode_scope_values(raw: Optional[str]) -> tuple[str, ...]: + if not raw: + return () + try: + values = json.loads(raw) + except json.JSONDecodeError: + return () + if not isinstance(values, list): + return () + return _clean_scope_values(values) + + +def _encode_scope_values(values: Iterable[str]) -> str: + return json.dumps(list(_clean_scope_values(values)), ensure_ascii=False) + + class LocalUser(BaseModel): id: str username: str role: str status: str must_reset_password: bool + tenant_ids: tuple[str, ...] = Field(default_factory=tuple) + asset_groups: tuple[str, ...] = Field(default_factory=tuple) created_at: str updated_at: str last_login_at: Optional[str] = None @@ -62,6 +95,8 @@ def to_auth_user(self) -> AuthUser: role=self.role, status=self.status, must_reset_password=self.must_reset_password, + tenant_ids=self.tenant_ids, + asset_groups=self.asset_groups, ) @@ -95,6 +130,8 @@ async def init(cls) -> None: role TEXT NOT NULL DEFAULT 'member', status TEXT NOT NULL DEFAULT 'active', must_reset_password INTEGER NOT NULL DEFAULT 0, + tenant_ids TEXT NOT NULL DEFAULT '[]', + asset_groups TEXT NOT NULL DEFAULT '[]', temp_password_expires_at TEXT, created_at TEXT NOT NULL, updated_at TEXT NOT NULL, @@ -115,6 +152,7 @@ async def init(cls) -> None: """ ) + await cls._ensure_user_scope_columns(db) await cls._drop_legacy_tables(db) await db.commit() @@ -128,6 +166,15 @@ async def init(cls) -> None: # enumerate every historical table name. _LEGACY_TABLE_PATTERNS: Tuple[str, ...] = ("cloud\\_%",) + @classmethod + async def _ensure_user_scope_columns(cls, db: aiosqlite.Connection) -> None: + async with db.execute("PRAGMA table_info(users)") as cursor: + columns = {row[1] for row in await cursor.fetchall()} + if "tenant_ids" not in columns: + await db.execute("ALTER TABLE users ADD COLUMN tenant_ids TEXT NOT NULL DEFAULT '[]'") + if "asset_groups" not in columns: + await db.execute("ALTER TABLE users ADD COLUMN asset_groups TEXT NOT NULL DEFAULT '[]'") + @classmethod async def _drop_legacy_tables(cls, db: aiosqlite.Connection) -> None: for pattern in cls._LEGACY_TABLE_PATTERNS: @@ -200,6 +247,8 @@ async def _create_user_internal( role: str = "member", must_reset_password: bool = False, temp_expires_at: Optional[str] = None, + tenant_ids: Iterable[str] = (), + asset_groups: Iterable[str] = (), ) -> LocalUser: await cls.init() if role not in {"admin", "member"}: @@ -219,9 +268,9 @@ async def _create_user_internal( """ INSERT INTO users ( id, username, password_hash, role, status, must_reset_password, - temp_password_expires_at, created_at, updated_at + tenant_ids, asset_groups, temp_password_expires_at, created_at, updated_at ) - VALUES (?, ?, ?, ?, 'active', ?, ?, ?, ?) + VALUES (?, ?, ?, ?, 'active', ?, ?, ?, ?, ?, ?) """, ( user_id, @@ -229,6 +278,8 @@ async def _create_user_internal( password_hash, role, 1 if must_reset_password else 0, + _encode_scope_values(tenant_ids), + _encode_scope_values(asset_groups), temp_expires_at, now, now, @@ -245,7 +296,7 @@ async def get_user_by_id(cls, user_id: str) -> Optional[LocalUser]: async with Storage.connect(db_path) as db: async with db.execute( """ - SELECT id, username, role, status, must_reset_password, + SELECT id, username, role, status, must_reset_password, tenant_ids, asset_groups, created_at, updated_at, last_login_at FROM users WHERE id = ? """, @@ -260,9 +311,11 @@ async def get_user_by_id(cls, user_id: str) -> Optional[LocalUser]: role=row[2], status=row[3], must_reset_password=bool(row[4]), - created_at=row[5], - updated_at=row[6], - last_login_at=row[7], + tenant_ids=_decode_scope_values(row[5]), + asset_groups=_decode_scope_values(row[6]), + created_at=row[7], + updated_at=row[8], + last_login_at=row[9], ) @classmethod @@ -272,7 +325,8 @@ async def get_user_by_username(cls, username: str) -> Optional[Tuple[LocalUser, async with Storage.connect(db_path) as db: async with db.execute( """ - SELECT id, username, role, status, must_reset_password, created_at, updated_at, last_login_at, + SELECT id, username, role, status, must_reset_password, tenant_ids, asset_groups, + created_at, updated_at, last_login_at, password_hash, temp_password_expires_at FROM users WHERE username = ? """, @@ -287,11 +341,13 @@ async def get_user_by_username(cls, username: str) -> Optional[Tuple[LocalUser, role=row[2], status=row[3], must_reset_password=bool(row[4]), - created_at=row[5], - updated_at=row[6], - last_login_at=row[7], + tenant_ids=_decode_scope_values(row[5]), + asset_groups=_decode_scope_values(row[6]), + created_at=row[7], + updated_at=row[8], + last_login_at=row[9], ) - return user, row[8], row[9] + return user, row[10], row[11] @classmethod async def list_users(cls) -> List[LocalUser]: @@ -301,7 +357,8 @@ async def list_users(cls) -> List[LocalUser]: async with Storage.connect(db_path) as db: async with db.execute( """ - SELECT id, username, role, status, must_reset_password, created_at, updated_at, last_login_at + SELECT id, username, role, status, must_reset_password, tenant_ids, asset_groups, + created_at, updated_at, last_login_at FROM users ORDER BY created_at ASC """ @@ -315,9 +372,11 @@ async def list_users(cls) -> List[LocalUser]: role=row[2], status=row[3], must_reset_password=bool(row[4]), - created_at=row[5], - updated_at=row[6], - last_login_at=row[7], + tenant_ids=_decode_scope_values(row[5]), + asset_groups=_decode_scope_values(row[6]), + created_at=row[7], + updated_at=row[8], + last_login_at=row[9], ) ) return users @@ -347,7 +406,8 @@ async def get_user_by_session_id(cls, session_id: str) -> Optional[LocalUser]: async with Storage.connect(db_path) as db: async with db.execute( """ - SELECT u.id, u.username, u.role, u.status, u.must_reset_password, u.created_at, u.updated_at, u.last_login_at, + SELECT u.id, u.username, u.role, u.status, u.must_reset_password, + u.tenant_ids, u.asset_groups, u.created_at, u.updated_at, u.last_login_at, s.expires_at FROM user_sessions s JOIN users u ON s.user_id = u.id @@ -358,7 +418,7 @@ async def get_user_by_session_id(cls, session_id: str) -> Optional[LocalUser]: row = await cursor.fetchone() if not row: return None - expires_at = _parse_iso(row[8]) + expires_at = _parse_iso(row[10]) if _utc_now() >= expires_at: await cls.revoke_session(session_id) return None @@ -368,9 +428,11 @@ async def get_user_by_session_id(cls, session_id: str) -> Optional[LocalUser]: role=row[2], status=row[3], must_reset_password=bool(row[4]), - created_at=row[5], - updated_at=row[6], - last_login_at=row[7], + tenant_ids=_decode_scope_values(row[5]), + asset_groups=_decode_scope_values(row[6]), + created_at=row[7], + updated_at=row[8], + last_login_at=row[9], ) if user.status != "active": return None @@ -478,6 +540,39 @@ async def set_password( await db.execute("DELETE FROM user_sessions WHERE user_id = ?", (target_user_id,)) await db.commit() + @classmethod + async def set_user_contract_scope( + cls, + *, + target_user_id: str, + tenant_ids: Iterable[str], + asset_groups: Iterable[str], + ) -> LocalUser: + await cls.init() + now = _iso_now() + db_path = Storage.get_db_path() + async with Storage.connect(db_path) as db: + cursor = await db.execute( + """ + UPDATE users + SET tenant_ids = ?, asset_groups = ?, updated_at = ? + WHERE id = ? + """, + ( + _encode_scope_values(tenant_ids), + _encode_scope_values(asset_groups), + now, + target_user_id, + ), + ) + await db.commit() + if cursor.rowcount == 0: + raise ValueError("用户不存在") + user = await cls.get_user_by_id(target_user_id) + if not user: + raise ValueError("用户不存在") + return user + @classmethod async def generate_admin_temp_password( cls, diff --git a/flocks/contracts/__init__.py b/flocks/contracts/__init__.py new file mode 100644 index 000000000..acb888015 --- /dev/null +++ b/flocks/contracts/__init__.py @@ -0,0 +1 @@ +"""Contract runtimes and adapters.""" diff --git a/flocks/contracts/access/__init__.py b/flocks/contracts/access/__init__.py new file mode 100644 index 000000000..11f54c1d9 --- /dev/null +++ b/flocks/contracts/access/__init__.py @@ -0,0 +1,12 @@ +"""Page data access contract runtime package.""" + +from flocks.contracts.access.discovery import discover_contract_plugins +from flocks.contracts.access.models import ContractRuntimeError, WebUIContractPlugin +from flocks.contracts.access.runtime import OperationRuntime + +__all__ = [ + "ContractRuntimeError", + "OperationRuntime", + "WebUIContractPlugin", + "discover_contract_plugins", +] diff --git a/flocks/contracts/access/discovery.py b/flocks/contracts/access/discovery.py new file mode 100644 index 000000000..de5fecba9 --- /dev/null +++ b/flocks/contracts/access/discovery.py @@ -0,0 +1,52 @@ +"""Plugin discovery for page data access contract providers.""" + +from __future__ import annotations + +from pathlib import Path + +from flocks.plugin import ExtensionPoint, PluginLoader +from flocks.contracts.access.models import WebUIContractPlugin + +CONTRACTS_ATTR = "CONTRACTS" + + +def discover_contract_plugins(project_dir: Path | None = None) -> tuple[WebUIContractPlugin, ...]: + plugins: list[WebUIContractPlugin] = [] + seen_plugin_ids: set[str] = set() + seen_contract_ids: set[tuple[str, str]] = set() + + def collect(items: list[WebUIContractPlugin], source: str) -> None: + for item in items: + contract_ids = {(contract.contract_id, contract.version) for contract in item.contracts} + if item.plugin_id in seen_plugin_ids: + continue + if seen_contract_ids.intersection(contract_ids): + continue + seen_plugin_ids.add(item.plugin_id) + seen_contract_ids.update(contract_ids) + plugins.append( + WebUIContractPlugin( + plugin_id=item.plugin_id, + contracts=item.contracts, + binding_resolver=item.binding_resolver, + adapter=item.adapter, + response_pipeline=item.response_pipeline, + overlay_store=item.overlay_store, + version=item.version, + source=source, + ) + ) + + PluginLoader.register_extension_point( + ExtensionPoint( + attr_name=CONTRACTS_ATTR, + subdir="contracts/access", + consumer=collect, + item_type=WebUIContractPlugin, + dedup_key=lambda plugin: plugin.plugin_id, + recursive=True, + max_depth=2, + ) + ) + PluginLoader.load_extension(CONTRACTS_ATTR, project_dir=project_dir or Path.cwd()) + return tuple(plugins) diff --git a/flocks/contracts/access/driver.py b/flocks/contracts/access/driver.py new file mode 100644 index 000000000..c9f1a3a9e --- /dev/null +++ b/flocks/contracts/access/driver.py @@ -0,0 +1,328 @@ +"""Driver proxy and builtin drivers for data access contracts.""" + +from __future__ import annotations + +import json +import re +import sqlite3 +from collections.abc import Iterable +from datetime import datetime +from pathlib import Path +from typing import Any + +from flocks.contracts.access.models import DriverResult, Predicate, QueryPlan, ContractRuntimeError + +DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") +JSON_DATA_SUFFIXES = frozenset({".jsonl", ".json"}) + + +class DriverProxy: + def __init__( + self, + jsonl_executor: "JsonlDriverExecutor | None" = None, + sqlite_executor: "SqliteJsonDriverExecutor | None" = None, + ) -> None: + self._jsonl_executor = jsonl_executor or JsonlDriverExecutor() + self._sqlite_executor = sqlite_executor or SqliteJsonDriverExecutor() + + def execute(self, plan: QueryPlan) -> DriverResult: + self._validate_plan(plan) + if plan.binding.adapter_kind == "builtin-sqlite-json": + return self._sqlite_executor.execute(plan) + return self._jsonl_executor.execute(plan) + + def _validate_plan(self, plan: QueryPlan) -> None: + if plan.binding.adapter_kind not in {"builtin-jsonl", "builtin-sqlite-json"}: + raise ContractRuntimeError( + "adapter_sandbox_unavailable", + status_code=400, + user_message="WebUI contract adapter is not available.", + admin_message=f"Unsupported adapter kind: {plan.binding.adapter_kind}", + ) + + missing = plan.driver_projection - plan.binding.driver_available_fields + if missing: + raise ContractRuntimeError( + "policy_filter_not_enforceable", + status_code=400, + user_message="WebUI contract data source cannot provide required fields.", + admin_message=f"Driver projection contains unavailable fields: {sorted(missing)}", + ) + + +class SqliteJsonDriverExecutor: + def execute(self, plan: QueryPlan) -> DriverResult: + db_path = plan.binding.source_root + self._assert_allowed(db_path, plan.binding.driver_allowlist_roots) + if not db_path.is_file(): + raise ContractRuntimeError( + "data_source_unavailable", + status_code=404, + user_message="WebUI contract SQLite database is not available.", + admin_message=f"SQLite source does not exist: {db_path}", + ) + + options = plan.binding.driver_options + table = _sqlite_identifier(options.get("table"), "records") + record_column = _sqlite_identifier(options.get("recordColumn"), "record_json") + date_column = _sqlite_identifier(options.get("dateColumn"), "record_date") + query = f"SELECT {record_column} FROM {table}" + query_params: list[Any] = [] + start_date, end_date = JsonlDriverExecutor()._request_date_range(plan.params) + if start_date and end_date and date_column: + query += f" WHERE {date_column} BETWEEN ? AND ?" + query_params.extend([start_date, end_date]) + query += " ORDER BY rowid" + + rows: list[dict[str, Any]] = [] + seen_record_ids: set[str] = set() + total_raw = 0 + duplicates = 0 + filtered_unique = 0 + parse_errors = 0 + try: + connection = sqlite3.connect(db_path) + cursor = connection.execute(query, query_params) + raw_records = cursor.fetchall() + connection.close() + except sqlite3.Error as exc: + raise ContractRuntimeError( + "data_source_unavailable", + status_code=500, + user_message="WebUI contract SQLite database cannot be queried.", + admin_message=f"SQLite query failed for {db_path}: {exc}", + ) from exc + + for (record_value,) in raw_records: + try: + record = json.loads(record_value) if isinstance(record_value, str) else None + except json.JSONDecodeError: + record = None + if not isinstance(record, dict): + parse_errors += 1 + continue + if record.get("_type") == "file_header": + continue + + total_raw += 1 + if record.get("is_duplicate") is True: + duplicates += 1 + continue + if not self._matches_predicates(record, plan.policy_plan.driver_predicates): + continue + + record_id = _read_string(record.get("id"), "") + if record_id: + if record_id in seen_record_ids: + duplicates += 1 + continue + seen_record_ids.add(record_id) + + filtered_unique += 1 + if len(rows) < plan.limit: + rows.append( + { + field: record[field] + for field in plan.driver_projection + if field in record + } + ) + + return DriverResult( + rows=rows, + source_files=(db_path,), + total_raw=total_raw, + total_unique=max(total_raw - duplicates, 0), + duplicates=duplicates, + filtered_unique=filtered_unique, + parse_errors=parse_errors, + ) + + def _assert_allowed(self, path: Path, allowlist_roots: tuple[Path, ...]) -> None: + JsonlDriverExecutor()._assert_allowed(path, allowlist_roots) + + def _matches_predicates(self, record: dict[str, Any], predicates: tuple[Predicate, ...]) -> bool: + return JsonlDriverExecutor()._matches_predicates(record, predicates) + + +class JsonlDriverExecutor: + def execute(self, plan: QueryPlan) -> DriverResult: + files = tuple(self._resolve_source_files(plan)) + rows: list[dict[str, Any]] = [] + seen_record_ids: set[str] = set() + total_raw = 0 + duplicates = 0 + filtered_unique = 0 + parse_errors = 0 + for path in files: + self._assert_allowed(path, plan.binding.driver_allowlist_roots) + for record in self._iter_records(path): + if record is None: + parse_errors += 1 + continue + if record.get("_type") == "file_header": + continue + + total_raw += 1 + if record.get("is_duplicate") is True: + duplicates += 1 + continue + if not self._matches_predicates(record, plan.policy_plan.driver_predicates): + continue + + record_id = _read_string(record.get("id"), "") + if record_id: + if record_id in seen_record_ids: + duplicates += 1 + continue + seen_record_ids.add(record_id) + + filtered_unique += 1 + if len(rows) < plan.limit: + rows.append( + { + field: record[field] + for field in plan.driver_projection + if field in record + } + ) + + return DriverResult( + rows=rows, + source_files=files, + total_raw=total_raw, + total_unique=max(total_raw - duplicates, 0), + duplicates=duplicates, + filtered_unique=filtered_unique, + parse_errors=parse_errors, + ) + + def _resolve_source_files(self, plan: QueryPlan) -> Iterable[Path]: + root = plan.binding.source_root + if not root.is_dir(): + raise ContractRuntimeError( + "data_source_unavailable", + status_code=404, + user_message="WebUI contract source files are not available.", + admin_message=f"Source root does not exist: {root}", + ) + + all_files = sorted( + path + for path in root.rglob("*") + if path.is_file() and path.suffix.lower() in JSON_DATA_SUFFIXES and not path.name.startswith(".") + ) + if not all_files: + return () + + start_date, end_date = self._request_date_range(plan.params) + if start_date and end_date: + matched = [ + path + for path in all_files + if (file_date := _data_file_date(root, path)) and start_date <= file_date <= end_date + ] + return matched + + dated_files = [(date, path) for path in all_files if (date := _data_file_date(root, path))] + if dated_files: + latest_date = max(date for date, _path in dated_files) + return [path for date, path in dated_files if date == latest_date] + return all_files + + def _request_date_range(self, params: dict[str, Any]) -> tuple[str, str] | tuple[None, None]: + from_date = _date_from_value(params.get("from") or params.get("startDate") or params.get("date")) + to_date = _date_from_value(params.get("to") or params.get("endDate") or params.get("date")) + if from_date and not to_date: + to_date = from_date + if to_date and not from_date: + from_date = to_date + if from_date and to_date and from_date > to_date: + from_date, to_date = to_date, from_date + if from_date and to_date: + return from_date, to_date + return None, None + + def _assert_allowed(self, path: Path, allowlist_roots: tuple[Path, ...]) -> None: + resolved = path.resolve() + for root in allowlist_roots: + try: + resolved.relative_to(root.resolve()) + return + except ValueError: + continue + raise ContractRuntimeError( + "data_source_unavailable", + status_code=403, + user_message="WebUI contract data source path is not allowed.", + admin_message=f"Driver rejected path outside allowlist: {resolved}", + ) + + def _iter_records(self, path: Path) -> Iterable[dict[str, Any] | None]: + with path.open("r", encoding="utf-8") as handle: + for line in handle: + stripped = line.strip() + if not stripped: + continue + try: + value = json.loads(stripped) + except json.JSONDecodeError: + yield None + continue + yield value if isinstance(value, dict) else None + + def _matches_predicates(self, record: dict[str, Any], predicates: tuple[Predicate, ...]) -> bool: + for predicate in predicates: + value = record.get(predicate.field) + if predicate.operator == "in": + allowed = {_normalize_compare(item) for item in predicate.values} + if _normalize_compare(value) not in allowed: + return False + else: + return False + return True + + +def _data_file_date(root: Path, path: Path) -> str: + try: + parts = path.resolve().relative_to(root.resolve()).parts + except ValueError: + parts = path.parts + for part in parts[:-1]: + if DATE_RE.fullmatch(part): + return part + return "" + + +def _date_from_value(value: Any) -> str: + if value is None: + return "" + text = str(value).strip() + if DATE_RE.fullmatch(text[:10]): + return text[:10] + try: + return datetime.fromisoformat(text.replace("Z", "+00:00")).date().isoformat() + except ValueError: + return "" + + +def _normalize_compare(value: Any) -> str: + if isinstance(value, float) and value.is_integer(): + return str(int(value)) + return str(value).strip().lower() + + +def _read_string(value: Any, fallback: str) -> str: + return value if isinstance(value, str) and value else fallback + + +def _sqlite_identifier(value: Any, fallback: str) -> str: + text = str(value or fallback).strip() + if not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", text): + raise ContractRuntimeError( + "data_source_unavailable", + status_code=400, + user_message="WebUI contract SQLite source is misconfigured.", + admin_message=f"Invalid SQLite identifier: {text}", + ) + return text diff --git a/flocks/contracts/access/models.py b/flocks/contracts/access/models.py new file mode 100644 index 000000000..fcf471ef7 --- /dev/null +++ b/flocks/contracts/access/models.py @@ -0,0 +1,218 @@ +"""Shared models for page data access contract operations.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +OperationType = Literal["query", "mutation"] +PredicateEnforcement = Literal["driver-required", "native-or-post-filter"] +FilterStage = Literal["driver-native", "post-filter"] + + +@dataclass(frozen=True) +class RuntimeContext: + workspace_id: str + page_id: str + slot_id: str + contract_id: str + contract_version: str + operation: str + operation_type: OperationType + request_id: str + principal_ref: str + policy_context: "PolicyContext" + binding_id: str + binding_version: int + test_mode: bool = False + + +@dataclass(frozen=True) +class PolicyContext: + tenant_ids: tuple[str, ...] = () + asset_groups: tuple[str, ...] = () + + +@dataclass(frozen=True) +class ContractOperation: + name: str + operation_type: OperationType + adapter_required_fields: frozenset[str] + identity_fields: frozenset[str] + public_fields: frozenset[str] + filter_fields: frozenset[str] = field(default_factory=frozenset) + filter_param_fields: dict[str, str] = field(default_factory=dict) + tenant_policy_field: str | None = None + asset_group_policy_field: str | None = None + cursor_fields: frozenset[str] = field(default_factory=frozenset) + sort_fields: frozenset[str] = field(default_factory=frozenset) + default_limit: int = 100 + max_limit: int = 1000 + requires_idempotency_key: bool = False + requires_expected_overlay_version: bool = False + mutation_entity_types: frozenset[str] = field(default_factory=frozenset) + + +@dataclass(frozen=True) +class Contract: + contract_id: str + version: str + page_id: str + operations: dict[str, ContractOperation] + + +@dataclass(frozen=True) +class Binding: + binding_id: str + binding_version: int + page_id: str + slot_id: str + contract_id: str + contract_version: str + adapter_kind: str + source_page_id: str + source_root: Path + driver_available_fields: frozenset[str] + driver_allowlist_roots: tuple[Path, ...] + driver_options: dict[str, Any] = field(default_factory=dict) + capabilities: frozenset[str] = frozenset({"query"}) + + +@dataclass(frozen=True) +class Predicate: + field: str + operator: str + values: tuple[Any, ...] + source: str + enforcement: PredicateEnforcement + filter_stage: FilterStage + + +@dataclass(frozen=True) +class PolicyEnforcementPlan: + policy_predicates: tuple[Predicate, ...] + frontend_predicates: tuple[Predicate, ...] + + @property + def driver_predicates(self) -> tuple[Predicate, ...]: + return tuple( + predicate + for predicate in (*self.policy_predicates, *self.frontend_predicates) + if predicate.filter_stage == "driver-native" + ) + + @property + def filter_stages_applied(self) -> list[dict[str, str]]: + return [ + { + "field": predicate.field, + "source": predicate.source, + "stage": predicate.filter_stage, + "enforcement": predicate.enforcement, + } + for predicate in (*self.policy_predicates, *self.frontend_predicates) + ] + + +@dataclass(frozen=True) +class FieldDependencyPlan: + driver_required_fields: frozenset[str] + internal_fields: frozenset[str] + identity_fields: frozenset[str] + policy_fields: frozenset[str] + cursor_fields: frozenset[str] + sort_fields: frozenset[str] + filter_fields: frozenset[str] + public_fields: frozenset[str] + + +@dataclass(frozen=True) +class QueryPlan: + context: RuntimeContext + binding: Binding + operation: ContractOperation + params: dict[str, Any] + policy_plan: PolicyEnforcementPlan + field_plan: FieldDependencyPlan + limit: int + + @property + def driver_projection(self) -> frozenset[str]: + return self.field_plan.driver_required_fields + + +@dataclass(frozen=True) +class MutationPlan: + context: RuntimeContext + binding: Binding + operation: ContractOperation + params: dict[str, Any] + entity_type: str + entity_id: str + idempotency_key: str + expected_overlay_version: int | None + write_through_enabled: bool = False + + +@dataclass(frozen=True) +class DriverResult: + rows: list[dict[str, Any]] + source_files: tuple[Path, ...] + total_raw: int + total_unique: int + duplicates: int + filtered_unique: int + parse_errors: int = 0 + + +@dataclass(frozen=True) +class InternalDataRow: + raw: dict[str, Any] + identity: dict[str, Any] + + +@dataclass(frozen=True) +class OperationResponse: + status_code: int + body: dict[str, Any] + + +@dataclass(frozen=True) +class WebUIContractPlugin: + plugin_id: str + contracts: tuple[Contract, ...] + binding_resolver: Any + adapter: Any + response_pipeline: Any + overlay_store: Any | None = None + version: str = "1.0" + source: str = "" + + +class ContractRuntimeError(Exception): + """Structured contract error that can be rendered by HTTP routes.""" + + def __init__( + self, + code: str, + *, + status_code: int = 400, + user_message: str | None = None, + admin_message: str | None = None, + request_id: str | None = None, + ) -> None: + super().__init__(admin_message or user_message or code) + self.code = code + self.status_code = status_code + self.user_message = user_message or code + self.admin_message = admin_message or self.user_message + self.request_id = request_id + + def to_detail(self) -> dict[str, Any]: + return { + "code": self.code, + "userMessage": self.user_message, + "adminMessage": self.admin_message, + "requestId": self.request_id, + } diff --git a/flocks/contracts/access/pipeline.py b/flocks/contracts/access/pipeline.py new file mode 100644 index 000000000..d60b812ba --- /dev/null +++ b/flocks/contracts/access/pipeline.py @@ -0,0 +1,143 @@ +"""Shared overlay, idempotency, and mutation pipelines.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from typing import Any + +from flocks.contracts.access.models import InternalDataRow, MutationPlan, RuntimeContext, ContractRuntimeError + + +@dataclass +class OverlayEntry: + version: int + fields: dict[str, Any] + + +class OverlayStore: + def __init__(self) -> None: + self._entries: dict[tuple[str, str, str, str, str], OverlayEntry] = {} + + def merge(self, rows: list[InternalDataRow], context: RuntimeContext) -> list[InternalDataRow]: + merged: list[InternalDataRow] = [] + for row in rows: + entity_type = row.identity.get("entityType") + entity_id = row.identity.get("entityId") + if not isinstance(entity_type, str) or not isinstance(entity_id, str): + merged.append(row) + continue + entry = self._entries.get( + ( + context.page_id, + context.contract_id, + context.contract_version, + entity_type, + entity_id, + ) + ) + if entry is None: + merged.append(row) + continue + merged.append( + InternalDataRow( + raw={ + **row.raw, + **entry.fields, + "_overlay_version": entry.version, + }, + identity=row.identity, + ) + ) + return merged + + def transaction(self, plan: MutationPlan) -> OverlayEntry: + key = ( + plan.context.page_id, + plan.context.contract_id, + plan.context.contract_version, + plan.entity_type, + plan.entity_id, + ) + current = self._entries.get(key) + current_version = current.version if current else 0 + expected = plan.expected_overlay_version + if current_version == 0: + if expected not in (None, 0): + raise ContractRuntimeError("conflict", status_code=409, user_message="Overlay version conflict.") + elif expected != current_version: + raise ContractRuntimeError("conflict", status_code=409, user_message="Overlay version conflict.") + + fields = dict(current.fields) if current else {} + fields.update( + { + key: value + for key, value in plan.params.items() + if key not in {"entityType", "entityId"} + } + ) + entry = OverlayEntry(version=current_version + 1, fields=fields) + self._entries[key] = entry + return entry + + +class IdempotencyService: + def __init__(self) -> None: + self._entries: dict[tuple[str, ...], tuple[str, dict[str, Any]]] = {} + + def replay_or_conflict(self, scope: tuple[str, ...], request_hash: str) -> dict[str, Any] | None: + entry = self._entries.get(scope) + if entry is None: + return None + stored_hash, response = entry + if stored_hash != request_hash: + raise ContractRuntimeError( + "conflict", + status_code=409, + user_message="idempotencyKey was already used with different content.", + ) + return response + + def store(self, scope: tuple[str, ...], request_hash: str, response: dict[str, Any]) -> None: + self._entries[scope] = (request_hash, response) + + +class MutationPipeline: + def __init__( + self, + overlay_store: OverlayStore | None = None, + idempotency_service: IdempotencyService | None = None, + ) -> None: + self._overlay_store = overlay_store or OverlayStore() + self._idempotency_service = idempotency_service or IdempotencyService() + + def run(self, plan: MutationPlan) -> dict[str, Any]: + request_hash = hashlib.sha256( + json.dumps(plan.params, ensure_ascii=False, sort_keys=True, default=str).encode("utf-8") + ).hexdigest() + scope = ( + plan.context.page_id, + plan.context.contract_id, + plan.context.contract_version, + plan.entity_type, + plan.entity_id, + plan.idempotency_key, + ) + replay = self._idempotency_service.replay_or_conflict(scope, request_hash) + if replay is not None: + return replay + + entry = self._overlay_store.transaction(plan) + response = { + "ok": True, + "entityType": plan.entity_type, + "entityId": plan.entity_id, + "overlayVersion": entry.version, + "writeThrough": { + "enabled": plan.write_through_enabled, + "status": "not_configured", + }, + } + self._idempotency_service.store(scope, request_hash, response) + return response diff --git a/flocks/contracts/access/plans.py b/flocks/contracts/access/plans.py new file mode 100644 index 000000000..c04bbd334 --- /dev/null +++ b/flocks/contracts/access/plans.py @@ -0,0 +1,215 @@ +"""Plan compilers for page data access contract operations.""" + +from __future__ import annotations + +from typing import Any + +from flocks.contracts.access.models import ( + Binding, + ContractOperation, + FieldDependencyPlan, + MutationPlan, + PolicyContext, + PolicyEnforcementPlan, + Predicate, + QueryPlan, + RuntimeContext, + ContractRuntimeError, +) + + +class PolicyPlanCompiler: + def compile( + self, + *, + operation: ContractOperation, + binding: Binding, + policy_context: PolicyContext, + params: dict[str, Any], + ) -> PolicyEnforcementPlan: + policy_predicates: list[Predicate] = [] + if policy_context.tenant_ids and operation.tenant_policy_field: + policy_predicates.append( + Predicate( + field=operation.tenant_policy_field, + operator="in", + values=policy_context.tenant_ids, + source="policy.tenantIds", + enforcement="driver-required", + filter_stage="driver-native", + ) + ) + if policy_context.asset_groups and operation.asset_group_policy_field: + policy_predicates.append( + Predicate( + field=operation.asset_group_policy_field, + operator="in", + values=policy_context.asset_groups, + source="policy.assetGroups", + enforcement="driver-required", + filter_stage="driver-native", + ) + ) + + frontend_predicates = self._compile_frontend_predicates(operation, params) + for predicate in (*policy_predicates, *frontend_predicates): + if predicate.filter_stage != "driver-native": + continue + if predicate.field not in binding.driver_available_fields or predicate.field not in operation.filter_fields: + raise ContractRuntimeError( + "policy_filter_not_enforceable", + status_code=400, + user_message="WebUI contract data source cannot enforce one of the requested filters.", + admin_message=f"Binding {binding.binding_id} cannot enforce {predicate.field} before adapter execution.", + ) + + return PolicyEnforcementPlan( + policy_predicates=tuple(policy_predicates), + frontend_predicates=tuple(frontend_predicates), + ) + + def _compile_frontend_predicates( + self, + operation: ContractOperation, + params: dict[str, Any], + ) -> tuple[Predicate, ...]: + raw_filters = params.get("filters") + if not isinstance(raw_filters, dict): + return () + + predicates: list[Predicate] = [] + for param_name, field_name in operation.filter_param_fields.items(): + values = _coerce_values(raw_filters.get(param_name)) + if not values: + continue + predicates.append( + Predicate( + field=field_name, + operator="in", + values=values, + source=f"params.filters.{param_name}", + enforcement="native-or-post-filter", + filter_stage="driver-native", + ) + ) + return tuple(predicates) + + +class FieldDependencyPlanCompiler: + def compile(self, *, operation: ContractOperation, policy_plan: PolicyEnforcementPlan) -> FieldDependencyPlan: + policy_fields = frozenset(predicate.field for predicate in policy_plan.policy_predicates) + frontend_filter_fields = frozenset(predicate.field for predicate in policy_plan.frontend_predicates) + driver_required_fields = frozenset( + operation.adapter_required_fields + | operation.identity_fields + | policy_fields + | operation.cursor_fields + | operation.sort_fields + | frontend_filter_fields + ) + return FieldDependencyPlan( + driver_required_fields=driver_required_fields, + internal_fields=driver_required_fields, + identity_fields=operation.identity_fields, + policy_fields=policy_fields, + cursor_fields=operation.cursor_fields, + sort_fields=operation.sort_fields, + filter_fields=frontend_filter_fields, + public_fields=operation.public_fields, + ) + + +class QueryPlanCompiler: + def compile( + self, + *, + context: RuntimeContext, + binding: Binding, + operation: ContractOperation, + params: dict[str, Any], + policy_plan: PolicyEnforcementPlan, + field_plan: FieldDependencyPlan, + ) -> QueryPlan: + limit = _read_int(params.get("limit"), operation.default_limit) + limit = max(1, min(limit, operation.max_limit)) + return QueryPlan( + context=context, + binding=binding, + operation=operation, + params=params, + policy_plan=policy_plan, + field_plan=field_plan, + limit=limit, + ) + + +class MutationPlanCompiler: + def compile( + self, + *, + context: RuntimeContext, + binding: Binding, + operation: ContractOperation, + payload: dict[str, Any], + ) -> MutationPlan: + params = payload.get("params") + if not isinstance(params, dict): + raise ContractRuntimeError("invalid_request", user_message="Mutation params are required.") + + idempotency_key = payload.get("idempotencyKey") + if operation.requires_idempotency_key and not isinstance(idempotency_key, str): + raise ContractRuntimeError("idempotency_key_required", user_message="idempotencyKey is required for mutations.") + + expected_overlay_version = payload.get("expectedOverlayVersion") + if operation.requires_expected_overlay_version and "expectedOverlayVersion" not in payload: + raise ContractRuntimeError( + "overlay_version_required", + status_code=409, + user_message="expectedOverlayVersion is required for this mutation.", + ) + if expected_overlay_version is not None and not isinstance(expected_overlay_version, int): + raise ContractRuntimeError("invalid_request", user_message="expectedOverlayVersion must be an integer or null.") + + entity_type = params.get("entityType") + entity_id = params.get("entityId") + if not isinstance(entity_type, str) or not isinstance(entity_id, str): + raise ContractRuntimeError("invalid_request", user_message="entityType and entityId are required.") + if operation.mutation_entity_types and entity_type not in operation.mutation_entity_types: + raise ContractRuntimeError( + "invalid_request", + user_message="Mutation entity type is not allowed for this operation.", + ) + + return MutationPlan( + context=context, + binding=binding, + operation=operation, + params=params, + entity_type=entity_type, + entity_id=entity_id, + idempotency_key=idempotency_key or "", + expected_overlay_version=expected_overlay_version, + ) + + +def _coerce_values(value: Any) -> tuple[Any, ...]: + if value is None: + return () + if isinstance(value, (list, tuple, set)): + return tuple(item for item in value if item not in (None, "")) + if value == "": + return () + return (value,) + + +def _read_int(value: Any, fallback: int) -> int: + if isinstance(value, bool): + return fallback + if isinstance(value, int): + return value + if isinstance(value, str): + try: + return int(value) + except ValueError: + return fallback + return fallback diff --git a/flocks/contracts/access/registry.py b/flocks/contracts/access/registry.py new file mode 100644 index 000000000..895b65892 --- /dev/null +++ b/flocks/contracts/access/registry.py @@ -0,0 +1,47 @@ +"""Contract registry for discovered WebUI contract plugins.""" + +from __future__ import annotations + +from flocks.contracts.access.models import Contract, ContractRuntimeError, WebUIContractPlugin + +DEFAULT_CONTRACT_VERSION = "1.0" +DEFAULT_SLOT_ID = "primary" + +FORBIDDEN_REQUEST_FIELDS = frozenset( + { + "bindingId", + "driver", + "adapterId", + "connectionRef", + "table", + "sql", + "index", + "secret", + } +) + +QUERY_FORBIDDEN_REQUEST_FIELDS = FORBIDDEN_REQUEST_FIELDS | {"idempotencyKey"} + + +class ContractRegistry: + def __init__(self, plugins: tuple[WebUIContractPlugin, ...]) -> None: + self._contracts: dict[tuple[str, str], Contract] = {} + self._providers: dict[tuple[str, str], WebUIContractPlugin] = {} + for plugin in plugins: + for contract in plugin.contracts: + key = (contract.contract_id, contract.version) + if key in self._contracts: + raise ContractRuntimeError( + "duplicate_contract", + status_code=500, + user_message="Duplicate WebUI contract registration.", + admin_message=f"Duplicate contract {contract.contract_id}@{contract.version}", + ) + self._contracts[key] = contract + self._providers[key] = plugin + + def get(self, contract_id: str, version: str = DEFAULT_CONTRACT_VERSION) -> Contract | None: + return self._contracts.get((contract_id, version)) + + def provider_for(self, contract_id: str, version: str = DEFAULT_CONTRACT_VERSION) -> WebUIContractPlugin | None: + return self._providers.get((contract_id, version)) diff --git a/flocks/contracts/access/runtime.py b/flocks/contracts/access/runtime.py new file mode 100644 index 000000000..805cb7c3d --- /dev/null +++ b/flocks/contracts/access/runtime.py @@ -0,0 +1,291 @@ +"""Operation runtime for page data access contracts.""" + +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import Any + +from flocks.auth.context import AuthUser +from flocks.contracts.access.discovery import discover_contract_plugins +from flocks.contracts.access.driver import DriverProxy +from flocks.contracts.access.models import ( + OperationResponse, + PolicyContext, + RuntimeContext, + ContractRuntimeError, + WebUIContractPlugin, +) +from flocks.contracts.access.pipeline import IdempotencyService, MutationPipeline, OverlayStore +from flocks.contracts.access.plans import ( + FieldDependencyPlanCompiler, + MutationPlanCompiler, + PolicyPlanCompiler, + QueryPlanCompiler, +) +from flocks.contracts.access.registry import ( + DEFAULT_CONTRACT_VERSION, + DEFAULT_SLOT_ID, + FORBIDDEN_REQUEST_FIELDS, + QUERY_FORBIDDEN_REQUEST_FIELDS, + ContractRegistry, +) + + +NO_POLICY_SCOPE = "__flocks_no_policy_scope__" + + +class PolicyContextResolver: + def resolve(self, principal: AuthUser | None) -> PolicyContext: + if principal is not None and principal.role == "admin": + return PolicyContext() + if principal is None: + return PolicyContext(tenant_ids=(NO_POLICY_SCOPE,), asset_groups=(NO_POLICY_SCOPE,)) + + tenant_ids = _clean_policy_values(principal.tenant_ids) + asset_groups = _clean_policy_values(principal.asset_groups) + return PolicyContext( + tenant_ids=tenant_ids or (NO_POLICY_SCOPE,), + asset_groups=asset_groups or (NO_POLICY_SCOPE,), + ) + + +class OperationRuntime: + def __init__( + self, + *, + plugins: tuple[WebUIContractPlugin, ...] | None = None, + registry: ContractRegistry | None = None, + policy_context_resolver: PolicyContextResolver | None = None, + driver_proxy: DriverProxy | None = None, + overlay_store: OverlayStore | None = None, + idempotency_service: IdempotencyService | None = None, + project_dir: Path | None = None, + ) -> None: + discovered = plugins if plugins is not None else discover_contract_plugins(project_dir=project_dir) + self._registry = registry or ContractRegistry(discovered) + self._policy_context_resolver = policy_context_resolver or PolicyContextResolver() + self._policy_plan_compiler = PolicyPlanCompiler() + self._field_plan_compiler = FieldDependencyPlanCompiler() + self._query_plan_compiler = QueryPlanCompiler() + self._mutation_plan_compiler = MutationPlanCompiler() + self._driver_proxy = driver_proxy or DriverProxy() + self._overlay_store = overlay_store or OverlayStore() + self._idempotency_service = idempotency_service or IdempotencyService() + + def execute( + self, + *, + page_id: str, + contract_id: str, + operation_name: str, + payload: dict[str, Any] | None, + principal: AuthUser | None, + contract_version: str = DEFAULT_CONTRACT_VERSION, + slot_id: str = DEFAULT_SLOT_ID, + test_mode: bool = False, + ) -> OperationResponse: + request_id = f"req-{uuid.uuid4().hex}" + contract = self._registry.get(contract_id, contract_version) + provider = self._registry.provider_for(contract_id, contract_version) + if contract is None or contract.page_id != page_id: + raise ContractRuntimeError( + "contract_not_found", + status_code=404, + user_message="WebUI contract is not available.", + request_id=request_id, + ) + if provider is None: + raise ContractRuntimeError( + "contract_provider_not_found", + status_code=404, + user_message="WebUI contract provider is not available.", + request_id=request_id, + ) + operation = contract.operations.get(operation_name) + if operation is None: + raise ContractRuntimeError( + "operation_not_found", + status_code=404, + user_message="WebUI contract operation is not available.", + request_id=request_id, + ) + + body = payload or {} + if not isinstance(body, dict): + raise ContractRuntimeError("invalid_request", user_message="Operation body must be an object.", request_id=request_id) + + forbidden = QUERY_FORBIDDEN_REQUEST_FIELDS if operation.operation_type == "query" else FORBIDDEN_REQUEST_FIELDS + forbidden_path = _find_forbidden_field(body, forbidden) + if forbidden_path: + raise ContractRuntimeError( + "forbidden_request_field", + status_code=400, + user_message="Request contains fields that pages are not allowed to submit.", + admin_message=f"Forbidden request field: {forbidden_path}", + request_id=request_id, + ) + + params = body.get("params", {}) + if operation.operation_type == "query" and not isinstance(params, dict): + raise ContractRuntimeError("invalid_request", user_message="Query params must be an object.", request_id=request_id) + + binding = provider.binding_resolver.resolve( + page_id=page_id, + slot_id=slot_id, + contract_id=contract.contract_id, + contract_version=contract.version, + ) + if operation.operation_type not in binding.capabilities: + raise ContractRuntimeError( + "operation_not_supported", + status_code=400, + user_message="WebUI contract data source does not support this operation.", + admin_message=( + f"Binding {binding.binding_id} capabilities do not include " + f"{operation.operation_type}." + ), + request_id=request_id, + ) + policy_context = self._policy_context_resolver.resolve(principal) + context = RuntimeContext( + workspace_id="default", + page_id=page_id, + slot_id=slot_id, + contract_id=contract.contract_id, + contract_version=contract.version, + operation=operation.name, + operation_type=operation.operation_type, + request_id=request_id, + principal_ref=_principal_ref(principal), + policy_context=policy_context, + binding_id=binding.binding_id, + binding_version=binding.binding_version, + test_mode=test_mode, + ) + + try: + if operation.operation_type == "query": + return self._execute_query( + context=context, + binding=binding, + operation=operation, + params=params, + provider=provider, + ) + return self._execute_mutation( + context=context, + binding=binding, + operation=operation, + payload=body, + provider=provider, + ) + except ContractRuntimeError as exc: + if exc.request_id is None: + exc.request_id = request_id + raise + + def _execute_query( + self, + *, + context, + binding, + operation, + params: dict[str, Any], + provider: WebUIContractPlugin, + ) -> OperationResponse: + policy_plan = self._policy_plan_compiler.compile( + operation=operation, + binding=binding, + policy_context=context.policy_context, + params=params, + ) + field_plan = self._field_plan_compiler.compile(operation=operation, policy_plan=policy_plan) + query_plan = self._query_plan_compiler.compile( + context=context, + binding=binding, + operation=operation, + params=params, + policy_plan=policy_plan, + field_plan=field_plan, + ) + driver_result = self._driver_proxy.execute(query_plan) + internal_rows = provider.adapter.normalize(driver_result) + body = provider.response_pipeline.run_query( + context=context, + binding_source_page_id=binding.source_page_id, + driver_result=driver_result, + rows=internal_rows, + filter_stages_applied=policy_plan.filter_stages_applied, + ) + return OperationResponse(status_code=200, body=body) + + def _execute_mutation( + self, + *, + context, + binding, + operation, + payload: dict[str, Any], + provider: WebUIContractPlugin, + ) -> OperationResponse: + mutation_plan = self._mutation_plan_compiler.compile( + context=context, + binding=binding, + operation=operation, + payload=payload, + ) + pipeline = MutationPipeline( + provider.overlay_store or self._overlay_store, + self._idempotency_service, + ) + return OperationResponse(status_code=200, body=pipeline.run(mutation_plan)) + + +class BindingTestHarness: + def __init__(self, runtime: OperationRuntime | None = None) -> None: + self._runtime = runtime or OperationRuntime() + + def run(self, *, page_id: str, contract_id: str, operation_name: str, profiles: tuple[AuthUser | None, ...]) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] + for profile in profiles: + try: + response = self._runtime.execute( + page_id=page_id, + contract_id=contract_id, + operation_name=operation_name, + payload={"params": {"limit": 1}}, + principal=profile, + test_mode=True, + ) + results.append({"ok": True, "statusCode": response.status_code}) + except ContractRuntimeError as exc: + results.append({"ok": False, "statusCode": exc.status_code, "error": exc.to_detail()}) + return results + + +def _find_forbidden_field(value: Any, forbidden: frozenset[str], path: str = "") -> str: + if isinstance(value, dict): + for key, item in value.items(): + current = f"{path}.{key}" if path else str(key) + if key in forbidden: + return current + nested = _find_forbidden_field(item, forbidden, current) + if nested: + return nested + elif isinstance(value, list): + for index, item in enumerate(value): + nested = _find_forbidden_field(item, forbidden, f"{path}[{index}]") + if nested: + return nested + return "" + + +def _principal_ref(principal: AuthUser | None) -> str: + if principal is None: + return "principal:anonymous" + return f"principal:user:{principal.id}" + + +def _clean_policy_values(values: tuple[str, ...]) -> tuple[str, ...]: + return tuple(value.strip() for value in values if isinstance(value, str) and value.strip()) diff --git a/flocks/contracts/access/sources.py b/flocks/contracts/access/sources.py new file mode 100644 index 000000000..e926c3475 --- /dev/null +++ b/flocks/contracts/access/sources.py @@ -0,0 +1,30 @@ +"""Reusable source helpers for page data access contract bindings.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from flocks.contracts.webui.store import WebUIPagesStore + + +@dataclass(frozen=True) +class WebUIPageAssetSource: + page_id: str + root: Path + allowlist_roots: tuple[Path, ...] + + +class WebUIPageAssetSourceResolver: + """Resolve a WebUI page assets directory as a contract data source.""" + + def __init__(self, store: WebUIPagesStore | None = None) -> None: + self._store = store or WebUIPagesStore() + + def resolve(self, page_id: str) -> WebUIPageAssetSource: + root = self._store.asset_path(page_id, "") + return WebUIPageAssetSource( + page_id=page_id, + root=root, + allowlist_roots=(root,), + ) diff --git a/flocks/contracts/webui/__init__.py b/flocks/contracts/webui/__init__.py new file mode 100644 index 000000000..8957b9ff4 --- /dev/null +++ b/flocks/contracts/webui/__init__.py @@ -0,0 +1,6 @@ +"""User-space WebUI page runtime.""" + +from flocks.contracts.webui.store import WebUIPagesStore +from flocks.contracts.webui.watcher import WebUIPagesWatcher + +__all__ = ["WebUIPagesStore", "WebUIPagesWatcher"] diff --git a/flocks/user_defined_pages/api_runtime.py b/flocks/contracts/webui/api_runtime.py similarity index 93% rename from flocks/user_defined_pages/api_runtime.py rename to flocks/contracts/webui/api_runtime.py index 0b04b60cd..70cedbe97 100644 --- a/flocks/user_defined_pages/api_runtime.py +++ b/flocks/contracts/webui/api_runtime.py @@ -1,4 +1,4 @@ -"""Page-scoped API runtime for user-defined pages.""" +"""Page-scoped API runtime for WebUI page plugins.""" from __future__ import annotations @@ -20,11 +20,11 @@ from fastapi.responses import JSONResponse, Response from starlette.requests import ClientDisconnect -from flocks.user_defined_pages.models import UserDefinedPageApiMeta -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.models import WebUIPageApiMeta +from flocks.contracts.webui.store import WebUIPagesStore from flocks.utils.log import Log -log = Log.create(service="user-defined-pages-api-runtime") +log = Log.create(service="webui-page-api-runtime") _ALLOWED_METHODS = {"GET", "POST", "PUT", "PATCH", "DELETE"} _DEFAULT_TIMEOUT_MS = 5000 @@ -60,11 +60,11 @@ class _PageRuntime: loaded_at: int -class UserDefinedPageApiRuntime: +class WebUIPageApiRuntime: """Load and dispatch api/routes.yaml + api/handlers.py for a page.""" - def __init__(self, store: Optional[UserDefinedPagesStore] = None) -> None: - self._store = store or UserDefinedPagesStore() + def __init__(self, store: Optional[WebUIPagesStore] = None) -> None: + self._store = store or WebUIPagesStore() self._cache: dict[str, _PageRuntime] = {} self._lock = asyncio.Lock() @@ -104,7 +104,7 @@ async def dispatch(self, page_id: str, api_path: str, request: Request, user: An await self._guard_request_size(request) except ClientDisconnect: log.info( - "user_defined_pages.api.client_disconnected", + "webui_pages.api.client_disconnected", {"pageId": page_id, "method": request.method, "path": request.url.path}, ) return Response(status_code=_CLIENT_CLOSED_REQUEST_STATUS) @@ -129,7 +129,7 @@ async def dispatch(self, page_id: str, api_path: str, request: Request, user: An raise except Exception as exc: await self._mark_failed(page_id, f"handler execution failed: {exc}") - log.warning("user_defined_pages.api.handler_failed", {"pageId": page_id, "error": str(exc)}) + log.warning("webui_pages.api.handler_failed", {"pageId": page_id, "error": str(exc)}) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="page api execution failed") from exc try: @@ -187,7 +187,7 @@ async def _load_page_runtime(self, page_id: str, *, force_reload: bool) -> _Page self._cache[page_id] = runtime self._store.write_api_meta( page_id, - UserDefinedPageApiMeta( + WebUIPageApiMeta( status="ready", loadedAt=runtime.loaded_at, error=None, @@ -221,7 +221,7 @@ def _compile_runtime( if not isinstance(route_items, list): raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="routes.yaml must contain a routes list") - module_name = f"flocks_user_defined_page_{page_id}_{int(time.time() * 1000)}" + module_name = f"flocks_webui_page_{page_id}_{int(time.time() * 1000)}" spec = importlib.util.spec_from_file_location(module_name, handlers_path) if spec is None or spec.loader is None: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="failed to load handlers.py") @@ -345,7 +345,7 @@ def _normalize_response(self, result: Any) -> Response: return JSONResponse(content=result) def _create_context(self, page_id: str, user: Any) -> Any: - logger = Log.create(service=f"user-defined-page-api:{page_id}") + logger = Log.create(service=f"webui-page-api:{page_id}") return SimpleNamespace( page_id=page_id, user=user, @@ -357,7 +357,7 @@ def _create_context(self, page_id: str, user: Any) -> Any: async def _mark_failed(self, page_id: str, error: str) -> None: self._store.write_api_meta( page_id, - UserDefinedPageApiMeta( + WebUIPageApiMeta( status="failed", loadedAt=int(time.time() * 1000), error=(error or "page api runtime failed")[:2000], @@ -374,11 +374,11 @@ def get(self, key: str, default: Any = None) -> Any: return default if value is None else value -_runtime: Optional[UserDefinedPageApiRuntime] = None +_runtime: Optional[WebUIPageApiRuntime] = None -def get_api_runtime() -> UserDefinedPageApiRuntime: +def get_api_runtime() -> WebUIPageApiRuntime: global _runtime if _runtime is None: - _runtime = UserDefinedPageApiRuntime() + _runtime = WebUIPageApiRuntime() return _runtime diff --git a/flocks/user_defined_pages/bootstrap.py b/flocks/contracts/webui/bootstrap.py similarity index 60% rename from flocks/user_defined_pages/bootstrap.py rename to flocks/contracts/webui/bootstrap.py index 76fc3100c..2395c6a37 100644 --- a/flocks/user_defined_pages/bootstrap.py +++ b/flocks/contracts/webui/bootstrap.py @@ -1,29 +1,29 @@ -"""Startup reconciliation for user-defined pages.""" +"""Startup reconciliation for WebUI page plugins.""" from __future__ import annotations from pathlib import Path from typing import Optional -from flocks.user_defined_pages.api_runtime import UserDefinedPageApiRuntime -from flocks.user_defined_pages.builder import RUNTIME_NAME, RUNTIME_VERSION, UserDefinedPagesBuilder -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.api_runtime import WebUIPageApiRuntime +from flocks.contracts.webui.builder import RUNTIME_NAME, RUNTIME_VERSION, SDK_IMPORT_NAME, WebUIPageBuilder +from flocks.contracts.webui.store import WebUIPagesStore from flocks.utils.log import Log -log = Log.create(service="user-defined-pages-bootstrap") +log = Log.create(service="webui-pages-bootstrap") _SOURCE_SUFFIXES = {".ts", ".tsx", ".js", ".jsx", ".css", ".json"} -async def reconcile_user_defined_pages( +async def reconcile_webui_pages( *, - store: Optional[UserDefinedPagesStore] = None, - builder: Optional[UserDefinedPagesBuilder] = None, - runtime: Optional[UserDefinedPageApiRuntime] = None, + store: Optional[WebUIPagesStore] = None, + builder: Optional[WebUIPageBuilder] = None, + runtime: Optional[WebUIPageApiRuntime] = None, ) -> None: - store = store or UserDefinedPagesStore() - builder = builder or UserDefinedPagesBuilder(store) - runtime = runtime or UserDefinedPageApiRuntime(store) + store = store or WebUIPagesStore() + builder = builder or WebUIPageBuilder(store) + runtime = runtime or WebUIPageApiRuntime(store) store.ensure_root() for page in store.list_pages(enabled_only=False): @@ -35,7 +35,7 @@ async def reconcile_user_defined_pages( try: manifest = store.get_page(page_id).manifest except Exception as exc: - log.warning("user_defined_pages.bootstrap.skip_invalid_manifest", {"pageId": page_id, "error": str(exc)}) + log.warning("webui_pages.bootstrap.skip_invalid_manifest", {"pageId": page_id, "error": str(exc)}) continue if not manifest.enabled: continue @@ -45,29 +45,31 @@ async def reconcile_user_defined_pages( meta = builder.build(page_id) if meta.status != "ready": log.warning( - "user_defined_pages.bootstrap.rebuild_failed", + "webui_pages.bootstrap.rebuild_failed", {"pageId": page_id, "error": meta.error or "build failed"}, ) except Exception as exc: - log.warning("user_defined_pages.bootstrap.rebuild_error", {"pageId": page_id, "error": str(exc)}) + log.warning("webui_pages.bootstrap.rebuild_error", {"pageId": page_id, "error": str(exc)}) try: if store.routes_path(page_id).is_file(): # Warm up page API runtime so restart/upgrade immediately serves APIs. await runtime.reload_page(page_id) except Exception as exc: - log.warning("user_defined_pages.bootstrap.api_preload_failed", {"pageId": page_id, "error": str(exc)}) + log.warning("webui_pages.bootstrap.api_preload_failed", {"pageId": page_id, "error": str(exc)}) -def _should_rebuild_page(store: UserDefinedPagesStore, page_id: str) -> bool: +def _should_rebuild_page(store: WebUIPagesStore, page_id: str) -> bool: bundle_path = store.bundle_path(page_id) build_meta = store.read_build_meta(page_id) if not bundle_path.is_file(): return True - if build_meta.status == "failed": + if build_meta.status != "ready": return True if build_meta.runtime != RUNTIME_NAME or build_meta.runtimeVersion != RUNTIME_VERSION: return True + if build_meta.sdkImport != SDK_IMPORT_NAME: + return True return _sources_newer_than_bundle(store.page_dir(page_id), bundle_path) diff --git a/flocks/user_defined_pages/builder.py b/flocks/contracts/webui/builder.py similarity index 77% rename from flocks/user_defined_pages/builder.py rename to flocks/contracts/webui/builder.py index 895d7486c..1c6b02c9d 100644 --- a/flocks/user_defined_pages/builder.py +++ b/flocks/contracts/webui/builder.py @@ -1,4 +1,4 @@ -"""Build user-defined page TSX sources into browser-loadable ESM bundles.""" +"""Build WebUI page TSX sources into browser-loadable ESM bundles.""" from __future__ import annotations @@ -10,23 +10,23 @@ from pathlib import Path from typing import Optional -from flocks.user_defined_pages.models import UserDefinedPageBuildMeta -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.models import WebUIPageBuildMeta +from flocks.contracts.webui.store import WEBUI_CONTRACT_SDK_IMPORT, WebUIPagesStore from flocks.utils.log import Log -log = Log.create(service="user-defined-pages-builder") +log = Log.create(service="webui-pages-builder") MAX_OUTPUT_BYTES = 2_000_000 BUILD_TIMEOUT_SECONDS = 30 _SHIMS_DIR = Path(__file__).resolve().parent / "shims" -RUNTIME_NAME = "user_defined_page" +RUNTIME_NAME = "webui_page" RUNTIME_VERSION = 1 -SDK_IMPORT_NAME = "@flocks/user-defined-page-sdk" +SDK_IMPORT_NAME = WEBUI_CONTRACT_SDK_IMPORT def _repo_root() -> Path: - # flocks/user_defined_pages/builder.py -> repo root is parents[2] - return Path(__file__).resolve().parents[2] + # flocks/contracts/webui/builder.py -> repo root is parents[3]. + return Path(__file__).resolve().parents[3] def resolve_esbuild_bin() -> Optional[Path]: @@ -42,16 +42,16 @@ def resolve_esbuild_bin() -> Optional[Path]: return None -class UserDefinedPagesBuilder: +class WebUIPageBuilder: """Compile a page entry file into dist/page.js.""" - def __init__(self, store: Optional[UserDefinedPagesStore] = None) -> None: - self._store = store or UserDefinedPagesStore() + def __init__(self, store: Optional[WebUIPagesStore] = None) -> None: + self._store = store or WebUIPagesStore() - def build(self, page_id: str) -> UserDefinedPageBuildMeta: + def build(self, page_id: str) -> WebUIPageBuildMeta: page_id = self._store.validate_page_id(page_id) detail = self._store.get_page(page_id) - page_dir = self._store.page_dir(page_id) + page_dir = self._store.writable_page_dir(page_id) entry = detail.manifest.entry.replace("\\", "/") entry_path = (page_dir / entry).resolve() try: @@ -69,7 +69,7 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: dist_dir.mkdir(parents=True, exist_ok=True) outfile = dist_dir / "page.js" - building = UserDefinedPageBuildMeta(status="building", hash="", builtAt=0, error=None) + building = WebUIPageBuildMeta(status="building", hash="", builtAt=0, error=None) self._store.write_build_meta(page_id, building) cmd = [ @@ -83,7 +83,7 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: "--jsx=automatic", f"--alias:react={_SHIMS_DIR / 'react.js'}", f"--alias:react/jsx-runtime={_SHIMS_DIR / 'jsx-runtime.js'}", - f"--alias:@flocks/user-defined-page-sdk={_SHIMS_DIR / 'sdk.js'}", + f"--alias:{WEBUI_CONTRACT_SDK_IMPORT}={_SHIMS_DIR / 'sdk.js'}", ] env = os.environ.copy() @@ -98,7 +98,7 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: check=False, ) except subprocess.TimeoutExpired as exc: - meta = UserDefinedPageBuildMeta( + meta = WebUIPageBuildMeta( status="failed", hash="", builtAt=int(time.time() * 1000), @@ -112,7 +112,7 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: if result.returncode != 0: stderr = (result.stderr or result.stdout or "esbuild failed").strip() - meta = UserDefinedPageBuildMeta( + meta = WebUIPageBuildMeta( status="failed", hash="", builtAt=int(time.time() * 1000), @@ -122,11 +122,11 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: sdkImport=SDK_IMPORT_NAME, ) self._store.write_build_meta(page_id, meta) - log.warning("user_defined_pages.build.failed", {"pageId": page_id, "error": stderr[:500]}) + log.warning("webui_pages.build.failed", {"pageId": page_id, "error": stderr[:500]}) return meta if not outfile.is_file(): - meta = UserDefinedPageBuildMeta( + meta = WebUIPageBuildMeta( status="failed", hash="", builtAt=int(time.time() * 1000), @@ -141,7 +141,7 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: content = outfile.read_bytes() if len(content) > MAX_OUTPUT_BYTES: outfile.unlink(missing_ok=True) - meta = UserDefinedPageBuildMeta( + meta = WebUIPageBuildMeta( status="failed", hash="", builtAt=int(time.time() * 1000), @@ -154,7 +154,7 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: return meta digest = hashlib.sha256(content).hexdigest()[:16] - meta = UserDefinedPageBuildMeta( + meta = WebUIPageBuildMeta( status="ready", hash=digest, builtAt=int(time.time() * 1000), @@ -164,5 +164,5 @@ def build(self, page_id: str) -> UserDefinedPageBuildMeta: sdkImport=SDK_IMPORT_NAME, ) self._store.write_build_meta(page_id, meta) - log.info("user_defined_pages.build.ready", {"pageId": page_id, "hash": digest}) + log.info("webui_pages.build.ready", {"pageId": page_id, "hash": digest}) return meta diff --git a/flocks/contracts/webui/models.py b/flocks/contracts/webui/models.py new file mode 100644 index 000000000..863745fc0 --- /dev/null +++ b/flocks/contracts/webui/models.py @@ -0,0 +1,131 @@ +"""Pydantic models for WebUI page plugins.""" + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class WebUIPageManifest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + id: str = Field(..., description="Stable page identifier") + title: str = Field(..., description="Navigation label") + route: str = Field(..., description="WebUI route path") + icon: str = Field("LayoutDashboard", description="Lucide icon name") + order: int = Field(100, description="Sort order in navigation") + enabled: bool = Field(True, description="Whether page appears in navigation") + placement: Literal["home.after"] = Field( + "home.after", + description="Where to insert the nav item", + ) + entry: str = Field("src/index.tsx", description="Source entry relative to page dir") + updatedAt: int = Field(0, description="Last manifest update timestamp (ms)") + + +class WebUIWorkspaceManifest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + id: str = Field(..., description="Stable workspace identifier") + title: str = Field(..., description="Navigation label") + icon: str = Field("LayoutDashboard", description="Lucide icon name") + order: int = Field(100, description="Sort order in navigation") + enabled: bool = Field(True, description="Whether workspace appears in navigation") + placement: Literal["sceneWorkspace", "aiWorkbench"] = Field( + "sceneWorkspace", + description="Where to insert the nav item", + ) + defaultPageId: Optional[str] = Field(None, description="Preferred default page id", alias="defaultPageId") + sections: list["WebUIWorkspaceSectionManifest"] = Field( + default_factory=list, + description="Workspace navigation sections", + ) + + +class WebUIWorkspaceSectionManifest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + id: str = Field(..., description="Stable section identifier") + label: str = Field(..., description="Section label") + pageIds: list[str] = Field( + default_factory=list, + description="Page ids in this section", + alias="pageIds", + ) + defaultPageId: Optional[str] = Field( + None, + description="Preferred default page id for this section", + alias="defaultPageId", + ) + contentPadding: Literal["comfortable", "none"] = Field( + "comfortable", + description="Whether the host should add standard page padding", + alias="contentPadding", + ) + themeOverride: Optional[Literal["light", "dark"]] = Field( + None, + description="Temporary theme override while viewing pages in this section", + alias="themeOverride", + ) + + +class WebUIPageBuildMeta(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + hash: str = Field("", description="Content hash for cache busting") + builtAt: int = Field(0, description="Build timestamp (ms)") + status: Literal["idle", "building", "ready", "failed"] = Field("idle") + error: Optional[str] = Field(None, description="Last build error message") + runtime: str = Field("webui_page", description="Builder runtime marker") + runtimeVersion: int = Field(1, description="Builder runtime version") + sdkImport: str = Field("@flocks/webui-contract-sdk", description="SDK import marker") + + +class WebUIPageApiMeta(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + status: Literal["idle", "ready", "failed"] = Field("idle") + loadedAt: int = Field(0, description="Runtime load timestamp (ms)") + error: Optional[str] = Field(None, description="Last API runtime error") + routes: list[dict[str, str]] = Field(default_factory=list, description="Loaded route descriptors") + + +class WebUIPageListItem(BaseModel): + model_config = ConfigDict(populate_by_name=True, by_alias=True) + + id: str + title: str + route: str + icon: str + order: int + enabled: bool + placement: str + buildHash: str = Field("", alias="buildHash") + buildStatus: str = Field("idle", alias="buildStatus") + workspaceId: Optional[str] = Field(None, alias="workspaceId") + workspaceTitle: Optional[str] = Field(None, alias="workspaceTitle") + workspaceRoute: Optional[str] = Field(None, alias="workspaceRoute") + + +class WebUIWorkspaceListItem(BaseModel): + model_config = ConfigDict(populate_by_name=True, by_alias=True) + + id: str + title: str + route: str + icon: str + order: int + enabled: bool + placement: str + defaultPageId: Optional[str] = Field(None, alias="defaultPageId") + sections: list[WebUIWorkspaceSectionManifest] = Field(default_factory=list) + pages: list[WebUIPageListItem] = Field(default_factory=list) + + +class WebUIPageDetail(BaseModel): + model_config = ConfigDict(populate_by_name=True, by_alias=True) + + manifest: WebUIPageManifest + build: WebUIPageBuildMeta + sourceFiles: list[str] = Field(default_factory=list, alias="sourceFiles") diff --git a/flocks/contracts/webui/shims/jsx-runtime.js b/flocks/contracts/webui/shims/jsx-runtime.js new file mode 100644 index 000000000..715c020fd --- /dev/null +++ b/flocks/contracts/webui/shims/jsx-runtime.js @@ -0,0 +1,7 @@ +const runtime = globalThis.__FLOCKS_WEBUI_CONTRACT_SDK__; +if (!runtime?.jsx || !runtime?.jsxs) { + throw new Error('Flocks WebUI page runtime is not initialized (missing jsx runtime).'); +} +export const jsx = runtime.jsx; +export const jsxs = runtime.jsxs; +export const Fragment = runtime.React.Fragment; diff --git a/flocks/user_defined_pages/shims/react.js b/flocks/contracts/webui/shims/react.js similarity index 77% rename from flocks/user_defined_pages/shims/react.js rename to flocks/contracts/webui/shims/react.js index 94ddbe5eb..ada1ba04d 100644 --- a/flocks/user_defined_pages/shims/react.js +++ b/flocks/contracts/webui/shims/react.js @@ -1,6 +1,6 @@ -const React = globalThis.__FLOCKS_USER_DEFINED_PAGE_SDK__?.React; +const React = globalThis.__FLOCKS_WEBUI_CONTRACT_SDK__?.React; if (!React) { - throw new Error('Flocks user-defined page runtime is not initialized (missing React).'); + throw new Error('Flocks WebUI page runtime is not initialized (missing React).'); } export default React; export const { diff --git a/flocks/contracts/webui/shims/sdk.js b/flocks/contracts/webui/shims/sdk.js new file mode 100644 index 000000000..61291307b --- /dev/null +++ b/flocks/contracts/webui/shims/sdk.js @@ -0,0 +1,8 @@ +const sdk = globalThis.__FLOCKS_WEBUI_CONTRACT_SDK__; +if (!sdk) { + throw new Error('Flocks WebUI page runtime is not initialized (missing SDK).'); +} +export const api = sdk.api; +export const contract = sdk.api.contract; +export const Card = sdk.Card; +export const useCurrentUser = sdk.useCurrentUser; diff --git a/flocks/contracts/webui/store.py b/flocks/contracts/webui/store.py new file mode 100644 index 000000000..cec1111bc --- /dev/null +++ b/flocks/contracts/webui/store.py @@ -0,0 +1,767 @@ +"""Filesystem store for WebUI pages.""" + +from __future__ import annotations + +import json +import os +import re +import shutil +import time +from pathlib import Path +from typing import Any, Optional + +from flocks.contracts.webui.models import ( + WebUIPageApiMeta, + WebUIPageBuildMeta, + WebUIPageDetail, + WebUIPageListItem, + WebUIPageManifest, + WebUIWorkspaceListItem, + WebUIWorkspaceManifest, +) +from flocks.utils.log import Log + +log = Log.create(service="webui-pages-store") + +PAGE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9-]*$") +WORKSPACE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_]*$") +MAX_SOURCE_FILE_BYTES = 512_000 +ALLOWED_WRITE_PREFIXES = ("src/", "assets/", "api/") +ALLOWED_WRITE_FILES = frozenset({"manifest.json"}) +WORKSPACE_MANIFEST_FILE = "workspace.json" +_SOURCE_SUFFIXES = {".tsx", ".ts", ".jsx", ".js", ".css", ".json"} +_API_SUFFIXES = {".py", ".yaml", ".yml"} +_MIGRATION_TEXT_SUFFIXES = _SOURCE_SUFFIXES | _API_SUFFIXES +_PROJECT_ROOT_UNSET = object() +_LEGACY_ROOT_UNSET = object() +WEBUI_CONTRACT_ROUTE_PREFIX = "/contracts/webui" +WEBUI_CONTRACT_SDK_IMPORT = "@flocks/webui-contract-sdk" +LEGACY_WEBUI_PAGE_ROUTE_PREFIX = "/user-defined-pages" +LEGACY_WEBUI_PAGE_SDK_IMPORT = "@flocks/user-defined-page-sdk" +LEGACY_WEBUI_PAGE_SDK_GLOBAL = "__FLOCKS_USER_DEFINED_PAGE_SDK__" +WEBUI_CONTRACT_SDK_GLOBAL = "__FLOCKS_WEBUI_CONTRACT_SDK__" + + +def _default_page_tsx(title: str) -> str: + safe_title = title.replace("\\", "\\\\").replace('"', '\\"') + return f"""import {{ useEffect, useState }} from 'react'; +import {{ Card }} from '{WEBUI_CONTRACT_SDK_IMPORT}'; + +export default function Page() {{ + const [ready, setReady] = useState(false); + + useEffect(() => {{ + setReady(true); + }}, []); + + return ( + + {{ready ? 'Ready' : 'Loading...'}} + + ); +}} +""" + +_DEFAULT_INDEX_TSX = """import Page from './Page'; + +export default Page; +""" + + +def webui_contract_page_route(page_id: str) -> str: + return f"{WEBUI_CONTRACT_ROUTE_PREFIX}/{page_id}" + + +def webui_contract_workspace_route(workspace_id: str, page_id: Optional[str] = None) -> str: + base = f"{WEBUI_CONTRACT_ROUTE_PREFIX}/workspaces/{workspace_id}" + return f"{base}/{page_id}" if page_id else base + + +def get_webui_pages_root() -> Path: + """Return the canonical user-space write root for WebUI pages.""" + override = os.environ.get("FLOCKS_CONTRACTS_WEBUI_ROOT") + if override: + return Path(override).expanduser().resolve() + return (Path.home() / ".flocks" / "plugins" / "contracts" / "webui").resolve() + + +def get_legacy_webui_pages_root() -> Path: + """Return the legacy user-space root used before WebUI contracts.""" + override = os.environ.get("FLOCKS_USER_DEFINED_PAGES_ROOT") + if override: + return Path(override).expanduser().resolve() + return (Path.home() / ".flocks" / "plugins" / "user_defined_pages").resolve() + + +def get_project_webui_pages_root(project_dir: Optional[Path] = None) -> Path: + """Return the project-space read root for checked-in WebUI pages.""" + base = project_dir or Path.cwd() + return (base / ".flocks" / "plugins" / "contracts" / "webui").resolve() + + +class WebUIPagesStore: + """CRUD and scan helpers for user-space WebUI pages.""" + + def __init__( + self, + root: Optional[Path] = None, + *, + project_root: Optional[Path] | object = _PROJECT_ROOT_UNSET, + legacy_root: Optional[Path] | object = _LEGACY_ROOT_UNSET, + project_dir: Optional[Path] = None, + ) -> None: + env_override = os.environ.get("FLOCKS_CONTRACTS_WEBUI_ROOT") + self._root = (root or get_webui_pages_root()).resolve() + if project_root is _PROJECT_ROOT_UNSET: + project_root = None if root is not None or env_override else get_project_webui_pages_root(project_dir) + if legacy_root is _LEGACY_ROOT_UNSET: + legacy_env = os.environ.get("FLOCKS_USER_DEFINED_PAGES_ROOT") + legacy_root = get_legacy_webui_pages_root() if legacy_env or (root is None and not env_override) else None + self._project_root = project_root.resolve() if isinstance(project_root, Path) else None + self._legacy_root = legacy_root.resolve() if isinstance(legacy_root, Path) else None + self._read_roots = self._dedupe_roots(self._root, self._legacy_root, self._project_root) + self._legacy_migration_done = False + + @property + def root(self) -> Path: + return self._root + + @property + def read_roots(self) -> tuple[Path, ...]: + return self._read_roots + + def ensure_root(self) -> Path: + self._root.mkdir(parents=True, exist_ok=True) + self._migrate_legacy_pages() + return self._root + + @staticmethod + def validate_page_id(page_id: str) -> str: + normalized = (page_id or "").strip().lower() + if not PAGE_ID_RE.fullmatch(normalized): + raise ValueError("invalid page id: use lowercase letters, numbers, and hyphens") + return normalized + + @staticmethod + def validate_workspace_id(workspace_id: str) -> str: + normalized = (workspace_id or "").strip().lower() + if not WORKSPACE_ID_RE.fullmatch(normalized): + raise ValueError("invalid workspace id: use lowercase letters, numbers, and underscores") + return normalized + + def page_dir(self, page_id: str) -> Path: + page_id = self.validate_page_id(page_id) + existing = self._find_page_dir(page_id) + if existing is not None: + return existing + return self._page_dir_in_root(self._root, page_id) + + def root_page_dir(self, page_id: str) -> Path: + """Return the canonical user-root path for a page without copying.""" + return self._page_dir_in_root(self._root, self.validate_page_id(page_id)) + + def writable_page_dir(self, page_id: str) -> Path: + """Return a writable page directory, copying read-only pages on write.""" + page_id = self.validate_page_id(page_id) + self.ensure_root() + target = self._page_dir_in_root(self._root, page_id) + if target.is_dir(): + return target + + source = self._find_page_dir(page_id) + if source is not None: + try: + source.resolve().relative_to(self._root.resolve()) + return source + except ValueError: + pass + if source is not None and source != target: + shutil.copytree(source, target) + self._normalize_migrated_page(target, page_id) + log.info("webui_pages.materialized", {"pageId": page_id, "source": str(source), "target": str(target)}) + return target + + def page_exists(self, page_id: str) -> bool: + return self._find_page_dir(self.validate_page_id(page_id)) is not None + + def _assert_writable_relative(self, relative_path: str) -> Path: + if not relative_path or Path(relative_path).is_absolute(): + raise ValueError("absolute path is not allowed") + rel = relative_path.replace("\\", "/").lstrip("/") + if rel in ALLOWED_WRITE_FILES: + return Path(rel) + if any(rel.startswith(prefix) for prefix in ALLOWED_WRITE_PREFIXES): + parts = rel.split("/") + if ".." in parts: + raise ValueError("path traversal is not allowed") + if any(part.startswith(".") for part in parts if part): + raise ValueError("hidden path is not allowed") + return Path(rel) + raise ValueError(f"writes are not allowed for path: {relative_path}") + + def list_pages(self, *, enabled_only: bool = False) -> list[WebUIPageListItem]: + self.ensure_root() + items: list[WebUIPageListItem] = [] + seen_keys: set[str] = set() + for root in self._read_roots: + if not root.is_dir(): + continue + for page_dir, page_id in self._iter_page_dirs(root): + if page_id in seen_keys: + continue + manifest = self._read_manifest_at(page_dir, page_id) + if manifest is None: + continue + if page_id in seen_keys or manifest.id in seen_keys: + continue + seen_keys.update({page_id, manifest.id}) + if enabled_only and not manifest.enabled: + continue + build = self._read_build_meta_at(page_dir) + workspace = self._workspace_for_page_dir(root, page_dir) + items.append( + WebUIPageListItem( + id=manifest.id, + title=manifest.title, + route=manifest.route, + icon=manifest.icon, + order=manifest.order, + enabled=manifest.enabled, + placement=manifest.placement, + buildHash=build.hash, + buildStatus=build.status, + workspaceId=workspace.id if workspace else None, + workspaceTitle=workspace.title if workspace else None, + workspaceRoute=webui_contract_workspace_route(workspace.id) if workspace else None, + ) + ) + items.sort(key=lambda item: (item.order, item.title)) + return items + + def list_workspaces(self, *, enabled_only: bool = False) -> list[WebUIWorkspaceListItem]: + self.ensure_root() + workspaces: list[WebUIWorkspaceListItem] = [] + seen_workspace_ids: set[str] = set() + for root in self._read_roots: + if not root.is_dir(): + continue + for workspace_dir, manifest in self._iter_workspace_dirs(root): + if manifest.id in seen_workspace_ids: + continue + seen_workspace_ids.add(manifest.id) + if enabled_only and not manifest.enabled: + continue + + pages: list[WebUIPageListItem] = [] + seen_page_ids: set[str] = set() + for page_dir, page_id in self._iter_page_dirs(workspace_dir): + if page_id in seen_page_ids: + continue + page_manifest = self._read_manifest_at(page_dir, page_id) + if page_manifest is None: + continue + seen_page_ids.add(page_manifest.id) + if enabled_only and not page_manifest.enabled: + continue + build = self._read_build_meta_at(page_dir) + pages.append( + WebUIPageListItem( + id=page_manifest.id, + title=page_manifest.title, + route=page_manifest.route, + icon=page_manifest.icon, + order=page_manifest.order, + enabled=page_manifest.enabled, + placement=page_manifest.placement, + buildHash=build.hash, + buildStatus=build.status, + workspaceId=manifest.id, + workspaceTitle=manifest.title, + workspaceRoute=webui_contract_workspace_route(manifest.id), + ) + ) + pages.sort(key=lambda item: (item.order, item.title)) + workspaces.append( + WebUIWorkspaceListItem( + id=manifest.id, + title=manifest.title, + route=webui_contract_workspace_route(manifest.id), + icon=manifest.icon, + order=manifest.order, + enabled=manifest.enabled, + placement=manifest.placement, + defaultPageId=manifest.defaultPageId, + sections=manifest.sections, + pages=pages, + ) + ) + workspaces.sort(key=lambda item: (item.order, item.title)) + return workspaces + + def get_page(self, page_id: str) -> WebUIPageDetail: + self.ensure_root() + page_dir = self.page_dir(page_id) + if not page_dir.is_dir(): + raise FileNotFoundError(f"page not found: {page_id}") + manifest = self._read_manifest(page_id) + if manifest is None: + raise FileNotFoundError(f"manifest missing for page: {page_id}") + build = self._read_build_meta(page_id) + source_files = sorted( + str(path.relative_to(page_dir)).replace("\\", "/") + for path in page_dir.rglob("*") + if path.is_file() and "dist/" not in str(path.relative_to(page_dir)).replace("\\", "/") + ) + return WebUIPageDetail(manifest=manifest, build=build, sourceFiles=source_files) + + def create_page( + self, + *, + page_id: str, + title: str, + icon: str = "LayoutDashboard", + order: int = 100, + ) -> WebUIPageDetail: + page_id = self.validate_page_id(page_id) + if self.page_exists(page_id): + raise FileExistsError(f"page already exists: {page_id}") + page_dir = self._page_dir_in_root(self._root, page_id) + + now_ms = int(time.time() * 1000) + manifest = WebUIPageManifest( + id=page_id, + title=title.strip() or page_id, + route=webui_contract_page_route(page_id), + icon=icon, + order=order, + enabled=True, + placement="home.after", + entry="src/index.tsx", + updatedAt=now_ms, + ) + + page_dir.mkdir(parents=True, exist_ok=False) + (page_dir / "src").mkdir(parents=True, exist_ok=True) + (page_dir / "api").mkdir(parents=True, exist_ok=True) + (page_dir / "assets").mkdir(parents=True, exist_ok=True) + (page_dir / "dist").mkdir(parents=True, exist_ok=True) + + self._write_manifest(page_id, manifest) + self._write_source_file(page_id, "src/Page.tsx", _default_page_tsx(manifest.title)) + self._write_source_file(page_id, "src/index.tsx", _DEFAULT_INDEX_TSX) + self._write_build_meta( + page_id, + WebUIPageBuildMeta(status="idle", hash="", builtAt=0, error=None), + ) + log.info("webui_pages.created", {"pageId": page_id}) + return self.get_page(page_id) + + def save_manifest(self, page_id: str, manifest_data: dict[str, Any]) -> WebUIPageManifest: + page_id = self.validate_page_id(page_id) + existing = self._read_manifest(page_id) + if existing is None: + raise FileNotFoundError(f"page not found: {page_id}") + + merged = existing.model_dump() + merged.update(manifest_data) + merged["id"] = page_id + merged["route"] = webui_contract_page_route(page_id) + merged["updatedAt"] = int(time.time() * 1000) + manifest = WebUIPageManifest.model_validate(merged) + self._write_manifest(page_id, manifest) + return manifest + + def save_source_file(self, page_id: str, relative_path: str, content: str) -> None: + rel = self._assert_writable_relative(relative_path) + rel_str = str(rel).replace("\\", "/") + if rel_str.startswith("api/"): + allowed_suffixes = _API_SUFFIXES + else: + allowed_suffixes = _SOURCE_SUFFIXES + if rel.suffix not in allowed_suffixes: + raise ValueError("unsupported source file type") + encoded = content.encode("utf-8") + if len(encoded) > MAX_SOURCE_FILE_BYTES: + raise ValueError("source file is too large") + self._write_source_file(page_id, rel_str, content) + + def read_source_file(self, page_id: str, relative_path: str) -> str: + self.ensure_root() + rel = self._assert_writable_relative(relative_path) + path = self.page_dir(page_id) / rel + if not path.is_file(): + raise FileNotFoundError(relative_path) + return path.read_text(encoding="utf-8") + + def bundle_path(self, page_id: str) -> Path: + self.ensure_root() + return self.page_dir(page_id) / "dist" / "page.js" + + def asset_path(self, page_id: str, relative_path: str) -> Path: + self.ensure_root() + rel = relative_path.replace("\\", "/").lstrip("/") + if ".." in rel.split("/"): + raise ValueError("path traversal is not allowed") + path = (self.page_dir(page_id) / "assets" / rel).resolve() + assets_root = (self.page_dir(page_id) / "assets").resolve() + try: + path.relative_to(assets_root) + except ValueError: + raise ValueError("invalid asset path") + return path + + def write_build_meta(self, page_id: str, meta: WebUIPageBuildMeta) -> None: + self._write_build_meta(page_id, meta) + + def read_build_meta(self, page_id: str) -> WebUIPageBuildMeta: + return self._read_build_meta(page_id) + + def routes_path(self, page_id: str) -> Path: + self.ensure_root() + return self.page_dir(page_id) / "api" / "routes.yaml" + + def api_handlers_path(self, page_id: str) -> Path: + self.ensure_root() + return self.page_dir(page_id) / "api" / "handlers.py" + + def read_api_routes(self, page_id: str) -> Optional[str]: + path = self.routes_path(page_id) + if not path.is_file(): + return None + return path.read_text(encoding="utf-8") + + def write_api_meta(self, page_id: str, meta: WebUIPageApiMeta) -> None: + path = self._api_meta_path(page_id) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(meta.model_dump(), ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def read_api_meta(self, page_id: str) -> WebUIPageApiMeta: + path = self._api_meta_path(page_id) + if not path.is_file(): + return WebUIPageApiMeta() + try: + raw = json.loads(path.read_text(encoding="utf-8")) + return WebUIPageApiMeta.model_validate(raw) + except Exception: + return WebUIPageApiMeta() + + def _manifest_path(self, page_id: str) -> Path: + return self.page_dir(page_id) / "manifest.json" + + def _build_meta_path(self, page_id: str) -> Path: + return self.page_dir(page_id) / "dist" / "meta.json" + + def _api_meta_path(self, page_id: str) -> Path: + return self.page_dir(page_id) / "dist" / "api-meta.json" + + def _read_manifest(self, page_id: str) -> Optional[WebUIPageManifest]: + return self._read_manifest_at(self.page_dir(page_id), page_id) + + def _read_manifest_at(self, page_dir: Path, page_id: str) -> Optional[WebUIPageManifest]: + path = page_dir / "manifest.json" + if not path.is_file(): + return None + try: + raw = json.loads(path.read_text(encoding="utf-8")) + manifest = WebUIPageManifest.model_validate(raw) + expected_route = webui_contract_page_route(page_id) + if manifest.id != page_id or manifest.route != expected_route: + return manifest.model_copy(update={"id": page_id, "route": expected_route}) + return manifest + except Exception as exc: + log.warning("webui_pages.manifest.invalid", {"pageId": page_id, "error": str(exc)}) + return None + + def _write_manifest(self, page_id: str, manifest: WebUIPageManifest) -> None: + self._write_manifest_at(self.writable_page_dir(page_id), manifest) + + def _read_build_meta(self, page_id: str) -> WebUIPageBuildMeta: + return self._read_build_meta_at(self.page_dir(page_id)) + + @staticmethod + def _read_build_meta_at(page_dir: Path) -> WebUIPageBuildMeta: + path = page_dir / "dist" / "meta.json" + if not path.is_file(): + return WebUIPageBuildMeta() + try: + raw = json.loads(path.read_text(encoding="utf-8")) + return WebUIPageBuildMeta.model_validate(raw) + except Exception: + return WebUIPageBuildMeta() + + def _write_build_meta(self, page_id: str, meta: WebUIPageBuildMeta) -> None: + self._write_build_meta_at(self.writable_page_dir(page_id), meta) + + @staticmethod + def _write_manifest_at(page_dir: Path, manifest: WebUIPageManifest) -> None: + path = page_dir / "manifest.json" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps(manifest.model_dump(), ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + @staticmethod + def _write_build_meta_at(page_dir: Path, meta: WebUIPageBuildMeta) -> None: + path = page_dir / "dist" / "meta.json" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(meta.model_dump(), ensure_ascii=False, indent=2), encoding="utf-8") + + def _write_source_file(self, page_id: str, relative_path: str, content: str) -> None: + rel = self._assert_writable_relative(relative_path) + target = self.writable_page_dir(page_id) / rel + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + + def _migrate_legacy_pages(self) -> None: + if self._legacy_migration_done: + return + self._legacy_migration_done = True + legacy_root = self._legacy_root + if legacy_root is None or not legacy_root.is_dir() or legacy_root == self._root: + return + + for child in sorted(legacy_root.iterdir()): + if not child.is_dir(): + continue + try: + page_id = self.validate_page_id(child.name) + target = self._page_dir_in_root(self._root, page_id) + except ValueError: + continue + if target.exists() or self._find_page_dir_in_root(self._root, page_id) is not None: + continue + try: + shutil.copytree(child, target) + self._normalize_migrated_page(target, page_id) + log.info("webui_pages.legacy_migrated", {"pageId": page_id, "source": str(child), "target": str(target)}) + except Exception as exc: + log.warning("webui_pages.legacy_migration_failed", {"pageId": child.name, "error": str(exc)}) + + def _normalize_migrated_page(self, page_dir: Path, page_id: str) -> None: + manifest_path = page_dir / "manifest.json" + if manifest_path.is_file(): + try: + raw = json.loads(manifest_path.read_text(encoding="utf-8")) + manifest = WebUIPageManifest.model_validate(raw).model_copy( + update={ + "id": page_id, + "route": webui_contract_page_route(page_id), + "updatedAt": int(time.time() * 1000), + } + ) + self._write_manifest_at(page_dir, manifest) + except Exception as exc: + log.warning("webui_pages.legacy_manifest_normalize_failed", {"pageId": page_id, "error": str(exc)}) + + replacements = { + LEGACY_WEBUI_PAGE_SDK_IMPORT: WEBUI_CONTRACT_SDK_IMPORT, + LEGACY_WEBUI_PAGE_SDK_GLOBAL: WEBUI_CONTRACT_SDK_GLOBAL, + f"{LEGACY_WEBUI_PAGE_ROUTE_PREFIX}/": f"{WEBUI_CONTRACT_ROUTE_PREFIX}/", + "/api/user-defined-pages/": "/api/contracts/webui/pages/", + } + for path in page_dir.rglob("*"): + if not path.is_file() or path.suffix not in _MIGRATION_TEXT_SUFFIXES: + continue + rel = str(path.relative_to(page_dir)).replace("\\", "/") + if not (rel.startswith("src/") or rel.startswith("api/") or rel == "dist/page.js"): + continue + try: + content = path.read_text(encoding="utf-8") + except UnicodeDecodeError: + continue + updated = content + for old, new in replacements.items(): + updated = updated.replace(old, new) + if updated != content: + path.write_text(updated, encoding="utf-8") + + self._write_build_meta_at(page_dir, WebUIPageBuildMeta(status="idle", hash="", builtAt=0, error=None)) + + @staticmethod + def _dedupe_roots(*roots: Optional[Path]) -> tuple[Path, ...]: + result: list[Path] = [] + seen: set[Path] = set() + for root in roots: + if root is None: + continue + resolved = root.resolve() + if resolved in seen: + continue + seen.add(resolved) + result.append(resolved) + return tuple(result) + + def _find_page_dir(self, page_id: str) -> Optional[Path]: + for root in self._read_roots: + page_dir = self._find_page_dir_in_root(root, page_id) + if page_dir is not None: + return page_dir + return None + + def _find_page_dir_in_root(self, root: Path, page_id: str) -> Optional[Path]: + candidate = self._page_dir_in_root(root, page_id) + if candidate.is_dir(): + return candidate + if not root.is_dir(): + return None + for page_dir, manifest_page_id in self._iter_page_dirs(root): + if manifest_page_id == page_id: + return page_dir + return None + + def page_id_for_path(self, path: Path) -> Optional[str]: + resolved_path = path.resolve(strict=False) + for root in self._read_roots: + if not root.is_dir(): + continue + resolved_root = root.resolve() + try: + resolved_path.relative_to(resolved_root) + except ValueError: + continue + + probe = resolved_path if resolved_path.is_dir() else resolved_path.parent + while True: + try: + probe.relative_to(resolved_root) + except ValueError: + break + page_id = self._manifest_page_id_at(probe / "manifest.json") + if page_id is not None: + return page_id + if probe == resolved_root: + break + probe = probe.parent + return None + + def workspace_id_for_path(self, path: Path) -> Optional[str]: + resolved_path = path.resolve(strict=False) + for root in self._read_roots: + if not root.is_dir(): + continue + resolved_root = root.resolve() + try: + resolved_path.relative_to(resolved_root) + except ValueError: + continue + + probe = resolved_path if resolved_path.is_dir() else resolved_path.parent + while True: + try: + probe.relative_to(resolved_root) + except ValueError: + break + manifest = self._read_workspace_manifest_at(probe) + if manifest is not None: + return manifest.id + if probe == resolved_root: + break + probe = probe.parent + return None + + def _iter_page_dirs(self, root: Path) -> list[tuple[Path, str]]: + page_dirs: list[tuple[Path, str]] = [] + for manifest_path in sorted( + root.rglob("manifest.json"), + key=lambda path: (len(path.relative_to(root).parts), str(path.relative_to(root))), + ): + page_dir = manifest_path.parent + if page_dir == root: + continue + page_id = self._manifest_page_id_at(manifest_path) + if page_id is None: + continue + page_dirs.append((page_dir, page_id)) + return page_dirs + + def _manifest_page_id_at(self, manifest_path: Path) -> Optional[str]: + if not manifest_path.is_file(): + return None + try: + raw = json.loads(manifest_path.read_text(encoding="utf-8")) + return self.validate_page_id(str(raw.get("id", ""))) + except Exception: + return None + + def _iter_workspace_dirs(self, root: Path) -> list[tuple[Path, WebUIWorkspaceManifest]]: + workspaces: list[tuple[Path, WebUIWorkspaceManifest]] = [] + for manifest_path in sorted( + root.rglob(WORKSPACE_MANIFEST_FILE), + key=lambda path: (len(path.relative_to(root).parts), str(path.relative_to(root))), + ): + workspace_dir = manifest_path.parent + if workspace_dir == root: + continue + manifest = self._read_workspace_manifest_at(workspace_dir) + if manifest is None: + continue + workspaces.append((workspace_dir, manifest)) + return workspaces + + def _workspace_for_page_dir(self, root: Path, page_dir: Path) -> Optional[WebUIWorkspaceManifest]: + resolved_root = root.resolve() + probe = page_dir.resolve().parent + while True: + try: + probe.relative_to(resolved_root) + except ValueError: + return None + manifest = self._read_workspace_manifest_at(probe) + if manifest is not None: + return manifest + if probe == resolved_root: + return None + probe = probe.parent + + def _read_workspace_manifest_at(self, workspace_dir: Path) -> Optional[WebUIWorkspaceManifest]: + path = workspace_dir / WORKSPACE_MANIFEST_FILE + if not path.is_file(): + return None + try: + raw = json.loads(path.read_text(encoding="utf-8")) + manifest = WebUIWorkspaceManifest.model_validate(raw) + workspace_id = self.validate_workspace_id(manifest.id) + default_page_id = self.validate_page_id(manifest.defaultPageId) if manifest.defaultPageId else None + sections = [] + sections_changed = False + for section in manifest.sections: + section_id = self.validate_workspace_id(section.id) + page_ids = [] + for page_id in section.pageIds: + normalized_page_id = self.validate_page_id(page_id) + if normalized_page_id not in page_ids: + page_ids.append(normalized_page_id) + section_default_page_id = self.validate_page_id(section.defaultPageId) if section.defaultPageId else None + if section_default_page_id and section_default_page_id not in page_ids: + page_ids.insert(0, section_default_page_id) + updated_section = section.model_copy( + update={ + "id": section_id, + "pageIds": page_ids, + "defaultPageId": section_default_page_id, + } + ) + sections.append(updated_section) + sections_changed = sections_changed or updated_section != section + if manifest.id != workspace_id or manifest.defaultPageId != default_page_id or sections_changed: + return manifest.model_copy(update={"id": workspace_id, "defaultPageId": default_page_id, "sections": sections}) + return manifest + except Exception as exc: + log.warning("webui_pages.workspace_manifest.invalid", {"path": str(path), "error": str(exc)}) + return None + + def _page_dir_in_root(self, root: Path, page_id: str) -> Path: + page_path = (root / page_id).resolve() + self._assert_inside_root(page_path, root) + return page_path + + @staticmethod + def _assert_inside_root(path: Path, root: Path) -> None: + try: + path.relative_to(root) + except ValueError: + raise ValueError("invalid page path") diff --git a/flocks/user_defined_pages/watcher.py b/flocks/contracts/webui/watcher.py similarity index 64% rename from flocks/user_defined_pages/watcher.py rename to flocks/contracts/webui/watcher.py index 49769cafc..9cbb413f0 100644 --- a/flocks/user_defined_pages/watcher.py +++ b/flocks/contracts/webui/watcher.py @@ -1,4 +1,4 @@ -"""Watch ~/.flocks/plugins/user_defined_pages for changes and trigger rebuilds.""" +"""Watch user-space WebUI pages for changes and trigger rebuilds.""" from __future__ import annotations @@ -9,13 +9,13 @@ from pathlib import Path from typing import Any, Callable, Coroutine, Optional -from flocks.user_defined_pages.api_runtime import UserDefinedPageApiRuntime -from flocks.user_defined_pages.builder import UserDefinedPagesBuilder -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.api_runtime import WebUIPageApiRuntime +from flocks.contracts.webui.builder import WebUIPageBuilder +from flocks.contracts.webui.store import WORKSPACE_MANIFEST_FILE, WebUIPagesStore from flocks.server.routes.event import publish_event from flocks.utils.log import Log -log = Log.create(service="user-defined-pages-watcher") +log = Log.create(service="webui-pages-watcher") _DEBOUNCE_SECONDS = 0.8 _RELOAD_EVENT_TYPES = frozenset({"modified", "created", "deleted", "moved"}) @@ -38,7 +38,7 @@ def _publish_event_sync(event_type: str, properties: dict) -> None: _main_loop, ) except Exception as exc: - log.warning("user_defined_pages.event.publish_failed", {"type": event_type, "error": str(exc)}) + log.warning("webui_pages.event.publish_failed", {"type": event_type, "error": str(exc)}) def _run_on_main_loop_sync(coro: Coroutine[Any, Any, Any], *, timeout_seconds: float = 5.0) -> Any: @@ -60,20 +60,20 @@ class _PendingAction: page_removed: bool = False -class UserDefinedPagesWatcher: - """Debounced filesystem watcher for user-defined pages.""" +class WebUIPagesWatcher: + """Debounced filesystem watcher for WebUI page plugins.""" def __init__( self, *, - store: Optional[UserDefinedPagesStore] = None, - builder: Optional[UserDefinedPagesBuilder] = None, - api_runtime: Optional[UserDefinedPageApiRuntime] = None, + store: Optional[WebUIPagesStore] = None, + builder: Optional[WebUIPageBuilder] = None, + api_runtime: Optional[WebUIPageApiRuntime] = None, on_build_complete: Optional[Callable[[str, bool, Optional[str]], None]] = None, ) -> None: - self._store = store or UserDefinedPagesStore() - self._builder = builder or UserDefinedPagesBuilder(self._store) - self._api_runtime = api_runtime or UserDefinedPageApiRuntime(self._store) + self._store = store or WebUIPagesStore() + self._builder = builder or WebUIPageBuilder(self._store) + self._api_runtime = api_runtime or WebUIPageApiRuntime(self._store) self._on_build_complete = on_build_complete self._observer: Optional[object] = None self._debounce_timer: Optional[threading.Timer] = None @@ -86,8 +86,8 @@ def start(self) -> None: from watchdog.observers import Observer except ImportError: log.warning( - "user_defined_pages.watcher.watchdog_missing", - {"msg": "watchdog not installed, user defined pages watcher disabled"}, + "webui_pages.watcher.watchdog_missing", + {"msg": "watchdog not installed, WebUI page watcher disabled"}, ) return @@ -112,7 +112,7 @@ def on_any_event(self, event: FileSystemEvent) -> None: observer.daemon = True observer.start() self._observer = observer - log.info("user_defined_pages.watcher.started", {"directory": str(root)}) + log.info("webui_pages.watcher.started", {"directory": str(root)}) def stop(self) -> None: with self._lock: @@ -127,7 +127,7 @@ def stop(self) -> None: except Exception: pass self._observer = None - log.info("user_defined_pages.watcher.stopped") + log.info("webui_pages.watcher.stopped") def _classify_event( self, @@ -143,12 +143,33 @@ def _classify_event( return None if not rel.parts: return None - page_id = rel.parts[0] - if is_directory and len(rel.parts) == 1 and event_type == "deleted": - return page_id, _PendingAction(page_removed=True) - if len(rel.parts) < 2: + + if rel.name == WORKSPACE_MANIFEST_FILE: + workspace_id = self._store.workspace_id_for_path(src) + if workspace_id is not None: + return workspace_id, _PendingAction(manifest_changed=True) + + page_id = self._store.page_id_for_path(src) + if page_id is None: + page_id = self._page_id_from_deleted_path(rel) + if page_id is None: + return None + if is_directory and event_type == "deleted": + return page_id, _PendingAction(page_removed=True) + if rel.name == "manifest.json" and event_type == "deleted": + return page_id, _PendingAction(manifest_changed=True) + return None + + try: + page_dir = self._store.page_dir(page_id).resolve() + page_rel = src.resolve(strict=False).relative_to(page_dir) + except Exception: + return None + if not page_rel.parts: + if is_directory and event_type == "deleted": + return page_id, _PendingAction(page_removed=True) return None - rel_str = str(Path(*rel.parts[1:])).replace("\\", "/") + rel_str = str(Path(*page_rel.parts)).replace("\\", "/") if rel_str == "manifest.json": return page_id, _PendingAction(manifest_changed=True) if rel_str.startswith("src/") and rel.suffix in {".ts", ".tsx", ".js", ".jsx", ".css"}: @@ -157,6 +178,15 @@ def _classify_event( return page_id, _PendingAction(api_changed=True) return None + def _page_id_from_deleted_path(self, rel: Path) -> Optional[str]: + if not rel.parts: + return None + page_dir_name = rel.parent.name if rel.name == "manifest.json" else rel.name + try: + return self._store.validate_page_id(page_dir_name.replace("_", "-")) + except ValueError: + return None + def _schedule(self, page_id: str, update: _PendingAction) -> None: with self._lock: pending = self._pending_pages.get(page_id, _PendingAction()) @@ -179,48 +209,48 @@ def _run_pending_builds(self) -> None: for page_id, pending in pages.items(): if pending.page_removed: self._api_runtime.clear_page(page_id) - _publish_event_sync("user_defined_pages.nav_changed", {"id": page_id}) + _publish_event_sync("contracts.webui.pages.nav_changed", {"id": page_id}) continue if pending.source_changed: try: meta = self._builder.build(page_id) if meta.status == "ready": - _publish_event_sync("user_defined_pages.updated", {"id": page_id, "hash": meta.hash}) - _publish_event_sync("user_defined_pages.nav_changed", {"id": page_id}) + _publish_event_sync("contracts.webui.pages.updated", {"id": page_id, "hash": meta.hash}) + _publish_event_sync("contracts.webui.pages.nav_changed", {"id": page_id}) else: _publish_event_sync( - "user_defined_pages.build_failed", + "contracts.webui.pages.build_failed", {"id": page_id, "error": meta.error or "build failed"}, ) if self._on_build_complete: self._on_build_complete(page_id, meta.status == "ready", meta.error) except Exception as exc: _publish_event_sync( - "user_defined_pages.build_failed", + "contracts.webui.pages.build_failed", {"id": page_id, "error": str(exc)}, ) - log.warning("user_defined_pages.watcher.build_failed", {"pageId": page_id, "error": str(exc)}) + log.warning("webui_pages.watcher.build_failed", {"pageId": page_id, "error": str(exc)}) if pending.api_changed: try: routes = _run_on_main_loop_sync(self._api_runtime.reload_page(page_id)) - _publish_event_sync("user_defined_pages.api_changed", {"id": page_id, "routes": routes}) + _publish_event_sync("contracts.webui.pages.api_changed", {"id": page_id, "routes": routes}) except Exception as exc: - _publish_event_sync("user_defined_pages.api_failed", {"id": page_id, "error": str(exc)}) - log.warning("user_defined_pages.watcher.api_reload_failed", {"pageId": page_id, "error": str(exc)}) + _publish_event_sync("contracts.webui.pages.api_failed", {"id": page_id, "error": str(exc)}) + log.warning("webui_pages.watcher.api_reload_failed", {"pageId": page_id, "error": str(exc)}) if pending.manifest_changed and not pending.source_changed: - _publish_event_sync("user_defined_pages.nav_changed", {"id": page_id}) + _publish_event_sync("contracts.webui.pages.nav_changed", {"id": page_id}) -_watcher: Optional[UserDefinedPagesWatcher] = None +_watcher: Optional[WebUIPagesWatcher] = None -def get_watcher() -> UserDefinedPagesWatcher: +def get_watcher() -> WebUIPagesWatcher: global _watcher if _watcher is None: - _watcher = UserDefinedPagesWatcher() + _watcher = WebUIPagesWatcher() return _watcher diff --git a/flocks/hub/catalog.py b/flocks/hub/catalog.py index 2f681cc74..5378d3aae 100644 --- a/flocks/hub/catalog.py +++ b/flocks/hub/catalog.py @@ -157,6 +157,7 @@ def _base_manifest( name: str, description: str, category: str, + version: str = "1.0.0", description_cn: Optional[str] = None, tags: Optional[list[str]] = None, use_cases: Optional[list[str]] = None, @@ -176,7 +177,7 @@ def _base_manifest( name=name or plugin_id, description=description or "", descriptionCn=description_cn, - version="1.0.0", + version=version or "1.0.0", author="Flocks Team", license="MIT", homepage="", @@ -382,6 +383,7 @@ def _tool_manifest(plugin_id: str, root: Path) -> Optional[HubPluginManifest]: name=str(provider.get("name") or first_tool.get("name") or plugin_id), description=description, category="integration", + version=str(provider.get("version") or "1.0.0"), description_cn=description_cn or None, tags=_tool_tags(plugin_id, f"{description} {description_cn}"), use_cases=["integration", "threat-intelligence"], diff --git a/flocks/ingest/kafka/manager.py b/flocks/ingest/kafka/manager.py index 34a30a07f..16acca4b6 100644 --- a/flocks/ingest/kafka/manager.py +++ b/flocks/ingest/kafka/manager.py @@ -27,7 +27,6 @@ from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional -from flocks.storage.storage import Storage from flocks.utils.log import Log from flocks.workflow.execution_store import ( DEFAULT_LARGE_LIST_KEYS, @@ -38,13 +37,20 @@ record_execution_result, resolve_execution_outcome, ) +from flocks.workflow.execution_plan import build_workflow_execution_plan from flocks.workflow.fs_store import read_workflow_from_fs +from flocks.workflow.models import Workflow from flocks.workflow.runner import run_workflow +from flocks.workflow.store import WorkflowStore from flocks.ingest.kafka.constants import WORKFLOW_KAFKA_CONFIG_PREFIX from flocks.workflow.triggers.compat import legacy_kafka_trigger_from_config from flocks.workflow.triggers.dispatcher import EventDispatcher, TriggerDispatchError, build_trigger_event -from flocks.workflow.triggers.models import TriggerDefinition, workflow_json_declares_triggers, workflow_trigger_definitions_from_json +from flocks.workflow.triggers.models import ( + TriggerDefinition, + workflow_json_declares_triggers, + workflow_trigger_definitions_from_json, +) log = Log.create(service="kafka.manager") @@ -176,9 +182,7 @@ def _compact_for_kafka_storage(outputs: Any) -> Dict[str, Any]: size_threshold=0, ) for key, value in list(compacted.items()): - if key == "kafka_output" or ( - isinstance(value, str) and len(value) > _STORAGE_PREVIEW_CHARS - ): + if key == "kafka_output" or (isinstance(value, str) and len(value) > _STORAGE_PREVIEW_CHARS): compacted[key] = _summarize_large_value(value) return compacted @@ -203,8 +207,10 @@ def _compact_history_for_kafka_storage( if not isinstance(payload, dict): continue for key, value in list(payload.items()): - if key in raw_input_keys or key == "kafka_output" or ( - isinstance(value, str) and len(value) > _STORAGE_PREVIEW_CHARS + if ( + key in raw_input_keys + or key == "kafka_output" + or (isinstance(value, str) and len(value) > _STORAGE_PREVIEW_CHARS) ): payload[key] = _summarize_large_value(value) return compacted @@ -282,22 +288,14 @@ def _resolve_active_trigger(self, workflow_json: Dict[str, Any], data: Dict[str, async def start_all(self) -> None: try: - keys = await Storage.list_keys(WORKFLOW_KAFKA_CONFIG_PREFIX) + configs = await WorkflowStore.list_configs(kind="workflow_kafka_config") except Exception as exc: - log.warning("kafka.list_keys_failed", {"error": str(exc)}) + log.warning("kafka.list_configs_failed", {"error": str(exc)}) return - for key in keys: - if not key.startswith(WORKFLOW_KAFKA_CONFIG_PREFIX): - continue - workflow_id = key[len(WORKFLOW_KAFKA_CONFIG_PREFIX):] + for workflow_id, data in configs: if not workflow_id: continue - try: - data = await Storage.read(key) - except Exception as exc: - log.warning("kafka.config_read_failed", {"key": key, "error": str(exc)}) - continue if isinstance(data, dict) and data.get("enabled"): await self.restart_workflow(workflow_id) @@ -370,9 +368,8 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: surface connection errors to the user. """ await self.stop_workflow(workflow_id) - key = self._config_key(workflow_id) try: - data = await Storage.read(key) + data = await WorkflowStore.get_config(workflow_id, kind="workflow_kafka_config") except Exception as exc: log.warning("kafka.restart_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) return {"state": "failed", "error": str(exc)} @@ -403,10 +400,15 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: return {"state": "failed", "error": err} trigger = self._resolve_active_trigger(workflow_json, data) + try: + workflow_plan = build_workflow_execution_plan(Workflow.from_dict(workflow_json)) + except Exception as exc: + err = f"workflow_plan_failed: {exc}" + self._status[workflow_id] = {"state": "failed", "error": err} + log.warning("kafka.workflow_plan_failed", {"workflow_id": workflow_id, "error": str(exc)}) + return self.get_consumer_status(workflow_id) group_id = str(data.get("inputGroupId") or "").strip() or f"flocks-consumer-{workflow_id}" - configured_inputs = _strip_execution_only_comments( - trigger.inputs if isinstance(trigger.inputs, dict) else {} - ) + configured_inputs = _strip_execution_only_comments(trigger.inputs if isinstance(trigger.inputs, dict) else {}) queue: asyncio.Queue = asyncio.Queue(maxsize=_MAX_QUEUE_SIZE) self._queues[workflow_id] = queue @@ -432,7 +434,13 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: workers.append( asyncio.create_task( self._worker_loop( - workflow_id, workflow_json, trigger, configured_inputs, queue, abort, input_topic, + workflow_id, + workflow_plan, + trigger, + configured_inputs, + queue, + abort, + input_topic, ), name=f"kafka-worker-{workflow_id}-{i}", ) @@ -441,9 +449,14 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: task = asyncio.create_task( self._consumer_loop( - workflow_id, input_broker, input_topic, group_id, + workflow_id, + input_broker, + input_topic, + group_id, str(data.get("autoOffsetReset") or "latest"), - queue, abort, ready, + queue, + abort, + ready, ), name=f"kafka-{workflow_id}", ) @@ -592,7 +605,7 @@ async def _consumer_loop( async def _worker_loop( self, workflow_id: str, - workflow_json: Any, + workflow_plan: Any, trigger: TriggerDefinition, configured_inputs: Dict[str, Any], queue: asyncio.Queue, @@ -611,7 +624,7 @@ async def _worker_loop( msg = _decode_message(msg.raw_value) await self._trigger_workflow( workflow_id, - workflow_json, + workflow_plan, msg, next(iter(trigger.mapping or {}), "kafka_message"), configured_inputs, @@ -629,7 +642,7 @@ async def _worker_loop( async def _trigger_workflow( self, workflow_id: str, - workflow_json: Any, + workflow_plan: Any, message: Any, input_key: str, configured_inputs: Optional[Dict[str, Any]] = None, @@ -688,33 +701,36 @@ async def _executor(mapped_inputs: Dict[str, Any]) -> Dict[str, Any]: try: result = await asyncio.to_thread( run_workflow, - workflow=workflow_json, + workflow=workflow_plan, inputs=mapped_inputs, + run_id=exec_id, trace=False, - history_mode="summary", + execution_profile="high_frequency", on_step_complete=step_recorder.on_step_complete, ) status, error_msg = resolve_execution_outcome(result) duration = time.time() - start_time step_count = step_recorder.step_count or result.steps exec_data.update(step_recorder.summary) - exec_data.update({ - "status": status, - "outputResults": _compact_for_kafka_storage(result.outputs), - "finishedAt": int(time.time() * 1000), - "duration": duration, - "errorMessage": error_msg, - "executionLog": [], - "stepCount": step_count, - "currentNodeId": result.last_node_id, - "currentPhase": status, - "currentStepIndex": step_count, - "triggerId": trigger.id, - "triggerType": trigger.type, - "deliveryId": trigger_meta.get("deliveryId"), - "attempt": trigger_meta.get("attempt"), - "triggerSource": trigger_meta.get("source"), - }) + exec_data.update( + { + "status": status, + "outputResults": _compact_for_kafka_storage(result.outputs), + "finishedAt": int(time.time() * 1000), + "duration": duration, + "errorMessage": error_msg, + "executionLog": [], + "stepCount": step_count, + "currentNodeId": result.last_node_id, + "currentPhase": status, + "currentStepIndex": step_count, + "triggerId": trigger.id, + "triggerType": trigger.type, + "deliveryId": trigger_meta.get("deliveryId"), + "attempt": trigger_meta.get("attempt"), + "triggerSource": trigger_meta.get("source"), + } + ) except Exception as exc: duration = time.time() - start_time log.error( @@ -722,19 +738,21 @@ async def _executor(mapped_inputs: Dict[str, Any]) -> Dict[str, Any]: {"workflow_id": workflow_id, "exec_id": exec_id, "error": str(exc)}, ) exec_data.update(step_recorder.summary) - exec_data.update({ - "status": "error", - "errorMessage": str(exc), - "finishedAt": int(time.time() * 1000), - "duration": duration, - "executionLog": [], - "currentPhase": "error", - "triggerId": trigger.id, - "triggerType": trigger.type, - "deliveryId": trigger_meta.get("deliveryId"), - "attempt": trigger_meta.get("attempt"), - "triggerSource": trigger_meta.get("source"), - }) + exec_data.update( + { + "status": "error", + "errorMessage": str(exc), + "finishedAt": int(time.time() * 1000), + "duration": duration, + "executionLog": [], + "currentPhase": "error", + "triggerId": trigger.id, + "triggerType": trigger.type, + "deliveryId": trigger_meta.get("deliveryId"), + "attempt": trigger_meta.get("attempt"), + "triggerSource": trigger_meta.get("source"), + } + ) finally: try: await record_execution_result(workflow_id, exec_id, exec_data) diff --git a/flocks/ingest/syslog/manager.py b/flocks/ingest/syslog/manager.py index 5c45b8d51..3fccaed88 100644 --- a/flocks/ingest/syslog/manager.py +++ b/flocks/ingest/syslog/manager.py @@ -7,7 +7,6 @@ import uuid from typing import Any, Dict, List -from flocks.storage.storage import Storage from flocks.utils.log import Log from flocks.workflow.execution_store import ( compact_outputs_for_storage, @@ -16,14 +15,21 @@ record_execution_result, resolve_execution_outcome, ) +from flocks.workflow.execution_plan import build_workflow_execution_plan from flocks.workflow.fs_store import read_workflow_from_fs +from flocks.workflow.models import Workflow from flocks.workflow.runner import run_workflow +from flocks.workflow.store import WorkflowStore from flocks.ingest.syslog.constants import WORKFLOW_SYSLOG_CONFIG_PREFIX from flocks.ingest.syslog.listener import run_tcp_syslog_server, run_udp_syslog_server from flocks.workflow.triggers.compat import legacy_syslog_trigger_from_config from flocks.workflow.triggers.dispatcher import EventDispatcher, TriggerDispatchError, build_trigger_event -from flocks.workflow.triggers.models import TriggerDefinition, workflow_json_declares_triggers, workflow_trigger_definitions_from_json +from flocks.workflow.triggers.models import ( + TriggerDefinition, + workflow_json_declares_triggers, + workflow_trigger_definitions_from_json, +) log = Log.create(service="syslog.manager") @@ -88,13 +94,16 @@ def flush_remaining(self, trigger: str = "shutdown") -> None: self._flush(trigger=trigger) def _flush(self, *, trigger: str) -> None: - log.warning("syslog.queue_full_dropped", { - "workflow_id": self._workflow_id, - "queue_size": self._queue.qsize(), - "queue_capacity": self._queue.maxsize, - "dropped_in_window": int(self._count), - "trigger": trigger, - }) + log.warning( + "syslog.queue_full_dropped", + { + "workflow_id": self._workflow_id, + "queue_size": self._queue.qsize(), + "queue_capacity": self._queue.maxsize, + "dropped_in_window": int(self._count), + "trigger": trigger, + }, + ) self._count = 0 self._last_log = time.monotonic() @@ -168,22 +177,14 @@ def _resolve_active_trigger(self, workflow_json: Dict[str, Any], data: Dict[str, async def start_all(self) -> None: try: - keys = await Storage.list_keys(WORKFLOW_SYSLOG_CONFIG_PREFIX) + configs = await WorkflowStore.list_configs(kind="workflow_syslog_config") except Exception as exc: - log.warning("syslog.list_keys_failed", {"error": str(exc)}) + log.warning("syslog.list_configs_failed", {"error": str(exc)}) return - for key in keys: - if not key.startswith(WORKFLOW_SYSLOG_CONFIG_PREFIX): - continue - workflow_id = key[len(WORKFLOW_SYSLOG_CONFIG_PREFIX) :] + for workflow_id, data in configs: if not workflow_id: continue - try: - data = await Storage.read(key) - except Exception as exc: - log.warning("syslog.config_read_failed", {"key": key, "error": str(exc)}) - continue if isinstance(data, dict) and data.get("enabled"): await self.restart_workflow(workflow_id) @@ -254,9 +255,8 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: user instead of silently leaving the listener in a failed state. """ await self.stop_workflow(workflow_id) - key = self._config_key(workflow_id) try: - data = await Storage.read(key) + data = await WorkflowStore.get_config(workflow_id, kind="workflow_syslog_config") except Exception as exc: log.warning("syslog.restart_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) return {"state": "failed", "error": str(exc)} @@ -279,6 +279,13 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: return {"state": "failed", "error": err} trigger = self._resolve_active_trigger(workflow_json, data) + try: + workflow_plan = build_workflow_execution_plan(Workflow.from_dict(workflow_json)) + except Exception as exc: + err = f"workflow_plan_failed: {exc}" + self._listener_status[workflow_id] = {"state": "failed", "error": err} + log.warning("syslog.workflow_plan_failed", {"workflow_id": workflow_id, "error": str(exc)}) + return self.get_listener_status(workflow_id) queue: asyncio.Queue = asyncio.Queue(maxsize=_MAX_QUEUE_SIZE) self._queues[workflow_id] = queue @@ -306,7 +313,7 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: for i in range(_MAX_CONCURRENT_EXECUTIONS): workers.append( asyncio.create_task( - self._worker_loop(workflow_id, workflow_json, trigger, queue, abort), + self._worker_loop(workflow_id, workflow_plan, trigger, queue, abort), name=f"syslog-worker-{workflow_id}-{i}", ) ) @@ -451,7 +458,7 @@ async def _mark_ready_after_bind() -> None: async def _worker_loop( self, workflow_id: str, - workflow_json: Any, + workflow_plan: Any, trigger: TriggerDefinition, queue: asyncio.Queue, abort: asyncio.Event, @@ -472,7 +479,7 @@ async def _worker_loop( try: await self._trigger_workflow( workflow_id, - workflow_json, + workflow_plan, msg, next(iter(trigger.mapping or {}), "syslog_message"), trigger=trigger, @@ -489,7 +496,7 @@ async def _worker_loop( async def _trigger_workflow( self, workflow_id: str, - workflow_json: Any, + workflow_plan: Any, syslog_msg: dict, input_key: str, *, @@ -534,32 +541,36 @@ async def _executor(mapped_inputs: Dict[str, Any]) -> Dict[str, Any]: try: result = await asyncio.to_thread( run_workflow, - workflow=workflow_json, + workflow=workflow_plan, inputs=mapped_inputs, + run_id=exec_id, trace=False, + execution_profile="high_frequency", on_step_complete=step_recorder.on_step_complete, ) status, error_msg = resolve_execution_outcome(result) duration = time.time() - start_time step_count = step_recorder.step_count or result.steps exec_data.update(step_recorder.summary) - exec_data.update({ - "status": status, - "outputResults": compact_outputs_for_storage(result.outputs), - "finishedAt": int(time.time() * 1000), - "duration": duration, - "errorMessage": error_msg, - "executionLog": [], - "stepCount": step_count, - "currentNodeId": result.last_node_id, - "currentPhase": status, - "currentStepIndex": step_count, - "triggerId": trigger.id, - "triggerType": trigger.type, - "deliveryId": trigger_meta.get("deliveryId"), - "attempt": trigger_meta.get("attempt"), - "triggerSource": trigger_meta.get("source"), - }) + exec_data.update( + { + "status": status, + "outputResults": compact_outputs_for_storage(result.outputs), + "finishedAt": int(time.time() * 1000), + "duration": duration, + "errorMessage": error_msg, + "executionLog": [], + "stepCount": step_count, + "currentNodeId": result.last_node_id, + "currentPhase": status, + "currentStepIndex": step_count, + "triggerId": trigger.id, + "triggerType": trigger.type, + "deliveryId": trigger_meta.get("deliveryId"), + "attempt": trigger_meta.get("attempt"), + "triggerSource": trigger_meta.get("source"), + } + ) except Exception as exc: duration = time.time() - start_time log.error( @@ -567,19 +578,21 @@ async def _executor(mapped_inputs: Dict[str, Any]) -> Dict[str, Any]: {"workflow_id": workflow_id, "exec_id": exec_id, "error": str(exc)}, ) exec_data.update(step_recorder.summary) - exec_data.update({ - "status": "error", - "errorMessage": str(exc), - "finishedAt": int(time.time() * 1000), - "duration": duration, - "executionLog": [], - "currentPhase": "error", - "triggerId": trigger.id, - "triggerType": trigger.type, - "deliveryId": trigger_meta.get("deliveryId"), - "attempt": trigger_meta.get("attempt"), - "triggerSource": trigger_meta.get("source"), - }) + exec_data.update( + { + "status": "error", + "errorMessage": str(exc), + "finishedAt": int(time.time() * 1000), + "duration": duration, + "executionLog": [], + "currentPhase": "error", + "triggerId": trigger.id, + "triggerType": trigger.type, + "deliveryId": trigger_meta.get("deliveryId"), + "attempt": trigger_meta.get("attempt"), + "triggerSource": trigger_meta.get("source"), + } + ) finally: try: await record_execution_result(workflow_id, exec_id, exec_data) diff --git a/flocks/server/app.py b/flocks/server/app.py index a605d33b4..8c7cb1b24 100644 --- a/flocks/server/app.py +++ b/flocks/server/app.py @@ -400,27 +400,27 @@ def _start_tool_watcher() -> None: except Exception as e: log.warning("tool.watcher.init_failed", {"error": str(e)}) - # Start user-defined pages watcher (auto-build user custom pages) + # Start WebUI page watcher (auto-build user custom pages) try: - from flocks.user_defined_pages.bootstrap import reconcile_user_defined_pages - from flocks.user_defined_pages.watcher import set_event_loop, start_watcher + from flocks.contracts.webui.bootstrap import reconcile_webui_pages + from flocks.contracts.webui.watcher import set_event_loop, start_watcher set_event_loop(asyncio.get_running_loop()) _schedule_startup_phase( app, log, - "user_defined_pages.bootstrap", - reconcile_user_defined_pages, + "webui_pages.bootstrap", + reconcile_webui_pages, ) - def _start_user_defined_pages_watcher() -> None: + def _start_webui_pages_watcher() -> None: start_watcher() - log.info("user_defined_pages.watcher.initialized") + log.info("webui_pages.watcher.initialized") - _schedule_startup_phase(app, log, "user_defined_pages.watcher.start", _start_user_defined_pages_watcher) + _schedule_startup_phase(app, log, "webui_pages.watcher.start", _start_webui_pages_watcher) except Exception as e: - log.warning("user_defined_pages.watcher.init_failed", {"error": str(e)}) + log.warning("webui_pages.watcher.init_failed", {"error": str(e)}) # Start Channel Gateway (connect enabled IM channels) try: @@ -539,12 +539,12 @@ async def _delayed_trigger_runtime_start() -> None: except Exception as e: log.warning("skill.watcher.stop_failed", {"error": str(e)}) - # Stop user-defined pages watcher + # Stop WebUI page watcher try: - from flocks.user_defined_pages.watcher import stop_watcher + from flocks.contracts.webui.watcher import stop_watcher stop_watcher() except Exception as e: - log.warning("user_defined_pages.watcher.stop_failed", {"error": str(e)}) + log.warning("webui_pages.watcher.stop_failed", {"error": str(e)}) # Shutdown MCP connections try: @@ -1013,7 +1013,9 @@ async def general_exception_handler(request: Request, exc: Exception): from flocks.server.routes.notifications import router as notifications_router from flocks.server.routes.device import router as device_router from flocks.server.routes.console_upgrade import router as console_upgrade_router -from flocks.server.routes.user_defined_pages import router as user_defined_pages_router +from flocks.server.routes.flockspro_license import router as flockspro_license_router +from flocks.server.routes.webui import router as webui_pages_router +from flocks.server.routes.contracts import router as contracts_router # Original routes with /api/ prefix app.include_router(health_router, prefix="/api", tags=["Health"]) app.include_router(session_router, prefix="/api/session", tags=["Session"]) @@ -1073,7 +1075,8 @@ async def general_exception_handler(request: Request, exc: Exception): # Device integration (named instances, SQL-backed) app.include_router(device_router, prefix="/api/devices", tags=["Device"]) app.include_router(console_upgrade_router, prefix="/api/console", tags=["ConsoleUpgrade"]) -app.include_router(user_defined_pages_router, prefix="/api", tags=["UserDefinedPages"]) +app.include_router(webui_pages_router, prefix="/api", tags=["WebUI"]) +app.include_router(contracts_router, prefix="/api", tags=["AccessContracts"]) # ============================================================ # TUI Compatible Routes (without /api/ prefix) @@ -1149,7 +1152,32 @@ def _load_installed_package_plugins() -> None: log.warning("plugins.installed.load_failed", {"error": str(e)}) +def _route_registered(path: str, method: str) -> bool: + target_method = method.upper() + for route in app.routes: + if getattr(route, "path", None) != path: + continue + methods = getattr(route, "methods", None) or set() + if target_method in methods: + return True + return False + + +def _install_flockspro_license_fallback() -> None: + status_registered = _route_registered("/api/flockspro/license/status", "GET") + refresh_registered = _route_registered("/api/flockspro/license/refresh", "POST") + if status_registered and refresh_registered: + log.info("flockspro.license.fallback.skipped", { + "status_registered": status_registered, + "refresh_registered": refresh_registered, + }) + return + app.include_router(flockspro_license_router, prefix="/api/flockspro/license", tags=["FlocksProLicense"]) + log.info("flockspro.license.fallback.installed") + + _load_installed_package_plugins() +_install_flockspro_license_fallback() @app.get("/", tags=["Root"]) diff --git a/flocks/server/routes/admin_users.py b/flocks/server/routes/admin_users.py index 921c08ca4..913bb7845 100644 --- a/flocks/server/routes/admin_users.py +++ b/flocks/server/routes/admin_users.py @@ -21,6 +21,8 @@ class UserResponse(BaseModel): role: str status: str must_reset_password: bool + tenant_ids: tuple[str, ...] = Field(default_factory=tuple) + asset_groups: tuple[str, ...] = Field(default_factory=tuple) created_at: str updated_at: str last_login_at: Optional[str] = None @@ -31,6 +33,11 @@ class ResetPasswordRequest(BaseModel): force_reset: bool = True +class ContractScopeRequest(BaseModel): + tenant_ids: list[str] = Field(default_factory=list) + asset_groups: list[str] = Field(default_factory=list) + + @router.get("/users", response_model=List[UserResponse], summary="管理员获取用户列表") async def list_users(request: Request) -> List[UserResponse]: _admin = require_admin(request) @@ -38,6 +45,23 @@ async def list_users(request: Request) -> List[UserResponse]: return [UserResponse(**u.model_dump()) for u in users] +@router.put("/users/{user_id}/contract-scope", response_model=UserResponse, summary="管理员更新用户契约范围") +async def update_user_contract_scope(user_id: str, payload: ContractScopeRequest, request: Request) -> UserResponse: + require_admin(request) + setter = getattr(AuthService.get_backend(), "set_user_contract_scope", None) + if setter is None: + raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="当前账号后端不支持契约范围管理") + try: + user = await setter( + target_user_id=user_id, + tenant_ids=payload.tenant_ids, + asset_groups=payload.asset_groups, + ) + except ValueError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + return UserResponse(**user.model_dump()) + + @router.post("/users/{user_id}/reset-password", summary="管理员重置密码") async def reset_user_password(user_id: str, payload: ResetPasswordRequest, request: Request) -> dict: require_admin(request) diff --git a/flocks/server/routes/console_upgrade.py b/flocks/server/routes/console_upgrade.py index 33324c722..9539bf71a 100644 --- a/flocks/server/routes/console_upgrade.py +++ b/flocks/server/routes/console_upgrade.py @@ -316,6 +316,7 @@ def _enrich_record_from_install_marker(record: dict[str, Any]) -> dict[str, Any] details = record.setdefault("details", {}) marker = _read_pro_bundle_install_marker() if marker: + details.setdefault("auto_install_release_id", marker.get("release_id") or marker.get("bundle_release_id")) details.setdefault("auto_install_version", marker.get("installed_version")) details.setdefault("auto_install_pro_version", marker.get("flockspro_component_version")) details.setdefault("flockspro_component_version", marker.get("flockspro_component_version")) @@ -333,7 +334,9 @@ def _enrich_record_from_install_marker(record: dict[str, Any]) -> dict[str, Any] return record -async def _maybe_activate_pro_license(record: dict[str, Any], *, force: bool = False) -> None: +async def _maybe_activate_pro_license( + record: dict[str, Any], *, force: bool = False, allow_fallback: bool = True +) -> None: activate_key = str(record.get("activate_key") or "").strip() if not activate_key: return @@ -355,7 +358,7 @@ async def _maybe_activate_pro_license(record: dict[str, Any], *, force: bool = F details.pop("license_activate_error", None) except Exception as exc: details["license_activate_error"] = str(exc) - if not _is_pro_component_installed(): + if not allow_fallback or not _is_pro_component_installed(): return _fallback_write_pro_license_state(record, activate_key, str(exc)) @@ -418,16 +421,104 @@ async def _maybe_refresh_pro_license(record: dict[str, Any]) -> None: details["license_refresh_error"] = str(exc) +def _clean_bundle_value(value: Any) -> str: + return str(value or "").strip() + + +def _clean_version_value(value: Any) -> str: + return _clean_bundle_value(value).removeprefix("v") + + +def _record_target_bundle(record: dict[str, Any]) -> dict[str, str]: + details = record.get("details") if isinstance(record.get("details"), dict) else {} + latest_bundle = details.get("latest_pro_bundle") if isinstance(details.get("latest_pro_bundle"), dict) else {} + release_id = _clean_bundle_value( + record.get("approved_bundle_release_id") + or details.get("approved_bundle_release_id") + or details.get("bundle_release_id") + or latest_bundle.get("release_id") + ) + target = { + "release_id": release_id, + "bundle_release_id": _clean_bundle_value(details.get("bundle_release_id") or release_id), + "build_id": _clean_bundle_value(details.get("target_build_id") or latest_bundle.get("build_id")), + "display_version": _clean_bundle_value( + details.get("target_display_version") + or details.get("auto_install_target") + or latest_bundle.get("display_version") + ), + "core_version": _clean_bundle_value( + details.get("target_core_version") + or details.get("target_oss_version") + or latest_bundle.get("core_version") + or latest_bundle.get("oss_version") + ), + "flockspro_component_version": _clean_bundle_value( + details.get("target_flockspro_component_version") + or latest_bundle.get("flockspro_component_version") + ), + } + return {key: value for key, value in target.items() if value} + + +def _target_bundle_fingerprint_matches(target: dict[str, str], marker: dict[str, Any]) -> bool | None: + build_id = target.get("build_id") + marker_build_id = _clean_bundle_value(marker.get("build_id")) + if build_id and marker_build_id: + return marker_build_id == build_id + + pro_version = target.get("flockspro_component_version") + marker_pro_version = _clean_bundle_value(marker.get("flockspro_component_version")) + if pro_version and marker_pro_version: + return marker_pro_version == pro_version + + display_version = target.get("display_version") + marker_display_version = _clean_bundle_value(marker.get("installed_version") or marker.get("display_version")) + if display_version and marker_display_version: + return _clean_version_value(marker_display_version) == _clean_version_value(display_version) + + core_version = target.get("core_version") or target.get("oss_version") + marker_core_version = _clean_bundle_value(marker.get("core_version") or marker.get("oss_version")) + if core_version and marker_core_version: + return _clean_version_value(marker_core_version) == _clean_version_value(core_version) + + return None + + +def _marker_matches_target_bundle(marker: dict[str, Any], record: dict[str, Any]) -> bool: + if not marker: + return False + target = _record_target_bundle(record) + target_release_id = _clean_bundle_value(target.get("release_id") or target.get("bundle_release_id")) + marker_release_id = _clean_bundle_value(marker.get("release_id") or marker.get("bundle_release_id")) + if target_release_id: + if marker_release_id: + return marker_release_id == target_release_id + fingerprint_match = _target_bundle_fingerprint_matches(target, marker) + return fingerprint_match is True + + fingerprint_match = _target_bundle_fingerprint_matches(target, marker) + if fingerprint_match is not None: + return fingerprint_match + return True + + async def _run_auto_upgrade_install(record: dict[str, Any]) -> dict[str, Any]: details = record.setdefault("details", {}) details["auto_install_result"] = "running" details["auto_install_started_at"] = datetime.now(UTC).isoformat() marker = _read_pro_bundle_install_marker() - if _is_pro_component_installed() and marker: - details["auto_install_result"] = "already_latest" + if _is_pro_component_installed() and _marker_matches_target_bundle(marker, record): + details["auto_install_release_id"] = marker.get("release_id") or marker.get("bundle_release_id") details["auto_install_version"] = marker.get("installed_version") + await _maybe_activate_pro_license(record, allow_fallback=False) + await _maybe_refresh_pro_license(record) + capability = _record_pro_capability(details) + if not capability.get("pro_enabled"): + details["auto_install_result"] = "license_inactive" + raise ValueError("Flocks Pro component is installed but license activation is inactive") + details["auto_install_result"] = "already_latest" details["auto_install_completed_at"] = datetime.now(UTC).isoformat() - _record_pro_capability(details) await _report_pro_bundle_installation(record, install_result="success") return record @@ -439,18 +530,23 @@ async def _run_auto_upgrade_install(record: dict[str, Any]) -> dict[str, Any]: if progress.stage == "error": raise ValueError(progress.message) - await _maybe_activate_pro_license(record) + await _maybe_activate_pro_license(record, allow_fallback=False) await _maybe_refresh_pro_license(record) capability = _record_pro_capability(details) marker = _read_pro_bundle_install_marker() details["auto_install_result"] = ( "done" if final_stage == "done" and capability.get("pro_enabled") else "license_inactive" ) + details["auto_install_release_id"] = marker.get("release_id") or marker.get("bundle_release_id") details["auto_install_version"] = marker.get("installed_version") details["auto_install_pro_version"] = marker.get("flockspro_component_version") details["auto_install_completed_at"] = datetime.now(UTC).isoformat() details["auto_install_message"] = final_message _enrich_record_from_install_marker(record) + if final_stage != "done": + raise ValueError(final_message or "Flocks Pro bundle installation did not complete") + if not capability.get("pro_enabled"): + raise ValueError("Flocks Pro bundle installed but license activation is inactive") await _report_pro_bundle_installation(record, install_result="success") return record @@ -464,6 +560,15 @@ def _read_pro_bundle_install_marker() -> dict[str, Any]: return payload if isinstance(payload, dict) else {} +def _marker_indicates_pro_bundle_installed(marker: dict[str, Any]) -> bool: + if not marker: + return False + return any( + str(marker.get(key) or "").strip() + for key in ("installed_at", "installed_version", "bundle_version", "flockspro_component_version", "build_id") + ) + + async def _report_pro_bundle_installation( record: dict[str, Any], *, @@ -477,14 +582,34 @@ async def _report_pro_bundle_installation( details["install_receipt_error"] = str(exc) return marker = _read_pro_bundle_install_marker() + target = _record_target_bundle(record) + marker_matches_target = _marker_matches_target_bundle(marker, record) + use_marker = install_result == "success" and marker_matches_target + source = marker if use_marker else target + release_id = _clean_bundle_value( + source.get("release_id") + or source.get("bundle_release_id") + or target.get("release_id") + or target.get("bundle_release_id") + ) + bundle_release_id = _clean_bundle_value(source.get("bundle_release_id") or source.get("release_id") or release_id) payload = { + "request_id": record.get("request_id"), + "release_id": release_id or None, + "bundle_release_id": bundle_release_id or None, "license_id": _record_license_id(record), "fingerprint": console_session.get("fingerprint"), "install_id": console_session.get("install_id"), - "installed_version": marker.get("installed_version") or details.get("auto_install_target") or details.get("auto_install_version") or "", - "oss_version": marker.get("oss_version"), - "flockspro_component_version": marker.get("flockspro_component_version"), - "build_id": marker.get("build_id"), + "installed_version": source.get("installed_version") + or source.get("display_version") + or target.get("display_version") + or details.get("auto_install_target") + or details.get("auto_install_version") + or "", + "core_version": source.get("core_version") or source.get("oss_version") or target.get("core_version"), + "oss_version": source.get("core_version") or source.get("oss_version") or target.get("core_version"), + "flockspro_component_version": source.get("flockspro_component_version") or target.get("flockspro_component_version"), + "build_id": source.get("build_id") or target.get("build_id"), "install_result": install_result, "error_message": error_message, "reported_at": datetime.now(UTC).isoformat(), @@ -501,6 +626,7 @@ async def _report_pro_bundle_installation( headers={"Authorization": f"Bearer {console_session['console_session_token']}"}, ) resp.raise_for_status() + details.pop("install_receipt_error", None) details["install_receipt_reported_at"] = datetime.now(UTC).isoformat() except Exception as exc: details["install_receipt_error"] = str(exc) @@ -529,11 +655,12 @@ async def _maybe_auto_activate_upgrade(record: dict[str, Any]) -> dict[str, Any] if not _is_approved(record): return record details = record.setdefault("details", {}) - if details.get("auto_install_result") in {"done", "already_latest"}: + if details.get("auto_install_result") in {"done", "already_latest"} and _marker_matches_target_bundle( + _read_pro_bundle_install_marker(), + record, + ): return record try: - await _maybe_activate_pro_license(record) - await _maybe_refresh_pro_license(record) await _run_auto_upgrade_install(record) capability = _record_pro_capability(details) if capability.get("pro_enabled"): @@ -549,6 +676,38 @@ async def _maybe_auto_activate_upgrade(record: dict[str, Any]) -> dict[str, Any] return record +async def _finalize_restarting_upgrade_if_installed(record: dict[str, Any]) -> dict[str, Any]: + details = record.setdefault("details", {}) + if details.get("auto_install_result") != "restarting": + return record + marker = _read_pro_bundle_install_marker() + if not _marker_matches_target_bundle(marker, record): + return record + + await _maybe_activate_pro_license(record) + await _maybe_refresh_pro_license(record) + capability = _record_pro_capability(details) + details["auto_install_result"] = "done" if capability.get("pro_enabled") else "license_inactive" + details["auto_install_version"] = marker.get("installed_version") or marker.get("display_version") + details["auto_install_pro_version"] = marker.get("flockspro_component_version") + details["auto_install_completed_at"] = datetime.now(UTC).isoformat() + details["auto_install_message"] = "Upgrade completed after service restart" + _enrich_record_from_install_marker(record) + if capability.get("pro_enabled"): + record["status"] = "activated" + record["updated_at"] = datetime.now(UTC).isoformat() + + previous_reported_at = details.get("install_receipt_reported_at") + await _report_pro_bundle_installation(record, install_result="success") + if details.get("install_receipt_reported_at") == previous_reported_at and details.get("install_receipt_error"): + details["auto_install_result"] = "restarting" + record["updated_at"] = datetime.now(UTC).isoformat() + return record + if capability.get("pro_enabled"): + await _mark_console_upgrade_activated(record) + return record + + async def _run_auto_activate_upgrade_task(request_id: str, record: dict[str, Any]) -> None: try: updated = await _maybe_auto_activate_upgrade(record) @@ -566,7 +725,12 @@ def _schedule_auto_activate_upgrade(request_id: str, record: dict[str, Any]) -> if not _is_approved(record): return details = record.setdefault("details", {}) - if details.get("auto_install_result") in {"running", "done", "already_latest"}: + if details.get("auto_install_result") == "running": + return + if details.get("auto_install_result") in {"done", "already_latest"} and _marker_matches_target_bundle( + _read_pro_bundle_install_marker(), + record, + ): return if request_id in _AUTO_UPGRADE_REQUEST_IDS: return @@ -584,8 +748,12 @@ def _raise_console_service_error(exc: Exception) -> None: if isinstance(payload, dict): detail = str(payload.get("detail") or payload.get("message") or detail) except Exception: - if exc.response.text: - detail = exc.response.text + response_text = str(exc.response.text or "").strip() + content_type = str(exc.response.headers.get("content-type", "")).lower() + if response_text and "html" not in content_type and not response_text.lower().startswith(" list[UpgradeRequestStatus]: for request_id in reversed(await _list_request_ids()): raw = await Storage.get(_request_key(request_id)) if raw: + raw = await _finalize_restarting_upgrade_if_installed(raw) + await Storage.set(_request_key(request_id), raw, "json") result.append(UpgradeRequestStatus(**_enrich_record_from_install_marker(raw))) return result @@ -700,9 +870,13 @@ async def get_pro_package_status(request: Request) -> dict[str, Any]: require_user(request) marker = _read_pro_bundle_install_marker() capability = _get_pro_capability_status() - installed = _is_pro_component_installed() + runtime_importable = _is_pro_component_installed() + install_marker_present = _marker_indicates_pro_bundle_installed(marker) + installed = runtime_importable or install_marker_present return { "installed": installed, + "runtime_importable": runtime_importable, + "install_marker_present": install_marker_present, "installed_version": marker.get("installed_version"), "flockspro_component_version": marker.get("flockspro_component_version"), "build_id": marker.get("build_id"), @@ -713,106 +887,14 @@ async def get_pro_package_status(request: Request) -> dict[str, Any]: } -@router.post("/licenses/sync-revocations") -async def sync_console_license_revocations(request: Request) -> dict[str, Any]: - require_admin(request) - console_base = _console_base_url() - if not console_base: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="FLOCKS_CONSOLE_BASE_URL 未配置") - try: - console_session = await ConsoleLoginService.require_console_session() - except ValueError as exc: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc - - headers = {"Authorization": f"Bearer {console_session['console_session_token']}"} - account_key = _console_session_account_key(console_session) - synced_license_ids: list[str] = [] - try: - async with httpx.AsyncClient(timeout=10) as client: - resp = await client.get(f"{console_base}/v1/licenses/revocations", headers=headers) - resp.raise_for_status() - data = resp.json() - - for request_id in await _list_request_ids(): - raw = await Storage.get(_request_key(request_id)) - if not isinstance(raw, dict): - continue - record_account_key = _record_account_key(raw) - if account_key and record_account_key and record_account_key != account_key: - continue - license_id = _record_license_id(raw) - if not license_id: - continue - license_resp = await client.get(f"{console_base}/v1/licenses/{license_id}", headers=headers) - if license_resp.status_code == status.HTTP_404_NOT_FOUND: - continue - license_resp.raise_for_status() - license_data = license_resp.json() - if isinstance(license_data, dict): - _apply_console_license_data(raw, license_data) - synced_license_ids.append(license_id) - await Storage.set(_request_key(request_id), raw, "json") - except httpx.HTTPError as exc: - _raise_console_service_error(exc) - - revoked_license_ids = data.get("revoked_license_ids", []) - if not isinstance(revoked_license_ids, list): - revoked_license_ids = [] - - imported = False - activated_license_id: str | None = None - refreshed_license_id: str | None = None - if not _is_pro_component_installed(): - return { - "revoked_license_ids": [str(item) for item in revoked_license_ids], - "imported": imported, - "synced_license_ids": synced_license_ids, - "activated_license_id": activated_license_id, - "refreshed_license_id": refreshed_license_id, - "inactive_reason": "flockspro_not_installed", - } - try: - from flockspro.license.runtime import get_license_checker # type: ignore[import-not-found] - - checker = get_license_checker() - import_fn = getattr(checker, "import_revocation", None) - if callable(import_fn): - import_fn([str(item) for item in revoked_license_ids]) - imported = True - - revoked_set = {str(item) for item in revoked_license_ids} - current_status = _get_pro_capability_status() - current_license_id = str(current_status.get("license_id") or "") - target = await _latest_usable_issued_record( - revoked_set, - account_key=_console_session_account_key(console_session), - ) - target_license_id = _record_license_id(target) if target else "" - if target and target_license_id: - if target_license_id != current_license_id: - await _maybe_activate_pro_license(target, force=True) - activated_license_id = target_license_id - await _maybe_refresh_pro_license(target) - refreshed_license_id = target_license_id - await Storage.set(_request_key(str(target["request_id"])), target, "json") - except Exception as exc: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc - - return { - "revoked_license_ids": [str(item) for item in revoked_license_ids], - "imported": imported, - "synced_license_ids": synced_license_ids, - "activated_license_id": activated_license_id, - "refreshed_license_id": refreshed_license_id, - } - - @router.get("/upgrade-requests/{request_id}", response_model=UpgradeRequestStatus) async def get_upgrade_request(request_id: str, request: Request) -> UpgradeRequestStatus: require_admin(request) raw = await Storage.get(_request_key(request_id)) if not raw: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="升级申请不存在") + raw = await _finalize_restarting_upgrade_if_installed(raw) + await Storage.set(_request_key(request_id), raw, "json") return UpgradeRequestStatus(**_enrich_record_from_install_marker(raw)) @@ -841,6 +923,9 @@ async def refresh_upgrade_request(request_id: str, request: Request) -> UpgradeR except httpx.HTTPError as exc: _raise_console_service_error(exc) else: + remote_details = data.get("form_data") if isinstance(data.get("form_data"), dict) else {} + local_details = raw.get("details") if isinstance(raw.get("details"), dict) else {} + refreshed_at = datetime.now(UTC).isoformat() raw.update( { "status": data.get("status", raw["status"]), @@ -853,13 +938,14 @@ async def refresh_upgrade_request(request_id: str, request: Request) -> UpgradeR "max_admins": data.get("max_admins", raw.get("max_admins")), "max_members": data.get("max_members", raw.get("max_members")), "expires_at": data.get("expires_at", raw.get("expires_at")), - "details": data.get("form_data", raw.get("details", {})), - "updated_at": datetime.now(UTC).isoformat(), + "details": {**local_details, **remote_details, "license_refreshed_at": refreshed_at}, + "updated_at": refreshed_at, } ) else: raw["updated_at"] = datetime.now(UTC).isoformat() + raw = await _finalize_restarting_upgrade_if_installed(raw) _enrich_record_from_install_marker(raw) await Storage.set(_request_key(request_id), raw, "json") return UpgradeRequestStatus(**raw) diff --git a/flocks/server/routes/contracts.py b/flocks/server/routes/contracts.py new file mode 100644 index 000000000..142e9d188 --- /dev/null +++ b/flocks/server/routes/contracts.py @@ -0,0 +1,93 @@ +"""Page data access contract routes.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, Body, Depends +from fastapi.responses import JSONResponse + +from flocks.auth.context import AuthUser +from flocks.server.auth import require_user +from flocks.contracts.access.models import ContractRuntimeError +from flocks.contracts.access.runtime import OperationRuntime + +router = APIRouter() +_runtime: OperationRuntime | None = None +_runtime_override = False +_runtime_signature: tuple[tuple[str, int, int], ...] = () + + +@router.post("/contracts/webui/pages/{page_path:path}/access/{contract_id}/operations/{operation_name}") +async def execute_webui_contract_operation( + page_path: str, + contract_id: str, + operation_name: str, + body: dict[str, Any] | None = Body(default=None), + user: AuthUser = Depends(require_user), +): + try: + runtime = _get_runtime() + response = runtime.execute( + page_id=page_path, + contract_id=contract_id, + operation_name=operation_name, + payload=body, + principal=user, + ) + except ContractRuntimeError as exc: + return JSONResponse(status_code=exc.status_code, content={"error": exc.to_detail()}) + return JSONResponse(status_code=response.status_code, content=response.body) + + +def reset_route_dependencies(*, runtime: OperationRuntime | None = None) -> None: + """Test helper to inject isolated route dependencies.""" + global _runtime, _runtime_override, _runtime_signature + _runtime = runtime + _runtime_override = runtime is not None + _runtime_signature = _contract_plugin_signature() + + +def _get_runtime() -> OperationRuntime: + global _runtime, _runtime_signature + if _runtime_override and _runtime is not None: + return _runtime + + signature = _contract_plugin_signature() + if _runtime is None or signature != _runtime_signature: + _runtime = OperationRuntime() + _runtime_signature = signature + return _runtime + + +def _contract_plugin_signature() -> tuple[tuple[str, int, int], ...]: + roots = ( + Path.home() / ".flocks" / "plugins" / "contracts" / "access", + Path.cwd() / ".flocks" / "plugins" / "contracts" / "access", + ) + entries: list[tuple[str, int, int]] = [] + seen: set[Path] = set() + for root in roots: + resolved_root = root.resolve() + if resolved_root in seen or not resolved_root.is_dir(): + continue + seen.add(resolved_root) + plugin_files = [ + path + for suffix in ("*.py", "*.json", "*.yaml", "*.yml") + for path in resolved_root.rglob(suffix) + ] + for path in sorted(plugin_files): + try: + relative_parts = path.relative_to(resolved_root).parts + except ValueError: + continue + if any(part.startswith(".") or part == "__pycache__" for part in relative_parts): + continue + try: + stat = path.stat() + except OSError: + continue + entries.append((str(path.resolve()), stat.st_mtime_ns, stat.st_size)) + return tuple(entries) diff --git a/flocks/server/routes/flockspro_license.py b/flocks/server/routes/flockspro_license.py new file mode 100644 index 000000000..10b132273 --- /dev/null +++ b/flocks/server/routes/flockspro_license.py @@ -0,0 +1,64 @@ +"""OSS fallback routes for Flocks Pro license status. + +When the Pro package is installed it owns the actual license runtime. These +routes keep the WebUI status calls deterministic before that runtime is +available and delegate to the Pro checker whenever it can be imported. +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Request + +from flocks.server.auth import require_user +from flocks.server.routes.console_upgrade import _get_pro_capability_status, _is_pro_component_installed + +router = APIRouter() + + +def _inactive_status(reason: str, **extra: Any) -> dict[str, Any]: + return { + "activated": False, + "active": False, + "pro_enabled": False, + "license_status": "uninstalled" if reason == "flockspro_not_installed" else "unknown", + "inactive_reason": reason, + **extra, + } + + +@router.get("/status") +async def get_flockspro_license_status(request: Request) -> dict[str, Any]: + require_user(request) + if not _is_pro_component_installed(): + return _inactive_status("flockspro_not_installed") + + status = _get_pro_capability_status() + if not status: + return _inactive_status("capability_check_failed") + status.setdefault("activated", bool(status.get("active") or status.get("pro_enabled"))) + status.setdefault("active", bool(status.get("pro_enabled"))) + status.setdefault("pro_enabled", bool(status.get("active"))) + return status + + +@router.post("/refresh") +async def refresh_flockspro_license_status(request: Request) -> dict[str, Any]: + require_user(request) + if not _is_pro_component_installed(): + return _inactive_status("flockspro_not_installed") + + try: + from flockspro.license.runtime import get_license_checker # type: ignore[import-not-found] + + checker = get_license_checker() + refresh_fn = getattr(checker, "refresh", None) + if callable(refresh_fn): + result = refresh_fn() + if hasattr(result, "__await__"): + await result + except Exception as exc: + return _inactive_status("capability_check_failed", error=str(exc)) + + return await get_flockspro_license_status(request) diff --git a/flocks/server/routes/session.py b/flocks/server/routes/session.py index 074f7c123..7d7858b2d 100644 --- a/flocks/server/routes/session.py +++ b/flocks/server/routes/session.py @@ -36,6 +36,7 @@ # Default agent name constant DEFAULT_AGENT = "rex" DEFAULT_MESSAGE_PAGE_LIMIT = 50 +_DESCENDANT_ABORT_SCAN_LIMIT = 3 # File extensions that are safe to persist when materialising data-URL uploads. # Intentionally narrow: any extension outside this set is rejected to prevent @@ -710,6 +711,14 @@ async def delete_session(sessionID: str, request: Request) -> bool: if not SessionPolicy.can_delete(session, current_user): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="仅会话所有者可删除会话") + from flocks.session.goal import GoalManager + from flocks.session.interaction_queue import InteractionQueue + + await _abort_session_processing(sessionID) + await InteractionQueue.clear(sessionID) + await GoalManager.clear(sessionID) + await _wait_for_sessions_idle([sessionID]) + await _abort_and_wait_descendant_sessions(session.project_id, sessionID) await Session.delete(session.project_id, sessionID) # Best-effort cleanup of any image/file uploads materialised for this @@ -756,6 +765,56 @@ async def delete_session(sessionID: str, request: Request) -> bool: return True +async def _collect_descendant_session_ids(project_id: str, session_id: str) -> List[str]: + """Return child session IDs in the same order Session.delete will recurse.""" + sessions = await Session.list(project_id) + children_by_parent: Dict[str, List[str]] = {} + for child in sessions: + if child.parent_id is None: + continue + children_by_parent.setdefault(child.parent_id, []).append(child.id) + + descendants: List[str] = [] + seen: set[str] = set() + + def visit(parent_id: str) -> None: + for child_id in children_by_parent.get(parent_id, []): + if child_id in seen: + continue + seen.add(child_id) + descendants.append(child_id) + visit(child_id) + + visit(session_id) + return descendants + + +async def _abort_and_wait_descendant_sessions(project_id: str, session_id: str) -> None: + """Abort descendants that exist after the parent stops and wait as a batch.""" + known: set[str] = set() + latest: List[str] = [] + + for _ in range(_DESCENDANT_ABORT_SCAN_LIMIT): + latest = await _collect_descendant_session_ids(project_id, session_id) + new_ids = [sid for sid in latest if sid not in known] + for descendant_id in new_ids: + await _abort_session_processing(descendant_id) + known.update(new_ids) + + if latest: + await _wait_for_sessions_idle(latest) + + refreshed = await _collect_descendant_session_ids(project_id, session_id) + if set(refreshed).issubset(known): + return + latest = refreshed + + log.warn("session.delete.descendants_unstable", { + "session_id": session_id, + "descendants": latest, + }) + + class SessionUpdateRequest(BaseModel): """Request to update session""" model_config = ConfigDict(populate_by_name=True) @@ -3315,11 +3374,30 @@ async def _run() -> None: async def _wait_for_session_idle(session_id: str, timeout_s: float = 5.0) -> None: + await _wait_for_sessions_idle([session_id], timeout_s=timeout_s) + + +async def _wait_for_sessions_idle(session_ids: List[str], timeout_s: float = 5.0) -> None: from flocks.session.session_loop import SessionLoop + pending = set(session_ids) + if not pending: + return + deadline = time.time() + timeout_s - while SessionLoop.is_running(session_id) and time.time() < deadline: - await asyncio.sleep(0.05) + while pending: + running = {sid for sid in pending if SessionLoop.is_running(sid)} + if not running: + return + now = time.time() + if now >= deadline: + log.warn("session.wait_idle.timeout", { + "session_ids": sorted(running), + "timeout_s": timeout_s, + }) + return + pending = running + await asyncio.sleep(min(0.05, max(0.0, deadline - now))) def _build_prompt_request_from_event(event, prompt_text: str, display_text: Optional[str] = None): diff --git a/flocks/server/routes/user_defined_pages.py b/flocks/server/routes/webui.py similarity index 62% rename from flocks/server/routes/user_defined_pages.py rename to flocks/server/routes/webui.py index 924422064..0271b4a85 100644 --- a/flocks/server/routes/user_defined_pages.py +++ b/flocks/server/routes/webui.py @@ -1,4 +1,4 @@ -"""user-defined custom pages API routes.""" +"""WebUI page plugin API routes.""" from __future__ import annotations @@ -17,15 +17,21 @@ from pydantic import BaseModel, ConfigDict, Field from flocks.server.auth import require_admin, require_user -from flocks.user_defined_pages.builder import UserDefinedPagesBuilder -from flocks.user_defined_pages.api_runtime import UserDefinedPageApiRuntime -from flocks.user_defined_pages.models import UserDefinedPageBuildMeta, UserDefinedPageDetail, UserDefinedPageListItem, UserDefinedPageManifest -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.builder import WebUIPageBuilder +from flocks.contracts.webui.api_runtime import WebUIPageApiRuntime +from flocks.contracts.webui.models import ( + WebUIPageBuildMeta, + WebUIPageDetail, + WebUIPageListItem, + WebUIPageManifest, + WebUIWorkspaceListItem, +) +from flocks.contracts.webui.store import WebUIPagesStore, webui_contract_page_route from flocks.server.routes.event import publish_event from flocks.utils.log import Log router = APIRouter() -log = Log.create(service="user-defined-pages-routes") +log = Log.create(service="webui-pages-routes") MAX_IMPORT_ARCHIVE_BYTES = 10_000_000 MAX_IMPORT_FILES = 500 @@ -35,12 +41,12 @@ _IMPORT_API_SUFFIXES = {".py", ".yaml", ".yml"} _IMPORT_DIST_FILES = {"dist/page.js", "dist/meta.json", "dist/api-meta.json"} -_store = UserDefinedPagesStore() -_builder = UserDefinedPagesBuilder(_store) -_api_runtime = UserDefinedPageApiRuntime(_store) +_store = WebUIPagesStore() +_builder = WebUIPageBuilder(_store) +_api_runtime = WebUIPageApiRuntime(_store) -class UserDefinedPageCreateRequest(BaseModel): +class WebUIPageCreateRequest(BaseModel): model_config = ConfigDict(populate_by_name=True) id: str = Field(..., description="Page identifier") @@ -49,7 +55,7 @@ class UserDefinedPageCreateRequest(BaseModel): order: int = Field(100, description="Navigation sort order") -class UserDefinedPageSaveRequest(BaseModel): +class WebUIPageSaveRequest(BaseModel): model_config = ConfigDict(populate_by_name=True) manifest: Optional[dict[str, Any]] = Field(None, description="Manifest fields to merge") @@ -57,11 +63,11 @@ class UserDefinedPageSaveRequest(BaseModel): sourceContent: Optional[str] = Field(None, description="Source file content") -class UserDefinedPageSaveResponse(BaseModel): +class WebUIPageSaveResponse(BaseModel): model_config = ConfigDict(populate_by_name=True, by_alias=True) - manifest: UserDefinedPageManifest - build: UserDefinedPageBuildMeta + manifest: WebUIPageManifest + build: WebUIPageBuildMeta async def _read_limited_upload(file: UploadFile) -> bytes: @@ -123,7 +129,7 @@ def _normalize_import_manifest(extracted_root: Path, page_id: str) -> None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="manifest.json is required") try: raw = json.loads(manifest_path.read_text(encoding="utf-8")) - manifest = UserDefinedPageManifest.model_validate(raw) + manifest = WebUIPageManifest.model_validate(raw) except Exception as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"invalid manifest.json: {exc}") from exc @@ -131,7 +137,7 @@ def _normalize_import_manifest(extracted_root: Path, page_id: str) -> None: normalized = manifest.model_copy( update={ "id": page_id, - "route": f"/user-defined-pages/{page_id}", + "route": webui_contract_page_route(page_id), "entry": entry, "updatedAt": int(time.time() * 1000), } @@ -142,13 +148,18 @@ def _normalize_import_manifest(extracted_root: Path, page_id: str) -> None: ) -@router.get("/user-defined-pages", response_model=list[UserDefinedPageListItem]) -async def list_user_defined_pages(enabled_only: bool = Query(False, alias="enabledOnly")): +@router.get("/contracts/webui/pages", response_model=list[WebUIPageListItem]) +async def list_webui_pages(enabled_only: bool = Query(False, alias="enabledOnly")): return _store.list_pages(enabled_only=enabled_only) -@router.post("/user-defined-pages", response_model=UserDefinedPageDetail, status_code=status.HTTP_201_CREATED) -async def create_user_defined_page(req: UserDefinedPageCreateRequest, _admin: object = Depends(require_admin)): +@router.get("/contracts/webui/workspaces", response_model=list[WebUIWorkspaceListItem]) +async def list_webui_workspaces(enabled_only: bool = Query(False, alias="enabledOnly")): + return _store.list_workspaces(enabled_only=enabled_only) + + +@router.post("/contracts/webui/pages", response_model=WebUIPageDetail, status_code=status.HTTP_201_CREATED) +async def create_webui_page(req: WebUIPageCreateRequest, _admin: object = Depends(require_admin)): try: detail = _store.create_page( page_id=req.id, @@ -164,29 +175,29 @@ async def create_user_defined_page(req: UserDefinedPageCreateRequest, _admin: ob try: build = _builder.build(detail.manifest.id) if build.status == "ready": - await publish_event("user_defined_pages.updated", {"id": detail.manifest.id, "hash": build.hash}) + await publish_event("contracts.webui.pages.updated", {"id": detail.manifest.id, "hash": build.hash}) elif build.status == "failed": await publish_event( - "user_defined_pages.build_failed", + "contracts.webui.pages.build_failed", {"id": detail.manifest.id, "error": build.error or "build failed"}, ) except Exception as exc: - log.warning("user_defined_pages.create.build_failed", {"pageId": detail.manifest.id, "error": str(exc)}) + log.warning("webui_pages.create.build_failed", {"pageId": detail.manifest.id, "error": str(exc)}) - await publish_event("user_defined_pages.nav_changed", {"id": detail.manifest.id}) + await publish_event("contracts.webui.pages.nav_changed", {"id": detail.manifest.id}) return _store.get_page(detail.manifest.id) -@router.get("/user-defined-pages/{page_id}", response_model=UserDefinedPageDetail) -async def get_user_defined_page(page_id: str): +@router.get("/contracts/webui/pages/{page_id}", response_model=WebUIPageDetail) +async def get_webui_page(page_id: str): try: return _store.get_page(page_id) except FileNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc -@router.put("/user-defined-pages/{page_id}", response_model=UserDefinedPageSaveResponse) -async def save_user_defined_page(page_id: str, req: UserDefinedPageSaveRequest, _admin: object = Depends(require_admin)): +@router.put("/contracts/webui/pages/{page_id}", response_model=WebUIPageSaveResponse) +async def save_webui_page(page_id: str, req: WebUIPageSaveRequest, _admin: object = Depends(require_admin)): nav_changed = False try: if req.manifest is not None: @@ -204,44 +215,44 @@ async def save_user_defined_page(page_id: str, req: UserDefinedPageSaveRequest, except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc - build = UserDefinedPageBuildMeta(status="idle") + build = WebUIPageBuildMeta(status="idle") if req.sourcePath is not None: rel = req.sourcePath.replace("\\", "/").lstrip("/") if rel.startswith("api/"): try: routes = await _api_runtime.reload_page(page_id) - await publish_event("user_defined_pages.api_changed", {"id": page_id, "routes": routes}) + await publish_event("contracts.webui.pages.api_changed", {"id": page_id, "routes": routes}) except HTTPException as exc: await publish_event( - "user_defined_pages.api_failed", + "contracts.webui.pages.api_failed", {"id": page_id, "error": str(exc.detail)}, ) raise except Exception as exc: await publish_event( - "user_defined_pages.api_failed", + "contracts.webui.pages.api_failed", {"id": page_id, "error": str(exc)}, ) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc else: build = _builder.build(page_id) if build.status == "ready": - await publish_event("user_defined_pages.updated", {"id": page_id, "hash": build.hash}) + await publish_event("contracts.webui.pages.updated", {"id": page_id, "hash": build.hash}) nav_changed = True else: await publish_event( - "user_defined_pages.build_failed", + "contracts.webui.pages.build_failed", {"id": page_id, "error": build.error or "build failed"}, ) elif nav_changed: - await publish_event("user_defined_pages.nav_changed", {"id": page_id}) + await publish_event("contracts.webui.pages.nav_changed", {"id": page_id}) manifest = _store.get_page(page_id).manifest - return UserDefinedPageSaveResponse(manifest=manifest, build=build) + return WebUIPageSaveResponse(manifest=manifest, build=build) -@router.post("/user-defined-pages/{page_id}/build", response_model=UserDefinedPageBuildMeta) -async def build_user_defined_page(page_id: str, _admin: object = Depends(require_admin)): +@router.post("/contracts/webui/pages/{page_id}/build", response_model=WebUIPageBuildMeta) +async def build_webui_page(page_id: str, _admin: object = Depends(require_admin)): try: build = _builder.build(page_id) except FileNotFoundError as exc: @@ -250,18 +261,18 @@ async def build_user_defined_page(page_id: str, _admin: object = Depends(require raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc if build.status == "ready": - await publish_event("user_defined_pages.updated", {"id": page_id, "hash": build.hash}) - await publish_event("user_defined_pages.nav_changed", {"id": page_id}) + await publish_event("contracts.webui.pages.updated", {"id": page_id, "hash": build.hash}) + await publish_event("contracts.webui.pages.nav_changed", {"id": page_id}) else: await publish_event( - "user_defined_pages.build_failed", + "contracts.webui.pages.build_failed", {"id": page_id, "error": build.error or "build failed"}, ) return build -@router.get("/user-defined-pages/{page_id}/bundle.js") -async def get_user_defined_page_bundle(page_id: str, v: Optional[str] = Query(None)): +@router.get("/contracts/webui/pages/{page_id}/bundle.js") +async def get_webui_page_bundle(page_id: str, v: Optional[str] = Query(None)): try: bundle_path = _store.bundle_path(page_id) if not bundle_path.is_file(): @@ -276,8 +287,8 @@ async def get_user_defined_page_bundle(page_id: str, v: Optional[str] = Query(No raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc -@router.get("/user-defined-pages/{page_id}/assets/{asset_path:path}") -async def get_user_defined_page_asset(page_id: str, asset_path: str): +@router.get("/contracts/webui/pages/{page_id}/assets/{asset_path:path}") +async def get_webui_page_asset(page_id: str, asset_path: str): try: path = _store.asset_path(page_id, asset_path) if not path.is_file(): @@ -287,34 +298,34 @@ async def get_user_defined_page_asset(page_id: str, asset_path: str): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc -@router.get("/user-defined-pages/{page_id}/api") -async def list_user_defined_page_api_routes(page_id: str): +@router.get("/contracts/webui/pages/{page_id}/api") +async def list_webui_page_api_routes(page_id: str): return await _api_runtime.list_routes(page_id) -@router.post("/user-defined-pages/{page_id}/api/reload") -async def reload_user_defined_page_api(page_id: str, _admin: object = Depends(require_admin)): +@router.post("/contracts/webui/pages/{page_id}/api/reload") +async def reload_webui_page_api(page_id: str, _admin: object = Depends(require_admin)): routes = await _api_runtime.reload_page(page_id) - await publish_event("user_defined_pages.api_changed", {"id": page_id, "routes": routes}) + await publish_event("contracts.webui.pages.api_changed", {"id": page_id, "routes": routes}) return {"routes": routes} @router.api_route( - "/user-defined-pages/{page_id}/api/{api_path:path}", + "/contracts/webui/pages/{page_id}/api/{api_path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE"], ) -async def dispatch_user_defined_page_api(page_id: str, api_path: str, request: Request): +async def dispatch_webui_page_api(page_id: str, api_path: str, request: Request): user = require_user(request) return await _api_runtime.dispatch(page_id, api_path, request, user) -@router.get("/user-defined-pages/{page_id}/export") -async def export_user_defined_page(page_id: str, _admin: object = Depends(require_admin)): +@router.get("/contracts/webui/pages/{page_id}/export") +async def export_webui_page(page_id: str, _admin: object = Depends(require_admin)): page_path = _store.page_dir(page_id) if not page_path.is_dir(): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"page not found: {page_id}") - fd, archive_path = tempfile.mkstemp(prefix=f"user-defined-page-{page_id}-", suffix=".zip") + fd, archive_path = tempfile.mkstemp(prefix=f"webui-page-{page_id}-", suffix=".zip") os.close(fd) try: with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: @@ -335,8 +346,8 @@ async def export_user_defined_page(page_id: str, _admin: object = Depends(requir ) -@router.post("/user-defined-pages/import") -async def import_user_defined_page( +@router.post("/contracts/webui/pages/import") +async def import_webui_page( file: UploadFile = File(...), overwrite: bool = Query(False), _admin: object = Depends(require_admin), @@ -345,7 +356,7 @@ async def import_user_defined_page( if not data: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="empty archive") try: - with tempfile.TemporaryDirectory(prefix="udp-import-") as tmpdir: + with tempfile.TemporaryDirectory(prefix="webui-import-") as tmpdir: temp_root = Path(tmpdir) / "extract" temp_root.mkdir(parents=True, exist_ok=True) archive_path = Path(tmpdir) / "archive.zip" @@ -381,32 +392,32 @@ async def import_user_defined_page( target.parent.mkdir(parents=True, exist_ok=True) target.write_bytes(zf.read(member)) _normalize_import_manifest(extracted_root, page_id) - target = _store.page_dir(page_id) + if _store.page_exists(page_id) and not overwrite: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"page already exists: {page_id}") + target = _store.root_page_dir(page_id) if target.exists(): - if not overwrite: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"page already exists: {page_id}") shutil.rmtree(target) shutil.move(str((temp_root / page_id).resolve()), str(target)) try: build = _builder.build(page_id) if build.status == "ready": - await publish_event("user_defined_pages.updated", {"id": page_id, "hash": build.hash}) + await publish_event("contracts.webui.pages.updated", {"id": page_id, "hash": build.hash}) else: await publish_event( - "user_defined_pages.build_failed", + "contracts.webui.pages.build_failed", {"id": page_id, "error": build.error or "build failed"}, ) except Exception as exc: - log.warning("user_defined_pages.import.build_failed", {"pageId": page_id, "error": str(exc)}) - await publish_event("user_defined_pages.build_failed", {"id": page_id, "error": str(exc)}) + log.warning("webui_pages.import.build_failed", {"pageId": page_id, "error": str(exc)}) + await publish_event("contracts.webui.pages.build_failed", {"id": page_id, "error": str(exc)}) if _store.routes_path(page_id).is_file(): try: routes = await _api_runtime.reload_page(page_id) - await publish_event("user_defined_pages.api_changed", {"id": page_id, "routes": routes}) + await publish_event("contracts.webui.pages.api_changed", {"id": page_id, "routes": routes}) except Exception as exc: - log.warning("user_defined_pages.import.api_reload_failed", {"pageId": page_id, "error": str(exc)}) - await publish_event("user_defined_pages.api_failed", {"id": page_id, "error": str(exc)}) - await publish_event("user_defined_pages.nav_changed", {"id": page_id}) + log.warning("webui_pages.import.api_reload_failed", {"pageId": page_id, "error": str(exc)}) + await publish_event("contracts.webui.pages.api_failed", {"id": page_id, "error": str(exc)}) + await publish_event("contracts.webui.pages.nav_changed", {"id": page_id}) return _store.get_page(page_id) except HTTPException: raise @@ -414,14 +425,97 @@ async def import_user_defined_page( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="invalid zip archive") from exc +@router.get("/user-defined-pages", response_model=list[WebUIPageListItem]) +async def list_legacy_user_defined_pages(enabled_only: bool = Query(False, alias="enabledOnly")): + """Compatibility alias for pre-contract WebUI page clients.""" + return await list_webui_pages(enabled_only=enabled_only) + + +@router.post("/user-defined-pages", response_model=WebUIPageDetail, status_code=status.HTTP_201_CREATED) +async def create_legacy_user_defined_page(req: WebUIPageCreateRequest, _admin: object = Depends(require_admin)): + """Compatibility alias for pre-contract WebUI page clients.""" + return await create_webui_page(req, _admin) + + +@router.get("/user-defined-pages/{page_id}", response_model=WebUIPageDetail) +async def get_legacy_user_defined_page(page_id: str): + """Compatibility alias for pre-contract WebUI page clients.""" + return await get_webui_page(page_id) + + +@router.put("/user-defined-pages/{page_id}", response_model=WebUIPageSaveResponse) +async def save_legacy_user_defined_page( + page_id: str, + req: WebUIPageSaveRequest, + _admin: object = Depends(require_admin), +): + """Compatibility alias for pre-contract WebUI page clients.""" + return await save_webui_page(page_id, req, _admin) + + +@router.post("/user-defined-pages/{page_id}/build", response_model=WebUIPageBuildMeta) +async def build_legacy_user_defined_page(page_id: str, _admin: object = Depends(require_admin)): + """Compatibility alias for pre-contract WebUI page clients.""" + return await build_webui_page(page_id, _admin) + + +@router.get("/user-defined-pages/{page_id}/bundle.js") +async def get_legacy_user_defined_page_bundle(page_id: str, v: Optional[str] = Query(None)): + """Compatibility alias for pre-contract WebUI page clients.""" + return await get_webui_page_bundle(page_id, v) + + +@router.get("/user-defined-pages/{page_id}/assets/{asset_path:path}") +async def get_legacy_user_defined_page_asset(page_id: str, asset_path: str): + """Compatibility alias for pre-contract WebUI page clients.""" + return await get_webui_page_asset(page_id, asset_path) + + +@router.get("/user-defined-pages/{page_id}/api") +async def list_legacy_user_defined_page_api_routes(page_id: str): + """Compatibility alias for pre-contract WebUI page clients.""" + return await list_webui_page_api_routes(page_id) + + +@router.post("/user-defined-pages/{page_id}/api/reload") +async def reload_legacy_user_defined_page_api(page_id: str, _admin: object = Depends(require_admin)): + """Compatibility alias for pre-contract WebUI page clients.""" + return await reload_webui_page_api(page_id, _admin) + + +@router.api_route( + "/user-defined-pages/{page_id}/api/{api_path:path}", + methods=["GET", "POST", "PUT", "PATCH", "DELETE"], +) +async def dispatch_legacy_user_defined_page_api(page_id: str, api_path: str, request: Request): + """Compatibility alias for pre-contract WebUI page clients.""" + return await dispatch_webui_page_api(page_id, api_path, request) + + +@router.get("/user-defined-pages/{page_id}/export") +async def export_legacy_user_defined_page(page_id: str, _admin: object = Depends(require_admin)): + """Compatibility alias for pre-contract WebUI page clients.""" + return await export_webui_page(page_id, _admin) + + +@router.post("/user-defined-pages/import") +async def import_legacy_user_defined_page( + file: UploadFile = File(...), + overwrite: bool = Query(False), + _admin: object = Depends(require_admin), +): + """Compatibility alias for pre-contract WebUI page clients.""" + return await import_webui_page(file, overwrite, _admin) + + def reset_route_dependencies( *, - store: Optional[UserDefinedPagesStore] = None, - builder: Optional[UserDefinedPagesBuilder] = None, - api_runtime: Optional[UserDefinedPageApiRuntime] = None, + store: Optional[WebUIPagesStore] = None, + builder: Optional[WebUIPageBuilder] = None, + api_runtime: Optional[WebUIPageApiRuntime] = None, ) -> None: """Test helper to inject isolated store/builder instances.""" global _store, _builder, _api_runtime - _store = store or UserDefinedPagesStore() - _builder = builder or UserDefinedPagesBuilder(_store) - _api_runtime = api_runtime or UserDefinedPageApiRuntime(_store) + _store = store or WebUIPagesStore() + _builder = builder or WebUIPageBuilder(_store) + _api_runtime = api_runtime or WebUIPageApiRuntime(_store) diff --git a/flocks/server/routes/workflow.py b/flocks/server/routes/workflow.py index b3f4bc1b9..c480e0b1f 100644 --- a/flocks/server/routes/workflow.py +++ b/flocks/server/routes/workflow.py @@ -3,6 +3,7 @@ Provides API endpoints for workflow CRUD, execution, history, and AI generation. """ + from __future__ import annotations import asyncio @@ -21,7 +22,7 @@ import uuid from flocks.workflow.models import Workflow, Node, Edge -from flocks.workflow.runner import run_workflow, RunWorkflowResult +from flocks.workflow.runner import RunWorkflowResult, run_workflow from flocks.workflow.center import ( WorkflowCenterError, WorkflowNotFoundError, @@ -61,6 +62,7 @@ workflow_execution_step_prefix as _workflow_execution_step_prefix, ) from flocks.workflow.io import load_workflow, dump_workflow +from flocks.workflow.store import WorkflowStore from flocks.workflow.tool_context import build_workflow_tool_context from flocks.workflow.tools import get_tool_registry from flocks.workflow.visibility import is_hidden_workflow_data @@ -105,21 +107,23 @@ _WORKFLOW_CENTER_RUNTIME_PREFIX = "workflow_runtime/" _WORKFLOW_CENTER_LOCAL_PID_PREFIX = "workflow_local_pid/" _WORKFLOW_POLLER_CONFIG_PREFIX = "workflow_poller_config/" -_WORKFLOW_CONFIG_TRIGGER_TYPES = frozenset({ - "manual", - "schedule", - "webhook", - "syslog", - "kafka", - "internal_event", - "custom_webhook", - "custom_adapter", - "plugin", - "api", - "publish", - "api_service", - "service", -}) +_WORKFLOW_CONFIG_TRIGGER_TYPES = frozenset( + { + "manual", + "schedule", + "webhook", + "syslog", + "kafka", + "internal_event", + "custom_webhook", + "custom_adapter", + "plugin", + "api", + "publish", + "api_service", + "service", + } +) _WORKFLOW_CONFIG_SECRET_KEYS = frozenset({"apikey", "password", "token", "secret"}) _WORKFLOW_CONFIG_SECRET_REF_KEYS = frozenset({"secretref", "secretreference"}) @@ -127,6 +131,7 @@ @dataclass class ActiveWorkflowExecution: """Tracks an in-flight workflow execution that can be cancelled.""" + workflow_id: str task: asyncio.Task[Any] cancel_event: threading.Event @@ -139,10 +144,12 @@ class ActiveWorkflowExecution: # Request/Response Models # ============================================================================= + class WorkflowCreateRequest(BaseModel): """Request to create a workflow""" + model_config = ConfigDict(populate_by_name=True) - + name: str = Field(..., description="Workflow name") name_i18n: Optional[Dict[str, str]] = Field(None, alias="nameI18n", description="Localized workflow display names") description: Optional[str] = Field(None, description="Workflow description") @@ -157,8 +164,9 @@ class WorkflowCreateRequest(BaseModel): class WorkflowUpdateRequest(BaseModel): """Request to update a workflow""" + model_config = ConfigDict(populate_by_name=True) - + name: Optional[str] = Field(None, description="Workflow name") name_i18n: Optional[Dict[str, str]] = Field(None, alias="nameI18n", description="Localized workflow display names") description: Optional[str] = Field(None, description="Workflow description") @@ -179,8 +187,9 @@ class WorkflowUpdateRequest(BaseModel): class WorkflowResponse(BaseModel): """Workflow response""" + model_config = ConfigDict(populate_by_name=True, by_alias=True) - + id: str = Field(..., description="Workflow ID") name: str = Field(..., description="Workflow name") nameI18n: Optional[Dict[str, str]] = Field(None, description="Localized workflow display names") @@ -199,8 +208,9 @@ class WorkflowResponse(BaseModel): class WorkflowRunRequest(BaseModel): """Request to run a workflow""" + model_config = ConfigDict(populate_by_name=True) - + inputs: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Input parameters") timeout_s: Optional[float] = Field(None, alias="timeoutS", description="Timeout in seconds") trace: bool = Field(False, description="Enable tracing") @@ -211,8 +221,9 @@ class WorkflowRunRequest(BaseModel): class WorkflowExecutionResponse(BaseModel): """Workflow execution response""" + model_config = ConfigDict(populate_by_name=True, by_alias=True) - + id: str = Field(..., description="Execution ID") workflowId: str = Field(..., description="Workflow ID") inputParams: Dict[str, Any] = Field(default_factory=dict, description="Input parameters") @@ -260,8 +271,9 @@ class WorkflowCenterInvokeRequest(BaseModel): class WorkflowStatsResponse(BaseModel): """Workflow statistics response""" + model_config = ConfigDict(populate_by_name=True, by_alias=True) - + workflowId: Optional[str] = Field(None, description="Workflow ID (null for aggregate)") callCount: int = Field(0, description="Total calls") successCount: int = Field(0, description="Successful calls") @@ -381,6 +393,45 @@ def _write_workflow_to_fs( legacy_edit_file.unlink() +def _apply_new_workflow_runtime_defaults(workflow_json: Dict[str, Any]) -> Dict[str, Any]: + """Return workflow JSON with runtime defaults for newly created workflows.""" + normalized = dict(workflow_json) + metadata = normalized.get("metadata") + if not isinstance(metadata, dict): + metadata = {} + else: + metadata = dict(metadata) + runtime = metadata.get("runtime") + if not isinstance(runtime, dict): + runtime = {} + else: + runtime = dict(runtime) + + runtime.setdefault("strict_edge_mapping", True) + runtime.setdefault("dataflow_mode", "vertex_cache") + metadata["runtime"] = runtime + normalized["metadata"] = metadata + return normalized + + +def _strict_edge_mapping_lint_errors(workflow: Workflow) -> List[Dict[str, Any]]: + """Return strict edge-mapping lint errors that would fail execution.""" + return [ + item + for item in lint_workflow(workflow) + if item.get("severity") == "error" and item.get("kind") == "implicit_full_payload_edge" + ] + + +def _schema_lint_errors(workflow: Workflow) -> List[Dict[str, Any]]: + """Return lightweight schema lint errors that would fail execution.""" + return [ + item + for item in lint_workflow(workflow) + if item.get("severity") == "error" and str(item.get("kind", "")).startswith("schema_") + ] + + def _delete_workflow_from_fs(workflow_id: str) -> bool: """Remove a workflow directory from all known locations (primary + legacy plugins). @@ -438,38 +489,23 @@ async def _stop_workflow_runtime_resources(workflow_id: str) -> None: async def _cleanup_workflow_storage(workflow_id: str) -> None: - await _remove_storage_key_if_exists(_workflow_stats_key(workflow_id)) - await _remove_storage_key_if_exists(_workflow_integration_config_key(workflow_id)) - await _remove_storage_key_if_exists(_api_service_key(workflow_id)) - await _remove_storage_key_if_exists(_syslog_config_key(workflow_id)) - await _remove_storage_key_if_exists(_kafka_config_key(workflow_id)) - await _remove_storage_key_if_exists(f"{_WORKFLOW_POLLER_CONFIG_PREFIX}{workflow_id}") - await _remove_storage_key_if_exists(f"{_WORKFLOW_CENTER_REGISTRY_PREFIX}{workflow_id}") - await _remove_storage_key_if_exists(f"{_WORKFLOW_CENTER_RUNTIME_PREFIX}{workflow_id}") - await _remove_storage_key_if_exists(f"{_WORKFLOW_CENTER_LOCAL_PID_PREFIX}{workflow_id}") - await _remove_storage_prefix(f"{_WORKFLOW_CENTER_RELEASE_PREFIX}{workflow_id}/") - await _remove_storage_prefix(f"workflow_execution_index/{workflow_id}/") - try: - exec_keys = await Storage.list_keys("workflow_execution/") - for key in exec_keys: - try: - exec_data = await Storage.read(key) - if isinstance(exec_data, dict) and exec_data.get("workflowId") == workflow_id: - exec_id = key.rsplit("/", 1)[-1] - await Storage.clear(_workflow_execution_step_prefix(exec_id)) - await Storage.remove(key) - except Exception as exc: - log.warning("workflow.delete.execution_cleanup_failed", { - "workflow_id": workflow_id, - "key": key, - "error": str(exc), - }) + await WorkflowStore.delete_stats(workflow_id) + await WorkflowStore.delete_config(workflow_id) + await WorkflowStore.delete_executions_for_workflow(workflow_id) + await WorkflowStore.kv_remove(_api_service_key(workflow_id)) + await WorkflowStore.kv_remove(f"{_WORKFLOW_CENTER_REGISTRY_PREFIX}{workflow_id}") + await WorkflowStore.kv_remove(f"{_WORKFLOW_CENTER_RUNTIME_PREFIX}{workflow_id}") + await WorkflowStore.kv_remove(f"{_WORKFLOW_CENTER_LOCAL_PID_PREFIX}{workflow_id}") + await WorkflowStore.kv_clear(f"{_WORKFLOW_CENTER_RELEASE_PREFIX}{workflow_id}/") except Exception as exc: - log.warning("workflow.delete.execution_scan_failed", { - "workflow_id": workflow_id, - "error": str(exc), - }) + log.warning( + "workflow.delete.workflow_store_cleanup_failed", + { + "workflow_id": workflow_id, + "error": str(exc), + }, + ) service_dir = Config.get_data_path() / "workflow-services" / "workflows" / workflow_id if service_dir.is_dir(): @@ -583,6 +619,7 @@ async def _migrate_storage_to_filesystem() -> None: # Storage Helpers (Stats & Execution only) # ============================================================================= + def _workflow_stats_key(workflow_id: str) -> str: return f"workflow/{workflow_id}/stats" @@ -593,13 +630,13 @@ def _syslog_config_key(workflow_id: str) -> str: async def _read_legacy_trigger_defs(workflow_id: str) -> List[TriggerDefinition]: triggers: List[TriggerDefinition] = [] - for key, converter in ( - (_kafka_config_key(workflow_id), legacy_kafka_trigger_from_config), - (f"workflow_poller_config/{workflow_id}", legacy_schedule_trigger_from_config), - (_syslog_config_key(workflow_id), legacy_syslog_trigger_from_config), + for kind, converter in ( + ("workflow_kafka_config", legacy_kafka_trigger_from_config), + ("workflow_poller_config", legacy_schedule_trigger_from_config), + ("workflow_syslog_config", legacy_syslog_trigger_from_config), ): try: - config = await Storage.read(key) + config = await WorkflowStore.get_config(workflow_id, kind=kind) except Exception: config = None trigger = converter(config) @@ -786,7 +823,7 @@ async def _build_workflow_integration_config( if trigger_defs is None: trigger_defs = await _get_workflow_trigger_defs(workflow_id, workflow_data) if service is None: - service = await Storage.read(_api_service_key(workflow_id)) + service = await WorkflowStore.kv_get(_api_service_key(workflow_id)) now_ms = int(time.time() * 1000) return { "version": _WORKFLOW_INTEGRATION_CONFIG_VERSION, @@ -810,7 +847,7 @@ async def _build_workflow_integration_runtime( workflow_data: Dict[str, Any], ) -> Dict[str, Any]: triggers = await _get_workflow_trigger_defs(workflow_id, workflow_data) - service = await Storage.read(_api_service_key(workflow_id)) + service = await WorkflowStore.kv_get(_api_service_key(workflow_id)) statuses: Dict[str, Dict[str, Any]] = {} try: statuses = { @@ -822,10 +859,13 @@ async def _build_workflow_integration_runtime( if item.get("triggerId") } except Exception as exc: - log.warning("workflow.config.runtime_status_failed", { - "id": workflow_id, - "error": str(exc), - }) + log.warning( + "workflow.config.runtime_status_failed", + { + "id": workflow_id, + "error": str(exc), + }, + ) return { "publish": _publish_for_config(service), @@ -863,7 +903,7 @@ async def _write_workflow_integration_config( async def _read_stored_workflow_integration_config(workflow_id: str) -> Optional[Dict[str, Any]]: - stored = await Storage.read(_workflow_integration_config_key(workflow_id)) + stored = await WorkflowStore.get_config(workflow_id) return stored if isinstance(stored, dict) else None @@ -892,12 +932,15 @@ async def _load_workflow_integration_config_template( file_config = await _read_file_workflow_integration_config(workflow_id, workflow_data, config_path) if file_config is not None: - await Storage.write(_workflow_integration_config_key(workflow_id), file_config) - log.info("workflow.config.migrated_from_file", { - "id": workflow_id, - "path": str(config_path), - "storage_key": _workflow_integration_config_key(workflow_id), - }) + await WorkflowStore.put_config(workflow_id, file_config) + log.info( + "workflow.config.migrated_from_file", + { + "id": workflow_id, + "path": str(config_path), + "storage_key": _workflow_integration_config_key(workflow_id), + }, + ) return file_config, "file_migrated" return None, "missing" @@ -938,19 +981,19 @@ def _disable_legacy_trigger_of_type( async def _sync_trigger_legacy_state(workflow_id: str, trigger: TriggerDefinition) -> Optional[Dict[str, Any]]: if trigger.type == "kafka": config = kafka_trigger_to_legacy_config(workflow_id, trigger) - await Storage.write(_kafka_config_key(workflow_id), config) + await WorkflowStore.put_config(workflow_id, config, kind="workflow_kafka_config") from flocks.ingest.kafka.manager import default_manager as _kafka_default_manager return await _kafka_default_manager.restart_workflow(workflow_id) if trigger.type == "schedule": config = schedule_trigger_to_legacy_config(workflow_id, trigger) - await Storage.write(f"workflow_poller_config/{workflow_id}", config) + await WorkflowStore.put_config(workflow_id, config, kind="workflow_poller_config") from flocks.workflow.poller_manager import default_manager as _poller_default_manager return await _poller_default_manager.restart_workflow(workflow_id) if trigger.type == "syslog": config = syslog_trigger_to_legacy_config(workflow_id, trigger) - await Storage.write(_syslog_config_key(workflow_id), config) + await WorkflowStore.put_config(workflow_id, config, kind="workflow_syslog_config") from flocks.ingest.syslog.manager import default_manager as _syslog_default_manager return await _syslog_default_manager.restart_workflow(workflow_id) @@ -966,10 +1009,7 @@ async def _remove_legacy_trigger_state(workflow_id: str, trigger: TriggerDefinit await _kafka_default_manager.stop_workflow(workflow_id) except Exception: pass - try: - await Storage.remove(_kafka_config_key(workflow_id)) - except Storage.NotFoundError: - pass + await WorkflowStore.delete_config(workflow_id, kind="workflow_kafka_config") return if trigger.type == "schedule": try: @@ -978,10 +1018,7 @@ async def _remove_legacy_trigger_state(workflow_id: str, trigger: TriggerDefinit await _poller_default_manager.stop_workflow(workflow_id) except Exception: pass - try: - await Storage.remove(f"workflow_poller_config/{workflow_id}") - except Storage.NotFoundError: - pass + await WorkflowStore.delete_config(workflow_id, kind="workflow_poller_config") return if trigger.type == "syslog": try: @@ -990,10 +1027,7 @@ async def _remove_legacy_trigger_state(workflow_id: str, trigger: TriggerDefinit await _syslog_default_manager.stop_workflow(workflow_id) except Exception: pass - try: - await Storage.remove(_syslog_config_key(workflow_id)) - except Storage.NotFoundError: - pass + await WorkflowStore.delete_config(workflow_id, kind="workflow_syslog_config") async def _persist_workflow_triggers( @@ -1027,7 +1061,6 @@ async def _run_workflow_execution_task( tool_context: Optional[ToolContext] = None, ) -> None: """Execute a workflow in the background and keep the execution record updated.""" - exec_key = _workflow_execution_key(exec_id) start_time = time.time() step_count = 0 loop = asyncio.get_running_loop() @@ -1049,13 +1082,16 @@ def _write_progress(update_fields: Dict[str, Any]) -> None: try: execution_summary.update(update_fields) asyncio.run_coroutine_threadsafe( - Storage.write(exec_key, compact_execution_summary(execution_summary)), loop + WorkflowStore.upsert_execution(compact_execution_summary(execution_summary)), loop ).result(timeout=5) except Exception as exc: - log.warning("workflow.step_progress.write_failed", { - "exec_id": exec_id, - "error": str(exc), - }) + log.warning( + "workflow.step_progress.write_failed", + { + "exec_id": exec_id, + "error": str(exc), + }, + ) def _on_step_start(_run_id, step_index, node, _inputs): nonlocal pending_step_index, pending_step @@ -1068,21 +1104,25 @@ def _on_step_start(_run_id, step_index, node, _inputs): outputs=None, ) pending_step_index = step_index - pending_step = { - "node_id": node_id, - "node_type": node_type, - "inputs": _inputs if isinstance(_inputs, dict) else {}, - "outputs": {}, - "error": "Run cancelled before node completed", - } - execution_summary.update({ - "currentNodeId": node_id, - "currentNodeType": node_type, - "currentPhase": "running", - "currentStepIndex": step_index, - "loopProgress": loop_progress, - "updatedAt": int(time.time() * 1000), - }) + pending_step = compact_step_for_storage( + { + "node_id": node_id, + "node_type": node_type, + "inputs": _inputs if isinstance(_inputs, dict) else {}, + "outputs": {}, + "error": "Run cancelled before node completed", + } + ) + _write_progress( + { + "currentNodeId": node_id, + "currentNodeType": node_type, + "currentPhase": "running", + "currentStepIndex": step_index, + "loopProgress": loop_progress, + "updatedAt": int(time.time() * 1000), + } + ) return step_index def _on_step_complete(step_result) -> None: @@ -1097,28 +1137,8 @@ def _on_step_complete(step_result) -> None: inputs=step_dict.get("inputs"), outputs=step_dict.get("outputs"), ) - execution_summary.update({ - "stepCount": step_count, - "currentNodeId": step_dict.get("node_id"), - "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), - "currentPhase": "running", - "currentStepIndex": step_count, - "loopProgress": loop_progress, - "updatedAt": int(time.time() * 1000), - }) - try: - asyncio.run_coroutine_threadsafe( - record_execution_step(exec_id, step_count, step_dict), - loop, - ).result(timeout=5) - except Exception as exc: - log.warning("workflow.execution_step.write_failed", { - "exec_id": exec_id, - "step_index": step_count, - "error": str(exc), - }) - if step_count % _PROGRESS_FLUSH_EVERY_STEPS == 0: - _write_progress({ + execution_summary.update( + { "stepCount": step_count, "currentNodeId": step_dict.get("node_id"), "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), @@ -1126,7 +1146,34 @@ def _on_step_complete(step_result) -> None: "currentStepIndex": step_count, "loopProgress": loop_progress, "updatedAt": int(time.time() * 1000), - }) + } + ) + try: + asyncio.run_coroutine_threadsafe( + record_execution_step(exec_id, step_count, step_dict), + loop, + ).result(timeout=5) + except Exception as exc: + log.warning( + "workflow.execution_step.write_failed", + { + "exec_id": exec_id, + "step_index": step_count, + "error": str(exc), + }, + ) + if step_count % _PROGRESS_FLUSH_EVERY_STEPS == 0: + _write_progress( + { + "stepCount": step_count, + "currentNodeId": step_dict.get("node_id"), + "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), + "currentPhase": "running", + "currentStepIndex": step_count, + "loopProgress": loop_progress, + "updatedAt": int(time.time() * 1000), + } + ) async def _flush_pending_step() -> None: if pending_step_index is None or pending_step is None: @@ -1134,11 +1181,14 @@ async def _flush_pending_step() -> None: try: await record_execution_step(exec_id, pending_step_index, pending_step) except Exception as exc: - log.warning("workflow.pending_step.write_failed", { - "exec_id": exec_id, - "step_index": pending_step_index, - "error": str(exc), - }) + log.warning( + "workflow.pending_step.write_failed", + { + "exec_id": exec_id, + "step_index": pending_step_index, + "error": str(exc), + }, + ) try: result: RunWorkflowResult = await asyncio.to_thread( @@ -1167,47 +1217,57 @@ async def _flush_pending_step() -> None: final_steps = result.steps if pending_step_index is not None: final_steps = max(final_steps, pending_step_index) - current_data.update({ - "outputResults": compact_outputs_for_storage(result.outputs), - "status": status_value, - "finishedAt": int(time.time() * 1000), - "duration": duration, - "executionLog": final_history, - "stepCount": final_steps, - "errorMessage": error_message, - "currentNodeId": result.last_node_id, - "currentNodeType": current_data.get("currentNodeType"), - "currentPhase": status_value, - "currentStepIndex": final_steps, - "updatedAt": int(time.time() * 1000), - }) + current_data.update( + { + "outputResults": compact_outputs_for_storage(result.outputs), + "status": status_value, + "finishedAt": int(time.time() * 1000), + "duration": duration, + "executionLog": final_history, + "stepCount": final_steps, + "errorMessage": error_message, + "currentNodeId": result.last_node_id, + "currentNodeType": current_data.get("currentNodeType"), + "currentPhase": status_value, + "currentStepIndex": final_steps, + "updatedAt": int(time.time() * 1000), + } + ) await _record_execution_result(workflow_id, exec_id, current_data) - log.info("workflow.executed", { - "id": workflow_id, - "exec_id": exec_id, - "status": status_value, - "duration": duration, - }) + log.info( + "workflow.executed", + { + "id": workflow_id, + "exec_id": exec_id, + "status": status_value, + "duration": duration, + }, + ) except Exception as exc: duration = time.time() - start_time current_data = dict(execution_summary) - current_data.update({ - "status": "cancelled" if cancel_event.is_set() else "error", - "finishedAt": int(time.time() * 1000), - "duration": duration, - "errorMessage": str(exc), - "executionLog": [], - "stepCount": step_count, - "currentPhase": "cancelled" if cancel_event.is_set() else "error", - "updatedAt": int(time.time() * 1000), - }) + current_data.update( + { + "status": "cancelled" if cancel_event.is_set() else "error", + "finishedAt": int(time.time() * 1000), + "duration": duration, + "errorMessage": str(exc), + "executionLog": [], + "stepCount": step_count, + "currentPhase": "cancelled" if cancel_event.is_set() else "error", + "updatedAt": int(time.time() * 1000), + } + ) await _record_execution_result(workflow_id, exec_id, current_data) - log.error("workflow.execute.error", { - "id": workflow_id, - "exec_id": exec_id, - "error": str(exc), - }) + log.error( + "workflow.execute.error", + { + "id": workflow_id, + "exec_id": exec_id, + "error": str(exc), + }, + ) finally: _active_workflow_executions.pop(exec_id, None) @@ -1234,7 +1294,7 @@ def _compute_avg_runtime(stats: Dict[str, Any]) -> Dict[str, Any]: async def _get_workflow_stats(workflow_id: str) -> Dict[str, Any]: """Get workflow statistics""" try: - data = await Storage.read(_workflow_stats_key(workflow_id)) + data = await WorkflowStore.get_stats(workflow_id) if data is None: return dict(_DEFAULT_STATS) return _compute_avg_runtime(data) @@ -1246,11 +1306,14 @@ async def _get_workflow_stats(workflow_id: str) -> Dict[str, Any]: # API Endpoints - Workflow CRUD # ============================================================================= + @router.get("/workflow", response_model=List[WorkflowResponse]) async def list_workflows( category: Optional[str] = Query(None, description="Filter by category"), status: Optional[str] = Query(None, description="Filter by status"), - exclude_id: Optional[str] = Query(None, alias="excludeId", description="Exclude workflow by ID (e.g. exclude self when selecting sub-workflows)"), + exclude_id: Optional[str] = Query( + None, alias="excludeId", description="Exclude workflow by ID (e.g. exclude self when selecting sub-workflows)" + ), ): """ Get workflow list @@ -1284,7 +1347,9 @@ async def list_workflows( workflows.sort(key=lambda w: w.updatedAt, reverse=True) - log.info("workflow.list", {"count": len(workflows), "category": category, "status": status, "exclude_id": exclude_id}) + log.info( + "workflow.list", {"count": len(workflows), "category": category, "status": status, "exclude_id": exclude_id} + ) return workflows except Exception as e: log.error("workflow.list.error", {"error": str(e)}) @@ -1300,10 +1365,29 @@ async def create_workflow(req: WorkflowCreateRequest): of truth. Stats are initialised in Storage on first access. """ try: + workflow_json = _apply_new_workflow_runtime_defaults(req.workflow_json) try: - Workflow.from_dict(req.workflow_json) + workflow_model = Workflow.from_dict(workflow_json) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid workflow JSON: {str(e)}") + strict_mapping_errors = _strict_edge_mapping_lint_errors(workflow_model) + if strict_mapping_errors: + raise HTTPException( + status_code=400, + detail=( + "Workflow strict edge mapping failed: " + f"{strict_mapping_errors[:5]}" + ), + ) + schema_errors = _schema_lint_errors(workflow_model) + if schema_errors: + raise HTTPException( + status_code=400, + detail=( + "Workflow schema lint failed: " + f"{schema_errors[:5]}" + ), + ) workflow_id = str(uuid.uuid4()) now_ms = int(time.time() * 1000) @@ -1321,12 +1405,12 @@ async def create_workflow(req: WorkflowCreateRequest): "updatedAt": now_ms, } - _write_workflow_to_fs(workflow_id, req.workflow_json, meta, global_store=(source == "global")) + _write_workflow_to_fs(workflow_id, workflow_json, meta, global_store=(source == "global")) stats = await _get_workflow_stats(workflow_id) data = { **meta, - "workflowJson": req.workflow_json, + "workflowJson": workflow_json, "markdownContent": None, "editMarkdownContent": None, "stats": stats, @@ -1396,9 +1480,29 @@ async def update_workflow(workflow_id: str, req: WorkflowUpdateRequest): data["status"] = req.status if req.workflow_json is not None: try: - Workflow.from_dict(req.workflow_json) + workflow_model = Workflow.from_dict(req.workflow_json) + strict_mapping_errors = _strict_edge_mapping_lint_errors(workflow_model) + if strict_mapping_errors: + raise HTTPException( + status_code=400, + detail=( + "Workflow strict edge mapping failed: " + f"{strict_mapping_errors[:5]}" + ), + ) + schema_errors = _schema_lint_errors(workflow_model) + if schema_errors: + raise HTTPException( + status_code=400, + detail=( + "Workflow schema lint failed: " + f"{schema_errors[:5]}" + ), + ) workflow_json = req.workflow_json except Exception as e: + if isinstance(e, HTTPException): + raise raise HTTPException(status_code=400, detail=f"Invalid workflow JSON: {str(e)}") if req.markdown_content is not None: markdown_content = req.markdown_content @@ -1473,11 +1577,12 @@ async def delete_workflow(workflow_id: str): # API Endpoints - Workflow Operations # ============================================================================= + @router.post("/workflow/{workflow_id}/run", response_model=WorkflowExecutionResponse) async def run_workflow_endpoint(workflow_id: str, req: WorkflowRunRequest): """ Execute workflow - + Runs the workflow with provided inputs and returns execution results. """ try: @@ -1499,7 +1604,7 @@ async def run_workflow_endpoint(workflow_id: str, req: WorkflowRunRequest): input_params=req.inputs or {}, ) exec_id = str(exec_data["id"]) - + cancel_event = threading.Event() task = asyncio.create_task( _run_workflow_execution_task( @@ -1517,18 +1622,23 @@ async def run_workflow_endpoint(workflow_id: str, req: WorkflowRunRequest): task=task, cancel_event=cancel_event, ) + # Guarantee cleanup of the registry entry even when the task is # cancelled or fails before reaching its own ``finally`` block (e.g. # if the event loop is shutting down). This prevents the ``Active*`` # map from growing forever when tasks are abandoned. def _cleanup_active(_t: asyncio.Task, _eid: str = exec_id) -> None: _active_workflow_executions.pop(_eid, None) + task.add_done_callback(_cleanup_active) - log.info("workflow.execution.started", { - "id": workflow_id, - "exec_id": exec_id, - }) + log.info( + "workflow.execution.started", + { + "id": workflow_id, + "exec_id": exec_id, + }, + ) return WorkflowExecutionResponse(**exec_data) except HTTPException: raise @@ -1541,7 +1651,9 @@ def _cleanup_active(_t: asyncio.Task, _eid: str = exec_id) -> None: async def cancel_workflow_execution(workflow_id: str, exec_id: str): """Request cooperative cancellation of a running workflow execution.""" try: - exec_data = await Storage.read(_workflow_execution_key(exec_id)) + exec_data = await WorkflowStore.get_execution(exec_id) + if not exec_data: + raise HTTPException(status_code=404, detail=f"Execution not found: {exec_id}") if exec_data.get("workflowId") != workflow_id: raise HTTPException(status_code=404, detail="Execution not found for this workflow") @@ -1557,30 +1669,36 @@ async def cancel_workflow_execution(workflow_id: str, exec_id: str): raise HTTPException(status_code=404, detail="Execution not found for this workflow") active.cancel_event.set() - exec_data.update({ - "currentPhase": "cancelling", - "errorMessage": exec_data.get("errorMessage") or "Cancellation requested", - }) - await Storage.write(_workflow_execution_key(exec_id), exec_data) - log.info("workflow.execution.cancel_requested", { - "id": workflow_id, - "exec_id": exec_id, - }) + exec_data.update( + { + "currentPhase": "cancelling", + "errorMessage": exec_data.get("errorMessage") or "Cancellation requested", + } + ) + await WorkflowStore.upsert_execution(exec_data) + log.info( + "workflow.execution.cancel_requested", + { + "id": workflow_id, + "exec_id": exec_id, + }, + ) return { "status": "accepted", "message": f"Cancellation requested for execution {exec_id}", "executionId": exec_id, } - except Storage.NotFoundError: - raise HTTPException(status_code=404, detail=f"Execution not found: {exec_id}") except HTTPException: raise except Exception as e: - log.error("workflow.execution.cancel.error", { - "id": workflow_id, - "exec_id": exec_id, - "error": str(e), - }) + log.error( + "workflow.execution.cancel.error", + { + "id": workflow_id, + "exec_id": exec_id, + "error": str(e), + }, + ) raise HTTPException(status_code=500, detail=f"Failed to cancel execution: {str(e)}") @@ -1588,7 +1706,7 @@ async def cancel_workflow_execution(workflow_id: str, exec_id: str): async def validate_workflow(workflow_id: str): """ Validate workflow - + Lints the workflow and returns validation errors/warnings. """ try: @@ -1621,12 +1739,11 @@ async def validate_workflow(workflow_id: str): raise HTTPException(status_code=500, detail=f"Failed to validate workflow: {str(e)}") - - # ============================================================================= # API Endpoints - Workflow Center (Skill -> Register -> Publish Service) # ============================================================================= + @router.post("/workflow-center/scan-workflows") async def workflow_center_scan_workflows(): """Scan .flocks/workflow and register discovered workflows.""" @@ -1718,34 +1835,39 @@ async def workflow_center_invoke(workflow_id: str, req: WorkflowCenterInvokeRequ # step callbacks run locally so executionLog stays as the empty list # set by create_execution_record. We still run compact_history here # as a forward-compatible guard in case a future code path populates it. - exec_data.update({ - "outputResults": compact_outputs_for_storage( - result.get("outputs", {}) if isinstance(result, dict) else {} - ), - "executionLog": compact_history_for_storage(exec_data.get("executionLog")), - "status": status_value, - "finishedAt": int(time.time() * 1000), - "duration": duration, - "currentPhase": status_value, - }) + exec_data.update( + { + "outputResults": compact_outputs_for_storage( + result.get("outputs", {}) if isinstance(result, dict) else {} + ), + "executionLog": compact_history_for_storage(exec_data.get("executionLog")), + "status": status_value, + "finishedAt": int(time.time() * 1000), + "duration": duration, + "currentPhase": status_value, + } + ) await _record_execution_result(workflow_id, exec_id, exec_data) return result except (WorkflowNotFoundError, WorkflowNotPublishedError) as e: duration = time.time() - started - exec_data.update({"status": "error", "finishedAt": int(time.time() * 1000), - "duration": duration, "errorMessage": str(e)}) + exec_data.update( + {"status": "error", "finishedAt": int(time.time() * 1000), "duration": duration, "errorMessage": str(e)} + ) await _record_execution_result(workflow_id, exec_id, exec_data) raise HTTPException(status_code=404, detail=str(e)) except WorkflowCenterError as e: duration = time.time() - started - exec_data.update({"status": "error", "finishedAt": int(time.time() * 1000), - "duration": duration, "errorMessage": str(e)}) + exec_data.update( + {"status": "error", "finishedAt": int(time.time() * 1000), "duration": duration, "errorMessage": str(e)} + ) await _record_execution_result(workflow_id, exec_id, exec_data) raise HTTPException(status_code=400, detail=str(e)) except Exception as e: duration = time.time() - started - exec_data.update({"status": "error", "finishedAt": int(time.time() * 1000), - "duration": duration, "errorMessage": str(e)}) + exec_data.update( + {"status": "error", "finishedAt": int(time.time() * 1000), "duration": duration, "errorMessage": str(e)} + ) await _record_execution_result(workflow_id, exec_id, exec_data) log.error("workflow.center.invoke.error", {"workflow_id": workflow_id, "error": str(e)}) raise HTTPException(status_code=500, detail=f"Failed to invoke workflow service: {str(e)}") @@ -1784,6 +1906,7 @@ async def workflow_center_releases(workflow_id: str): # API Endpoints - Workflow History # ============================================================================= + @router.get("/workflow/{workflow_id}/history", response_model=List[WorkflowExecutionResponse]) async def get_workflow_history( workflow_id: str, @@ -1803,33 +1926,17 @@ async def get_workflow_history( # Keep the list endpoint on summary rows only. Do not materialize # append-only step logs here; details load them separately. - all_entries = await Storage.list_raw("workflow_execution/") + rows = await WorkflowStore.list_executions( + workflow_id, + limit=limit, + trigger_id=trigger_id, + trigger_type=trigger_type, + ) executions = [] - workflow_marker = f'"workflowId": "{workflow_id}"' - compact_marker = f'"workflowId":"{workflow_id}"' - for _key, raw_value in all_entries: - try: - head = raw_value[:500] - if workflow_marker not in head and compact_marker not in head: - continue - exec_data = json.loads(raw_value) - if not isinstance(exec_data, dict): - continue - if exec_data.get("workflowId") != workflow_id: - continue - if trigger_id and exec_data.get("triggerId") != trigger_id: - continue - if trigger_type and exec_data.get("triggerType") != trigger_type: - continue - exec_data["executionLog"] = [] - executions.append(WorkflowExecutionResponse(**exec_data)) - except Exception as e: - log.warning("workflow.history.skip", {"key": _key, "error": str(e)}) - continue - - # Sort by start time (newest first) and limit - executions.sort(key=lambda e: e.startedAt, reverse=True) - executions = executions[:limit] + for exec_data in rows: + item = dict(exec_data) + item["executionLog"] = [] + executions.append(WorkflowExecutionResponse(**item)) log.info("workflow.history", {"id": workflow_id, "count": len(executions)}) return executions @@ -1849,16 +1956,18 @@ async def get_execution_details( ): """ Get execution details - + Returns detailed information about a specific workflow execution. """ try: - exec_data = await Storage.read(_workflow_execution_key(exec_id)) - + exec_data = await WorkflowStore.get_execution(exec_id) + if not exec_data: + raise HTTPException(status_code=404, detail=f"Execution not found: {exec_id}") + # Verify workflow ID matches if exec_data.get("workflowId") != workflow_id: raise HTTPException(status_code=404, detail="Execution not found for this workflow") - + if step_limit == 0: inline_log = exec_data.get("executionLog") inline_count = len(inline_log) if isinstance(inline_log, list) else 0 @@ -1872,7 +1981,7 @@ async def get_execution_details( if total_steps == 0: legacy_steps = compact_history_for_storage(exec_data.get("executionLog")) total_steps = len(legacy_steps) - steps = legacy_steps[step_offset:step_offset + step_limit] + steps = legacy_steps[step_offset : step_offset + step_limit] exec_data = dict(exec_data) exec_data["executionLog"] = steps exec_data["stepLogOffset"] = step_offset @@ -1880,8 +1989,6 @@ async def get_execution_details( exec_data["stepLogTotal"] = total_steps exec_data["stepCount"] = exec_data.get("stepCount") or total_steps return WorkflowExecutionResponse(**exec_data) - except Storage.NotFoundError: - raise HTTPException(status_code=404, detail=f"Execution not found: {exec_id}") except HTTPException: raise except Exception as e: @@ -1893,11 +2000,12 @@ async def get_execution_details( # API Endpoints - Workflow Statistics # ============================================================================= + @router.get("/workflow/stats", response_model=WorkflowStatsResponse) async def get_aggregate_stats(): """ Get aggregate workflow statistics - + Returns statistics across all workflows. """ try: @@ -1941,7 +2049,7 @@ async def get_aggregate_stats(): async def get_workflow_stats_endpoint(workflow_id: str): """ Get workflow statistics - + Returns statistics for a specific workflow. """ try: @@ -1949,12 +2057,12 @@ async def get_workflow_stats_endpoint(workflow_id: str): raise HTTPException(status_code=404, detail=f"Workflow not found: {workflow_id}") stats = await _get_workflow_stats(workflow_id) - + # Calculate average runtime avg_runtime = 0.0 if stats["callCount"] > 0: avg_runtime = stats["totalRuntime"] / stats["callCount"] - + result = { "workflowId": workflow_id, "callCount": stats["callCount"], @@ -1965,7 +2073,7 @@ async def get_workflow_stats_endpoint(workflow_id: str): "thumbsUp": stats["thumbsUp"], "thumbsDown": stats["thumbsDown"], } - + return WorkflowStatsResponse(**result) except HTTPException: raise @@ -1978,6 +2086,7 @@ async def get_workflow_stats_endpoint(workflow_id: str): # API Endpoints - Import/Export # ============================================================================= + @router.post("/workflow/import", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED) async def import_workflow(workflow_json: Dict[str, Any]): """ @@ -2035,7 +2144,7 @@ async def import_workflow(workflow_json: Dict[str, Any]): async def export_workflow(workflow_id: str): """ Export workflow - + Exports workflow as JSON for download/sharing. """ try: @@ -2050,7 +2159,7 @@ async def export_workflow(workflow_id: str): workflow_json["metadata"]["exportedFrom"] = "flocks" workflow_json["metadata"]["exportedAt"] = int(time.time() * 1000) workflow_json["name"] = data["name"] - + log.info("workflow.exported", {"id": workflow_id}) return workflow_json except HTTPException: @@ -2101,7 +2210,7 @@ async def _prepare_workflow_api_registry(workflow_id: str) -> tuple[Dict[str, An fp = hashlib.sha256(workflow_path.read_bytes()).hexdigest() now_ms = int(time.time() * 1000) - existing_registry = await Storage.read(f"{_REGISTRY_PREFIX_MAIN}{workflow_id}") or {} + existing_registry = await WorkflowStore.kv_get(f"{_REGISTRY_PREFIX_MAIN}{workflow_id}") or {} registry_entry = { "workflowId": workflow_id, "name": data["name"], @@ -2112,7 +2221,7 @@ async def _prepare_workflow_api_registry(workflow_id: str) -> tuple[Dict[str, An "registeredAt": existing_registry.get("registeredAt", now_ms), "updatedAt": now_ms, } - await Storage.write(f"{_REGISTRY_PREFIX_MAIN}{workflow_id}", registry_entry) + await WorkflowStore.kv_put(f"{_REGISTRY_PREFIX_MAIN}{workflow_id}", registry_entry) return data, now_ms @@ -2139,13 +2248,15 @@ async def _normalize_listed_api_service(key: Any, entry: Any) -> Optional[Dict[s service = dict(entry) workflow_id = str(service.get("workflowId") or _workflow_id_from_api_service_key(key)) service["workflowId"] = workflow_id - runtime = await Storage.read(_runtime_key_main(workflow_id)) + runtime = await WorkflowStore.kv_get(_runtime_key_main(workflow_id)) if isinstance(runtime, dict) and runtime: service_url = runtime.get("serviceUrl") or service.get("serviceUrl") or "" service["serviceUrl"] = service_url service["invokeUrl"] = f"{service_url}/invoke" if service_url else service.get("invokeUrl", "") - service["status"] = "running" if runtime.get("status") in {"active", "running"} else service.get("status", "running") + service["status"] = ( + "running" if runtime.get("status") in {"active", "running"} else service.get("status", "running") + ) service["driver"] = runtime.get("driver") or service.get("driver") service["containerName"] = runtime.get("containerName") or service.get("containerName", "") service["image"] = runtime.get("image") or service.get("image") @@ -2169,9 +2280,9 @@ async def reconcile_published_workflow_api_services() -> Dict[str, int]: if not _workflow_api_autostart_enabled(): return stats - keys = await Storage.list_keys(_API_SERVICE_PREFIX) + keys = await WorkflowStore.kv_list_keys(_API_SERVICE_PREFIX) for key in keys: - service = await Storage.read(key) + service = await WorkflowStore.kv_get(key) if not isinstance(service, dict): continue @@ -2189,7 +2300,7 @@ async def reconcile_published_workflow_api_services() -> Dict[str, int]: if health.get("ok"): service["status"] = "running" service["health"] = health - await Storage.write(_api_service_key(workflow_id), service) + await WorkflowStore.kv_put(_api_service_key(workflow_id), service) stats["healthy"] += 1 continue @@ -2205,27 +2316,29 @@ async def reconcile_published_workflow_api_services() -> Dict[str, int]: ) service_url = active_record.get("serviceUrl", "") - service.update({ - "workflowId": workflow_id, - "workflowName": service.get("workflowName") or data["name"], - "serviceUrl": service_url, - "invokeUrl": f"{service_url}/invoke", - "apiKey": service.get("apiKey") or active_record.get("apiKey"), - "status": "running", - "containerName": active_record.get("containerName", ""), - "driver": active_record.get("driver") or service.get("driver"), - "image": active_record.get("image") or service.get("image"), - "restartedAt": int(time.time() * 1000), - }) + service.update( + { + "workflowId": workflow_id, + "workflowName": service.get("workflowName") or data["name"], + "serviceUrl": service_url, + "invokeUrl": f"{service_url}/invoke", + "apiKey": service.get("apiKey") or active_record.get("apiKey"), + "status": "running", + "containerName": active_record.get("containerName", ""), + "driver": active_record.get("driver") or service.get("driver"), + "image": active_record.get("image") or service.get("image"), + "restartedAt": int(time.time() * 1000), + } + ) service.pop("lastStartError", None) service["health"] = {"ok": True, "restarted": True} - await Storage.write(_api_service_key(workflow_id), service) + await WorkflowStore.kv_put(_api_service_key(workflow_id), service) stats["restarted"] += 1 except Exception as exc: service["status"] = "error" service["health"] = health service["lastStartError"] = str(exc) - await Storage.write(_api_service_key(workflow_id), service) + await WorkflowStore.kv_put(_api_service_key(workflow_id), service) log.warning("workflow.api.autostart_failed", {"id": workflow_id, "error": str(exc)}) stats["failed"] += 1 return stats @@ -2341,7 +2454,7 @@ async def publish_workflow_as_api( # Preserve existing API key across re-publishes so callers don't break. # The runtime must receive the same key before it starts so /invoke can # enforce the key returned to callers. - existing_service = await Storage.read(_api_service_key(workflow_id)) or {} + existing_service = await WorkflowStore.kv_get(_api_service_key(workflow_id)) or {} api_key = existing_service.get("apiKey") or (uuid.uuid4().hex + uuid.uuid4().hex) # Use center.py to publish the selected runtime. @@ -2370,7 +2483,7 @@ async def publish_workflow_as_api( "driver": driver, "image": image, } - await Storage.write(_api_service_key(workflow_id), service_info) + await WorkflowStore.kv_put(_api_service_key(workflow_id), service_info) log.info("workflow.api.published", {"id": workflow_id, "url": service_url}) return service_info @@ -2392,7 +2505,7 @@ async def unpublish_workflow_api(workflow_id: str): Stop a published workflow API service. """ try: - existing = await Storage.read(_api_service_key(workflow_id)) + existing = await WorkflowStore.kv_get(_api_service_key(workflow_id)) if not existing: raise HTTPException(status_code=404, detail="No published service found for this workflow") @@ -2403,7 +2516,7 @@ async def unpublish_workflow_api(workflow_id: str): existing["status"] = "stopped" existing["stoppedAt"] = int(time.time() * 1000) - await Storage.write(_api_service_key(workflow_id), existing) + await WorkflowStore.kv_put(_api_service_key(workflow_id), existing) log.info("workflow.api.unpublished", {"id": workflow_id}) return {"ok": True} @@ -2421,7 +2534,7 @@ async def get_workflow_service(workflow_id: str): Returns null if not published. """ try: - return await Storage.read(_api_service_key(workflow_id)) # None / null if not found + return await WorkflowStore.kv_get(_api_service_key(workflow_id)) # None / null if not found except Exception as e: log.error("workflow.service.get.error", {"id": workflow_id, "error": str(e)}) raise HTTPException(status_code=500, detail=f"Failed to get service info: {str(e)}") @@ -2431,7 +2544,7 @@ async def get_workflow_service(workflow_id: str): async def delete_workflow_service(workflow_id: str): """Delete the stored API service configuration for a workflow.""" try: - existing = await Storage.read(_api_service_key(workflow_id)) + existing = await WorkflowStore.kv_get(_api_service_key(workflow_id)) if not existing: raise HTTPException(status_code=404, detail="No published service found for this workflow") @@ -2441,8 +2554,8 @@ async def delete_workflow_service(workflow_id: str): pass try: - await Storage.remove(_api_service_key(workflow_id)) - except Storage.NotFoundError: + await WorkflowStore.kv_remove(_api_service_key(workflow_id)) + except Exception: pass log.info("workflow.api.service_deleted", {"id": workflow_id}) @@ -2504,11 +2617,14 @@ async def update_workflow_config( try: normalized_config = _normalize_workflow_integration_config_template(workflow_id, data, config) config_path = _workflow_config_dir(workflow_id, data) / "config.json" - await Storage.write(_workflow_integration_config_key(workflow_id), normalized_config) - log.info("workflow.config.updated", { - "id": workflow_id, - "storage_key": _workflow_integration_config_key(workflow_id), - }) + await WorkflowStore.put_config(workflow_id, normalized_config) + log.info( + "workflow.config.updated", + { + "id": workflow_id, + "storage_key": _workflow_integration_config_key(workflow_id), + }, + ) return { "ok": True, "exists": True, @@ -2546,11 +2662,14 @@ async def sync_workflow_config(workflow_id: str): } config = await _build_workflow_integration_config(workflow_id, data) - await Storage.write(_workflow_integration_config_key(workflow_id), config) - log.info("workflow.config.synced", { - "id": workflow_id, - "storage_key": _workflow_integration_config_key(workflow_id), - }) + await WorkflowStore.put_config(workflow_id, config) + log.info( + "workflow.config.synced", + { + "id": workflow_id, + "storage_key": _workflow_integration_config_key(workflow_id), + }, + ) return { "ok": True, "path": str(config_path), @@ -2571,10 +2690,10 @@ async def list_workflow_services(): List all published workflow API services. """ try: - keys = await Storage.list_keys(_API_SERVICE_PREFIX) + keys = await WorkflowStore.kv_list_keys(_API_SERVICE_PREFIX) services = [] for key in keys: - entry = await Storage.read(key) + entry = await WorkflowStore.kv_get(key) service = await _normalize_listed_api_service(key, entry) if service: services.append(service) @@ -2600,9 +2719,7 @@ def _validate_trigger_type_constraints(triggers: List[TriggerDefinition]) -> Non singleton_ids_by_type.setdefault(trigger.type, []).append(trigger.id or "") duplicates = { - trigger_type: trigger_ids - for trigger_type, trigger_ids in singleton_ids_by_type.items() - if len(trigger_ids) > 1 + trigger_type: trigger_ids for trigger_type, trigger_ids in singleton_ids_by_type.items() if len(trigger_ids) > 1 } if not duplicates: return @@ -2897,7 +3014,7 @@ async def save_kafka_config(workflow_id: str, req: KafkaConfigRequest): "inputs": _strip_execution_only_comments(req.inputs), "updatedAt": int(time.time() * 1000), } - await Storage.write(_kafka_config_key(workflow_id), config) + await WorkflowStore.put_config(workflow_id, config, kind="workflow_kafka_config") unified_trigger = TriggerDefinition.model_validate( { "id": "kafka-default", @@ -2949,7 +3066,7 @@ async def get_kafka_config(workflow_id: str): Get saved Kafka configuration for a workflow. """ try: - config = await Storage.read(_kafka_config_key(workflow_id)) + config = await WorkflowStore.get_config(workflow_id, kind="workflow_kafka_config") if config is None: data = _read_workflow_from_fs(workflow_id) if data: @@ -3012,7 +3129,7 @@ async def save_workflow_poller_config(workflow_id: str, req: WorkflowPollerConfi "inputs": req.inputs, "updatedAt": int(time.time() * 1000), } - await Storage.write(f"workflow_poller_config/{workflow_id}", config) + await WorkflowStore.put_config(workflow_id, config, kind="workflow_poller_config") unified_trigger = TriggerDefinition.model_validate( { "id": "schedule-default", @@ -3057,7 +3174,7 @@ async def save_workflow_poller_config(workflow_id: str, req: WorkflowPollerConfi async def get_workflow_poller_config(workflow_id: str): """Get saved poller configuration for a workflow.""" try: - config = await Storage.read(f"workflow_poller_config/{workflow_id}") + config = await WorkflowStore.get_config(workflow_id, kind="workflow_poller_config") if config is None: data = _read_workflow_from_fs(workflow_id) if data: @@ -3128,7 +3245,7 @@ async def save_syslog_config(workflow_id: str, req: SyslogConfigRequest): "inputKey": req.input_key, "updatedAt": int(time.time() * 1000), } - await Storage.write(_syslog_config_key(workflow_id), config) + await WorkflowStore.put_config(workflow_id, config, kind="workflow_syslog_config") unified_trigger = TriggerDefinition.model_validate( { "id": "syslog-default", @@ -3177,7 +3294,7 @@ async def save_syslog_config(workflow_id: str, req: SyslogConfigRequest): async def get_syslog_config(workflow_id: str): """Get saved syslog configuration for a workflow.""" try: - config = await Storage.read(_syslog_config_key(workflow_id)) + config = await WorkflowStore.get_config(workflow_id, kind="workflow_syslog_config") if config is None: data = _read_workflow_from_fs(workflow_id) if data: @@ -3214,8 +3331,10 @@ async def get_syslog_status(workflow_id: str): # API Endpoints - Run Single Node # ============================================================================= + class RunNodeRequest(BaseModel): """Request to execute a single workflow node.""" + model_config = ConfigDict(populate_by_name=True) node_id: str = Field(..., description="Node ID to execute") @@ -3227,6 +3346,7 @@ class RunNodeRequest(BaseModel): class RunNodeResponse(BaseModel): """Response from executing a single workflow node.""" + model_config = ConfigDict(populate_by_name=True) node_id: str @@ -3273,12 +3393,15 @@ async def run_single_node(workflow_id: str, req: RunNodeRequest): step_result = await asyncio.to_thread(engine.run_node, req.node_id, req.inputs) - log.info("workflow.run_node", { - "workflow_id": workflow_id, - "node_id": req.node_id, - "success": step_result.error is None, - "duration_ms": step_result.duration_ms, - }) + log.info( + "workflow.run_node", + { + "workflow_id": workflow_id, + "node_id": req.node_id, + "success": step_result.error is None, + "duration_ms": step_result.duration_ms, + }, + ) return RunNodeResponse( node_id=step_result.node_id, @@ -3310,8 +3433,10 @@ async def run_single_node(workflow_id: str, req: RunNodeRequest): # API Endpoints - Sample Inputs # ============================================================================= + class SampleInputsRequest(BaseModel): """Request to save sample inputs for a workflow.""" + model_config = ConfigDict(populate_by_name=True) sampleInputs: Dict[str, Any] = Field(default_factory=dict, description="Sample input data") diff --git a/flocks/server/routes/workspace.py b/flocks/server/routes/workspace.py index 51aed456d..642e56ce4 100644 --- a/flocks/server/routes/workspace.py +++ b/flocks/server/routes/workspace.py @@ -20,13 +20,17 @@ GET /api/workspace/file read text file content PUT /api/workspace/file write / update text file content DELETE /api/workspace/file delete file + GET /api/workspace/preview preview single file inline GET /api/workspace/download download single file POST /api/workspace/download/zip batch download as zip POST /api/workspace/move move / rename + POST /api/workspace/reveal open containing folder in system file manager Memory view (read-only, points to data/memory/) - GET /api/workspace/memory/list list memory files - GET /api/workspace/memory/file read memory file content + GET /api/workspace/memory/list list memory files + GET /api/workspace/memory/file read memory file content + GET /api/workspace/memory/preview preview single memory file inline + GET /api/workspace/memory/download download single memory file Stats GET /api/workspace/stats workspace + memory totals @@ -36,9 +40,12 @@ import asyncio import io +import mimetypes import os import shutil import stat as stat_module +import subprocess +import sys import zipfile from pathlib import Path from typing import List, Optional, Literal @@ -65,6 +72,16 @@ _ALLOWED_UPLOAD_LABEL = ( "txt, md, json, yaml, yml, xml, csv, pdf, doc, docx, html, htm, ppt, pptx, xls, xlsx" ) +_ALLOWED_PREVIEW_MEDIA_TYPES = { + "application/pdf", + "image/png", + "image/jpeg", + "image/gif", + "image/webp", + "image/svg+xml", +} + + def _max_upload_bytes() -> int: return int(os.getenv("FLOCKS_WORKSPACE_MAX_UPLOAD_MB", str(_DEFAULT_MAX_UPLOAD_MB))) * 1024 * 1024 @@ -159,6 +176,69 @@ def _read_text_preview_sync(path: Path, max_bytes: int) -> tuple[str, bool]: return data.decode("utf-8", errors="replace"), truncated +def _inline_preview_response(target: Path) -> FileResponse: + media_type = mimetypes.guess_type(target.name)[0] or "application/octet-stream" + if media_type not in _ALLOWED_PREVIEW_MEDIA_TYPES: + raise HTTPException( + status_code=415, + detail="File type is not supported for inline preview", + ) + headers = { + "Content-Disposition": "inline", + "X-Content-Type-Options": "nosniff", + } + if media_type == "image/svg+xml": + headers["Content-Security-Policy"] = ( + "sandbox; default-src 'none'; script-src 'none'; " + "object-src 'none'; base-uri 'none'; img-src data: blob:; " + "style-src 'unsafe-inline'" + ) + return FileResponse( + path=str(target), + media_type=media_type, + headers=headers, + ) + + +def _download_response(target: Path) -> FileResponse: + return FileResponse( + path=str(target), + filename=target.name, + media_type="application/octet-stream", + ) + + +def _reveal_in_file_manager(target: Path) -> str: + """Open the OS file manager for a workspace file or directory.""" + if sys.platform == "win32": + args = ["explorer", str(target)] if target.is_dir() else ["explorer", f"/select,{target}"] + subprocess.Popen(args) + return "reveal" if target.is_file() else "open" + + if sys.platform == "darwin": + args = ["open", str(target)] if target.is_dir() else ["open", "-R", str(target)] + subprocess.Popen(args) + return "reveal" if target.is_file() else "open" + + directory = target if target.is_dir() else target.parent + opener = shutil.which("xdg-open") + if opener: + subprocess.Popen([opener, str(directory)]) + return "open" + + gio = shutil.which("gio") + if gio: + subprocess.Popen([gio, "open", str(directory)]) + return "open" + + kde_open = shutil.which("kde-open") + if kde_open: + subprocess.Popen([kde_open, str(directory)]) + return "open" + + raise RuntimeError("No file manager opener found for this platform") + + # ─── directory operations ─────────────────────────────────────────────────── @router.get("/tree", response_model=WorkspaceNode, summary="List directory tree") @@ -386,6 +466,22 @@ async def delete_file( return {"path": path, "deleted": True} +@router.get("/preview", summary="Preview single file inline") +async def preview_file( + path: str = Query(..., description="Relative path to file"), +): + mgr = _get_manager() + try: + target = mgr.resolve_workspace_path(path) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not target.exists(): + raise HTTPException(status_code=404, detail=f"File not found: {path}") + if not target.is_file(): + raise HTTPException(status_code=400, detail=f"Not a file: {path}") + return _inline_preview_response(target) + + @router.get("/download", summary="Download single file") async def download_file( path: str = Query(..., description="Relative path to file"), @@ -399,11 +495,7 @@ async def download_file( raise HTTPException(status_code=404, detail=f"File not found: {path}") if not target.is_file(): raise HTTPException(status_code=400, detail=f"Not a file: {path}") - return FileResponse( - path=str(target), - filename=target.name, - media_type="application/octet-stream", - ) + return _download_response(target) class ZipDownloadRequest(BaseModel): @@ -436,6 +528,10 @@ class MoveRequest(BaseModel): dst: str +class RevealRequest(BaseModel): + path: str + + @router.post("/move", summary="Move / rename file or directory") async def move_item(body: MoveRequest): mgr = _get_manager() @@ -454,6 +550,26 @@ async def move_item(body: MoveRequest): return {"src": body.src, "dst": body.dst, "moved": True} +@router.post("/reveal", summary="Open containing folder in system file manager") +async def reveal_item(body: RevealRequest): + mgr = _get_manager() + try: + target = mgr.resolve_workspace_path(body.path) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not target.exists(): + raise HTTPException(status_code=404, detail=f"Path not found: {body.path}") + target_type = "directory" if target.is_dir() else "file" + try: + mode = await asyncio.to_thread(_reveal_in_file_manager, target) + except RuntimeError as e: + raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to open file manager: {e}") + log.info("workspace.item.revealed", {"path": body.path, "target": target_type, "mode": mode}) + return {"path": body.path, "opened": True, "target": target_type, "mode": mode} + + # ─── memory view (read-only) ──────────────────────────────────────────────── def _list_memory_sync(memory_dir: Path) -> List[WorkspaceNode]: @@ -496,6 +612,11 @@ async def read_memory_file( raise HTTPException(status_code=404, detail=f"Memory file not found: {path}") if not target.is_file(): raise HTTPException(status_code=400, detail=f"Not a file: {path}") + if not WorkspaceManager.is_text_file(target): + raise HTTPException( + status_code=400, + detail="Binary file — use /memory/download endpoint instead", + ) max_read_bytes = _max_read_bytes() try: content, truncated = await asyncio.to_thread( @@ -514,6 +635,38 @@ async def read_memory_file( } +@router.get("/memory/preview", summary="Preview single memory file inline") +async def preview_memory_file( + path: str = Query(..., description="Relative path inside memory directory"), +): + mgr = _get_manager() + try: + target = mgr.resolve_memory_path(path) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not target.exists(): + raise HTTPException(status_code=404, detail=f"Memory file not found: {path}") + if not target.is_file(): + raise HTTPException(status_code=400, detail=f"Not a file: {path}") + return _inline_preview_response(target) + + +@router.get("/memory/download", summary="Download single memory file") +async def download_memory_file( + path: str = Query(..., description="Relative path inside memory directory"), +): + mgr = _get_manager() + try: + target = mgr.resolve_memory_path(path) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not target.exists(): + raise HTTPException(status_code=404, detail=f"Memory file not found: {path}") + if not target.is_file(): + raise HTTPException(status_code=400, detail=f"Not a file: {path}") + return _download_response(target) + + # ─── stats ────────────────────────────────────────────────────────────────── @router.get("/stats", response_model=WorkspaceStats, summary="Workspace statistics") diff --git a/flocks/session/callable_schema.py b/flocks/session/callable_schema.py index 914bd3a60..d54d1639b 100644 --- a/flocks/session/callable_schema.py +++ b/flocks/session/callable_schema.py @@ -42,17 +42,22 @@ def resolve_callable_tool_infos(tool_names: Iterable[str]) -> tuple[List[Any], i async def _resolve_dynamic_always_load_tool_names() -> Set[str]: """Return runtime-only always-load tools. - Device discovery should be available without an extra ``tool_search`` hop + Device management should be available without an extra ``tool_search`` hop when the workspace has at least one enabled device, but we do not want to - expose ``device_context`` in sessions that have no security devices. + expose it in sessions that have no security devices. """ dynamic_names: Set[str] = set() + candidate_names = ("device_manage",) - try: - device_context_tool = ToolRegistry.get("device_context") - if device_context_tool is None or not getattr(device_context_tool.info, "enabled", True): - return dynamic_names - except Exception: + for name in candidate_names: + try: + tool = ToolRegistry.get(name) + except Exception: + continue + if tool is not None and getattr(tool.info, "enabled", True): + dynamic_names.add(name) + + if not dynamic_names: return dynamic_names try: @@ -62,10 +67,7 @@ async def _resolve_dynamic_always_load_tool_names() -> Set[str]: except Exception: return dynamic_names - if any(device.enabled for device in devices): - dynamic_names.add("device_context") - - return dynamic_names + return dynamic_names if any(device.enabled for device in devices) else set() async def list_session_callable_tool_infos( diff --git a/flocks/session/core/status.py b/flocks/session/core/status.py index a76ace596..42c0d43c8 100644 --- a/flocks/session/core/status.py +++ b/flocks/session/core/status.py @@ -139,8 +139,8 @@ def clear_all(cls) -> None: def get_busy_session_ids(cls) -> List[str]: """Return IDs of all sessions that are busy or compacting (across all instances).""" result: List[str] = [] - for _inst_id, statuses in cls._state.items(): - for sid, info in statuses.items(): + for _inst_id, statuses in list(cls._state.items()): + for sid, info in list(statuses.items()): if info.type in ("busy", "compacting"): result.append(sid) return result diff --git a/flocks/session/message.py b/flocks/session/message.py index 075c19248..7fd6dcbe2 100644 --- a/flocks/session/message.py +++ b/flocks/session/message.py @@ -25,6 +25,10 @@ log = Log.create(service="message") +class MessageCacheInvalidatedError(RuntimeError): + """Raised when a full cache load cannot stabilize after invalidation.""" + + class _SessionLockManager: """Per-session asyncio.Lock manager with LRU eviction. @@ -441,15 +445,20 @@ class Message: _parts_flush_tasks: Dict[str, asyncio.Task] = {} _parts_fully_loaded: set[str] = set() _lru: OrderedDict[str, bool] = OrderedDict() # LRU tracker: move_to_end() is O(1) + _cache_epoch: int = 0 + _session_cache_generations: OrderedDict[str, int] = OrderedDict() + _cache_loading_sessions: set[str] = set() # Maximum number of sessions to keep in cache before evicting oldest _MAX_CACHED_SESSIONS = 50 + _MAX_CACHE_GENERATIONS = _MAX_CACHED_SESSIONS * 4 # Storage key prefixes _MESSAGE_PREFIX = "message" _PARTS_PREFIX = "message_parts" _PARTS_ITEM_PREFIX = "message_parts" _PARTS_PERSIST_DEBOUNCE_MS = 75 + _CACHE_RELOAD_RETRY_LIMIT = 3 @classmethod def _cancel_parts_flush_task(cls, session_id: str) -> None: @@ -457,6 +466,70 @@ def _cancel_parts_flush_task(cls, session_id: str) -> None: if task and not task.done(): task.cancel() + @classmethod + def _cache_token(cls, session_id: str) -> tuple[int, int]: + return cls._cache_epoch, cls._session_cache_generations.get(session_id, 0) + + @classmethod + def _bump_cache_epoch(cls) -> None: + cls._cache_epoch += 1 + cls._session_cache_generations.clear() + + @classmethod + def _bump_session_cache_generation(cls, session_id: str) -> None: + generation = cls._session_cache_generations.get(session_id, 0) + 1 + cls._session_cache_generations[session_id] = generation + cls._session_cache_generations.move_to_end(session_id) + cls._prune_session_cache_generations() + + @classmethod + def _prune_session_cache_generations(cls) -> None: + while len(cls._session_cache_generations) > cls._MAX_CACHE_GENERATIONS: + for stale_id in list(cls._session_cache_generations): + if stale_id in cls._cache_loading_sessions: + continue + cls._session_cache_generations.pop(stale_id, None) + break + else: + break + + @classmethod + def _has_message_cache(cls, session_id: str) -> bool: + return session_id in cls._lru and session_id in cls._messages_cache + + @classmethod + def _drop_cached_session( + cls, + session_id: str, + *, + discard_lock: bool = True, + ) -> None: + cls._bump_session_cache_generation(session_id) + cls._cancel_parts_flush_task(session_id) + cls._lru.pop(session_id, None) + cls._messages_cache.pop(session_id, None) + cls._msg_id_index.pop(session_id, None) + cls._parts_cache.pop(session_id, None) + cls._parts_revision_cache.pop(session_id, None) + cls._parts_serialized_cache.pop(session_id, None) + cls._parts_storage_format.pop(session_id, None) + cls._parts_persisted_mids.pop(session_id, None) + cls._parts_fully_loaded.discard(session_id) + if discard_lock: + _session_locks.discard(session_id) + + @classmethod + def _touch_lru(cls, session_id: str) -> None: + if session_id in cls._lru: + cls._lru.move_to_end(session_id) + return + + while len(cls._lru) >= cls._MAX_CACHED_SESSIONS: + evict_id, _ = cls._lru.popitem(last=False) + cls._drop_cached_session(evict_id) + log.debug("message.cache.evicted", {"session_id": evict_id}) + cls._lru[session_id] = True + @classmethod def _serialize_message_parts(cls, parts: List[PartType]) -> List[Dict[str, Any]]: return [p.model_dump() for p in parts] @@ -571,26 +644,59 @@ async def _ensure_cache(cls, session_id: str) -> None: Uses a per-session lock so operations on different sessions are fully concurrent. """ - await cls._ensure_message_cache(session_id) - if session_id in cls._parts_fully_loaded: - cls._lru.move_to_end(session_id) - return - - lock = _session_locks.get(session_id) - async with lock: + reload_attempts = 0 + while True: + await cls._ensure_message_cache(session_id) if session_id in cls._parts_fully_loaded: - cls._lru.move_to_end(session_id) + if cls._has_message_cache(session_id): + cls._touch_lru(session_id) + return + cls._parts_fully_loaded.discard(session_id) + + lock = _session_locks.get(session_id) + async with lock: + if session_id in cls._parts_fully_loaded: + if cls._has_message_cache(session_id): + cls._touch_lru(session_id) + return + cls._parts_fully_loaded.discard(session_id) + continue + + cls._cache_loading_sessions.add(session_id) + try: + token = cls._cache_token(session_id) + message_times = { + message.id: message.time + for message in cls._messages_cache.get(session_id, []) + } + await cls._load_all_parts_locked(session_id, message_times=message_times) + cache_invalidated = ( + token != cls._cache_token(session_id) + or not cls._has_message_cache(session_id) + ) + finally: + cls._cache_loading_sessions.discard(session_id) + cls._prune_session_cache_generations() + + if cache_invalidated: + reload_attempts += 1 + cls._drop_cached_session(session_id, discard_lock=False) + if reload_attempts >= cls._CACHE_RELOAD_RETRY_LIMIT: + log.warn("message.cache.reload_aborted_after_invalidation", { + "session_id": session_id, + "attempts": reload_attempts, + }) + raise MessageCacheInvalidatedError( + f"Message cache for {session_id} was invalidated during load" + ) + log.debug("message.cache.reload_after_invalidation", {"session_id": session_id}) + continue + + cls._parts_fully_loaded.add(session_id) + cls._touch_lru(session_id) + log.debug("message.cache.loaded", {"session_id": session_id, "parts": "all"}) return - message_times = { - message.id: message.time - for message in cls._messages_cache.get(session_id, []) - } - await cls._load_all_parts_locked(session_id, message_times=message_times) - cls._parts_fully_loaded.add(session_id) - cls._lru.move_to_end(session_id) - log.debug("message.cache.loaded", {"session_id": session_id, "parts": "all"}) - @classmethod async def _ensure_message_cache(cls, session_id: str) -> None: """Ensure message metadata is cached without loading every part. @@ -599,28 +705,23 @@ async def _ensure_message_cache(cls, session_id: str) -> None: parts here would deserialize large tool outputs even when only the newest page is requested. """ - if session_id in cls._lru: - cls._lru.move_to_end(session_id) + if cls._has_message_cache(session_id): + cls._touch_lru(session_id) return + if session_id in cls._lru: + cls._drop_cached_session(session_id, discard_lock=False) lock = _session_locks.get(session_id) async with lock: - if session_id in cls._lru: - cls._lru.move_to_end(session_id) + if cls._has_message_cache(session_id): + cls._touch_lru(session_id) return + if session_id in cls._lru: + cls._drop_cached_session(session_id, discard_lock=False) while len(cls._lru) >= cls._MAX_CACHED_SESSIONS: evict_id, _ = cls._lru.popitem(last=False) - cls._cancel_parts_flush_task(evict_id) - cls._messages_cache.pop(evict_id, None) - cls._msg_id_index.pop(evict_id, None) - cls._parts_cache.pop(evict_id, None) - cls._parts_revision_cache.pop(evict_id, None) - cls._parts_serialized_cache.pop(evict_id, None) - cls._parts_storage_format.pop(evict_id, None) - cls._parts_persisted_mids.pop(evict_id, None) - cls._parts_fully_loaded.discard(evict_id) - _session_locks.discard(evict_id) + cls._drop_cached_session(evict_id) log.debug("message.cache.evicted", {"session_id": evict_id}) storage_key = f"{cls._MESSAGE_PREFIX}:{session_id}" @@ -1204,7 +1305,7 @@ async def _persist_parts(cls, session_id: str, *, message_id: Optional[str] = No else: serialized = { mid: cls._serialize_message_parts(mparts) - for mid, mparts in all_parts.items() + for mid, mparts in list(all_parts.items()) } cls._parts_serialized_cache[session_id] = serialized cls._parts_persisted_mids[session_id] = set(serialized.keys()) @@ -1224,7 +1325,7 @@ async def _persist_parts(cls, session_id: str, *, message_id: Optional[str] = No persisted_mids = cls._parts_persisted_mids.setdefault(session_id, set()) current_mids = set(all_parts.keys()) - for mid, mparts in all_parts.items(): + for mid, mparts in list(all_parts.items()): serialized_one = cls._serialize_message_parts(mparts) if serialized.get(mid) != serialized_one or mid not in persisted_mids: await Storage.set( @@ -1413,7 +1514,7 @@ async def parts(cls, message_id: str, session_id: Optional[str] = None) -> List[ else: # Search all sessions for the message parts = [] - for sid in cls._parts_cache: + for sid in list(cls._parts_cache): await cls._ensure_cache(sid) if message_id in cls._parts_cache.get(sid, {}): parts = cls._parts_cache[sid][message_id] @@ -1822,7 +1923,7 @@ async def clear(cls, session_id: str) -> int: Returns: Number of messages cleared """ - await cls._ensure_cache(session_id) + await cls._ensure_message_cache(session_id) async with _session_locks.get(session_id): count = len(cls._messages_cache.get(session_id, [])) @@ -1856,20 +1957,11 @@ def invalidate_cache(cls, session_id: Optional[str] = None) -> None: session_id: Optional session ID, if None invalidates all """ if session_id: - cls._cancel_parts_flush_task(session_id) - cls._lru.pop(session_id, None) - cls._messages_cache.pop(session_id, None) - cls._msg_id_index.pop(session_id, None) - cls._parts_cache.pop(session_id, None) - cls._parts_revision_cache.pop(session_id, None) - cls._parts_serialized_cache.pop(session_id, None) - cls._parts_storage_format.pop(session_id, None) - cls._parts_persisted_mids.pop(session_id, None) - cls._parts_fully_loaded.discard(session_id) - _session_locks.discard(session_id) + cls._drop_cached_session(session_id) else: for sid in list(cls._parts_flush_tasks): cls._cancel_parts_flush_task(sid) + cls._bump_cache_epoch() cls._lru.clear() cls._messages_cache.clear() cls._msg_id_index.clear() @@ -2349,7 +2441,7 @@ def parts(cls, message_id: str, session_id: Optional[str] = None) -> List[PartTy parts = Message._parts_cache.get(session_id, {}).get(message_id, []) else: parts = [] - for sid in Message._parts_cache: + for sid in list(Message._parts_cache): if message_id in Message._parts_cache.get(sid, {}): parts = Message._parts_cache[sid][message_id] break diff --git a/flocks/session/runner.py b/flocks/session/runner.py index 1e6646541..e9c10b5ae 100644 --- a/flocks/session/runner.py +++ b/flocks/session/runner.py @@ -753,7 +753,7 @@ def cancel_children(cls, parent_session_id: str) -> int: cancelled = 0 child_ids = [ - sid for sid, runner in cls._active_sessions.items() + sid for sid, runner in list(cls._active_sessions.items()) if getattr(runner.session, 'parent_id', None) == parent_session_id ] for sid in child_ids: @@ -1516,7 +1516,7 @@ async def _build_device_asset_hint(self) -> Optional[str]: return ( "## 安全设备使用\n\n" f"{summary}\n\n" - "当用户要操作特定机房、设备或产品时,先调用 `device_context` 获取 `device_id` 等相关信息。" + "当用户要操作特定机房、设备或产品时,先调用 `device_manage(action='list')` 获取 `device_id` 等相关信息。" "如果当前无已接入设备,请提示用户前往「设备接入」页面添加设备。" "使用 `tool_search` 搜索工具名称查看用法;执行设备工具时必须传入目标 `device_id`。" "如果同类设备有多个候选,不要猜测,先询问用户选择。" @@ -2939,7 +2939,7 @@ def _build_llm_response_payload( name=tc_state.name, arguments=tc_state.input, ) - for tc_state in processor.tool_calls.values() + for tc_state in list(processor.tool_calls.values()) ] result_action = "continue" if tool_calls_for_result else "stop" response_payload = _build_llm_response_payload( diff --git a/flocks/session/session_loop.py b/flocks/session/session_loop.py index d09c7a14a..35faa5678 100644 --- a/flocks/session/session_loop.py +++ b/flocks/session/session_loop.py @@ -162,7 +162,7 @@ def abort_children(cls, parent_session_id: str) -> int: """Abort all child loops whose session.parent_id matches, recursively.""" aborted = 0 child_ids = [ - sid for sid, ctx in cls._active_loops.items() + sid for sid, ctx in list(cls._active_loops.items()) if getattr(ctx.session, 'parent_id', None) == parent_session_id ] for sid in child_ids: @@ -190,6 +190,28 @@ async def _publish_runtime_event( "error": str(exc), }) + @classmethod + async def _publish_turn_stopped( + cls, + callbacks: "LoopCallbacks", + session_id: str, + *, + step: int, + stop_reason: str, + ) -> None: + turn_state = set_turn_state( + session_id, + step=step, + status="stopped", + stop_reason=stop_reason, + queued_message_detected=False, + ) + await cls._publish_runtime_event( + callbacks, + "turn.stopped", + turn_state.model_dump(by_alias=True), + ) + @classmethod async def _publish_session_status( cls, @@ -592,6 +614,12 @@ async def _run_loop( }) if not messages: log.info("loop.no_messages", {"session_id": ctx.session.id}) + await cls._publish_turn_stopped( + callbacks, + ctx.session.id, + step=ctx.step, + stop_reason="no_messages", + ) break # Analyze messages (matching TUI lines 277-292) @@ -635,7 +663,17 @@ async def _run_loop( # Check if we have a user message if not last_user: - log.error("loop.no_user_message", {"session_id": ctx.session.id}) + log.info("loop.no_user_message", { + "session_id": ctx.session.id, + "message_count": len(messages), + "roles": [str(getattr(msg, "role", "")) for msg in messages[-5:]], + }) + await cls._publish_turn_stopped( + callbacks, + ctx.session.id, + step=ctx.step, + stop_reason="no_user_message", + ) break last_assistant_parts = ( diff --git a/flocks/storage/storage.py b/flocks/storage/storage.py index 77bcc5563..bad02354f 100644 --- a/flocks/storage/storage.py +++ b/flocks/storage/storage.py @@ -25,11 +25,13 @@ class NotFoundError(Exception): """Raised when a resource is not found""" + pass class StorageError(Exception): """Base storage error""" + pass @@ -44,10 +46,7 @@ class CheckpointBusyError(StorageError): """ def __init__(self, mode: str, log_pages: int, checkpointed_pages: int) -> None: - super().__init__( - f"wal_checkpoint({mode}) busy: " - f"log_pages={log_pages}, checkpointed_pages={checkpointed_pages}" - ) + super().__init__(f"wal_checkpoint({mode}) busy: log_pages={log_pages}, checkpointed_pages={checkpointed_pages}") self.mode = mode self.log_pages = log_pages self.checkpointed_pages = checkpointed_pages @@ -56,15 +55,15 @@ def __init__(self, mode: str, log_pages: int, checkpointed_pages: int) -> None: class Storage: """ Storage namespace for persistent data operations - + Similar to Flocks's Storage namespace. Provides both TypeScript-compatible API (key arrays) and Python API (key strings). """ - + NotFoundError = NotFoundError StorageError = StorageError CheckpointBusyError = CheckpointBusyError - + _log = Log.create(service="storage") _db_path: Optional[Path] = None _initialized = False @@ -95,21 +94,6 @@ class Storage: _sqlite_write_retry_base_delay_s = 0.05 _multi_db_migration_marker_key = "storage.migration.multi_db.v1" _multi_db_migration_batch_size = 500 - _workflow_key_prefixes = ( - "workflow/", - "workflow_execution/", - "workflow_execution_index/", - "workflow_execution_step/", - "workflow_registry/", - "workflow_release/", - "workflow_runtime/", - "workflow_local_pid/", - "workflow_api_service/", - "workflow_integration_config/", - "workflow_kafka_config/", - "workflow_poller_config/", - "workflow_syslog_config/", - ) # Substrings that mark an SQLite file as unrecoverably damaged at open # time. We deliberately keep this list short and English-only because @@ -121,20 +105,31 @@ class Storage: ) @classmethod - def _invalidate_runtime_caches(cls) -> None: + def _prefix_matches_runtime_cache(cls, prefix: Optional[str], roots: tuple[str, ...]) -> bool: + if prefix is None: + return True + normalized = str(prefix) + return any(normalized == root.rstrip(":/") or normalized.startswith(root) for root in roots) + + @classmethod + def _invalidate_runtime_caches(cls, prefix: Optional[str] = None) -> None: """Clear higher-level caches that depend on the active storage DB.""" - try: - from flocks.session.session import Session - Session.invalidate_cache() - except Exception: - pass + if cls._prefix_matches_runtime_cache(prefix, ("session:",)): + try: + from flocks.session.session import Session + + Session.invalidate_cache() + except Exception: + pass + + if cls._prefix_matches_runtime_cache(prefix, ("message:", "message_parts:")): + try: + from flocks.session.message import Message + + Message.invalidate_cache() + except Exception: + pass - try: - from flocks.session.message import Message - Message.invalidate_cache() - except Exception: - pass - @classmethod def get_db_path(cls) -> Path: """Return the resolved database file path. @@ -155,35 +150,20 @@ def get_workflow_db_path(cls) -> Path: @classmethod def route_db_path_for_key(cls, key: str) -> Path: """Return the storage DB path that owns *key*.""" - if cls._is_workflow_key(key): - return cls.get_workflow_db_path() return cls.get_db_path() @classmethod def route_db_path_for_prefix(cls, prefix: Optional[str]) -> Path: """Return the storage DB path for a prefix-scoped KV operation.""" - if prefix and ( - cls._is_workflow_key(prefix) - or f"{prefix}/" in cls._workflow_key_prefixes - ): - return cls.get_workflow_db_path() return cls.get_db_path() @classmethod - def _is_workflow_key(cls, key: str) -> bool: - return any(key.startswith(prefix) for prefix in cls._workflow_key_prefixes) - - @classmethod - async def configure_connection( - cls, conn: aiosqlite.Connection - ) -> aiosqlite.Connection: + async def configure_connection(cls, conn: aiosqlite.Connection) -> aiosqlite.Connection: """Apply the runtime SQLite contract to an async connection.""" await conn.execute(f"PRAGMA journal_mode={cls._sqlite_journal_mode}") await conn.execute(f"PRAGMA synchronous={cls._sqlite_synchronous}") await conn.execute(f"PRAGMA busy_timeout={cls._sqlite_busy_timeout_ms}") - await conn.execute( - f"PRAGMA wal_autocheckpoint={cls._sqlite_wal_autocheckpoint_pages}" - ) + await conn.execute(f"PRAGMA wal_autocheckpoint={cls._sqlite_wal_autocheckpoint_pages}") await conn.execute("PRAGMA foreign_keys = ON") return conn @@ -193,9 +173,7 @@ def configure_sync_connection(cls, conn: sqlite3.Connection) -> sqlite3.Connecti conn.execute(f"PRAGMA journal_mode={cls._sqlite_journal_mode}") conn.execute(f"PRAGMA synchronous={cls._sqlite_synchronous}") conn.execute(f"PRAGMA busy_timeout={cls._sqlite_busy_timeout_ms}") - conn.execute( - f"PRAGMA wal_autocheckpoint={cls._sqlite_wal_autocheckpoint_pages}" - ) + conn.execute(f"PRAGMA wal_autocheckpoint={cls._sqlite_wal_autocheckpoint_pages}") conn.execute("PRAGMA foreign_keys = ON") return conn @@ -288,6 +266,7 @@ def _quarantine_corrupt_db(cls, db_path: Path) -> Optional[Path]: return None from datetime import UTC + timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%S") suffix = f".corrupt.{timestamp}" @@ -302,10 +281,13 @@ def _quarantine_corrupt_db(cls, db_path: Path) -> Optional[Path]: try: db_path.rename(new_main) except OSError as exc: - cls._log.error("storage.quarantine.rename_failed", { - "path": str(db_path), - "error": str(exc), - }) + cls._log.error( + "storage.quarantine.rename_failed", + { + "path": str(db_path), + "error": str(exc), + }, + ) return None for sidecar_name in (f"{db_path.name}-wal", f"{db_path.name}-shm"): @@ -315,20 +297,26 @@ def _quarantine_corrupt_db(cls, db_path: Path) -> Optional[Path]: try: side_path.rename(side_path.with_name(sidecar_name + suffix)) except OSError as exc: - cls._log.warn("storage.quarantine.sidecar_rename_failed", { - "path": str(side_path), - "error": str(exc), - }) - - cls._log.error("storage.corruption.quarantined", { - "original_path": str(db_path), - "quarantined_path": str(new_main), - "hint": ( - "Server is starting with a fresh empty database. " - "Run scripts/recover_raw_flocks_db.py against the " - "quarantined file to attempt data recovery." - ), - }) + cls._log.warn( + "storage.quarantine.sidecar_rename_failed", + { + "path": str(side_path), + "error": str(exc), + }, + ) + + cls._log.error( + "storage.corruption.quarantined", + { + "original_path": str(db_path), + "quarantined_path": str(new_main), + "hint": ( + "Server is starting with a fresh empty database. " + "Run scripts/recover_raw_flocks_db.py against the " + "quarantined file to attempt data recovery." + ), + }, + ) return new_main @classmethod @@ -405,14 +393,17 @@ async def _run_write_with_retry( raise last_exc = exc sleep_s = delay_s * (2 ** (attempt - 1)) - cls._log.warn("storage.sqlite_write_retry", { - "action": action, - "target": target, - "attempt": attempt, - "max_attempts": attempts, - "sleep_s": round(sleep_s, 3), - "error": str(exc), - }) + cls._log.warn( + "storage.sqlite_write_retry", + { + "action": action, + "target": target, + "attempt": attempt, + "max_attempts": attempts, + "sleep_s": round(sleep_s, 3), + "error": str(exc), + }, + ) await asyncio.sleep(sleep_s) assert last_exc is not None @@ -420,9 +411,7 @@ async def _run_write_with_retry( @classmethod @asynccontextmanager - async def connect( - cls, db_path: Optional[Path] = None - ) -> AsyncIterator[aiosqlite.Connection]: + async def connect(cls, db_path: Optional[Path] = None) -> AsyncIterator[aiosqlite.Connection]: """Open a configured async SQLite connection for the active storage DB.""" target = Path(db_path) if db_path is not None else cls.get_db_path() target.parent.mkdir(parents=True, exist_ok=True) @@ -455,14 +444,14 @@ def register_ddl(cls, ddl: str) -> None: def _resolve_key(key: List[str] | str) -> str: """ Convert key to string format - + Matches TypeScript's resolve() function: - Array keys: ["session", "proj1", "ses1"] -> "session/proj1/ses1" - String keys: passed through unchanged - + Args: key: Key as list or string - + Returns: Key as string """ @@ -473,28 +462,12 @@ def _resolve_key(key: List[str] | str) -> str: @classmethod def _like_prefix_pattern(cls, prefix: str) -> str: """Return a SQLite LIKE pattern that treats prefix chars literally.""" - return ( - prefix.replace("\\", "\\\\") - .replace("%", "\\%") - .replace("_", "\\_") - + "%" - ) + return prefix.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + "%" @classmethod def _like_prefix_clause(cls, column: str = "key") -> str: return f"{column} LIKE ? ESCAPE '\\'" - @classmethod - def _workflow_prefix_filter(cls) -> tuple[str, tuple[str, ...]]: - clause = " OR ".join( - cls._like_prefix_clause("key") for _ in cls._workflow_key_prefixes - ) - params = tuple( - cls._like_prefix_pattern(prefix) - for prefix in cls._workflow_key_prefixes - ) - return clause, params - @staticmethod def _marker_row_count(marker: Dict[str, Any], key: str) -> int: try: @@ -502,30 +475,6 @@ def _marker_row_count(marker: Dict[str, Any], key: str) -> int: except (TypeError, ValueError): return 0 - @classmethod - async def _ensure_storage_table(cls, db_path: Path) -> None: - """Ensure the generic KV storage table exists in *db_path*.""" - db_path.parent.mkdir(parents=True, exist_ok=True) - - async def _create_storage_table() -> None: - async with cls.connect(db_path) as db: - await db.execute(""" - CREATE TABLE IF NOT EXISTS storage ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - type TEXT NOT NULL, - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - ) - """) - await db.commit() - - await cls._run_write_with_retry( - _create_storage_table, - action="init.create_storage_table", - target=str(db_path), - ) - @classmethod def _read_multi_db_migration_marker_sync(cls) -> Dict[str, Any]: if cls._db_path is None or not cls._db_path.exists(): @@ -577,129 +526,29 @@ def _write_multi_db_migration_marker_sync(cls, marker: Dict[str, Any]) -> None: conn.close() @classmethod - def _copy_workflow_kv_to_workflow_db_sync(cls) -> tuple[int, int]: - if cls._db_path is None: - raise StorageError("Storage DB path is not initialized") - workflow_db_path = cls.get_workflow_db_path() - source = cls.connect_sync(cls._db_path) - target = cls.connect_sync(workflow_db_path) - try: - clauses, params = cls._workflow_prefix_filter() - total_migrated = 0 - last_key = "" - while True: - rows = source.execute( - f""" - SELECT key, value, type, created_at, updated_at - FROM storage - WHERE ({clauses}) AND key > ? - ORDER BY key - LIMIT ? - """, - (*params, last_key, cls._multi_db_migration_batch_size), - ).fetchall() - if not rows: - break + async def _ensure_storage_table(cls, db_path: Path) -> None: + """Ensure the generic KV storage table exists in *db_path*.""" + db_path.parent.mkdir(parents=True, exist_ok=True) - target.executemany( - """ - INSERT OR REPLACE INTO storage (key, value, type, created_at, updated_at) - VALUES (?, ?, ?, ?, ?) - """, - [ - ( - row["key"], - row["value"], - row["type"], - row["created_at"], - row["updated_at"], - ) - for row in rows - ], - ) - target.commit() - - keys = [row["key"] for row in rows] - placeholders = ", ".join("?" for _ in keys) - copied = target.execute( - f"SELECT COUNT(*) FROM storage WHERE key IN ({placeholders})", - keys, - ).fetchone()[0] - if copied != len(keys): - raise StorageError( - "Workflow storage migration verification failed: " - f"expected {len(keys)} copied rows, got {copied}" + async def _create_storage_table() -> None: + async with cls.connect(db_path) as db: + await db.execute(""" + CREATE TABLE IF NOT EXISTS storage ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + type TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL ) + """) + await db.commit() - total_migrated += len(rows) - last_key = keys[-1] - - return total_migrated, 0 - except Exception: - if source.in_transaction: - source.rollback() - if target.in_transaction: - target.rollback() - raise - finally: - source.close() - target.close() - - @classmethod - async def _migrate_workflow_storage_to_workflow_db( - cls, - *, - workflow_db_existed_before_init: bool, - ) -> None: - """Copy legacy workflow-domain KV rows from flocks.db to workflow.db.""" - marker = await asyncio.to_thread(cls._read_multi_db_migration_marker_sync) - if marker.get("workflow_migrated") is True: - if ( - not workflow_db_existed_before_init - and cls._marker_row_count(marker, "workflow_rows") > 0 - ): - raise StorageError( - "workflow.db is missing after a completed multi-db migration" - ) - return - - from datetime import UTC + await cls._run_write_with_retry( + _create_storage_table, + action="init.create_storage_table", + target=str(db_path), + ) - marker.setdefault("version", 1) - marker.setdefault("started_at", datetime.now(UTC).isoformat()) - marker["source_db"] = str(cls.get_db_path()) - marker["workflow_db"] = str(cls.get_workflow_db_path()) - try: - migrated, deleted = await asyncio.to_thread( - cls._copy_workflow_kv_to_workflow_db_sync - ) - marker["workflow_migrated"] = True - marker["workflow_migrated_at"] = datetime.now(UTC).isoformat() - marker["workflow_rows"] = migrated - marker["workflow_source_rows_deleted"] = deleted - marker.pop("workflow_error", None) - await asyncio.to_thread(cls._write_multi_db_migration_marker_sync, marker) - cls._log.info("storage.multi_db.workflow_migrated", { - "rows": migrated, - "source_rows_deleted": deleted, - "source_db": marker["source_db"], - "workflow_db": marker["workflow_db"], - }) - except Exception as exc: - marker["workflow_migrated"] = False - marker["workflow_error"] = str(exc) - marker["workflow_failed_at"] = datetime.now(UTC).isoformat() - try: - await asyncio.to_thread(cls._write_multi_db_migration_marker_sync, marker) - except Exception: - pass - cls._log.error("storage.multi_db.workflow_migration_failed", { - "source_db": marker.get("source_db"), - "workflow_db": marker.get("workflow_db"), - "error": str(exc), - }) - raise - @classmethod async def init(cls, db_path: Optional[Path] = None) -> None: """ @@ -733,10 +582,13 @@ async def init(cls, db_path: Optional[Path] = None) -> None: # SQLite never gets a chance to delete adjacent WAL/SHM sidecars # — those sidecars are what the offline recovery script reads. if cls._file_has_invalid_sqlite_header(cls._db_path): - cls._log.error("storage.corruption.invalid_header", { - "db_path": str(cls._db_path), - "size": cls._db_path.stat().st_size, - }) + cls._log.error( + "storage.corruption.invalid_header", + { + "db_path": str(cls._db_path), + "size": cls._db_path.stat().st_size, + }, + ) quarantined = cls._quarantine_corrupt_db(cls._db_path) if quarantined is None: # The pre-flight check confirmed the file is not SQLite, @@ -764,27 +616,19 @@ async def init(cls, db_path: Optional[Path] = None) -> None: except Exception as exc: if not cls._is_db_corruption_error(exc): raise - cls._log.error("storage.corruption.detected_on_init", { - "db_path": str(cls._db_path), - "error": str(exc), - "error_type": type(exc).__name__, - }) + cls._log.error( + "storage.corruption.detected_on_init", + { + "db_path": str(cls._db_path), + "error": str(exc), + "error_type": type(exc).__name__, + }, + ) quarantined = cls._quarantine_corrupt_db(cls._db_path) if quarantined is None: raise await cls._bootstrap_schema() - # Workflow-domain KV rows live in a sibling workflow.db. Keep the - # public Storage API unchanged by routing workflow key prefixes at - # the KV-operation seam, while initialising the sibling DB here so - # startup fails early if that database is unusable. - workflow_db_path = cls.get_workflow_db_path() - workflow_db_existed_before_init = workflow_db_path.exists() - await cls._ensure_storage_table(workflow_db_path) - await cls._migrate_workflow_storage_to_workflow_db( - workflow_db_existed_before_init=workflow_db_existed_before_init, - ) - # Drain any residual WAL frames left by the previous process so the # next ``SIGKILL`` does not have to truncate a 4 MB-class WAL during # recovery (which is exactly when main-DB page 1 / the header can @@ -797,10 +641,13 @@ async def init(cls, db_path: Optional[Path] = None) -> None: cls._init_pid = os.getpid() cls._initialized = True - cls._log.info("storage.initialized", { - "db_path": str(db_path), - "pid": cls._init_pid, - }) + cls._log.info( + "storage.initialized", + { + "db_path": str(db_path), + "pid": cls._init_pid, + }, + ) @classmethod async def _checkpoint(cls, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: @@ -830,9 +677,7 @@ async def _checkpoint(cls, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: if mode_normalised not in valid_modes: raise ValueError(f"invalid checkpoint mode: {mode!r}") async with cls.connect(cls._db_path) as db: - cursor = await db.execute( - f"PRAGMA wal_checkpoint({mode_normalised})" - ) + cursor = await db.execute(f"PRAGMA wal_checkpoint({mode_normalised})") row = await cursor.fetchone() await db.commit() @@ -847,9 +692,7 @@ async def _checkpoint(cls, *, mode: str = "TRUNCATE") -> tuple[int, int, int]: log_pages = int(row[1]) checkpointed_pages = int(row[2]) if busy != 0: - raise CheckpointBusyError( - mode_normalised, log_pages, checkpointed_pages - ) + raise CheckpointBusyError(mode_normalised, log_pages, checkpointed_pages) return (busy, log_pages, checkpointed_pages) # Tunables for the shutdown WAL drain. Kept as class attributes so @@ -892,51 +735,61 @@ async def shutdown(cls) -> None: try: for attempt in range(1, attempts + 1): try: - _, log_pages, checkpointed = await cls._checkpoint( - mode="TRUNCATE" + _, log_pages, checkpointed = await cls._checkpoint(mode="TRUNCATE") + cls._log.info( + "storage.shutdown.checkpoint.done", + { + "db_path": str(cls._db_path) if cls._db_path else None, + "log_pages": log_pages, + "checkpointed_pages": checkpointed, + "attempts": attempt, + }, ) - cls._log.info("storage.shutdown.checkpoint.done", { - "db_path": str(cls._db_path) if cls._db_path else None, - "log_pages": log_pages, - "checkpointed_pages": checkpointed, - "attempts": attempt, - }) return except CheckpointBusyError as exc: last_busy = exc - cls._log.warn("storage.shutdown.checkpoint.busy", { - "attempt": attempt, - "max_attempts": attempts, - "mode": exc.mode, - "log_pages": exc.log_pages, - "checkpointed_pages": exc.checkpointed_pages, - }) + cls._log.warn( + "storage.shutdown.checkpoint.busy", + { + "attempt": attempt, + "max_attempts": attempts, + "mode": exc.mode, + "log_pages": exc.log_pages, + "checkpointed_pages": exc.checkpointed_pages, + }, + ) if attempt < attempts: await asyncio.sleep(backoff * attempt) except Exception as exc: last_failure = exc - cls._log.warn("storage.shutdown.checkpoint.failed", { - "db_path": str(cls._db_path) if cls._db_path else None, - "error": str(exc), - "error_type": type(exc).__name__, - }) + cls._log.warn( + "storage.shutdown.checkpoint.failed", + { + "db_path": str(cls._db_path) if cls._db_path else None, + "error": str(exc), + "error_type": type(exc).__name__, + }, + ) break # Reached only on persistent busy or fatal failure. Do NOT # log "done" — that would mask the residual-WAL risk this # whole method exists to prevent. - cls._log.warn("storage.shutdown.checkpoint.unfinished", { - "db_path": str(cls._db_path) if cls._db_path else None, - "busy": last_busy is not None, - "log_pages": getattr(last_busy, "log_pages", None), - "checkpointed_pages": getattr(last_busy, "checkpointed_pages", None), - "fatal_error": str(last_failure) if last_failure else None, - "hint": ( - "WAL was not truncated; next startup will run WAL " - "recovery and remains at risk of header corruption if " - "killed mid-recovery." - ), - }) + cls._log.warn( + "storage.shutdown.checkpoint.unfinished", + { + "db_path": str(cls._db_path) if cls._db_path else None, + "busy": last_busy is not None, + "log_pages": getattr(last_busy, "log_pages", None), + "checkpointed_pages": getattr(last_busy, "checkpointed_pages", None), + "fatal_error": str(last_failure) if last_failure else None, + "hint": ( + "WAL was not truncated; next startup will run WAL " + "recovery and remains at risk of header corruption if " + "killed mid-recovery." + ), + }, + ) finally: cls._initialized = False cls._init_pid = None @@ -955,6 +808,7 @@ async def _bootstrap_schema(cls) -> None: # Initialize vector storage tables (for memory system) try: from flocks.storage.vector import ensure_vector_tables + vector_status = await ensure_vector_tables(cls._db_path) cls._log.info("storage.vector.initialized", vector_status) except Exception as e: @@ -966,6 +820,7 @@ async def _bootstrap_schema(cls) -> None: # Run extension DDLs registered before init for ddl in cls._extension_ddls: try: + async def _run_extension_ddl() -> None: async with cls.connect(cls._db_path) as db: await db.executescript(ddl) @@ -987,6 +842,7 @@ async def _create_model_management_tables(cls) -> None: (credentials, model settings, default models, custom providers) is stored in flocks.json / .secret.json. """ + async def _create_tables() -> None: async with cls.connect(cls._db_path) as db: await db.executescript(""" @@ -1020,7 +876,10 @@ async def _create_tables() -> None: schema_additions = [ ("message_id", "ALTER TABLE usage_records ADD COLUMN message_id TEXT"), - ("cache_write_tokens", "ALTER TABLE usage_records ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0"), + ( + "cache_write_tokens", + "ALTER TABLE usage_records ADD COLUMN cache_write_tokens INTEGER NOT NULL DEFAULT 0", + ), ("source", "ALTER TABLE usage_records ADD COLUMN source TEXT NOT NULL DEFAULT 'live'"), ("backfilled_at", "ALTER TABLE usage_records ADD COLUMN backfilled_at TEXT"), ] @@ -1050,7 +909,7 @@ async def _create_tables() -> None: target=str(cls._db_path), ) cls._log.info("storage.model_management_tables_ready") - + @classmethod async def _ensure_init(cls) -> None: """Ensure storage is initialized for the *current* process. @@ -1068,29 +927,27 @@ async def _ensure_init(cls) -> None: whenever the current PID differs. """ current_pid = os.getpid() - forked = ( - cls._initialized - and cls._init_pid is not None - and cls._init_pid != current_pid - ) + forked = cls._initialized and cls._init_pid is not None and cls._init_pid != current_pid if forked: - cls._log.warn("storage.fork_detected", { - "parent_pid": cls._init_pid, - "child_pid": current_pid, - "hint": "Reinitialising Storage to avoid sharing SQLite " - "file descriptors across processes.", - }) + cls._log.warn( + "storage.fork_detected", + { + "parent_pid": cls._init_pid, + "child_pid": current_pid, + "hint": "Reinitialising Storage to avoid sharing SQLite file descriptors across processes.", + }, + ) cls._initialized = False cls._init_pid = None if not cls._initialized or cls._db_path is None or not cls._db_path.exists(): await cls.init(cls._db_path) - + @classmethod async def set(cls, key: str, value: Any, value_type: str = "json") -> None: """ Store a value - + Args: key: Storage key value: Value to store (will be JSON serialized) @@ -1098,82 +955,78 @@ async def set(cls, key: str, value: Any, value_type: str = "json") -> None: """ await cls._ensure_init() db_path = cls.route_db_path_for_key(key) - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) - + if isinstance(value, BaseModel): serialized = value.model_dump_json() else: serialized = json.dumps(value) - + from datetime import UTC + now = datetime.now(UTC).isoformat() - + async def _write() -> None: async with cls.connect(db_path) as db: - await db.execute(""" + await db.execute( + """ INSERT OR REPLACE INTO storage (key, value, type, created_at, updated_at) VALUES (?, ?, ?, COALESCE((SELECT created_at FROM storage WHERE key = ?), ?), ?) - """, (key, serialized, value_type, key, now, now)) + """, + (key, serialized, value_type, key, now, now), + ) await db.commit() await cls._run_write_with_retry(_write, action="set", target=key) - + cls._log.debug("storage.set", {"key": key, "type": value_type}) - + @classmethod async def get(cls, key: str, model: Optional[Type[T]] = None) -> Optional[T | Any]: """ Retrieve a value - + Args: key: Storage key model: Optional Pydantic model class to deserialize into - + Returns: Stored value or None if not found """ await cls._ensure_init() db_path = cls.route_db_path_for_key(key) - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) - + async with cls.connect(db_path) as db: - async with db.execute( - "SELECT value, type FROM storage WHERE key = ?", (key,) - ) as cursor: + async with db.execute("SELECT value, type FROM storage WHERE key = ?", (key,)) as cursor: row = await cursor.fetchone() - + if row is None: return None - + value_str, value_type = row - + if model is not None and hasattr(model, "model_validate_json"): return model.model_validate_json(value_str) # Fall back to a plain JSON decode when no Pydantic model is supplied # (or when callers accidentally pass a builtin container type such as # ``dict``/``list``, which is not a Pydantic model). return json.loads(value_str) - + @classmethod async def delete(cls, key: str) -> bool: """ Delete a value - + Args: key: Storage key - + Returns: True if deleted, False if not found """ await cls._ensure_init() db_path = cls.route_db_path_for_key(key) - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) - + async def _delete() -> bool: async with cls.connect(db_path) as db: cursor = await db.execute("DELETE FROM storage WHERE key = ?", (key,)) @@ -1181,32 +1034,29 @@ async def _delete() -> bool: return cursor.rowcount > 0 deleted = await cls._run_write_with_retry(_delete, action="delete", target=key) - + if deleted: cls._log.debug("storage.delete", {"key": key}) - + return deleted - + @classmethod async def list_keys(cls, prefix: Optional[str] = None) -> List[str]: """ List all keys, optionally filtered by prefix - + Args: prefix: Optional key prefix to filter by - + Returns: List of matching keys """ await cls._ensure_init() if prefix is None: - db_paths = (cls.get_db_path(), cls.get_workflow_db_path()) + db_paths = (cls.get_db_path(),) else: db_paths = (cls.route_db_path_for_prefix(prefix),) - for db_path in db_paths: - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) - + keys: set[str] = set() for db_path in db_paths: async with cls.connect(db_path) as db: @@ -1220,7 +1070,7 @@ async def list_keys(cls, prefix: Optional[str] = None) -> List[str]: async with db.execute(query, params) as cursor: rows = await cursor.fetchall() keys.update(row[0] for row in rows) - + return sorted(keys) @classmethod @@ -1230,12 +1080,9 @@ async def _list_entry_rows( prefix: Optional[str], ) -> List[Tuple[str, str]]: if prefix is None: - db_paths = (cls.get_db_path(), cls.get_workflow_db_path()) + db_paths = (cls.get_db_path(),) else: db_paths = (cls.route_db_path_for_prefix(prefix),) - for db_path in db_paths: - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) rows_by_key: dict[str, str] = {} for db_path in db_paths: @@ -1297,8 +1144,6 @@ async def list_entries_page( """List one page of entries for a prefix, plus total matching rows.""" await cls._ensure_init() db_path = cls.route_db_path_for_prefix(prefix) - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) safe_offset = max(int(offset), 0) safe_limit = max(int(limit), 0) @@ -1355,46 +1200,39 @@ async def list_raw( async def exists(cls, key: str) -> bool: """ Check if a key exists - + Args: key: Storage key - + Returns: True if exists, False otherwise """ await cls._ensure_init() db_path = cls.route_db_path_for_key(key) - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) - + async with cls.connect(db_path) as db: - async with db.execute( - "SELECT 1 FROM storage WHERE key = ?", (key,) - ) as cursor: + async with db.execute("SELECT 1 FROM storage WHERE key = ?", (key,)) as cursor: row = await cursor.fetchone() - + return row is not None - + @classmethod async def clear(cls, prefix: Optional[str] = None) -> int: """ Clear storage, optionally filtered by prefix - + Args: prefix: Optional key prefix to filter by - + Returns: Number of deleted entries """ await cls._ensure_init() if prefix is None: - db_paths = (cls.get_db_path(), cls.get_workflow_db_path()) + db_paths = (cls.get_db_path(),) else: db_paths = (cls.route_db_path_for_prefix(prefix),) - for db_path in db_paths: - if db_path != cls.get_db_path(): - await cls._ensure_storage_table(db_path) - + async def _clear_db(db_path: Path) -> int: async with cls.connect(db_path) as db: if prefix: @@ -1420,73 +1258,73 @@ async def _clear_db(db_path: Path) -> int: action="clear", target=f"{prefix or ''}@{db_path}", ) - + cls._log.info("storage.clear", {"prefix": prefix, "deleted": deleted}) - cls._invalidate_runtime_caches() + cls._invalidate_runtime_caches(prefix) return deleted - + # ==================== TypeScript-compatible API ==================== - + @classmethod async def read(cls, key: List[str] | str, model: Optional[Type[T]] = None) -> Optional[T | Any]: """ Read a value (TypeScript-compatible API) - + Matches TypeScript: Storage.read(key: string[]) - + Args: key: Storage key as list or string model: Optional Pydantic model class - + Returns: Stored value or None if not found - + Raises: NotFoundError: If key not found (when strict mode needed) """ resolved_key = cls._resolve_key(key) return await cls.get(resolved_key, model) - + @classmethod async def write(cls, key: List[str] | str, content: Any) -> None: """ Write a value (TypeScript-compatible API) - + Matches TypeScript: Storage.write(key: string[], content: T) - + Args: key: Storage key as list or string content: Content to store """ resolved_key = cls._resolve_key(key) await cls.set(resolved_key, content) - + @classmethod async def update(cls, key: List[str] | str, fn: callable, model: Optional[Type[T]] = None) -> Optional[T | Any]: """ Update a value in place (TypeScript-compatible API) - + Matches TypeScript: Storage.update(key: string[], fn: (draft: T) => void) - + Args: key: Storage key as list or string fn: Function that modifies the content in place model: Optional Pydantic model class - + Returns: Updated value - + Raises: NotFoundError: If key not found """ resolved_key = cls._resolve_key(key) - + # Read current value content = await cls.get(resolved_key, model) - + if content is None: raise NotFoundError(f"Key not found: {resolved_key}") - + # If it's a dict, apply function if isinstance(content, dict): fn(content) @@ -1499,43 +1337,43 @@ async def update(cls, key: List[str] | str, fn: callable, model: Optional[Type[T else: # For other types, try to call fn on it fn(content) - + # Write back await cls.set(resolved_key, content) - + return content - + @classmethod async def remove(cls, key: List[str] | str) -> bool: """ Remove a value (TypeScript-compatible API) - + Matches TypeScript: Storage.remove(key: string[]) - + Args: key: Storage key as list or string - + Returns: True if deleted, False if not found """ resolved_key = cls._resolve_key(key) return await cls.delete(resolved_key) - + @classmethod async def list(cls, prefix: List[str] | str | None = None) -> List[List[str]]: """ List keys (TypeScript-compatible API) - + Matches TypeScript: Storage.list(prefix: string[]) - + Args: prefix: Optional key prefix as list or string - + Returns: List of keys as lists (e.g., [["session", "proj1", "ses1"], ...]) """ prefix_str = cls._resolve_key(prefix) if prefix else None keys = await cls.list_keys(prefix_str) - + # Convert string keys back to list format return [key.split("/") for key in keys] diff --git a/flocks/tool/device/device_context_tool.py b/flocks/tool/device/device_context_tool.py deleted file mode 100644 index ac9750263..000000000 --- a/flocks/tool/device/device_context_tool.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Built-in `device_context` tool. - -Lets the Agent query the machine-room / device / tool hierarchy on demand, -rather than having everything pre-baked into the system prompt. - -Usage (Agent-side): - Call ``device_context`` with no arguments to get a structured view of: - 机房 → 设备名称 → 该设备对应的工具列表 - -Output is Markdown so the model can parse and cite device names naturally. -""" -from __future__ import annotations - -from flocks.tool.registry import ToolCategory, ToolContext, ToolResult, ToolRegistry -from flocks.utils.log import Log - -log = Log.create(service="tool.device.device_context_tool") - - -@ToolRegistry.register_function( - name="device_context", - description=( - "查询已接入安全设备的机房结构、设备列表及各设备对应工具名称。" - "当用户涉及特定设备操作时,先调用此工具确认设备名称与工具前缀的映射关系," - "再选择对应设备的工具执行任务。" - ), - category=ToolCategory.SYSTEM, - parameters=[], -) -async def device_context(ctx: ToolContext) -> ToolResult: - """Return the machine-room → device → tool hierarchy as Markdown.""" - try: - from flocks.tool.device.prompt import build_device_context_section - - content = await build_device_context_section() - if not content: - return ToolResult( - success=True, - output=( - "当前没有已接入的安全设备。" - "设备不存在时,请提醒用户前往设备接入页面添加设备。" - ), - ) - return ToolResult(success=True, output=content) - except Exception as exc: - log.warn("tool.device_context.failed", {"error": str(exc)}) - return ToolResult(success=False, error=f"查询设备上下文失败: {exc}") diff --git a/flocks/tool/device/manage_tool.py b/flocks/tool/device/manage_tool.py new file mode 100644 index 000000000..d6e8bf20d --- /dev/null +++ b/flocks/tool/device/manage_tool.py @@ -0,0 +1,273 @@ +"""Built-in device management tool for Rex. + +``device_manage`` is the single system-tool entrypoint for device discovery, +non-secret config updates, and standard connectivity checks. Its +``connectivity_test`` action reuses the existing device test path, so card state +stays consistent with the ``POST /api/devices/{id}/test`` endpoint. +""" +from __future__ import annotations + +from typing import Any, Optional + +from flocks.tool.device.intake import DeviceNotFoundError, test_device, update_device +from flocks.tool.device.models import DeviceIntegrationUpdate +from flocks.tool.registry import ( + ParameterType, + ToolCategory, + ToolContext, + ToolParameter, + ToolRegistry, + ToolResult, +) +from flocks.utils.log import Log + +log = Log.create(service="tool.device.manage_tool") + + +@ToolRegistry.register_function( + name="device_manage", + description=( + "管理已接入安全设备。action=list 用于列出机房、设备、device_id 和工具集;" + "action=update 用于写入/更新已有设备实例的非敏感配置字段;" + "action=connectivity_test 用于测试指定设备连通性并更新设备卡片状态。" + ), + description_cn="列出设备、更新非敏感配置或测试设备连通性", + category=ToolCategory.SYSTEM, + parameters=[ + ToolParameter( + name="action", + type=ParameterType.STRING, + description="操作类型:list 列出设备;update 更新已有设备非敏感配置;connectivity_test 测试设备连通性。", + required=True, + enum=["list", "update", "connectivity_test"], + ), + ToolParameter( + name="device_id", + type=ParameterType.STRING, + description="目标设备实例 ID。action=update 或 connectivity_test 时必填。", + required=False, + ), + ToolParameter( + name="fields", + type=ParameterType.OBJECT, + description=( + "要更新的非敏感设备配置字段,例如 {\"base_url\":\"https://device.local\"}。" + "禁止传入 api_key、secret、password、token、cookie、auth_state 等敏感字段。" + ), + required=False, + ), + ToolParameter( + name="verify_ssl", + type=ParameterType.BOOLEAN, + description="是否开启 SSL 证书验证。仅 action=update 时使用。", + required=False, + ), + ], +) +async def device_manage( + ctx: ToolContext, + action: str, + device_id: Optional[str] = None, + fields: Optional[dict[str, Any]] = None, + verify_ssl: Optional[bool] = None, +) -> ToolResult: + """List devices, update non-secret config, or run a standard probe.""" + normalized_action = (action or "").strip() + if normalized_action == "list": + return await _list_devices() + if normalized_action == "update": + return await _update_device_config(ctx, device_id, fields, verify_ssl) + if normalized_action == "connectivity_test": + return await _connectivity_test(ctx, device_id) + return ToolResult( + success=False, + error="未知 action,请使用 list、update 或 connectivity_test。", + ) + + +async def _list_devices() -> ToolResult: + try: + from flocks.tool.device.prompt import build_device_manage_list_section + + content = await build_device_manage_list_section() + if not content: + return ToolResult( + success=True, + output=( + "当前没有已接入的安全设备。" + "设备不存在时,请提醒用户前往设备接入页面添加设备。" + ), + ) + return ToolResult(success=True, output=content) + except Exception as exc: + log.warn("tool.device_manage.list_failed", {"error": str(exc)}) + return ToolResult(success=False, error=f"查询设备列表失败: {exc}") + + +_SENSITIVE_FIELD_KEYS = frozenset( + { + "api_key", + "apikey", + "secret", + "client_secret", + "password", + "passwd", + "token", + "access_token", + "refresh_token", + "cookie", + "auth_state", + } +) + + +def _normalize_update_fields( + fields: Optional[dict[str, Any]], +) -> tuple[dict[str, str], Optional[str]]: + if fields is None: + return {}, None + if not isinstance(fields, dict): + return {}, "fields 必须是对象,例如 {\"base_url\":\"https://device.local\"}。" + + normalized: dict[str, str] = {} + rejected: list[str] = [] + for raw_key, raw_value in fields.items(): + key = str(raw_key).strip() + if not key: + return {}, "fields 不能包含空字段名。" + if key.lower() in _SENSITIVE_FIELD_KEYS: + rejected.append(key) + continue + normalized[key] = "" if raw_value is None else str(raw_value) + + if rejected: + return {}, ( + "拒绝通过 device_manage(action='update') 写入敏感字段:" + + ", ".join(f"`{key}`" for key in sorted(rejected)) + + "。请在设备接入页面的配置表单中填写密钥、密码、Token、Cookie 或 auth_state。" + ) + + return normalized, None + + +async def _update_device_config( + ctx: ToolContext, + device_id: Optional[str], + fields: Optional[dict[str, Any]], + verify_ssl: Optional[bool], +) -> ToolResult: + target = (device_id or "").strip() + if not target: + return ToolResult( + success=False, + error="action=update 时 device_id 不能为空。", + ) + + normalized_fields, field_error = _normalize_update_fields(fields) + if field_error: + return ToolResult(success=False, error=field_error) + if not normalized_fields and verify_ssl is None: + return ToolResult( + success=False, + error="action=update 至少需要提供 fields 或 verify_ssl。", + ) + + log.info( + "tool.device_manage.update.start", + { + "device_id": target, + "session_id": ctx.session_id, + "fields": sorted(normalized_fields), + "verify_ssl": verify_ssl, + }, + ) + + try: + updated = await update_device( + target, + DeviceIntegrationUpdate( + fields=normalized_fields or None, + verify_ssl=verify_ssl, + ), + ) + except DeviceNotFoundError: + return ToolResult( + success=False, + error=f"设备 {target!r} 未找到,请通过 device_manage(action='list') 确认 device_id。", + ) + except ValueError as exc: + return ToolResult(success=False, error=f"设备配置更新失败: {exc}") + except Exception as exc: + log.warn( + "tool.device_manage.update_failed", + {"device_id": target, "error": str(exc)}, + ) + return ToolResult(success=False, error=f"设备配置更新失败: {exc}") + + return ToolResult( + success=True, + output={ + "device_id": updated.id, + "name": updated.name, + "storage_key": updated.storage_key, + "service_id": updated.service_id, + "enabled": updated.enabled, + "verify_ssl": updated.verify_ssl, + "fields": updated.fields, + "fields_set": updated.fields_set, + "updated_fields": sorted(normalized_fields), + }, + metadata={ + "device_id": updated.id, + "updated_fields": sorted(normalized_fields), + "verify_ssl": updated.verify_ssl, + }, + title="设备配置已更新", + ) + + +async def _connectivity_test(ctx: ToolContext, device_id: Optional[str]) -> ToolResult: + target = (device_id or "").strip() + if not target: + return ToolResult( + success=False, + error="action=connectivity_test 时 device_id 不能为空。", + ) + + log.info( + "tool.device_manage.connectivity_test.start", + {"device_id": target, "session_id": ctx.session_id}, + ) + + try: + result = await test_device(target) + except DeviceNotFoundError: + return ToolResult( + success=False, + error=f"设备 {target!r} 未找到,请通过 device_manage(action='list') 确认 device_id。", + ) + except Exception as exc: + log.warn( + "tool.device_manage.connectivity_test_failed", + {"device_id": target, "error": str(exc)}, + ) + return ToolResult(success=False, error=f"设备连通性检测失败: {exc}") + + status = "ok" if result.success else "error" + return ToolResult( + success=True, + output={ + "device_id": target, + "connected": result.success, + "status": status, + "message": result.message, + "latency_ms": result.latency_ms, + }, + metadata={ + "device_id": target, + "status": status, + "latency_ms": result.latency_ms, + "card_status_updated": True, + }, + title="设备连通性检测完成", + ) diff --git a/flocks/tool/device/models.py b/flocks/tool/device/models.py index 6c16d9074..a3b5d15d1 100644 --- a/flocks/tool/device/models.py +++ b/flocks/tool/device/models.py @@ -147,7 +147,7 @@ class DeviceIntegrationUpdate(BaseModel): enabled: Optional[bool] = None verify_ssl: Optional[bool] = None #: Partial update: absent keys keep existing value; empty-string secret - #: fields keep the existing secret ("leave blank = keep current" UX). + #: fields clear the existing secret. fields: Optional[Dict[str, str]] = None @@ -187,6 +187,7 @@ class DeviceTemplate(BaseModel): vendor: Optional[str] = None description: Optional[str] = None description_cn: Optional[str] = None + docs_url: Optional[str] = None credential_schema: List[Dict[str, Any]] = Field(default_factory=list) tool_count: int = 0 installed: bool diff --git a/flocks/tool/device/plugin_index.py b/flocks/tool/device/plugin_index.py index d138de485..04d221cf4 100644 --- a/flocks/tool/device/plugin_index.py +++ b/flocks/tool/device/plugin_index.py @@ -234,6 +234,7 @@ def _template_from_descriptor( vendor=_optional_str(provider.get("vendor")), description=description, description_cn=description_cn, + docs_url=_optional_str(provider.get("docs_url")), credential_schema=[ field.model_dump(mode="json") for field in _build_api_service_credential_schema(descriptor.storage_key, provider) diff --git a/flocks/tool/device/prompt.py b/flocks/tool/device/prompt.py index 62c28c94d..14f06bdb0 100644 --- a/flocks/tool/device/prompt.py +++ b/flocks/tool/device/prompt.py @@ -18,8 +18,8 @@ log = Log.create(service="tool.device.prompt") -async def build_device_context_section() -> Optional[str]: - """Return a Markdown block describing the device asset context. +async def build_device_manage_list_section() -> Optional[str]: + """Return a Markdown block describing the device inventory. Returns None when no devices are registered, so the caller can skip injection. diff --git a/flocks/tool/device/secrets.py b/flocks/tool/device/secrets.py index 6f341a8d6..022e9f6cd 100644 --- a/flocks/tool/device/secrets.py +++ b/flocks/tool/device/secrets.py @@ -79,7 +79,8 @@ def persist_fields( Rules: - **Sensitive, non-empty** → write to SecretManager, store placeholder. - - **Sensitive, empty/absent** → keep existing placeholder (no-op). + - **Sensitive, empty** → delete existing placeholder and secret. + - **Sensitive, absent** → keep existing placeholder (no-op). - **Non-sensitive** → store plaintext. - Keys absent from ``incoming`` inherit from ``prior_db_fields``. """ @@ -92,7 +93,17 @@ def persist_fields( for key, value in (incoming or {}).items(): if key in secret_keys: if not value or not value.strip(): - continue # keep existing placeholder + prior = result.pop(key, None) + sid = _parse_placeholder(prior) + if sid: + try: + secrets.delete(sid) + except Exception as exc: + log.warn( + "tool.device.secret.delete_error", + {"id": sid, "error": str(exc)}, + ) + continue sid = _secret_id(device_id, key) try: secrets.set(sid, value) diff --git a/flocks/tool/registry.py b/flocks/tool/registry.py index f76c1fa53..5492f3459 100644 --- a/flocks/tool/registry.py +++ b/flocks/tool/registry.py @@ -172,7 +172,7 @@ def get_schema(self) -> ToolSchema: properties["device_id"] = { "type": "string", "description": ( - "目标设备实例的唯一 ID(UUID),来自 device_context 工具返回的" + "目标设备实例的唯一 ID(UUID),来自 device_manage(action='list') 返回的" " `device_id` 字段。当系统中接入了多台同类型设备(例如多台 TDP)" "时必须传入;只有单台时也建议显式传入以避免歧义。" ), @@ -789,24 +789,25 @@ async def execute( "params": list(kwargs.keys()), }) - device_id = kwargs.pop("device_id", None) + device_id = None per_device_enabled = None if tool.info.source == "device" and tool.info.provider: + requested_device_id = kwargs.pop("device_id", None) try: resolved_device_id, resolution_error = await cls._resolve_device_target( storage_key=tool.info.provider, - requested_device_id=str(device_id).strip() if device_id else None, + requested_device_id=str(requested_device_id).strip() if requested_device_id else None, ) except Exception as exc: log.warn("tool.device.target_resolve_failed", { "tool": tool_name, "provider": tool.info.provider, - "device_id": device_id, + "device_id": requested_device_id, "error": str(exc), }) resolved_device_id = None - resolution_error = "设备目标解析失败,请通过 device_context 工具确认设备后重试。" + resolution_error = "设备目标解析失败,请通过 device_manage(action='list') 确认设备后重试。" if resolution_error: return ToolResult(success=False, error=resolution_error) @@ -850,7 +851,7 @@ async def execute( if not activated: return ToolResult( success=False, - error=f"设备 {device_id!r} 未找到或已禁用,请通过 device_context 工具确认 device_id 是否正确。", + error=f"设备 {device_id!r} 未找到或已禁用,请通过 device_manage(action='list') 确认 device_id 是否正确。", ) result = await tool.execute(ctx, **kwargs) else: @@ -921,11 +922,11 @@ async def _resolve_device_target( requested = next((device for device in devices if device.id == requested_device_id), None) if requested is None or not requested.enabled: return None, ( - f"设备 {requested_device_id!r} 未找到或已禁用,请通过 device_context 工具确认 device_id 是否正确。" + f"设备 {requested_device_id!r} 未找到或已禁用,请通过 device_manage(action='list') 确认 device_id 是否正确。" ) if requested.storage_key != storage_key: return None, ( - f"设备 {requested_device_id!r} 不属于当前工具对应的设备类型,请通过 device_context 工具确认目标设备。" + f"设备 {requested_device_id!r} 不属于当前工具对应的设备类型,请通过 device_manage(action='list') 确认目标设备。" ) return requested.id, None @@ -938,11 +939,11 @@ async def _resolve_device_target( return resolved, None if not enabled_candidates: - return None, "当前没有可用的目标设备,请通过 device_context 工具确认设备状态。" + return None, "当前没有可用的目标设备,请通过 device_manage(action='list') 确认设备状态。" return None, ( "当前存在多台同类型设备,调用前必须显式传入 `device_id`。" - "请先调用 device_context 工具确认目标设备。" + "请先调用 device_manage(action='list') 确认目标设备。" ) @classmethod @@ -1462,8 +1463,8 @@ def _register_builtin_tools(cls) -> None: ("flocks.tool.system", ["question", "model_config", "memory", "flocks_mcp", "session_manage", "slash_command", "tool_search"]), # skill/ — skill management (search, install, status, deps, remove, load) ("flocks.tool.skill", ["flocks_skills", "skill_load"]), - # device/ — security device asset context - ("flocks.tool.device", ["device_context_tool"]), + # device/ — security device asset context and status probes + ("flocks.tool.device", ["manage_tool"]), # channel/ — IM platform messaging ("flocks.tool.channel", ["channel_message", "im_send_message"]), # wecom/ — 企业微信 MCP(文档、智能表格) @@ -1739,6 +1740,32 @@ def _register_dynamic_tools(cls) -> None: # --------------------------------------------------------------------------- +_TOOL_WATCH_SUBDIRS = ("api", "device", "python") + + +def _tool_event_paths(event: object) -> List[str]: + candidate_paths: List[str] = [] + src = getattr(event, "src_path", "") or "" + if src: + candidate_paths.append(src) + dest = getattr(event, "dest_path", "") or "" + if dest: + candidate_paths.append(dest) + return candidate_paths + + +def _tool_path_matches_watch_scope(path: str, subdir: str | None = None) -> bool: + """Return whether ``path`` is under ``plugins/tools//``.""" + parts = Path(path).parts + allowed = {subdir} if subdir else set(_TOOL_WATCH_SUBDIRS) + for idx, part in enumerate(parts): + if part != "tools": + continue + if idx + 1 < len(parts) and parts[idx + 1] in allowed: + return True + return False + + def _tool_event_should_reload(event: object) -> bool: """Return True if a watchdog filesystem event should trigger a plugin reload. @@ -1753,17 +1780,13 @@ def _tool_event_should_reload(event: object) -> bool: Exposed at module scope so it can be unit-tested without spinning up ``watchdog.observers.Observer`` against a temp directory. """ - candidate_paths: List[str] = [] - src = getattr(event, "src_path", "") or "" - if src: - candidate_paths.append(src) - dest = getattr(event, "dest_path", "") or "" - if dest: - candidate_paths.append(dest) + candidate_paths = _tool_event_paths(event) if not candidate_paths: return False for path in candidate_paths: + if not _tool_path_matches_watch_scope(path): + continue if not (path.endswith(".yaml") or path.endswith(".py")): continue fname = os.path.basename(path) @@ -1775,6 +1798,11 @@ def _tool_event_should_reload(event: object) -> bool: return False +def _tool_event_touches_device_plugin(event: object) -> bool: + """Return True when a reload event targets ``tools/device``.""" + return any(_tool_path_matches_watch_scope(path, "device") for path in _tool_event_paths(event)) + + class ToolFileWatcher: """Watch plugin tool directories and auto-reload plugin tools on change. @@ -1788,7 +1816,7 @@ class ToolFileWatcher: """ _DEBOUNCE_SECONDS = 1.0 - _WATCH_SUBDIRS = ("api", "device", "python") + _WATCH_SUBDIRS = _TOOL_WATCH_SUBDIRS def __init__(self) -> None: self._observer: Optional[object] = None @@ -1842,7 +1870,7 @@ def on_any_event(self, event: FileSystemEvent) -> None: return if not _tool_event_should_reload(event): return - watcher._schedule_refresh() + watcher._schedule_refresh(device_changed=_tool_event_touches_device_plugin(event)) handler = _Handler() observer = Observer() @@ -1874,11 +1902,12 @@ def stop(self) -> None: # ---- internal ---- - def _schedule_refresh(self) -> None: + def _schedule_refresh(self, *, device_changed: bool = False) -> None: """Debounced plugin tool reload.""" with self._lock: if self._debounce_timer is not None: self._debounce_timer.cancel() + self._device_changed = getattr(self, "_device_changed", False) or device_changed self._debounce_timer = threading.Timer( self._DEBOUNCE_SECONDS, self._do_refresh ) @@ -1895,31 +1924,38 @@ def _do_refresh(self) -> None: def _run_refresh(self) -> None: try: + if getattr(self, "_device_changed", False): + try: + from flocks.config.api_versioning import discover_api_service_descriptors + from flocks.tool.device.plugin_index import clear_device_template_cache + + clear_device_template_cache() + discover_api_service_descriptors(refresh=True) + except Exception as exc: + log.debug("tool.watcher.device_cache_clear_failed", {"error": str(exc)}) + finally: + self._device_changed = False ToolRegistry.refresh_plugin_tools() log.debug("tool.watcher.reloaded", {"reason": "plugin tool file changed on disk"}) except Exception as e: log.warn("tool.watcher.reload_failed", {"error": str(e)}) def _collect_watch_dirs(self) -> Set[str]: - """Return the api/ and python/ subdirectories that exist and should be watched.""" + """Return plugin roots that exist and should be watched.""" dirs: Set[str] = set() try: from flocks.plugin.loader import DEFAULT_PLUGIN_ROOT - tools_root = DEFAULT_PLUGIN_ROOT / "tools" + user_plugin_root = DEFAULT_PLUGIN_ROOT except Exception: - tools_root = Path.home() / ".flocks" / "plugins" / "tools" + user_plugin_root = Path.home() / ".flocks" / "plugins" - for subdir in self._WATCH_SUBDIRS: - d = str(tools_root / subdir) - if os.path.isdir(d): - dirs.add(d) + if os.path.isdir(user_plugin_root): + dirs.add(str(user_plugin_root)) try: - project_tools_root = Path.cwd() / ".flocks" / "plugins" / "tools" - for subdir in self._WATCH_SUBDIRS: - d = str(project_tools_root / subdir) - if d not in dirs and os.path.isdir(d): - dirs.add(d) + project_plugin_root = Path.cwd() / ".flocks" / "plugins" + if os.path.isdir(project_plugin_root): + dirs.add(str(project_plugin_root)) except Exception: pass diff --git a/flocks/tool/schema/api_service_schema.py b/flocks/tool/schema/api_service_schema.py index 3b59c09c9..781a57b94 100644 --- a/flocks/tool/schema/api_service_schema.py +++ b/flocks/tool/schema/api_service_schema.py @@ -42,6 +42,7 @@ class APIServiceCredentialField(BaseModel): config_key: str secret_id: Optional[str] = None default_value: Optional[str] = None + internal: bool = False # --------------------------------------------------------------------------- @@ -148,6 +149,7 @@ def _load_provider_yaml_metadata(provider_id: str) -> Optional[Dict[str, Any]]: "version": extract_provider_version(prov), "description": prov.get("description"), "description_cn": prov.get("description_cn"), + "docs_url": prov.get("docs_url"), "auth": prov.get("auth"), "credential_fields": prov.get("credential_fields"), "defaults": prov.get("defaults", {}), @@ -251,6 +253,8 @@ def _normalize_api_service_credential_field( if description is not None and not isinstance(description, str): description = str(description) + internal = bool(raw_field.get("internal") or raw_field.get("hidden")) + return APIServiceCredentialField( key=key, label=label, @@ -262,6 +266,7 @@ def _normalize_api_service_credential_field( config_key=config_key, secret_id=secret_id, default_value=default_value, + internal=internal, ) diff --git a/flocks/tool/security/ssh_utils.py b/flocks/tool/security/ssh_utils.py index 5ea1580c9..89406bdea 100644 --- a/flocks/tool/security/ssh_utils.py +++ b/flocks/tool/security/ssh_utils.py @@ -134,6 +134,16 @@ def __init__( def _key(self, session_id: str, host: str, port: int, username: str) -> tuple[str, str, int, str]: return (session_id, host, port, username) + def _is_connection_closed(self, conn: asyncssh.SSHClientConnection) -> bool: + """Return whether an asyncssh connection is already closed.""" + is_closed = getattr(conn, "is_closed", None) + if not callable(is_closed): + return False + try: + return bool(is_closed()) + except Exception: + return False + async def get_connection( self, session_id: str, @@ -145,10 +155,8 @@ async def get_connection( ) -> asyncssh.SSHClientConnection: """Return an existing connection or create a new one and mark it in use. - Stale connections are not proactively detected here — the caller is - responsible for catching connection errors and calling - ``invalidate_connection()`` before retrying. This avoids relying on - asyncssh private attributes for liveness checks. + Closed connections are evicted before reuse. Other stale connection + failures are handled by the caller, which should invalidate and retry. """ key = self._key(session_id, host, port, username) @@ -156,16 +164,24 @@ async def get_connection( self._prune_idle_locked(time.monotonic()) if key not in self._locks: self._locks[key] = asyncio.Lock() + lock = self._locks[key] - async with self._locks[key]: + async with lock: now = time.monotonic() + stale_conn: Optional[asyncssh.SSHClientConnection] = None async with self._global_lock: entry = self._connections.get(key) if entry is not None: - entry.in_use += 1 - entry.last_used = now - self._connections.move_to_end(key) - return entry.conn + if self._is_connection_closed(entry.conn): + self._connections.pop(key, None) + stale_conn = entry.conn + else: + entry.in_use += 1 + entry.last_used = now + self._connections.move_to_end(key) + return entry.conn + if stale_conn is not None: + self._close_connection(stale_conn) connect_kwargs: dict = dict( host=host, @@ -207,15 +223,30 @@ async def release_connection( self._enforce_limits_locked() async def invalidate_connection( - self, session_id: str, host: str, port: int, username: str - ) -> None: - """Close and evict a stale connection so the next call reconnects.""" + self, + session_id: str, + host: str, + port: int, + username: str, + *, + close_active: bool = True, + ) -> bool: + """Close and evict a stale connection so the next call reconnects. + + When ``close_active`` is false, a connection still used by other + commands is left open so one channel failure doesn't interrupt + unrelated in-flight commands sharing the same SSH transport. + """ key = self._key(session_id, host, port, username) async with self._global_lock: + entry = self._connections.get(key) + if entry is not None and not close_active and entry.in_use > 1: + return False entry = self._connections.pop(key, None) self._locks.pop(key, None) if entry is not None: self._close_connection(entry.conn) + return True async def close_session(self, session_id: str) -> None: """Close all connections belonging to *session_id*.""" @@ -375,15 +406,40 @@ async def execute_ssh_command( result.stdout or "", result.stderr or "", ) - except (asyncssh.ConnectionLost, asyncssh.DisconnectError, BrokenPipeError, OSError): + except ( + asyncssh.ConnectionLost, + asyncssh.DisconnectError, + asyncssh.ChannelOpenError, + BrokenPipeError, + OSError, + ): # Stale connection — evict from pool and retry with a fresh one. - invalidated = True - await _pool.invalidate_connection(session_id, host, port, username) - conn = await _pool.get_connection( - session_id=session_id, - host=host, port=port, username=username, - key_path=key_path, password=password, + invalidated = await _pool.invalidate_connection( + session_id, + host, + port, + username, + close_active=False, ) + if invalidated: + conn = await _pool.get_connection( + session_id=session_id, + host=host, port=port, username=username, + key_path=key_path, password=password, + ) + else: + connect_kwargs: dict = dict( + host=host, + port=port, + username=username, + connect_timeout=15, + known_hosts=None, + ) + if key_path: + connect_kwargs["client_keys"] = [key_path] + elif password: + connect_kwargs["password"] = password + conn = await asyncssh.connect(**connect_kwargs) try: result = await asyncio.wait_for( conn.run(command, check=False), @@ -395,7 +451,13 @@ async def execute_ssh_command( result.stderr or "", ) finally: - await _pool.release_connection(session_id, host, port, username) + if invalidated: + await _pool.release_connection(session_id, host, port, username) + else: + try: + conn.close() + except Exception: + pass finally: if not invalidated: await _pool.release_connection(session_id, host, port, username) diff --git a/flocks/tool/task/run_workflow.py b/flocks/tool/task/run_workflow.py index 4c9c27bdd..a52ca2426 100644 --- a/flocks/tool/task/run_workflow.py +++ b/flocks/tool/task/run_workflow.py @@ -13,10 +13,7 @@ from types import SimpleNamespace from typing import Optional, Dict, Any, Union -from flocks.tool.registry import ( - ToolRegistry, ToolCategory, ToolParameter, ParameterType, ToolResult, ToolContext -) -from flocks.storage.storage import Storage +from flocks.tool.registry import ToolRegistry, ToolCategory, ToolParameter, ParameterType, ToolResult, ToolContext from flocks.utils.log import Log from flocks.session.recorder import Recorder from flocks.workflow.execution_store import ( @@ -30,9 +27,10 @@ record_execution_step, record_execution_result, resolve_execution_outcome, - workflow_execution_key, ) from flocks.workflow.fs_store import read_workflow_from_fs, resolve_workflow_id_from_source +from flocks.workflow.store import WorkflowStore +from flocks.tool.truncation import truncate_output log = Log.create(service="tool.run_workflow") @@ -53,7 +51,10 @@ def _get_workflow_runtime(): return None, None, None try: # Prefer in-repo integration (flocks.workflow). Fallback to external package if installed. - from flocks.workflow import RequirementsInstaller as _ReqInstaller, run_workflow as _run + from flocks.workflow import ( + RequirementsInstaller as _ReqInstaller, + run_workflow as _run, + ) from flocks.workflow.runner import RunWorkflowResult as _Result RequirementsInstaller = _ReqInstaller @@ -80,6 +81,16 @@ def _get_workflow_runtime(): return None, None, None +async def _call_workflow_runtime(runtime_fn, call_kwargs: Dict[str, Any]): + """Call either the async process executor or a legacy sync workflow runtime.""" + if inspect.iscoroutinefunction(runtime_fn): + return await runtime_fn(**call_kwargs) + result = await asyncio.to_thread(runtime_fn, **call_kwargs) + if inspect.isawaitable(result): + return await result + return result + + _BASE_DESCRIPTION = """Execute a workflow definition using the flocks-workflow runtime. When to use: @@ -165,6 +176,7 @@ async def _build_description() -> str: try: from flocks.workflow.center import scan_skill_workflows + entries = await scan_skill_workflows() if not entries: result = _BASE_DESCRIPTION @@ -199,96 +211,137 @@ def _format_workflow_result(result: Any, *, include_history: bool = False) -> st session tool output stays concise. The structured history remains available in metadata and persisted execution records. """ - if hasattr(result, '__dict__'): + if hasattr(result, "__dict__"): # RunWorkflowResult object data = result.__dict__ elif isinstance(result, dict): data = result else: return str(result) - + output_lines = [] output_lines.append(f"Status: {data.get('status', 'UNKNOWN')}") - - if data.get('run_id'): - output_lines.append(f"Run ID: {data.get('run_id')}") - - if data.get('steps'): + + if data.get("steps"): output_lines.append(f"Steps executed: {data.get('steps')}") - - if data.get('last_node_id'): + + if data.get("last_node_id"): output_lines.append(f"Last node: {data.get('last_node_id')}") - - if data.get('error'): + + if data.get("error"): output_lines.append(f"\nError: {data.get('error')}") - - if data.get('outputs'): + + if data.get("outputs"): output_lines.append("\nFinal Outputs:") try: - outputs_str = json.dumps(data.get('outputs'), indent=2, ensure_ascii=False) - output_lines.append(outputs_str) + output_lines.append(json.dumps(data.get("outputs"), indent=2, ensure_ascii=False, default=str)) except Exception: - output_lines.append(str(data.get('outputs'))) - - if include_history and data.get('history'): - history = data.get('history', []) + output_lines.append(str(data.get("outputs"))) + + if include_history and data.get("history"): + history = data.get("history", []) if history: - output_lines.append(f"\n{'='*80}") + output_lines.append(f"\n{'=' * 80}") output_lines.append(f"Execution History ({len(history)} steps):") - output_lines.append('='*80) - + output_lines.append("=" * 80) + for i, step in enumerate(history, 1): - node_id = step.get('node_id', 'unknown') - duration_ms = step.get('duration_ms') - error = step.get('error') - + node_id = step.get("node_id", "unknown") + duration_ms = step.get("duration_ms") + error = step.get("error") + output_lines.append(f"\n[Step {i}] Node: {node_id}") if duration_ms is not None: output_lines.append(f" Duration: {duration_ms:.2f}ms") - + # Show inputs - inputs = step.get('inputs', {}) + inputs = step.get("inputs", {}) if inputs: output_lines.append(" Inputs:") try: inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False) - for line in inputs_str.split('\n'): + for line in inputs_str.split("\n"): output_lines.append(f" {line}") except Exception: output_lines.append(f" {str(inputs)}") - + # Show outputs - outputs = step.get('outputs', {}) + outputs = step.get("outputs", {}) if outputs: output_lines.append(" Outputs:") try: outputs_str = json.dumps(outputs, indent=4, ensure_ascii=False) - for line in outputs_str.split('\n'): + for line in outputs_str.split("\n"): output_lines.append(f" {line}") except Exception: output_lines.append(f" {str(outputs)}") - + # Show stdout if present - stdout = step.get('stdout', '') + stdout = step.get("stdout", "") if stdout: output_lines.append(" Stdout:") - for line in stdout.split('\n'): + for line in stdout.split("\n"): output_lines.append(f" {line}") - + # Show error if present if error: output_lines.append(f" Error: {error}") - traceback_info = step.get('traceback', '') + traceback_info = step.get("traceback", "") if traceback_info: output_lines.append(" Traceback:") - for line in traceback_info.split('\n'): + for line in traceback_info.split("\n"): output_lines.append(f" {line}") - - output_lines.append(f"\n{'='*80}") - + + output_lines.append(f"\n{'=' * 80}") + return "\n".join(output_lines) +def _format_workflow_result_for_tool(result: Any) -> tuple[str, bool, Optional[str]]: + """Format raw workflow output for the agent and save oversized text.""" + output = _format_workflow_result(result) + truncated = truncate_output(output) + return truncated.content, truncated.truncated, truncated.output_path + + +def _output_keys(outputs: Any) -> list[str]: + if isinstance(outputs, dict): + return [str(key) for key in outputs.keys()] + return [] + + +def _workflow_tool_metadata( + *, + workflow_id: str, + workflow_name: str, + total_nodes: Optional[int], + workflow_execution_id: Optional[str], + status: str, + steps: Any = 0, + last_node_id: Optional[str] = None, + outputs: Any = None, + history_count: int = 0, + output_truncated: bool = False, + output_path: Optional[str] = None, +) -> Dict[str, Any]: + metadata: Dict[str, Any] = { + "workflow_id": workflow_id, + "workflow_name": workflow_name, + "total_nodes": total_nodes, + "workflow_execution_id": workflow_execution_id, + "status": status, + "steps": steps, + "last_node_id": last_node_id, + "has_output": bool(outputs), + "output_keys": _output_keys(outputs), + "output_truncated": output_truncated, + "history_count": history_count, + } + if output_path: + metadata["output_path"] = output_path + return metadata + + async def _record_workflow_tool_result(workflow_id: str, result: Any) -> None: """Record workflow tool execution to JSONL (best-effort).""" try: @@ -343,7 +396,7 @@ async def _record_workflow_tool_result(workflow_id: str, result: Any) -> None: name="use_llm", type=ParameterType.BOOLEAN, description=( - "Enable LLM-backed code generation for `type=\"logic\"` nodes (when code is missing). " + 'Enable LLM-backed code generation for `type="logic"` nodes (when code is missing). ' "Recommended to keep enabled for logic-node workflows." ), required=False, @@ -354,22 +407,22 @@ async def _record_workflow_tool_result(workflow_id: str, result: Any) -> None: type=ParameterType.BOOLEAN, description="Whether to automatically install requirements declared in workflow metadata", required=False, - default=True + default=True, ), ToolParameter( name="timeout_s", type=ParameterType.NUMBER, description="Execution timeout in seconds (optional)", - required=False + required=False, ), ToolParameter( name="trace", type=ParameterType.BOOLEAN, description="Enable execution tracing for debugging", required=False, - default=False + default=False, ), - ] + ], ) async def run_workflow_tool( ctx: ToolContext, @@ -382,7 +435,7 @@ async def run_workflow_tool( ) -> ToolResult: """ Execute a workflow using flocks-workflow runtime - + Args: ctx: Tool context workflow: Workflow definition (dict), JSON string, or a workflow JSON file path @@ -391,7 +444,7 @@ async def run_workflow_tool( ensure_requirements: Whether to install requirements automatically timeout_s: Execution timeout in seconds trace: Enable execution tracing - + Returns: ToolResult with workflow execution results """ @@ -402,20 +455,15 @@ async def run_workflow_tool( req_installer, _run_workflow_fn, RunWorkflowResultCls = _get_workflow_runtime() if _run_workflow_fn is None or RunWorkflowResultCls is None: - return ToolResult( - success=False, - error="flocks-workflow package is not available. Please check code" - ) - + return ToolResult(success=False, error="flocks-workflow package is not available. Please check code") + # Validate workflow parameter if not workflow: - return ToolResult( - success=False, - error="workflow parameter is required" - ) - + return ToolResult(success=False, error="workflow parameter is required") + # Accept workflow as dict, JSON string, or file path. workflow_source: Union[Dict[str, Any], Path] + registered_workflow_id: Optional[str] = None if isinstance(workflow, str): raw = workflow.strip() # Try to parse as JSON first (handles JSON-encoded dicts or strings). @@ -441,15 +489,17 @@ async def run_workflow_tool( error=( f"Workflow file not found: {parsed!r}. " "Provide a valid workflow JSON file path or a workflow dict." - ) + ), ) elif parsed is None: # json.loads raised JSONDecodeError — raw is not JSON. # First try to resolve as a registered workflow ID, then fall back to file path. existing_workflow = read_workflow_from_fs(raw) if existing_workflow is not None: - workflow_source = existing_workflow["workflowJson"] - raw = existing_workflow["id"] + workflow_source = dict(existing_workflow["workflowJson"]) + registered_workflow_id = str(existing_workflow["id"]) + workflow_source["id"] = registered_workflow_id + raw = registered_workflow_id else: p = Path(raw).expanduser() if p.exists() and p.is_file(): @@ -460,7 +510,7 @@ async def run_workflow_tool( error=( "Unsupported workflow string. Provide a workflow ID, workflow JSON string, " "or a valid workflow JSON file path." - ) + ), ) else: # json.loads returned list / int / bool — not a valid workflow parameter. @@ -469,16 +519,15 @@ async def run_workflow_tool( error=( f"Invalid workflow parameter: expected a workflow dict or a file path string, " f"got JSON-decoded {type(parsed).__name__} ({parsed!r})." - ) + ), ) elif isinstance(workflow, dict): workflow_source = workflow else: return ToolResult( - success=False, - error=f"workflow must be a dictionary or string, got {type(workflow).__name__}" + success=False, error=f"workflow must be a dictionary or string, got {type(workflow).__name__}" ) - + # Sanity-check dict workflows: must have at least a `start` field so we # surface a clear error instead of a confusing Pydantic validation message. if isinstance(workflow_source, dict) and "start" not in workflow_source: @@ -488,7 +537,7 @@ async def run_workflow_tool( "Invalid workflow definition: the `start` field is required. " "Make sure you pass the workflow JSON (with `start`, `nodes`, `edges`) " "as the `workflow` parameter, not the execution inputs." - ) + ), ) # Request permission (workflow execution can run arbitrary code) @@ -501,11 +550,10 @@ async def run_workflow_tool( workflow_id = str(workflow_source) workflow_inputs = inputs or {} - canonical_workflow_id = resolve_workflow_id_from_source(workflow_source) + canonical_workflow_id = registered_workflow_id or resolve_workflow_id_from_source(workflow_source) display_workflow_id = canonical_workflow_id or workflow_id tracked_execution: Optional[Dict[str, Any]] = None tracked_step_count = 0 - tracked_exec_key: Optional[str] = None pending_step_index: Optional[int] = None pending_step: Optional[Dict[str, Any]] = None loop = asyncio.get_running_loop() @@ -514,25 +562,26 @@ def _emit_metadata(metadata: Dict[str, Any]) -> None: loop.call_soon_threadsafe(ctx.metadata, metadata) def _update_execution_progress(update_fields: Dict[str, Any]) -> None: - if not tracked_exec_key: - return try: if tracked_execution is None: return tracked_execution.update(update_fields) asyncio.run_coroutine_threadsafe( - Storage.write(tracked_exec_key, compact_execution_summary(tracked_execution)), + WorkflowStore.upsert_execution(compact_execution_summary(tracked_execution)), loop, ).result(timeout=5) except Exception as exc: - log.warning("run_workflow.execution_progress.write_failed", { - "workflow_id": display_workflow_id, - "exec_id": tracked_execution["id"] if tracked_execution else None, - "error": str(exc), - }) + log.warning( + "run_workflow.execution_progress.write_failed", + { + "workflow_id": display_workflow_id, + "exec_id": tracked_execution["id"] if tracked_execution else None, + "error": str(exc), + }, + ) def _on_step_start( - run_id: Optional[str], + _run_id: Optional[str], step_index: int, node: Any, _inputs: Dict[str, Any], @@ -547,38 +596,43 @@ def _on_step_start( outputs=None, ) pending_step_index = step_index - pending_step = { - "node_id": current_node_id, - "node_type": current_node_type, - "inputs": _inputs if isinstance(_inputs, dict) else {}, - "outputs": {}, - "error": "Run cancelled before node completed", - } + pending_step = compact_step_for_storage( + { + "node_id": current_node_id, + "node_type": current_node_type, + "inputs": _inputs if isinstance(_inputs, dict) else {}, + "outputs": {}, + "error": "Run cancelled before node completed", + } + ) if tracked_execution is not None: - tracked_execution.update({ - "currentNodeId": current_node_id, - "currentNodeType": current_node_type, - "currentPhase": "running", - "currentStepIndex": step_index, - "loopProgress": loop_progress, - "updatedAt": int(time.time() * 1000), - }) - _emit_metadata({ - "title": f"Running workflow: {workflow_name}", - "metadata": { - "workflow_id": display_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, - "run_id": run_id, - "status": "running", - "phase": "running", - "current_node_id": current_node_id, - "current_node_type": current_node_type, - "step_index": step_index, - "loop_progress": loop_progress, - }, - }) + _update_execution_progress( + { + "currentNodeId": current_node_id, + "currentNodeType": current_node_type, + "currentPhase": "running", + "currentStepIndex": step_index, + "loopProgress": loop_progress, + "updatedAt": int(time.time() * 1000), + } + ) + _emit_metadata( + { + "title": f"Running workflow: {workflow_name}", + "metadata": { + "workflow_id": display_workflow_id, + "workflow_name": workflow_name, + "total_nodes": workflow_total_nodes, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": "running", + "phase": "running", + "current_node_id": current_node_id, + "current_node_type": current_node_type, + "step_index": step_index, + "loop_progress": loop_progress, + }, + } + ) return step_index def _on_step_complete(step_result: Any) -> None: @@ -601,15 +655,17 @@ def _on_step_complete(step_result: Any) -> None: ) tracked_step_count = step_index if tracked_execution is not None: - tracked_execution.update({ - "stepCount": tracked_step_count, - "currentNodeId": step_dict.get("node_id"), - "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), - "currentPhase": "running", - "currentStepIndex": tracked_step_count, - "loopProgress": loop_progress, - "updatedAt": int(time.time() * 1000), - }) + tracked_execution.update( + { + "stepCount": tracked_step_count, + "currentNodeId": step_dict.get("node_id"), + "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), + "currentPhase": "running", + "currentStepIndex": tracked_step_count, + "loopProgress": loop_progress, + "updatedAt": int(time.time() * 1000), + } + ) if tracked_execution is not None: try: asyncio.run_coroutine_threadsafe( @@ -617,46 +673,49 @@ def _on_step_complete(step_result: Any) -> None: loop, ).result(timeout=5) except Exception as exc: - log.warning("run_workflow.execution_step.write_failed", { - "workflow_id": display_workflow_id, - "exec_id": tracked_execution["id"], - "step_index": step_index, - "error": str(exc), - }) + log.warning( + "run_workflow.execution_step.write_failed", + { + "workflow_id": display_workflow_id, + "exec_id": tracked_execution["id"], + "step_index": step_index, + "error": str(exc), + }, + ) if tracked_step_count % _PROGRESS_FLUSH_EVERY_STEPS == 0: - _update_execution_progress({ - "stepCount": tracked_step_count, - "currentNodeId": step_dict.get("node_id"), - "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), - "currentPhase": "running", - "currentStepIndex": tracked_step_count, - "loopProgress": loop_progress, - "updatedAt": int(time.time() * 1000), - }) - _emit_metadata({ - "title": f"Running workflow: {workflow_name}", - "metadata": { - "workflow_id": display_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, - "status": "running", - "phase": "running", - "current_node_id": step_dict.get("node_id"), - "current_node_type": step_dict.get("node_type") or step_dict.get("type"), - "step_index": tracked_step_count, - "step_count": tracked_step_count, - "loop_progress": loop_progress, - }, - }) + _update_execution_progress( + { + "stepCount": tracked_step_count, + "currentNodeId": step_dict.get("node_id"), + "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), + "currentPhase": "running", + "currentStepIndex": tracked_step_count, + "loopProgress": loop_progress, + "updatedAt": int(time.time() * 1000), + } + ) + _emit_metadata( + { + "title": f"Running workflow: {workflow_name}", + "metadata": { + "workflow_id": display_workflow_id, + "workflow_name": workflow_name, + "total_nodes": workflow_total_nodes, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": "running", + "phase": "running", + "current_node_id": step_dict.get("node_id"), + "current_node_type": step_dict.get("node_type") or step_dict.get("type"), + "step_index": tracked_step_count, + "step_count": tracked_step_count, + "loop_progress": loop_progress, + }, + } + ) return async def _flush_pending_step() -> None: - if ( - tracked_execution is None - or pending_step_index is None - or pending_step is None - ): + if tracked_execution is None or pending_step_index is None or pending_step is None: return try: await record_execution_step( @@ -665,13 +724,16 @@ async def _flush_pending_step() -> None: pending_step, ) except Exception as exc: - log.warning("run_workflow.pending_step.write_failed", { - "workflow_id": display_workflow_id, - "exec_id": tracked_execution["id"], - "step_index": pending_step_index, - "error": str(exc), - }) - + log.warning( + "run_workflow.pending_step.write_failed", + { + "workflow_id": display_workflow_id, + "exec_id": tracked_execution["id"], + "step_index": pending_step_index, + "error": str(exc), + }, + ) + await ctx.ask( permission="run_workflow", patterns=[workflow_id, workflow_name], @@ -681,37 +743,41 @@ async def _flush_pending_step() -> None: "workflow_name": workflow_name, "ensure_requirements": ensure_requirements, "use_llm": use_llm, - } + }, ) - + if canonical_workflow_id: tracked_execution = await create_execution_record( canonical_workflow_id, input_params=workflow_inputs, ) - tracked_exec_key = workflow_execution_key(tracked_execution["id"]) # Update metadata to show workflow is running - _emit_metadata({ - "title": f"Running workflow: {workflow_name}", - "metadata": { - "workflow_id": display_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, - "status": "running", - "phase": "queued", - "step_index": 0, - }, - }) + _emit_metadata( + { + "title": f"Running workflow: {workflow_name}", + "metadata": { + "workflow_id": display_workflow_id, + "workflow_name": workflow_name, + "total_nodes": workflow_total_nodes, + "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, + "status": "running", + "phase": "queued", + "step_index": 0, + }, + } + ) try: # Execute workflow - log.info("run_workflow.execute.start", { - "workflow_id": workflow_id, - "workflow_name": workflow_name, - "ensure_requirements": ensure_requirements, - }) + log.info( + "run_workflow.execute.start", + { + "workflow_id": workflow_id, + "workflow_name": workflow_name, + "ensure_requirements": ensure_requirements, + }, + ) execution_started_at = time.time() nested_tool_ctx = _create_nested_tool_context(ctx) @@ -728,31 +794,42 @@ async def _flush_pending_step() -> None: } # Backward-compatibility: older runtimes may not accept `use_llm`. + supports_workflow_id = False supports_use_llm = False + supports_run_id = False supports_step_start = False supports_cancel = False try: sig = inspect.signature(_run_workflow_fn) - supports_use_llm = ( - "use_llm" in sig.parameters - or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + supports_workflow_id = "workflow_id" in sig.parameters or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + supports_use_llm = "use_llm" in sig.parameters or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + supports_run_id = "run_id" in sig.parameters or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) - supports_step_start = ( - "on_step_start" in sig.parameters - or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + supports_step_start = "on_step_start" in sig.parameters or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) - supports_cancel = ( - "cancel" in sig.parameters - or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + supports_cancel = "cancel" in sig.parameters or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ) except Exception: # Best-effort: assume supported. + supports_workflow_id = True supports_use_llm = True + supports_run_id = True supports_step_start = True supports_cancel = True + if supports_workflow_id: + call_kwargs["workflow_id"] = display_workflow_id if supports_use_llm: call_kwargs["use_llm"] = use_llm + if supports_run_id and tracked_execution: + call_kwargs["run_id"] = tracked_execution["id"] if supports_step_start: call_kwargs["on_step_start"] = _on_step_start call_kwargs["on_step_complete"] = _on_step_complete @@ -760,21 +837,27 @@ async def _flush_pending_step() -> None: call_kwargs["cancel"] = ctx.abort.is_set try: - result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) + result = await _call_workflow_runtime(_run_workflow_fn, call_kwargs) except TypeError as te: # Fallback if the runtime rejects `use_llm` (unexpected keyword). - if supports_use_llm and "use_llm" in str(te): + if supports_workflow_id and "workflow_id" in str(te): + call_kwargs.pop("workflow_id", None) + result = await _call_workflow_runtime(_run_workflow_fn, call_kwargs) + elif supports_use_llm and "use_llm" in str(te): call_kwargs.pop("use_llm", None) - result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) + result = await _call_workflow_runtime(_run_workflow_fn, call_kwargs) + elif supports_run_id and "run_id" in str(te): + call_kwargs.pop("run_id", None) + result = await _call_workflow_runtime(_run_workflow_fn, call_kwargs) elif supports_step_start and "on_step_start" in str(te): call_kwargs.pop("on_step_start", None) - result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) + result = await _call_workflow_runtime(_run_workflow_fn, call_kwargs) elif supports_cancel and "cancel" in str(te): call_kwargs.pop("cancel", None) - result = await asyncio.to_thread(_run_workflow_fn, **call_kwargs) + result = await _call_workflow_runtime(_run_workflow_fn, call_kwargs) else: raise - + # Format result if RunWorkflowResultCls and isinstance(result, RunWorkflowResultCls): result_dict = result.__dict__ @@ -782,19 +865,22 @@ async def _flush_pending_step() -> None: result_dict = result else: result_dict = {"status": "UNKNOWN", "output": str(result)} - + status = result_dict.get("status", "UNKNOWN") success = status == "SUCCEEDED" error = result_dict.get("error") - - output = _format_workflow_result(result_dict) - - log.info("run_workflow.execute.complete", { - "workflow_id": workflow_id, - "status": status, - "success": success, - "steps": result_dict.get("steps", 0), - }) + + output, output_truncated, output_path = _format_workflow_result_for_tool(result_dict) + + log.info( + "run_workflow.execute.complete", + { + "workflow_id": workflow_id, + "status": status, + "success": success, + "steps": result_dict.get("steps", 0), + }, + ) # Append-only recording for audit/replay await _record_workflow_tool_result(display_workflow_id, result_dict) @@ -809,7 +895,7 @@ async def _flush_pending_step() -> None: final_step_count = tracked_step_count if pending_step_index is not None: final_step_count = max(final_step_count, pending_step_index) - if tracked_execution and canonical_workflow_id and tracked_exec_key: + if tracked_execution and canonical_workflow_id: current_data = dict(tracked_execution) outcome_result = result if not hasattr(outcome_result, "status"): @@ -819,42 +905,43 @@ async def _flush_pending_step() -> None: error=result_dict.get("error"), ) status_value, error_message = resolve_execution_outcome(outcome_result) # type: ignore[arg-type] - current_data.update({ - "outputResults": compact_outputs_for_storage(result_dict.get("outputs")), - "status": status_value, - "finishedAt": int(time.time() * 1000), - "duration": time.time() - execution_started_at, - "executionLog": compacted_history, - "stepCount": final_step_count, - "errorMessage": error_message, - "currentNodeId": result_dict.get("last_node_id"), - "currentPhase": status_value, - "currentStepIndex": final_step_count, - "updatedAt": int(time.time() * 1000), - }) + current_data.update( + { + "outputResults": compact_outputs_for_storage(result_dict.get("outputs")), + "status": status_value, + "finishedAt": int(time.time() * 1000), + "duration": time.time() - execution_started_at, + "executionLog": compacted_history, + "stepCount": final_step_count, + "errorMessage": error_message, + "currentNodeId": result_dict.get("last_node_id"), + "currentPhase": status_value, + "currentStepIndex": final_step_count, + "updatedAt": int(time.time() * 1000), + } + ) await record_execution_result( canonical_workflow_id, tracked_execution["id"], current_data, ) - _emit_metadata({ - "title": f"Workflow: {workflow_name}", - "metadata": { - "workflow_id": canonical_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"], - "run_id": result_dict.get("run_id"), - "status": status_value, - "phase": status_value, - "current_node_id": result_dict.get("last_node_id"), - "step_index": final_step_count, - "step_count": final_step_count, - "loop_progress": current_data.get("loopProgress"), - }, - }) - - compacted_outputs = compact_outputs_for_storage(result_dict.get("outputs")) + _emit_metadata( + { + "title": f"Workflow: {workflow_name}", + "metadata": { + "workflow_id": canonical_workflow_id, + "workflow_name": workflow_name, + "total_nodes": workflow_total_nodes, + "workflow_execution_id": tracked_execution["id"], + "status": status_value, + "phase": status_value, + "current_node_id": result_dict.get("last_node_id"), + "step_index": final_step_count, + "step_count": final_step_count, + "loop_progress": current_data.get("loopProgress"), + }, + } + ) # If workflow failed, include error in ToolResult if not success and error: @@ -863,85 +950,94 @@ async def _flush_pending_step() -> None: error=error, output=output, # Also include formatted output for context title=f"Workflow: {workflow_name}", - metadata={ - "workflow_id": display_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, - "status": status_value, - "steps": result_dict.get("steps", 0), - "run_id": result_dict.get("run_id"), - "last_node_id": result_dict.get("last_node_id"), - "outputs": compacted_outputs, - "history": [], - "history_count": history_count, - } + metadata=_workflow_tool_metadata( + workflow_id=display_workflow_id, + workflow_name=workflow_name, + total_nodes=workflow_total_nodes, + workflow_execution_id=tracked_execution["id"] if tracked_execution else None, + status=status_value, + steps=result_dict.get("steps", 0), + last_node_id=result_dict.get("last_node_id"), + outputs=result_dict.get("outputs"), + history_count=history_count, + output_truncated=output_truncated, + output_path=output_path, + ), + truncated=output_truncated, ) - + return ToolResult( success=success, output=output, title=f"Workflow: {workflow_name}", - metadata={ - "workflow_id": display_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, - "status": status_value, - "steps": result_dict.get("steps", 0), - "run_id": result_dict.get("run_id"), - "last_node_id": result_dict.get("last_node_id"), - "outputs": compacted_outputs, - "history": [], - "history_count": history_count, - } + metadata=_workflow_tool_metadata( + workflow_id=display_workflow_id, + workflow_name=workflow_name, + total_nodes=workflow_total_nodes, + workflow_execution_id=tracked_execution["id"] if tracked_execution else None, + status=status_value, + steps=result_dict.get("steps", 0), + last_node_id=result_dict.get("last_node_id"), + outputs=result_dict.get("outputs"), + history_count=history_count, + output_truncated=output_truncated, + output_path=output_path, + ), + truncated=output_truncated, ) - + except Exception as e: error_msg = str(e) - log.error("run_workflow.execute.error", { - "workflow_id": workflow_id, - "error": error_msg, - }) - if tracked_execution and canonical_workflow_id and tracked_exec_key: + log.error( + "run_workflow.execute.error", + { + "workflow_id": workflow_id, + "error": error_msg, + }, + ) + if tracked_execution and canonical_workflow_id: current_data = dict(tracked_execution) - current_data.update({ - "status": "error", - "finishedAt": int(time.time() * 1000), - "errorMessage": error_msg, - "executionLog": [], - "stepCount": tracked_step_count, - "currentPhase": "error", - "currentStepIndex": tracked_step_count, - "updatedAt": int(time.time() * 1000), - }) + current_data.update( + { + "status": "error", + "finishedAt": int(time.time() * 1000), + "errorMessage": error_msg, + "executionLog": [], + "stepCount": tracked_step_count, + "currentPhase": "error", + "currentStepIndex": tracked_step_count, + "updatedAt": int(time.time() * 1000), + } + ) await record_execution_result( canonical_workflow_id, tracked_execution["id"], current_data, ) - _emit_metadata({ - "title": f"Workflow: {workflow_name}", - "metadata": { - "workflow_id": canonical_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"], - "status": "error", - "phase": "error", - "step_index": tracked_step_count, - }, - }) - + _emit_metadata( + { + "title": f"Workflow: {workflow_name}", + "metadata": { + "workflow_id": canonical_workflow_id, + "workflow_name": workflow_name, + "total_nodes": workflow_total_nodes, + "workflow_execution_id": tracked_execution["id"], + "status": "error", + "phase": "error", + "step_index": tracked_step_count, + }, + } + ) + return ToolResult( success=False, error=f"Workflow execution failed: {error_msg}", title=f"Workflow: {workflow_name}", - metadata={ - "workflow_id": display_workflow_id, - "workflow_name": workflow_name, - "total_nodes": workflow_total_nodes, - "workflow_execution_id": tracked_execution["id"] if tracked_execution else None, - "status": "FAILED", - } + metadata=_workflow_tool_metadata( + workflow_id=display_workflow_id, + workflow_name=workflow_name, + total_nodes=workflow_total_nodes, + workflow_execution_id=tracked_execution["id"] if tracked_execution else None, + status="FAILED", + ), ) diff --git a/flocks/updater/models.py b/flocks/updater/models.py index bde9a8e27..7f9336632 100644 --- a/flocks/updater/models.py +++ b/flocks/updater/models.py @@ -21,6 +21,12 @@ class VersionInfo(BaseModel): current_version: str latest_version: str | None = None + current_core_version: str | None = None + latest_core_version: str | None = None + current_bundle_version: str | None = None + latest_bundle_version: str | None = None + current_pro_component_version: str | None = None + latest_pro_component_version: str | None = None edition: Literal["flocks", "flockspro"] = "flocks" has_update: bool = False release_notes: str | None = None diff --git a/flocks/updater/updater.py b/flocks/updater/updater.py index ce6c3c09f..d076af229 100644 --- a/flocks/updater/updater.py +++ b/flocks/updater/updater.py @@ -80,6 +80,15 @@ class ConsoleManifestRelease: bundle_format: str manifest: dict[str, Any] console_session_token: str | None = None + release_id: str | None = None + bundle_release_id: str | None = None + + +@dataclass(frozen=True) +class ProVersionState: + bundle_version: str + core_version: str + pro_component_version: str | None def _record_update_journal(message: str) -> None: @@ -952,7 +961,7 @@ async def _resolve_sources_for_edition(configured_sources: list[str]) -> list[st The generic Flocks update endpoint checks OSS releases. Pro bundle upgrades opt into the Console manifest explicitly via ``force_console_manifest``. """ - return list(configured_sources) + return [source for source in configured_sources if source != "console-manifest"] def _is_flockspro_license_active() -> bool: @@ -1074,14 +1083,57 @@ def _archive_format_for_url(url: str, manifest_format: str | None = None) -> str def _console_manifest_display_version(data: dict[str, Any]) -> str: - component_version = str(data.get("flockspro_component_version") or "").strip() - if component_version: - return component_version display_version = str(data.get("display_version") or data.get("version") or data.get("latest_version") or "").strip() if display_version: return display_version + core_version = str(data.get("core_version") or "").strip() + if core_version: + return core_version + oss_version = str(data.get("oss_version") or "").strip() + if oss_version: + return oss_version compare_version = str(data.get("compare_version") or "").strip() - return f"pro-v{compare_version}" if compare_version else "" + return f"v{compare_version}" if compare_version and not compare_version.startswith(("v", "V")) else compare_version + + +def _pro_bundle_core_version(data: dict[str, Any]) -> str: + core_version = str(data.get("core_version") or "").strip() + if core_version: + return core_version + oss_version = str(data.get("oss_version") or "").strip() + if oss_version: + return oss_version + return _console_manifest_display_version(data) + + +def _pro_bundle_oss_version(data: dict[str, Any]) -> str: + return _pro_bundle_core_version(data) + + +def _version_label(version: str | None) -> str: + normalized = str(version or "").strip() + if not normalized: + return "" + return normalized if normalized.startswith(("v", "V")) else f"v{normalized}" + + +def _is_pro_bundle_oss_older_than_local(manifest: dict[str, Any], current_version: str | None = None) -> bool: + bundle_oss_version = _pro_bundle_oss_version(manifest) + if not bundle_oss_version: + return False + local_version = str(current_version or get_current_version() or "").strip() + if not local_version: + return False + return _parse_version(bundle_oss_version) < _parse_version(local_version) + + +def _effective_pro_bundle_manifest(manifest: dict[str, Any], effective_oss_version: str) -> dict[str, Any]: + payload = dict(manifest) + effective_label = _version_label(effective_oss_version) + if effective_label: + payload["core_version"] = effective_label + payload["oss_version"] = effective_label + return payload def _archive_filename_for_format(latest_tag: str, fmt: str) -> str: @@ -1258,6 +1310,8 @@ async def _fetch_console_manifest_release_info(console_session_token: str | None if not bundle_url: raise ValueError("manifest 响应缺少 bundle_url") bundle_format = _archive_format_for_url(bundle_url, str(data.get("bundle_format") or data.get("archive_format") or "")) + release_id = str(data.get("release_id") or data.get("bundle_release_id") or "").strip() or None + bundle_release_id = str(data.get("bundle_release_id") or data.get("release_id") or "").strip() or None return ConsoleManifestRelease( version=latest, release_notes=data.get("release_notes") or data.get("notes"), @@ -1267,6 +1321,8 @@ async def _fetch_console_manifest_release_info(console_session_token: str | None bundle_format=bundle_format, manifest=data, console_session_token=token, + release_id=release_id, + bundle_release_id=bundle_release_id, ) @@ -1779,13 +1835,44 @@ async def run_handoff_upgrade_tasks( return None +def _merge_console_manifest_release_identity( + bundle_manifest: dict[str, Any], + console_manifest: dict[str, Any] | None, +) -> dict[str, Any]: + if not console_manifest: + return bundle_manifest + merged = dict(bundle_manifest) + release_id = str(console_manifest.get("release_id") or console_manifest.get("bundle_release_id") or "").strip() + bundle_release_id = str(console_manifest.get("bundle_release_id") or console_manifest.get("release_id") or "").strip() + if release_id and not merged.get("release_id"): + merged["release_id"] = release_id + if bundle_release_id and not merged.get("bundle_release_id"): + merged["bundle_release_id"] = bundle_release_id + for key in ("display_version", "version", "latest_version", "compare_version", "flockspro_component_version", "build_id"): + if console_manifest.get(key): + merged[key] = console_manifest.get(key) + core_version = console_manifest.get("core_version") or console_manifest.get("oss_version") or merged.get("core_version") or merged.get("oss_version") + if core_version: + merged["core_version"] = core_version + merged["oss_version"] = core_version + return merged + + def _write_pro_bundle_install_marker(manifest: dict[str, Any], *, bundle_sha256: str | None = None) -> None: marker = _flocks_root() / "run" / "pro-bundle-installed.json" marker.parent.mkdir(parents=True, exist_ok=True) + release_id = manifest.get("release_id") or manifest.get("bundle_release_id") + bundle_release_id = manifest.get("bundle_release_id") or manifest.get("release_id") + display_version = _console_manifest_display_version(manifest) + core_version = manifest.get("core_version") or manifest.get("oss_version") payload = { - "display_version": manifest.get("display_version"), - "installed_version": manifest.get("display_version"), - "oss_version": manifest.get("oss_version"), + "release_id": release_id, + "bundle_release_id": bundle_release_id, + "bundle_version": display_version, + "display_version": display_version, + "installed_version": display_version, + "core_version": core_version, + "oss_version": core_version, "flockspro_component_version": manifest.get("flockspro_component_version"), "build_id": manifest.get("build_id"), "bundle_sha256": bundle_sha256 or manifest.get("bundle_sha256"), @@ -2435,25 +2522,75 @@ def _read_pyproject_version() -> str: return "" -def _read_pro_bundle_installed_version() -> str: +def _read_pro_bundle_install_marker() -> dict[str, Any]: marker = _flocks_root() / "run" / "pro-bundle-installed.json" try: payload = json.loads(marker.read_text(encoding="utf-8")) except Exception: - return "" - if not isinstance(payload, dict): - return "" - component_version = str(payload.get("flockspro_component_version") or "").strip() - if component_version: - return component_version - installed_version = str(payload.get("installed_version") or "").strip() - if installed_version.startswith(("pro-v", "pro-V")): - return installed_version - if installed_version: - return f"pro-{installed_version}" if installed_version.startswith(("v", "V")) else f"pro-v{installed_version}" + return {} + return payload if isinstance(payload, dict) else {} + + +def _read_pro_bundle_installed_bundle_version() -> str: + payload = _read_pro_bundle_install_marker() + for key in ("bundle_version", "installed_version", "display_version"): + version = str(payload.get(key) or "").strip() + if version: + return version return "" +def _read_pro_bundle_installed_core_version() -> str: + payload = _read_pro_bundle_install_marker() + for key in ("core_version", "oss_version"): + version = str(payload.get(key) or "").strip() + if version: + return version + return "" + + +def _read_pro_bundle_installed_component_version() -> str: + payload = _read_pro_bundle_install_marker() + return str(payload.get("flockspro_component_version") or "").strip() + + +def _current_pro_version_state(local_core_version: str) -> ProVersionState: + core_version = _pick_newer_version(_read_pro_bundle_installed_core_version(), local_core_version) + return ProVersionState( + bundle_version=_version_label(_read_pro_bundle_installed_bundle_version()), + core_version=_version_label(core_version), + pro_component_version=_read_pro_bundle_installed_component_version() or None, + ) + + +def _latest_pro_version_state(manifest_info: ConsoleManifestRelease) -> ProVersionState: + return ProVersionState( + bundle_version=_version_label(manifest_info.version), + core_version=_version_label(_pro_bundle_core_version(manifest_info.manifest)), + pro_component_version=str(manifest_info.manifest.get("flockspro_component_version") or "").strip() or None, + ) + + +def _is_newer_version(latest: str | None, current: str | None) -> bool: + latest_version = str(latest or "").strip() + if not latest_version: + return False + current_version = str(current or "").strip() + if not current_version: + return True + return _parse_version(latest_version) > _parse_version(current_version) + + +def _pick_newer_version(left: str | None, right: str | None) -> str: + left_version = str(left or "").strip() + right_version = str(right or "").strip() + if not left_version: + return right_version + if not right_version: + return left_version + return left_version if _parse_version(left_version) >= _parse_version(right_version) else right_version + + def get_current_version() -> str: """ Return the running version. @@ -2488,7 +2625,7 @@ async def get_latest_release( repo = repo or ucfg.repo token = token or ucfg.token base_url = base_url or ucfg.base_url - sources = list(sources_override or ucfg.sources) + sources = list(ucfg.sources if sources_override is None else sources_override) if provider: sources = [provider] @@ -2544,23 +2681,31 @@ async def check_update( region=region, locale=locale, ) - current = _read_pro_bundle_installed_version() if profile.sources == ["console-manifest"] else get_current_version() + is_console_manifest = profile.sources == ["console-manifest"] + local_core_version = get_current_version() + current_pro_state = _current_pro_version_state(local_core_version) if is_console_manifest else None + current = current_pro_state.bundle_version if current_pro_state else local_core_version if not current: - current = get_current_version() + current = local_core_version if not ucfg.enabled: return VersionInfo( current_version=current, + current_core_version=current_pro_state.core_version if current_pro_state else None, + current_bundle_version=current_pro_state.bundle_version if current_pro_state else None, + current_pro_component_version=current_pro_state.pro_component_version if current_pro_state else None, deploy_mode=mode, update_allowed=(mode != "docker"), ) bundle_sha256: str | None = None bundle_format: str | None = None + latest_pro_state: ProVersionState | None = None try: - if profile.sources == ["console-manifest"]: + if is_console_manifest: manifest_info = await _fetch_console_manifest_release_info() - tag = manifest_info.version + latest_pro_state = _latest_pro_version_state(manifest_info) + tag = latest_pro_state.bundle_version notes = manifest_info.release_notes url = manifest_info.release_url bundle_sha256 = manifest_info.bundle_sha256 @@ -2577,18 +2722,35 @@ async def check_update( log.warning("updater.check_failed", {"error": str(exc)}) return VersionInfo( current_version=current, + current_core_version=current_pro_state.core_version if current_pro_state else None, + current_bundle_version=current_pro_state.bundle_version if current_pro_state else None, + current_pro_component_version=current_pro_state.pro_component_version if current_pro_state else None, error="Failed to check for updates. Please check your network connection.", deploy_mode=mode, update_allowed=(mode != "docker"), ) - has_update = _parse_version(tag) > _parse_version(current) + bundle_has_update = _is_newer_version(tag, current) + core_has_update = ( + _is_newer_version(latest_pro_state.core_version, current_pro_state.core_version) + if latest_pro_state and current_pro_state + else False + ) + pro_component_has_update = ( + _is_newer_version(latest_pro_state.pro_component_version, current_pro_state.pro_component_version) + if latest_pro_state and current_pro_state + else False + ) + has_update = bundle_has_update or core_has_update or pro_component_has_update log.info( "updater.check.result", { "current": current, "latest": tag, "has_update": has_update, + "bundle_has_update": bundle_has_update, + "core_has_update": core_has_update, + "pro_component_has_update": pro_component_has_update, "sources": profile.sources, "region": profile.region, }, @@ -2596,7 +2758,13 @@ async def check_update( return VersionInfo( current_version=current, latest_version=tag, - edition="flockspro" if profile.sources == ["console-manifest"] else "flocks", + current_core_version=current_pro_state.core_version if current_pro_state else None, + latest_core_version=latest_pro_state.core_version if latest_pro_state else None, + current_bundle_version=current_pro_state.bundle_version if current_pro_state else None, + latest_bundle_version=latest_pro_state.bundle_version if latest_pro_state else None, + current_pro_component_version=current_pro_state.pro_component_version if current_pro_state else None, + latest_pro_component_version=latest_pro_state.pro_component_version if latest_pro_state else None, + edition="flockspro" if is_console_manifest else "flocks", has_update=has_update, release_notes=notes, release_url=url, @@ -2703,6 +2871,7 @@ async def perform_pro_bundle_install( bundle_sha256=manifest_info.bundle_sha256, bundle_format=manifest_info.bundle_format, console_session_token=manifest_info.console_session_token, + console_manifest_payload=manifest_info.manifest, restart=restart, force_console_manifest=True, ): @@ -2717,6 +2886,7 @@ async def perform_update( bundle_sha256: str | None = None, bundle_format: str | None = None, console_session_token: str | None = None, + console_manifest_payload: dict[str, Any] | None = None, restart: bool = True, locale: str | None = None, region: str | None = None, @@ -2742,8 +2912,11 @@ async def perform_update( ) install_root = _get_repo_root() current_version = get_current_version() + effective_update_version = current_version + skip_core_replace = False handover_active = False console_manifest_info: ConsoleManifestRelease | None = None + console_manifest_payload = console_manifest_payload if isinstance(console_manifest_payload, dict) else None fmt = _choose_archive_format(ucfg.archive_format) if profile.sources == ["console-manifest"]: if not (zipball_url or tarball_url) or console_session_token is None: @@ -2771,6 +2944,7 @@ async def perform_update( bundle_sha256 = console_manifest_info.bundle_sha256 bundle_format = console_manifest_info.bundle_format console_session_token = console_manifest_info.console_session_token + console_manifest_payload = console_manifest_info.manifest primary_bundle_url = zipball_url or tarball_url or "" fmt = _archive_format_for_url(primary_bundle_url, bundle_format) @@ -2879,6 +3053,11 @@ async def _queue_download_progress(progress: UpdateProgress) -> None: extract_dir, ) content_root, pro_wheel_path, pro_bundle_manifest = _resolve_pro_bundle_content(content_root) + if profile.sources == ["console-manifest"]: + pro_bundle_manifest = _merge_console_manifest_release_identity( + pro_bundle_manifest, + console_manifest_payload, + ) if profile.sources == ["console-manifest"] and pro_wheel_path is None: raise ValueError("Pro bundle 中未找到 flockspro wheel") except Exception as exc: @@ -2889,12 +3068,29 @@ async def _queue_download_progress(progress: UpdateProgress) -> None: _record_update_journal(f"ERROR {msg}") yield UpdateProgress(stage="error", message=msg, success=False) return + if profile.sources == ["console-manifest"]: + skip_core_replace = _is_pro_bundle_oss_older_than_local(pro_bundle_manifest, current_version) + if skip_core_replace: + bundle_oss_version = _pro_bundle_oss_version(pro_bundle_manifest) + pro_bundle_manifest = _effective_pro_bundle_manifest(pro_bundle_manifest, current_version) + log.info( + "updater.pro_bundle.keep_local_core", + {"local_version": current_version, "bundle_oss_version": bundle_oss_version}, + ) + else: + effective_update_version = latest_tag + else: + effective_update_version = latest_tag # ------------------------------------------------------------------ # # Step 3 – determine whether frontend handover is needed # ------------------------------------------------------------------ # staged_webui_dir = content_root / "webui" - needs_handover = staged_webui_dir.is_dir() and (staged_webui_dir / "package.json").exists() + needs_handover = ( + not skip_core_replace + and staged_webui_dir.is_dir() + and (staged_webui_dir / "package.json").exists() + ) # ------------------------------------------------------------------ # # Step 4 – replace install tree @@ -2902,7 +3098,11 @@ async def _queue_download_progress(progress: UpdateProgress) -> None: # ------------------------------------------------------------------ # yield UpdateProgress( stage="applying", - message=f"Applying v{latest_tag}...", + message=( + f"Keeping local Flocks {_version_label(current_version)} and installing the Pro component..." + if skip_core_replace + else f"Applying v{latest_tag}..." + ), ) async def _restore_after_apply_failure() -> None: @@ -2929,11 +3129,12 @@ async def _restore_after_apply_failure() -> None: ) try: - await asyncio.to_thread( - _replace_install_dir, - content_root, - install_root, - ) + if not skip_core_replace: + await asyncio.to_thread( + _replace_install_dir, + content_root, + install_root, + ) except Exception as exc: final_replace_error: Exception | None = exc if ( @@ -2947,11 +3148,12 @@ async def _restore_after_apply_failure() -> None: try: _prepare_upgrade_handover(latest_tag) handover_active = True - await asyncio.to_thread( - _replace_install_dir, - content_root, - install_root, - ) + if not skip_core_replace: + await asyncio.to_thread( + _replace_install_dir, + content_root, + install_root, + ) except Exception as retry_exc: final_replace_error = retry_exc else: @@ -3007,7 +3209,7 @@ async def _restore_after_apply_failure() -> None: task_error = await run_handoff_upgrade_tasks( install_root=install_root, uv_path=uv_path, - version=latest_tag, + version=effective_update_version, uv_default_index=profile.uv_default_index, npm_registry=profile.npm_registry, pro_wheel_path=pro_wheel_path, @@ -3023,10 +3225,10 @@ async def _restore_after_apply_failure() -> None: return shutil.rmtree(tmp_dir, ignore_errors=True) - log.info("updater.apply.done", {"version": latest_tag, "restart": False, "region": profile.region}) + log.info("updater.apply.done", {"version": effective_update_version, "restart": False, "region": profile.region}) yield UpdateProgress( stage="done", - message=f"Upgraded to v{latest_tag}", + message=f"Upgraded to {_version_label(effective_update_version)}", success=True, ) return @@ -3037,6 +3239,7 @@ async def _restore_after_apply_failure() -> None: "updater.restart", { "tag": latest_tag, + "effective_version": effective_update_version, "sources": profile.sources, "repo": ucfg.repo, "region": profile.region, @@ -3085,7 +3288,7 @@ async def _restore_after_apply_failure() -> None: install_root, uv_path=uv_path, sync_timeout=sync_timeout, - version=latest_tag, + version=effective_update_version, current_version=current_version, backup_path=backup_path, uv_default_index=profile.uv_default_index, diff --git a/flocks/user_defined_pages/__init__.py b/flocks/user_defined_pages/__init__.py deleted file mode 100644 index 6303f1d05..000000000 --- a/flocks/user_defined_pages/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""User-space user-defined custom pages under ~/.flocks/plugins/user_defined_pages.""" - -from flocks.user_defined_pages.store import UserDefinedPagesStore -from flocks.user_defined_pages.watcher import UserDefinedPagesWatcher - -__all__ = ["UserDefinedPagesStore", "UserDefinedPagesWatcher"] diff --git a/flocks/user_defined_pages/models.py b/flocks/user_defined_pages/models.py deleted file mode 100644 index 02189fca4..000000000 --- a/flocks/user_defined_pages/models.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Pydantic models for user-defined custom pages.""" - -from __future__ import annotations - -from typing import Literal, Optional - -from pydantic import BaseModel, ConfigDict, Field - - -class UserDefinedPageManifest(BaseModel): - model_config = ConfigDict(populate_by_name=True) - - id: str = Field(..., description="Stable page identifier") - title: str = Field(..., description="Navigation label") - route: str = Field(..., description="WebUI route path") - icon: str = Field("LayoutDashboard", description="Lucide icon name") - order: int = Field(100, description="Sort order in navigation") - enabled: bool = Field(True, description="Whether page appears in navigation") - placement: Literal["home.after"] = Field( - "home.after", - description="Where to insert the nav item", - ) - entry: str = Field("src/index.tsx", description="Source entry relative to page dir") - updatedAt: int = Field(0, description="Last manifest update timestamp (ms)") - - -class UserDefinedPageBuildMeta(BaseModel): - model_config = ConfigDict(populate_by_name=True) - - hash: str = Field("", description="Content hash for cache busting") - builtAt: int = Field(0, description="Build timestamp (ms)") - status: Literal["idle", "building", "ready", "failed"] = Field("idle") - error: Optional[str] = Field(None, description="Last build error message") - runtime: str = Field("user_defined_page", description="Builder runtime marker") - runtimeVersion: int = Field(1, description="Builder runtime version") - sdkImport: str = Field("@flocks/user-defined-page-sdk", description="SDK import marker") - - -class UserDefinedPageApiMeta(BaseModel): - model_config = ConfigDict(populate_by_name=True) - - status: Literal["idle", "ready", "failed"] = Field("idle") - loadedAt: int = Field(0, description="Runtime load timestamp (ms)") - error: Optional[str] = Field(None, description="Last API runtime error") - routes: list[dict[str, str]] = Field(default_factory=list, description="Loaded route descriptors") - - -class UserDefinedPageListItem(BaseModel): - model_config = ConfigDict(populate_by_name=True, by_alias=True) - - id: str - title: str - route: str - icon: str - order: int - enabled: bool - placement: str - buildHash: str = Field("", alias="buildHash") - buildStatus: str = Field("idle", alias="buildStatus") - - -class UserDefinedPageDetail(BaseModel): - model_config = ConfigDict(populate_by_name=True, by_alias=True) - - manifest: UserDefinedPageManifest - build: UserDefinedPageBuildMeta - sourceFiles: list[str] = Field(default_factory=list, alias="sourceFiles") diff --git a/flocks/user_defined_pages/shims/jsx-runtime.js b/flocks/user_defined_pages/shims/jsx-runtime.js deleted file mode 100644 index dc2740653..000000000 --- a/flocks/user_defined_pages/shims/jsx-runtime.js +++ /dev/null @@ -1,7 +0,0 @@ -const runtime = globalThis.__FLOCKS_USER_DEFINED_PAGE_SDK__; -if (!runtime?.jsx || !runtime?.jsxs) { - throw new Error('Flocks user-defined page runtime is not initialized (missing jsx runtime).'); -} -export const jsx = runtime.jsx; -export const jsxs = runtime.jsxs; -export const Fragment = runtime.React.Fragment; diff --git a/flocks/user_defined_pages/shims/sdk.js b/flocks/user_defined_pages/shims/sdk.js deleted file mode 100644 index 81efdd874..000000000 --- a/flocks/user_defined_pages/shims/sdk.js +++ /dev/null @@ -1,7 +0,0 @@ -const sdk = globalThis.__FLOCKS_USER_DEFINED_PAGE_SDK__; -if (!sdk) { - throw new Error('Flocks user-defined page runtime is not initialized (missing SDK).'); -} -export const api = sdk.api; -export const Card = sdk.Card; -export const useCurrentUser = sdk.useCurrentUser; diff --git a/flocks/user_defined_pages/store.py b/flocks/user_defined_pages/store.py deleted file mode 100644 index 0160f0395..000000000 --- a/flocks/user_defined_pages/store.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Filesystem store for user-defined pages.""" - -from __future__ import annotations - -import json -import os -import re -import time -from pathlib import Path -from typing import Any, Optional - -from flocks.user_defined_pages.models import ( - UserDefinedPageApiMeta, - UserDefinedPageBuildMeta, - UserDefinedPageDetail, - UserDefinedPageListItem, - UserDefinedPageManifest, -) -from flocks.utils.log import Log - -log = Log.create(service="user-defined-pages-store") - -PAGE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9-]*$") -MAX_SOURCE_FILE_BYTES = 512_000 -ALLOWED_WRITE_PREFIXES = ("src/", "assets/", "api/") -ALLOWED_WRITE_FILES = frozenset({"manifest.json"}) -_SOURCE_SUFFIXES = {".tsx", ".ts", ".jsx", ".js", ".css", ".json"} -_API_SUFFIXES = {".py", ".yaml", ".yml"} - -def _default_page_tsx(title: str) -> str: - safe_title = title.replace("\\", "\\\\").replace('"', '\\"') - return f"""import {{ useEffect, useState }} from 'react'; -import {{ Card }} from '@flocks/user-defined-page-sdk'; - -export default function Page() {{ - const [ready, setReady] = useState(false); - - useEffect(() => {{ - setReady(true); - }}, []); - - return ( - - {{ready ? 'Ready' : 'Loading...'}} - - ); -}} -""" - -_DEFAULT_INDEX_TSX = """import Page from './Page'; - -export default Page; -""" - - -def get_user_defined_pages_root() -> Path: - """Return canonical user-space root for user-defined pages.""" - override = os.environ.get("FLOCKS_USER_DEFINED_PAGES_ROOT") - if override: - return Path(override).expanduser().resolve() - return (Path.home() / ".flocks" / "plugins" / "user_defined_pages").resolve() - - -class UserDefinedPagesStore: - """CRUD and scan helpers for ~/.flocks/plugins/user_defined_pages.""" - - def __init__(self, root: Optional[Path] = None) -> None: - self._root = (root or get_user_defined_pages_root()).resolve() - - @property - def root(self) -> Path: - return self._root - - def ensure_root(self) -> Path: - self._root.mkdir(parents=True, exist_ok=True) - return self._root - - @staticmethod - def validate_page_id(page_id: str) -> str: - normalized = (page_id or "").strip().lower() - if not PAGE_ID_RE.fullmatch(normalized): - raise ValueError("invalid page id: use lowercase letters, numbers, and hyphens") - return normalized - - def page_dir(self, page_id: str) -> Path: - page_id = self.validate_page_id(page_id) - page_path = (self._root / page_id).resolve() - try: - page_path.relative_to(self._root) - except ValueError: - raise ValueError("invalid page path") - return page_path - - def _assert_writable_relative(self, relative_path: str) -> Path: - if not relative_path or Path(relative_path).is_absolute(): - raise ValueError("absolute path is not allowed") - rel = relative_path.replace("\\", "/").lstrip("/") - if rel in ALLOWED_WRITE_FILES: - return Path(rel) - if any(rel.startswith(prefix) for prefix in ALLOWED_WRITE_PREFIXES): - parts = rel.split("/") - if ".." in parts: - raise ValueError("path traversal is not allowed") - if any(part.startswith(".") for part in parts if part): - raise ValueError("hidden path is not allowed") - return Path(rel) - raise ValueError(f"writes are not allowed for path: {relative_path}") - - def list_pages(self, *, enabled_only: bool = False) -> list[UserDefinedPageListItem]: - self.ensure_root() - items: list[UserDefinedPageListItem] = [] - for child in sorted(self._root.iterdir()): - if not child.is_dir(): - continue - manifest = self._read_manifest(child.name) - if manifest is None: - continue - if enabled_only and not manifest.enabled: - continue - build = self._read_build_meta(child.name) - items.append( - UserDefinedPageListItem( - id=manifest.id, - title=manifest.title, - route=manifest.route, - icon=manifest.icon, - order=manifest.order, - enabled=manifest.enabled, - placement=manifest.placement, - buildHash=build.hash, - buildStatus=build.status, - ) - ) - items.sort(key=lambda item: (item.order, item.title)) - return items - - def get_page(self, page_id: str) -> UserDefinedPageDetail: - page_dir = self.page_dir(page_id) - if not page_dir.is_dir(): - raise FileNotFoundError(f"page not found: {page_id}") - manifest = self._read_manifest(page_id) - if manifest is None: - raise FileNotFoundError(f"manifest missing for page: {page_id}") - build = self._read_build_meta(page_id) - source_files = sorted( - str(path.relative_to(page_dir)).replace("\\", "/") - for path in page_dir.rglob("*") - if path.is_file() and "dist/" not in str(path.relative_to(page_dir)).replace("\\", "/") - ) - return UserDefinedPageDetail(manifest=manifest, build=build, sourceFiles=source_files) - - def create_page( - self, - *, - page_id: str, - title: str, - icon: str = "LayoutDashboard", - order: int = 100, - ) -> UserDefinedPageDetail: - page_id = self.validate_page_id(page_id) - page_dir = self.page_dir(page_id) - if page_dir.exists(): - raise FileExistsError(f"page already exists: {page_id}") - - now_ms = int(time.time() * 1000) - manifest = UserDefinedPageManifest( - id=page_id, - title=title.strip() or page_id, - route=f"/user-defined-pages/{page_id}", - icon=icon, - order=order, - enabled=True, - placement="home.after", - entry="src/index.tsx", - updatedAt=now_ms, - ) - - page_dir.mkdir(parents=True, exist_ok=False) - (page_dir / "src").mkdir(parents=True, exist_ok=True) - (page_dir / "api").mkdir(parents=True, exist_ok=True) - (page_dir / "assets").mkdir(parents=True, exist_ok=True) - (page_dir / "dist").mkdir(parents=True, exist_ok=True) - - self._write_manifest(page_id, manifest) - self._write_source_file(page_id, "src/Page.tsx", _default_page_tsx(manifest.title)) - self._write_source_file(page_id, "src/index.tsx", _DEFAULT_INDEX_TSX) - self._write_build_meta( - page_id, - UserDefinedPageBuildMeta(status="idle", hash="", builtAt=0, error=None), - ) - log.info("user_defined_pages.created", {"pageId": page_id}) - return self.get_page(page_id) - - def save_manifest(self, page_id: str, manifest_data: dict[str, Any]) -> UserDefinedPageManifest: - page_id = self.validate_page_id(page_id) - existing = self._read_manifest(page_id) - if existing is None: - raise FileNotFoundError(f"page not found: {page_id}") - - merged = existing.model_dump() - merged.update(manifest_data) - merged["id"] = page_id - merged["route"] = f"/user-defined-pages/{page_id}" - merged["updatedAt"] = int(time.time() * 1000) - manifest = UserDefinedPageManifest.model_validate(merged) - self._write_manifest(page_id, manifest) - return manifest - - def save_source_file(self, page_id: str, relative_path: str, content: str) -> None: - rel = self._assert_writable_relative(relative_path) - rel_str = str(rel).replace("\\", "/") - if rel_str.startswith("api/"): - allowed_suffixes = _API_SUFFIXES - else: - allowed_suffixes = _SOURCE_SUFFIXES - if rel.suffix not in allowed_suffixes: - raise ValueError("unsupported source file type") - encoded = content.encode("utf-8") - if len(encoded) > MAX_SOURCE_FILE_BYTES: - raise ValueError("source file is too large") - self._write_source_file(page_id, rel_str, content) - - def read_source_file(self, page_id: str, relative_path: str) -> str: - rel = self._assert_writable_relative(relative_path) - path = self.page_dir(page_id) / rel - if not path.is_file(): - raise FileNotFoundError(relative_path) - return path.read_text(encoding="utf-8") - - def bundle_path(self, page_id: str) -> Path: - return self.page_dir(page_id) / "dist" / "page.js" - - def asset_path(self, page_id: str, relative_path: str) -> Path: - rel = relative_path.replace("\\", "/").lstrip("/") - if ".." in rel.split("/"): - raise ValueError("path traversal is not allowed") - path = (self.page_dir(page_id) / "assets" / rel).resolve() - assets_root = (self.page_dir(page_id) / "assets").resolve() - try: - path.relative_to(assets_root) - except ValueError: - raise ValueError("invalid asset path") - return path - - def write_build_meta(self, page_id: str, meta: UserDefinedPageBuildMeta) -> None: - self._write_build_meta(page_id, meta) - - def read_build_meta(self, page_id: str) -> UserDefinedPageBuildMeta: - return self._read_build_meta(page_id) - - def routes_path(self, page_id: str) -> Path: - return self.page_dir(page_id) / "api" / "routes.yaml" - - def api_handlers_path(self, page_id: str) -> Path: - return self.page_dir(page_id) / "api" / "handlers.py" - - def read_api_routes(self, page_id: str) -> Optional[str]: - path = self.routes_path(page_id) - if not path.is_file(): - return None - return path.read_text(encoding="utf-8") - - def write_api_meta(self, page_id: str, meta: UserDefinedPageApiMeta) -> None: - path = self._api_meta_path(page_id) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text( - json.dumps(meta.model_dump(), ensure_ascii=False, indent=2), - encoding="utf-8", - ) - - def read_api_meta(self, page_id: str) -> UserDefinedPageApiMeta: - path = self._api_meta_path(page_id) - if not path.is_file(): - return UserDefinedPageApiMeta() - try: - raw = json.loads(path.read_text(encoding="utf-8")) - return UserDefinedPageApiMeta.model_validate(raw) - except Exception: - return UserDefinedPageApiMeta() - - def _manifest_path(self, page_id: str) -> Path: - return self.page_dir(page_id) / "manifest.json" - - def _build_meta_path(self, page_id: str) -> Path: - return self.page_dir(page_id) / "dist" / "meta.json" - - def _api_meta_path(self, page_id: str) -> Path: - return self.page_dir(page_id) / "dist" / "api-meta.json" - - def _read_manifest(self, page_id: str) -> Optional[UserDefinedPageManifest]: - path = self._manifest_path(page_id) - if not path.is_file(): - return None - try: - raw = json.loads(path.read_text(encoding="utf-8")) - return UserDefinedPageManifest.model_validate(raw) - except Exception as exc: - log.warning("user_defined_pages.manifest.invalid", {"pageId": page_id, "error": str(exc)}) - return None - - def _write_manifest(self, page_id: str, manifest: UserDefinedPageManifest) -> None: - path = self._manifest_path(page_id) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text( - json.dumps(manifest.model_dump(), ensure_ascii=False, indent=2), - encoding="utf-8", - ) - - def _read_build_meta(self, page_id: str) -> UserDefinedPageBuildMeta: - path = self._build_meta_path(page_id) - if not path.is_file(): - return UserDefinedPageBuildMeta() - try: - raw = json.loads(path.read_text(encoding="utf-8")) - return UserDefinedPageBuildMeta.model_validate(raw) - except Exception: - return UserDefinedPageBuildMeta() - - def _write_build_meta(self, page_id: str, meta: UserDefinedPageBuildMeta) -> None: - path = self._build_meta_path(page_id) - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text( - json.dumps(meta.model_dump(), ensure_ascii=False, indent=2), - encoding="utf-8", - ) - - def _write_source_file(self, page_id: str, relative_path: str, content: str) -> None: - rel = self._assert_writable_relative(relative_path) - target = self.page_dir(page_id) / rel - target.parent.mkdir(parents=True, exist_ok=True) - target.write_text(content, encoding="utf-8") diff --git a/flocks/workflow/center.py b/flocks/workflow/center.py index b286e2677..ad05aff3b 100644 --- a/flocks/workflow/center.py +++ b/flocks/workflow/center.py @@ -26,7 +26,7 @@ from flocks.config.config import Config from flocks.plugin.loader import DEFAULT_PLUGIN_ROOT from flocks.sandbox.docker import docker_container_state, exec_docker -from flocks.storage.storage import Storage +from flocks.workflow.store import WorkflowStore from flocks.utils.log import Log from flocks.workflow.models import Workflow from flocks.workflow.requirements import resolve_python_package_index_url @@ -135,8 +135,8 @@ def resolve_global_workflow_roots() -> list[Path]: """ home = Path.home() / ".flocks" return [ - home / "plugins" / "workflow", # legacy compat (read-only) - home / "workflow", # legacy compat (read-only) + home / "plugins" / "workflow", # legacy compat (read-only) + home / "workflow", # legacy compat (read-only) home / "plugins" / "workflows", # new canonical (read + write) ] @@ -150,8 +150,8 @@ def resolve_project_workflow_roots(base_dir: Optional[Path] = None) -> list[Path root = base_dir or Path.cwd() flocks = root / ".flocks" return [ - flocks / "plugins" / "workflow", # legacy compat (read-only) - flocks / "workflow", # legacy compat (read-only) + flocks / "plugins" / "workflow", # legacy compat (read-only) + flocks / "workflow", # legacy compat (read-only) flocks / "plugins" / "workflows", # new canonical (read + write) ] @@ -206,14 +206,14 @@ async def _reserved_service_ports() -> set[int]: ports: set[int] = set() for prefix in (_RUNTIME_PREFIX, _API_SERVICE_PREFIX, _REGISTRY_PREFIX): try: - keys = await Storage.list_keys(prefix) + keys = await WorkflowStore.kv_list_keys(prefix) except Exception as exc: log.warning("workflow.port.list_reserved_failed", {"prefix": prefix, "error": str(exc)}) continue for key in keys: try: - record = await Storage.read(key) + record = await WorkflowStore.kv_get(key) except Exception as exc: log.warning("workflow.port.read_reserved_failed", {"key": _key_to_string(key), "error": str(exc)}) continue @@ -223,11 +223,7 @@ async def _reserved_service_ports() -> set[int]: def _reserved_in_flight_ports() -> set[int]: now = time.time() - expired = [ - port - for port, expires_at in _IN_FLIGHT_PORT_RESERVATIONS.items() - if expires_at <= now - ] + expired = [port for port, expires_at in _IN_FLIGHT_PORT_RESERVATIONS.items() if expires_at <= now] for port in expired: _IN_FLIGHT_PORT_RESERVATIONS.pop(port, None) return set(_IN_FLIGHT_PORT_RESERVATIONS) @@ -254,7 +250,7 @@ async def _allocate_port() -> int: async def _read_registry(workflow_id: str) -> Dict[str, Any]: - data = await Storage.read(_registry_key(workflow_id)) + data = await WorkflowStore.kv_get(_registry_key(workflow_id)) if not data: raise WorkflowNotFoundError(f"Workflow not registered: {workflow_id}") return data @@ -276,11 +272,7 @@ async def _scan_workflow_dir( try: raw = json.loads(workflow_path.read_text(encoding="utf-8")) meta_path = workflow_path.parent / "meta.json" - meta = ( - json.loads(meta_path.read_text(encoding="utf-8")) - if meta_path.is_file() - else None - ) + meta = json.loads(meta_path.read_text(encoding="utf-8")) if meta_path.is_file() else None if is_hidden_workflow(raw, meta): continue Workflow.from_dict(raw) @@ -294,7 +286,7 @@ async def _scan_workflow_dir( workflow_id = _normalize_workflow_id(workflow_path) fp = _fingerprint(workflow_path) now_ms = _now_ms() - existing = await Storage.read(_registry_key(workflow_id)) or {} + existing = await WorkflowStore.kv_get(_registry_key(workflow_id)) or {} created_at = existing.get("registeredAt", now_ms) draft_changed = bool(existing) and existing.get("fingerprint") != fp entry = { @@ -315,7 +307,7 @@ async def _scan_workflow_dir( "serviceKey": existing.get("serviceKey"), "serviceUrl": existing.get("serviceUrl"), } - await Storage.write(_registry_key(workflow_id), entry) + await WorkflowStore.kv_put(_registry_key(workflow_id), entry) by_id[workflow_id] = entry @@ -383,11 +375,11 @@ def format_workflow_entries( async def list_registry_entries() -> List[Dict[str, Any]]: """List registered skill workflows.""" - keys = await Storage.list(_REGISTRY_PREFIX) + keys = await WorkflowStore.kv_list(_REGISTRY_PREFIX) items: List[Dict[str, Any]] = [] for raw_key in keys: key = _key_to_string(raw_key) - entry = await Storage.read(key) + entry = await WorkflowStore.kv_get(key) if entry: items.append(entry) items.sort(key=lambda item: item.get("updatedAt", 0), reverse=True) @@ -424,10 +416,14 @@ async def _write_requirements_snapshot(release_dir: Path) -> bool: Returns True on success, False if the snapshot could not be created. """ import sys + req_file = release_dir / "requirements.txt" try: proc = await asyncio.create_subprocess_exec( - sys.executable, "-m", "pip", "freeze", + sys.executable, + "-m", + "pip", + "freeze", "--exclude-editable", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, @@ -508,16 +504,18 @@ def _docker_proxy_env_value(env_value: str) -> str: netloc = f"{userinfo}host.docker.internal" if parsed.port: netloc += f":{parsed.port}" - rewritten = urlunparse(( - parsed.scheme, - netloc, - parsed.path, - parsed.params, - parsed.query, - parsed.fragment, - )) + rewritten = urlunparse( + ( + parsed.scheme, + netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) if not has_scheme and rewritten.startswith("http://"): - return rewritten[len("http://"):] + return rewritten[len("http://") :] return rewritten except Exception: return env_value @@ -563,8 +561,7 @@ async def _wait_docker_service_healthy( logs = await _docker_logs_tail(container_name) detail = logs or "container exited before reporting healthy" raise WorkflowCenterError( - "Published workflow service container exited before health check " - f"passed: {detail}" + f"Published workflow service container exited before health check passed: {detail}" ) await asyncio.sleep(interval_s) return False @@ -633,12 +630,12 @@ def _signal_local_process(pid: int, sig: signal.Signals, process_group_id: Optio async def _stop_local_service(workflow_id: str) -> None: """Kill a previously started local workflow service process.""" - pid_record = await Storage.read(_local_pid_key(workflow_id)) + pid_record = await WorkflowStore.kv_get(_local_pid_key(workflow_id)) if not pid_record: return pid = pid_record.get("pid") if not pid: - await Storage.remove(_local_pid_key(workflow_id)) + await WorkflowStore.kv_remove(_local_pid_key(workflow_id)) return pid_int = int(pid) process_group_id = pid_record.get("processGroupId") @@ -651,7 +648,7 @@ async def _stop_local_service(workflow_id: str) -> None: _signal_local_process(pid_int, signal.SIGKILL, process_group_id) await _wait_for_pid_exit(pid_int, 1.0) finally: - await Storage.remove(_local_pid_key(workflow_id)) + await WorkflowStore.kv_remove(_local_pid_key(workflow_id)) async def _stop_local_runtime(workflow_id: str, runtime: Dict[str, Any]) -> bool: @@ -671,7 +668,7 @@ async def _stop_local_runtime(workflow_id: str, runtime: Dict[str, Any]) -> bool log.warning("workflow.local.force_kill", {"workflow_id": workflow_id, "pid": pid}) _signal_local_process(pid, signal.SIGKILL, process_group_id) exited = await _wait_for_pid_exit(pid, 1.0) - await Storage.remove(_local_pid_key(workflow_id)) + await WorkflowStore.kv_remove(_local_pid_key(workflow_id)) return exited @@ -692,11 +689,11 @@ def _runtime_driver(runtime: Optional[Dict[str, Any]]) -> str: async def _mark_release_inactive(workflow_id: str, release_id: Optional[Any]) -> None: if not release_id: return - release_record = await Storage.read(_release_key(workflow_id, str(release_id))) or {} + release_record = await WorkflowStore.kv_get(_release_key(workflow_id, str(release_id))) or {} if release_record: release_record["status"] = "inactive" release_record["deactivatedAt"] = _now_ms() - await Storage.write(_release_key(workflow_id, str(release_id)), release_record) + await WorkflowStore.kv_put(_release_key(workflow_id, str(release_id)), release_record) async def _stop_runtime_record( @@ -719,18 +716,18 @@ async def _stop_runtime_record( if not stopped: raise WorkflowCenterError(f"Failed to stop Docker container: {container_name}") - active = (await Storage.read(_active_release_key(workflow_id)) or {}) if clear_runtime_keys else {} + active = (await WorkflowStore.kv_get(_active_release_key(workflow_id)) or {}) if clear_runtime_keys else {} release_id = runtime.get("releaseId") or active.get("releaseId") await _mark_release_inactive(workflow_id, release_id) if clear_runtime_keys: - await Storage.remove(_runtime_key(workflow_id)) - await Storage.remove(_active_release_key(workflow_id)) + await WorkflowStore.kv_remove(_runtime_key(workflow_id)) + await WorkflowStore.kv_remove(_active_release_key(workflow_id)) if update_registry: registry["publishStatus"] = "stopped" registry["updatedAt"] = _now_ms() registry["serviceUrl"] = None - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) return { "workflowId": workflow_id, @@ -742,7 +739,7 @@ async def _stop_runtime_record( async def _stop_existing_runtime_for_publish(workflow_id: str) -> None: """Best-effort cleanup before starting a replacement service.""" - runtime = await Storage.read(_runtime_key(workflow_id)) + runtime = await WorkflowStore.kv_get(_runtime_key(workflow_id)) if isinstance(runtime, dict) and runtime: await _stop_runtime_record(workflow_id, runtime, update_registry=False) else: @@ -767,7 +764,7 @@ async def publish_workflow_local(workflow_id: str, *, api_key: Optional[str] = N now_ms = _now_ms() registry["publishStatus"] = "publishing" registry["updatedAt"] = now_ms - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) release_snapshot_file = await _write_release_snapshot(workflow_id, release_id, workflow_json) @@ -784,26 +781,37 @@ async def publish_workflow_local(workflow_id: str, *, api_key: Optional[str] = N env[_SERVICE_API_KEY_ENV] = runtime_api_key proc = await asyncio.create_subprocess_exec( sys.executable, - "-m", "flocks.workflow.service_runtime", - "--workflow", str(release_snapshot_file), - "--workflow-id", workflow_id, - "--release-id", release_id, - "--host", "127.0.0.1", - "--port", str(host_port), + "-m", + "flocks.workflow.service_runtime", + "--workflow", + str(release_snapshot_file), + "--workflow-id", + workflow_id, + "--release-id", + release_id, + "--host", + "127.0.0.1", + "--port", + str(host_port), stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL, env=env, start_new_session=True, ) - await Storage.write(_local_pid_key(workflow_id), { - "pid": proc.pid, - "processGroupId": proc.pid, - "port": host_port, - }) + await WorkflowStore.kv_put( + _local_pid_key(workflow_id), + { + "pid": proc.pid, + "processGroupId": proc.pid, + "port": host_port, + }, + ) health_retries = int(os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_RETRIES", str(_DEFAULT_HEALTH_RETRIES))) - health_interval_s = float(os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_INTERVAL_S", str(_DEFAULT_HEALTH_INTERVAL_S))) + health_interval_s = float( + os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_INTERVAL_S", str(_DEFAULT_HEALTH_INTERVAL_S)) + ) healthy = await _wait_service_healthy(service_url, retries=health_retries, interval_s=health_interval_s) if not healthy: @@ -828,8 +836,8 @@ async def publish_workflow_local(workflow_id: str, *, api_key: Optional[str] = N "driver": "local", "apiKey": runtime_api_key, } - await Storage.write(_active_release_key(workflow_id), active_record) - await Storage.write(_runtime_key(workflow_id), active_record) + await WorkflowStore.kv_put(_active_release_key(workflow_id), active_record) + await WorkflowStore.kv_put(_runtime_key(workflow_id), active_record) _release_port_reservation(host_port) except Exception as exc: _release_port_reservation(host_port) @@ -841,7 +849,7 @@ async def publish_workflow_local(workflow_id: str, *, api_key: Optional[str] = N pass registry["publishStatus"] = "failed" registry["updatedAt"] = _now_ms() - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) if isinstance(exc, WorkflowCenterError): raise raise WorkflowCenterError(str(exc)) from exc @@ -851,7 +859,7 @@ async def publish_workflow_local(workflow_id: str, *, api_key: Optional[str] = N registry["serviceKey"] = service_key registry["serviceUrl"] = service_url registry["updatedAt"] = _now_ms() - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) log.info("workflow.local.published", {"id": workflow_id, "port": host_port, "pid": proc.pid}) return active_record @@ -861,12 +869,12 @@ async def stop_local_service(workflow_id: str) -> Dict[str, Any]: """Stop a local workflow service process.""" await _stop_local_service(workflow_id) registry = await _read_registry(workflow_id) - await Storage.remove(_runtime_key(workflow_id)) - await Storage.remove(_active_release_key(workflow_id)) + await WorkflowStore.kv_remove(_runtime_key(workflow_id)) + await WorkflowStore.kv_remove(_active_release_key(workflow_id)) registry["publishStatus"] = "stopped" registry["updatedAt"] = _now_ms() registry["serviceUrl"] = None - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) return {"workflowId": workflow_id, "status": "stopped", "stopped": True} @@ -874,6 +882,7 @@ async def stop_local_service(workflow_id: str) -> Dict[str, Any]: # Unified publish / stop entry points (driver-aware) # ───────────────────────────────────────────────────────────────────────────── + def _service_driver() -> str: return os.getenv("FLOCKS_WORKFLOW_SERVICE_DRIVER", _DEFAULT_SERVICE_DRIVER).lower() @@ -895,10 +904,10 @@ async def publish_workflow( async def stop_workflow_service(workflow_id: str) -> Dict[str, Any]: """Stop a published workflow service (driver-aware).""" - runtime = await Storage.read(_runtime_key(workflow_id)) + runtime = await WorkflowStore.kv_get(_runtime_key(workflow_id)) if isinstance(runtime, dict) and runtime: return await _stop_runtime_record(workflow_id, runtime, update_registry=True) - active = await Storage.read(_active_release_key(workflow_id)) + active = await WorkflowStore.kv_get(_active_release_key(workflow_id)) if isinstance(active, dict) and active: return await _stop_runtime_record(workflow_id, active, update_registry=True) @@ -925,7 +934,7 @@ async def _publish_workflow_docker( now_ms = _now_ms() registry["publishStatus"] = "publishing" registry["updatedAt"] = now_ms - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) release_snapshot_file = await _write_release_snapshot(workflow_id, release_id, workflow_json) release_runtime_dir = release_snapshot_file.parent @@ -941,10 +950,10 @@ async def _publish_workflow_docker( "activatedAt": None, "deactivatedAt": None, } - await Storage.write(_release_key(workflow_id, release_id), release_record) + await WorkflowStore.kv_put(_release_key(workflow_id, release_id), release_record) - previous_runtime = await Storage.read(_runtime_key(workflow_id)) or {} - previous_active = await Storage.read(_active_release_key(workflow_id)) or {} + previous_runtime = await WorkflowStore.kv_get(_runtime_key(workflow_id)) or {} + previous_active = await WorkflowStore.kv_get(_active_release_key(workflow_id)) or {} previous_container_name = previous_active.get("containerName") previous_release_id = previous_active.get("releaseId") @@ -956,15 +965,9 @@ async def _publish_workflow_docker( "false", "no", } - health_interval_s = float( - os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_INTERVAL_S", str(_DEFAULT_HEALTH_INTERVAL_S)) - ) - default_retries = ( - _DEFAULT_RUNTIME_INSTALL_HEALTH_RETRIES if runtime_install else _DEFAULT_HEALTH_RETRIES - ) - health_retries = int( - os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_RETRIES", str(default_retries)) - ) + health_interval_s = float(os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_INTERVAL_S", str(_DEFAULT_HEALTH_INTERVAL_S))) + default_retries = _DEFAULT_RUNTIME_INSTALL_HEALTH_RETRIES if runtime_install else _DEFAULT_HEALTH_RETRIES + health_retries = int(os.getenv("FLOCKS_WORKFLOW_SERVICE_HEALTH_RETRIES", str(default_retries))) project_root = Path.cwd().resolve() user_config_dir = Config.get_config_path().resolve() pip_cache_dir = _service_cache_dir("pip") @@ -1002,7 +1005,7 @@ async def _publish_workflow_docker( image_name, ] if user_config_dir.exists(): - cmd[cmd.index(image_name):cmd.index(image_name)] = [ + cmd[cmd.index(image_name) : cmd.index(image_name)] = [ "-v", f"{user_config_dir}:/runtime/.flocks-config:ro", ] @@ -1028,14 +1031,14 @@ async def _publish_workflow_docker( proxy_injections.extend(["-e", f"{env_name}={docker_env_value}"]) if proxy_injections: if needs_host_gateway: - cmd[cmd.index(image_name):cmd.index(image_name)] = [ + cmd[cmd.index(image_name) : cmd.index(image_name)] = [ "--add-host", "host.docker.internal:host-gateway", ] - cmd[cmd.index(image_name):cmd.index(image_name)] = proxy_injections + cmd[cmd.index(image_name) : cmd.index(image_name)] = proxy_injections python_index_url = resolve_python_package_index_url() if python_index_url: - cmd[cmd.index(image_name):cmd.index(image_name)] = [ + cmd[cmd.index(image_name) : cmd.index(image_name)] = [ "-e", f"PIP_INDEX_URL={python_index_url}", "-e", @@ -1097,7 +1100,7 @@ async def _publish_workflow_docker( release_record["status"] = "active" release_record["activatedAt"] = _now_ms() - await Storage.write(_release_key(workflow_id, release_id), release_record) + await WorkflowStore.kv_put(_release_key(workflow_id, release_id), release_record) active_record = { "releaseId": release_id, @@ -1113,8 +1116,8 @@ async def _publish_workflow_docker( "driver": "docker", "apiKey": runtime_api_key, } - await Storage.write(_active_release_key(workflow_id), active_record) - await Storage.write(_runtime_key(workflow_id), active_record) + await WorkflowStore.kv_put(_active_release_key(workflow_id), active_record) + await WorkflowStore.kv_put(_runtime_key(workflow_id), active_record) _release_port_reservation(host_port) registry["publishStatus"] = "active" @@ -1122,7 +1125,7 @@ async def _publish_workflow_docker( registry["serviceKey"] = service_key registry["serviceUrl"] = service_url registry["updatedAt"] = _now_ms() - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) if isinstance(previous_runtime, dict) and previous_runtime: await _stop_runtime_record( @@ -1142,22 +1145,24 @@ async def _publish_workflow_docker( await _stop_and_remove_container(container_name) release_record["status"] = "failed" release_record["deactivatedAt"] = _now_ms() - await Storage.write(_release_key(workflow_id, release_id), release_record) + await WorkflowStore.kv_put(_release_key(workflow_id, release_id), release_record) registry["publishStatus"] = "failed" registry["updatedAt"] = _now_ms() - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) raise WorkflowCenterError(str(exc)) from exc async def _stop_workflow_service_docker(workflow_id: str) -> Dict[str, Any]: """Stop a published workflow Docker service container.""" registry = await _read_registry(workflow_id) - runtime = await Storage.read(_runtime_key(workflow_id)) or await Storage.read(_active_release_key(workflow_id)) + runtime = await WorkflowStore.kv_get(_runtime_key(workflow_id)) or await WorkflowStore.kv_get( + _active_release_key(workflow_id) + ) if not runtime: registry["publishStatus"] = "stopped" registry["updatedAt"] = _now_ms() - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) return {"workflowId": workflow_id, "status": "stopped", "stopped": False} container_name = runtime.get("containerName") @@ -1166,28 +1171,28 @@ async def _stop_workflow_service_docker(workflow_id: str) -> Dict[str, Any]: if not stopped: raise WorkflowCenterError(f"Failed to stop Docker container: {container_name}") - active = await Storage.read(_active_release_key(workflow_id)) or {} + active = await WorkflowStore.kv_get(_active_release_key(workflow_id)) or {} release_id = active.get("releaseId") if release_id: - release_record = await Storage.read(_release_key(workflow_id, release_id)) or {} + release_record = await WorkflowStore.kv_get(_release_key(workflow_id, release_id)) or {} if release_record: release_record["status"] = "inactive" release_record["deactivatedAt"] = _now_ms() - await Storage.write(_release_key(workflow_id, str(release_id)), release_record) + await WorkflowStore.kv_put(_release_key(workflow_id, str(release_id)), release_record) - await Storage.remove(_runtime_key(workflow_id)) - await Storage.remove(_active_release_key(workflow_id)) + await WorkflowStore.kv_remove(_runtime_key(workflow_id)) + await WorkflowStore.kv_remove(_active_release_key(workflow_id)) registry["publishStatus"] = "stopped" registry["updatedAt"] = _now_ms() registry["serviceUrl"] = None - await Storage.write(_registry_key(workflow_id), registry) + await WorkflowStore.kv_put(_registry_key(workflow_id), registry) return {"workflowId": workflow_id, "status": "stopped", "stopped": True} async def get_workflow_health(workflow_id: str) -> Dict[str, Any]: """Get workflow container and HTTP health status.""" _ = await _read_registry(workflow_id) - runtime = await Storage.read(_runtime_key(workflow_id)) + runtime = await WorkflowStore.kv_get(_runtime_key(workflow_id)) if not runtime: return {"workflowId": workflow_id, "published": False, "containerRunning": False, "ok": False} @@ -1221,7 +1226,9 @@ async def get_workflow_health(workflow_id: str) -> Dict[str, Any]: "driver": "local", } - docker_state = await docker_container_state(container_name) if container_name else {"exists": False, "running": False} + docker_state = ( + await docker_container_state(container_name) if container_name else {"exists": False, "running": False} + ) endpoint_ok = False endpoint_payload: Dict[str, Any] = {} @@ -1255,7 +1262,7 @@ async def invoke_published_workflow( ) -> Dict[str, Any]: """Invoke active published workflow service by workflow_id.""" _ = await _read_registry(workflow_id) - runtime = await Storage.read(_runtime_key(workflow_id)) + runtime = await WorkflowStore.kv_get(_runtime_key(workflow_id)) if not runtime: raise WorkflowNotPublishedError(f"Workflow not published: {workflow_id}") @@ -1288,13 +1295,13 @@ async def invoke_published_workflow( async def list_workflow_releases(workflow_id: str) -> List[Dict[str, Any]]: """List release history for one workflow.""" _ = await _read_registry(workflow_id) - keys = await Storage.list(f"{_RELEASE_PREFIX}{workflow_id}/") + keys = await WorkflowStore.kv_list(f"{_RELEASE_PREFIX}{workflow_id}/") releases: List[Dict[str, Any]] = [] for raw_key in keys: key = _key_to_string(raw_key) if key.endswith("/active"): continue - release = await Storage.read(key) + release = await WorkflowStore.kv_get(key) if release: releases.append(release) releases.sort(key=lambda item: item.get("createdAt", 0), reverse=True) diff --git a/flocks/workflow/edge_resolver.py b/flocks/workflow/edge_resolver.py new file mode 100644 index 000000000..466c3722b --- /dev/null +++ b/flocks/workflow/edge_resolver.py @@ -0,0 +1,190 @@ +"""Workflow edge selection and input mapping.""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Literal, Optional + +from .models import Edge, Node + + +_logger = logging.getLogger("flocks.workflow.edge_resolver") +_BROADCAST_NODE_TYPES = {"python", "tool", "llm", "http_request", "subworkflow"} + + +class EdgeResolver: + """Resolve selected edges and downstream inputs for one workflow run.""" + + def __init__( + self, + *, + dataflow_mode: Literal["legacy", "vertex_cache"], + trace: bool = False, + ) -> None: + self.dataflow_mode = dataflow_mode + self.trace = trace + + def resolve( + self, + *, + node: Node, + node_inputs: Dict[str, Any], + node_outputs: Dict[str, Any], + edges: List[Edge], + ) -> list[tuple[Edge, Dict[str, Any]]]: + if self.dataflow_mode == "vertex_cache": + scopes = [node_outputs, node_inputs] + selected = self.select_edges_from_scopes(node, scopes, edges) + return [(edge, self.build_downstream_inputs_from_scopes(scopes, edge)) for edge in selected] + + upstream = dict(node_inputs) + upstream.update(node_outputs) + selected = self.select_edges(node, upstream, edges) + return [(edge, self.build_downstream_inputs(upstream, edge)) for edge in selected] + + def select_edges(self, node: Node, payload: Dict[str, Any], edges: List[Edge]) -> List[Edge]: + if not edges: + return [] + if node.type in _BROADCAST_NODE_TYPES: + return list(edges) + key = node.select_key or "result" + value = self.get_by_path(payload, key) + return self._select_by_label(value, edges) + + def select_edges_from_scopes( + self, + node: Node, + scopes: List[Dict[str, Any]], + edges: List[Edge], + ) -> List[Edge]: + if not edges: + return [] + if node.type in _BROADCAST_NODE_TYPES: + return list(edges) + key = node.select_key or "result" + found, value = self.try_get_by_path_from_scopes(scopes, key) + return self._select_by_label(value if found else None, edges) + + def build_downstream_inputs(self, upstream: Dict[str, Any], edge: Edge) -> Dict[str, Any]: + if edge.mapping: + out: Dict[str, Any] = {} + for dst, src in edge.mapping.items(): + found, value = self.try_get_by_path(upstream, src) + if found: + out[dst] = value + elif self.trace: + self._log_missing_mapping(edge, dst, src, list(upstream.keys())[:10]) + else: + out = dict(upstream) + if edge.const: + out.update(edge.const) + return out + + def build_downstream_inputs_from_scopes( + self, + scopes: List[Dict[str, Any]], + edge: Edge, + ) -> Dict[str, Any]: + if edge.mapping: + out: Dict[str, Any] = {} + for dst, src in edge.mapping.items(): + found, value = self.try_get_by_path_from_scopes(scopes, src) + if found: + out[dst] = value + elif self.trace: + available_keys: list[str] = [] + for scope in scopes: + available_keys.extend(str(key) for key in list(scope.keys())[:10]) + self._log_missing_mapping(edge, dst, src, available_keys[:10]) + else: + # Compatibility fallback for opt-in workflows that have not yet + # enabled strict edge mappings. This preserves legacy input shape, + # but does not get the memory benefits of vertex-cache dataflow. + out = {} + for scope in reversed(scopes): + out.update(scope) + if edge.const: + out.update(edge.const) + return out + + def try_get_by_path_from_scopes( + self, + scopes: List[Dict[str, Any]], + path: str, + ) -> tuple[bool, Any]: + if str(path or "").strip() == "$": + return True, self._merge_scopes(scopes) + for scope in scopes: + found, value = self.try_get_by_path(scope, path) + if found: + return True, value + return False, None + + def _merge_scopes(self, scopes: List[Dict[str, Any]]) -> Dict[str, Any]: + out: Dict[str, Any] = {} + for scope in reversed(scopes): + out.update(scope) + return out + + def try_get_by_path(self, data: Any, path: str) -> tuple[bool, Any]: + if path is None: + return False, None + path = str(path).strip() + if not path: + return False, None + if path == "$": + return True, data + if path.startswith("$."): + path = path[2:] + cur: Any = data + for part in path.split("."): + if isinstance(cur, dict): + if part in cur: + cur = cur[part] + else: + return False, None + elif isinstance(cur, list): + try: + idx = int(part) + except Exception: + return False, None + if 0 <= idx < len(cur): + cur = cur[idx] + else: + return False, None + else: + return False, None + return True, cur + + def get_by_path(self, data: Any, path: str) -> Any: + found, value = self.try_get_by_path(data, path) + return value if found else None + + def _select_by_label(self, value: Any, edges: List[Edge]) -> List[Edge]: + selected_label: Optional[str] + if value is None: + selected_label = None + elif isinstance(value, bool): + selected_label = "true" if value else "false" + elif isinstance(value, str): + selected_label = value + else: + selected_label = str(value) + matched = [e for e in edges if e.label == selected_label] if selected_label is not None else [] + if matched: + return matched + defaults = [e for e in edges if e.label is None] + return defaults[:1] if defaults else [] + + def _log_missing_mapping(self, edge: Edge, dst: str, src: str, available_keys: list[Any]) -> None: + _logger.warning( + "wf.edge.mapping.none_value", + extra={ + "edge_from": edge.from_, + "edge_to": edge.to, + "dst_key": dst, + "src_path": src, + "available_keys": available_keys, + "dataflow_mode": self.dataflow_mode, + }, + ) diff --git a/flocks/workflow/engine.py b/flocks/workflow/engine.py index 60c86ea98..e0621865a 100644 --- a/flocks/workflow/engine.py +++ b/flocks/workflow/engine.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError, as_completed from dataclasses import dataclass, field import hashlib +from itertools import islice import logging import time import traceback @@ -11,10 +12,11 @@ import json from typing import Any, Callable, Deque, Dict, List, Literal, NamedTuple, Optional, Set, Tuple, TypeVar -from pydantic import BaseModel, Field - from .code_gen import CodeGen, SimpleCodeGen, LLMCodeGen +from .edge_resolver import EdgeResolver from .errors import MaxStepsExceededError, NodeExecutionError, RunCancelledError, RunTimeoutError +from .execution_plan import WorkflowExecutionPlan +from .execution_state import ExecutionResult, StepResult, WorkflowExecutionState from .models import Edge, Workflow, Node from .repl_runtime import PythonExecRuntime, Runtime @@ -29,6 +31,7 @@ class _ExecOutcome(NamedTuple): """Result of executing a single node within a batch.""" + idx: int outputs: Dict[str, Any] stdout: str @@ -45,23 +48,24 @@ def _summarize_for_observability(value: Any, *, depth: int = 0) -> Any: if isinstance(value, (list, tuple, set)): return {"_type": type(value).__name__, "count": len(value)} if isinstance(value, dict): - return {"_type": "dict", "keys": list(value.keys())[:20]} + return {"_type": "dict", "keys": list(islice(value.keys(), 20))} if isinstance(value, str) and len(value) > 200: return {"_type": "string", "chars": len(value), "preview": value[:200]} return value if isinstance(value, dict): return { key: _summarize_for_observability(item, depth=depth + 1) - for key, item in list(value.items())[:50] + for key, item in islice(value.items(), 50) } if isinstance(value, (list, tuple, set)): + if isinstance(value, (list, tuple)): + preview_items = value[:3] + else: + preview_items = islice(value, 3) return { "_type": type(value).__name__, "count": len(value), - "preview": [ - _summarize_for_observability(item, depth=depth + 1) - for item in list(value)[:3] - ], + "preview": [_summarize_for_observability(item, depth=depth + 1) for item in preview_items], } if isinstance(value, str) and len(value) > 200: return {"_type": "string", "chars": len(value), "preview": value[:200]} @@ -83,37 +87,26 @@ def _outputs_for_log(outputs: Dict[str, Any], *, max_chars: int = 4000) -> str: return text[:max_chars] + f"...[truncated:{len(text) - max_chars}]" -class StepResult(BaseModel): - node_id: str - inputs: Dict[str, Any] = Field(default_factory=dict) - outputs: Dict[str, Any] = Field(default_factory=dict) - stdout: str = "" - error: Optional[str] = None - traceback: Optional[str] = None - duration_ms: Optional[float] = None - - -class ExecutionResult(BaseModel): - steps: int - history: list[StepResult] = Field(default_factory=list) - last_node_id: Optional[str] = None - outputs: Dict[str, Any] = Field(default_factory=dict) - run_id: str = Field(default_factory=lambda: uuid.uuid4().hex) - - def _default_workflow_loader(workflow_id: str) -> "Workflow": - """Default loader: resolves workflow from Storage by ID (sync wrapper).""" + """Default loader: resolves workflow by ID from disk, then legacy KV.""" import asyncio async def _load(): + from flocks.workflow.fs_store import read_workflow_from_fs from flocks.storage.storage import Storage - data = await Storage.read(f"workflow/{workflow_id}") + + data = read_workflow_from_fs(workflow_id) + if data is None: + data = await Storage.read(f"workflow/{workflow_id}") if data is None: raise NodeExecutionError( node_id="", message=f"Workflow not found: {workflow_id!r}", ) wf_json = data.get("workflowJson") or data + if isinstance(wf_json, dict): + wf_json = dict(wf_json) + wf_json.setdefault("id", workflow_id) return Workflow.from_dict(wf_json) try: @@ -145,6 +138,8 @@ class WorkflowEngine: workflow_path: Optional[str] = None node_timeout_s: Optional[float] = 300.0 history_mode: Literal["full", "summary"] = "summary" + dataflow_mode: Literal["legacy", "vertex_cache"] = "legacy" + execution_plan: Optional[WorkflowExecutionPlan] = None _depth: int = 0 max_parallel_workers: int = 4 workflow_loader: Optional[Callable[[str], "Workflow"]] = field(default=None, repr=False) @@ -156,6 +151,8 @@ def __post_init__(self) -> None: raise ValueError("max_parallel_workers must be >= 1") if self.history_mode not in ("full", "summary"): raise ValueError("history_mode must be 'full' or 'summary'") + if self.dataflow_mode not in ("legacy", "vertex_cache"): + raise ValueError("dataflow_mode must be 'legacy' or 'vertex_cache'") if self.runtime is None: self.runtime = PythonExecRuntime() if self.code_gen is None: @@ -181,6 +178,7 @@ def _get_isolated_runtime(self) -> "Runtime": globals=dict(self.runtime.globals), tool_registry=self.runtime.tool_registry, cancel_checker=self.runtime.cancel_checker, + cleanup_globals_after_execute=self.runtime.cleanup_globals_after_execute, ) return self.runtime @@ -196,24 +194,23 @@ def run( retain_history: bool = False, ) -> ExecutionResult: assert self.runtime is not None - nodes = self.workflow.nodes_by_id() - adj = self.workflow.adjacency() - incoming_from: Dict[str, List[str]] = {n.id: [] for n in self.workflow.nodes} - for e in self.workflow.edges: - incoming_from.setdefault(e.to, []).append(e.from_) - for k in incoming_from: - incoming_from[k].sort() - q: Deque[Tuple[str, Dict[str, Any], Optional[str]]] = deque( - [(self.workflow.start, initial_inputs or {}, None)] - ) - history: list[StepResult] = [] - last_outputs: Dict[str, Any] = {} - step_count = 0 - last_node_id: Optional[str] = None + if self.execution_plan is not None and self.execution_plan.workflow is self.workflow: + nodes = self.execution_plan.nodes_by_id + adj = self.execution_plan.adjacency + incoming_from = self.execution_plan.incoming_from + else: + nodes = self.workflow.nodes_by_id() + adj = self.workflow.adjacency() + incoming_from: Dict[str, List[str]] = {n.id: [] for n in self.workflow.nodes} + for e in self.workflow.edges: + incoming_from.setdefault(e.to, []).append(e.from_) + for k in incoming_from: + incoming_from[k].sort() + q: Deque[Tuple[str, Dict[str, Any], Optional[str]]] = deque([(self.workflow.start, initial_inputs or {}, None)]) rid = (run_id or uuid.uuid4().hex).strip() or uuid.uuid4().hex + state = WorkflowExecutionState(run_id=rid, history_mode=self.history_mode, retain_history=retain_history) + edge_resolver = EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace) run_t0 = time.perf_counter() - join_inputs: Dict[str, Dict[str, Dict[str, Any]]] = {} - join_seen_sources: Dict[str, Set[str]] = {} # Dedup: track (node_id -> last_input_hash) to skip identical re-executions. _dedup_hashes: Dict[str, str] = {} step_timeout_s = self.node_timeout_s if (self.node_timeout_s is not None and self.node_timeout_s > 0) else None @@ -225,22 +222,10 @@ def run( previous_cancel_checker = self.runtime.cancel_checker self.runtime.cancel_checker = cancel try: - def _retain_step(step: StepResult) -> None: - if retain_history: - history.append(step) - - def _build_execution_context() -> Dict[str, Any]: - return { - "run_id": rid, - "steps": step_count, - "last_node_id": last_node_id, - "outputs": last_outputs, - "history": history, - } def _raise_cancelled() -> None: err = RunCancelledError(rid) - err.execution_context = _build_execution_context() + err.execution_context = state.build_context() raise err while q: @@ -249,28 +234,28 @@ def _raise_cancelled() -> None: if timeout_s is not None and timeout_s > 0: if (time.perf_counter() - run_t0) > float(timeout_s): err = RunTimeoutError(rid, float(timeout_s)) - err.execution_context = _build_execution_context() + err.execution_context = state.build_context() raise err - if step_count >= self.max_steps: + if state.steps >= self.max_steps: err = MaxStepsExceededError(self.max_steps) - err.execution_context = _build_execution_context() + err.execution_context = state.build_context() raise err # ── Phase 1: drain queue, apply join / dedup ────────────── ready: List[Tuple[str, Node, Dict[str, Any], Optional[str]]] = [] while q: node_id, inputs, src_node_id = q.popleft() - last_node_id = node_id + state.last_node_id = node_id node = nodes[node_id] # Join handling if getattr(node, "join", False) and incoming_from.get(node_id): expected = incoming_from[node_id] - by_src = join_inputs.setdefault(node_id, {}) + by_src = state.join_inputs.setdefault(node_id, {}) src_key = src_node_id or "__start__" buf = by_src.setdefault(src_key, {}) buf.update(inputs) - seen = join_seen_sources.setdefault(node_id, set()) + seen = state.join_seen_sources.setdefault(node_id, set()) if src_node_id is not None: if src_node_id in expected: seen.add(src_node_id) @@ -278,7 +263,7 @@ def _raise_cancelled() -> None: seen.add("__start__") if len(seen.intersection(set(expected))) < len(expected): continue - by_src = join_inputs.pop(node_id, by_src) + by_src = state.join_inputs.pop(node_id, by_src) merged: Dict[str, Any] = {} origin: Dict[str, str] = {} conflict_mode = getattr(node, "join_conflict", "overwrite") @@ -308,7 +293,7 @@ def merge_payload(src: str, payload: Dict[str, Any]) -> None: if join_mode == "namespace": merged.setdefault(namespace_key, dict(by_src)) inputs = merged - join_seen_sources.pop(node_id, None) + state.join_seen_sources.pop(node_id, None) # Dedup: skip if same node already ran with identical inputs. # Lightweight history mode is used by high-throughput ingest @@ -321,7 +306,9 @@ def merge_payload(src: str, payload: Dict[str, Any]) -> None: try: _hash_raw = json.dumps( {"n": node_id, "i": inputs}, - sort_keys=True, ensure_ascii=False, default=str, + sort_keys=True, + ensure_ascii=False, + default=str, ) _input_hash = hashlib.sha256(_hash_raw.encode()).hexdigest()[:16] except Exception: @@ -329,7 +316,8 @@ def merge_payload(src: str, payload: Dict[str, Any]) -> None: if _input_hash and node_id in _dedup_hashes and _dedup_hashes[node_id] == _input_hash: _logger.info( "wf.step.dedup_skip node=%s (identical input hash %s)", - node_id, _input_hash, + node_id, + _input_hash, extra={"run_id": rid, "node_id": node_id, "input_hash": _input_hash}, ) continue @@ -352,20 +340,24 @@ def merge_payload(src: str, payload: Dict[str, Any]) -> None: _desc = (_nd.description or "").strip().splitlines()[0:1] _desc_text = _desc[0] if _desc else "" _par_tag = " [parallel]" if use_parallel else "" - print(f"\n[WF] step={step_count+_idx+1} node={_nid} type={_nd.type} {_desc_text}{_par_tag}".rstrip()) + print( + f"\n[WF] step={state.steps + _idx + 1} node={_nid} type={_nd.type} {_desc_text}{_par_tag}".rstrip() + ) _logger.info( "wf.step.start step=%s node=%s type=%s%s", - step_count + _idx + 1, _nid, _nd.type, + state.steps + _idx + 1, + _nid, + _nd.type, " (parallel)" if use_parallel else "", - extra={"run_id": rid, "step": step_count + _idx + 1, "node_id": _nid, "node_type": _nd.type}, + extra={"run_id": rid, "step": state.steps + _idx + 1, "node_id": _nid, "node_type": _nd.type}, ) if on_step_start is not None: try: - step_tokens[_idx] = on_step_start(rid, step_count + _idx + 1, _nd, _inp) + step_tokens[_idx] = on_step_start(rid, state.steps + _idx + 1, _nd, _inp) except Exception: _logger.exception( "wf.step_start.hook_error", - extra={"run_id": rid, "step": step_count + _idx + 1, "node_id": _nid}, + extra={"run_id": rid, "step": state.steps + _idx + 1, "node_id": _nid}, ) if use_parallel: @@ -383,9 +375,14 @@ def _par_exec( return _ExecOutcome(_pi, _pouts, _pso, None, None, (time.perf_counter() - _t0) * 1000.0) except RunCancelledError as _ce: return _ExecOutcome( - _pi, {}, "", str(_ce), None, + _pi, + {}, + "", + str(_ce), + None, (time.perf_counter() - _t0) * 1000.0, - False, True, + False, + True, ) except Exception as _pe: _perr = str(_pe) @@ -411,13 +408,17 @@ def _par_exec( except FuturesTimeoutError: for _f2, _ci2 in _fut_map.items(): if _ci2 not in _completed_idx: - exec_results.append(_ExecOutcome( - idx=_ci2, outputs={}, stdout="", - error=f"节点执行超时 ({self.node_timeout_s}s)", - traceback=None, - duration_ms=(step_timeout_s or 0) * 1000.0, - is_timeout=True, - )) + exec_results.append( + _ExecOutcome( + idx=_ci2, + outputs={}, + stdout="", + error=f"节点执行超时 ({self.node_timeout_s}s)", + traceback=None, + duration_ms=(step_timeout_s or 0) * 1000.0, + is_timeout=True, + ) + ) finally: try: _pool.shutdown(wait=False, cancel_futures=True) @@ -434,11 +435,17 @@ def _par_exec( _outs, _so = _sfut.result(timeout=step_timeout_s) else: _outs, _so = self._execute_node(_nd, _inp) - exec_results.append(_ExecOutcome(_idx, _outs, _so, None, None, (time.perf_counter() - _t0) * 1000.0)) + exec_results.append( + _ExecOutcome(_idx, _outs, _so, None, None, (time.perf_counter() - _t0) * 1000.0) + ) except FuturesTimeoutError as _fte: _fte_msg = str(_fte).strip() _err = _fte_msg if _fte_msg else f"节点执行超时 ({self.node_timeout_s}s)" - exec_results.append(_ExecOutcome(_idx, {}, "", _err, None, (time.perf_counter() - _t0) * 1000.0, is_timeout=True)) + exec_results.append( + _ExecOutcome( + _idx, {}, "", _err, None, (time.perf_counter() - _t0) * 1000.0, is_timeout=True + ) + ) if timeout_executor is not None: try: timeout_executor.shutdown(wait=False, cancel_futures=True) @@ -446,7 +453,7 @@ def _par_exec( timeout_executor.shutdown(wait=False) timeout_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="wf-node") except RunCancelledError as _ce: - _ce.execution_context = _build_execution_context() + _ce.execution_context = state.build_context() raise except Exception as _e: _err = str(_e) @@ -456,18 +463,18 @@ def _par_exec( _so = _e.stdout or "" if isinstance(_e, NodeExecutionError) and getattr(_e, "traceback", None): _tb = _e.traceback - exec_results.append(_ExecOutcome(_idx, {}, _so, _err, _tb, (time.perf_counter() - _t0) * 1000.0)) + exec_results.append( + _ExecOutcome(_idx, {}, _so, _err, _tb, (time.perf_counter() - _t0) * 1000.0) + ) # ── Phase 3: record results, hooks, enqueue downstream ──── _stop_exc: Optional[NodeExecutionError] = None for _eo in exec_results: _nid, _nd, _inp, _src = ready[_eo.idx] - _sn = step_count + _eo.idx + 1 - last_node_id = _nid - last_outputs = ( - _summarize_for_observability(_eo.outputs) - if self.history_mode == "summary" - else _eo.outputs + _sn = state.steps + _eo.idx + 1 + state.last_node_id = _nid + state.last_outputs = ( + _summarize_for_observability(_eo.outputs) if self.history_mode == "summary" else _eo.outputs ) def _build_step_result( @@ -504,7 +511,7 @@ def _build_step_result( error=_eo.error or "Run cancelled", traceback_text=_eo.traceback, ) - _retain_step(step_res) + state.retain_step(step_res) if on_step_end is not None and _eo.idx in step_tokens: try: on_step_end(step_tokens[_eo.idx], step_res) @@ -530,24 +537,40 @@ def _build_step_result( error=_eo.error, traceback_text=_eo.traceback, ) - _retain_step(step_res) + state.retain_step(step_res) _status = "timeout" if _eo.is_timeout else "error" (_logger.warning if _eo.is_timeout else _logger.error)( f"wf.step.{_status}", extra={ - "run_id": rid, "step": _sn, "node_id": _nid, - "node_type": _nd.type, "error": _eo.error, - **({"timeout_s": self.node_timeout_s} if _eo.is_timeout else {"traceback": (_eo.traceback or "")[:500]}), + "run_id": rid, + "step": _sn, + "node_id": _nid, + "node_type": _nd.type, + "error": _eo.error, + **( + {"timeout_s": self.node_timeout_s} + if _eo.is_timeout + else {"traceback": (_eo.traceback or "")[:500]} + ), }, ) outputs_keys = list(_eo.outputs.keys()) _logger.info( "wf.step.end step=%s node=%s type=%s status=%s duration_ms=%.3f outputs_keys=%s", - _sn, _nid, _nd.type, _status, _eo.duration_ms, outputs_keys, + _sn, + _nid, + _nd.type, + _status, + _eo.duration_ms, + outputs_keys, extra={ - "run_id": rid, "step": _sn, "node_id": _nid, - "node_type": _nd.type, "status": _status, - "duration_ms": _eo.duration_ms, "outputs_keys": outputs_keys, + "run_id": rid, + "step": _sn, + "node_id": _nid, + "node_type": _nd.type, + "status": _status, + "duration_ms": _eo.duration_ms, + "outputs_keys": outputs_keys, "error": _eo.error, }, ) @@ -555,10 +578,16 @@ def _build_step_result( outputs_for_debug = _outputs_for_log(_eo.outputs) _logger.debug( "wf.step.outputs step=%s node=%s status=%s outputs=%s", - _sn, _nid, _status, outputs_for_debug, + _sn, + _nid, + _status, + outputs_for_debug, extra={ - "run_id": rid, "step": _sn, "node_id": _nid, - "node_type": _nd.type, "status": _status, + "run_id": rid, + "step": _sn, + "node_id": _nid, + "node_type": _nd.type, + "status": _status, "outputs": outputs_for_debug, "error": _eo.error, }, @@ -573,14 +602,16 @@ def _build_step_result( ) if self.stop_on_error and _stop_exc is None and not _eo.is_timeout: _stop_exc = NodeExecutionError( - node_id=_nid, message=_eo.error, - stdout=_eo.stdout, traceback=_eo.traceback, + node_id=_nid, + message=_eo.error, + stdout=_eo.stdout, + traceback=_eo.traceback, execution_context={ "run_id": rid, - "steps": step_count + len(exec_results), + "steps": state.steps + len(exec_results), "last_node_id": _nid, - "outputs": last_outputs, - "history": history, + "outputs": state.last_outputs, + "history": state.history, }, ) continue @@ -598,15 +629,24 @@ def _build_step_result( stdout=_eo.stdout, error=None, ) - _retain_step(step_res) + state.retain_step(step_res) outputs_keys = list(_eo.outputs.keys()) _logger.info( "wf.step.end step=%s node=%s type=%s status=%s duration_ms=%.3f outputs_keys=%s", - _sn, _nid, _nd.type, "ok", _eo.duration_ms, outputs_keys, + _sn, + _nid, + _nd.type, + "ok", + _eo.duration_ms, + outputs_keys, extra={ - "run_id": rid, "step": _sn, "node_id": _nid, - "node_type": _nd.type, "status": "ok", - "duration_ms": _eo.duration_ms, "outputs_keys": outputs_keys, + "run_id": rid, + "step": _sn, + "node_id": _nid, + "node_type": _nd.type, + "status": "ok", + "duration_ms": _eo.duration_ms, + "outputs_keys": outputs_keys, "error": None, }, ) @@ -614,10 +654,16 @@ def _build_step_result( outputs_for_debug = _outputs_for_log(_eo.outputs) _logger.debug( "wf.step.outputs step=%s node=%s status=%s outputs=%s", - _sn, _nid, "ok", outputs_for_debug, + _sn, + _nid, + "ok", + outputs_for_debug, extra={ - "run_id": rid, "step": _sn, "node_id": _nid, - "node_type": _nd.type, "status": "ok", + "run_id": rid, + "step": _sn, + "node_id": _nid, + "node_type": _nd.type, + "status": "ok", "outputs": outputs_for_debug, "error": None, }, @@ -633,37 +679,34 @@ def _build_step_result( # Enqueue downstream (skip for failed node when stop_on_error) if _stop_exc is None: - upstream = dict(_inp) - upstream.update(_eo.outputs) - selected = self._select_edges(_nd, upstream, adj.get(_nid, [])) - for edge in selected: - q.append((edge.to, self._build_downstream_inputs(upstream, edge), _nid)) - - step_count += len(exec_results) + state.record_vertex_output(_nid, _eo.outputs) + for edge, edge_inputs in edge_resolver.resolve( + node=_nd, + node_inputs=_inp, + node_outputs=_eo.outputs, + edges=adj.get(_nid, []), + ): + q.append((edge.to, edge_inputs, _nid)) + + state.steps += len(exec_results) if _stop_exc is not None: raise _stop_exc if cancel is not None and cancel(): _raise_cancelled() pending_joins = [] - for nid, buf in join_inputs.items(): + for nid, buf in state.join_inputs.items(): n = nodes.get(nid) if n is not None and getattr(n, "join", False): expected = incoming_from.get(nid, []) - seen = join_seen_sources.get(nid, set()) + seen = state.join_seen_sources.get(nid, set()) pending_joins.append((nid, expected, sorted(seen))) if pending_joins: msg = "Join node(s) did not receive all incoming inputs: " + "; ".join( f"{nid} expected={expected} seen={seen}" for nid, expected, seen in pending_joins ) raise NodeExecutionError(node_id=pending_joins[0][0], message=msg) - return ExecutionResult( - steps=step_count, - history=history, - last_node_id=last_node_id, - outputs=last_outputs, - run_id=rid, - ) + return state.to_result() finally: if isinstance(self.runtime, PythonExecRuntime): self.runtime.cancel_checker = previous_cancel_checker @@ -749,9 +792,7 @@ def _execute_node( return self._execute_http_request_node(node, inputs) if node.type == "subworkflow": return self._execute_subworkflow_node(node, inputs, _runtime=_runtime) - raise NodeExecutionError( - node_id=node_id, message=f"Unsupported node.type={node.type!r}" - ) + raise NodeExecutionError(node_id=node_id, message=f"Unsupported node.type={node.type!r}") def _execute_tool_node( self, @@ -801,6 +842,7 @@ def _execute_llm_node( assert node.prompt, "llm node requires prompt" try: from jinja2 import Template, TemplateError + rendered = Template(node.prompt).render(**inputs) except Exception as e: raise NodeExecutionError( @@ -808,6 +850,7 @@ def _execute_llm_node( message=f"Prompt template render failed: {type(e).__name__}: {e}", ) from e from .llm import get_llm_client + _rt = _runtime or self.runtime cancel_checker = getattr(_rt, "cancel_checker", None) try: @@ -829,6 +872,7 @@ def _execute_http_request_node(self, node: Node, inputs: Dict[str, Any]) -> Tupl assert node.method, "http_request node requires method" try: from jinja2 import Template + url = Template(node.url).render(**inputs) method = node.method.upper() headers = node.headers or {} @@ -842,6 +886,7 @@ def _execute_http_request_node(self, node: Node, inputs: Dict[str, Any]) -> Tupl ) from e try: import httpx + with httpx.Client(timeout=30.0) as client: if method in {"GET", "DELETE", "HEAD"}: resp = client.request(method, url, headers=headers) @@ -904,7 +949,9 @@ def _execute_subworkflow_node( trace=self.trace, node_timeout_s=self.node_timeout_s, history_mode=self.history_mode, + dataflow_mode=self.dataflow_mode, _depth=self._depth + 1, + max_parallel_workers=self.max_parallel_workers, workflow_loader=self.workflow_loader, ) result = sub_engine.run(initial_inputs=sub_inputs) @@ -913,76 +960,42 @@ def _execute_subworkflow_node( return {output_k: last_outputs}, "" def _select_edges(self, node: Any, payload: Dict[str, Any], edges: List[Edge]) -> List[Edge]: - if not edges: - return [] - if node.type in {"python", "tool", "llm", "http_request", "subworkflow"}: - return list(edges) - key = node.select_key or "result" - value = self._get_by_path(payload, key) - selected_label: Optional[str] - if value is None: - selected_label = None - elif isinstance(value, bool): - selected_label = "true" if value else "false" - elif isinstance(value, str): - selected_label = value - else: - selected_label = str(value) - matched = [e for e in edges if e.label == selected_label] if selected_label is not None else [] - if matched: - return matched - defaults = [e for e in edges if e.label is None] - return defaults[:1] if defaults else [] + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).select_edges(node, payload, edges) + + def _select_edges_from_scopes( + self, + node: Any, + scopes: List[Dict[str, Any]], + edges: List[Edge], + ) -> List[Edge]: + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).select_edges_from_scopes( + node, + scopes, + edges, + ) def _build_downstream_inputs(self, upstream: Dict[str, Any], edge: Edge) -> Dict[str, Any]: - if edge.mapping: - out: Dict[str, Any] = {} - for dst, src in edge.mapping.items(): - found, value = self._try_get_by_path(upstream, src) - if found: - out[dst] = value - elif self.trace: - available_keys = list(upstream.keys())[:10] - _logger.warning( - "wf.edge.mapping.none_value", - extra={"edge_from": edge.from_, "edge_to": edge.to, "dst_key": dst, "src_path": src, "available_keys": available_keys}, - ) - else: - out = dict(upstream) - if edge.const: - out.update(edge.const) - return out + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).build_downstream_inputs(upstream, edge) + + def _build_downstream_inputs_from_scopes( + self, + scopes: List[Dict[str, Any]], + edge: Edge, + ) -> Dict[str, Any]: + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).build_downstream_inputs_from_scopes( + scopes, + edge, + ) + + def _try_get_by_path_from_scopes( + self, + scopes: List[Dict[str, Any]], + path: str, + ) -> tuple[bool, Any]: + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).try_get_by_path_from_scopes(scopes, path) def _try_get_by_path(self, data: Any, path: str) -> tuple[bool, Any]: - if path is None: - return False, None - path = str(path).strip() - if not path: - return False, None - if path == "$": - return True, data - if path.startswith("$."): - path = path[2:] - cur: Any = data - for part in path.split("."): - if isinstance(cur, dict): - if part in cur: - cur = cur[part] - else: - return False, None - elif isinstance(cur, list): - try: - idx = int(part) - except Exception: - return False, None - if 0 <= idx < len(cur): - cur = cur[idx] - else: - return False, None - else: - return False, None - return True, cur + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).try_get_by_path(data, path) def _get_by_path(self, data: Any, path: str) -> Any: - found, value = self._try_get_by_path(data, path) - return value if found else None + return EdgeResolver(dataflow_mode=self.dataflow_mode, trace=self.trace).get_by_path(data, path) diff --git a/flocks/workflow/execution_plan.py b/flocks/workflow/execution_plan.py new file mode 100644 index 000000000..ca7e6e4e8 --- /dev/null +++ b/flocks/workflow/execution_plan.py @@ -0,0 +1,91 @@ +"""Reusable static workflow execution plan. + +The plan intentionally contains only immutable-ish workflow structure and +derived metadata. Per-run state such as node outputs, history, joins, and +runtime globals must stay in ``WorkflowExecutionState`` / runtime instances. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional + +from .models import Edge, Node, Workflow +from .requirements import requirements_from_workflow_metadata +from .workflow_lint import lint_workflow + + +def resolve_workflow_dataflow_mode(workflow_metadata: Optional[Dict[str, Any]]) -> Literal["legacy", "vertex_cache"]: + """Resolve workflow dataflow mode from metadata. + + Missing metadata intentionally stays legacy so historical workflow files run + with their original full-payload edge semantics. + """ + if not isinstance(workflow_metadata, dict): + return "legacy" + + candidates: list[Any] = [ + workflow_metadata.get("dataflow_mode"), + workflow_metadata.get("dataflowMode"), + ] + for section_key in ("runtime", "runtime_defaults", "runtimeDefaults"): + section = workflow_metadata.get(section_key) + if isinstance(section, dict): + candidates.extend( + [ + section.get("dataflow_mode"), + section.get("dataflowMode"), + ] + ) + + for value in candidates: + normalized = str(value or "").strip().lower().replace("-", "_") + if normalized in {"vertex_cache", "vertex", "cache"}: + return "vertex_cache" + if normalized in {"legacy", "classic", "default"}: + return "legacy" + return "legacy" + + +def _incoming_edges_by_node(workflow: Workflow) -> Dict[str, List[str]]: + incoming_from: Dict[str, List[str]] = {node.id: [] for node in workflow.nodes} + for edge in workflow.edges: + incoming_from.setdefault(edge.to, []).append(edge.from_) + for node_id in incoming_from: + incoming_from[node_id].sort() + return incoming_from + + +@dataclass(frozen=True) +class WorkflowExecutionPlan: + """Precomputed static workflow graph data safe to reuse across runs.""" + + workflow: Workflow + workflow_path: Optional[str] + use_llm: Optional[bool] + lint_results: tuple[Dict[str, Any], ...] + requirements: tuple[str, ...] + dataflow_mode: Literal["legacy", "vertex_cache"] + nodes_by_id: Dict[str, Node] + adjacency: Dict[str, List[Edge]] + incoming_from: Dict[str, List[str]] + + +def build_workflow_execution_plan( + workflow: Workflow, + *, + workflow_path: Optional[str] = None, + use_llm: Optional[bool] = None, +) -> WorkflowExecutionPlan: + """Build reusable static execution data for a workflow.""" + return WorkflowExecutionPlan( + workflow=workflow, + workflow_path=workflow_path, + use_llm=use_llm, + lint_results=tuple(lint_workflow(workflow)), + requirements=tuple(requirements_from_workflow_metadata(workflow.metadata)), + dataflow_mode=resolve_workflow_dataflow_mode(workflow.metadata), + nodes_by_id=workflow.nodes_by_id(), + adjacency=workflow.adjacency(), + incoming_from=_incoming_edges_by_node(workflow), + ) diff --git a/flocks/workflow/execution_state.py b/flocks/workflow/execution_state.py new file mode 100644 index 000000000..e418217fa --- /dev/null +++ b/flocks/workflow/execution_state.py @@ -0,0 +1,75 @@ +"""Workflow execution state containers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import uuid +from typing import Any, Dict, Literal, Optional, Set + +from pydantic import BaseModel, Field + + +_VERTEX_OUTPUT_KEY_LIMIT = 50 + + +class StepResult(BaseModel): + node_id: str + inputs: Dict[str, Any] = Field(default_factory=dict) + outputs: Dict[str, Any] = Field(default_factory=dict) + stdout: str = "" + error: Optional[str] = None + traceback: Optional[str] = None + duration_ms: Optional[float] = None + + +class ExecutionResult(BaseModel): + steps: int + history: list[StepResult] = Field(default_factory=list) + last_node_id: Optional[str] = None + outputs: Dict[str, Any] = Field(default_factory=dict) + run_id: str = Field(default_factory=lambda: uuid.uuid4().hex) + + +@dataclass +class WorkflowExecutionState: + """Mutable state for one workflow engine run.""" + + run_id: str + history_mode: Literal["full", "summary"] + retain_history: bool = False + steps: int = 0 + last_node_id: Optional[str] = None + last_outputs: Dict[str, Any] = field(default_factory=dict) + history: list[StepResult] = field(default_factory=list) + join_inputs: Dict[str, Dict[str, Dict[str, Any]]] = field(default_factory=dict) + join_seen_sources: Dict[str, Set[str]] = field(default_factory=dict) + vertex_outputs: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + def retain_step(self, step: StepResult) -> None: + if self.retain_history: + self.history.append(step) + + def record_vertex_output(self, node_id: str, outputs: Dict[str, Any]) -> None: + self.vertex_outputs[node_id] = outputs + + def build_context(self) -> Dict[str, Any]: + return { + "run_id": self.run_id, + "steps": self.steps, + "last_node_id": self.last_node_id, + "outputs": self.last_outputs, + "history": self.history, + "vertex_output_keys": { + node_id: list(outputs.keys())[:_VERTEX_OUTPUT_KEY_LIMIT] + for node_id, outputs in self.vertex_outputs.items() + }, + } + + def to_result(self) -> ExecutionResult: + return ExecutionResult( + steps=self.steps, + history=self.history, + last_node_id=self.last_node_id, + outputs=self.last_outputs, + run_id=self.run_id, + ) diff --git a/flocks/workflow/execution_store.py b/flocks/workflow/execution_store.py index f175f97ac..c9389c445 100644 --- a/flocks/workflow/execution_store.py +++ b/flocks/workflow/execution_store.py @@ -3,14 +3,15 @@ from __future__ import annotations import asyncio +from itertools import islice import time import uuid from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from flocks.session.recorder import Recorder -from flocks.storage.storage import Storage from flocks.utils.log import Log from flocks.workflow.runner import RunWorkflowResult +from flocks.workflow.store import WorkflowStore log = Log.create(service="workflow.execution_store") @@ -33,6 +34,105 @@ # protects against accidentally stripping small metadata lists that happen # to share a name with a known large-list key. DEFAULT_COMPACT_SIZE_THRESHOLD: int = 100 +DEFAULT_GENERIC_SEQUENCE_THRESHOLD: int = 1_000 +DEFAULT_MAX_INLINE_STRING_CHARS: int = 20_000 +DEFAULT_MAX_INLINE_DICT_KEYS: int = 200 +DEFAULT_PREVIEW_ITEMS: int = 3 +DEFAULT_PREVIEW_CHARS: int = 500 + + +def _sequence_preview(value: Any, *, limit: int = DEFAULT_PREVIEW_ITEMS) -> list[Any]: + if isinstance(value, (list, tuple)): + items = value[:limit] + else: + items = islice(value, limit) + return [_summarize_large_value(item, depth=1) for item in items] + + +def _summarize_large_value(value: Any, *, depth: int = 0) -> Dict[str, Any]: + if isinstance(value, str): + return { + "_type": "string", + "chars": len(value), + "preview": value[:DEFAULT_PREVIEW_CHARS], + } + if isinstance(value, dict): + return { + "_type": "dict", + "key_count": len(value), + "keys": list(islice(value.keys(), DEFAULT_PREVIEW_ITEMS * 10)), + } + if isinstance(value, (list, tuple, set)): + summary: Dict[str, Any] = { + "_type": type(value).__name__, + "count": len(value), + } + if depth == 0: + summary["preview"] = _sequence_preview(value) + return summary + return { + "_type": type(value).__name__, + "preview": str(value)[:DEFAULT_PREVIEW_CHARS], + } + + +def _compact_value_for_storage( + value: Any, + *, + key: Optional[str], + known_large_keys: frozenset[str], + size_threshold: int, + generic_sequence_threshold: int, + max_inline_string_chars: int, + max_inline_dict_keys: int, + depth: int = 0, +) -> Any: + if ( + key in known_large_keys + and isinstance(value, (list, tuple)) + and len(value) > size_threshold + ): + return {f"_{key}_count": len(value)} + + if isinstance(value, str): + if len(value) > max_inline_string_chars: + return _summarize_large_value(value) + return value + + if isinstance(value, (list, tuple, set)): + if len(value) > generic_sequence_threshold: + return _summarize_large_value(value) + return value + + if isinstance(value, dict): + if len(value) > max_inline_dict_keys: + return _summarize_large_value(value) + if depth >= 2: + return value + compacted: Dict[str, Any] = {} + changed = False + for child_key, child_value in value.items(): + child_compacted = _compact_value_for_storage( + child_value, + key=str(child_key), + known_large_keys=known_large_keys, + size_threshold=size_threshold, + generic_sequence_threshold=generic_sequence_threshold, + max_inline_string_chars=max_inline_string_chars, + max_inline_dict_keys=max_inline_dict_keys, + depth=depth + 1, + ) + if isinstance(child_compacted, dict) and len(child_compacted) == 1: + marker_key = next(iter(child_compacted)) + if marker_key.startswith("_") and marker_key.endswith("_count"): + compacted[marker_key] = child_compacted[marker_key] + changed = True + continue + compacted[child_key] = child_compacted + changed = changed or child_compacted is not child_value + return compacted if changed else value + + return value def compact_outputs_for_storage( @@ -40,37 +140,37 @@ def compact_outputs_for_storage( *, keys: Iterable[str] = DEFAULT_LARGE_LIST_KEYS, size_threshold: int = DEFAULT_COMPACT_SIZE_THRESHOLD, + generic_sequence_threshold: int = DEFAULT_GENERIC_SEQUENCE_THRESHOLD, + max_inline_string_chars: int = DEFAULT_MAX_INLINE_STRING_CHARS, + max_inline_dict_keys: int = DEFAULT_MAX_INLINE_DICT_KEYS, ) -> Dict[str, Any]: - """Return a copy of *outputs* with large alert lists replaced by counts. - - Only **list or tuple** values whose key is in *keys* AND whose length - exceeds *size_threshold* are compacted to ``__count``; everything - else is passed through unchanged. This prevents megabytes of alert data - from being serialised into the ``workflow_execution`` SQLite row on every - invocation, while still keeping small sequences (e.g. error details, short - configuration arrays) fully inspectable in the execution-history UI. - - **Keys that are compacted by default** (see ``DEFAULT_LARGE_LIST_KEYS``): - ``enriched_alerts``, ``unique_alerts``, ``raw_alerts``, - ``normalized_alerts``, ``filtered_alerts``. Keys outside this set — such - as a generic ``alerts`` parameter — are *not* compacted unless the caller - passes a custom *keys* argument. Callers who depend on inspecting the - full list contents of compacted keys must read the data from the JSONL - files written by the workflow itself. + """Return a bounded copy of *outputs* safe for execution records. + + Known large-list keys keep the historical ``__count`` shape. Other + oversized strings, sequences, and dictionaries are replaced with bounded + summaries so unknown workflow payload names cannot inflate SQLite rows or + tool metadata. """ if not isinstance(outputs, dict): return {} - key_set = frozenset(keys) + known_large_keys = frozenset(keys) compacted: Dict[str, Any] = {} for k, v in outputs.items(): - if ( - k in key_set - and isinstance(v, (list, tuple)) - and len(v) > size_threshold - ): - compacted[f"_{k}_count"] = len(v) - else: - compacted[k] = v + value = _compact_value_for_storage( + v, + key=str(k), + known_large_keys=known_large_keys, + size_threshold=size_threshold, + generic_sequence_threshold=generic_sequence_threshold, + max_inline_string_chars=max_inline_string_chars, + max_inline_dict_keys=max_inline_dict_keys, + ) + if isinstance(value, dict) and len(value) == 1: + marker_key = next(iter(value)) + if marker_key == f"_{k}_count": + compacted[marker_key] = value[marker_key] + continue + compacted[k] = value return compacted @@ -89,9 +189,7 @@ def compact_step_for_storage( for field in ("inputs", "outputs"): raw_value = step_copy.get(field) if isinstance(raw_value, dict): - step_copy[field] = compact_outputs_for_storage( - raw_value, keys=keys, size_threshold=size_threshold - ) + step_copy[field] = compact_outputs_for_storage(raw_value, keys=keys, size_threshold=size_threshold) return step_copy @@ -109,10 +207,7 @@ def compact_history_for_storage( """ if not history: return [] - return [ - compact_step_for_storage(step, keys=keys, size_threshold=size_threshold) - for step in history - ] + return [compact_step_for_storage(step, keys=keys, size_threshold=size_threshold) for step in history] def _first_value(data: Dict[str, Any], keys: Iterable[str]) -> Any: @@ -159,23 +254,27 @@ def derive_loop_progress( if isinstance(outputs, dict): merged.update(outputs) - iteration = _as_positive_int(_first_value( - merged, - ("iteration", "loop_index", "current_index", "item_idx", "item_index", "host_idx"), - )) - total = _as_positive_int(_first_value( - merged, - ( - "total_iterations", - "total_items", - "item_count", - "items_count", - "total_hosts", - "host_count", - "hosts_count", - "hosts_total", - ), - )) + iteration = _as_positive_int( + _first_value( + merged, + ("iteration", "loop_index", "current_index", "item_idx", "item_index", "host_idx"), + ) + ) + total = _as_positive_int( + _first_value( + merged, + ( + "total_iterations", + "total_items", + "item_count", + "items_count", + "total_hosts", + "host_count", + "hosts_count", + "hosts_total", + ), + ) + ) if total is None: hosts = merged.get("hosts") if isinstance(hosts, list): @@ -198,6 +297,7 @@ def derive_loop_progress( "global_step_index": global_step_index, } + # Maximum number of execution history records retained per workflow. # Keep this intentionally small so high-frequency workflows do not keep # inflating the SQLite row set and matching JSONL audit files indefinitely. @@ -254,26 +354,15 @@ async def _update_workflow_stats(workflow_id: str, success: bool, duration: floa lock = _get_stats_lock(workflow_id) async with lock: try: - key = _workflow_stats_key(workflow_id) - try: - stats: Dict[str, Any] = await Storage.read(key) or dict(_DEFAULT_STATS) - except Exception: - stats = dict(_DEFAULT_STATS) - stats["callCount"] = stats.get("callCount", 0) + 1 - if success: - stats["successCount"] = stats.get("successCount", 0) + 1 - else: - stats["errorCount"] = stats.get("errorCount", 0) + 1 - total = stats.get("totalRuntime", 0.0) + duration - stats["totalRuntime"] = total - call_count = stats["callCount"] - stats["avgRuntime"] = (total / call_count) if call_count > 0 else 0.0 - await Storage.write(key, stats) + await WorkflowStore.increment_stats(workflow_id, success=success, duration=duration) except Exception as exc: - log.warning("workflow.stats.update_failed", { - "workflow_id": workflow_id, - "error": str(exc), - }) + log.warning( + "workflow.stats.update_failed", + { + "workflow_id": workflow_id, + "error": str(exc), + }, + ) def workflow_execution_key(exec_id: str) -> str: @@ -324,7 +413,7 @@ async def record_execution_step( ) -> Dict[str, Any]: """Persist one compacted execution step and return the stored payload.""" step_payload = compact_step_for_storage(step) - await Storage.write(workflow_execution_step_key(exec_id, step_index), step_payload) + await WorkflowStore.record_step(exec_id, step_index, step_payload) return step_payload @@ -363,26 +452,31 @@ def on_step_complete(self, step_result: Any) -> None: inputs=step_dict.get("inputs"), outputs=step_dict.get("outputs"), ) - self.summary.update({ - "stepCount": self.step_count, - "currentNodeId": step_dict.get("node_id"), - "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), - "currentPhase": "running", - "currentStepIndex": self.step_count, - "loopProgress": loop_progress, - "updatedAt": int(time.time() * 1000), - }) + self.summary.update( + { + "stepCount": self.step_count, + "currentNodeId": step_dict.get("node_id"), + "currentNodeType": step_dict.get("node_type") or step_dict.get("type"), + "currentPhase": "running", + "currentStepIndex": self.step_count, + "loopProgress": loop_progress, + "updatedAt": int(time.time() * 1000), + } + ) try: asyncio.run_coroutine_threadsafe( record_execution_step(self.exec_id, self.step_count, step_dict), self.loop, ).result(timeout=self.write_timeout_s) except Exception as exc: - self.logger.warning(self.log_event, { - "exec_id": self.exec_id, - "step_index": self.step_count, - "error": str(exc), - }) + self.logger.warning( + self.log_event, + { + "exec_id": self.exec_id, + "step_index": self.step_count, + "error": str(exc), + }, + ) async def _backfill_execution_steps( @@ -399,14 +493,17 @@ async def _backfill_execution_steps( if not isinstance(step_payload, dict): continue try: - await Storage.write(workflow_execution_step_key(exec_id, step_index), step_payload) + await WorkflowStore.record_step(exec_id, step_index, step_payload) written += 1 except Exception as exc: - log.warning("workflow.execution_step.backfill_failed", { - "exec_id": exec_id, - "step_index": step_index, - "error": str(exc), - }) + log.warning( + "workflow.execution_step.backfill_failed", + { + "exec_id": exec_id, + "step_index": step_index, + "error": str(exc), + }, + ) return written @@ -418,14 +515,11 @@ async def load_execution_steps( ) -> Tuple[List[Dict[str, Any]], int]: """Load persisted step logs for an execution, sorted by step key.""" page_limit = 500 if limit is None else max(limit, 0) - selected, total = await Storage.list_entries_page( - workflow_execution_step_prefix(exec_id), + return await WorkflowStore.list_steps( + exec_id, offset=max(offset, 0), limit=page_limit, ) - return [ - value for _key, value in selected if isinstance(value, dict) - ], total def normalize_execution_status(status: str) -> str: @@ -462,9 +556,7 @@ def resolve_execution_outcome(result: RunWorkflowResult) -> tuple[str, Optional[ if result.outputs.get("workflow_success") is False: return ( "error", - error_message - or _extract_business_failure_message(result.outputs) - or "Workflow reported business failure.", + error_message or _extract_business_failure_message(result.outputs) or "Workflow reported business failure.", ) return status_value, error_message @@ -510,14 +602,7 @@ async def create_execution_record( input_params=compacted_params, exec_id=exec_id, ) - exec_key = workflow_execution_key(exec_data["id"]) - await Storage.write(exec_key, compact_execution_summary(exec_data)) - await _write_execution_index( - workflow_id=workflow_id, - exec_id=str(exec_data["id"]), - execution_key=exec_key, - started_at=int(exec_data.get("startedAt") or 0), - ) + await WorkflowStore.upsert_execution(compact_execution_summary(exec_data)) return exec_data @@ -533,13 +618,7 @@ async def record_execution_result( if backfilled_steps and (existing_step_count is None or existing_step_count < backfilled_steps): summary_data["stepCount"] = backfilled_steps - await Storage.write(workflow_execution_key(exec_id), compact_execution_summary(summary_data)) - await _write_execution_index( - workflow_id=workflow_id, - exec_id=exec_id, - execution_key=workflow_execution_key(exec_id), - started_at=int(summary_data.get("startedAt") or 0), - ) + await WorkflowStore.upsert_execution(compact_execution_summary(summary_data)) # Update call/success/error counters so all trigger paths (HTTP, syslog, etc.) # are reflected in the UI stats panel. @@ -556,6 +635,7 @@ async def record_execution_result( # Run it as a background task so the syslog/HTTP dispatcher can release the # concurrency slot immediately instead of waiting on session-history I/O. try: + async def _record_audit() -> None: try: await Recorder.record_workflow_execution( @@ -564,10 +644,13 @@ async def _record_audit() -> None: run_result=exec_data, ) except Exception as exc: - log.debug("workflow.audit.record_failed", { - "exec_id": exec_id, - "error": str(exc), - }) + log.debug( + "workflow.audit.record_failed", + { + "exec_id": exec_id, + "error": str(exc), + }, + ) asyncio.create_task(_record_audit(), name=f"audit-{exec_id}") except RuntimeError: @@ -587,51 +670,14 @@ async def _record_audit() -> None: try: await _trim_execution_history(workflow_id) except Exception as exc: - log.error("workflow.history.trim_failed", { - "workflow_id": workflow_id, - "exec_id": exec_id, - "error": str(exc), - }) - - -async def _write_execution_index( - *, - workflow_id: str, - exec_id: str, - execution_key: str, - started_at: int, -) -> str: - index_key = workflow_execution_index_key(workflow_id, started_at, exec_id) - await Storage.write( - index_key, - { - "workflowId": workflow_id, - "execId": exec_id, - "executionKey": execution_key, - "startedAt": started_at, - }, - ) - return index_key - - -async def _list_workflow_execution_entries( - workflow_id: str, -) -> List[Tuple[str, int, Optional[str]]]: - """Return indexed ``(execution_key, started_at, index_key)`` rows for one workflow.""" - indexed_rows = await Storage.list_entries(workflow_execution_index_prefix(workflow_id)) - entries: List[Tuple[str, int, Optional[str]]] = [] - for index_key, value in indexed_rows: - if not isinstance(value, dict): - continue - exec_key = value.get("executionKey") - exec_id = value.get("execId") - if not isinstance(exec_key, str) and isinstance(exec_id, str): - exec_key = workflow_execution_key(exec_id) - if not isinstance(exec_key, str): - continue - started_at = _as_positive_int(value.get("startedAt")) or 0 - entries.append((exec_key, started_at, index_key)) - return entries + log.error( + "workflow.history.trim_failed", + { + "workflow_id": workflow_id, + "exec_id": exec_id, + "error": str(exc), + }, + ) async def _delete_execution_history_record( @@ -640,18 +686,19 @@ async def _delete_execution_history_record( index_key: Optional[str] = None, ) -> None: exec_id = execution_key.rsplit("/", 1)[-1] - deleted_steps = await Storage.clear(workflow_execution_step_prefix(exec_id)) - removed_execution = await Storage.remove(execution_key) - if index_key: - await Storage.remove(index_key) + deleted_steps = await WorkflowStore.clear_steps(exec_id) + removed_execution = await WorkflowStore.delete_execution(exec_id) record_path = Recorder.paths().workflow_dir / f"{exec_id}.jsonl" await asyncio.to_thread(record_path.unlink, missing_ok=True) - log.debug("workflow.history.trim_deleted", { - "exec_id": exec_id, - "execution_key": execution_key, - "steps": deleted_steps, - "removed_execution": removed_execution, - }) + log.debug( + "workflow.history.trim_deleted", + { + "exec_id": exec_id, + "execution_key": execution_key, + "steps": deleted_steps, + "removed_execution": removed_execution, + }, + ) async def _trim_execution_history(workflow_id: str) -> None: @@ -668,27 +715,24 @@ async def _trim_execution_history(workflow_id: str) -> None: """ lock = _get_trim_lock(workflow_id) async with lock: - wf_entries = await _list_workflow_execution_entries(workflow_id) - if len(wf_entries) <= _MAX_EXECUTION_HISTORY_PER_WORKFLOW: - return - - # Sort ascending by startedAt and remove the oldest excess records. - wf_entries.sort(key=lambda kd: kd[1]) - excess = len(wf_entries) - _MAX_EXECUTION_HISTORY_PER_WORKFLOW failures: List[str] = [] - for key, _started_at, index_key in wf_entries[:excess]: + for exec_id in await WorkflowStore.trim_executions( + workflow_id, + keep=_MAX_EXECUTION_HISTORY_PER_WORKFLOW, + ): try: - await _delete_execution_history_record(key, index_key=index_key) + record_path = Recorder.paths().workflow_dir / f"{exec_id}.jsonl" + await asyncio.to_thread(record_path.unlink, missing_ok=True) except Exception as exc: - failures.append(f"{key}: {exc}") - log.warning("workflow.history.trim_delete_failed", { - "workflow_id": workflow_id, - "key": key, - "error": str(exc), - }) + failures.append(f"{exec_id}: {exc}") + log.warning( + "workflow.history.trim_delete_failed", + { + "workflow_id": workflow_id, + "exec_id": exec_id, + "error": str(exc), + }, + ) if failures: - raise RuntimeError( - "Failed to trim workflow execution history: " - + "; ".join(failures[:3]) - ) + raise RuntimeError("Failed to trim workflow execution history: " + "; ".join(failures[:3])) diff --git a/flocks/workflow/models.py b/flocks/workflow/models.py index e20f296e3..57f3761a4 100644 --- a/flocks/workflow/models.py +++ b/flocks/workflow/models.py @@ -10,6 +10,8 @@ class Node(BaseModel): + model_config = ConfigDict(populate_by_name=True) + id: str = Field(min_length=1) type: Literal[ "python", "logic", "branch", "loop", @@ -22,6 +24,8 @@ class Node(BaseModel): join_mode: Literal["flat", "namespace"] = "flat" join_conflict: Literal["overwrite", "error"] = "overwrite" join_namespace_key: str = "__by_source__" + input_schema: Optional[Dict[str, Any]] = Field(None, alias="inputSchema") + output_schema: Optional[Dict[str, Any]] = Field(None, alias="outputSchema") # tool 节点 tool_name: Optional[str] = None diff --git a/flocks/workflow/poller_manager.py b/flocks/workflow/poller_manager.py index 923baf500..c4596286c 100644 --- a/flocks/workflow/poller_manager.py +++ b/flocks/workflow/poller_manager.py @@ -15,7 +15,6 @@ from croniter import croniter -from flocks.storage.storage import Storage from flocks.utils.log import Log from flocks.workflow.execution_store import ( compact_outputs_for_storage, @@ -24,8 +23,11 @@ record_execution_result, resolve_execution_outcome, ) +from flocks.workflow.execution_plan import build_workflow_execution_plan from flocks.workflow.fs_store import read_workflow_from_fs +from flocks.workflow.models import Workflow from flocks.workflow.runner import RunWorkflowResult, run_workflow +from flocks.workflow.store import WorkflowStore WORKFLOW_POLLER_CONFIG_PREFIX = "workflow_poller_config/" DEFAULT_INTERVAL_SECONDS = 30 @@ -180,22 +182,14 @@ def get_status(self, workflow_id: str) -> Dict[str, Any]: async def start_all(self) -> None: try: - keys = await Storage.list_keys(WORKFLOW_POLLER_CONFIG_PREFIX) + configs = await WorkflowStore.list_configs(kind="workflow_poller_config") except Exception as exc: - log.warning("poller.list_keys_failed", {"error": str(exc)}) + log.warning("poller.list_configs_failed", {"error": str(exc)}) return - for key in keys: - if not key.startswith(WORKFLOW_POLLER_CONFIG_PREFIX): - continue - workflow_id = key[len(WORKFLOW_POLLER_CONFIG_PREFIX):] + for workflow_id, data in configs: if not workflow_id: continue - try: - data = await Storage.read(key) - except Exception as exc: - log.warning("poller.config_read_failed", {"key": key, "error": str(exc)}) - continue if isinstance(data, dict) and data.get("enabled"): await self.restart_workflow(workflow_id) @@ -238,7 +232,7 @@ async def stop_workflow(self, workflow_id: str) -> None: async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: await self.stop_workflow(workflow_id) try: - stored = await Storage.read(self._config_key(workflow_id)) + stored = await WorkflowStore.get_config(workflow_id, kind="workflow_poller_config") except Exception as exc: log.warning("poller.restart_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) return {"workflowId": workflow_id, "state": "failed", "error": str(exc)} @@ -275,6 +269,19 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: } return self.get_status(workflow_id) + try: + workflow_plan = build_workflow_execution_plan(Workflow.from_dict(workflow_json)) + except Exception as exc: + err = f"workflow_plan_failed: {exc}" + self._status[workflow_id] = { + **self.get_status(workflow_id), + "workflowId": workflow_id, + "state": "failed", + "error": err, + } + log.warning("poller.workflow_plan_failed", {"workflow_id": workflow_id, "error": str(exc)}) + return self.get_status(workflow_id) + abort_event = asyncio.Event() self._abort_events[workflow_id] = abort_event self._status[workflow_id] = { @@ -290,7 +297,7 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: "nextRunAt": self._compute_next_run_at_ms(config), } task = asyncio.create_task( - self._poller_loop(workflow_id, workflow_json, config, abort_event), + self._poller_loop(workflow_id, workflow_plan, config, abort_event), name=f"workflow-poller-{workflow_id}", ) self._tasks[workflow_id] = task @@ -298,7 +305,7 @@ async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: async def run_once(self, workflow_id: str) -> Dict[str, Any]: try: - stored = await Storage.read(self._config_key(workflow_id)) + stored = await WorkflowStore.get_config(workflow_id, kind="workflow_poller_config") except Exception as exc: log.warning("poller.run_once_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) current = self.get_status(workflow_id) @@ -330,7 +337,7 @@ async def run_once(self, workflow_id: str) -> Dict[str, Any]: async def _poller_loop( self, workflow_id: str, - workflow_json: Dict[str, Any], + workflow_json: Any, config: Dict[str, Any], abort_event: asyncio.Event, ) -> None: @@ -383,7 +390,7 @@ async def _poller_loop( async def _schedule_run( self, workflow_id: str, - workflow_json: Dict[str, Any], + workflow_json: Any, config: Dict[str, Any], ) -> None: active_runs = self._cleanup_done_runs(workflow_id) @@ -432,8 +439,10 @@ async def _execute_run( run_workflow, workflow=workflow_json, inputs=inputs, + run_id=exec_id, timeout_s=config["timeoutSeconds"], trace=False, + execution_profile="high_frequency", cancel=cancel_event.is_set, on_step_complete=step_recorder.on_step_complete, ) @@ -448,23 +457,25 @@ async def _execute_run( summary = self._summarize_outputs(result.outputs) step_count = step_recorder.step_count or result.steps exec_data.update(step_recorder.summary) - exec_data.update({ - "outputResults": compact_outputs_for_storage(result.outputs), - "status": status_value, - "finishedAt": _now_ms(), - "duration": duration_s, - "executionLog": [], - "errorMessage": error_message, - "stepCount": step_count, - "currentNodeId": result.last_node_id, - "currentPhase": status_value, - "currentStepIndex": step_count, - "triggerId": "schedule-default", - "triggerType": "schedule", - "deliveryId": inputs.get("_flocks", {}).get("trigger", {}).get("deliveryId"), - "attempt": 1, - "triggerSource": "poller", - }) + exec_data.update( + { + "outputResults": compact_outputs_for_storage(result.outputs), + "status": status_value, + "finishedAt": _now_ms(), + "duration": duration_s, + "executionLog": [], + "errorMessage": error_message, + "stepCount": step_count, + "currentNodeId": result.last_node_id, + "currentPhase": status_value, + "currentStepIndex": step_count, + "triggerId": "schedule-default", + "triggerType": "schedule", + "deliveryId": inputs.get("_flocks", {}).get("trigger", {}).get("deliveryId"), + "attempt": 1, + "triggerSource": "poller", + } + ) current = self._status.get(workflow_id) or self._base_status(workflow_id) current.update(summary) current["lastRunAt"] = started_at_ms @@ -483,19 +494,21 @@ async def _execute_run( status_value = "cancelled" if cancel_event.is_set() else "error" finished_at_ms = _now_ms() exec_data.update(step_recorder.summary) - exec_data.update({ - "status": status_value, - "finishedAt": finished_at_ms, - "duration": duration_s, - "errorMessage": str(exc), - "executionLog": [], - "currentPhase": status_value, - "triggerId": "schedule-default", - "triggerType": "schedule", - "deliveryId": inputs.get("_flocks", {}).get("trigger", {}).get("deliveryId"), - "attempt": 1, - "triggerSource": "poller", - }) + exec_data.update( + { + "status": status_value, + "finishedAt": finished_at_ms, + "duration": duration_s, + "errorMessage": str(exc), + "executionLog": [], + "currentPhase": status_value, + "triggerId": "schedule-default", + "triggerType": "schedule", + "deliveryId": inputs.get("_flocks", {}).get("trigger", {}).get("deliveryId"), + "attempt": 1, + "triggerSource": "poller", + } + ) current = self._status.get(workflow_id) or self._base_status(workflow_id) current["lastRunAt"] = started_at_ms current["lastDurationMs"] = duration_ms diff --git a/flocks/workflow/repl_runtime.py b/flocks/workflow/repl_runtime.py index 2f010fb75..ad78f0b59 100644 --- a/flocks/workflow/repl_runtime.py +++ b/flocks/workflow/repl_runtime.py @@ -54,8 +54,6 @@ class PythonExecRuntime(Runtime): _RUNTIME_GLOBAL_KEYS: ClassVar[frozenset[str]] = frozenset( { "__builtins__", - "inputs", - "outputs", "cancelled", "is_cancelled", "llm", @@ -141,73 +139,80 @@ def _cancel_trace(_frame: Any, event: str, _arg: Any) -> Any: return _cancel_trace try: - if self.cancel_checker is not None: - previous_trace = sys.gettrace() - sys.settrace(_cancel_trace) - with contextlib.redirect_stdout(buf): - exec(code, g, g) - except SystemExit: - # Node code called exit() / sys.exit() — treat as early return with - # whatever has been written to outputs so far. Do NOT propagate - # SystemExit; that would kill the asyncio event loop. - pass - except RunCancelledError: - raise - except _FuturesTimeoutError: - raise - except SyntaxError as e: - tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) - raise NodeExecutionError( - node_id="", - message=f"Syntax error in code at line {e.lineno}: {e.msg}", - stdout=buf.getvalue(), - traceback=tb_str - ) from e - except AttributeError as e: - tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) - error_msg = f"AttributeError: {e}" - if "'NoneType' object has no attribute" in str(e): - attr_name = str(e).split("'")[-2] if "'" in str(e) else "unknown" - error_msg += f"\n提示: 对象为 None,无法访问属性 '{attr_name}'。" - raise NodeExecutionError(node_id="", message=error_msg, stdout=buf.getvalue(), traceback=tb_str) from e - except KeyError as e: - tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) - raise NodeExecutionError( - node_id="", - message=f"Missing required input key: {e}", - stdout=buf.getvalue(), - traceback=tb_str - ) from e - except NameError as e: - tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) - raise NodeExecutionError( - node_id="", - message=f"Undefined variable or function: {e}", - stdout=buf.getvalue(), - traceback=tb_str - ) from e - except TypeError as e: - tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) - error_msg = f"Type error during execution: {e}" - raise NodeExecutionError(node_id="", message=error_msg, stdout=buf.getvalue(), traceback=tb_str) from e - except Exception as e: - tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) - error_msg = f"Runtime error ({type(e).__name__}): {e}" - raise NodeExecutionError(node_id="", message=error_msg, stdout=buf.getvalue(), traceback=tb_str) from e + try: + if self.cancel_checker is not None: + previous_trace = sys.gettrace() + sys.settrace(_cancel_trace) + with contextlib.redirect_stdout(buf): + exec(code, g, g) + except SystemExit: + # Node code called exit() / sys.exit() — treat as early return with + # whatever has been written to outputs so far. Do NOT propagate + # SystemExit; that would kill the asyncio event loop. + pass + except RunCancelledError: + raise + except _FuturesTimeoutError: + raise + except SyntaxError as e: + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + raise NodeExecutionError( + node_id="", + message=f"Syntax error in code at line {e.lineno}: {e.msg}", + stdout=buf.getvalue(), + traceback=tb_str, + ) from e + except AttributeError as e: + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + error_msg = f"AttributeError: {e}" + if "'NoneType' object has no attribute" in str(e): + attr_name = str(e).split("'")[-2] if "'" in str(e) else "unknown" + error_msg += f"\n提示: 对象为 None,无法访问属性 '{attr_name}'。" + raise NodeExecutionError( + node_id="", message=error_msg, stdout=buf.getvalue(), traceback=tb_str + ) from e + except KeyError as e: + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + raise NodeExecutionError( + node_id="", + message=f"Missing required input key: {e}", + stdout=buf.getvalue(), + traceback=tb_str, + ) from e + except NameError as e: + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + raise NodeExecutionError( + node_id="", + message=f"Undefined variable or function: {e}", + stdout=buf.getvalue(), + traceback=tb_str, + ) from e + except TypeError as e: + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + error_msg = f"Type error during execution: {e}" + raise NodeExecutionError( + node_id="", message=error_msg, stdout=buf.getvalue(), traceback=tb_str + ) from e + except Exception as e: + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + error_msg = f"Runtime error ({type(e).__name__}): {e}" + raise NodeExecutionError( + node_id="", message=error_msg, stdout=buf.getvalue(), traceback=tb_str + ) from e + + out_obj = g.get("outputs", {}) + if out_obj is None: + out_obj = {} + if not isinstance(out_obj, dict): + raise NodeExecutionError(node_id="", message="`outputs` must be a dict") + return dict(out_obj), buf.getvalue() finally: if self.cancel_checker is not None: sys.settrace(previous_trace) - - out_obj = g.get("outputs", {}) - if out_obj is None: - out_obj = {} - if not isinstance(out_obj, dict): - raise NodeExecutionError(node_id="", message="`outputs` must be a dict") - if self.cleanup_globals_after_execute: - for key in list(g.keys()): - if key not in self._RUNTIME_GLOBAL_KEYS: - g.pop(key, None) - return out_obj, buf.getvalue() + if self.cleanup_globals_after_execute: + for key in list(g.keys()): + if key not in self._RUNTIME_GLOBAL_KEYS: + g.pop(key, None) def reset(self) -> None: self.globals.clear() @@ -646,7 +651,7 @@ def execute(self, code: str, inputs: Dict[str, Any]) -> Tuple[Dict[str, Any], st outputs: Dict[str, Any] = {} for line in stdout.splitlines(): if line.startswith(marker): - payload = line[len(marker):] + payload = line[len(marker) :] try: obj = json.loads(payload) if payload else {} except Exception: diff --git a/flocks/workflow/runner.py b/flocks/workflow/runner.py index 51f27a48d..f8c82e6e7 100644 --- a/flocks/workflow/runner.py +++ b/flocks/workflow/runner.py @@ -14,19 +14,20 @@ from flocks.config.config import Config from flocks.sandbox.context import resolve_sandbox_context -from .errors import FlocksWorkflowError, RunCancelledError, RunTimeoutError +from .errors import FlocksWorkflowError, RunCancelledError, RunTimeoutError, WorkflowValidationError from .io import dump_workflow, load_workflow from .compiler import default_exec_path, compile_workflow, workflow_has_logic_nodes +from .execution_plan import ( + WorkflowExecutionPlan, + build_workflow_execution_plan, + resolve_workflow_dataflow_mode, +) from .models import Workflow from .engine import WorkflowEngine from .repl_runtime import PythonExecRuntime, SandboxPythonExecRuntime from .tools import get_tool_registry -from .requirements import ( - RequirementsInstaller, - SandboxRequirementsInstaller, - requirements_from_workflow_metadata, -) -from .workflow_lint import lint_workflow +from .requirements import RequirementsInstaller, SandboxRequirementsInstaller +from .workflow_lint import is_schema_lint_error from .logging_config import setup_workflow_logging @@ -242,7 +243,8 @@ def _ensure_logging_configured() -> None: setup_workflow_logging() -WorkflowSource = Union[Dict[str, Any], str, Path, Workflow] +WorkflowSource = Union[Dict[str, Any], str, Path, Workflow, WorkflowExecutionPlan] +ExecutionProfile = Literal["default", "high_frequency"] @dataclass @@ -280,6 +282,35 @@ def _build_initial_inputs( return initial_inputs +def _apply_execution_profile( + profile: Optional[ExecutionProfile], + *, + node_timeout_s: Optional[float], + trace: bool, + history_mode: Literal["full", "summary"], + retain_history: bool, +) -> tuple[Optional[float], bool, Literal["full", "summary"], bool]: + """Apply a coarse runtime profile before execution.""" + if profile != "high_frequency": + return node_timeout_s, trace, history_mode, retain_history + return node_timeout_s, False, "summary", False + + +def _validate_lint_results(lint_results: tuple[Dict[str, Any], ...]) -> None: + lint_errors = [r for r in lint_results if r.get("severity") == "error"] + lint_warnings = [r for r in lint_results if r.get("severity") != "error"] + if lint_errors: + _logger.error(f"workflow lint 检查发现 {len(lint_errors)} 个错误: {lint_errors[:5]}") + strict_mapping_errors = [item for item in lint_errors if item.get("kind") == "implicit_full_payload_edge"] + if strict_mapping_errors: + raise WorkflowValidationError(f"Workflow strict edge mapping failed: {strict_mapping_errors[:5]}") + schema_errors = [item for item in lint_errors if is_schema_lint_error(item)] + if schema_errors: + raise WorkflowValidationError(f"Workflow schema lint failed: {schema_errors[:5]}") + if lint_warnings: + _logger.warning(f"workflow lint 检查发现 {len(lint_warnings)} 个警告: {lint_warnings[:5]}") + + def run_workflow( *, workflow: WorkflowSource, @@ -293,21 +324,39 @@ def run_workflow( ensure_requirements: bool = True, requirements_installer: Optional[RequirementsInstaller] = None, sandbox_requirements_installer: Optional[SandboxRequirementsInstaller] = None, + run_id: Optional[str] = None, on_step_start: Optional[Any] = None, on_step_complete: Optional[Any] = None, max_parallel_workers: int = 4, history_mode: Literal["full", "summary"] = "summary", cancel: Optional[Callable[[], bool]] = None, retain_history: bool = False, + execution_profile: Optional[ExecutionProfile] = None, ) -> RunWorkflowResult: # 确保日志已配置 _ensure_logging_configured() - + _logger.debug("=== 开始执行 workflow ===") - + + node_timeout_s, trace, history_mode, retain_history = _apply_execution_profile( + execution_profile, + node_timeout_s=node_timeout_s, + trace=trace, + history_mode=history_mode, + retain_history=retain_history, + ) + workflow_path_for_engine: Optional[str] = None effective_use_llm: Optional[bool] = use_llm - if isinstance(workflow, Workflow): + plan: Optional[WorkflowExecutionPlan] = None + if isinstance(workflow, WorkflowExecutionPlan): + _logger.debug("workflow 来源: execution plan") + plan = workflow + wf = plan.workflow + workflow_path_for_engine = plan.workflow_path + if effective_use_llm is None: + effective_use_llm = plan.use_llm + elif isinstance(workflow, Workflow): _logger.debug("workflow 来源: Workflow 对象") wf = workflow elif isinstance(workflow, (str, Path)): @@ -365,16 +414,19 @@ def run_workflow( if effective_use_llm is None: effective_use_llm = workflow_has_logic_nodes(wf) + if plan is None: + plan = build_workflow_execution_plan( + wf, + workflow_path=workflow_path_for_engine, + use_llm=effective_use_llm, + ) + _logger.debug("workflow 信息: nodes=%s, edges=%s, start=%s", len(wf.nodes), len(wf.edges), wf.start) - + try: - lint_results = lint_workflow(wf) - lint_errors = [r for r in lint_results if r.get("severity") == "error"] - lint_warnings = [r for r in lint_results if r.get("severity") != "error"] - if lint_errors: - _logger.error(f"workflow lint 检查发现 {len(lint_errors)} 个错误: {lint_errors[:5]}") - if lint_warnings: - _logger.warning(f"workflow lint 检查发现 {len(lint_warnings)} 个警告: {lint_warnings[:5]}") + _validate_lint_results(plan.lint_results) + except WorkflowValidationError: + raise except Exception: pass @@ -382,8 +434,9 @@ def run_workflow( node_timeout_s, wf.metadata, ) + dataflow_mode = plan.dataflow_mode - reqs = requirements_from_workflow_metadata(wf.metadata) + reqs = list(plan.requirements) if ensure_requirements: _logger.debug("检查依赖包...") if reqs: @@ -420,14 +473,15 @@ def run_workflow( tool_registry=registry, cleanup_globals_after_execute=(history_mode == "summary"), ) - + _logger.debug( - "创建执行引擎 (use_llm=%s, trace=%s, node_timeout=%ss, parallel_workers=%s, history_mode=%s)", + "创建执行引擎 (use_llm=%s, trace=%s, node_timeout=%ss, parallel_workers=%s, history_mode=%s, dataflow_mode=%s)", effective_use_llm, trace, effective_node_timeout_s, max_parallel_workers, history_mode, + dataflow_mode, ) engine = WorkflowEngine( wf, @@ -438,8 +492,10 @@ def run_workflow( node_timeout_s=effective_node_timeout_s, max_parallel_workers=max_parallel_workers, history_mode=history_mode, + dataflow_mode=dataflow_mode, + execution_plan=plan, ) - + initial_inputs = _build_initial_inputs(inputs, workflow_path_for_engine) _logger.debug( "开始执行 workflow (timeout=%ss, inputs=%s)", @@ -450,18 +506,23 @@ def run_workflow( _on_step_start = None _on_step_end = None if on_step_start is not None: + def _on_step_start(_rid, _step, _node, _inp): return on_step_start(_rid, _step, _node, _inp) elif on_step_complete is not None: + def _on_step_start(_rid, _step, _node, _inp): return True + if on_step_complete is not None: + def _on_step_end(_token, step_result): return on_step_complete(step_result) try: result = engine.run( initial_inputs=initial_inputs, + run_id=run_id, timeout_s=timeout_s, cancel=cancel, on_step_start=_on_step_start, @@ -470,17 +531,17 @@ def _on_step_end(_token, step_result): ) except FlocksWorkflowError as e: # Extract execution context from error if available - exec_ctx = getattr(e, 'execution_context', {}) - history_from_error = exec_ctx.get('history', []) - + exec_ctx = getattr(e, "execution_context", {}) + history_from_error = exec_ctx.get("history", []) + # Convert StepResult objects to dicts if needed - if history_from_error and hasattr(history_from_error[0], 'model_dump'): + if history_from_error and hasattr(history_from_error[0], "model_dump"): history_from_error = [s.model_dump(mode="json") for s in history_from_error] - - last_outputs = exec_ctx.get('outputs') or ( - history_from_error[-1].get('outputs', {}) if history_from_error else {} + + last_outputs = exec_ctx.get("outputs") or ( + history_from_error[-1].get("outputs", {}) if history_from_error else {} ) - + status = "FAILED" if isinstance(e, RunCancelledError): status = "CANCELLED" @@ -490,9 +551,9 @@ def _on_step_end(_token, step_result): return RunWorkflowResult( status=status, error=f"{type(e).__name__}: {e}", - run_id=exec_ctx.get('run_id'), - steps=exec_ctx.get('steps', 0), - last_node_id=exec_ctx.get('last_node_id'), + run_id=exec_ctx.get("run_id"), + steps=exec_ctx.get("steps", 0), + last_node_id=exec_ctx.get("last_node_id"), outputs=last_outputs, history=history_from_error, ) @@ -510,9 +571,11 @@ def _on_step_end(_token, step_result): history=history, error=f"RunCancelledError: Run cancelled: run_id={result.run_id}", ) - - _logger.info(f"=== workflow 执行成功 === run_id={result.run_id}, steps={result.steps}, last_node={result.last_node_id}") - + + _logger.info( + f"=== workflow 执行成功 === run_id={result.run_id}, steps={result.steps}, last_node={result.last_node_id}" + ) + return RunWorkflowResult( status="SUCCEEDED", run_id=result.run_id, diff --git a/flocks/workflow/store.py b/flocks/workflow/store.py new file mode 100644 index 000000000..8212383ad --- /dev/null +++ b/flocks/workflow/store.py @@ -0,0 +1,766 @@ +"""SQLite persistence for workflow runtime data.""" + +from __future__ import annotations + +import asyncio +import json +import os +import sqlite3 +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import aiosqlite + +from flocks.storage.storage import Storage +from flocks.utils.log import Log + +log = Log.create(service="workflow.store") + +_MIGRATION_MARKER_KEY = "workflow_store.migration.tables.v1" +_JSON_TYPE = "json" +_WORKFLOW_KV_PREFIXES = ( + "workflow_registry/", + "workflow_release/", + "workflow_runtime/", + "workflow_local_pid/", + "workflow_api_service/", +) +_WORKFLOW_TABLE_PREFIXES = ( + "workflow_execution/", + "workflow_execution_index/", + "workflow_execution_step/", + "workflow/", + "workflow_integration_config/", + "workflow_kafka_config/", + "workflow_poller_config/", + "workflow_syslog_config/", +) +_WORKFLOW_PREFIXES = _WORKFLOW_KV_PREFIXES + _WORKFLOW_TABLE_PREFIXES + + +class WorkflowStore: + """Workflow-domain store backed by ``workflow.db`` tables.""" + + _initialized = False + _conn: Optional[aiosqlite.Connection] = None + _init_pid: Optional[int] = None + _db_path: Optional[Path] = None + + @classmethod + def get_db_path(cls) -> Path: + return Storage.get_workflow_db_path() + + @classmethod + async def init(cls) -> None: + current_pid = os.getpid() + db_path = cls.get_db_path() + if cls._initialized and cls._init_pid == current_pid and cls._db_path == db_path: + return + if cls._initialized and ( + (cls._init_pid is not None and cls._init_pid != current_pid) + or (cls._db_path is not None and cls._db_path != db_path) + ): + log.warn( + "workflow.store.fork_detected", + { + "parent_pid": cls._init_pid, + "child_pid": current_pid, + "old_db_path": str(cls._db_path) if cls._db_path else None, + "new_db_path": str(db_path), + }, + ) + if cls._conn: + await cls._conn.close() + cls._conn = None + cls._initialized = False + cls._init_pid = None + + await Storage._ensure_init() + db_path.parent.mkdir(parents=True, exist_ok=True) + try: + cls._conn = await aiosqlite.connect( + db_path, + timeout=Storage._sqlite_timeout_s, + ) + cls._conn.row_factory = aiosqlite.Row + await Storage.configure_connection(cls._conn) + await cls._conn.executescript(_WORKFLOW_DDL) + for stmt in _INDEX_STMTS: + await cls._conn.execute(stmt) + await cls._conn.commit() + cls._initialized = True + cls._init_pid = current_pid + cls._db_path = db_path + await cls._migrate_legacy_kv() + log.info("workflow.store.initialized") + except Exception: + if cls._conn: + await cls._conn.close() + cls._conn = None + cls._initialized = False + cls._init_pid = None + cls._db_path = None + raise + + @classmethod + async def close(cls) -> None: + if cls._conn: + await cls._conn.close() + cls._conn = None + cls._initialized = False + cls._init_pid = None + cls._db_path = None + + @classmethod + async def _db(cls) -> aiosqlite.Connection: + if cls._initialized and cls._init_pid is not None and cls._init_pid != os.getpid(): + await cls.init() + if not cls._conn or not cls._initialized: + await cls.init() + return cls._conn # type: ignore[return-value] + + @classmethod + async def raw_db(cls) -> aiosqlite.Connection: + return await cls._db() + + @staticmethod + def _json_dumps(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, default=str) + + @staticmethod + def _json_loads(value: Optional[str], default: Any = None) -> Any: + if value is None: + return default + try: + return json.loads(value) + except Exception: + return default + + @staticmethod + def _now_iso() -> str: + return datetime.now(UTC).isoformat() + + @staticmethod + def _now_ms() -> int: + return int(datetime.now(UTC).timestamp() * 1000) + + @staticmethod + def _as_int(value: Any) -> Optional[int]: + if isinstance(value, bool): + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _as_float(value: Any) -> Optional[float]: + if isinstance(value, bool): + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _quote_identifier(identifier: str) -> str: + return '"' + identifier.replace('"', '""') + '"' + + @classmethod + def _legacy_table_exists(cls, conn: sqlite3.Connection) -> bool: + row = conn.execute("SELECT 1 FROM sqlite_master WHERE type='table' AND name = 'storage'").fetchone() + return row is not None + + @classmethod + def _legacy_rows_from_db(cls, db_path: Path) -> list[sqlite3.Row]: + if not db_path.exists(): + return [] + conn = Storage.connect_sync(db_path) + try: + if not cls._legacy_table_exists(conn): + return [] + clauses = " OR ".join("key LIKE ?" for _ in _WORKFLOW_PREFIXES) + params = tuple(f"{prefix}%" for prefix in _WORKFLOW_PREFIXES) + return conn.execute( + f""" + SELECT key, value, type, created_at, updated_at + FROM storage + WHERE {clauses} + ORDER BY key + """, + params, + ).fetchall() + finally: + conn.close() + + @classmethod + async def _migrate_legacy_kv(cls) -> None: + if await cls.kv_get(_MIGRATION_MARKER_KEY) is not None: + return + rows_by_key: dict[str, sqlite3.Row] = {} + for db_path in (Storage.get_db_path(), cls.get_db_path()): + for row in await asyncio.to_thread(cls._legacy_rows_from_db, db_path): + rows_by_key[str(row["key"])] = row + + counts = { + "executions": 0, + "steps": 0, + "stats": 0, + "configs": 0, + "kv": 0, + "skipped": 0, + } + for key, row in rows_by_key.items(): + value = cls._json_loads(str(row["value"]), None) + if value is None: + counts["skipped"] += 1 + continue + if key.startswith("workflow_execution_step/") and isinstance(value, dict): + parts = key.split("/") + if len(parts) >= 3: + try: + await cls.record_step(parts[1], int(parts[2]), value) + counts["steps"] += 1 + except Exception: + counts["skipped"] += 1 + continue + if key.startswith("workflow_execution/") and isinstance(value, dict): + await cls.upsert_execution(value) + counts["executions"] += 1 + continue + if key.startswith("workflow/") and key.endswith("/stats") and isinstance(value, dict): + workflow_id = key[len("workflow/") : -len("/stats")] + await cls.put_stats(workflow_id, value) + counts["stats"] += 1 + continue + if key.startswith("workflow_integration_config/") and isinstance(value, dict): + workflow_id = key[len("workflow_integration_config/") :] + await cls.put_config(workflow_id, value) + counts["configs"] += 1 + continue + if key.startswith(_WORKFLOW_KV_PREFIXES): + await cls.kv_put(key, value) + counts["kv"] += 1 + continue + if key.startswith( + ( + "workflow_kafka_config/", + "workflow_poller_config/", + "workflow_syslog_config/", + ) + ) and isinstance(value, dict): + workflow_id = key.rsplit("/", 1)[-1] + await cls.put_config(workflow_id, value, kind=key.split("/", 1)[0]) + counts["configs"] += 1 + + await cls.kv_put( + _MIGRATION_MARKER_KEY, + { + "version": 1, + "migrated_at": cls._now_iso(), + "source_db": str(Storage.get_db_path()), + "workflow_db": str(cls.get_db_path()), + **counts, + }, + ) + log.info("workflow.store.legacy_kv_migrated", counts) + + @classmethod + async def upsert_execution(cls, exec_data: Dict[str, Any]) -> None: + db = await cls._db() + payload = dict(exec_data) + exec_id = str(payload.get("id") or "") + workflow_id = str(payload.get("workflowId") or payload.get("workflow_id") or "") + if not exec_id or not workflow_id: + raise ValueError("workflow execution requires id and workflowId") + await db.execute( + """ + INSERT OR REPLACE INTO workflow_executions + (id, workflow_id, status, current_phase, current_node_id, current_node_type, + current_step_index, step_count, input_params, output_results, error_message, + trigger_id, trigger_type, started_at, finished_at, duration, updated_at, payload) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + exec_id, + workflow_id, + str(payload.get("status") or "running"), + payload.get("currentPhase"), + payload.get("currentNodeId"), + payload.get("currentNodeType"), + cls._as_int(payload.get("currentStepIndex")), + cls._as_int(payload.get("stepCount")) or 0, + cls._json_dumps(payload.get("inputParams") or {}), + cls._json_dumps(payload.get("outputResults") or {}), + payload.get("errorMessage"), + payload.get("triggerId"), + payload.get("triggerType"), + cls._as_int(payload.get("startedAt")) or cls._now_ms(), + cls._as_int(payload.get("finishedAt")), + cls._as_float(payload.get("duration")), + cls._as_int(payload.get("updatedAt")) or cls._now_ms(), + cls._json_dumps(payload), + ), + ) + await db.commit() + + @classmethod + async def get_execution(cls, exec_id: str) -> Optional[Dict[str, Any]]: + db = await cls._db() + async with db.execute( + "SELECT payload FROM workflow_executions WHERE id = ?", + (exec_id,), + ) as cur: + row = await cur.fetchone() + if not row: + return None + value = cls._json_loads(row["payload"], None) + return value if isinstance(value, dict) else None + + @classmethod + async def list_executions( + cls, + workflow_id: str, + *, + limit: int = 50, + trigger_id: Optional[str] = None, + trigger_type: Optional[str] = None, + ) -> List[Dict[str, Any]]: + db = await cls._db() + clauses = ["workflow_id = ?"] + params: list[Any] = [workflow_id] + if trigger_id: + clauses.append("trigger_id = ?") + params.append(trigger_id) + if trigger_type: + clauses.append("trigger_type = ?") + params.append(trigger_type) + params.append(max(int(limit), 0)) + async with db.execute( + f""" + SELECT payload FROM workflow_executions + WHERE {" AND ".join(clauses)} + ORDER BY started_at DESC + LIMIT ? + """, + tuple(params), + ) as cur: + rows = await cur.fetchall() + items: List[Dict[str, Any]] = [] + for row in rows: + value = cls._json_loads(row["payload"], None) + if isinstance(value, dict): + items.append(value) + return items + + @classmethod + async def delete_execution(cls, exec_id: str) -> bool: + db = await cls._db() + await db.execute("DELETE FROM workflow_execution_steps WHERE exec_id = ?", (exec_id,)) + cur = await db.execute("DELETE FROM workflow_executions WHERE id = ?", (exec_id,)) + await db.commit() + return cur.rowcount > 0 + + @classmethod + async def delete_executions_for_workflow(cls, workflow_id: str) -> int: + db = await cls._db() + async with db.execute( + "SELECT id FROM workflow_executions WHERE workflow_id = ?", + (workflow_id,), + ) as cur: + exec_ids = [str(row["id"]) for row in await cur.fetchall()] + for exec_id in exec_ids: + await db.execute("DELETE FROM workflow_execution_steps WHERE exec_id = ?", (exec_id,)) + cur = await db.execute("DELETE FROM workflow_executions WHERE workflow_id = ?", (workflow_id,)) + await db.commit() + return cur.rowcount + + @classmethod + async def trim_executions(cls, workflow_id: str, *, keep: int) -> List[str]: + db = await cls._db() + async with db.execute( + """ + SELECT id FROM workflow_executions + WHERE workflow_id = ? + ORDER BY started_at DESC + LIMIT -1 OFFSET ? + """, + (workflow_id, max(int(keep), 0)), + ) as cur: + exec_ids = [str(row["id"]) for row in await cur.fetchall()] + for exec_id in exec_ids: + await db.execute("DELETE FROM workflow_execution_steps WHERE exec_id = ?", (exec_id,)) + await db.execute("DELETE FROM workflow_executions WHERE id = ?", (exec_id,)) + await db.commit() + return exec_ids + + @classmethod + async def record_step( + cls, + exec_id: str, + step_index: int, + step_payload: Dict[str, Any], + ) -> None: + db = await cls._db() + await db.execute( + """ + INSERT OR REPLACE INTO workflow_execution_steps + (exec_id, step_index, node_id, node_type, inputs, outputs, error, payload) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + exec_id, + int(step_index), + step_payload.get("node_id"), + step_payload.get("node_type") or step_payload.get("type"), + cls._json_dumps(step_payload.get("inputs") or {}), + cls._json_dumps(step_payload.get("outputs") or {}), + step_payload.get("error"), + cls._json_dumps(step_payload), + ), + ) + await db.commit() + + @classmethod + async def list_steps( + cls, + exec_id: str, + *, + offset: int = 0, + limit: int = 500, + ) -> Tuple[List[Dict[str, Any]], int]: + db = await cls._db() + safe_offset = max(int(offset), 0) + safe_limit = max(int(limit), 0) + async with db.execute( + "SELECT COUNT(*) AS total FROM workflow_execution_steps WHERE exec_id = ?", + (exec_id,), + ) as cur: + row = await cur.fetchone() + total = int(row["total"]) if row else 0 + if safe_limit == 0: + return [], total + async with db.execute( + """ + SELECT payload FROM workflow_execution_steps + WHERE exec_id = ? + ORDER BY step_index + LIMIT ? OFFSET ? + """, + (exec_id, safe_limit, safe_offset), + ) as cur: + rows = await cur.fetchall() + steps: List[Dict[str, Any]] = [] + for row in rows: + value = cls._json_loads(row["payload"], None) + if isinstance(value, dict): + steps.append(value) + return steps, total + + @classmethod + async def clear_steps(cls, exec_id: str) -> int: + db = await cls._db() + cur = await db.execute("DELETE FROM workflow_execution_steps WHERE exec_id = ?", (exec_id,)) + await db.commit() + return cur.rowcount + + @classmethod + async def get_stats(cls, workflow_id: str) -> Optional[Dict[str, Any]]: + db = await cls._db() + async with db.execute( + "SELECT * FROM workflow_stats WHERE workflow_id = ?", + (workflow_id,), + ) as cur: + row = await cur.fetchone() + if not row: + return None + return { + "callCount": int(row["call_count"] or 0), + "successCount": int(row["success_count"] or 0), + "errorCount": int(row["error_count"] or 0), + "totalRuntime": float(row["total_runtime"] or 0.0), + "avgRuntime": float(row["avg_runtime"] or 0.0), + "thumbsUp": int(row["thumbs_up"] or 0), + "thumbsDown": int(row["thumbs_down"] or 0), + } + + @classmethod + async def put_stats(cls, workflow_id: str, stats: Dict[str, Any]) -> None: + db = await cls._db() + await db.execute( + """ + INSERT OR REPLACE INTO workflow_stats + (workflow_id, call_count, success_count, error_count, total_runtime, + avg_runtime, thumbs_up, thumbs_down, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + workflow_id, + int(stats.get("callCount") or stats.get("call_count") or 0), + int(stats.get("successCount") or stats.get("success_count") or 0), + int(stats.get("errorCount") or stats.get("error_count") or 0), + float(stats.get("totalRuntime") or stats.get("total_runtime") or 0.0), + float(stats.get("avgRuntime") or stats.get("avg_runtime") or 0.0), + int(stats.get("thumbsUp") or stats.get("thumbs_up") or 0), + int(stats.get("thumbsDown") or stats.get("thumbs_down") or 0), + cls._now_ms(), + ), + ) + await db.commit() + + @classmethod + async def delete_stats(cls, workflow_id: str) -> bool: + db = await cls._db() + cur = await db.execute("DELETE FROM workflow_stats WHERE workflow_id = ?", (workflow_id,)) + await db.commit() + return cur.rowcount > 0 + + @classmethod + async def increment_stats(cls, workflow_id: str, *, success: bool, duration: float) -> None: + db = await cls._db() + runtime = float(duration) + success_delta = 1 if success else 0 + error_delta = 0 if success else 1 + updated_at = cls._now_ms() + await db.execute( + """ + INSERT INTO workflow_stats ( + workflow_id, + call_count, + success_count, + error_count, + total_runtime, + avg_runtime, + thumbs_up, + thumbs_down, + updated_at + ) + VALUES (?, 1, ?, ?, ?, ?, 0, 0, ?) + ON CONFLICT(workflow_id) DO UPDATE SET + call_count = workflow_stats.call_count + 1, + success_count = workflow_stats.success_count + excluded.success_count, + error_count = workflow_stats.error_count + excluded.error_count, + total_runtime = workflow_stats.total_runtime + excluded.total_runtime, + avg_runtime = ( + workflow_stats.total_runtime + excluded.total_runtime + ) / (workflow_stats.call_count + 1), + updated_at = excluded.updated_at + """, + ( + workflow_id, + success_delta, + error_delta, + runtime, + runtime, + updated_at, + ), + ) + await db.commit() + + @classmethod + async def put_config( + cls, + workflow_id: str, + config: Dict[str, Any], + *, + kind: Optional[str] = None, + ) -> None: + db = await cls._db() + config_kind = kind or str(config.get("kind") or "workflow.integration-config") + version = cls._as_int(config.get("version")) + await db.execute( + """ + INSERT OR REPLACE INTO workflow_configs + (workflow_id, kind, version, config, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + (workflow_id, config_kind, version, cls._json_dumps(config), cls._now_ms()), + ) + await db.commit() + + @classmethod + async def get_config( + cls, + workflow_id: str, + *, + kind: str = "workflow.integration-config", + ) -> Optional[Dict[str, Any]]: + db = await cls._db() + async with db.execute( + "SELECT config FROM workflow_configs WHERE workflow_id = ? AND kind = ?", + (workflow_id, kind), + ) as cur: + row = await cur.fetchone() + if not row: + return None + value = cls._json_loads(row["config"], None) + return value if isinstance(value, dict) else None + + @classmethod + async def list_configs(cls, *, kind: str) -> List[Tuple[str, Dict[str, Any]]]: + db = await cls._db() + async with db.execute( + "SELECT workflow_id, config FROM workflow_configs WHERE kind = ? ORDER BY workflow_id", + (kind,), + ) as cur: + rows = await cur.fetchall() + items: List[Tuple[str, Dict[str, Any]]] = [] + for row in rows: + value = cls._json_loads(row["config"], None) + if isinstance(value, dict): + items.append((str(row["workflow_id"]), value)) + return items + + @classmethod + async def delete_config(cls, workflow_id: str, *, kind: Optional[str] = None) -> int: + db = await cls._db() + if kind: + cur = await db.execute( + "DELETE FROM workflow_configs WHERE workflow_id = ? AND kind = ?", + (workflow_id, kind), + ) + else: + cur = await db.execute("DELETE FROM workflow_configs WHERE workflow_id = ?", (workflow_id,)) + await db.commit() + return cur.rowcount + + @classmethod + async def kv_put(cls, key: str, value: Any, value_type: str = _JSON_TYPE) -> None: + db = await cls._db() + now = cls._now_iso() + await db.execute( + """ + INSERT OR REPLACE INTO workflow_kv (key, value, type, created_at, updated_at) + VALUES (?, ?, ?, + COALESCE((SELECT created_at FROM workflow_kv WHERE key = ?), ?), + ?) + """, + (key, cls._json_dumps(value), value_type, key, now, now), + ) + await db.commit() + + @classmethod + async def kv_get(cls, key: str) -> Optional[Any]: + db = await cls._db() + async with db.execute("SELECT value FROM workflow_kv WHERE key = ?", (key,)) as cur: + row = await cur.fetchone() + if not row: + return None + return cls._json_loads(row["value"], None) + + @classmethod + async def kv_remove(cls, key: str) -> bool: + db = await cls._db() + cur = await db.execute("DELETE FROM workflow_kv WHERE key = ?", (key,)) + await db.commit() + return cur.rowcount > 0 + + @classmethod + async def kv_list_keys(cls, prefix: str) -> List[str]: + db = await cls._db() + async with db.execute( + "SELECT key FROM workflow_kv WHERE key LIKE ? ESCAPE '\\' ORDER BY key", + (Storage._like_prefix_pattern(prefix),), + ) as cur: + rows = await cur.fetchall() + return [str(row["key"]) for row in rows] + + @classmethod + async def kv_list(cls, prefix: str) -> List[str]: + return await cls.kv_list_keys(prefix) + + @classmethod + async def kv_entries(cls, prefix: str) -> List[Tuple[str, Any]]: + db = await cls._db() + async with db.execute( + "SELECT key, value FROM workflow_kv WHERE key LIKE ? ESCAPE '\\' ORDER BY key", + (Storage._like_prefix_pattern(prefix),), + ) as cur: + rows = await cur.fetchall() + entries: List[Tuple[str, Any]] = [] + for row in rows: + entries.append((str(row["key"]), cls._json_loads(row["value"], None))) + return entries + + @classmethod + async def kv_clear(cls, prefix: str) -> int: + db = await cls._db() + cur = await db.execute( + "DELETE FROM workflow_kv WHERE key LIKE ? ESCAPE '\\'", + (Storage._like_prefix_pattern(prefix),), + ) + await db.commit() + return cur.rowcount + + +_WORKFLOW_DDL = """ +CREATE TABLE IF NOT EXISTS workflow_executions ( + id TEXT PRIMARY KEY, + workflow_id TEXT NOT NULL, + status TEXT NOT NULL, + current_phase TEXT, + current_node_id TEXT, + current_node_type TEXT, + current_step_index INTEGER, + step_count INTEGER NOT NULL DEFAULT 0, + input_params TEXT NOT NULL DEFAULT '{}', + output_results TEXT NOT NULL DEFAULT '{}', + error_message TEXT, + trigger_id TEXT, + trigger_type TEXT, + started_at INTEGER NOT NULL, + finished_at INTEGER, + duration REAL, + updated_at INTEGER, + payload TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS workflow_execution_steps ( + exec_id TEXT NOT NULL, + step_index INTEGER NOT NULL, + node_id TEXT, + node_type TEXT, + inputs TEXT NOT NULL DEFAULT '{}', + outputs TEXT NOT NULL DEFAULT '{}', + error TEXT, + payload TEXT NOT NULL, + PRIMARY KEY (exec_id, step_index) +); + +CREATE TABLE IF NOT EXISTS workflow_stats ( + workflow_id TEXT PRIMARY KEY, + call_count INTEGER NOT NULL DEFAULT 0, + success_count INTEGER NOT NULL DEFAULT 0, + error_count INTEGER NOT NULL DEFAULT 0, + total_runtime REAL NOT NULL DEFAULT 0, + avg_runtime REAL NOT NULL DEFAULT 0, + thumbs_up INTEGER NOT NULL DEFAULT 0, + thumbs_down INTEGER NOT NULL DEFAULT 0, + updated_at INTEGER +); + +CREATE TABLE IF NOT EXISTS workflow_configs ( + workflow_id TEXT NOT NULL, + kind TEXT NOT NULL, + version INTEGER, + config TEXT NOT NULL, + updated_at INTEGER, + PRIMARY KEY (workflow_id, kind) +); + +CREATE TABLE IF NOT EXISTS workflow_kv ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + type TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); +""" + +_INDEX_STMTS = [ + "CREATE INDEX IF NOT EXISTS idx_workflow_executions_workflow_started ON workflow_executions(workflow_id, started_at DESC)", + "CREATE INDEX IF NOT EXISTS idx_workflow_executions_workflow_status ON workflow_executions(workflow_id, status)", + "CREATE INDEX IF NOT EXISTS idx_workflow_executions_trigger ON workflow_executions(workflow_id, trigger_type, trigger_id)", + "CREATE INDEX IF NOT EXISTS idx_workflow_execution_steps_exec_step ON workflow_execution_steps(exec_id, step_index)", +] diff --git a/flocks/workflow/triggers/runtime.py b/flocks/workflow/triggers/runtime.py index 5f6d94a52..74cb329e2 100644 --- a/flocks/workflow/triggers/runtime.py +++ b/flocks/workflow/triggers/runtime.py @@ -7,7 +7,6 @@ import time from typing import Any, Dict, List, Optional, Tuple -from flocks.storage.storage import Storage from flocks.utils.log import Log from flocks.workflow.execution_store import ( compact_history_for_storage, @@ -18,6 +17,7 @@ ) from flocks.workflow.fs_store import read_workflow_dir, workflow_scan_dirs from flocks.workflow.runner import run_workflow +from flocks.workflow.store import WorkflowStore from .compat import ( LEGACY_KAFKA_CONFIG_PREFIX, @@ -45,6 +45,16 @@ def _now_ms() -> int: return int(time.time() * 1000) +def _config_kind_from_legacy_key(key: str) -> Optional[str]: + if key.startswith(LEGACY_POLLER_CONFIG_PREFIX): + return "workflow_poller_config" + if key.startswith(LEGACY_SYSLOG_CONFIG_PREFIX): + return "workflow_syslog_config" + if key.startswith(LEGACY_KAFKA_CONFIG_PREFIX): + return "workflow_kafka_config" + return None + + class TriggerRuntime: """Unified trigger runtime that wraps legacy managers and custom adapters.""" @@ -70,17 +80,20 @@ def _iter_workflows(self) -> List[Dict[str, Any]]: async def _write_disabled_legacy_configs(self, workflow_id: str) -> None: now_ms = _now_ms() - await Storage.write( - f"{LEGACY_POLLER_CONFIG_PREFIX}{workflow_id}", + await WorkflowStore.put_config( + workflow_id, {"workflowId": workflow_id, "enabled": False, "updatedAt": now_ms}, + kind="workflow_poller_config", ) - await Storage.write( - f"{LEGACY_SYSLOG_CONFIG_PREFIX}{workflow_id}", + await WorkflowStore.put_config( + workflow_id, {"workflowId": workflow_id, "enabled": False, "updatedAt": now_ms}, + kind="workflow_syslog_config", ) - await Storage.write( - f"{LEGACY_KAFKA_CONFIG_PREFIX}{workflow_id}", + await WorkflowStore.put_config( + workflow_id, {"workflowId": workflow_id, "enabled": False, "updatedAt": now_ms}, + kind="workflow_kafka_config", ) @staticmethod @@ -88,7 +101,9 @@ def _trigger_signature(trigger: TriggerDefinition) -> str: payload = trigger.model_dump(mode="json", exclude_none=True) return json.dumps(payload, sort_keys=True, separators=(",", ":")) - async def _sync_legacy_configs_from_workflow(self, workflow_id: str, workflow_json: Dict[str, Any]) -> List[TriggerDefinition]: + async def _sync_legacy_configs_from_workflow( + self, workflow_id: str, workflow_json: Dict[str, Any] + ) -> List[TriggerDefinition]: triggers = workflow_trigger_definitions_from_json(workflow_json) if not triggers: if workflow_json_declares_triggers(workflow_json): @@ -99,22 +114,27 @@ async def _sync_legacy_configs_from_workflow(self, workflow_id: str, workflow_js for trigger in triggers: key, value = trigger_to_legacy_config(workflow_id, trigger) if key and value is not None: - await Storage.write(key, value) + kind = _config_kind_from_legacy_key(key) + if kind: + await WorkflowStore.put_config(workflow_id, value, kind=kind) if "schedule" not in by_type: - await Storage.write( - f"{LEGACY_POLLER_CONFIG_PREFIX}{workflow_id}", + await WorkflowStore.put_config( + workflow_id, {"workflowId": workflow_id, "enabled": False, "updatedAt": _now_ms()}, + kind="workflow_poller_config", ) if "syslog" not in by_type: - await Storage.write( - f"{LEGACY_SYSLOG_CONFIG_PREFIX}{workflow_id}", + await WorkflowStore.put_config( + workflow_id, {"workflowId": workflow_id, "enabled": False, "updatedAt": _now_ms()}, + kind="workflow_syslog_config", ) if "kafka" not in by_type: - await Storage.write( - f"{LEGACY_KAFKA_CONFIG_PREFIX}{workflow_id}", + await WorkflowStore.put_config( + workflow_id, {"workflowId": workflow_id, "enabled": False, "updatedAt": _now_ms()}, + kind="workflow_kafka_config", ) return triggers @@ -307,12 +327,11 @@ async def _start_custom_adapters_for_workflow(self, workflow_id: str, workflow_j continue key = (workflow_id, trigger.id or "") trigger_signature = desired_signatures[key] - if ( - key in self._custom_adapter_tasks - and self._custom_adapter_signatures.get(key) == trigger_signature - ): + if key in self._custom_adapter_tasks and self._custom_adapter_signatures.get(key) == trigger_signature: continue - plugin_id = str((trigger.source or {}).get("adapterId") or (trigger.source or {}).get("pluginId") or "").strip() + plugin_id = str( + (trigger.source or {}).get("adapterId") or (trigger.source or {}).get("pluginId") or "" + ).strip() plugin_spec = next((item for item in list_trigger_plugins() if item.get("id") == plugin_id), None) if plugin_spec is None: self._custom_status[key] = { @@ -352,11 +371,15 @@ async def _start_custom_adapters_for_workflow(self, workflow_id: str, workflow_j continue async def _emit(payload: Any, *, _trigger: TriggerDefinition = trigger) -> Dict[str, Any]: - event = payload if isinstance(payload, TriggerEvent) else build_trigger_event( - workflow_id=workflow_id, - trigger=_trigger, - body=payload, - raw=payload, + event = ( + payload + if isinstance(payload, TriggerEvent) + else build_trigger_event( + workflow_id=workflow_id, + trigger=_trigger, + body=payload, + raw=payload, + ) ) try: result = await self.dispatch_event( diff --git a/flocks/workflow/workflow_lint.py b/flocks/workflow/workflow_lint.py index 621fda325..82b0bc345 100644 --- a/flocks/workflow/workflow_lint.py +++ b/flocks/workflow/workflow_lint.py @@ -15,9 +15,26 @@ _CN_SECTION_OUTPUT_RE = re.compile(r"^\s*输出要求\s*[::]?\s*$") # Patterns that indicate an "expensive" node (LLM call / file write). -_EXPENSIVE_CALL_RE = re.compile( - r"""llm\.ask\s*\(|tool\.run\s*\(\s*['"]write['"]""" -) +_EXPENSIVE_CALL_RE = re.compile(r"""llm\.ask\s*\(|tool\.run\s*\(\s*['"]write['"]""") +_STRICT_MAPPING_TRIGGER_TYPES = {"syslog", "kafka", "schedule"} +_SCHEMA_ERROR_KINDS = { + "schema_mapping_src_not_declared", + "schema_mapping_dst_not_declared", + "schema_mapping_type_mismatch", + "schema_mapping_large_payload", + "schema_required_input_missing", +} +_TYPE_ALIASES = { + "array": "list", + "sequence": "list", + "object": "dict", + "map": "dict", + "integer": "int", + "number": "float", + "boolean": "bool", + "text": "str", + "string": "str", +} def _split_keys(raw: str) -> list[str]: @@ -69,6 +86,161 @@ def estimate_node_output_keys(node: Node) -> Set[str]: return keys +def _normalize_schema_field(raw: Any) -> Dict[str, Any]: + if isinstance(raw, str): + return {"type": _normalize_type_name(raw)} + if isinstance(raw, dict): + normalized = dict(raw) + if "type" in normalized: + normalized["type"] = _normalize_type_name(normalized.get("type")) + return normalized + return {} + + +def _normalize_type_name(value: Any) -> str: + raw = str(value or "").strip().lower() + return _TYPE_ALIASES.get(raw, raw) + + +def _schema_fields(raw_schema: Optional[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + if not isinstance(raw_schema, dict): + return {} + fields = raw_schema.get("fields") if isinstance(raw_schema.get("fields"), dict) else raw_schema + if not isinstance(fields, dict): + return {} + return {str(key): _normalize_schema_field(value) for key, value in fields.items()} + + +def _field_type(field: Dict[str, Any]) -> str: + return _normalize_type_name(field.get("type")) + + +def _types_compatible(src_type: str, dst_type: str) -> bool: + if not src_type or not dst_type: + return True + if src_type == dst_type: + return True + if "any" in {src_type, dst_type}: + return True + numeric = {"int", "float"} + return src_type in numeric and dst_type in numeric + + +def _path_top_key(path: Any) -> str: + src_path = "" if path is None else str(path).strip() + if src_path == "$": + return "$" + if src_path.startswith("$."): + src_path = src_path[2:] + return src_path.split(".", 1)[0] if src_path else "" + + +def is_schema_lint_error(item: Dict[str, Any]) -> bool: + return item.get("kind") in _SCHEMA_ERROR_KINDS and item.get("severity") == "error" + + +def _strict_edge_mapping_enabled(workflow: Workflow) -> bool: + metadata = workflow.metadata if isinstance(workflow.metadata, dict) else {} + candidates = [ + metadata.get("strict_edge_mapping"), + metadata.get("strictEdgeMapping"), + ] + for section_key in ("runtime", "runtime_defaults", "runtimeDefaults"): + section = metadata.get(section_key) + if isinstance(section, dict): + candidates.extend( + [ + section.get("strict_edge_mapping"), + section.get("strictEdgeMapping"), + ] + ) + for value in candidates: + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + return False + + +def _workflow_has_strict_mapping_setting(workflow: Workflow) -> bool: + metadata = workflow.metadata if isinstance(workflow.metadata, dict) else {} + if "strict_edge_mapping" in metadata or "strictEdgeMapping" in metadata: + return True + for section_key in ("runtime", "runtime_defaults", "runtimeDefaults"): + section = metadata.get(section_key) + if isinstance(section, dict) and ( + "strict_edge_mapping" in section or "strictEdgeMapping" in section + ): + return True + return False + + +def lint_implicit_full_payload_edges(workflow: Workflow) -> List[Dict[str, Any]]: + nodes = workflow.nodes_by_id() + strict = _strict_edge_mapping_enabled(workflow) + severity = "error" if strict else "warning" + results: list[dict[str, Any]] = [] + for edge in workflow.edges: + if edge.mapping: + continue + upstream = nodes.get(edge.from_) + downstream = nodes.get(edge.to) + results.append( + { + "kind": "implicit_full_payload_edge", + "severity": severity, + "edge_from": edge.from_, + "edge_to": edge.to, + "upstream_type": getattr(upstream, "type", None), + "downstream_type": getattr(downstream, "type", None), + "strict_edge_mapping": strict, + "message": ( + f"edge {edge.from_!r}->{edge.to!r} has no mapping; the full upstream " + "payload will be passed to the downstream node" + ), + } + ) + return results + + +def lint_recommend_strict_edge_mapping(workflow: Workflow) -> List[Dict[str, Any]]: + """Recommend strict edge mapping for high-volume trigger workflows. + + This intentionally stays a warning so existing workflow execution remains + compatible unless the workflow explicitly opts into strict mode. + """ + if _workflow_has_strict_mapping_setting(workflow): + return [] + + trigger_types = sorted( + { + trigger.type + for trigger in workflow.triggers + if getattr(trigger, "type", None) in _STRICT_MAPPING_TRIGGER_TYPES + } + ) + if not trigger_types: + return [] + + return [ + { + "kind": "recommend_strict_edge_mapping", + "severity": "warning", + "trigger_types": trigger_types, + "message": ( + "workflows triggered by syslog, kafka, or schedule can process high-volume " + "payloads; set metadata.runtime.strict_edge_mapping=true, " + "metadata.runtime.dataflow_mode='vertex_cache', and use explicit edge " + "mappings for new workflow definitions" + ), + } + ] + + def lint_workflow_mappings(workflow: Workflow) -> List[Dict[str, Any]]: nodes = workflow.nodes_by_id() warnings: list[dict[str, Any]] = [] @@ -85,35 +257,151 @@ def lint_workflow_mappings(workflow: Workflow) -> List[Dict[str, Any]]: src_path = src_path[2:] top_key = src_path.split(".", 1)[0] if src_path else "" if top_key and upstream_out and top_key not in upstream_out: - warnings.append({ - "kind": "mapping_src_key_not_in_upstream_outputs", - "edge_from": e.from_, - "edge_to": e.to, - "dst_key": dst, - "src_path": src, - "upstream_type": getattr(upstream, "type", None), - "estimated_upstream_output_keys": sorted(upstream_out)[:50], - "message": ( - f"edge.mapping maps src {src!r} but upstream node {e.from_!r} " - "does not appear to write that key to outputs; mapping may produce missing value" - ), - }) - if dst == src and not (e.const or {}): - warnings.append({ - "kind": "scheme_a_suggest_omit_identity_mapping", - "severity": "warning", - "edge_from": e.from_, - "edge_to": e.to, - "dst_key": dst, - "src_path": src, - "message": ( - "edge.mapping is an identity mapping. Scheme A recommends omitting mapping " - "to pass through the full payload and reduce missing-key issues." - ), - }) + warnings.append( + { + "kind": "mapping_src_key_not_in_upstream_outputs", + "edge_from": e.from_, + "edge_to": e.to, + "dst_key": dst, + "src_path": src, + "upstream_type": getattr(upstream, "type", None), + "estimated_upstream_output_keys": sorted(upstream_out)[:50], + "message": ( + f"edge.mapping maps src {src!r} but upstream node {e.from_!r} " + "does not appear to write that key to outputs; mapping may produce missing value" + ), + } + ) return warnings +def lint_workflow_schema(workflow: Workflow) -> List[Dict[str, Any]]: + """Validate explicit edge mappings against lightweight node schemas. + + This is intentionally opt-in: old workflows without ``inputSchema`` or + ``outputSchema`` keep their current behavior. When a node does declare a + schema, mappings to/from unknown fields or incompatible field types become + lint errors. + """ + nodes = workflow.nodes_by_id() + results: list[dict[str, Any]] = [] + + provided_inputs: Dict[str, Set[str]] = {node.id: set() for node in workflow.nodes} + for edge in workflow.edges: + downstream = nodes.get(edge.to) + if downstream is None: + continue + if edge.mapping: + provided_inputs.setdefault(edge.to, set()).update(str(dst) for dst in edge.mapping) + if edge.const: + provided_inputs.setdefault(edge.to, set()).update(str(key) for key in edge.const) + + for edge in workflow.edges: + if not edge.mapping: + continue + upstream = nodes.get(edge.from_) + downstream = nodes.get(edge.to) + if upstream is None or downstream is None: + continue + output_schema = _schema_fields(upstream.output_schema) + input_schema = _schema_fields(downstream.input_schema) + for dst, src in edge.mapping.items(): + dst_key = str(dst) + src_key = _path_top_key(src) + output_field = output_schema.get(src_key) + input_field = input_schema.get(dst_key) + if output_schema and src_key and src_key != "$" and output_field is None: + results.append( + { + "kind": "schema_mapping_src_not_declared", + "severity": "error", + "edge_from": edge.from_, + "edge_to": edge.to, + "dst_key": dst_key, + "src_path": src, + "declared_output_keys": sorted(output_schema), + "message": ( + f"edge {edge.from_!r}->{edge.to!r} maps src {src!r}, " + f"but upstream node {edge.from_!r} outputSchema does not declare {src_key!r}" + ), + } + ) + continue + if input_schema and input_field is None: + results.append( + { + "kind": "schema_mapping_dst_not_declared", + "severity": "error", + "edge_from": edge.from_, + "edge_to": edge.to, + "dst_key": dst_key, + "src_path": src, + "declared_input_keys": sorted(input_schema), + "message": ( + f"edge {edge.from_!r}->{edge.to!r} maps to input {dst_key!r}, " + f"but downstream node {edge.to!r} inputSchema does not declare it" + ), + } + ) + continue + if output_field is None or input_field is None: + continue + src_type = _field_type(output_field) + dst_type = _field_type(input_field) + if not _types_compatible(src_type, dst_type): + results.append( + { + "kind": "schema_mapping_type_mismatch", + "severity": "error", + "edge_from": edge.from_, + "edge_to": edge.to, + "dst_key": dst_key, + "src_path": src, + "output_type": src_type, + "input_type": dst_type, + "message": ( + f"edge {edge.from_!r}->{edge.to!r} maps {src_key!r} ({src_type}) " + f"to {dst_key!r} ({dst_type})" + ), + } + ) + if output_field.get("large") and not input_field.get("large"): + results.append( + { + "kind": "schema_mapping_large_payload", + "severity": "error", + "edge_from": edge.from_, + "edge_to": edge.to, + "dst_key": dst_key, + "src_path": src, + "message": ( + f"edge {edge.from_!r}->{edge.to!r} maps large output {src_key!r} " + f"to input {dst_key!r} that is not marked large" + ), + } + ) + + for node in workflow.nodes: + if node.id == workflow.start: + continue + input_schema = _schema_fields(node.input_schema) + if not input_schema: + continue + required = {key for key, field in input_schema.items() if field.get("required")} + missing = sorted(required - provided_inputs.get(node.id, set())) + if missing: + results.append( + { + "kind": "schema_required_input_missing", + "severity": "error", + "node_id": node.id, + "missing_inputs": missing, + "message": f"node {node.id!r} inputSchema requires inputs that no incoming edge provides: {missing}", + } + ) + return results + + # --------------------------------------------------------------------------- # Join-safety checks # --------------------------------------------------------------------------- @@ -168,6 +456,8 @@ def lint_join_requirements(workflow: Workflow) -> List[Dict[str, Any]]: continue if getattr(node, "join", False): continue # already has join, OK + if node.type == "loop": + continue unique_sources = set(sources) @@ -184,18 +474,20 @@ def lint_join_requirements(workflow: Workflow) -> List[Dict[str, Any]]: break if not is_exclusive: - results.append({ - "kind": "multi_incoming_no_join", - "severity": "error", - "node_id": nid, - "sources": sorted(sources), - "message": ( - f"Node {nid!r} has {len(sources)} incoming edges from " - f"non-exclusive sources {sorted(unique_sources)} but join=false. " - "This will cause the node to execute multiple times. " - "Set join=true on this node or restructure edges." - ), - }) + results.append( + { + "kind": "multi_incoming_no_join", + "severity": "warning", + "node_id": nid, + "sources": sorted(sources), + "message": ( + f"Node {nid!r} has {len(sources)} incoming edges from " + f"non-exclusive sources {sorted(unique_sources)} but join=false. " + "This may cause the node to execute multiple times. " + "Set join=true on this node or restructure edges." + ), + } + ) return results @@ -234,18 +526,20 @@ def lint_expensive_node_multi_trigger(workflow: Workflow) -> List[Dict[str, Any] break if not is_exclusive: - results.append({ - "kind": "expensive_node_multi_trigger", - "severity": "error", - "node_id": nid, - "sources": sorted(sources), - "message": ( - f"Expensive node {nid!r} (contains LLM/write calls) has " - f"{len(sources)} non-exclusive incoming edges but join=false. " - "This may cause costly duplicate execution. " - "Add a join node before this expensive node." - ), - }) + results.append( + { + "kind": "expensive_node_multi_trigger", + "severity": "error", + "node_id": nid, + "sources": sorted(sources), + "message": ( + f"Expensive node {nid!r} (contains LLM/write calls) has " + f"{len(sources)} non-exclusive incoming edges but join=false. " + "This may cause costly duplicate execution. " + "Add a join node before this expensive node." + ), + } + ) return results @@ -266,15 +560,17 @@ def lint_subworkflow_depth(workflow: Workflow) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] for node in workflow.nodes: if node.type == "subworkflow": - results.append({ - "kind": "SW-001", - "severity": "error", - "node_id": node.id, - "message": ( - f"Node {node.id!r} is a subworkflow node. " - "Sub-workflows cannot nest further sub-workflows (max depth=1)." - ), - }) + results.append( + { + "kind": "SW-001", + "severity": "error", + "node_id": node.id, + "message": ( + f"Node {node.id!r} is a subworkflow node. " + "Sub-workflows cannot nest further sub-workflows (max depth=1)." + ), + } + ) return results @@ -294,23 +590,27 @@ def lint_subworkflow_ids( if node.type == "subworkflow": wid = node.workflow_id or "" if not wid: - results.append({ - "kind": "SW-002", - "severity": "error", - "node_id": node.id, - "message": f"subworkflow node {node.id!r} has no workflow_id set.", - }) + results.append( + { + "kind": "SW-002", + "severity": "error", + "node_id": node.id, + "message": f"subworkflow node {node.id!r} has no workflow_id set.", + } + ) elif wid not in known_workflow_ids: - results.append({ - "kind": "SW-002", - "severity": "error", - "node_id": node.id, - "workflow_id": wid, - "message": ( - f"subworkflow node {node.id!r} references workflow_id={wid!r} " - "which was not found in the known workflow registry." - ), - }) + results.append( + { + "kind": "SW-002", + "severity": "error", + "node_id": node.id, + "workflow_id": wid, + "message": ( + f"subworkflow node {node.id!r} references workflow_id={wid!r} " + "which was not found in the known workflow registry." + ), + } + ) return results @@ -338,10 +638,13 @@ def lint_workflow( subworkflow nodes. """ results: List[Dict[str, Any]] = [] + results.extend(lint_implicit_full_payload_edges(workflow)) + results.extend(lint_recommend_strict_edge_mapping(workflow)) # Existing mapping checks (warnings) for item in lint_workflow_mappings(workflow): item.setdefault("severity", "warning") results.append(item) + results.extend(lint_workflow_schema(workflow)) # Join safety (errors) results.extend(lint_join_requirements(workflow)) # Expensive node multi-trigger (errors) diff --git a/pyproject.toml b/pyproject.toml index 45d82a8ec..81ebf74ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flocks" -version = "v2026.6.24" +version = "v2026.7.1" description = "AI-Native SecOps platform with multi-agent collaboration" authors = [ {name = "Flocks Team", email = "team@example.com"} diff --git a/scripts/dev.sh b/scripts/dev.sh index cc0a2e262..8a7554c24 100644 --- a/scripts/dev.sh +++ b/scripts/dev.sh @@ -17,6 +17,7 @@ BACKEND_HOST="${BACKEND_HOST:-127.0.0.1}" BACKEND_PORT="${BACKEND_PORT:-8000}" FRONTEND_HOST="${FRONTEND_HOST:-127.0.0.1}" FRONTEND_PORT="${FRONTEND_PORT:-5173}" +FLOCKS_CONSOLE_BASE_URL="${FLOCKS_CONSOLE_BASE_URL:-https://portalflocks.threatbook.cn}" BACKEND_ACCESS_HOST="${BACKEND_HOST}" if [ "${BACKEND_ACCESS_HOST}" = "0.0.0.0" ] || [ "${BACKEND_ACCESS_HOST}" = "::" ]; then @@ -42,6 +43,7 @@ Environment variables: BACKEND_PORT 默认 8000 FRONTEND_HOST 默认 127.0.0.1 FRONTEND_PORT 默认 5173 + FLOCKS_CONSOLE_BASE_URL 默认 https://portalflocks.threatbook.cn EOF } @@ -206,6 +208,7 @@ start_backend() { cd "${PROJECT_ROOT}" _FLOCKS_WEBUI_HOST="${FRONTEND_HOST}" \ _FLOCKS_WEBUI_PORT="${FRONTEND_PORT}" \ + FLOCKS_CONSOLE_BASE_URL="${FLOCKS_CONSOLE_BASE_URL}" \ uv run uvicorn flocks.server.app:app \ --host "${BACKEND_HOST}" \ --port "${BACKEND_PORT}" \ @@ -226,6 +229,7 @@ start_all() { cd "${PROJECT_ROOT}" _FLOCKS_WEBUI_HOST="${FRONTEND_HOST}" \ _FLOCKS_WEBUI_PORT="${FRONTEND_PORT}" \ + FLOCKS_CONSOLE_BASE_URL="${FLOCKS_CONSOLE_BASE_URL}" \ uv run uvicorn flocks.server.app:app \ --host "${BACKEND_HOST}" \ --port "${BACKEND_PORT}" \ @@ -262,4 +266,4 @@ case "${MODE}" in usage exit 1 ;; -esac \ No newline at end of file +esac diff --git a/tests/contracts/access/test_runtime.py b/tests/contracts/access/test_runtime.py new file mode 100644 index 000000000..ed4d0fac8 --- /dev/null +++ b/tests/contracts/access/test_runtime.py @@ -0,0 +1,669 @@ +from __future__ import annotations + +import json +import sqlite3 +from pathlib import Path +from typing import Any + +import pytest + +from flocks.auth.context import AuthUser +from flocks.contracts.access.discovery import discover_contract_plugins +from flocks.contracts.access.models import ( + Binding, + Contract, + ContractOperation, + ContractRuntimeError, + DriverResult, + InternalDataRow, + PolicyContext, + RuntimeContext, + WebUIContractPlugin, +) +from flocks.contracts.access.pipeline import OverlayStore +from flocks.contracts.access.plans import FieldDependencyPlanCompiler, PolicyPlanCompiler +from flocks.contracts.access.runtime import BindingTestHarness, NO_POLICY_SCOPE, OperationRuntime, PolicyContextResolver +from flocks.contracts.webui.store import WebUIPagesStore +from flocks.plugin.loader import PluginLoader + +SOURCE_PAGE_ID = "contract-source" +PAGE_ID = "test/records" +CONTRACT_ID = "test.records" +CONTRACT_VERSION = "1.0" +DRIVER_FIELDS = frozenset({"id", "tenant", "asset_group", "status", "severity", "time"}) + + +def _write_contract_assets(store: WebUIPagesStore, records: list[dict[str, Any]]) -> None: + store.create_page(page_id=SOURCE_PAGE_ID, title="Contract source") + asset_dir = store.asset_path(SOURCE_PAGE_ID, "2026-06-25") + asset_dir.mkdir(parents=True, exist_ok=True) + asset_path = asset_dir / "records.jsonl" + lines = [{"_type": "file_header", "date": "2026-06-25"}, *records] + asset_path.write_text( + "\n".join(json.dumps(line, ensure_ascii=False) for line in lines), + encoding="utf-8", + ) + + +def _write_contract_sqlite(db_path: Path, records: list[dict[str, Any]]) -> None: + db_path.parent.mkdir(parents=True, exist_ok=True) + connection = sqlite3.connect(db_path) + connection.execute( + """ + CREATE TABLE records ( + id TEXT PRIMARY KEY, + record_date TEXT NOT NULL, + record_json TEXT NOT NULL + ) + """ + ) + connection.executemany( + "INSERT INTO records (id, record_date, record_json) VALUES (?, ?, ?)", + [ + ( + str(record.get("id") or index), + str(record.get("record_date") or "2026-06-25"), + json.dumps(record, ensure_ascii=False), + ) + for index, record in enumerate(records, start=1) + ], + ) + connection.commit() + connection.close() + + +def _contract_record(**overrides: Any) -> dict[str, Any]: + record = { + "id": "record-1", + "tenant": "tenant-a", + "asset_group": "core", + "status": "open", + "severity": "high", + "time": 1779086941, + } + record.update(overrides) + return record + + +def _store(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> WebUIPagesStore: + root = tmp_path / "contracts-webui" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + return WebUIPagesStore(root=root) + + +def _contract() -> Contract: + return Contract( + contract_id=CONTRACT_ID, + version=CONTRACT_VERSION, + page_id=PAGE_ID, + operations={ + "list": ContractOperation( + name="list", + operation_type="query", + adapter_required_fields=DRIVER_FIELDS, + identity_fields=frozenset({"id"}), + public_fields=frozenset({"summary", "items", "meta"}), + filter_fields=frozenset({"tenant", "asset_group", "status", "severity"}), + filter_param_fields={"status": "status", "severity": "severity"}, + tenant_policy_field="tenant", + asset_group_policy_field="asset_group", + cursor_fields=frozenset({"time", "id"}), + sort_fields=frozenset({"time", "id"}), + default_limit=100, + max_limit=1000, + ), + "update": ContractOperation( + name="update", + operation_type="mutation", + adapter_required_fields=frozenset(), + identity_fields=frozenset({"entityType", "entityId"}), + public_fields=frozenset({"ok", "entityType", "entityId", "overlayVersion", "writeThrough"}), + requires_idempotency_key=True, + requires_expected_overlay_version=True, + mutation_entity_types=frozenset({"record"}), + ), + }, + ) + + +class _BindingResolver: + def __init__( + self, + store: WebUIPagesStore, + *, + capabilities: frozenset[str] | None = None, + adapter_kind: str = "builtin-jsonl", + source_root: Path | None = None, + driver_options: dict[str, Any] | None = None, + ) -> None: + self._store = store + self._capabilities = capabilities or frozenset({"query", "mutation"}) + self._adapter_kind = adapter_kind + self._source_root = source_root + self._driver_options = driver_options or {} + + def resolve(self, *, page_id: str, slot_id: str, contract_id: str, contract_version: str) -> Binding: + source_root = self._source_root or self._store.asset_path(SOURCE_PAGE_ID, "") + return Binding( + binding_id=f"test-{self._adapter_kind}", + binding_version=1, + page_id=page_id, + slot_id=slot_id, + contract_id=contract_id, + contract_version=contract_version, + adapter_kind=self._adapter_kind, + source_page_id=SOURCE_PAGE_ID, + source_root=source_root, + driver_available_fields=DRIVER_FIELDS, + driver_allowlist_roots=(source_root if source_root.is_dir() else source_root.parent,), + driver_options=self._driver_options, + capabilities=self._capabilities, + ) + + +class _Adapter: + def normalize(self, driver_result: DriverResult) -> list[InternalDataRow]: + return [ + InternalDataRow( + raw=row, + identity={"entityType": "record", "entityId": f"record:{row['id']}"}, + ) + for row in driver_result.rows + ] + + +class _ResponsePipeline: + def __init__(self, overlay_store: OverlayStore) -> None: + self._overlay_store = overlay_store + + def run_query( + self, + *, + context: RuntimeContext, + binding_source_page_id: str, + driver_result: DriverResult, + rows: list[InternalDataRow], + filter_stages_applied: list[dict[str, str]], + ) -> dict[str, Any]: + merged_rows = self._overlay_store.merge(rows, context) + items = [ + { + **row.raw, + "entityType": row.identity["entityType"], + "entityId": row.identity["entityId"], + "overlayVersion": row.raw.get("_overlay_version", 0), + } + for row in merged_rows + ] + return { + "summary": { + "totalRaw": driver_result.total_raw, + "filteredUnique": driver_result.filtered_unique, + "closed": sum(1 for item in items if item.get("manualStatus") == "closed"), + }, + "items": items, + "meta": { + "sourcePageId": binding_source_page_id, + "filterStagesApplied": filter_stages_applied, + }, + } + + +def _plugin( + store: WebUIPagesStore, + *, + capabilities: frozenset[str] | None = None, + adapter_kind: str = "builtin-jsonl", + source_root: Path | None = None, + driver_options: dict[str, Any] | None = None, +) -> WebUIContractPlugin: + overlay_store = OverlayStore() + return WebUIContractPlugin( + plugin_id="test-records", + contracts=(_contract(),), + binding_resolver=_BindingResolver( + store, + capabilities=capabilities, + adapter_kind=adapter_kind, + source_root=source_root, + driver_options=driver_options, + ), + adapter=_Adapter(), + response_pipeline=_ResponsePipeline(overlay_store), + overlay_store=overlay_store, + ) + + +def _runtime( + store: WebUIPagesStore, + policy_context: PolicyContext | None = None, + *, + capabilities: frozenset[str] | None = None, +) -> OperationRuntime: + class Resolver(PolicyContextResolver): + def resolve(self, _principal: Any) -> PolicyContext: + return policy_context or PolicyContext() + + return OperationRuntime( + plugins=(_plugin(store, capabilities=capabilities),), + policy_context_resolver=Resolver(), + ) + + +def test_query_uses_shared_runtime_and_jsonl_driver(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + runtime = _runtime(store) + + response = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"params": {"limit": 10}}, + principal=AuthUser(id="u1", username="alice", role="admin"), + ) + + assert response.body["summary"]["totalRaw"] == 1 + assert response.body["items"][0]["entityId"] == "record:record-1" + assert response.body["items"][0]["status"] == "open" + assert response.body["meta"]["sourcePageId"] == SOURCE_PAGE_ID + + +def test_query_can_use_sqlite_json_driver(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + db_path = tmp_path / "contract_records.db" + _write_contract_sqlite( + db_path, + [ + _contract_record(id="allowed", status="open", severity="high", record_date="2026-06-25"), + _contract_record(id="hidden", status="closed", severity="low", record_date="2026-06-25"), + _contract_record(id="outside-date", status="open", severity="high", record_date="2026-06-26"), + ], + ) + runtime = OperationRuntime( + plugins=( + _plugin( + store, + adapter_kind="builtin-sqlite-json", + source_root=db_path, + driver_options={ + "table": "records", + "recordColumn": "record_json", + "dateColumn": "record_date", + }, + ), + ), + ) + + response = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"params": {"date": "2026-06-25", "filters": {"status": ["open"]}, "limit": 10}}, + principal=AuthUser(id="u1", username="alice", role="admin"), + ) + + assert response.body["summary"]["totalRaw"] == 2 + assert [item["id"] for item in response.body["items"]] == ["allowed"] + assert response.body["items"][0]["entityId"] == "record:allowed" + + +def test_query_rejects_page_supplied_binding_or_idempotency_key(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + runtime = _runtime(store) + + with pytest.raises(ContractRuntimeError) as binding_error: + runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"bindingId": "bad", "params": {}}, + principal=None, + ) + assert binding_error.value.code == "forbidden_request_field" + + with pytest.raises(ContractRuntimeError) as idempotency_error: + runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"idempotencyKey": "query-key", "params": {}}, + principal=None, + ) + assert idempotency_error.value.code == "forbidden_request_field" + + +def test_default_policy_resolver_filters_member_scope(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets( + store, + [ + _contract_record(id="allowed", tenant="tenant-a", asset_group="core"), + _contract_record(id="blocked", tenant="tenant-b", asset_group="core"), + ], + ) + runtime = OperationRuntime(plugins=(_plugin(store),)) + + response = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"params": {"limit": 10}}, + principal=AuthUser( + id="u1", + username="analyst", + role="member", + tenant_ids=("tenant-a",), + asset_groups=("core",), + ), + ) + + assert [item["id"] for item in response.body["items"]] == ["allowed"] + assert response.body["meta"]["filterStagesApplied"][:2] == [ + { + "field": "tenant", + "source": "policy.tenantIds", + "stage": "driver-native", + "enforcement": "driver-required", + }, + { + "field": "asset_group", + "source": "policy.assetGroups", + "stage": "driver-native", + "enforcement": "driver-required", + }, + ] + + +def test_default_policy_resolver_fails_closed_for_unscoped_member(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record(id="hidden", tenant="tenant-a", asset_group="core")]) + runtime = OperationRuntime(plugins=(_plugin(store),)) + + response = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"params": {"limit": 10}}, + principal=AuthUser(id="u1", username="analyst", role="member"), + ) + + assert response.body["items"] == [] + assert response.body["meta"]["filterStagesApplied"][0]["source"] == "policy.tenantIds" + assert response.body["meta"]["filterStagesApplied"][0]["field"] == "tenant" + assert response.body["meta"]["filterStagesApplied"][0]["stage"] == "driver-native" + assert PolicyContextResolver().resolve(AuthUser(id="u1", username="analyst", role="member")).tenant_ids == ( + NO_POLICY_SCOPE, + ) + + +def test_policy_and_field_dependency_plans_drive_query_projection(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + provider = _plugin(store) + binding = provider.binding_resolver.resolve( + page_id=PAGE_ID, + slot_id="primary", + contract_id=CONTRACT_ID, + contract_version=CONTRACT_VERSION, + ) + operation = provider.contracts[0].operations["list"] + policy_plan = PolicyPlanCompiler().compile( + operation=operation, + binding=binding, + policy_context=PolicyContext(tenant_ids=("tenant-a",), asset_groups=("core",)), + params={"filters": {"status": ["open"]}}, + ) + field_plan = FieldDependencyPlanCompiler().compile(operation=operation, policy_plan=policy_plan) + + assert {"tenant", "asset_group", "status", "id", "time"}.issubset(field_plan.driver_required_fields) + + +def test_policy_filter_not_enforceable_when_binding_lacks_required_field(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + provider = _plugin(store) + binding = provider.binding_resolver.resolve( + page_id=PAGE_ID, + slot_id="primary", + contract_id=CONTRACT_ID, + contract_version=CONTRACT_VERSION, + ) + binding = binding.__class__( + **{ + **binding.__dict__, + "driver_available_fields": frozenset({"id"}), + } + ) + + with pytest.raises(ContractRuntimeError) as exc: + PolicyPlanCompiler().compile( + operation=provider.contracts[0].operations["list"], + binding=binding, + policy_context=PolicyContext(tenant_ids=("tenant-a",)), + params={}, + ) + + assert exc.value.code == "policy_filter_not_enforceable" + + +def test_mutation_pipeline_enforces_overlay_version_and_idempotency(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + runtime = _runtime(store) + payload = { + "idempotencyKey": "key-1", + "expectedOverlayVersion": None, + "params": { + "entityType": "record", + "entityId": "record:record-1", + "manualStatus": "closed", + "note": "confirmed", + }, + } + + response = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="update", + payload=payload, + principal=None, + ) + replay = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="update", + payload=payload, + principal=None, + ) + + assert response.body["overlayVersion"] == 1 + assert replay.body == response.body + + with pytest.raises(ContractRuntimeError) as exc: + runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="update", + payload={**payload, "params": {**payload["params"], "note": "different"}}, + principal=None, + ) + assert exc.value.code == "conflict" + + +def test_mutation_overlay_is_visible_on_followup_query(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + runtime = _runtime(store) + + runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="update", + payload={ + "idempotencyKey": "overlay-query-1", + "expectedOverlayVersion": None, + "params": { + "entityType": "record", + "entityId": "record:record-1", + "manualStatus": "closed", + "note": "confirmed from logs", + }, + }, + principal=None, + ) + updated = runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + payload={"params": {"limit": 10}}, + principal=None, + ) + + item = updated.body["items"][0] + assert updated.body["summary"]["closed"] == 1 + assert item["manualStatus"] == "closed" + assert item["note"] == "confirmed from logs" + assert item["overlayVersion"] == 1 + + +def test_mutation_rejects_unsupported_entity_type(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + runtime = _runtime(store) + + with pytest.raises(ContractRuntimeError) as exc: + runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="update", + payload={ + "idempotencyKey": "bad-entity-type", + "expectedOverlayVersion": None, + "params": { + "entityType": "case", + "entityId": "case-1", + "manualStatus": "closed", + }, + }, + principal=None, + ) + + assert exc.value.code == "invalid_request" + + +def test_runtime_rejects_operations_outside_binding_capabilities(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + runtime = _runtime(store, capabilities=frozenset({"query"})) + + with pytest.raises(ContractRuntimeError) as exc: + runtime.execute( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="update", + payload={ + "idempotencyKey": "readonly", + "expectedOverlayVersion": None, + "params": { + "entityType": "record", + "entityId": "record:record-1", + }, + }, + principal=None, + ) + + assert exc.value.code == "operation_not_supported" + + +def test_binding_test_harness_reuses_operation_runtime(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + store = _store(tmp_path, monkeypatch) + _write_contract_assets(store, [_contract_record()]) + harness = BindingTestHarness(runtime=_runtime(store)) + + results = harness.run( + page_id=PAGE_ID, + contract_id=CONTRACT_ID, + operation_name="list", + profiles=(AuthUser(id="u1", username="alice", role="admin"),), + ) + + assert results == [{"ok": True, "statusCode": 200}] + + +def test_discovery_keeps_user_plugin_when_project_plugin_has_same_id(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + user_root = tmp_path / "user-plugins" + project_dir = tmp_path / "project" + user_access = user_root / "contracts" / "access" + project_access = project_dir / ".flocks" / "plugins" / "contracts" / "access" + user_access.mkdir(parents=True) + project_access.mkdir(parents=True) + + plugin_template = """ +from flocks.contracts.access.models import WebUIContractPlugin + +CONTRACTS = [ + WebUIContractPlugin( + plugin_id="same-id", + contracts=(), + binding_resolver=object(), + adapter=object(), + response_pipeline=object(), + version="{version}", + ) +] +""" + (user_access / "plugin.py").write_text(plugin_template.format(version="user"), encoding="utf-8") + (project_access / "plugin.py").write_text(plugin_template.format(version="project"), encoding="utf-8") + + monkeypatch.setattr(PluginLoader, "_plugin_root", user_root) + monkeypatch.setattr(PluginLoader, "_extension_points", dict(PluginLoader._extension_points)) + PluginLoader.clear_extension_points() + + plugins = discover_contract_plugins(project_dir=project_dir) + + assert [(plugin.plugin_id, plugin.version) for plugin in plugins] == [("same-id", "user")] + + +def test_discovery_keeps_user_contract_when_project_plugin_has_same_contract_id( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +): + user_root = tmp_path / "user-plugins" + project_dir = tmp_path / "project" + user_access = user_root / "contracts" / "access" + project_access = project_dir / ".flocks" / "plugins" / "contracts" / "access" + user_access.mkdir(parents=True) + project_access.mkdir(parents=True) + + plugin_template = """ +from flocks.contracts.access.models import Contract, WebUIContractPlugin + +CONTRACTS = [ + WebUIContractPlugin( + plugin_id="{plugin_id}", + contracts=(Contract(contract_id="same.contract", version="1.0", page_id="page", operations={{}}),), + binding_resolver=object(), + adapter=object(), + response_pipeline=object(), + version="{version}", + ) +] +""" + (user_access / "user_plugin.py").write_text( + plugin_template.format(plugin_id="user-plugin", version="user"), + encoding="utf-8", + ) + (project_access / "project_plugin.py").write_text( + plugin_template.format(plugin_id="project-plugin", version="project"), + encoding="utf-8", + ) + + monkeypatch.setattr(PluginLoader, "_plugin_root", user_root) + monkeypatch.setattr(PluginLoader, "_extension_points", dict(PluginLoader._extension_points)) + PluginLoader.clear_extension_points() + + plugins = discover_contract_plugins(project_dir=project_dir) + + assert [(plugin.plugin_id, plugin.version) for plugin in plugins] == [("user-plugin", "user")] diff --git a/tests/user_defined_pages/test_api_runtime.py b/tests/contracts/webui/test_api_runtime.py similarity index 74% rename from tests/user_defined_pages/test_api_runtime.py rename to tests/contracts/webui/test_api_runtime.py index d08462cc4..7af873f53 100644 --- a/tests/user_defined_pages/test_api_runtime.py +++ b/tests/contracts/webui/test_api_runtime.py @@ -2,15 +2,15 @@ from fastapi import FastAPI, Request from httpx import AsyncClient, ASGITransport -from flocks.user_defined_pages.api_runtime import UserDefinedPageApiRuntime -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.api_runtime import WebUIPageApiRuntime +from flocks.contracts.webui.store import WebUIPagesStore @pytest.fixture def runtime_store(tmp_path, monkeypatch): - root = tmp_path / "user_defined_pages" - monkeypatch.setenv("FLOCKS_USER_DEFINED_PAGES_ROOT", str(root)) - store = UserDefinedPagesStore() + root = tmp_path / "webui_pages" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + store = WebUIPagesStore() store.create_page(page_id="runtime-page", title="运行时页面") return store @@ -21,7 +21,7 @@ def runtime_app(): @pytest.mark.asyncio -async def test_api_runtime_dispatch_sync_and_async(runtime_store: UserDefinedPagesStore, runtime_app: FastAPI): +async def test_api_runtime_dispatch_sync_and_async(runtime_store: WebUIPagesStore, runtime_app: FastAPI): runtime_store.save_source_file( "runtime-page", "api/routes.yaml", @@ -46,28 +46,28 @@ async def test_api_runtime_dispatch_sync_and_async(runtime_store: UserDefinedPag " return {'acked': body.get('id')}\n" ), ) - runtime = UserDefinedPageApiRuntime(runtime_store) + runtime = WebUIPageApiRuntime(runtime_store) - @runtime_app.get("/api/user-defined-pages/{page_id}/api/{api_path:path}") + @runtime_app.get("/api/contracts/webui/pages/{page_id}/api/{api_path:path}") async def _get_dispatch(page_id: str, api_path: str, request: Request): return await runtime.dispatch(page_id, api_path, request, {"role": "admin"}) - @runtime_app.post("/api/user-defined-pages/{page_id}/api/{api_path:path}") + @runtime_app.post("/api/contracts/webui/pages/{page_id}/api/{api_path:path}") async def _post_dispatch(page_id: str, api_path: str, request: Request): return await runtime.dispatch(page_id, api_path, request, {"role": "admin"}) async with AsyncClient(transport=ASGITransport(app=runtime_app), base_url="http://test") as client: - resp_get = await client.get("/api/user-defined-pages/runtime-page/api/stats") + resp_get = await client.get("/api/contracts/webui/pages/runtime-page/api/stats") assert resp_get.status_code == 200 assert resp_get.json()["pageId"] == "runtime-page" - resp_post = await client.post("/api/user-defined-pages/runtime-page/api/ack", json={"id": "a-1"}) + resp_post = await client.post("/api/contracts/webui/pages/runtime-page/api/ack", json={"id": "a-1"}) assert resp_post.status_code == 200 assert resp_post.json() == {"acked": "a-1"} @pytest.mark.asyncio -async def test_api_runtime_timeout_and_reload(runtime_store: UserDefinedPagesStore, runtime_app: FastAPI): +async def test_api_runtime_timeout_and_reload(runtime_store: WebUIPagesStore, runtime_app: FastAPI): runtime_store.save_source_file( "runtime-page", "api/routes.yaml", @@ -89,14 +89,14 @@ async def test_api_runtime_timeout_and_reload(runtime_store: UserDefinedPagesSto " return {'ok': True}\n" ), ) - runtime = UserDefinedPageApiRuntime(runtime_store) + runtime = WebUIPageApiRuntime(runtime_store) - @runtime_app.get("/api/user-defined-pages/{page_id}/api/{api_path:path}") + @runtime_app.get("/api/contracts/webui/pages/{page_id}/api/{api_path:path}") async def _dispatch(page_id: str, api_path: str, request: Request): return await runtime.dispatch(page_id, api_path, request, {"role": "admin"}) async with AsyncClient(transport=ASGITransport(app=runtime_app), base_url="http://test") as client: - timeout_resp = await client.get("/api/user-defined-pages/runtime-page/api/slow") + timeout_resp = await client.get("/api/contracts/webui/pages/runtime-page/api/slow") assert timeout_resp.status_code == 504 runtime_store.save_source_file( @@ -118,13 +118,13 @@ async def _dispatch(page_id: str, api_path: str, request: Request): assert routes[0]["handler"] == "handlers.fast" async with AsyncClient(transport=ASGITransport(app=runtime_app), base_url="http://test") as client: - ok_resp = await client.get("/api/user-defined-pages/runtime-page/api/slow") + ok_resp = await client.get("/api/contracts/webui/pages/runtime-page/api/slow") assert ok_resp.status_code == 200 assert ok_resp.json() == {"ok": True} @pytest.mark.asyncio -async def test_api_runtime_rejects_oversized_request_body(runtime_store: UserDefinedPagesStore, runtime_app: FastAPI): +async def test_api_runtime_rejects_oversized_request_body(runtime_store: WebUIPagesStore, runtime_app: FastAPI): runtime_store.save_source_file( "runtime-page", "api/routes.yaml", @@ -144,21 +144,21 @@ async def test_api_runtime_rejects_oversized_request_body(runtime_store: UserDef " return {'size': len(body)}\n" ), ) - runtime = UserDefinedPageApiRuntime(runtime_store) + runtime = WebUIPageApiRuntime(runtime_store) - @runtime_app.post("/api/user-defined-pages/{page_id}/api/{api_path:path}") + @runtime_app.post("/api/contracts/webui/pages/{page_id}/api/{api_path:path}") async def _dispatch(page_id: str, api_path: str, request: Request): return await runtime.dispatch(page_id, api_path, request, {"role": "admin"}) payload = "x" * 1_000_001 async with AsyncClient(transport=ASGITransport(app=runtime_app), base_url="http://test") as client: - resp = await client.post("/api/user-defined-pages/runtime-page/api/echo", content=payload) + resp = await client.post("/api/contracts/webui/pages/runtime-page/api/echo", content=payload) assert resp.status_code == 413 @pytest.mark.asyncio -async def test_api_runtime_treats_client_disconnect_as_closed_request(runtime_store: UserDefinedPagesStore): - runtime = UserDefinedPageApiRuntime(runtime_store) +async def test_api_runtime_treats_client_disconnect_as_closed_request(runtime_store: WebUIPagesStore): + runtime = WebUIPageApiRuntime(runtime_store) async def receive(): return {"type": "http.disconnect"} @@ -170,8 +170,8 @@ async def receive(): "http_version": "1.1", "method": "POST", "scheme": "http", - "path": "/api/user-defined-pages/runtime-page/api/echo", - "raw_path": b"/api/user-defined-pages/runtime-page/api/echo", + "path": "/api/contracts/webui/pages/runtime-page/api/echo", + "raw_path": b"/api/contracts/webui/pages/runtime-page/api/echo", "query_string": b"", "headers": [], "client": ("127.0.0.1", 12345), @@ -186,7 +186,7 @@ async def receive(): @pytest.mark.asyncio -async def test_api_runtime_blocks_non_local_imports(runtime_store: UserDefinedPagesStore, runtime_app: FastAPI): +async def test_api_runtime_blocks_non_local_imports(runtime_store: WebUIPagesStore, runtime_app: FastAPI): runtime_store.save_source_file( "runtime-page", "api/routes.yaml", @@ -206,12 +206,12 @@ async def test_api_runtime_blocks_non_local_imports(runtime_store: UserDefinedPa " return {'ok': True}\n" ), ) - runtime = UserDefinedPageApiRuntime(runtime_store) + runtime = WebUIPageApiRuntime(runtime_store) - @runtime_app.get("/api/user-defined-pages/{page_id}/api/{api_path:path}") + @runtime_app.get("/api/contracts/webui/pages/{page_id}/api/{api_path:path}") async def _dispatch(page_id: str, api_path: str, request: Request): return await runtime.dispatch(page_id, api_path, request, {"role": "admin"}) async with AsyncClient(transport=ASGITransport(app=runtime_app), base_url="http://test") as client: - resp = await client.get("/api/user-defined-pages/runtime-page/api/unsafe") + resp = await client.get("/api/contracts/webui/pages/runtime-page/api/unsafe") assert resp.status_code == 500 diff --git a/tests/user_defined_pages/test_bootstrap.py b/tests/contracts/webui/test_bootstrap.py similarity index 70% rename from tests/user_defined_pages/test_bootstrap.py rename to tests/contracts/webui/test_bootstrap.py index 8aecc6e85..62fdc9e7e 100644 --- a/tests/user_defined_pages/test_bootstrap.py +++ b/tests/contracts/webui/test_bootstrap.py @@ -1,7 +1,7 @@ import pytest -from flocks.user_defined_pages.bootstrap import reconcile_user_defined_pages -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.bootstrap import reconcile_webui_pages +from flocks.contracts.webui.store import WebUIPagesStore class _BuilderStub: @@ -24,16 +24,16 @@ async def reload_page(self, page_id: str): @pytest.mark.asyncio async def test_reconcile_rebuilds_missing_bundle_and_preloads_api(tmp_path, monkeypatch): - root = tmp_path / "user_defined_pages" - monkeypatch.setenv("FLOCKS_USER_DEFINED_PAGES_ROOT", str(root)) - store = UserDefinedPagesStore() + root = tmp_path / "webui_pages" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + store = WebUIPagesStore() store.create_page(page_id="boot-page", title="启动页") store.save_source_file("boot-page", "api/routes.yaml", "routes: []\n") store.save_source_file("boot-page", "api/handlers.py", "def noop(ctx, request):\n return {}\n") builder = _BuilderStub() runtime = _RuntimeStub() - await reconcile_user_defined_pages(store=store, builder=builder, runtime=runtime) + await reconcile_webui_pages(store=store, builder=builder, runtime=runtime) assert builder.calls == ["boot-page"] assert runtime.calls == ["boot-page"] diff --git a/tests/user_defined_pages/test_builder.py b/tests/contracts/webui/test_builder.py similarity index 54% rename from tests/user_defined_pages/test_builder.py rename to tests/contracts/webui/test_builder.py index e4af8c8f2..d2102af1a 100644 --- a/tests/user_defined_pages/test_builder.py +++ b/tests/contracts/webui/test_builder.py @@ -1,32 +1,32 @@ import pytest -from flocks.user_defined_pages.builder import UserDefinedPagesBuilder, resolve_esbuild_bin -from flocks.user_defined_pages.store import UserDefinedPagesStore +from flocks.contracts.webui.builder import WebUIPageBuilder, resolve_esbuild_bin +from flocks.contracts.webui.store import WebUIPagesStore @pytest.fixture def built_store(tmp_path, monkeypatch): - root = tmp_path / "user_defined_pages" - monkeypatch.setenv("FLOCKS_USER_DEFINED_PAGES_ROOT", str(root)) - store = UserDefinedPagesStore() + root = tmp_path / "webui_pages" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + store = WebUIPagesStore() store.create_page(page_id="build-page", title="构建页") return store @pytest.mark.skipif(resolve_esbuild_bin() is None, reason="esbuild is not installed") -def test_builder_produces_ready_bundle(built_store: UserDefinedPagesStore): - builder = UserDefinedPagesBuilder(built_store) +def test_builder_produces_ready_bundle(built_store: WebUIPagesStore): + builder = WebUIPageBuilder(built_store) meta = builder.build("build-page") assert meta.status == "ready" assert meta.hash assert built_store.bundle_path("build-page").is_file() -def test_builder_rejects_entry_outside_page_dir(built_store: UserDefinedPagesStore): +def test_builder_rejects_entry_outside_page_dir(built_store: WebUIPagesStore): built_store.create_page(page_id="build-page-neighbor", title="相邻页") built_store.save_manifest("build-page", {"entry": "../build-page-neighbor/src/index.tsx"}) - builder = UserDefinedPagesBuilder(built_store) + builder = WebUIPageBuilder(built_store) with pytest.raises(ValueError, match="invalid entry path"): builder.build("build-page") diff --git a/tests/contracts/webui/test_store.py b/tests/contracts/webui/test_store.py new file mode 100644 index 000000000..d88f35175 --- /dev/null +++ b/tests/contracts/webui/test_store.py @@ -0,0 +1,283 @@ +import json + +import pytest + +from flocks.contracts.webui.store import WebUIPagesStore + + +@pytest.fixture +def store(tmp_path, monkeypatch): + root = tmp_path / "webui_pages" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + return WebUIPagesStore() + + +def _write_page(root, page_id: str, title: str, order: int = 100) -> None: + _write_page_at(root, page_id, page_id, title, order=order) + + +def _write_page_at(root, page_path: str, page_id: str, title: str, order: int = 100) -> None: + page_dir = root / page_path + (page_dir / "dist").mkdir(parents=True, exist_ok=True) + (page_dir / "src").mkdir(parents=True, exist_ok=True) + (page_dir / "manifest.json").write_text( + json.dumps( + { + "id": page_id, + "title": title, + "route": f"/contracts/webui/{page_id}", + "icon": "LayoutDashboard", + "order": order, + "enabled": True, + "placement": "home.after", + "entry": "src/index.tsx", + "updatedAt": 0, + }, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + (page_dir / "src" / "index.tsx").write_text("export default function Page() { return null; }\n", encoding="utf-8") + + +def _read_manifest(root, page_id: str): + return json.loads((root / page_id / "manifest.json").read_text(encoding="utf-8")) + + +def _write_workspace( + root, + workspace_id: str, + title: str, + order: int = 100, + default_page_id: str | None = None, + sections: list[dict] | None = None, +) -> None: + workspace_dir = root / workspace_id + workspace_dir.mkdir(parents=True, exist_ok=True) + payload = { + "id": workspace_id, + "title": title, + "icon": "ShieldCheck", + "order": order, + "enabled": True, + "placement": "sceneWorkspace", + } + if default_page_id is not None: + payload["defaultPageId"] = default_page_id + if sections is not None: + payload["sections"] = sections + (workspace_dir / "workspace.json").write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + + +def test_create_page_scaffold(store: WebUIPagesStore): + detail = store.create_page(page_id="my-dashboard", title="我的大屏") + assert detail.manifest.id == "my-dashboard" + assert detail.manifest.route == "/contracts/webui/my-dashboard" + assert (store.page_dir("my-dashboard") / "src" / "Page.tsx").is_file() + assert (store.page_dir("my-dashboard") / "manifest.json").is_file() + + +def test_list_pages_enabled_only(store: WebUIPagesStore): + store.create_page(page_id="enabled-page", title="启用页") + disabled = store.create_page(page_id="disabled-page", title="禁用页") + store.save_manifest("disabled-page", {**disabled.manifest.model_dump(), "enabled": False}) + + all_pages = store.list_pages(enabled_only=False) + enabled_pages = store.list_pages(enabled_only=True) + + assert {page.id for page in all_pages} == {"enabled-page", "disabled-page"} + assert [page.id for page in enabled_pages] == ["enabled-page"] + + +def test_list_pages_scans_user_and_project_roots_with_user_priority(tmp_path): + user_root = tmp_path / "user" / "contracts" / "webui" + project_root = tmp_path / "project" / ".flocks" / "plugins" / "contracts" / "webui" + _write_page(project_root, "shared-page", "Project shared", order=20) + _write_page(project_root, "project-page", "Project page", order=30) + _write_page(user_root, "shared-page", "User shared", order=10) + _write_page(user_root, "user-page", "User page", order=40) + + store = WebUIPagesStore(root=user_root, project_root=project_root) + + pages = store.list_pages() + assert [page.id for page in pages] == ["shared-page", "project-page", "user-page"] + assert pages[0].title == "User shared" + assert store.get_page("shared-page").manifest.title == "User shared" + assert store.get_page("project-page").manifest.title == "Project page" + assert store.page_dir("shared-page").is_relative_to(user_root) + assert store.page_dir("project-page").is_relative_to(project_root) + + +def test_grouped_page_directory_uses_manifest_id_for_lookup(tmp_path): + user_root = tmp_path / "user" / "contracts" / "webui" + _write_workspace(user_root, "scene_workspace", "场景工作区") + _write_page_at(user_root, "scene_workspace/investigation_list", "investigation-list", "Investigation List") + + store = WebUIPagesStore(root=user_root, project_root=None, legacy_root=None) + + pages = store.list_pages() + assert [page.id for page in pages] == ["investigation-list"] + assert pages[0].route == "/contracts/webui/investigation-list" + assert pages[0].workspaceId == "scene_workspace" + assert pages[0].workspaceTitle == "场景工作区" + assert pages[0].workspaceRoute == "/contracts/webui/workspaces/scene_workspace" + assert store.page_dir("investigation-list") == user_root / "scene_workspace" / "investigation_list" + + store.save_manifest("investigation-list", {"title": "调查列表"}) + store.write_build_meta("investigation-list", store.read_build_meta("investigation-list").model_copy(update={"status": "ready", "hash": "abc"})) + + nested_manifest = json.loads((user_root / "scene_workspace" / "investigation_list" / "manifest.json").read_text(encoding="utf-8")) + assert nested_manifest["title"] == "调查列表" + assert (user_root / "scene_workspace" / "investigation_list" / "dist" / "meta.json").is_file() + assert not (user_root / "investigation-list").exists() + + +def test_list_workspaces_returns_grouped_pages(tmp_path): + user_root = tmp_path / "user" / "contracts" / "webui" + _write_workspace( + user_root, + "scene_workspace", + "场景工作区", + order=5, + default_page_id="ops-overview", + sections=[ + { + "id": "operations", + "label": "调查列表", + "pageIds": ["ops-overview", "investigation-list"], + "defaultPageId": "ops-overview", + "contentPadding": "comfortable", + }, + ], + ) + _write_page_at(user_root, "scene_workspace/investigation_list", "investigation-list", "Investigation List", order=20) + _write_page_at(user_root, "scene_workspace/ops_overview", "ops-overview", "Ops Overview", order=10) + store = WebUIPagesStore(root=user_root, project_root=None, legacy_root=None) + store.write_build_meta("ops-overview", store.read_build_meta("ops-overview").model_copy(update={"status": "ready", "hash": "abc"})) + + workspaces = store.list_workspaces() + + assert [workspace.id for workspace in workspaces] == ["scene_workspace"] + assert workspaces[0].title == "场景工作区" + assert workspaces[0].route == "/contracts/webui/workspaces/scene_workspace" + assert workspaces[0].placement == "sceneWorkspace" + assert workspaces[0].defaultPageId == "ops-overview" + assert len(workspaces[0].sections) == 1 + assert workspaces[0].sections[0].id == "operations" + assert workspaces[0].sections[0].label == "调查列表" + assert workspaces[0].sections[0].pageIds == ["ops-overview", "investigation-list"] + assert workspaces[0].sections[0].defaultPageId == "ops-overview" + assert workspaces[0].sections[0].contentPadding == "comfortable" + assert [page.id for page in workspaces[0].pages] == ["ops-overview", "investigation-list"] + assert workspaces[0].pages[0].buildStatus == "ready" + + +def test_legacy_migration_skips_existing_grouped_page(tmp_path): + user_root = tmp_path / "user" / "contracts" / "webui" + legacy_root = tmp_path / "user" / "user_defined_pages" + _write_page_at( + user_root, + "scene_workspace/risk_dashboard", + "risk-dashboard", + "Grouped page", + ) + _write_page(legacy_root, "risk-dashboard", "Legacy page") + + store = WebUIPagesStore(root=user_root, project_root=None, legacy_root=legacy_root) + pages = store.list_pages() + + assert [page.id for page in pages] == ["risk-dashboard"] + assert store.page_dir("risk-dashboard") == user_root / "scene_workspace" / "risk_dashboard" + assert not (user_root / "risk-dashboard").exists() + + +def test_save_project_root_page_materializes_user_copy(tmp_path): + user_root = tmp_path / "user" / "contracts" / "webui" + project_root = tmp_path / "project" / ".flocks" / "plugins" / "contracts" / "webui" + _write_page(project_root, "project-page", "Project page") + + store = WebUIPagesStore(root=user_root, project_root=project_root) + store.save_manifest("project-page", {"title": "User override"}) + store.write_build_meta("project-page", store.read_build_meta("project-page").model_copy(update={"status": "ready", "hash": "abc"})) + + assert _read_manifest(project_root, "project-page")["title"] == "Project page" + assert _read_manifest(user_root, "project-page")["title"] == "User override" + assert (user_root / "project-page" / "dist" / "meta.json").is_file() + assert not (project_root / "project-page" / "dist" / "meta.json").is_file() + assert store.page_dir("project-page").is_relative_to(user_root) + + +def test_legacy_user_defined_pages_are_migrated_to_contract_root(tmp_path): + user_root = tmp_path / "user" / "contracts" / "webui" + legacy_root = tmp_path / "user" / "user_defined_pages" + _write_page(legacy_root, "legacy-page", "Legacy page") + legacy_manifest = _read_manifest(legacy_root, "legacy-page") + legacy_manifest["route"] = "/user-defined-pages/legacy-page" + (legacy_root / "legacy-page" / "manifest.json").write_text(json.dumps(legacy_manifest), encoding="utf-8") + (legacy_root / "legacy-page" / "src" / "index.tsx").write_text( + "import { Card } from '@flocks/user-defined-page-sdk';\n" + "const sdk = globalThis.__FLOCKS_USER_DEFINED_PAGE_SDK__;\n", + encoding="utf-8", + ) + + store = WebUIPagesStore(root=user_root, project_root=None, legacy_root=legacy_root) + pages = store.list_pages() + + assert [page.id for page in pages] == ["legacy-page"] + assert pages[0].route == "/contracts/webui/legacy-page" + migrated_source = (user_root / "legacy-page" / "src" / "index.tsx").read_text(encoding="utf-8") + assert "@flocks/webui-contract-sdk" in migrated_source + assert "__FLOCKS_WEBUI_CONTRACT_SDK__" in migrated_source + assert _read_manifest(user_root, "legacy-page")["route"] == "/contracts/webui/legacy-page" + assert json.loads((user_root / "legacy-page" / "dist" / "meta.json").read_text(encoding="utf-8"))["status"] == "idle" + + +def test_reject_path_traversal_on_write(store: WebUIPagesStore): + store.create_page(page_id="safe-page", title="安全页") + with pytest.raises(ValueError, match="writes are not allowed"): + store.save_source_file("safe-page", "../escape.tsx", "bad") + + +def test_allow_page_api_source_files(store: WebUIPagesStore): + store.create_page(page_id="api-page", title="API 页") + store.save_source_file("api-page", "api/routes.yaml", "routes: []\n") + store.save_source_file("api-page", "api/handlers.py", "def ping(ctx, request):\n return {'ok': True}\n") + assert store.read_source_file("api-page", "api/routes.yaml").startswith("routes:") + detail = store.get_page("api-page") + assert "api/routes.yaml" in detail.sourceFiles + assert "api/handlers.py" in detail.sourceFiles + + +def test_reject_unsupported_api_extension(store: WebUIPagesStore): + store.create_page(page_id="api-ext-page", title="API 后缀页") + with pytest.raises(ValueError, match="unsupported source file type"): + store.save_source_file("api-ext-page", "api/secret.txt", "nope") + + +def test_reject_invalid_page_id(store: WebUIPagesStore): + with pytest.raises(ValueError, match="invalid page id"): + store.validate_page_id("../bad") + + +def test_asset_path_stays_inside_assets_dir(store: WebUIPagesStore): + store.create_page(page_id="asset-page", title="资源页") + with pytest.raises(ValueError, match="path traversal is not allowed"): + store.asset_path("asset-page", "../manifest.json") + + +def test_manifest_roundtrip(store: WebUIPagesStore): + store.create_page(page_id="roundtrip", title="原始标题") + manifest = store.save_manifest( + "roundtrip", + { + "title": "新标题", + "order": 10, + "route": "/custom/route", + }, + ) + assert manifest.title == "新标题" + assert manifest.order == 10 + assert manifest.route == "/contracts/webui/roundtrip" + raw = json.loads((store.page_dir("roundtrip") / "manifest.json").read_text(encoding="utf-8")) + assert raw["route"] == "/contracts/webui/roundtrip" diff --git a/tests/contracts/webui/test_watcher.py b/tests/contracts/webui/test_watcher.py new file mode 100644 index 000000000..a5043c6a1 --- /dev/null +++ b/tests/contracts/webui/test_watcher.py @@ -0,0 +1,106 @@ +import json + +from flocks.contracts.webui.store import WebUIPagesStore +from flocks.contracts.webui import watcher as watcher_module +from flocks.contracts.webui.watcher import WebUIPagesWatcher, _PendingAction + + +class _RuntimeStub: + async def reload_page(self, _page_id: str): + return [{"method": "GET", "path": "/stats", "handler": "handlers.stats"}] + + +class _BuilderStub: + def build(self, _page_id: str): + raise AssertionError("build should not be called for api-only change") + + +def test_watcher_api_change_uses_main_loop_bridge(monkeypatch): + emitted: list[tuple[str, dict]] = [] + bridge_calls: list[str] = [] + + def _bridge(coro, *, timeout_seconds=5.0): + bridge_calls.append("called") + coro.close() + return [{"method": "GET", "path": "/stats", "handler": "handlers.stats"}] + + def _emit(event_type: str, properties: dict): + emitted.append((event_type, properties)) + + monkeypatch.setattr(watcher_module, "_run_on_main_loop_sync", _bridge) + monkeypatch.setattr(watcher_module, "_publish_event_sync", _emit) + + watcher = WebUIPagesWatcher(builder=_BuilderStub(), api_runtime=_RuntimeStub()) + watcher._pending_pages["demo-page"] = _PendingAction(api_changed=True) + watcher._run_pending_builds() + + assert bridge_calls == ["called"] + assert emitted[0][0] == "contracts.webui.pages.api_changed" + assert emitted[0][1]["id"] == "demo-page" + + +def test_watcher_classifies_nested_page_api_change(tmp_path): + root = tmp_path / "webui_pages" + page_dir = root / "scene_workspace" / "investigation_list" + (page_dir / "api").mkdir(parents=True) + (page_dir / "manifest.json").write_text( + json.dumps( + { + "id": "investigation-list", + "title": "Investigation List", + "route": "/contracts/webui/investigation-list", + "icon": "AlertTriangle", + "order": 20, + "enabled": True, + "placement": "home.after", + "entry": "src/index.tsx", + "updatedAt": 0, + } + ), + encoding="utf-8", + ) + + store = WebUIPagesStore(root=root, project_root=None, legacy_root=None) + watcher = WebUIPagesWatcher(store=store, builder=_BuilderStub(), api_runtime=_RuntimeStub()) + + page_id, pending = watcher._classify_event( + page_dir / "api" / "handlers.py", + root, + event_type="modified", + is_directory=False, + ) + + assert page_id == "investigation-list" + assert pending.api_changed + + +def test_watcher_classifies_workspace_manifest_change(tmp_path): + root = tmp_path / "webui_pages" + workspace_dir = root / "scene_workspace" + workspace_dir.mkdir(parents=True) + (workspace_dir / "workspace.json").write_text( + json.dumps( + { + "id": "scene_workspace", + "title": "场景工作区", + "icon": "ShieldCheck", + "order": 10, + "enabled": True, + "placement": "sceneWorkspace", + } + ), + encoding="utf-8", + ) + + store = WebUIPagesStore(root=root, project_root=None, legacy_root=None) + watcher = WebUIPagesWatcher(store=store, builder=_BuilderStub(), api_runtime=_RuntimeStub()) + + page_id, pending = watcher._classify_event( + workspace_dir / "workspace.json", + root, + event_type="modified", + is_directory=False, + ) + + assert page_id == "scene_workspace" + assert pending.manifest_changed diff --git a/tests/hub/test_bundled_tools.py b/tests/hub/test_bundled_tools.py index 16436e7c3..e704e4848 100644 --- a/tests/hub/test_bundled_tools.py +++ b/tests/hub/test_bundled_tools.py @@ -215,6 +215,7 @@ def test_bundled_only_lists_as_available(self, isolated_hub): entries = catalog.list_catalog(plugin_type="tool") match = next((e for e in entries if e.id == "onesig_v2_5_3_D20250710"), None) assert match is not None, "bundled tool should appear in catalog" + assert match.version == "2.5.3 D20250710" assert match.state == "available" assert match.installedVersion is None assert match.source == "bundled" diff --git a/tests/ingest/test_kafka_manager.py b/tests/ingest/test_kafka_manager.py index 199d8d2ac..31de29d02 100644 --- a/tests/ingest/test_kafka_manager.py +++ b/tests/ingest/test_kafka_manager.py @@ -167,10 +167,12 @@ async def test_restart_disabled_config_reports_stopped(monkeypatch: pytest.Monke manager = kafka_manager.KafkaManager() - async def _fake_read(key): # noqa: ANN001 + async def _fake_get_config(workflow_id: str, *, kind: str) -> dict: + assert workflow_id == "wf-disabled" + assert kind == "workflow_kafka_config" return {"enabled": False} - monkeypatch.setattr(kafka_manager.Storage, "read", _fake_read) + monkeypatch.setattr(kafka_manager.WorkflowStore, "get_config", _fake_get_config) status = await manager.restart_workflow("wf-disabled") assert status == {"state": "stopped", "error": None} @@ -182,10 +184,12 @@ async def test_restart_missing_broker_reports_failed(monkeypatch: pytest.MonkeyP manager = kafka_manager.KafkaManager() - async def _fake_read(key): # noqa: ANN001 + async def _fake_get_config(workflow_id: str, *, kind: str) -> dict: + assert workflow_id == "wf-no-broker" + assert kind == "workflow_kafka_config" return {"enabled": True, "inputBroker": "", "inputTopic": ""} - monkeypatch.setattr(kafka_manager.Storage, "read", _fake_read) + monkeypatch.setattr(kafka_manager.WorkflowStore, "get_config", _fake_get_config) status = await manager.restart_workflow("wf-no-broker") assert status["state"] == "failed" @@ -201,7 +205,9 @@ async def test_restart_workflow_cleans_resources_after_connect_failure( manager = kafka_manager.KafkaManager() workflow_id = "wf-connect-failed" - async def _fake_read(key): # noqa: ANN001 + async def _fake_get_config(workflow_id_value: str, *, kind: str) -> dict: + assert workflow_id_value == workflow_id + assert kind == "workflow_kafka_config" return { "enabled": True, "inputBroker": "localhost:9092", @@ -220,11 +226,17 @@ async def start(self) -> None: async def stop(self) -> None: self.stopped = True - monkeypatch.setattr(kafka_manager.Storage, "read", _fake_read) + monkeypatch.setattr(kafka_manager.WorkflowStore, "get_config", _fake_get_config) monkeypatch.setattr( kafka_manager, "read_workflow_from_fs", - lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + lambda _workflow_id: { + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + }, ) monkeypatch.setitem(sys.modules, "aiokafka", SimpleNamespace(AIOKafkaConsumer=_Consumer)) @@ -318,7 +330,8 @@ def _fake_run_workflow(**kwargs): # noqa: ANN003 assert captured_input_params["kafka_message"]["alarmData"]["_type"] == "string" assert captured_input_params["kafka_message"]["alarmData"]["chars"] == 50_000 - assert captured_run_kwargs["history_mode"] == "summary" + assert captured_run_kwargs["run_id"] == "exec-compact" + assert captured_run_kwargs["execution_profile"] == "high_frequency" assert callable(captured_run_kwargs["on_step_complete"]) assert captured_exec_data["outputResults"] == { "_enriched_alerts_count": 1, diff --git a/tests/ingest/test_syslog_manager_backpressure.py b/tests/ingest/test_syslog_manager_backpressure.py index 13a581817..a042e2fc5 100644 --- a/tests/ingest/test_syslog_manager_backpressure.py +++ b/tests/ingest/test_syslog_manager_backpressure.py @@ -238,6 +238,8 @@ def _fake_run_workflow(**kwargs): # noqa: ANN003 assert captured_run_kwargs["inputs"]["message"] == "demo" assert captured_run_kwargs["inputs"]["hostname"] == "router-a" assert captured_run_kwargs["inputs"]["pipeline"] == "syslog" + assert captured_run_kwargs["run_id"] == "exec-syslog" + assert captured_run_kwargs["execution_profile"] == "high_frequency" assert callable(captured_run_kwargs["on_step_complete"]) assert recorded_steps[0][0] == "exec-syslog" assert recorded_steps[0][1] == 1 diff --git a/tests/provider/test_api_service_management.py b/tests/provider/test_api_service_management.py index 7cb002040..768978cad 100644 --- a/tests/provider/test_api_service_management.py +++ b/tests/provider/test_api_service_management.py @@ -399,6 +399,7 @@ def test_load_provider_yaml_metadata_from_project_plugins(self, tmp_path, monkey "service_id": "threatbook_api", "description": "Threat intelligence", "description_cn": "威胁情报", + "docs_url": "https://docs.example.com/threatbook", }), encoding="utf-8") (provider_dir / "threatbook_ip_query.yaml").write_text(yaml.safe_dump({ "name": "threatbook_ip_query", @@ -413,6 +414,7 @@ def test_load_provider_yaml_metadata_from_project_plugins(self, tmp_path, monkey assert metadata is not None assert metadata["name"] == "ThreatBook" assert metadata["description_cn"] == "威胁情报" + assert metadata["docs_url"] == "https://docs.example.com/threatbook" assert metadata["apis"][0]["name"] == "threatbook_ip_query" @pytest.mark.asyncio @@ -424,6 +426,7 @@ async def test_get_api_service_metadata_returns_credential_schema(self): return_value={ "name": "Qingteng", "description": "Qingteng API service", + "docs_url": "https://docs.example.com/qingteng", "credential_fields": [ {"key": "base_url", "storage": "config", "config_key": "base_url", "input_type": "url"}, {"key": "username", "storage": "config", "config_key": "username"}, @@ -434,6 +437,7 @@ async def test_get_api_service_metadata_returns_credential_schema(self): result = await get_api_service_metadata("qingteng") assert result.name == "Qingteng" + assert result.docs_url == "https://docs.example.com/qingteng" assert result.credential_schema is not None assert [field["key"] for field in result.credential_schema] == ["base_url", "username", "password"] assert result.credential_schema[2]["secret_id"] == "qingteng_password" diff --git a/tests/scripts/test_bash_scripts.py b/tests/scripts/test_bash_scripts.py index 13e81cc39..4dcb69e9c 100644 --- a/tests/scripts/test_bash_scripts.py +++ b/tests/scripts/test_bash_scripts.py @@ -34,3 +34,10 @@ def test_dev_script_stops_backend_process_tree_on_exit() -> None: assert 'kill -TERM "${kill_targets[@]}"' in script assert 'kill -KILL "${remaining[@]}"' in script assert 'kill "${BACKEND_PID}" 2>/dev/null || true' not in script + + +def test_dev_script_sets_console_base_url_default() -> None: + script = (SCRIPT_DIR / "dev.sh").read_text(encoding="utf-8") + + assert 'FLOCKS_CONSOLE_BASE_URL="${FLOCKS_CONSOLE_BASE_URL:-https://portalflocks.threatbook.cn}"' in script + assert 'FLOCKS_CONSOLE_BASE_URL="${FLOCKS_CONSOLE_BASE_URL}" \\' in script diff --git a/tests/server/routes/test_admin_users_routes.py b/tests/server/routes/test_admin_users_routes.py index 8f9e335d6..1ed56409e 100644 --- a/tests/server/routes/test_admin_users_routes.py +++ b/tests/server/routes/test_admin_users_routes.py @@ -22,6 +22,36 @@ async def test_admin_routes_list_users(client: AsyncClient): assert len(users) >= 1 admin_user = users[0] assert admin_user["role"] == "admin" + assert admin_user["tenant_ids"] == [] + assert admin_user["asset_groups"] == [] + + +@pytest.mark.asyncio +async def test_admin_routes_update_contract_scope(client: AsyncClient): + from flocks.auth.service import AuthService + + member = await AuthService._create_user_internal( + username="analyst", + password="Password123!", + role="member", + ) + + response = await client.put( + f"/api/admin/users/{member.id}/contract-scope", + json={ + "tenant_ids": ["tenant-a", "tenant-a", " "], + "asset_groups": ["core"], + }, + ) + + assert response.status_code == 200, response.text + assert response.json()["tenant_ids"] == ["tenant-a"] + assert response.json()["asset_groups"] == ["core"] + + stored = await AuthService.get_user_by_id(member.id) + assert stored is not None + assert stored.to_auth_user().tenant_ids == ("tenant-a",) + assert stored.to_auth_user().asset_groups == ("core",) @pytest.mark.asyncio diff --git a/tests/server/routes/test_console_upgrade_routes.py b/tests/server/routes/test_console_upgrade_routes.py index 8eb8380bd..575e1e10e 100644 --- a/tests/server/routes/test_console_upgrade_routes.py +++ b/tests/server/routes/test_console_upgrade_routes.py @@ -3,7 +3,7 @@ import json import pytest -from fastapi import status +from fastapi import FastAPI, status import httpx from types import ModuleType from httpx import AsyncClient @@ -174,6 +174,40 @@ async def test_fallback_license_state_does_not_mark_license_activated( assert record["details"]["license_activate_fallback_saved_at"] +async def test_pro_upgrade_activation_can_disable_fallback_license_state( + tmp_path, + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + + class _Checker: + def activate(self, *_args, **_kwargs): + raise RuntimeError("activation service unavailable") + + runtime_module = ModuleType("flockspro.license.runtime") + flockspro_module = ModuleType("flockspro") + license_module = ModuleType("flockspro.license") + runtime_module.get_license_checker = lambda: _Checker() # type: ignore[attr-defined] + monkeypatch.setitem(__import__("sys").modules, "flockspro", flockspro_module) + monkeypatch.setitem(__import__("sys").modules, "flockspro.license", license_module) + monkeypatch.setitem(__import__("sys").modules, "flockspro.license.runtime", runtime_module) + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) + monkeypatch.setattr(console_routes, "_is_pro_component_installed", lambda: True) + + record = { + "request_id": "req_no_fallback", + "license_id": "lic_no_fallback", + "activate_key": "signed.token.value", + "details": {}, + } + + await console_routes._maybe_activate_pro_license(record, allow_fallback=False) + + assert record["details"]["license_activate_error"] == "activation service unavailable" + assert not (tmp_path / "flockspro" / "license.json").exists() + assert "license_activate_fallback_saved_at" not in record["details"] + + async def test_refresh_pro_license_updates_record_timestamp(monkeypatch: pytest.MonkeyPatch): from flocks.server.routes import console_upgrade as console_routes @@ -226,429 +260,81 @@ async def test_pro_package_status_reports_installed_marker( assert resp.status_code == status.HTTP_200_OK payload = resp.json() assert payload["installed"] is True + assert payload["runtime_importable"] is True + assert payload["install_marker_present"] is True assert payload["flockspro_component_version"] == "1.2.3" -async def test_sync_console_license_revocations_without_pro_package_only_syncs_console_records( +async def test_pro_package_status_treats_install_marker_as_installed( client: AsyncClient, monkeypatch: pytest.MonkeyPatch, ): from flocks.server.routes import console_upgrade as console_routes - from flocks.storage.storage import Storage - monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "http://console.local") - monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) monkeypatch.setattr(console_routes, "_is_pro_component_installed", lambda: False) - await _set_bound_console_session() - await Storage.set("console:upgrade_request_ids", ["req_install"], "json") - await Storage.set( - "console:upgrade_request:req_install", - { - "request_id": "req_install", - "status": "approved", - "activate_key": "install_token", - "license_id": "lic_install", - "license_status": "poc", - "details": {"console_account_name": "alice", "license_id": "lic_install"}, - "created_at": "2026-05-15T10:00:00+00:00", - "updated_at": "2026-05-15T10:00:00+00:00", + monkeypatch.setattr( + console_routes, + "_read_pro_bundle_install_marker", + lambda: { + "installed_version": "pro-v2026.6.23", + "flockspro_component_version": "2026.6.23", + "installed_at": "2026-06-29T04:00:00+00:00", }, - "json", ) - class _FakeResponse: - def __init__(self, payload: dict, status_code: int = status.HTTP_200_OK) -> None: - self._payload = payload - self.status_code = status_code - - def json(self) -> dict: - return self._payload - - def raise_for_status(self) -> None: - return None - - class _FakeClient: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - async def get(self, url, headers=None): - assert headers == {"Authorization": "Bearer token_abc"} - if url == "http://console.local/v1/licenses/revocations": - return _FakeResponse({"revoked_license_ids": ["lic_revoked"]}) - if url == "http://console.local/v1/licenses/lic_install": - return _FakeResponse( - { - "license_id": "lic_install", - "license_status": "poc", - "effective_status": "poc", - "effective_expires_at": 1781417933, - "effective_max_admins": 3, - "effective_max_members": 9, - } - ) - raise AssertionError(url) - - monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _FakeClient()) - - resp = await client.post("/api/console/licenses/sync-revocations") + resp = await client.get("/api/console/pro-package-status") assert resp.status_code == status.HTTP_200_OK payload = resp.json() - assert payload["imported"] is False + assert payload["installed"] is True + assert payload["runtime_importable"] is False + assert payload["install_marker_present"] is True assert payload["inactive_reason"] == "flockspro_not_installed" - assert payload["synced_license_ids"] == ["lic_install"] - stored = await Storage.get("console:upgrade_request:req_install") - assert stored["max_admins"] == 3 - assert stored["max_members"] == 9 - - -async def test_sync_console_license_revocations_imports_into_checker( - client: AsyncClient, - monkeypatch: pytest.MonkeyPatch, -): - from flocks.server.routes import console_upgrade as console_routes - - monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "http://console.local") - monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) - await _set_bound_console_session() - - class _FakeResponse: - def json(self) -> dict: - return {"revoked_license_ids": ["lic_revoked"]} - - def raise_for_status(self) -> None: - return None - - class _FakeClient: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - async def get(self, url, headers=None): - assert url == "http://console.local/v1/licenses/revocations" - assert headers == {"Authorization": "Bearer token_abc"} - return _FakeResponse() - - imported: list[str] = [] - - class _Checker: - def import_revocation(self, revoked_license_ids): - imported.extend(revoked_license_ids) - - runtime_module = ModuleType("flockspro.license.runtime") - flockspro_module = ModuleType("flockspro") - license_module = ModuleType("flockspro.license") - runtime_module.get_license_checker = lambda: _Checker() # type: ignore[attr-defined] - runtime_module.get_pro_capability_status = lambda: {"pro_enabled": False, "active": False} # type: ignore[attr-defined] - monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _FakeClient()) - monkeypatch.setattr(console_routes, "_is_pro_component_installed", lambda: True) - monkeypatch.setitem(__import__("sys").modules, "flockspro", flockspro_module) - monkeypatch.setitem(__import__("sys").modules, "flockspro.license", license_module) - monkeypatch.setitem(__import__("sys").modules, "flockspro.license.runtime", runtime_module) - - resp = await client.post("/api/console/licenses/sync-revocations") - - assert resp.status_code == status.HTTP_200_OK - assert resp.json()["imported"] is True - assert imported == ["lic_revoked"] - - -async def test_sync_console_license_revocations_switches_from_revoked_runtime_license( - client: AsyncClient, - monkeypatch: pytest.MonkeyPatch, -): - from flocks.server.routes import console_upgrade as console_routes - from flocks.storage.storage import Storage - - monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "http://console.local") - monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) - await _set_bound_console_session() - await Storage.set("console:upgrade_request_ids", ["req_old", "req_new", "req_later_revoked"], "json") - await Storage.set( - "console:upgrade_request:req_old", - { - "request_id": "req_old", - "status": "activated", - "activate_key": "old_token", - "license_id": "lic_old", - "license_status": "poc", - "details": {"license_id": "lic_old"}, - "created_at": "2026-05-15T10:00:00+00:00", - "updated_at": "2026-05-15T10:00:00+00:00", - }, - "json", - ) - await Storage.set( - "console:upgrade_request:req_new", - { - "request_id": "req_new", - "status": "approved", - "activate_key": "new_token", - "license_id": "lic_new", - "license_status": "poc", - "details": {"license_id": "lic_new"}, - "created_at": "2026-05-15T11:00:00+00:00", - "updated_at": "2026-05-15T11:00:00+00:00", - }, - "json", - ) - await Storage.set( - "console:upgrade_request:req_later_revoked", - { - "request_id": "req_later_revoked", - "status": "approved", - "activate_key": "later_revoked_token", - "license_id": "lic_later_revoked", - "license_status": "revoked", - "details": {"license_id": "lic_later_revoked"}, - "created_at": "2026-05-15T12:00:00+00:00", - "updated_at": "2026-05-15T12:00:00+00:00", - }, - "json", - ) - - class _FakeResponse: - def __init__(self, payload: dict, status_code: int = status.HTTP_200_OK) -> None: - self._payload = payload - self.status_code = status_code - - def json(self) -> dict: - return self._payload - - def raise_for_status(self) -> None: - return None - - class _FakeClient: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - async def get(self, url, headers=None): - if url == "http://console.local/v1/licenses/revocations": - return _FakeResponse({"revoked_license_ids": ["lic_old"]}) - if url == "http://console.local/v1/licenses/lic_old": - return _FakeResponse( - { - "license_id": "lic_old", - "revoked": True, - "license_status": "revoked", - "effective_status": "revoked", - "effective_expires_at": 1778825933, - } - ) - if url == "http://console.local/v1/licenses/lic_new": - return _FakeResponse( - { - "license_id": "lic_new", - "license_status": "poc", - "effective_status": "poc", - "effective_expires_at": 1781417933, - "effective_max_admins": 2, - "effective_max_members": 6, - } - ) - if url == "http://console.local/v1/licenses/lic_later_revoked": - return _FakeResponse( - { - "license_id": "lic_later_revoked", - "revoked": True, - "license_status": "revoked", - "effective_status": "revoked", - "effective_expires_at": 1781417933, - } - ) - raise AssertionError(url) - class _Checker: - def __init__(self) -> None: - self.license_id = "lic_old" - self.active = False - self.activated_tokens: list[str] = [] - self.refreshed = False - def import_revocation(self, revoked_license_ids): - assert revoked_license_ids == ["lic_old"] +async def test_flockspro_license_status_fallback_reports_uninstalled(monkeypatch: pytest.MonkeyPatch): + from flocks.server.routes import flockspro_license as license_routes - def status(self): - return { - "license_id": self.license_id, - "license_status": "revoked" if not self.active else "poc", - "active": self.active, - } - - def activate(self, token: str): - self.activated_tokens.append(token) - self.license_id = "lic_new" - self.active = True - return self.status() - - async def refresh(self): - self.refreshed = True - return self.status() - - checker = _Checker() - runtime_module = ModuleType("flockspro.license.runtime") - flockspro_module = ModuleType("flockspro") - license_module = ModuleType("flockspro.license") - runtime_module.get_license_checker = lambda: checker # type: ignore[attr-defined] - runtime_module.get_pro_capability_status = lambda: { # type: ignore[attr-defined] - **checker.status(), - "pro_enabled": checker.active, - } - monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _FakeClient()) - monkeypatch.setattr(console_routes, "_is_pro_component_installed", lambda: True) - monkeypatch.setitem(__import__("sys").modules, "flockspro", flockspro_module) - monkeypatch.setitem(__import__("sys").modules, "flockspro.license", license_module) - monkeypatch.setitem(__import__("sys").modules, "flockspro.license.runtime", runtime_module) + app = FastAPI() + app.include_router(license_routes.router, prefix="/api/flockspro/license") + monkeypatch.setattr(license_routes, "_is_pro_component_installed", lambda: False) + monkeypatch.setattr(license_routes, "require_user", lambda _req: _mock_admin()) - resp = await client.post("/api/console/licenses/sync-revocations") + transport = httpx.ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as local_client: + resp = await local_client.get("/api/flockspro/license/status") assert resp.status_code == status.HTTP_200_OK payload = resp.json() - assert payload["activated_license_id"] == "lic_new" - assert payload["refreshed_license_id"] == "lic_new" - assert checker.activated_tokens == ["new_token"] - assert checker.refreshed is True + assert payload["active"] is False + assert payload["pro_enabled"] is False + assert payload["license_status"] == "uninstalled" + assert payload["inactive_reason"] == "flockspro_not_installed" -async def test_sync_console_license_revocations_switches_to_newer_active_license( - client: AsyncClient, - monkeypatch: pytest.MonkeyPatch, -): - from flocks.server.routes import console_upgrade as console_routes - from flocks.storage.storage import Storage +async def test_flockspro_license_status_delegates_to_pro_runtime(monkeypatch: pytest.MonkeyPatch): + from flocks.server.routes import flockspro_license as license_routes - monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "http://console.local") - monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) - await _set_bound_console_session() - await Storage.set("console:upgrade_request_ids", ["req_old", "req_new"], "json") - await Storage.set( - "console:upgrade_request:req_old", - { - "request_id": "req_old", - "status": "activated", - "activate_key": "old_token", - "license_id": "lic_old", - "license_status": "poc", - "details": {"license_id": "lic_old"}, - "created_at": "2026-05-15T10:00:00+00:00", - "updated_at": "2026-05-15T10:00:00+00:00", - }, - "json", - ) - await Storage.set( - "console:upgrade_request:req_new", - { - "request_id": "req_new", - "status": "approved", - "activate_key": "new_token", - "license_id": "lic_new", - "license_status": "poc", - "details": {"license_id": "lic_new"}, - "created_at": "2026-05-15T11:00:00+00:00", - "updated_at": "2026-05-15T11:00:00+00:00", - }, - "json", + app = FastAPI() + app.include_router(license_routes.router, prefix="/api/flockspro/license") + monkeypatch.setattr(license_routes, "_is_pro_component_installed", lambda: True) + monkeypatch.setattr( + license_routes, + "_get_pro_capability_status", + lambda: {"active": True, "pro_enabled": True, "license_status": "poc", "license_id": "lic_1"}, ) + monkeypatch.setattr(license_routes, "require_user", lambda _req: _mock_admin()) - class _FakeResponse: - def __init__(self, payload: dict, status_code: int = status.HTTP_200_OK) -> None: - self._payload = payload - self.status_code = status_code - - def json(self) -> dict: - return self._payload - - def raise_for_status(self) -> None: - return None - - class _FakeClient: - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - return False - - async def get(self, url, headers=None): - if url == "http://console.local/v1/licenses/revocations": - return _FakeResponse({"revoked_license_ids": []}) - if url == "http://console.local/v1/licenses/lic_old": - return _FakeResponse( - { - "license_id": "lic_old", - "license_status": "poc", - "effective_status": "poc", - "effective_max_admins": 3, - "effective_max_members": 20, - } - ) - if url == "http://console.local/v1/licenses/lic_new": - return _FakeResponse( - { - "license_id": "lic_new", - "license_status": "poc", - "effective_status": "poc", - "effective_max_admins": 3, - "effective_max_members": 21, - } - ) - raise AssertionError(url) - - class _Checker: - def __init__(self) -> None: - self.license_id = "lic_old" - self.activated_tokens: list[str] = [] - self.refreshed = False - - def import_revocation(self, revoked_license_ids): - assert revoked_license_ids == [] - - def status(self): - return { - "license_id": self.license_id, - "license_status": "poc", - "active": True, - } - - def activate(self, token: str): - self.activated_tokens.append(token) - self.license_id = "lic_new" - return self.status() - - async def refresh(self): - self.refreshed = True - return self.status() - - checker = _Checker() - runtime_module = ModuleType("flockspro.license.runtime") - flockspro_module = ModuleType("flockspro") - license_module = ModuleType("flockspro.license") - runtime_module.get_license_checker = lambda: checker # type: ignore[attr-defined] - runtime_module.get_pro_capability_status = lambda: { # type: ignore[attr-defined] - **checker.status(), - "pro_enabled": True, - } - monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _FakeClient()) - monkeypatch.setattr(console_routes, "_is_pro_component_installed", lambda: True) - monkeypatch.setitem(__import__("sys").modules, "flockspro", flockspro_module) - monkeypatch.setitem(__import__("sys").modules, "flockspro.license", license_module) - monkeypatch.setitem(__import__("sys").modules, "flockspro.license.runtime", runtime_module) - - resp = await client.post("/api/console/licenses/sync-revocations") + transport = httpx.ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as local_client: + resp = await local_client.get("/api/flockspro/license/status") assert resp.status_code == status.HTTP_200_OK payload = resp.json() - assert payload["activated_license_id"] == "lic_new" - assert payload["refreshed_license_id"] == "lic_new" - assert checker.activated_tokens == ["new_token"] - assert checker.refreshed is True + assert payload["active"] is True + assert payload["pro_enabled"] is True + assert payload["activated"] is True + assert payload["license_id"] == "lic_1" async def test_create_upgrade_request_does_not_link_previous_request_when_omitted( @@ -776,6 +462,59 @@ async def post(self, url, json=None, headers=None): assert "console unavailable" in resp.text +async def test_create_upgrade_request_sanitizes_html_console_failure( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + + monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "http://console.local") + monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) + await _set_bound_console_session() + + class _FakeResponse: + status_code = status.HTTP_502_BAD_GATEWAY + + def raise_for_status(self) -> None: + request = httpx.Request("POST", "http://console.local/v1/upgrade-requests") + response = httpx.Response( + self.status_code, + request=request, + text="502 Bad Gatewaybad gateway", + headers={"content-type": "text/html"}, + ) + raise httpx.HTTPStatusError("console call failed", request=request, response=response) + + class _FakeClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + assert url == "http://console.local/v1/upgrade-requests" + return _FakeResponse() + + monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _FakeClient()) + + resp = await client.post( + "/api/console/upgrade-requests", + json={ + "product": "Flocks Pro", + "license_type": "poc", + "company": "acme", + "applicant_name": "alice", + "applicant_email": "alice@example.com", + "applicant_phone": "+1 415 555 0100", + }, + ) + + assert resp.status_code == status.HTTP_502_BAD_GATEWAY + assert "console 升级服务暂不可用" in resp.text + assert "" not in resp.text + + async def test_cancel_approved_request_falls_back_to_local_cancel_when_console_rejects( client: AsyncClient, monkeypatch: pytest.MonkeyPatch, @@ -886,6 +625,95 @@ async def test_refresh_approved_request_does_not_auto_activate_install( assert "auto_install_task_scheduled_at" not in payload["details"] +async def test_refresh_request_remote_form_data_overwrites_stale_local_details( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + from flocks.storage.storage import Storage + + monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "http://console.local") + monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) + await _set_bound_console_session() + request_id = "req_refresh_remote_details" + await Storage.set( + f"console:upgrade_request:{request_id}", + { + "request_id": request_id, + "status": "approved", + "previous_request_id": None, + "reason": None, + "suggestion": None, + "activate_key": "old_key", + "manifest_url": "https://manifest.example.com/v1/manifest/latest", + "license_id": "lic_old", + "license_status": "poc", + "max_admins": 1, + "max_members": 2, + "expires_at": 100, + "details": { + "company": "acme", + "license_id": "lic_old", + "license_effective_expires_at": 100, + "local_only": "keep", + }, + "created_at": "2026-05-08T08:00:00+00:00", + "updated_at": "2026-05-08T08:00:00+00:00", + }, + "json", + ) + + class _FakeResponse: + def raise_for_status(self) -> None: + return None + + def json(self) -> dict: + return { + "request_id": request_id, + "status": "activated", + "activate_key": "new_key", + "license_id": "lic_new", + "license_status": "poc", + "max_admins": 3, + "max_members": 20, + "expires_at": 1782532851, + "form_data": { + "company": "acme", + "license_id": "lic_new", + "license_effective_expires_at": 1782532851, + }, + } + + class _FakeClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url, headers=None): + assert url == f"http://console.local/v1/upgrade-requests/{request_id}" + assert headers == {"Authorization": "Bearer token_abc"} + return _FakeResponse() + + monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _FakeClient()) + + resp = await client.post(f"/api/console/upgrade-requests/{request_id}/refresh") + + assert resp.status_code == status.HTTP_200_OK + payload = resp.json() + assert payload["status"] == "activated" + assert payload["license_id"] == "lic_new" + assert payload["max_admins"] == 3 + assert payload["max_members"] == 20 + assert payload["details"]["license_id"] == "lic_new" + assert payload["details"]["license_effective_expires_at"] == 1782532851 + assert payload["details"]["license_refreshed_at"] + assert payload["details"]["license_refreshed_at"] == payload["updated_at"] + assert payload["details"]["license_refreshed_at"] != "2026-05-08T08:00:00+00:00" + assert payload["details"]["local_only"] == "keep" + + async def test_start_approved_request_streams_restart_without_marking_activated( client: AsyncClient, monkeypatch: pytest.MonkeyPatch, @@ -922,7 +750,7 @@ async def _fake_perform_pro_bundle_install(*args, **kwargs): yield UpdateProgress(stage="fetching", message="Downloading Flocks Pro bundle...", success=None) yield UpdateProgress(stage="restarting", message="Restarting service...", success=None) - async def _noop(_record: dict): + async def _noop(_record: dict, **_kwargs): return None reported: list[tuple[str, str | None]] = [] @@ -939,7 +767,10 @@ async def _fake_report(record: dict, *, install_result: str, error_message: str monkeypatch.setattr( console_routes, "_read_pro_bundle_install_marker", - lambda: {"installed_version": "v2026.5.9"}, + lambda: { + "installed_version": "v2026.6.5", + "flockspro_component_version": "v2026.6.5", + }, ) resp = await client.post(f"/api/console/upgrade-requests/{request_id}/start") @@ -954,6 +785,82 @@ async def _fake_report(record: dict, *, install_result: str, error_message: str assert reported == [] +async def test_restarting_request_reports_receipt_after_service_restart( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + from flocks.storage.storage import Storage + + monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "https://console.example.com") + monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) + await _set_bound_console_session() + request_id = "req_restart_complete" + await Storage.set("console:upgrade_request_ids", [request_id], "json") + await Storage.set( + f"console:upgrade_request:{request_id}", + { + "request_id": request_id, + "status": "approved", + "previous_request_id": None, + "reason": None, + "suggestion": None, + "activate_key": "key_start", + "manifest_url": "https://manifest.example.com/v1/manifest/latest", + "details": { + "auto_install_result": "restarting", + "approved_bundle_release_id": "rel_restart", + "latest_pro_bundle": { + "release_id": "rel_restart", + "display_version": "v2026.6.24", + "core_version": "v2026.6.21", + "flockspro_component_version": "v2026.6.24", + "build_id": "job_restart", + }, + }, + "created_at": "2026-05-08T08:00:00+00:00", + "updated_at": "2026-05-08T08:00:00+00:00", + }, + "json", + ) + + async def _noop(_record: dict, **_kwargs): + return None + + reported: list[tuple[str, str | None]] = [] + + async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): + reported.append((install_result, error_message)) + record.setdefault("details", {})["install_receipt_reported_at"] = "2026-06-24T08:11:09+00:00" + + monkeypatch.setattr(console_routes, "_maybe_activate_pro_license", _noop) + monkeypatch.setattr(console_routes, "_maybe_refresh_pro_license", _noop) + monkeypatch.setattr(console_routes, "_report_pro_bundle_installation", _fake_report) + monkeypatch.setattr(console_routes, "_mark_console_upgrade_activated", _noop) + monkeypatch.setattr(console_routes, "_get_pro_capability_status", lambda: {"pro_enabled": True, "active": True}) + monkeypatch.setattr( + console_routes, + "_read_pro_bundle_install_marker", + lambda: { + "release_id": "rel_restart", + "bundle_release_id": "rel_restart", + "installed_version": "v2026.6.24", + "core_version": "v2026.6.21", + "flockspro_component_version": "v2026.6.24", + "build_id": "job_restart", + }, + ) + + resp = await client.get(f"/api/console/upgrade-requests/{request_id}") + + assert resp.status_code == status.HTTP_200_OK + payload = resp.json() + assert payload["status"] == "activated" + assert payload["details"]["auto_install_result"] == "done" + assert payload["details"]["auto_install_version"] == "v2026.6.24" + assert reported == [("success", None)] + + async def test_start_approved_request_reports_error_after_restart_stage( client: AsyncClient, monkeypatch: pytest.MonkeyPatch, @@ -988,7 +895,7 @@ async def _fake_perform_pro_bundle_install(*args, **kwargs): yield UpdateProgress(stage="restarting", message="Restarting service...", success=None) yield UpdateProgress(stage="error", message="Failed to build restart command: missing python", success=False) - async def _noop(_record: dict): + async def _noop(_record: dict, **_kwargs): return None reported: list[tuple[str, str | None]] = [] @@ -1054,7 +961,7 @@ async def _fake_perform_pro_bundle_install(*args, **kwargs): installed = True yield UpdateProgress(stage="done", message="Flocks Pro component installed.", success=True) - async def _noop(_record: dict): + async def _noop(_record: dict, **_kwargs): return None async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): @@ -1124,7 +1031,7 @@ async def test_auto_activate_reports_already_latest_install( async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): reported.append((install_result, error_message)) - async def _noop(_record: dict): + async def _noop(_record: dict, **_kwargs): return None monkeypatch.setattr(console_routes, "_maybe_activate_pro_license", _noop) @@ -1154,15 +1061,87 @@ async def _noop(_record: dict): assert reported == [("success", None)] -async def test_auto_activate_does_not_mark_activated_when_license_inactive( +async def test_auto_activate_reinstalls_when_existing_pro_marker_is_not_target_bundle( monkeypatch: pytest.MonkeyPatch, ): from flocks.server.routes import console_upgrade as console_routes + from flocks.updater.models import UpdateProgress + + marker_state = { + "payload": { + "release_id": "rel_20260601", + "bundle_release_id": "rel_20260601", + "installed_version": "v2026.6.1", + "flockspro_component_version": "v2026.6.1", + "build_id": "job_20260601", + } + } + reported: list[tuple[str, str | None]] = [] + + async def _fake_perform_pro_bundle_install(*args, **kwargs): + assert args == () + assert kwargs["restart"] is False + marker_state["payload"] = { + "release_id": "rel_20260605", + "bundle_release_id": "rel_20260605", + "installed_version": "v2026.6.5", + "flockspro_component_version": "v2026.6.5", + "build_id": "job_20260605", + } + yield UpdateProgress(stage="done", message="Flocks Pro component installed.", success=True) async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): + reported.append((install_result, error_message)) + + async def _noop(_record: dict, **_kwargs): return None - async def _noop(_record: dict): + monkeypatch.setattr(console_routes, "perform_pro_bundle_install", _fake_perform_pro_bundle_install) + monkeypatch.setattr(console_routes, "_maybe_activate_pro_license", _noop) + monkeypatch.setattr(console_routes, "_maybe_refresh_pro_license", _noop) + monkeypatch.setattr(console_routes, "_report_pro_bundle_installation", _fake_report) + monkeypatch.setattr(console_routes, "_is_pro_component_installed", lambda: True) + monkeypatch.setattr(console_routes, "_get_pro_capability_status", lambda: {"pro_enabled": True, "active": True}) + monkeypatch.setattr(console_routes, "_read_pro_bundle_install_marker", lambda: marker_state["payload"]) + + record = { + "request_id": "req_auto_reinstall_target_bundle", + "status": "approved", + "activate_key": "key_auto", + "details": { + "auto_install_result": "already_latest", + "approved_bundle_release_id": "rel_20260605", + "latest_pro_bundle": { + "release_id": "rel_20260605", + "display_version": "v2026.6.5", + "flockspro_component_version": "v2026.6.5", + "build_id": "job_20260605", + }, + }, + "created_at": "2026-06-05T08:00:00+00:00", + "updated_at": "2026-06-05T08:00:00+00:00", + } + + payload = await console_routes._maybe_auto_activate_upgrade(record) + + assert payload["status"] == "activated" + assert payload["details"]["auto_install_result"] == "done" + assert payload["details"]["auto_install_release_id"] == "rel_20260605" + assert payload["details"]["auto_install_version"] == "v2026.6.5" + assert reported == [("success", None)] + + +async def test_auto_activate_does_not_mark_activated_when_license_inactive( + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + + reported: list[tuple[str, str | None]] = [] + + async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): + reported.append((install_result, error_message)) + + async def _noop(_record: dict, **_kwargs): return None monkeypatch.setattr(console_routes, "_maybe_activate_pro_license", _noop) @@ -1187,8 +1166,12 @@ async def _noop(_record: dict): payload = await console_routes._maybe_auto_activate_upgrade(record) assert payload["status"] == "approved" - assert payload["details"]["auto_install_result"] == "license_inactive" + assert payload["details"]["auto_install_result"] == "failed" + assert "license activation is inactive" in payload["details"]["auto_install_error"] assert payload["details"]["runtime_license_inactive_reason"] == "expired" + assert reported + assert reported[-1][0] == "failed" + assert "license activation is inactive" in (reported[-1][1] or "") async def test_auto_activate_installs_pro_bundle_when_core_version_is_latest( @@ -1210,7 +1193,7 @@ async def _fake_perform_pro_bundle_install(*args, **kwargs): async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): return None - async def _noop(_record: dict): + async def _noop(_record: dict, **_kwargs): return None monkeypatch.setattr(console_routes, "perform_pro_bundle_install", _fake_perform_pro_bundle_install) @@ -1278,9 +1261,94 @@ async def post(self, url, json=None, headers=None): "status": "approved", "activate_key": "activation_token", "license_id": "lic_receipt", - "details": {"license_id": "lic_receipt"}, + "details": { + "license_id": "lic_receipt", + "approved_bundle_release_id": "rel_receipt", + "latest_pro_bundle": { + "release_id": "rel_receipt", + "display_version": "v2026.6.5", + "core_version": "v2026.6.1", + "flockspro_component_version": "v2026.6.5", + "build_id": "job_receipt", + }, + }, } await console_routes._report_pro_bundle_installation(record, install_result="success") assert posted_payloads[0]["license_id"] == "lic_receipt" + assert posted_payloads[0]["request_id"] == "req_receipt" + assert posted_payloads[0]["release_id"] == "rel_receipt" + assert posted_payloads[0]["bundle_release_id"] == "rel_receipt" + assert posted_payloads[0]["core_version"] == "v2026.6.1" + assert posted_payloads[0]["oss_version"] == "v2026.6.1" + assert posted_payloads[0]["build_id"] == "job_receipt" + + +async def test_report_failed_installation_uses_target_bundle_when_marker_is_stale(monkeypatch: pytest.MonkeyPatch): + from flocks.server.routes import console_upgrade as console_routes + + posted_payloads: list[dict] = [] + + class _Response: + def raise_for_status(self): + return None + + class _Client: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + posted_payloads.append(json) + return _Response() + + monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "https://console.example.com") + await _set_bound_console_session() + monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _Client()) + monkeypatch.setattr( + console_routes, + "_read_pro_bundle_install_marker", + lambda: { + "release_id": "rel_old", + "bundle_release_id": "rel_old", + "installed_version": "v2026.6.1", + "flockspro_component_version": "v2026.6.1", + "build_id": "job_old", + }, + ) + + record = { + "request_id": "req_failed_receipt", + "status": "approved", + "activate_key": "activation_token", + "license_id": "lic_receipt", + "details": { + "license_id": "lic_receipt", + "approved_bundle_release_id": "rel_new", + "latest_pro_bundle": { + "release_id": "rel_new", + "display_version": "v2026.6.5", + "oss_version": "v2026.6.5", + "flockspro_component_version": "v2026.6.5", + "build_id": "job_new", + }, + }, + } + + await console_routes._report_pro_bundle_installation( + record, + install_result="failed", + error_message="install failed", + ) + + assert posted_payloads[0]["release_id"] == "rel_new" + assert posted_payloads[0]["bundle_release_id"] == "rel_new" + assert posted_payloads[0]["installed_version"] == "v2026.6.5" + assert posted_payloads[0]["build_id"] == "job_new" + assert posted_payloads[0]["install_result"] == "failed" diff --git a/tests/server/routes/test_contracts_access_routes.py b/tests/server/routes/test_contracts_access_routes.py new file mode 100644 index 000000000..7f0026e99 --- /dev/null +++ b/tests/server/routes/test_contracts_access_routes.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from httpx import AsyncClient + +from flocks.auth.service import AuthService +from flocks.server.auth import SESSION_COOKIE_NAME +from flocks.server.routes import contracts as contracts_routes +from flocks.contracts.access.runtime import OperationRuntime +from flocks.contracts.webui.store import WebUIPagesStore +from tests.contracts.access.test_runtime import ( + CONTRACT_ID, + PAGE_ID, + _contract_record, + _plugin, + _write_contract_assets, +) + + +def test_contract_route_runtime_reloads_when_plugin_signature_changes(monkeypatch: pytest.MonkeyPatch): + created: list[object] = [] + signature = (("plugin.py", 1, 100),) + + class RuntimeStub: + def __init__(self) -> None: + created.append(self) + + monkeypatch.setattr(contracts_routes, "OperationRuntime", RuntimeStub) + monkeypatch.setattr(contracts_routes, "_contract_plugin_signature", lambda: signature) + contracts_routes.reset_route_dependencies() + + first = contracts_routes._get_runtime() + second = contracts_routes._get_runtime() + assert first is second + assert len(created) == 1 + + signature = (("plugin.py", 2, 100),) + third = contracts_routes._get_runtime() + assert third is not first + assert len(created) == 2 + + +@pytest.fixture +def contract_pages(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + root = tmp_path / "contracts-webui" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + store = WebUIPagesStore(root=root) + _write_contract_assets(store, [_contract_record(id="record-route-1")]) + contracts_routes.reset_route_dependencies(runtime=OperationRuntime(plugins=(_plugin(store),))) + return store + + +@pytest.fixture +def contract_pages_with_policy_rows(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + root = tmp_path / "contracts-webui-policy" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + store = WebUIPagesStore(root=root) + _write_contract_assets( + store, + [ + _contract_record(id="allowed", tenant="tenant-a", asset_group="core"), + _contract_record(id="blocked", tenant="tenant-b", asset_group="core"), + ], + ) + contracts_routes.reset_route_dependencies(runtime=OperationRuntime(plugins=(_plugin(store),))) + return store + + +@pytest.mark.asyncio +async def test_contract_operation_route(client: AsyncClient, contract_pages: WebUIPagesStore): + resp = await client.post( + f"/api/contracts/webui/pages/{PAGE_ID}/access/{CONTRACT_ID}/operations/list", + json={"params": {"limit": 10}}, + ) + + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["summary"]["totalRaw"] == 1 + assert body["items"][0]["id"] == "record-route-1" + + +@pytest.mark.asyncio +async def test_contract_operation_route_applies_loaded_member_policy( + client: AsyncClient, + contract_pages_with_policy_rows: WebUIPagesStore, +): + await AuthService._create_user_internal( + username="analyst", + password="Password123!", + role="member", + tenant_ids=("tenant-a",), + asset_groups=("core",), + ) + _user, session_id = await AuthService.login("analyst", "Password123!") + client.cookies.set(SESSION_COOKIE_NAME, session_id) + + response = await client.post( + f"/api/contracts/webui/pages/{PAGE_ID}/access/{CONTRACT_ID}/operations/list", + json={"params": {"limit": 10}}, + ) + + assert response.status_code == 200 + body = response.json() + assert [item["id"] for item in body["items"]] == ["allowed"] + assert body["meta"]["filterStagesApplied"][0]["source"] == "policy.tenantIds" + + +@pytest.mark.asyncio +async def test_contract_operation_route_rejects_forbidden_fields( + client: AsyncClient, + contract_pages: WebUIPagesStore, +): + resp = await client.post( + f"/api/contracts/webui/pages/{PAGE_ID}/access/{CONTRACT_ID}/operations/list", + json={"params": {"driver": "sqlite"}}, + ) + + assert resp.status_code == 400 + assert resp.json()["error"]["code"] == "forbidden_request_field" diff --git a/tests/server/routes/test_contracts_webui_routes.py b/tests/server/routes/test_contracts_webui_routes.py new file mode 100644 index 000000000..ef2bd1db2 --- /dev/null +++ b/tests/server/routes/test_contracts_webui_routes.py @@ -0,0 +1,459 @@ +import io +import json +import zipfile +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import AsyncClient + +from flocks.server.app import app +from flocks.server.auth import require_admin +from flocks.server.routes import webui as webui_routes +from flocks.contracts.webui.builder import WebUIPageBuilder +from flocks.contracts.webui.models import WebUIPageBuildMeta +from flocks.contracts.webui.store import WebUIPagesStore + + +def _make_page_archive(page_id: str, manifest: dict, extra_files: dict[str, str] | None = None) -> bytes: + buffer = io.BytesIO() + files = { + "manifest.json": json.dumps(manifest), + "src/index.tsx": "export default function Page(){return
ok
;}", + } + if extra_files: + files.update(extra_files) + with zipfile.ZipFile(buffer, "w", compression=zipfile.ZIP_DEFLATED) as zf: + for relative_path, content in files.items(): + zf.writestr(f"{page_id}/{relative_path}", content) + return buffer.getvalue() + + +@pytest.fixture +def webui_pages_env(tmp_path, monkeypatch): + root = tmp_path / "webui_pages" + monkeypatch.setenv("FLOCKS_CONTRACTS_WEBUI_ROOT", str(root)) + store = WebUIPagesStore() + builder = WebUIPageBuilder(store) + webui_routes.reset_route_dependencies(store=store, builder=builder) + return store + + +@pytest.mark.asyncio +async def test_create_and_list_webui_pages(client: AsyncClient, webui_pages_env: WebUIPagesStore): + create_resp = await client.post( + "/api/contracts/webui/pages", + json={"id": "dash-1", "title": "仪表盘"}, + ) + assert create_resp.status_code == 201, create_resp.text + data = create_resp.json() + assert data["manifest"]["id"] == "dash-1" + + list_resp = await client.get("/api/contracts/webui/pages", params={"enabledOnly": True}) + assert list_resp.status_code == 200 + items = list_resp.json() + assert len(items) == 1 + assert items[0]["title"] == "仪表盘" + assert items[0]["route"] == "/contracts/webui/dash-1" + + +@pytest.mark.asyncio +async def test_list_webui_workspaces_returns_grouped_pages(client: AsyncClient, webui_pages_env: WebUIPagesStore): + root = webui_pages_env.root + workspace_dir = root / "scene_workspace" + workspace_dir.mkdir(parents=True, exist_ok=True) + (workspace_dir / "workspace.json").write_text( + json.dumps( + { + "id": "scene_workspace", + "title": "场景工作区", + "icon": "ShieldCheck", + "order": 10, + "enabled": True, + "placement": "sceneWorkspace", + "defaultPageId": "ops-overview", + "sections": [ + { + "id": "operations", + "label": "调查列表", + "pageIds": ["ops-overview", "investigation-list"], + "defaultPageId": "ops-overview", + "contentPadding": "comfortable", + } + ], + } + ), + encoding="utf-8", + ) + for page_id, page_dir_name, title, order in [ + ("ops-overview", "ops_overview", "运营总览", 10), + ("investigation-list", "investigation_list", "调查列表", 20), + ]: + page_dir = workspace_dir / page_dir_name + (page_dir / "src").mkdir(parents=True, exist_ok=True) + (page_dir / "dist").mkdir(parents=True, exist_ok=True) + (page_dir / "manifest.json").write_text( + json.dumps( + { + "id": page_id, + "title": title, + "route": f"/contracts/webui/{page_id}", + "icon": "LayoutDashboard", + "order": order, + "enabled": True, + "placement": "home.after", + "entry": "src/index.tsx", + "updatedAt": 0, + } + ), + encoding="utf-8", + ) + + resp = await client.get("/api/contracts/webui/workspaces", params={"enabledOnly": True}) + + assert resp.status_code == 200 + assert resp.json() == [ + { + "id": "scene_workspace", + "title": "场景工作区", + "route": "/contracts/webui/workspaces/scene_workspace", + "icon": "ShieldCheck", + "order": 10, + "enabled": True, + "placement": "sceneWorkspace", + "defaultPageId": "ops-overview", + "sections": [ + { + "id": "operations", + "label": "调查列表", + "pageIds": ["ops-overview", "investigation-list"], + "defaultPageId": "ops-overview", + "contentPadding": "comfortable", + "themeOverride": None, + } + ], + "pages": [ + { + "id": "ops-overview", + "title": "运营总览", + "route": "/contracts/webui/ops-overview", + "icon": "LayoutDashboard", + "order": 10, + "enabled": True, + "placement": "home.after", + "buildHash": "", + "buildStatus": "idle", + "workspaceId": "scene_workspace", + "workspaceTitle": "场景工作区", + "workspaceRoute": "/contracts/webui/workspaces/scene_workspace", + }, + { + "id": "investigation-list", + "title": "调查列表", + "route": "/contracts/webui/investigation-list", + "icon": "LayoutDashboard", + "order": 20, + "enabled": True, + "placement": "home.after", + "buildHash": "", + "buildStatus": "idle", + "workspaceId": "scene_workspace", + "workspaceTitle": "场景工作区", + "workspaceRoute": "/contracts/webui/workspaces/scene_workspace", + }, + ], + } + ] + + +@pytest.mark.asyncio +async def test_legacy_user_defined_pages_api_alias_uses_contract_routes( + client: AsyncClient, + webui_pages_env: WebUIPagesStore, +): + create_resp = await client.post( + "/api/user-defined-pages", + json={"id": "legacy-dash", "title": "旧页面"}, + ) + assert create_resp.status_code == 201, create_resp.text + assert create_resp.json()["manifest"]["route"] == "/contracts/webui/legacy-dash" + + list_resp = await client.get("/api/user-defined-pages", params={"enabledOnly": True}) + assert list_resp.status_code == 200 + assert list_resp.json()[0]["route"] == "/contracts/webui/legacy-dash" + + +@pytest.mark.asyncio +async def test_legacy_user_defined_pages_runtime_aliases( + client: AsyncClient, + webui_pages_env: WebUIPagesStore, +): + create_resp = await client.post( + "/api/user-defined-pages", + json={"id": "legacy-runtime", "title": "旧运行时"}, + ) + assert create_resp.status_code == 201, create_resp.text + + bundle_path = webui_pages_env.bundle_path("legacy-runtime") + bundle_path.parent.mkdir(parents=True, exist_ok=True) + bundle_path.write_text("export default function Page(){return null;}", encoding="utf-8") + webui_pages_env.write_build_meta( + "legacy-runtime", + WebUIPageBuildMeta(status="ready", hash="legacy-hash", builtAt=1), + ) + bundle_resp = await client.get("/api/user-defined-pages/legacy-runtime/bundle.js") + assert bundle_resp.status_code == 200 + assert "application/javascript" in bundle_resp.headers.get("content-type", "") + + asset_path = webui_pages_env.asset_path("legacy-runtime", "logo.txt") + asset_path.parent.mkdir(parents=True, exist_ok=True) + asset_path.write_text("asset-ok", encoding="utf-8") + asset_resp = await client.get("/api/user-defined-pages/legacy-runtime/assets/logo.txt") + assert asset_resp.status_code == 200 + assert asset_resp.text == "asset-ok" + + webui_pages_env.save_source_file( + "legacy-runtime", + "api/routes.yaml", + "routes:\n - method: GET\n path: /stats\n handler: handlers.get_stats\n", + ) + webui_pages_env.save_source_file( + "legacy-runtime", + "api/handlers.py", + "def get_stats(ctx, request):\n return {'ok': True}\n", + ) + list_api_resp = await client.get("/api/user-defined-pages/legacy-runtime/api") + assert list_api_resp.status_code == 200 + assert list_api_resp.json()[0]["path"] == "/stats" + reload_resp = await client.post("/api/user-defined-pages/legacy-runtime/api/reload") + assert reload_resp.status_code == 200 + dispatch_resp = await client.get("/api/user-defined-pages/legacy-runtime/api/stats") + assert dispatch_resp.status_code == 200 + assert dispatch_resp.json()["ok"] is True + + export_resp = await client.get("/api/user-defined-pages/legacy-runtime/export") + assert export_resp.status_code == 200 + assert export_resp.headers.get("content-type", "").startswith("application/zip") + + archive = _make_page_archive( + "legacy-imported", + { + "id": "legacy-imported", + "title": "旧导入", + "route": "/user-defined-pages/legacy-imported", + "icon": "LayoutDashboard", + "order": 10, + "enabled": True, + "placement": "home.after", + "entry": "src/index.tsx", + "updatedAt": 1, + }, + ) + import_resp = await client.post( + "/api/user-defined-pages/import", + files={"file": ("legacy-imported.zip", archive, "application/zip")}, + ) + assert import_resp.status_code == 200, import_resp.text + assert import_resp.json()["manifest"]["route"] == "/contracts/webui/legacy-imported" + + +@pytest.mark.asyncio +async def test_save_source_triggers_build_and_event(client: AsyncClient, webui_pages_env: WebUIPagesStore): + await client.post("/api/contracts/webui/pages", json={"id": "live-page", "title": "实时页"}) + source = webui_pages_env.read_source_file("live-page", "src/Page.tsx") + + with patch("flocks.server.routes.webui._builder.build") as build_mock: + build_mock.return_value = WebUIPageBuildMeta( + status="ready", + hash="abc123", + builtAt=1, + error=None, + ) + with patch("flocks.server.routes.webui.publish_event", new_callable=AsyncMock) as publish_mock: + save_resp = await client.put( + "/api/contracts/webui/pages/live-page", + json={"sourcePath": "src/Page.tsx", "sourceContent": source}, + ) + + assert save_resp.status_code == 200, save_resp.text + body = save_resp.json() + assert body["build"]["status"] == "ready" + publish_mock.assert_any_await("contracts.webui.pages.updated", {"id": "live-page", "hash": "abc123"}) + + +@pytest.mark.asyncio +async def test_bundle_endpoint_available_after_create(client: AsyncClient, webui_pages_env: WebUIPagesStore): + await client.post("/api/contracts/webui/pages", json={"id": "empty-page", "title": "空页面"}) + bundle_path = webui_pages_env.bundle_path("empty-page") + bundle_path.parent.mkdir(parents=True, exist_ok=True) + bundle_path.write_text("export default function Page(){return null;}", encoding="utf-8") + webui_pages_env.write_build_meta( + "empty-page", + WebUIPageBuildMeta(status="ready", hash="test-hash", builtAt=1), + ) + bundle_resp = await client.get("/api/contracts/webui/pages/empty-page/bundle.js") + assert bundle_resp.status_code == 200 + assert "application/javascript" in bundle_resp.headers.get("content-type", "") + assert "content-disposition" not in bundle_resp.headers + assert bundle_resp.text.strip() + + +@pytest.mark.asyncio +async def test_reject_invalid_page_id_on_create(client: AsyncClient, webui_pages_env: WebUIPagesStore): + resp = await client.post("/api/contracts/webui/pages", json={"id": "../bad", "title": "坏页面"}) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_admin_required_for_create(client: AsyncClient, webui_pages_env: WebUIPagesStore): + from fastapi import HTTPException, Request + + def _deny_admin(_request: Request): + raise HTTPException(status_code=403, detail="仅管理员可执行该操作") + + app.dependency_overrides[require_admin] = _deny_admin + try: + resp = await client.post("/api/contracts/webui/pages", json={"id": "denied-page", "title": "禁止"}) + assert resp.status_code == 403 + finally: + app.dependency_overrides.pop(require_admin, None) + + +@pytest.mark.asyncio +async def test_admin_required_for_build_and_api_reload(client: AsyncClient, webui_pages_env: WebUIPagesStore): + from fastapi import HTTPException, Request + + await client.post("/api/contracts/webui/pages", json={"id": "admin-guard-page", "title": "权限页"}) + webui_pages_env.save_source_file( + "admin-guard-page", + "api/routes.yaml", + "routes:\n - method: GET\n path: /x\n handler: handlers.x\n", + ) + webui_pages_env.save_source_file( + "admin-guard-page", + "api/handlers.py", + "def x(ctx, request):\n return {'ok': True}\n", + ) + + def _deny_admin(_request: Request): + raise HTTPException(status_code=403, detail="仅管理员可执行该操作") + + app.dependency_overrides[require_admin] = _deny_admin + try: + build_resp = await client.post("/api/contracts/webui/pages/admin-guard-page/build") + assert build_resp.status_code == 403 + + reload_resp = await client.post("/api/contracts/webui/pages/admin-guard-page/api/reload") + assert reload_resp.status_code == 403 + finally: + app.dependency_overrides.pop(require_admin, None) + + +@pytest.mark.asyncio +async def test_page_api_routes_reload_and_dispatch(client: AsyncClient, webui_pages_env: WebUIPagesStore): + await client.post("/api/contracts/webui/pages", json={"id": "api-page", "title": "接口页"}) + webui_pages_env.save_source_file( + "api-page", + "api/routes.yaml", + "routes:\n - method: GET\n path: /stats\n handler: handlers.get_stats\n", + ) + webui_pages_env.save_source_file( + "api-page", + "api/handlers.py", + "def get_stats(ctx, request):\n return {'ok': True}\n", + ) + + list_resp = await client.get("/api/contracts/webui/pages/api-page/api") + assert list_resp.status_code == 200 + assert list_resp.json()[0]["path"] == "/stats" + + reload_resp = await client.post("/api/contracts/webui/pages/api-page/api/reload") + assert reload_resp.status_code == 200 + assert reload_resp.json()["routes"][0]["handler"] == "handlers.get_stats" + + dispatch_resp = await client.get("/api/contracts/webui/pages/api-page/api/stats") + assert dispatch_resp.status_code == 200 + assert dispatch_resp.json()["ok"] is True + + +@pytest.mark.asyncio +async def test_export_and_import_webui_page(client: AsyncClient, webui_pages_env: WebUIPagesStore): + await client.post("/api/contracts/webui/pages", json={"id": "backup-page", "title": "备份页"}) + await client.put( + "/api/contracts/webui/pages/backup-page", + json={"sourcePath": "src/Page.tsx", "sourceContent": "export default function Page(){return
backup
;}"}, + ) + + export_resp = await client.get("/api/contracts/webui/pages/backup-page/export") + assert export_resp.status_code == 200 + assert export_resp.headers.get("content-type", "").startswith("application/zip") + + import_resp = await client.post( + "/api/contracts/webui/pages/import?overwrite=true", + files={"file": ("backup-page.zip", export_resp.content, "application/zip")}, + ) + assert import_resp.status_code == 200 + assert import_resp.json()["manifest"]["id"] == "backup-page" + + +@pytest.mark.asyncio +async def test_import_normalizes_manifest_identity(client: AsyncClient, webui_pages_env: WebUIPagesStore): + archive = _make_page_archive( + "fixed-page", + { + "id": "wrong-page", + "title": "导入页", + "route": "/contracts/webui/wrong-page", + "icon": "LayoutDashboard", + "order": 10, + "enabled": True, + "placement": "home.after", + "entry": "src/index.tsx", + "updatedAt": 1, + }, + ) + + import_resp = await client.post( + "/api/contracts/webui/pages/import", + files={"file": ("fixed-page.zip", archive, "application/zip")}, + ) + + assert import_resp.status_code == 200, import_resp.text + body = import_resp.json() + assert body["manifest"]["id"] == "fixed-page" + assert body["manifest"]["route"] == "/contracts/webui/fixed-page" + + list_resp = await client.get("/api/contracts/webui/pages") + assert list_resp.status_code == 200 + assert list_resp.json()[0]["id"] == "fixed-page" + assert list_resp.json()[0]["route"] == "/contracts/webui/fixed-page" + + +@pytest.mark.asyncio +async def test_import_rejects_archives_with_too_many_files( + client: AsyncClient, + webui_pages_env: WebUIPagesStore, + monkeypatch, +): + monkeypatch.setattr(webui_routes, "MAX_IMPORT_FILES", 1) + archive = _make_page_archive( + "too-many", + { + "id": "too-many", + "title": "过多文件", + "route": "/contracts/webui/too-many", + "icon": "LayoutDashboard", + "order": 10, + "enabled": True, + "placement": "home.after", + "entry": "src/index.tsx", + "updatedAt": 1, + }, + ) + + import_resp = await client.post( + "/api/contracts/webui/pages/import", + files={"file": ("too-many.zip", archive, "application/zip")}, + ) + + assert import_resp.status_code == 400 + assert "too many files" in import_resp.text diff --git a/tests/server/routes/test_session_routes.py b/tests/server/routes/test_session_routes.py index 42023abd6..5aa2e8994 100644 --- a/tests/server/routes/test_session_routes.py +++ b/tests/server/routes/test_session_routes.py @@ -218,6 +218,175 @@ async def test_delete_session(self, client: AsyncClient, session_id: str): get_resp = await client.get(f"/api/session/{session_id}") assert get_resp.status_code == status.HTTP_404_NOT_FOUND + @staticmethod + def _patch_delete_session_dependencies( + monkeypatch, + session_routes, + *, + session_id: str, + order: list[str], + session_list, + ) -> None: + async def fake_abort_session_processing(abort_session_id: str) -> bool: + order.append(f"abort:{abort_session_id}") + return True + + async def fake_wait_for_sessions_idle(session_ids: list[str], timeout_s: float = 5.0) -> None: + order.append(f"wait:{','.join(session_ids)}") + + async def fake_interaction_queue_clear(_session_id: str) -> None: + order.append("queue_clear") + + async def fake_goal_clear(_session_id: str) -> None: + order.append("goal_clear") + + async def fake_session_delete(_project_id: str, delete_session_id: str) -> bool: + assert delete_session_id == session_id + order.append("delete") + return True + + monkeypatch.setattr(session_routes.Session, "list", session_list) + monkeypatch.setattr( + session_routes, + "_abort_session_processing", + fake_abort_session_processing, + ) + monkeypatch.setattr( + session_routes, + "_wait_for_sessions_idle", + fake_wait_for_sessions_idle, + ) + monkeypatch.setattr( + "flocks.session.interaction_queue.InteractionQueue.clear", + fake_interaction_queue_clear, + ) + monkeypatch.setattr( + "flocks.session.goal.GoalManager.clear", + fake_goal_clear, + ) + monkeypatch.setattr(session_routes.Session, "delete", fake_session_delete) + + @pytest.mark.asyncio + async def test_delete_session_aborts_and_waits_before_delete( + self, + client: AsyncClient, + session_id: str, + monkeypatch, + ): + """DELETE waits for active processing to stop before clearing messages.""" + from flocks.server.routes import session as session_routes + + order: list[str] = [] + + async def fake_session_list(_project_id: str): + return [] + + self._patch_delete_session_dependencies( + monkeypatch, + session_routes, + session_id=session_id, + order=order, + session_list=fake_session_list, + ) + + resp = await client.delete(f"/api/session/{session_id}") + + assert resp.status_code == status.HTTP_200_OK + assert resp.json() is True + assert order == [ + f"abort:{session_id}", + "queue_clear", + "goal_clear", + f"wait:{session_id}", + "delete", + ] + + @pytest.mark.asyncio + async def test_delete_session_waits_for_descendant_loops_before_delete( + self, + client: AsyncClient, + session_id: str, + monkeypatch, + ): + """DELETE waits for child and grandchild loops before recursive delete.""" + from flocks.server.routes import session as session_routes + + child_id = "ses_delete_child_wait" + grandchild_id = "ses_delete_grandchild_wait" + order: list[str] = [] + + async def fake_session_list(_project_id: str): + return [ + SimpleNamespace(id=child_id, parent_id=session_id), + SimpleNamespace(id=grandchild_id, parent_id=child_id), + ] + + self._patch_delete_session_dependencies( + monkeypatch, + session_routes, + session_id=session_id, + order=order, + session_list=fake_session_list, + ) + + resp = await client.delete(f"/api/session/{session_id}") + + assert resp.status_code == status.HTTP_200_OK + assert resp.json() is True + assert order == [ + f"abort:{session_id}", + "queue_clear", + "goal_clear", + f"wait:{session_id}", + f"abort:{child_id}", + f"abort:{grandchild_id}", + f"wait:{child_id},{grandchild_id}", + "delete", + ] + + @pytest.mark.asyncio + async def test_delete_session_aborts_descendant_that_appears_after_parent_wait( + self, + client: AsyncClient, + session_id: str, + monkeypatch, + ): + """DELETE re-collects descendants after parent abort to catch late children.""" + from flocks.server.routes import session as session_routes + + child_id = "ses_delete_late_child_wait" + list_calls = 0 + order: list[str] = [] + + async def fake_session_list(_project_id: str): + nonlocal list_calls + list_calls += 1 + if list_calls == 1: + return [] + return [SimpleNamespace(id=child_id, parent_id=session_id)] + + self._patch_delete_session_dependencies( + monkeypatch, + session_routes, + session_id=session_id, + order=order, + session_list=fake_session_list, + ) + + resp = await client.delete(f"/api/session/{session_id}") + + assert resp.status_code == status.HTTP_200_OK + assert resp.json() is True + assert order == [ + f"abort:{session_id}", + "queue_clear", + "goal_clear", + f"wait:{session_id}", + f"abort:{child_id}", + f"wait:{child_id}", + "delete", + ] + @pytest.mark.asyncio async def test_delete_session_not_found(self, client: AsyncClient): """DELETE for unknown session returns 404.""" diff --git a/tests/server/routes/test_user_defined_pages_routes.py b/tests/server/routes/test_user_defined_pages_routes.py deleted file mode 100644 index 686a10d33..000000000 --- a/tests/server/routes/test_user_defined_pages_routes.py +++ /dev/null @@ -1,259 +0,0 @@ -import io -import json -import zipfile -from unittest.mock import AsyncMock, patch - -import pytest -from httpx import AsyncClient - -from flocks.server.app import app -from flocks.server.auth import require_admin -from flocks.server.routes import user_defined_pages as user_defined_pages_routes -from flocks.user_defined_pages.builder import UserDefinedPagesBuilder -from flocks.user_defined_pages.models import UserDefinedPageBuildMeta -from flocks.user_defined_pages.store import UserDefinedPagesStore - - -def _make_page_archive(page_id: str, manifest: dict, extra_files: dict[str, str] | None = None) -> bytes: - buffer = io.BytesIO() - files = { - "manifest.json": json.dumps(manifest), - "src/index.tsx": "export default function Page(){return
ok
;}", - } - if extra_files: - files.update(extra_files) - with zipfile.ZipFile(buffer, "w", compression=zipfile.ZIP_DEFLATED) as zf: - for relative_path, content in files.items(): - zf.writestr(f"{page_id}/{relative_path}", content) - return buffer.getvalue() - - -@pytest.fixture -def user_defined_pages_env(tmp_path, monkeypatch): - root = tmp_path / "user_defined_pages" - monkeypatch.setenv("FLOCKS_USER_DEFINED_PAGES_ROOT", str(root)) - store = UserDefinedPagesStore() - builder = UserDefinedPagesBuilder(store) - user_defined_pages_routes.reset_route_dependencies(store=store, builder=builder) - return store - - -@pytest.mark.asyncio -async def test_create_and_list_user_defined_pages(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - create_resp = await client.post( - "/api/user-defined-pages", - json={"id": "dash-1", "title": "仪表盘"}, - ) - assert create_resp.status_code == 201, create_resp.text - data = create_resp.json() - assert data["manifest"]["id"] == "dash-1" - - list_resp = await client.get("/api/user-defined-pages", params={"enabledOnly": True}) - assert list_resp.status_code == 200 - items = list_resp.json() - assert len(items) == 1 - assert items[0]["title"] == "仪表盘" - assert items[0]["route"] == "/user-defined-pages/dash-1" - - -@pytest.mark.asyncio -async def test_save_source_triggers_build_and_event(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - await client.post("/api/user-defined-pages", json={"id": "live-page", "title": "实时页"}) - source = user_defined_pages_env.read_source_file("live-page", "src/Page.tsx") - - with patch("flocks.server.routes.user_defined_pages._builder.build") as build_mock: - build_mock.return_value = UserDefinedPageBuildMeta( - status="ready", - hash="abc123", - builtAt=1, - error=None, - ) - with patch("flocks.server.routes.user_defined_pages.publish_event", new_callable=AsyncMock) as publish_mock: - save_resp = await client.put( - "/api/user-defined-pages/live-page", - json={"sourcePath": "src/Page.tsx", "sourceContent": source}, - ) - - assert save_resp.status_code == 200, save_resp.text - body = save_resp.json() - assert body["build"]["status"] == "ready" - publish_mock.assert_any_await("user_defined_pages.updated", {"id": "live-page", "hash": "abc123"}) - - -@pytest.mark.asyncio -async def test_bundle_endpoint_available_after_create(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - await client.post("/api/user-defined-pages", json={"id": "empty-page", "title": "空页面"}) - bundle_path = user_defined_pages_env.bundle_path("empty-page") - bundle_path.parent.mkdir(parents=True, exist_ok=True) - bundle_path.write_text("export default function Page(){return null;}", encoding="utf-8") - user_defined_pages_env.write_build_meta( - "empty-page", - UserDefinedPageBuildMeta(status="ready", hash="test-hash", builtAt=1), - ) - bundle_resp = await client.get("/api/user-defined-pages/empty-page/bundle.js") - assert bundle_resp.status_code == 200 - assert "application/javascript" in bundle_resp.headers.get("content-type", "") - assert "content-disposition" not in bundle_resp.headers - assert bundle_resp.text.strip() - - -@pytest.mark.asyncio -async def test_reject_invalid_page_id_on_create(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - resp = await client.post("/api/user-defined-pages", json={"id": "../bad", "title": "坏页面"}) - assert resp.status_code == 400 - - -@pytest.mark.asyncio -async def test_admin_required_for_create(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - from fastapi import HTTPException, Request - - def _deny_admin(_request: Request): - raise HTTPException(status_code=403, detail="仅管理员可执行该操作") - - app.dependency_overrides[require_admin] = _deny_admin - try: - resp = await client.post("/api/user-defined-pages", json={"id": "denied-page", "title": "禁止"}) - assert resp.status_code == 403 - finally: - app.dependency_overrides.pop(require_admin, None) - - -@pytest.mark.asyncio -async def test_admin_required_for_build_and_api_reload(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - from fastapi import HTTPException, Request - - await client.post("/api/user-defined-pages", json={"id": "admin-guard-page", "title": "权限页"}) - user_defined_pages_env.save_source_file( - "admin-guard-page", - "api/routes.yaml", - "routes:\n - method: GET\n path: /x\n handler: handlers.x\n", - ) - user_defined_pages_env.save_source_file( - "admin-guard-page", - "api/handlers.py", - "def x(ctx, request):\n return {'ok': True}\n", - ) - - def _deny_admin(_request: Request): - raise HTTPException(status_code=403, detail="仅管理员可执行该操作") - - app.dependency_overrides[require_admin] = _deny_admin - try: - build_resp = await client.post("/api/user-defined-pages/admin-guard-page/build") - assert build_resp.status_code == 403 - - reload_resp = await client.post("/api/user-defined-pages/admin-guard-page/api/reload") - assert reload_resp.status_code == 403 - finally: - app.dependency_overrides.pop(require_admin, None) - - -@pytest.mark.asyncio -async def test_page_api_routes_reload_and_dispatch(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - await client.post("/api/user-defined-pages", json={"id": "api-page", "title": "接口页"}) - user_defined_pages_env.save_source_file( - "api-page", - "api/routes.yaml", - "routes:\n - method: GET\n path: /stats\n handler: handlers.get_stats\n", - ) - user_defined_pages_env.save_source_file( - "api-page", - "api/handlers.py", - "def get_stats(ctx, request):\n return {'ok': True}\n", - ) - - list_resp = await client.get("/api/user-defined-pages/api-page/api") - assert list_resp.status_code == 200 - assert list_resp.json()[0]["path"] == "/stats" - - reload_resp = await client.post("/api/user-defined-pages/api-page/api/reload") - assert reload_resp.status_code == 200 - assert reload_resp.json()["routes"][0]["handler"] == "handlers.get_stats" - - dispatch_resp = await client.get("/api/user-defined-pages/api-page/api/stats") - assert dispatch_resp.status_code == 200 - assert dispatch_resp.json()["ok"] is True - - -@pytest.mark.asyncio -async def test_export_and_import_user_defined_page(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - await client.post("/api/user-defined-pages", json={"id": "backup-page", "title": "备份页"}) - await client.put( - "/api/user-defined-pages/backup-page", - json={"sourcePath": "src/Page.tsx", "sourceContent": "export default function Page(){return
backup
;}"}, - ) - - export_resp = await client.get("/api/user-defined-pages/backup-page/export") - assert export_resp.status_code == 200 - assert export_resp.headers.get("content-type", "").startswith("application/zip") - - import_resp = await client.post( - "/api/user-defined-pages/import?overwrite=true", - files={"file": ("backup-page.zip", export_resp.content, "application/zip")}, - ) - assert import_resp.status_code == 200 - assert import_resp.json()["manifest"]["id"] == "backup-page" - - -@pytest.mark.asyncio -async def test_import_normalizes_manifest_identity(client: AsyncClient, user_defined_pages_env: UserDefinedPagesStore): - archive = _make_page_archive( - "fixed-page", - { - "id": "wrong-page", - "title": "导入页", - "route": "/user-defined-pages/wrong-page", - "icon": "LayoutDashboard", - "order": 10, - "enabled": True, - "placement": "home.after", - "entry": "src/index.tsx", - "updatedAt": 1, - }, - ) - - import_resp = await client.post( - "/api/user-defined-pages/import", - files={"file": ("fixed-page.zip", archive, "application/zip")}, - ) - - assert import_resp.status_code == 200, import_resp.text - body = import_resp.json() - assert body["manifest"]["id"] == "fixed-page" - assert body["manifest"]["route"] == "/user-defined-pages/fixed-page" - - list_resp = await client.get("/api/user-defined-pages") - assert list_resp.status_code == 200 - assert list_resp.json()[0]["id"] == "fixed-page" - assert list_resp.json()[0]["route"] == "/user-defined-pages/fixed-page" - - -@pytest.mark.asyncio -async def test_import_rejects_archives_with_too_many_files( - client: AsyncClient, - user_defined_pages_env: UserDefinedPagesStore, - monkeypatch, -): - monkeypatch.setattr(user_defined_pages_routes, "MAX_IMPORT_FILES", 1) - archive = _make_page_archive( - "too-many", - { - "id": "too-many", - "title": "过多文件", - "route": "/user-defined-pages/too-many", - "icon": "LayoutDashboard", - "order": 10, - "enabled": True, - "placement": "home.after", - "entry": "src/index.tsx", - "updatedAt": 1, - }, - ) - - import_resp = await client.post( - "/api/user-defined-pages/import", - files={"file": ("too-many.zip", archive, "application/zip")}, - ) - - assert import_resp.status_code == 400 - assert "too many files" in import_resp.text diff --git a/tests/server/routes/test_workflow_poller_routes.py b/tests/server/routes/test_workflow_poller_routes.py index ff8679fd2..f683b4057 100644 --- a/tests/server/routes/test_workflow_poller_routes.py +++ b/tests/server/routes/test_workflow_poller_routes.py @@ -16,8 +16,8 @@ async def test_save_poller_config_restarts_manager( ) -> None: writes: list[tuple[str, dict[str, Any]]] = [] - async def _fake_write(key: Any, value: dict[str, Any]) -> None: - writes.append((key, value)) + async def _fake_put_config(workflow_id: str, config: dict[str, Any], *, kind: str | None = None) -> None: + writes.append((f"{kind}/{workflow_id}", config)) async def _fake_restart(workflow_id: str) -> dict[str, Any]: assert workflow_id == "wf-1" @@ -26,9 +26,11 @@ async def _fake_restart(workflow_id: str) -> dict[str, Any]: monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None, + lambda workflow_id: ( + {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None + ), ) - monkeypatch.setattr(workflow_routes.Storage, "write", _fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "put_config", _fake_put_config) monkeypatch.setattr( "flocks.workflow.poller_manager.default_manager", SimpleNamespace(restart_workflow=_fake_restart), @@ -64,8 +66,8 @@ async def test_save_poller_config_preserves_cron_schedule( writes: list[tuple[str, dict[str, Any]]] = [] persisted_sources: list[dict[str, Any]] = [] - async def _fake_write(key: Any, value: dict[str, Any]) -> None: - writes.append((key, value)) + async def _fake_put_config(workflow_id: str, config: dict[str, Any], *, kind: str | None = None) -> None: + writes.append((f"{kind}/{workflow_id}", config)) async def _fake_persist( _workflow_id: str, @@ -81,9 +83,11 @@ async def _fake_restart(workflow_id: str) -> dict[str, Any]: monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None, + lambda workflow_id: ( + {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None + ), ) - monkeypatch.setattr(workflow_routes.Storage, "write", _fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "put_config", _fake_put_config) monkeypatch.setattr(workflow_routes, "_persist_workflow_triggers", _fake_persist) monkeypatch.setattr( "flocks.workflow.poller_manager.default_manager", @@ -119,8 +123,8 @@ async def test_get_poller_config_returns_saved_data( client: AsyncClient, monkeypatch: pytest.MonkeyPatch, ) -> None: - async def _fake_read(_key: Any, *_args: Any, **_kwargs: Any) -> dict[str, Any] | None: - if _key != "workflow_poller_config/wf-1": + async def _fake_get_config(workflow_id: str, *, kind: str = "workflow.integration-config") -> dict[str, Any] | None: + if workflow_id != "wf-1" or kind != "workflow_poller_config": return None return { "workflowId": "wf-1", @@ -131,7 +135,7 @@ async def _fake_read(_key: Any, *_args: Any, **_kwargs: Any) -> dict[str, Any] | "inputs": {"dedup_source_workflow_name": "stream_alert_denoise_gt_fast"}, } - monkeypatch.setattr(workflow_routes.Storage, "read", _fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "get_config", _fake_get_config) response = await client.get("/api/workflow/wf-1/poller-config") assert response.status_code == 200, response.text @@ -178,7 +182,9 @@ async def _fake_run_once(workflow_id: str) -> dict[str, Any]: monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None, + lambda workflow_id: ( + {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None + ), ) monkeypatch.setattr( "flocks.workflow.poller_manager.default_manager", diff --git a/tests/server/routes/test_workflow_publish_api.py b/tests/server/routes/test_workflow_publish_api.py index 8ac6afc40..a715239f4 100644 --- a/tests/server/routes/test_workflow_publish_api.py +++ b/tests/server/routes/test_workflow_publish_api.py @@ -20,16 +20,20 @@ async def test_publish_workflow_as_api_reuses_key_for_runtime( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda requested_id: { - "id": requested_id, - "name": "Demo Workflow", - "workflowJson": { + lambda requested_id: ( + { "id": requested_id, - "start": "n1", - "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], - "edges": [], - }, - } if requested_id == workflow_id else None, + "name": "Demo Workflow", + "workflowJson": { + "id": requested_id, + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + }, + } + if requested_id == workflow_id + else None + ), ) monkeypatch.setattr(workflow_routes.Config, "get_data_path", lambda: tmp_path) @@ -47,12 +51,14 @@ async def fake_publish_workflow( driver: str | None = None, api_key: str | None = None, ) -> dict[str, Any]: - publish_calls.append({ - "workflow_id": requested_id, - "image": image, - "driver": driver, - "api_key": api_key, - }) + publish_calls.append( + { + "workflow_id": requested_id, + "image": image, + "driver": driver, + "api_key": api_key, + } + ) return { "serviceUrl": "http://127.0.0.1:19000", "containerName": "local-wf-1", @@ -60,8 +66,8 @@ async def fake_publish_workflow( "apiKey": api_key, } - monkeypatch.setattr(workflow_routes.Storage, "read", fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_put", fake_write) monkeypatch.setattr(workflow_routes, "publish_workflow", fake_publish_workflow) result = await workflow_routes.publish_workflow_as_api( @@ -69,12 +75,14 @@ async def fake_publish_workflow( workflow_routes.WorkflowCenterPublishRequest(driver="local"), ) - assert publish_calls == [{ - "workflow_id": workflow_id, - "image": None, - "driver": "local", - "api_key": existing_key, - }] + assert publish_calls == [ + { + "workflow_id": workflow_id, + "image": None, + "driver": "local", + "api_key": existing_key, + } + ] assert result["apiKey"] == existing_key assert writes[workflow_routes._api_service_key(workflow_id)]["apiKey"] == existing_key @@ -124,12 +132,14 @@ async def fake_publish_workflow( driver: str | None = None, api_key: str | None = None, ) -> dict[str, Any]: - publish_calls.append({ - "workflow_id": requested_id, - "image": image, - "driver": driver, - "api_key": api_key, - }) + publish_calls.append( + { + "workflow_id": requested_id, + "image": image, + "driver": driver, + "api_key": api_key, + } + ) return { "serviceUrl": "http://127.0.0.1:19001", "containerName": "flocks-wf-wf-1-rel-1", @@ -138,9 +148,9 @@ async def fake_publish_workflow( "apiKey": api_key, } - monkeypatch.setattr(workflow_routes.Storage, "list_keys", fake_list_keys) - monkeypatch.setattr(workflow_routes.Storage, "read", fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_list_keys", fake_list_keys) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_put", fake_write) monkeypatch.setattr(workflow_routes, "get_workflow_health", fake_health) monkeypatch.setattr(workflow_routes, "_prepare_workflow_api_registry", fake_prepare_registry) monkeypatch.setattr(workflow_routes, "publish_workflow", fake_publish_workflow) @@ -149,12 +159,14 @@ async def fake_publish_workflow( assert result["checked"] == 1 assert result["restarted"] == 1 - assert publish_calls == [{ - "workflow_id": workflow_id, - "image": "custom-image:latest", - "driver": "docker", - "api_key": existing_key, - }] + assert publish_calls == [ + { + "workflow_id": workflow_id, + "image": "custom-image:latest", + "driver": "docker", + "api_key": existing_key, + } + ] assert store[service_key]["status"] == "running" assert store[service_key]["apiKey"] == existing_key assert store[service_key]["serviceUrl"] == "http://127.0.0.1:19001" @@ -208,9 +220,9 @@ async def fake_publish_workflow( "apiKey": api_key, } - monkeypatch.setattr(workflow_routes.Storage, "list_keys", fake_list_keys) - monkeypatch.setattr(workflow_routes.Storage, "read", fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_list_keys", fake_list_keys) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_put", fake_write) monkeypatch.setattr(workflow_routes, "get_workflow_health", fake_health) monkeypatch.setattr(workflow_routes, "_prepare_workflow_api_registry", fake_prepare_registry) monkeypatch.setattr(workflow_routes, "publish_workflow", fake_publish_workflow) @@ -251,8 +263,8 @@ async def fake_health(requested_id: str) -> dict[str, Any]: health_calls.append(requested_id) return {"ok": True} - monkeypatch.setattr(workflow_routes.Storage, "list_keys", fake_list_keys) - monkeypatch.setattr(workflow_routes.Storage, "read", fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_list_keys", fake_list_keys) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", fake_read) monkeypatch.setattr(workflow_routes, "get_workflow_health", fake_health) result = await workflow_routes.reconcile_published_workflow_api_services() @@ -286,8 +298,8 @@ async def fake_health(requested_id: str) -> dict[str, Any]: health_calls.append(requested_id) return {"ok": False, "published": False} - monkeypatch.setattr(workflow_routes.Storage, "read", fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_put", fake_write) monkeypatch.setattr(workflow_routes, "get_workflow_health", fake_health) result = await workflow_routes.get_workflow_service(workflow_id) @@ -327,9 +339,9 @@ async def fake_read(key: Any, *_args: Any, **_kwargs: Any) -> Any: async def fake_write(key: Any, value: Any) -> None: writes.append((key, value)) - monkeypatch.setattr(workflow_routes.Storage, "list_keys", fake_list_keys) - monkeypatch.setattr(workflow_routes.Storage, "read", fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_list_keys", fake_list_keys) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_put", fake_write) result = await workflow_routes.list_workflow_services() diff --git a/tests/server/routes/test_workflow_run_route.py b/tests/server/routes/test_workflow_run_route.py index af1b4aa7a..b3654ddd9 100644 --- a/tests/server/routes/test_workflow_run_route.py +++ b/tests/server/routes/test_workflow_run_route.py @@ -8,6 +8,221 @@ import flocks.server.routes.workflow as workflow_module +def _minimal_workflow_json(metadata=None): + workflow = { + "name": "minimal", + "start": "start", + "nodes": [{"id": "start", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + if metadata is not None: + workflow["metadata"] = metadata + return workflow + + +def _two_node_workflow_json(edge): + return { + "name": "two-node", + "start": "prepare_message", + "nodes": [ + { + "id": "prepare_message", + "type": "python", + "code": "outputs['message_text'] = inputs.get('message', '')", + }, + { + "id": "transform_message", + "type": "python", + "code": "outputs['final_message'] = inputs.get('message_text', '').upper()", + }, + ], + "edges": [edge], + } + + +@pytest.mark.asyncio +async def test_create_workflow_applies_vertex_cache_runtime_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + writes: list[dict] = [] + + def _fake_write_workflow_to_fs(workflow_id, workflow_json, meta, *args, **kwargs): + writes.append({"workflow_id": workflow_id, "workflow_json": workflow_json, "meta": meta}) + + monkeypatch.setattr(workflow_module, "_write_workflow_to_fs", _fake_write_workflow_to_fs) + monkeypatch.setattr(workflow_module, "_get_workflow_stats", AsyncMock(return_value={})) + monkeypatch.setattr(workflow_module, "publish_event", AsyncMock(return_value=None)) + + req = workflow_module.WorkflowCreateRequest( + name="new workflow", + workflowJson=_minimal_workflow_json(), + ) + + result = await workflow_module.create_workflow(req) + + runtime = result.workflowJson["metadata"]["runtime"] + assert runtime["strict_edge_mapping"] is True + assert runtime["dataflow_mode"] == "vertex_cache" + assert writes[0]["workflow_json"]["metadata"]["runtime"] == runtime + + +@pytest.mark.asyncio +async def test_create_workflow_preserves_explicit_runtime_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + writes: list[dict] = [] + + def _fake_write_workflow_to_fs(workflow_id, workflow_json, meta, *args, **kwargs): + writes.append({"workflow_id": workflow_id, "workflow_json": workflow_json, "meta": meta}) + + monkeypatch.setattr(workflow_module, "_write_workflow_to_fs", _fake_write_workflow_to_fs) + monkeypatch.setattr(workflow_module, "_get_workflow_stats", AsyncMock(return_value={})) + monkeypatch.setattr(workflow_module, "publish_event", AsyncMock(return_value=None)) + + req = workflow_module.WorkflowCreateRequest( + name="legacy workflow", + workflowJson=_minimal_workflow_json( + { + "runtime": { + "strict_edge_mapping": False, + "dataflow_mode": "legacy", + } + } + ), + ) + + result = await workflow_module.create_workflow(req) + + runtime = result.workflowJson["metadata"]["runtime"] + assert runtime["strict_edge_mapping"] is False + assert runtime["dataflow_mode"] == "legacy" + assert writes[0]["workflow_json"]["metadata"]["runtime"] == runtime + + +@pytest.mark.asyncio +async def test_create_workflow_rejects_unmapped_edges_after_strict_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + write_workflow = Mock() + monkeypatch.setattr(workflow_module, "_write_workflow_to_fs", write_workflow) + + req = workflow_module.WorkflowCreateRequest( + name="new workflow", + workflowJson=_two_node_workflow_json( + {"from": "prepare_message", "to": "transform_message", "order": 0} + ), + ) + + with pytest.raises(workflow_module.HTTPException) as exc_info: + await workflow_module.create_workflow(req) + + assert exc_info.value.status_code == 400 + assert "Workflow strict edge mapping failed" in str(exc_info.value.detail) + assert "prepare_message" in str(exc_info.value.detail) + write_workflow.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_workflow_accepts_explicit_mapping_after_strict_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + writes: list[dict] = [] + + def _fake_write_workflow_to_fs(workflow_id, workflow_json, meta, *args, **kwargs): + writes.append({"workflow_id": workflow_id, "workflow_json": workflow_json, "meta": meta}) + + monkeypatch.setattr(workflow_module, "_write_workflow_to_fs", _fake_write_workflow_to_fs) + monkeypatch.setattr(workflow_module, "_get_workflow_stats", AsyncMock(return_value={})) + monkeypatch.setattr(workflow_module, "publish_event", AsyncMock(return_value=None)) + + req = workflow_module.WorkflowCreateRequest( + name="new mapped workflow", + workflowJson=_two_node_workflow_json( + { + "from": "prepare_message", + "to": "transform_message", + "order": 0, + "mapping": {"message_text": "message_text"}, + } + ), + ) + + result = await workflow_module.create_workflow(req) + + runtime = result.workflowJson["metadata"]["runtime"] + assert runtime["strict_edge_mapping"] is True + assert runtime["dataflow_mode"] == "vertex_cache" + assert writes[0]["workflow_json"]["edges"][0]["mapping"] == {"message_text": "message_text"} + + +@pytest.mark.asyncio +async def test_create_workflow_rejects_schema_lint_errors( + monkeypatch: pytest.MonkeyPatch, +) -> None: + write_workflow = Mock() + monkeypatch.setattr(workflow_module, "_write_workflow_to_fs", write_workflow) + + workflow_json = _two_node_workflow_json( + { + "from": "prepare_message", + "to": "transform_message", + "order": 0, + "mapping": {"message_text": "missing_message_text"}, + } + ) + workflow_json["nodes"][0]["outputSchema"] = {"message_text": {"type": "str"}} + + req = workflow_module.WorkflowCreateRequest( + name="bad schema workflow", + workflowJson=workflow_json, + ) + + with pytest.raises(workflow_module.HTTPException) as exc_info: + await workflow_module.create_workflow(req) + + assert exc_info.value.status_code == 400 + assert "Workflow schema lint failed" in str(exc_info.value.detail) + assert "schema_mapping_src_not_declared" in str(exc_info.value.detail) + write_workflow.assert_not_called() + + +@pytest.mark.asyncio +async def test_update_workflow_rejects_unmapped_edges_when_strict( + monkeypatch: pytest.MonkeyPatch, +) -> None: + write_workflow = Mock() + existing = { + "id": "wf-1", + "name": "existing workflow", + "category": "default", + "status": "draft", + "createdAt": 1, + "updatedAt": 1, + "source": "global", + "workflowJson": _minimal_workflow_json( + {"runtime": {"strict_edge_mapping": True, "dataflow_mode": "vertex_cache"}} + ), + "markdownContent": None, + "editMarkdownContent": None, + } + + monkeypatch.setattr(workflow_module, "_read_workflow_from_fs", lambda _workflow_id: dict(existing)) + monkeypatch.setattr(workflow_module, "_write_workflow_to_fs", write_workflow) + + req = workflow_module.WorkflowUpdateRequest( + workflowJson={ + **_two_node_workflow_json( + {"from": "prepare_message", "to": "transform_message", "order": 0} + ), + "metadata": {"runtime": {"strict_edge_mapping": True, "dataflow_mode": "vertex_cache"}}, + } + ) + + with pytest.raises(workflow_module.HTTPException) as exc_info: + await workflow_module.update_workflow("wf-1", req) + + assert exc_info.value.status_code == 400 + assert "Workflow strict edge mapping failed" in str(exc_info.value.detail) + assert "prepare_message" in str(exc_info.value.detail) + write_workflow.assert_not_called() + + @pytest.mark.asyncio async def test_run_workflow_execution_task_reuses_existing_mcp_without_reinit( monkeypatch: pytest.MonkeyPatch, @@ -64,7 +279,7 @@ async def test_save_kafka_config_persists_consumer_settings( ) -> None: from flocks.ingest.kafka import manager as kafka_manager - storage_write = AsyncMock(return_value=None) + put_config = AsyncMock(return_value=None) restart_workflow = AsyncMock(return_value={"state": "running", "error": None}) persisted_triggers: list[list[str]] = [] @@ -73,7 +288,7 @@ async def test_save_kafka_config_persists_consumer_settings( "_read_workflow_from_fs", lambda _workflow_id: {"id": "wf-input", "workflowJson": {}}, ) - monkeypatch.setattr(workflow_module.Storage, "write", storage_write) + monkeypatch.setattr(workflow_module.WorkflowStore, "put_config", put_config) monkeypatch.setattr(kafka_manager.default_manager, "restart_workflow", restart_workflow) monkeypatch.setattr(workflow_module, "_get_workflow_trigger_defs", AsyncMock(return_value=[])) @@ -105,8 +320,10 @@ async def _fake_persist(workflow_id: str, workflow_data: dict, triggers: list) - response = await workflow_module.save_kafka_config("wf-input", req) assert response == {"ok": True, "consumer": {"state": "running", "error": None}} - storage_write.assert_awaited_once() - _, saved_config = storage_write.await_args.args + put_config.assert_awaited_once() + workflow_id, saved_config = put_config.await_args.args + assert workflow_id == "wf-input" + assert put_config.await_args.kwargs["kind"] == "workflow_kafka_config" assert saved_config["enabled"] is True assert saved_config["inputBroker"] == "localhost:9092" assert saved_config["inputTopic"] == "workflow-input" @@ -129,7 +346,7 @@ async def test_save_syslog_config_persists_listener_settings( ) -> None: from flocks.ingest.syslog import manager as syslog_manager - storage_write = AsyncMock(return_value=None) + put_config = AsyncMock(return_value=None) restart_workflow = AsyncMock(return_value={"state": "listening", "error": None}) persisted_triggers: list[list[str]] = [] @@ -138,7 +355,7 @@ async def test_save_syslog_config_persists_listener_settings( "_read_workflow_from_fs", lambda _workflow_id: {"id": "wf-input", "workflowJson": {}}, ) - monkeypatch.setattr(workflow_module.Storage, "write", storage_write) + monkeypatch.setattr(workflow_module.WorkflowStore, "put_config", put_config) monkeypatch.setattr(syslog_manager.default_manager, "restart_workflow", restart_workflow) monkeypatch.setattr(workflow_module, "_get_workflow_trigger_defs", AsyncMock(return_value=[])) @@ -166,8 +383,10 @@ async def _fake_persist(workflow_id: str, workflow_data: dict, triggers: list) - response = await workflow_module.save_syslog_config("wf-input", req) assert response == {"ok": True, "listener": {"state": "listening", "error": None}} - storage_write.assert_awaited_once() - _, saved_config = storage_write.await_args.args + put_config.assert_awaited_once() + workflow_id, saved_config = put_config.await_args.args + assert workflow_id == "wf-input" + assert put_config.await_args.kwargs["kind"] == "workflow_syslog_config" assert saved_config["enabled"] is True assert saved_config["protocol"] == "udp" assert saved_config["host"] == "0.0.0.0" diff --git a/tests/server/routes/test_workflow_trigger_routes.py b/tests/server/routes/test_workflow_trigger_routes.py index d811c5034..bcd8d08e2 100644 --- a/tests/server/routes/test_workflow_trigger_routes.py +++ b/tests/server/routes/test_workflow_trigger_routes.py @@ -20,22 +20,26 @@ async def test_list_workflow_triggers_returns_unified_status( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: { - "id": workflow_id, - "workflowJson": { - "start": "n1", - "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], - "edges": [], - "triggers": [ - { - "id": "schedule-default", - "type": "schedule", - "enabled": True, - "source": {"intervalSeconds": 60}, - } - ], - }, - } if workflow_id == "wf-1" else None, + lambda workflow_id: ( + { + "id": workflow_id, + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], + "edges": [], + "triggers": [ + { + "id": "schedule-default", + "type": "schedule", + "enabled": True, + "source": {"intervalSeconds": 60}, + } + ], + }, + } + if workflow_id == "wf-1" + else None + ), ) async def _fake_statuses(_workflow_id: str, _workflow_json: dict[str, Any]) -> list[dict[str, Any]]: @@ -70,15 +74,19 @@ async def test_list_workflow_triggers_respects_explicit_empty_trigger_list( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: { - "id": workflow_id, - "workflowJson": { - "start": "n1", - "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], - "edges": [], - "triggers": [], - }, - } if workflow_id == "wf-1" else None, + lambda workflow_id: ( + { + "id": workflow_id, + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], + "edges": [], + "triggers": [], + }, + } + if workflow_id == "wf-1" + else None + ), ) async def _fake_legacy_triggers(_workflow_id: str) -> list[Any]: @@ -148,27 +156,31 @@ async def test_workflow_config_response_keeps_template_separate_from_runtime( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: { - "id": workflow_id, - "name": "demo", - "workflowJson": { - "start": "n1", - "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], - "edges": [], - "triggers": [ - { - "id": "syslog-default", - "type": "syslog", - "enabled": True, - } - ], - }, - } if workflow_id == "wf-1" else None, + lambda workflow_id: ( + { + "id": workflow_id, + "name": "demo", + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], + "edges": [], + "triggers": [ + { + "id": "syslog-default", + "type": "syslog", + "enabled": True, + } + ], + }, + } + if workflow_id == "wf-1" + else None + ), ) stored_writes: dict[str, Any] = {} - async def _fake_read(key: Any, _model: Any = None) -> dict[str, Any] | None: + async def _fake_kv_get(key: Any) -> dict[str, Any] | None: if key == workflow_routes._api_service_key("wf-1"): return { "workflowId": "wf-1", @@ -177,8 +189,8 @@ async def _fake_read(key: Any, _model: Any = None) -> dict[str, Any] | None: } return None - async def _fake_write(key: Any, value: Any) -> None: - stored_writes[str(key)] = value + async def _fake_put_config(workflow_id: str, config: dict[str, Any], *, kind: str | None = None) -> None: + stored_writes[workflow_routes._workflow_integration_config_key(workflow_id)] = config async def _fake_statuses(_workflow_id: str, _workflow_json: dict[str, Any]) -> list[dict[str, Any]]: return [ @@ -190,8 +202,8 @@ async def _fake_statuses(_workflow_id: str, _workflow_json: dict[str, Any]) -> l } ] - monkeypatch.setattr(workflow_routes.Storage, "read", _fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", _fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", _fake_kv_get) + monkeypatch.setattr(workflow_routes.WorkflowStore, "put_config", _fake_put_config) monkeypatch.setattr( workflow_routes, "default_trigger_runtime", @@ -251,26 +263,30 @@ async def test_workflow_config_prefers_storage_over_config_file( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: { - "id": workflow_id, - "name": "demo", - "workflowJson": {"start": "n1", "nodes": [], "edges": [], "triggers": []}, - } if workflow_id == "wf-1" else None, + lambda workflow_id: ( + { + "id": workflow_id, + "name": "demo", + "workflowJson": {"start": "n1", "nodes": [], "edges": [], "triggers": []}, + } + if workflow_id == "wf-1" + else None + ), ) - async def _fake_read(key: Any, _model: Any = None) -> dict[str, Any] | None: - if key == workflow_routes._workflow_integration_config_key("wf-1"): + async def _fake_get_config(workflow_id: str, *, kind: str = "workflow.integration-config") -> dict[str, Any] | None: + if workflow_id == "wf-1" and kind == "workflow.integration-config": return stored_config return None - async def _fake_write(key: Any, value: Any) -> None: - write_calls.append((key, value)) + async def _fake_put_config(workflow_id: str, config: dict[str, Any], *, kind: str | None = None) -> None: + write_calls.append((workflow_routes._workflow_integration_config_key(workflow_id), config)) async def _fake_statuses(_workflow_id: str, _workflow_json: dict[str, Any]) -> list[dict[str, Any]]: return [] - monkeypatch.setattr(workflow_routes.Storage, "read", _fake_read) - monkeypatch.setattr(workflow_routes.Storage, "write", _fake_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "get_config", _fake_get_config) + monkeypatch.setattr(workflow_routes.WorkflowStore, "put_config", _fake_put_config) monkeypatch.setattr( workflow_routes, "default_trigger_runtime", @@ -323,10 +339,9 @@ async def test_update_workflow_config_writes_template_without_mutating_runtime( monkeypatch.chdir(workspace) monkeypatch.setattr(fs_store, "_workspace_root", None) - original_storage_read = workflow_routes.Storage.read stored_writes: dict[str, Any] = {} - async def _fake_storage_read(key: Any, *args: Any, **kwargs: Any) -> Any: + async def _fake_kv_get(key: Any) -> Any: if key == workflow_routes._api_service_key(workflow_id): return { "workflowId": workflow_id, @@ -338,10 +353,10 @@ async def _fake_storage_read(key: Any, *args: Any, **kwargs: Any) -> Any: "driver": "local", "publishedAt": 123, } - return await original_storage_read(key, *args, **kwargs) + return None - async def _fake_storage_write(key: Any, value: Any) -> None: - stored_writes[str(key)] = value + async def _fake_put_config(workflow_id: str, config: dict[str, Any], *, kind: str | None = None) -> None: + stored_writes[workflow_routes._workflow_integration_config_key(workflow_id)] = config async def _fake_statuses(_workflow_id: str, _workflow_json: dict[str, Any]) -> list[dict[str, Any]]: return [ @@ -353,8 +368,8 @@ async def _fake_statuses(_workflow_id: str, _workflow_json: dict[str, Any]) -> l } ] - monkeypatch.setattr(workflow_routes.Storage, "read", _fake_storage_read) - monkeypatch.setattr(workflow_routes.Storage, "write", _fake_storage_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", _fake_kv_get) + monkeypatch.setattr(workflow_routes.WorkflowStore, "put_config", _fake_put_config) monkeypatch.setattr( workflow_routes, "default_trigger_runtime", @@ -422,7 +437,7 @@ async def test_delete_workflow_service_removes_runtime_service_record( monkeypatch: pytest.MonkeyPatch, ) -> None: workflow_id = "wf-service-delete" - await workflow_routes.Storage.write( + await workflow_routes.WorkflowStore.kv_put( workflow_routes._api_service_key(workflow_id), { "workflowId": workflow_id, @@ -448,7 +463,7 @@ async def _fake_stop_service(wid: str) -> dict[str, Any]: assert response.status_code == 200, response.text assert response.json() == {"ok": True, "workflowId": workflow_id} assert stopped == [workflow_id] - assert await workflow_routes.Storage.read(workflow_routes._api_service_key(workflow_id)) is None + assert await workflow_routes.WorkflowStore.kv_get(workflow_routes._api_service_key(workflow_id)) is None @pytest.mark.asyncio @@ -459,11 +474,15 @@ async def test_update_workflow_config_rejects_mismatched_workflow_id( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: { - "id": workflow_id, - "name": "demo", - "workflowJson": {"start": "n1", "nodes": [], "edges": []}, - } if workflow_id == "wf-1" else None, + lambda workflow_id: ( + { + "id": workflow_id, + "name": "demo", + "workflowJson": {"start": "n1", "nodes": [], "edges": []}, + } + if workflow_id == "wf-1" + else None + ), ) response = await client.put( @@ -506,27 +525,35 @@ async def test_delete_workflow_cleans_directory_and_storage( monkeypatch.chdir(workspace) monkeypatch.setattr(fs_store, "_workspace_root", None) - storage_keys = [ - workflow_routes._workflow_stats_key(workflow_id), - workflow_routes._workflow_integration_config_key(workflow_id), + await workflow_routes.WorkflowStore.put_stats(workflow_id, {"callCount": 1}) + await workflow_routes.WorkflowStore.put_config(workflow_id, {"workflowId": workflow_id}) + await workflow_routes.WorkflowStore.put_config( + workflow_id, + {"workflowId": workflow_id}, + kind="workflow_syslog_config", + ) + await workflow_routes.WorkflowStore.put_config( + workflow_id, + {"workflowId": workflow_id}, + kind="workflow_kafka_config", + ) + await workflow_routes.WorkflowStore.put_config( + workflow_id, + {"workflowId": workflow_id}, + kind="workflow_poller_config", + ) + kv_keys = [ workflow_routes._api_service_key(workflow_id), - workflow_routes._syslog_config_key(workflow_id), - workflow_routes._kafka_config_key(workflow_id), - f"workflow_poller_config/{workflow_id}", f"workflow_registry/{workflow_id}", f"workflow_runtime/{workflow_id}", f"workflow_local_pid/{workflow_id}", f"workflow_release/{workflow_id}/active", f"workflow_release/{workflow_id}/rel-1", - workflow_routes._workflow_execution_key("exec-delete"), - "workflow_execution_step/exec-delete/00000001", - f"workflow_execution_index/{workflow_id}/00000000000000000001/exec-delete", ] - for key in storage_keys: - payload = {"workflowId": workflow_id} - if key == workflow_routes._workflow_execution_key("exec-delete"): - payload = {"id": "exec-delete", "workflowId": workflow_id} - await workflow_routes.Storage.write(key, payload) + for key in kv_keys: + await workflow_routes.WorkflowStore.kv_put(key, {"workflowId": workflow_id}) + await workflow_routes.WorkflowStore.upsert_execution({"id": "exec-delete", "workflowId": workflow_id}) + await workflow_routes.WorkflowStore.record_step("exec-delete", 1, {"workflowId": workflow_id}) stopped: list[Any] = [] @@ -552,9 +579,15 @@ async def _fake_restart_workflow(wid: str, workflow_json: dict[str, Any]) -> dic assert not service_dir.exists() assert ("service", workflow_id) in stopped assert ("triggers", workflow_id, {"triggers": []}) in stopped - for key in storage_keys: - assert await workflow_routes.Storage.read(key) is None - assert await workflow_routes.Storage.list(f"workflow_release/{workflow_id}/") == [] + assert await workflow_routes.WorkflowStore.get_stats(workflow_id) is None + assert await workflow_routes.WorkflowStore.get_config(workflow_id) is None + assert await workflow_routes.WorkflowStore.get_config(workflow_id, kind="workflow_syslog_config") is None + assert await workflow_routes.WorkflowStore.get_config(workflow_id, kind="workflow_kafka_config") is None + assert await workflow_routes.WorkflowStore.get_config(workflow_id, kind="workflow_poller_config") is None + assert await workflow_routes.WorkflowStore.get_execution("exec-delete") is None + for key in kv_keys: + assert await workflow_routes.WorkflowStore.kv_get(key) is None + assert await workflow_routes.WorkflowStore.kv_list(f"workflow_release/{workflow_id}/") == [] @pytest.mark.asyncio @@ -675,22 +708,26 @@ async def test_create_workflow_trigger_rejects_multiple_legacy_singletons( monkeypatch.setattr( workflow_routes, "_read_workflow_from_fs", - lambda workflow_id: { - "id": workflow_id, - "workflowJson": { - "start": "n1", - "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], - "edges": [], - "triggers": [ - { - "id": "schedule-default", - "type": "schedule", - "enabled": True, - "source": {"intervalSeconds": 60}, - } - ], - }, - } if workflow_id == "wf-1" else None, + lambda workflow_id: ( + { + "id": workflow_id, + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "result = {'ok': True}"}], + "edges": [], + "triggers": [ + { + "id": "schedule-default", + "type": "schedule", + "enabled": True, + "source": {"intervalSeconds": 60}, + } + ], + }, + } + if workflow_id == "wf-1" + else None + ), ) response = await client.post( @@ -982,10 +1019,9 @@ async def test_sync_workflow_config_writes_publish_and_trigger_capabilities( monkeypatch.chdir(workspace) monkeypatch.setattr(fs_store, "_workspace_root", None) - original_storage_read = workflow_routes.Storage.read stored_writes: dict[str, Any] = {} - async def _fake_storage_read(key: Any, *args: Any, **kwargs: Any) -> Any: + async def _fake_kv_get(key: Any) -> Any: if key == workflow_routes._api_service_key(workflow_id): return { "workflowId": workflow_id, @@ -997,13 +1033,13 @@ async def _fake_storage_read(key: Any, *args: Any, **kwargs: Any) -> Any: "driver": "local", "publishedAt": 123, } - return await original_storage_read(key, *args, **kwargs) + return None - async def _fake_storage_write(key: Any, value: Any) -> None: - stored_writes[str(key)] = value + async def _fake_put_config(workflow_id: str, config: dict[str, Any], *, kind: str | None = None) -> None: + stored_writes[workflow_routes._workflow_integration_config_key(workflow_id)] = config - monkeypatch.setattr(workflow_routes.Storage, "read", _fake_storage_read) - monkeypatch.setattr(workflow_routes.Storage, "write", _fake_storage_write) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", _fake_kv_get) + monkeypatch.setattr(workflow_routes.WorkflowStore, "put_config", _fake_put_config) response = await client.post(f"/api/workflow/{workflow_id}/config/sync") @@ -1049,10 +1085,10 @@ async def test_persist_workflow_triggers_does_not_overwrite_config_template( monkeypatch.chdir(workspace) monkeypatch.setattr(fs_store, "_workspace_root", None) - async def _fake_storage_read(_key: Any, *_args: Any, **_kwargs: Any) -> None: + async def _fake_kv_get_none(_key: Any) -> None: return None - monkeypatch.setattr(workflow_routes.Storage, "read", _fake_storage_read) + monkeypatch.setattr(workflow_routes.WorkflowStore, "kv_get", _fake_kv_get_none) workflow_data = { "id": workflow_id, @@ -1129,5 +1165,5 @@ async def test_sync_workflow_config_preserves_existing_template( assert body["config"]["publish"] == {"type": "api_service"} assert body["config"]["triggers"] == [] assert config_path.read_text(encoding="utf-8") == before - stored = await workflow_routes.Storage.read(workflow_routes._workflow_integration_config_key(workflow_id)) + stored = await workflow_routes.WorkflowStore.get_config(workflow_id) assert stored == body["config"] diff --git a/tests/session/test_callable_schema.py b/tests/session/test_callable_schema.py index 3eb823573..f1d8a572b 100644 --- a/tests/session/test_callable_schema.py +++ b/tests/session/test_callable_schema.py @@ -141,13 +141,13 @@ async def test_callable_schema_keeps_user_plugin_tools_visible(monkeypatch: pyte @pytest.mark.asyncio -async def test_callable_schema_dynamically_exposes_device_context_when_enabled_devices_exist( +async def test_callable_schema_dynamically_exposes_device_manage_when_enabled_devices_exist( monkeypatch: pytest.MonkeyPatch, ) -> None: tools = [ _tool("question", ToolCategory.SYSTEM), _tool("tool_search", ToolCategory.SYSTEM), - _tool("device_context", ToolCategory.SYSTEM), + _tool("device_manage", ToolCategory.SYSTEM), ] monkeypatch.setattr("flocks.session.callable_schema.ToolRegistry.list_tools", lambda: tools) @@ -157,8 +157,8 @@ async def test_callable_schema_dynamically_exposes_device_context_when_enabled_d ) monkeypatch.setattr( "flocks.session.callable_schema.ToolRegistry.get", - lambda name: SimpleNamespace(info=_tool("device_context", ToolCategory.SYSTEM)) - if name == "device_context" + lambda name: SimpleNamespace(info=_tool(name, ToolCategory.SYSTEM)) + if name == "device_manage" else None, ) monkeypatch.setattr( @@ -169,5 +169,5 @@ async def test_callable_schema_dynamically_exposes_device_context_when_enabled_d result = await list_session_callable_tool_infos(session_id="session-device-aware") names = [tool.name for tool in result.tool_infos] - assert set(names) == {"device_context", "question", "tool_search"} - assert "device_context" in result.metadata["alwaysLoadToolNames"] + assert set(names) == {"device_manage", "question", "tool_search"} + assert "device_manage" in result.metadata["alwaysLoadToolNames"] diff --git a/tests/session/test_message_parts_persistence.py b/tests/session/test_message_parts_persistence.py index fef4d425f..4cb058253 100644 --- a/tests/session/test_message_parts_persistence.py +++ b/tests/session/test_message_parts_persistence.py @@ -1,9 +1,12 @@ """Persistence tests for message parts storage formats.""" +import asyncio + import pytest from flocks.session.message import ( Message, + MessageCacheInvalidatedError, MessageRole, TextPart, UserMessageInfo, @@ -197,6 +200,198 @@ async def test_clear_removes_legacy_blob_and_per_message_keys() -> None: assert await Storage.list_keys(prefix=f"message_parts:{per_message_session_id}:") == [] +@pytest.mark.asyncio +async def test_clear_tolerates_cache_invalidation_during_full_parts_load(monkeypatch) -> None: + session_id = "ses_parts_clear_lru_race" + Message.invalidate_cache() + Message._lru[session_id] = True + Message._messages_cache[session_id] = [] + Message._parts_cache[session_id] = {} + Message._parts_revision_cache[session_id] = {} + Message._parts_serialized_cache[session_id] = {} + Message._parts_storage_format[session_id] = "per_message" + Message._parts_persisted_mids[session_id] = set() + Message._parts_fully_loaded.discard(session_id) + + async def fake_load_all_parts_locked(cls, sid: str, *, message_times: dict) -> None: + assert sid == session_id + Message.invalidate_cache(sid) + await asyncio.sleep(0) + Message._parts_cache[sid] = {} + Message._parts_revision_cache[sid] = {} + Message._parts_serialized_cache[sid] = {} + Message._parts_storage_format[sid] = "per_message" + Message._parts_persisted_mids[sid] = set() + + monkeypatch.setattr( + Message, + "_load_all_parts_locked", + classmethod(fake_load_all_parts_locked), + ) + + assert await Message.clear(session_id) == 0 + assert session_id not in Message._lru + assert session_id not in Message._parts_fully_loaded + + +@pytest.mark.asyncio +async def test_ensure_cache_raises_after_repeated_invalidation(monkeypatch) -> None: + session_id = "ses_parts_full_load_repeated_invalidate" + await _write_legacy_session(session_id, {"msg_a": "never stabilizes"}) + + async def fake_load_all_parts_locked(cls, sid: str, *, message_times: dict) -> None: + assert sid == session_id + Message.invalidate_cache(sid) + for index in range(Message._MAX_CACHE_GENERATIONS + 5): + Message.invalidate_cache(f"ses_parts_full_load_churn_{index}") + Message._parts_cache[sid] = { + "msg_a": [_text_part(sid, "msg_a", "partial")] + } + await asyncio.sleep(0) + + monkeypatch.setattr( + Message, + "_load_all_parts_locked", + classmethod(fake_load_all_parts_locked), + ) + + with pytest.raises(MessageCacheInvalidatedError): + await Message._ensure_cache(session_id) + + assert session_id not in Message._lru + assert session_id not in Message._messages_cache + assert session_id not in Message._parts_cache + assert session_id not in Message._parts_fully_loaded + + +@pytest.mark.asyncio +async def test_ensure_cache_retries_when_invalidated_during_full_parts_load(monkeypatch) -> None: + session_id = "ses_parts_full_load_invalidate_retry" + await _write_legacy_session(session_id, {"msg_a": "survives reload"}) + + original_load_all_parts_locked = Message._load_all_parts_locked + invalidated = False + + async def fake_load_all_parts_locked(cls, sid: str, *, message_times: dict) -> None: + nonlocal invalidated + assert sid == session_id + if not invalidated: + invalidated = True + Message.invalidate_cache(sid) + await asyncio.sleep(0) + return + await original_load_all_parts_locked(sid, message_times=message_times) + + monkeypatch.setattr( + Message, + "_load_all_parts_locked", + classmethod(fake_load_all_parts_locked), + ) + + await Message._ensure_cache(session_id) + + assert invalidated is True + assert session_id in Message._lru + assert session_id in Message._messages_cache + assert session_id in Message._parts_fully_loaded + messages = await Message.list(session_id) + assert [message.id for message in messages] == ["msg_a"] + + +@pytest.mark.asyncio +async def test_ensure_cache_ignores_unrelated_session_invalidation(monkeypatch) -> None: + session_id = "ses_parts_full_load_target" + unrelated_id = "ses_parts_full_load_unrelated" + await _write_legacy_session(session_id, {"msg_a": "target survives"}) + await _write_legacy_session(unrelated_id, {"msg_b": "unrelated"}) + + original_load_all_parts_locked = Message._load_all_parts_locked + load_calls = 0 + + async def fake_load_all_parts_locked(cls, sid: str, *, message_times: dict) -> None: + nonlocal load_calls + assert sid == session_id + load_calls += 1 + Message.invalidate_cache(unrelated_id) + await original_load_all_parts_locked(sid, message_times=message_times) + + monkeypatch.setattr( + Message, + "_load_all_parts_locked", + classmethod(fake_load_all_parts_locked), + ) + + await Message._ensure_cache(session_id) + + assert load_calls == 1 + assert session_id in Message._parts_fully_loaded + messages = await Message.list(session_id) + assert [message.id for message in messages] == ["msg_a"] + + +def test_session_cache_generation_map_is_bounded() -> None: + Message.invalidate_cache() + + for index in range(Message._MAX_CACHE_GENERATIONS + 25): + Message.invalidate_cache(f"ses_generation_bound_{index}") + + assert len(Message._session_cache_generations) == Message._MAX_CACHE_GENERATIONS + + +@pytest.mark.asyncio +async def test_message_list_recovers_when_lru_outlives_message_cache() -> None: + session_id = "ses_parts_stale_lru_without_messages" + await _write_legacy_session(session_id, {"msg_a": "restored from disk"}) + Message._lru[session_id] = True + Message._messages_cache.pop(session_id, None) + Message._parts_fully_loaded.add(session_id) + + messages = await Message.list(session_id) + + assert [message.id for message in messages] == ["msg_a"] + assert session_id in Message._lru + assert session_id in Message._messages_cache + assert session_id not in Message._parts_fully_loaded + + +@pytest.mark.asyncio +async def test_parts_without_session_uses_cache_snapshot(monkeypatch) -> None: + session_id = "ses_parts_snapshot_search" + Message.invalidate_cache() + Message._parts_cache[session_id] = {} + + async def fake_ensure_cache(cls, sid: str) -> None: + assert sid == session_id + Message._parts_cache["ses_parts_snapshot_added"] = {} + await asyncio.sleep(0) + + monkeypatch.setattr(Message, "_ensure_cache", classmethod(fake_ensure_cache)) + + assert await Message.parts("missing_message_id") == [] + + +@pytest.mark.asyncio +async def test_persist_parts_uses_snapshot_when_cache_changes(monkeypatch) -> None: + session_id = "ses_parts_persist_snapshot" + Message.invalidate_cache() + Message._parts_cache[session_id] = { + "msg_a": [_text_part(session_id, "msg_a", "a")], + } + Message._parts_serialized_cache[session_id] = {} + Message._parts_storage_format[session_id] = "per_message" + Message._parts_persisted_mids[session_id] = set() + + async def fake_storage_set(key: str, value, value_type: str = "json") -> None: + Message._parts_cache[session_id]["msg_b"] = [ + _text_part(session_id, "msg_b", "b") + ] + await asyncio.sleep(0) + + monkeypatch.setattr(Storage, "set", fake_storage_set) + + await Message._persist_parts(session_id) + + def test_deserialize_legacy_text_part_normalizes_content_and_time() -> None: part = Message.deserialize_part( { diff --git a/tests/session/test_runner_device_hint.py b/tests/session/test_runner_device_hint.py index 17451aa61..01773730f 100644 --- a/tests/session/test_runner_device_hint.py +++ b/tests/session/test_runner_device_hint.py @@ -61,7 +61,7 @@ async def test_device_asset_hint_stays_short_and_strategy_only() -> None: assert "已省略" not in hint assert "threatbook" in hint assert "qianxin" in hint - assert "`device_context`" in hint + assert "`device_manage(action='list')`" in hint assert "`tool_search`" in hint assert "`device_id`" in hint assert "机房:" not in hint diff --git a/tests/session/test_session_abort_inject.py b/tests/session/test_session_abort_inject.py index 0f3138b1a..e690ea944 100644 --- a/tests/session/test_session_abort_inject.py +++ b/tests/session/test_session_abort_inject.py @@ -357,6 +357,59 @@ def _make_msg(msg_id: str, role: str, finish: str = None, *, tokens=None, summar msg.summary = summary return msg + @pytest.mark.asyncio + async def test_run_loop_stops_turn_when_messages_are_empty(self): + session = SimpleNamespace( + id="turn_no_messages_session", + agent="rex", + directory="/tmp", + memory_enabled=False, + ) + ctx = LoopContext( + session=session, + provider_id="test-provider", + model_id="test-model", + agent_name="rex", + ) + ctx.session_ctx = SimpleNamespace(get_messages=AsyncMock(return_value=[])) + event_callback = AsyncMock() + callbacks = LoopCallbacks(event_publish_callback=event_callback) + + result = await SessionLoop._run_loop(ctx, callbacks) + + assert result.action == "stop" + event_names = [call.args[0] for call in event_callback.await_args_list] + assert event_names == ["turn.started", "turn.stopped"] + stopped_payload = event_callback.await_args_list[1].args[1] + assert stopped_payload["stop_reason"] == "no_messages" + + @pytest.mark.asyncio + async def test_run_loop_stops_turn_when_no_user_message_exists(self): + session = SimpleNamespace( + id="turn_no_user_session", + agent="rex", + directory="/tmp", + memory_enabled=False, + ) + ctx = LoopContext( + session=session, + provider_id="test-provider", + model_id="test-model", + agent_name="rex", + ) + assistant = self._make_msg("msg_001", "assistant", finish="stop") + ctx.session_ctx = SimpleNamespace(get_messages=AsyncMock(return_value=[assistant])) + event_callback = AsyncMock() + callbacks = LoopCallbacks(event_publish_callback=event_callback) + + result = await SessionLoop._run_loop(ctx, callbacks) + + assert result.action == "stop" + event_names = [call.args[0] for call in event_callback.await_args_list] + assert event_names == ["turn.started", "turn.stopped"] + stopped_payload = event_callback.await_args_list[1].args[1] + assert stopped_payload["stop_reason"] == "no_user_message" + @pytest.mark.asyncio async def test_run_loop_continues_for_active_goal_after_stop(self): session = SimpleNamespace( diff --git a/tests/storage/test_multi_db_routing.py b/tests/storage/test_multi_db_routing.py index 1cdfb93b2..8dd5eb7f4 100644 --- a/tests/storage/test_multi_db_routing.py +++ b/tests/storage/test_multi_db_routing.py @@ -6,9 +6,10 @@ import pytest from flocks.config.config import Config -from flocks.storage.storage import Storage, StorageError +from flocks.storage.storage import Storage from flocks.task.models import TaskExecution, TaskScheduler from flocks.task.store import TaskStore, _TASKS_DDL +from flocks.workflow.store import WorkflowStore def _reset_state() -> None: @@ -20,6 +21,10 @@ def _reset_state() -> None: TaskStore._initialized = False TaskStore._conn = None TaskStore._init_pid = None + WorkflowStore._initialized = False + WorkflowStore._conn = None + WorkflowStore._init_pid = None + WorkflowStore._db_path = None def _fetch_storage_value(db_path: Path, key: str): @@ -34,6 +39,18 @@ def _fetch_storage_value(db_path: Path, key: str): conn.close() +def _fetch_workflow_kv_value(db_path: Path, key: str): + conn = sqlite3.connect(db_path) + try: + row = conn.execute( + "SELECT value FROM workflow_kv WHERE key = ?", + (key,), + ).fetchone() + return row[0] if row else None + finally: + conn.close() + + def _fetch_table_count(db_path: Path, table_name: str) -> int: conn = sqlite3.connect(db_path) try: @@ -50,11 +67,12 @@ async def isolated_multi_db_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) _reset_state() yield await TaskStore.close() + await WorkflowStore.close() _reset_state() @pytest.mark.asyncio -async def test_workflow_keys_route_to_workflow_db() -> None: +async def test_storage_no_longer_routes_workflow_keys_to_workflow_db() -> None: await Storage.init() await Storage.write("workflow/wf-1", {"name": "workflow"}) @@ -63,10 +81,9 @@ async def test_workflow_keys_route_to_workflow_db() -> None: flocks_db = Storage.get_db_path() workflow_db = Storage.get_workflow_db_path() - assert _fetch_storage_value(workflow_db, "workflow/wf-1") is not None - assert _fetch_storage_value(flocks_db, "workflow/wf-1") is None + assert _fetch_storage_value(flocks_db, "workflow/wf-1") is not None assert _fetch_storage_value(flocks_db, "project/proj-1") is not None - assert _fetch_storage_value(workflow_db, "project/proj-1") is None + assert not workflow_db.exists() assert await Storage.read("workflow/wf-1") == {"name": "workflow"} assert await Storage.list_keys("workflow") == ["workflow/wf-1"] @@ -78,11 +95,11 @@ async def test_short_non_workflow_prefix_stays_on_flocks_db() -> None: await Storage.write("workspace/item-1", {"name": "workspace"}) await Storage.write("workflow/item-1", {"name": "workflow"}) - assert await Storage.list_keys("work") == ["workspace/item-1"] + assert await Storage.list_keys("work") == ["workflow/item-1", "workspace/item-1"] @pytest.mark.asyncio -async def test_clear_without_prefix_clears_flocks_and_workflow_dbs() -> None: +async def test_clear_without_prefix_clears_flocks_db_only() -> None: await Storage.init() await Storage.write("project/proj-1", {"name": "project"}) @@ -94,7 +111,61 @@ async def test_clear_without_prefix_clears_flocks_and_workflow_dbs() -> None: @pytest.mark.asyncio -async def test_list_without_prefix_merges_flocks_and_workflow_dbs() -> None: +async def test_workflow_prefix_clear_does_not_invalidate_session_caches( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from flocks.session.message import Message + from flocks.session.session import Session + + await Storage.init() + await Storage.write("workflow_execution_step/exec-1/00000001", {"ok": True}) + calls: list[str] = [] + + monkeypatch.setattr(Session, "invalidate_cache", lambda: calls.append("session")) + monkeypatch.setattr( + Message, + "invalidate_cache", + lambda session_id=None: calls.append(f"message:{session_id}"), + ) + + assert await Storage.clear("workflow_execution_step/exec-1/") == 1 + assert calls == [] + + +@pytest.mark.asyncio +async def test_session_and_message_prefix_clear_invalidate_matching_runtime_cache( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from flocks.session.message import Message + from flocks.session.session import Session + + await Storage.init() + await Storage.write("session:proj:ses-1", {"id": "ses-1"}) + await Storage.write("message:ses-1", []) + await Storage.write("message_parts:ses-1/msg-1", []) + calls: list[str] = [] + + monkeypatch.setattr(Session, "invalidate_cache", lambda: calls.append("session")) + monkeypatch.setattr( + Message, + "invalidate_cache", + lambda session_id=None: calls.append(f"message:{session_id}"), + ) + + assert await Storage.clear("session:") == 1 + assert calls == ["session"] + + calls.clear() + assert await Storage.clear("message:") == 1 + assert calls == ["message:None"] + + calls.clear() + assert await Storage.clear("message_parts:") == 1 + assert calls == ["message:None"] + + +@pytest.mark.asyncio +async def test_list_without_prefix_reads_flocks_db_only() -> None: await Storage.init() await Storage.write("project/proj-1", {"name": "project"}) @@ -140,16 +211,15 @@ async def test_workflow_kv_migrates_from_legacy_flocks_db() -> None: finally: conn.close() - await Storage.init() + await WorkflowStore.init() workflow_db = Storage.get_workflow_db_path() - assert _fetch_storage_value(workflow_db, "workflow_registry/wf-legacy") == '{"ok": true}' + assert _fetch_workflow_kv_value(workflow_db, "workflow_registry/wf-legacy") == '{"ok": true}' assert _fetch_storage_value(flocks_db, "workflow_registry/wf-legacy") == '{"ok": true}' - assert _fetch_storage_value(workflow_db, "session:legacy") is None - marker = await Storage.get(Storage._multi_db_migration_marker_key) - assert marker["workflow_migrated"] is True - assert marker["workflow_rows"] == 1 - assert marker["workflow_source_rows_deleted"] == 0 + assert _fetch_workflow_kv_value(workflow_db, "session:legacy") is None + assert await WorkflowStore.kv_get("workflow_registry/wf-legacy") == {"ok": True} + marker = await WorkflowStore.kv_get("workflow_store.migration.tables.v1") + assert marker["kv"] == 1 @pytest.mark.asyncio @@ -181,18 +251,19 @@ async def test_workflow_prefix_migration_treats_underscore_literally() -> None: finally: conn.close() - await Storage.init() + await WorkflowStore.init() workflow_db = Storage.get_workflow_db_path() - assert _fetch_storage_value(workflow_db, "workflow_registry/wf-ok") == '{"ok": true}' - assert _fetch_storage_value(workflow_db, "workflowXregistry/wf-bad") is None + assert _fetch_workflow_kv_value(workflow_db, "workflow_registry/wf-ok") == '{"ok": true}' + assert _fetch_workflow_kv_value(workflow_db, "workflowXregistry/wf-bad") is None assert _fetch_storage_value(flocks_db, "workflow_registry/wf-ok") == '{"ok": true}' assert _fetch_storage_value(flocks_db, "workflowXregistry/wf-bad") == '{"bad": true}' assert await Storage.list_keys("workflow_registry/") == ["workflow_registry/wf-ok"] + assert await WorkflowStore.kv_list_keys("workflow_registry/") == ["workflow_registry/wf-ok"] @pytest.mark.asyncio -async def test_completed_workflow_migration_fails_if_workflow_db_disappears() -> None: +async def test_workflow_store_recreates_workflow_db_if_it_disappears() -> None: flocks_db = Config.get_data_path() / "flocks.db" flocks_db.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(flocks_db) @@ -210,20 +281,23 @@ async def test_completed_workflow_migration_fails_if_workflow_db_disappears() -> ) conn.execute( "INSERT INTO storage (key, value, type, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", - ("workflow/wf-legacy", '{"ok": true}', "json", "old", "old"), + ("workflow_registry/wf-legacy", '{"ok": true}', "json", "old", "old"), ) conn.commit() finally: conn.close() - await Storage.init() + await WorkflowStore.init() workflow_db = Storage.get_workflow_db_path() assert workflow_db.exists() + assert await WorkflowStore.kv_get("workflow_registry/wf-legacy") == {"ok": True} + await WorkflowStore.close() workflow_db.unlink() _reset_state() - with pytest.raises(StorageError, match="workflow.db is missing"): - await Storage.init() + await WorkflowStore.init() + assert Storage.get_workflow_db_path().exists() + assert await WorkflowStore.kv_get("workflow_registry/wf-legacy") == {"ok": True} @pytest.mark.asyncio diff --git a/tests/tool/test_api_service_schema.py b/tests/tool/test_api_service_schema.py new file mode 100644 index 000000000..a56b1e6fd --- /dev/null +++ b/tests/tool/test_api_service_schema.py @@ -0,0 +1,37 @@ +from flocks.tool.schema.api_service_schema import _build_api_service_credential_schema + + +def test_credential_schema_preserves_internal_fields() -> None: + fields = _build_api_service_credential_schema( + "webcli_device", + { + "credential_fields": [ + { + "key": "base_url", + "label": "Base URL", + "storage": "config", + "config_key": "base_url", + }, + { + "key": "auth_state", + "label": "Auth State", + "storage": "secret", + "config_key": "auth_state", + "internal": True, + }, + { + "key": "legacy_cookie", + "label": "Legacy Cookie", + "storage": "secret", + "config_key": "legacy_cookie", + "hidden": True, + }, + ], + }, + ) + + by_key = {field.key: field for field in fields} + assert by_key["base_url"].internal is False + assert by_key["auth_state"].internal is True + assert by_key["auth_state"].storage == "secret" + assert by_key["legacy_cookie"].internal is True diff --git a/tests/tool/test_device_context_prompt.py b/tests/tool/test_device_manage_prompt.py similarity index 92% rename from tests/tool/test_device_context_prompt.py rename to tests/tool/test_device_manage_prompt.py index 487f9d817..2ec818674 100644 --- a/tests/tool/test_device_context_prompt.py +++ b/tests/tool/test_device_manage_prompt.py @@ -3,7 +3,7 @@ import pytest -from flocks.tool.device.prompt import build_device_context_section +from flocks.tool.device.prompt import build_device_manage_list_section from flocks.tool.registry import ParameterType, ToolCategory, ToolInfo, ToolParameter def _stub_groups(monkeypatch: pytest.MonkeyPatch, groups): @@ -32,7 +32,7 @@ def _stub_tools(monkeypatch: pytest.MonkeyPatch, tools): @pytest.mark.asyncio -async def test_device_context_deduplicates_tool_sets_and_references_them_from_devices( +async def test_device_manage_list_deduplicates_tool_sets_and_references_them_from_devices( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr( @@ -91,7 +91,7 @@ async def test_device_context_deduplicates_tool_sets_and_references_them_from_de ], ) - content = await build_device_context_section() + content = await build_device_manage_list_section() assert content is not None assert "设备不存在时,请提醒用户前往设备接入页面添加设备。" in content @@ -109,7 +109,7 @@ async def test_device_context_deduplicates_tool_sets_and_references_them_from_de @pytest.mark.asyncio -async def test_device_context_shows_per_device_disabled_tools_only_for_their_device( +async def test_device_manage_list_shows_per_device_disabled_tools_only_for_their_device( monkeypatch: pytest.MonkeyPatch, ) -> None: """Per-device disabled tools must be annotated only under the device that @@ -143,7 +143,7 @@ async def test_device_context_shows_per_device_disabled_tools_only_for_their_dev ), ]) - content = await build_device_context_section() + content = await build_device_manage_list_section() assert content is not None # Locate the per-device blocks by anchoring on the device name line. @@ -166,7 +166,7 @@ async def test_device_context_shows_per_device_disabled_tools_only_for_their_dev @pytest.mark.asyncio -async def test_device_context_omits_notice_when_no_per_device_overrides( +async def test_device_manage_list_omits_notice_when_no_per_device_overrides( monkeypatch: pytest.MonkeyPatch, ) -> None: """No notice line should appear when a device has no per-device overrides.""" @@ -186,6 +186,6 @@ async def test_device_context_omits_notice_when_no_per_device_overrides( ), ]) - content = await build_device_context_section() + content = await build_device_manage_list_section() assert content is not None assert "已单独禁用" not in content diff --git a/tests/tool/test_device_manage_tool.py b/tests/tool/test_device_manage_tool.py new file mode 100644 index 000000000..8e7c82861 --- /dev/null +++ b/tests/tool/test_device_manage_tool.py @@ -0,0 +1,241 @@ +"""Tests for the built-in device_manage tool.""" +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from flocks.tool.device.intake import DeviceNotFoundError +from flocks.tool.device.manage_tool import device_manage +from flocks.tool.device.models import DeviceIntegration, DeviceTestResult +from flocks.tool.registry import ToolContext, ToolRegistry + + +def make_ctx() -> ToolContext: + return ToolContext( + session_id="session-device-test", + message_id="message-device-test", + agent="rex", + ) + + +def make_device(**overrides) -> DeviceIntegration: + data = { + "id": "dev-1", + "group_id": "default-room", + "name": "自定义设备", + "storage_key": "custom_device_v1", + "service_id": "custom_device", + "enabled": True, + "verify_ssl": False, + "fields": {"base_url": "https://device.local"}, + "fields_set": {"base_url": True}, + "status": "unknown", + "message": None, + "latency_ms": None, + "checked_at": None, + "created_at": 1, + "updated_at": 2, + } + data.update(overrides) + return DeviceIntegration(**data) + + +def test_device_manage_is_registered(): + tools = {tool.name for tool in ToolRegistry.list_tools()} + assert "device_manage" in tools + + +def test_device_manage_schema_includes_update_action(): + tool = ToolRegistry.get("device_manage") + assert tool is not None + + action_param = next(param for param in tool.info.parameters if param.name == "action") + assert action_param.enum == ["list", "update", "connectivity_test"] + assert {param.name for param in tool.info.parameters} >= { + "device_id", + "fields", + "verify_ssl", + } + + +@pytest.mark.asyncio +async def test_device_manage_list_returns_device_inventory(): + with patch( + "flocks.tool.device.prompt.build_device_manage_list_section", + AsyncMock(return_value="### 已接入设备\n- device_id: `dev-1`"), + ): + result = await device_manage(make_ctx(), action="list") + + assert result.success is True + assert "dev-1" in result.output + + +@pytest.mark.asyncio +async def test_device_manage_update_updates_existing_device_non_secret_config(): + updated_device = make_device(verify_ssl=True) + with patch( + "flocks.tool.device.manage_tool.update_device", + AsyncMock(return_value=updated_device), + ) as mocked_update: + result = await device_manage( + make_ctx(), + action="update", + device_id="dev-1", + fields={"base_url": "https://device.local", "port": 443}, + verify_ssl=True, + ) + + mocked_update.assert_awaited_once() + called_device_id, update_body = mocked_update.await_args.args + assert called_device_id == "dev-1" + assert update_body.fields == { + "base_url": "https://device.local", + "port": "443", + } + assert update_body.verify_ssl is True + assert result.success is True + assert result.output["device_id"] == "dev-1" + assert result.output["updated_fields"] == ["base_url", "port"] + assert result.metadata["verify_ssl"] is True + + +@pytest.mark.asyncio +async def test_device_manage_update_rejects_sensitive_fields(): + with patch( + "flocks.tool.device.manage_tool.update_device", + AsyncMock(), + ) as mocked_update: + result = await device_manage( + make_ctx(), + action="update", + device_id="dev-1", + fields={"api_key": "secret-value"}, + ) + + mocked_update.assert_not_awaited() + assert result.success is False + assert "敏感字段" in (result.error or "") + assert "api_key" in (result.error or "") + + +@pytest.mark.asyncio +async def test_device_manage_update_requires_fields_or_verify_ssl(): + result = await device_manage( + make_ctx(), + action="update", + device_id="dev-1", + ) + + assert result.success is False + assert "至少需要提供 fields 或 verify_ssl" in (result.error or "") + + +@pytest.mark.asyncio +async def test_device_manage_update_reports_missing_device_as_tool_error(): + with patch( + "flocks.tool.device.manage_tool.update_device", + AsyncMock(side_effect=DeviceNotFoundError("missing")), + ): + result = await device_manage( + make_ctx(), + action="update", + device_id="missing-id", + fields={"base_url": "https://device.local"}, + ) + + assert result.success is False + assert "未找到" in (result.error or "") + + +@pytest.mark.asyncio +async def test_device_manage_connectivity_test_writes_status_via_existing_test_path(): + with patch( + "flocks.tool.device.manage_tool.test_device", + AsyncMock( + return_value=DeviceTestResult( + success=True, + message="HTTP 200,延迟 12ms", + latency_ms=12, + ) + ), + ) as mocked_test: + result = await device_manage( + make_ctx(), + action="connectivity_test", + device_id="dev-1", + ) + + mocked_test.assert_awaited_once_with("dev-1") + assert result.success is True + assert result.output == { + "device_id": "dev-1", + "connected": True, + "status": "ok", + "message": "HTTP 200,延迟 12ms", + "latency_ms": 12, + } + assert result.metadata["card_status_updated"] is True + + +@pytest.mark.asyncio +async def test_device_manage_keeps_device_id_when_executed_through_registry(): + with patch( + "flocks.tool.device.manage_tool.test_device", + AsyncMock( + return_value=DeviceTestResult( + success=True, + message="HTTP 200,延迟 12ms", + latency_ms=12, + ) + ), + ) as mocked_test: + result = await ToolRegistry.execute( + "device_manage", + make_ctx(), + action="connectivity_test", + device_id="dev-1", + ) + + mocked_test.assert_awaited_once_with("dev-1") + assert result.success is True + assert result.metadata["device_id"] == "dev-1" + + +@pytest.mark.asyncio +async def test_device_manage_connectivity_test_returns_successful_tool_result_for_failed_probe(): + with patch( + "flocks.tool.device.manage_tool.test_device", + AsyncMock( + return_value=DeviceTestResult( + success=False, + message="无法连接到 https://device.local", + latency_ms=10000, + ) + ), + ): + result = await device_manage( + make_ctx(), + action="connectivity_test", + device_id="dev-1", + ) + + assert result.success is True + assert result.output["connected"] is False + assert result.output["status"] == "error" + + +@pytest.mark.asyncio +async def test_device_manage_connectivity_test_reports_missing_device_as_tool_error(): + with patch( + "flocks.tool.device.manage_tool.test_device", + AsyncMock(side_effect=DeviceNotFoundError("missing")), + ): + result = await device_manage( + make_ctx(), + action="connectivity_test", + device_id="missing-id", + ) + + assert result.success is False + assert "未找到" in (result.error or "") diff --git a/tests/tool/test_device_plugin_index.py b/tests/tool/test_device_plugin_index.py index 1015d7afa..678159d80 100644 --- a/tests/tool/test_device_plugin_index.py +++ b/tests/tool/test_device_plugin_index.py @@ -71,6 +71,7 @@ def test_device_plugin_index_filters_and_shapes_templates(monkeypatch, tmp_path) "version": "1.2.3", "integration_type": "device", "vendor": "demo", + "docs_url": "https://docs.example.com/demo-device", "credential_fields": [ {"key": "base_url", "label": "Base URL", "storage": "config"}, ], @@ -128,6 +129,7 @@ def test_device_plugin_index_filters_and_shapes_templates(monkeypatch, tmp_path) assert template.storage_key == "demo_api_v1_2_3" assert template.service_id == "demo_api" assert template.vendor == "demo" + assert template.docs_url == "https://docs.example.com/demo-device" assert template.installed is False assert template.state == "available" assert template.source == "bundled" @@ -179,6 +181,37 @@ def test_device_plugin_index_normalizes_plugin_id_name(monkeypatch, tmp_path): assert templates[0].version == "2.5.3 D20250710" +def test_onesig_strategy_api_template_uses_unversioned_service_id(monkeypatch, tmp_path): + from flocks.tool.device import plugin_index + + _reset_env(monkeypatch, tmp_path) + root = Path.cwd() / ".flocks" / "plugins" / "tools" / "device" / "onesig_v2_5_3" + _write_provider( + root, + { + "name": "onesig", + "service_id": "onesig_api", + "version": "2.5.3", + "integration_type": "device", + "vendor": "threatbook", + "credential_fields": [ + {"key": "base_url", "label": "Base URL", "storage": "config"}, + ], + }, + ) + _write_tool(root, "onesig_strategy_api_query") + + monkeypatch.setattr(plugin_index.hub_catalog, "list_catalog", lambda plugin_type=None: []) + monkeypatch.setattr(plugin_index.ToolRegistry, "init", classmethod(lambda cls: None)) + monkeypatch.setattr(plugin_index.ToolRegistry, "list_tools", classmethod(lambda cls: [])) + + templates = plugin_index.list_device_templates(refresh=True) + + assert len(templates) == 1 + assert templates[0].service_id == "onesig_api" + assert templates[0].storage_key == "onesig_api_v2_5_3" + + def test_device_template_refresh_reloads_plugin_tools(monkeypatch, tmp_path): from flocks.tool.device import plugin_index diff --git a/tests/tool/test_device_secrets.py b/tests/tool/test_device_secrets.py index 770b30501..6b6074c26 100644 --- a/tests/tool/test_device_secrets.py +++ b/tests/tool/test_device_secrets.py @@ -23,3 +23,41 @@ def test_persist_fields_keeps_non_tdp_base_url_paths(): ) assert fields["base_url"] == "https://proxy.local/config/api" + + +def test_persist_fields_deletes_secret_when_empty_string_is_submitted(): + secret_manager = MagicMock() + prior = { + "api_key": "{secret:device_device-1_api_key}", + "base_url": "https://tdp.local", + } + + with patch("flocks.security.get_secret_manager", return_value=secret_manager): + fields = persist_fields( + "device-1", + "tdp_api_v3_3_10", + {"api_key": ""}, + prior_db_fields=prior, + ) + + assert "api_key" not in fields + assert fields["base_url"] == "https://tdp.local" + secret_manager.delete.assert_called_once_with("device_device-1_api_key") + + +def test_persist_fields_keeps_secret_when_key_is_absent(): + prior = { + "api_key": "{secret:device_device-1_api_key}", + "base_url": "https://tdp.local", + } + + with patch("flocks.security.get_secret_manager", return_value=MagicMock()): + fields = persist_fields( + "device-1", + "tdp_api_v3_3_10", + {"base_url": "https://tdp.local/config/api"}, + prior_db_fields=prior, + ) + + assert fields["api_key"] == "{secret:device_device-1_api_key}" + assert fields["base_url"] == "https://tdp.local" diff --git a/tests/tool/test_skyeye_api_tools.py b/tests/tool/test_skyeye_api_tools.py index 4c1e21187..97c71bf07 100644 --- a/tests/tool/test_skyeye_api_tools.py +++ b/tests/tool/test_skyeye_api_tools.py @@ -98,7 +98,7 @@ async def test_skyeye_dashboard_view_tool_uses_custom_login_flow(): method, auth_url, auth_kwargs = fake_session.calls[0] assert method == "POST" - assert auth_url == "https://skyeye.local/api/v1/admin/auth" + assert auth_url == "https://skyeye.local/skyeye/api/v1/admin/auth" assert auth_kwargs["data"] == { "client_id": auth_kwargs["data"]["client_id"], "username": "tapadmin", @@ -108,7 +108,7 @@ async def test_skyeye_dashboard_view_tool_uses_custom_login_flow(): method, request_url, request_kwargs = fake_session.calls[2] assert method == "GET" - assert request_url == "https://skyeye.local/api/v1/monitor-center/dashboard/view" + assert request_url == "https://skyeye.local/skyeye/api/v1/monitor-center/dashboard/view" assert request_kwargs["params"]["name"] == "overall_view" assert request_kwargs["params"]["interval_time"] == 7 assert request_kwargs["params"]["csrf_token"] == "abcdef1234567890" @@ -142,11 +142,11 @@ async def test_skyeye_alarm_params_tool_can_use_secret_host_and_login_key(): method, auth_url, _ = fake_session.calls[0] assert method == "POST" - assert auth_url == "https://skyeye.internal:443/v1/admin/auth" + assert auth_url == "https://skyeye.internal:443/skyeye/v1/admin/auth" method, request_url, request_kwargs = fake_session.calls[2] assert method == "GET" - assert request_url == "https://skyeye.internal:443/v1/alarm/alarm/alarm-params" + assert request_url == "https://skyeye.internal:443/skyeye/v1/alarm/alarm/alarm-params" assert request_kwargs["params"]["data_source"] == 0 assert request_kwargs["params"]["csrf_token"] == "1234567890abcdef" @@ -202,7 +202,7 @@ async def test_skyeye_alarm_list_accepts_legacy_threat_level_alias(): assert result.success is True method, request_url, request_kwargs = fake_session.calls[2] assert method == "GET" - assert request_url == "https://skyeye.local/api/v1/alarm/alarm/list" + assert request_url == "https://skyeye.local/skyeye/api/v1/alarm/alarm/list" assert request_kwargs["params"]["hazard_level"] == "3" assert "threat_level" not in request_kwargs["params"] diff --git a/tests/tool/test_ssh_utils_pool.py b/tests/tool/test_ssh_utils_pool.py index 7937e8c0e..21507e3ae 100644 --- a/tests/tool/test_ssh_utils_pool.py +++ b/tests/tool/test_ssh_utils_pool.py @@ -14,10 +14,24 @@ def __init__(self, host: str) -> None: def close(self) -> None: self.closed = True + def is_closed(self) -> bool: + return self.closed + async def run(self, command: str, check: bool = False) -> SimpleNamespace: return SimpleNamespace(exit_status=0, stdout=f"{self.host}:{command}", stderr="") +class ChannelFailingConnection(DummyConnection): + def __init__(self, host: str, *, fail: bool = False) -> None: + super().__init__(host) + self.fail = fail + + async def run(self, command: str, check: bool = False) -> SimpleNamespace: + if self.fail: + raise ssh_utils.asyncssh.ChannelOpenError(2, "session channel closed") + return await super().run(command, check=check) + + @pytest.mark.asyncio async def test_ssh_pool_evicts_least_recent_idle_connection(monkeypatch: pytest.MonkeyPatch) -> None: created: list[DummyConnection] = [] @@ -79,6 +93,32 @@ async def fake_connect(**kwargs): assert pool.stats()["locks"] == 0 +@pytest.mark.asyncio +async def test_ssh_pool_reconnects_instead_of_reusing_closed_connection( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[DummyConnection] = [] + + async def fake_connect(**kwargs): + conn = DummyConnection(kwargs["host"]) + created.append(conn) + return conn + + monkeypatch.setattr(ssh_utils.asyncssh, "connect", fake_connect) + pool = ssh_utils.SSHConnectionPool(max_connections=10, idle_ttl_s=3600) + + first = await pool.get_connection("session", "host-1", 22, "root", None, None) + await pool.release_connection("session", "host-1", 22, "root") + first.close() + + second = await pool.get_connection("session", "host-1", 22, "root", None, None) + + assert second is not first + assert len(created) == 2 + assert pool.stats()["connections"] == 1 + assert pool.stats()["locks"] == 1 + + @pytest.mark.asyncio async def test_execute_ssh_command_releases_connection_after_run( monkeypatch: pytest.MonkeyPatch, @@ -106,3 +146,80 @@ async def fake_connect(**kwargs): assert stderr == "" assert pool.stats()["connections"] == 1 assert pool.stats()["active_connections"] == 0 + + +@pytest.mark.asyncio +async def test_execute_ssh_command_reconnects_after_channel_open_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[ChannelFailingConnection] = [] + + async def fake_connect(**kwargs): + conn = ChannelFailingConnection(kwargs["host"], fail=not created) + created.append(conn) + return conn + + pool = ssh_utils.SSHConnectionPool(max_connections=10, idle_ttl_s=3600) + monkeypatch.setattr(ssh_utils.asyncssh, "connect", fake_connect) + monkeypatch.setattr(ssh_utils, "_pool", pool) + + exit_code, stdout, stderr = await ssh_utils.execute_ssh_command( + host="host-1", + command="uptime", + username="root", + port=22, + key_path=None, + password=None, + timeout_s=5, + session_id="session", + ) + + assert exit_code == 0 + assert stdout == "host-1:uptime" + assert stderr == "" + assert len(created) == 2 + assert created[0].closed is True + assert created[1].closed is False + assert pool.stats()["connections"] == 1 + assert pool.stats()["active_connections"] == 0 + + +@pytest.mark.asyncio +async def test_execute_ssh_command_does_not_close_active_connection_on_channel_open_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[ChannelFailingConnection] = [] + + async def fake_connect(**kwargs): + conn = ChannelFailingConnection(kwargs["host"], fail=not created) + created.append(conn) + return conn + + pool = ssh_utils.SSHConnectionPool(max_connections=10, idle_ttl_s=3600) + monkeypatch.setattr(ssh_utils.asyncssh, "connect", fake_connect) + monkeypatch.setattr(ssh_utils, "_pool", pool) + + active_conn = await pool.get_connection("session", "host-1", 22, "root", None, None) + + exit_code, stdout, stderr = await ssh_utils.execute_ssh_command( + host="host-1", + command="uptime", + username="root", + port=22, + key_path=None, + password=None, + timeout_s=5, + session_id="session", + ) + + assert exit_code == 0 + assert stdout == "host-1:uptime" + assert stderr == "" + assert len(created) == 2 + assert active_conn.closed is False + assert created[1].closed is True + assert pool.stats()["connections"] == 1 + assert pool.stats()["active_connections"] == 1 + + await pool.release_connection("session", "host-1", 22, "root") + assert pool.stats()["active_connections"] == 0 diff --git a/tests/tool/test_tdp_skyeye_api_plugins.py b/tests/tool/test_tdp_skyeye_api_plugins.py index e5c75907a..0d6a60521 100644 --- a/tests/tool/test_tdp_skyeye_api_plugins.py +++ b/tests/tool/test_tdp_skyeye_api_plugins.py @@ -833,6 +833,15 @@ def test_skyeye_verify_ssl_defaults_false_when_unset(): assert module._verify_ssl({"verify_ssl": False}) is False +def test_skyeye_base_url_appends_skyeye_suffix_once(): + module = _load_module("test_skyeye_handler_base_url", _SKYEYE_HANDLER) + + assert module._ensure_skyeye_base_path("https://skyeye.local") == "https://skyeye.local/skyeye" + assert module._ensure_skyeye_base_path("https://skyeye.local/") == "https://skyeye.local/skyeye" + assert module._ensure_skyeye_base_path("https://skyeye.local/skyeye") == "https://skyeye.local/skyeye" + assert module._ensure_skyeye_base_path("https://skyeye.local/skyeye/") == "https://skyeye.local/skyeye" + + def test_tdp_resolve_verify_ssl_defaults_false_when_unset(): module = _load_module("test_tdp_handler_verify_ssl", _TDP_HANDLER) assert module._resolve_verify_ssl({}) is False diff --git a/tests/tool/test_watcher_atomic_save.py b/tests/tool/test_watcher_atomic_save.py index 913b6c84a..d3f1979ff 100644 --- a/tests/tool/test_watcher_atomic_save.py +++ b/tests/tool/test_watcher_atomic_save.py @@ -13,7 +13,12 @@ from types import SimpleNamespace -from flocks.tool.registry import ToolFileWatcher, _tool_event_should_reload +from flocks.tool.registry import ( + ToolFileWatcher, + ToolRegistry, + _tool_event_should_reload, + _tool_event_touches_device_plugin, +) from flocks.agent.registry import _agent_event_should_reload from flocks.skill.skill import _skill_event_should_reload @@ -64,6 +69,54 @@ def test_tool_watcher_includes_device_plugin_directory() -> None: assert "device" in ToolFileWatcher._WATCH_SUBDIRS +def test_tool_watcher_detects_device_plugin_events() -> None: + evt = _modify_event("/repo/.flocks/plugins/tools/device/tdp/tdp.yaml") + assert _tool_event_should_reload(evt) is True + assert _tool_event_touches_device_plugin(evt) is True + + +def test_tool_watcher_rejects_unwatched_tool_subdirs() -> None: + evt = _modify_event("/repo/.flocks/plugins/tools/mcp/demo/server.yaml") + assert _tool_event_should_reload(evt) is False + + +def test_tool_watcher_watches_plugin_root_to_catch_new_tool_dirs(monkeypatch, tmp_path) -> None: + from flocks.plugin import loader + + user_plugin_root = tmp_path / "home" / ".flocks" / "plugins" + user_plugin_root.mkdir(parents=True) + project_root = tmp_path / "missing-project" + project_root.mkdir() + monkeypatch.setattr(loader, "DEFAULT_PLUGIN_ROOT", user_plugin_root) + monkeypatch.chdir(project_root) + + assert ToolFileWatcher()._collect_watch_dirs() == {str(user_plugin_root)} + + +def test_tool_watcher_refresh_clears_device_caches(monkeypatch) -> None: + from flocks.config import api_versioning + from flocks.tool.device import plugin_index + + calls: list[str] = [] + monkeypatch.setattr(plugin_index, "clear_device_template_cache", lambda: calls.append("templates")) + monkeypatch.setattr( + api_versioning, + "discover_api_service_descriptors", + lambda *, refresh=False: calls.append(f"descriptors:{refresh}") or [], + ) + monkeypatch.setattr( + ToolRegistry, + "refresh_plugin_tools", + classmethod(lambda cls: calls.append("tools") or []), + ) + + watcher = ToolFileWatcher() + watcher._device_changed = True + watcher._run_refresh() + + assert calls == ["templates", "descriptors:True", "tools"] + + # --------------------------------------------------------------------------- # Agent watcher predicate # --------------------------------------------------------------------------- diff --git a/tests/updater/test_updater_console_manifest_bundle.py b/tests/updater/test_updater_console_manifest_bundle.py index 28dd076fe..518f3b862 100644 --- a/tests/updater/test_updater_console_manifest_bundle.py +++ b/tests/updater/test_updater_console_manifest_bundle.py @@ -1,6 +1,7 @@ from __future__ import annotations import zipfile +import json from datetime import UTC, datetime, timedelta from types import SimpleNamespace @@ -47,7 +48,7 @@ async def get(self, url, headers=None, follow_redirects=True): monkeypatch.setattr(updater.httpx, "AsyncClient", lambda timeout=15: _Client()) result = await updater._fetch_console_manifest_release() assert result == ( - "pro-v2026-5-10", + "v2026.5.10", "bundle release", "https://cdn.example.com/flockspro-bundle-v2026.5.10.tar.gz", None, @@ -59,7 +60,10 @@ async def get(self, url, headers=None, follow_redirects=True): @pytest.mark.asyncio -async def test_check_update_uses_pro_marker_and_component_version(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: +async def test_check_update_uses_pro_marker_bundle_version_and_component_metadata( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: marker = tmp_path / "run" / "pro-bundle-installed.json" marker.parent.mkdir(parents=True) marker.write_text( @@ -75,13 +79,16 @@ async def _fake_sources(_sources): async def _fake_manifest_info(): return updater.ConsoleManifestRelease( - version="pro-v2026-05-23", + version="v2026.5.23", release_notes="latest pro", release_url="https://cdn.example.com/flockspro-bundle-pro-v2026-05-23.zip", bundle_url="https://cdn.example.com/flockspro-bundle-pro-v2026-05-23.zip", bundle_sha256=None, bundle_format="zip", - manifest={"flockspro_component_version": "pro-v2026-05-23"}, + manifest={ + "display_version": "v2026.5.23", + "flockspro_component_version": "pro-v2026-05-23", + }, ) async def _fake_config(): @@ -92,15 +99,22 @@ async def _fake_config(): monkeypatch.setattr(updater, "_get_updater_config", _fake_config) monkeypatch.setattr(updater, "_resolve_sources_for_edition", _fake_sources) monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", _fake_manifest_info) + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.5.23") info = await updater.check_update() - assert info.current_version == "pro-v2026-05-23" - assert info.latest_version == "pro-v2026-05-23" + assert info.current_version == "v2026.5.23" + assert info.latest_version == "v2026.5.23" + assert info.current_core_version == "v2026.5.23" + assert info.latest_core_version == "v2026.5.23" + assert info.current_bundle_version == "v2026.5.23" + assert info.latest_bundle_version == "v2026.5.23" + assert info.current_pro_component_version == "pro-v2026-05-23" + assert info.latest_pro_component_version == "pro-v2026-05-23" assert info.has_update is False @pytest.mark.asyncio -async def test_check_update_force_console_manifest_uses_pro_versions(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: +async def test_check_update_force_console_manifest_uses_bundle_versions(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None: marker = tmp_path / "run" / "pro-bundle-installed.json" marker.parent.mkdir(parents=True) marker.write_text( @@ -116,29 +130,172 @@ async def _fake_config(): async def _fake_manifest_info(): return updater.ConsoleManifestRelease( - version="pro-v2026-05-24", + version="v2026.5.24", release_notes="latest pro", release_url="https://console.example.com/v1/pro-bundles/rel_1/download", bundle_url="https://console.example.com/v1/pro-bundles/rel_1/download", bundle_sha256="abc123", bundle_format="zip", - manifest={"flockspro_component_version": "pro-v2026-05-24"}, + manifest={ + "display_version": "v2026.5.24", + "core_version": "v2026.5.23", + "flockspro_component_version": "pro-v2026-05-24", + }, ) monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) monkeypatch.setattr("flocks.updater.deploy.detect_deploy_mode", lambda: "source") monkeypatch.setattr(updater, "_get_updater_config", _fake_config) monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", _fake_manifest_info) + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.5.23") info = await updater.check_update(force_console_manifest=True) assert info.edition == "flockspro" - assert info.current_version == "pro-v2026-05-23" - assert info.latest_version == "pro-v2026-05-24" + assert info.current_version == "v2026.5.23" + assert info.latest_version == "v2026.5.24" + assert info.current_core_version == "v2026.5.23" + assert info.latest_core_version == "v2026.5.23" + assert info.current_bundle_version == "v2026.5.23" + assert info.latest_bundle_version == "v2026.5.24" + assert info.current_pro_component_version == "pro-v2026-05-23" + assert info.latest_pro_component_version == "pro-v2026-05-24" assert info.bundle_sha256 == "abc123" assert info.has_update is True +@pytest.mark.asyncio +async def test_check_update_force_console_manifest_detects_component_only_update( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + marker = tmp_path / "run" / "pro-bundle-installed.json" + marker.parent.mkdir(parents=True) + marker.write_text( + """{ + "installed_version": "v2026.6.18", + "flockspro_component_version": "v2026.6.1" +}""", + encoding="utf-8", + ) + + async def _fake_config(): + return SimpleNamespace(enabled=True, sources=["github"], repo="", token=None) + + async def _fake_manifest_info(): + return updater.ConsoleManifestRelease( + version="v2026.6.18", + release_notes="latest pro", + release_url="https://console.example.com/v1/pro-bundles/rel_2/download", + bundle_url="https://console.example.com/v1/pro-bundles/rel_2/download", + bundle_sha256="def456", + bundle_format="zip", + manifest={ + "display_version": "v2026.6.18", + "oss_version": "v2026.6.18", + "flockspro_component_version": "v2026.6.2", + }, + ) + + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) + monkeypatch.setattr("flocks.updater.deploy.detect_deploy_mode", lambda: "source") + monkeypatch.setattr(updater, "_get_updater_config", _fake_config) + monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", _fake_manifest_info) + + info = await updater.check_update(force_console_manifest=True) + + assert info.current_version == "v2026.6.18" + assert info.latest_version == "v2026.6.18" + assert info.current_pro_component_version == "v2026.6.1" + assert info.latest_pro_component_version == "v2026.6.2" + assert info.has_update is True + + +@pytest.mark.asyncio +async def test_check_update_force_console_manifest_reports_stale_product_marker_as_bundle_update( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + marker = tmp_path / "run" / "pro-bundle-installed.json" + marker.parent.mkdir(parents=True) + marker.write_text( + """{ + "installed_version": "v2026.6.22", + "core_version": "v2026.6.21", + "flockspro_component_version": "v2026.6.23" +}""", + encoding="utf-8", + ) + + async def _fake_config(): + return SimpleNamespace(enabled=True, sources=["github"], repo="", token=None) + + async def _fake_manifest_info(): + return updater.ConsoleManifestRelease( + version="v2026.6.23", + release_notes="latest pro", + release_url="https://console.example.com/v1/pro-bundles/rel_3/download", + bundle_url="https://console.example.com/v1/pro-bundles/rel_3/download", + bundle_sha256="ghi789", + bundle_format="zip", + manifest={ + "display_version": "v2026.6.23", + "core_version": "v2026.6.21", + "flockspro_component_version": "v2026.6.23", + }, + ) + + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) + monkeypatch.setattr("flocks.updater.deploy.detect_deploy_mode", lambda: "source") + monkeypatch.setattr(updater, "_get_updater_config", _fake_config) + monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", _fake_manifest_info) + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.6.21") + + info = await updater.check_update(force_console_manifest=True) + + assert info.current_version == "v2026.6.22" + assert info.latest_version == "v2026.6.23" + assert info.current_bundle_version == "v2026.6.22" + assert info.latest_bundle_version == "v2026.6.23" + assert info.current_core_version == "v2026.6.21" + assert info.latest_core_version == "v2026.6.21" + assert info.current_pro_component_version == "v2026.6.23" + assert info.latest_pro_component_version == "v2026.6.23" + assert info.has_update is True + + +def test_console_manifest_release_identity_writes_product_and_core_versions( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) + merged = updater._merge_console_manifest_release_identity( + { + "display_version": "v2026.6.21", + "core_version": "v2026.6.21", + "flockspro_component_version": "v2026.6.23", + }, + { + "release_id": "rel_623", + "display_version": "v2026.6.23", + "core_version": "v2026.6.21", + "flockspro_component_version": "v2026.6.23", + "build_id": "job_623", + }, + ) + + assert merged["display_version"] == "v2026.6.23" + assert merged["core_version"] == "v2026.6.21" + updater._write_pro_bundle_install_marker(merged, bundle_sha256="sha623") + + marker = json.loads((tmp_path / "run" / "pro-bundle-installed.json").read_text(encoding="utf-8")) + assert marker["installed_version"] == "v2026.6.23" + assert marker["core_version"] == "v2026.6.21" + assert marker["oss_version"] == "v2026.6.21" + assert marker["flockspro_component_version"] == "v2026.6.23" + assert marker["build_id"] == "job_623" + + @pytest.mark.asyncio async def test_load_console_session_token_falls_back_to_shared_session( monkeypatch: pytest.MonkeyPatch, @@ -442,6 +599,7 @@ async def test_perform_pro_bundle_install_replaces_core_and_installs_wheel( (venv_bin / "python").write_text("#!/usr/bin/env python\n", encoding="utf-8") monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path / "flocks-root")) monkeypatch.setattr(updater, "_get_repo_root", lambda: install_root) + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.5.10") monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", lambda: _async_manifest_info(bundle)) monkeypatch.setattr(updater, "_download_console_bundle", lambda *_args, **_kwargs: _async_path(bundle)) monkeypatch.setattr(updater, "_verify_download_sha256", lambda *_args, **_kwargs: None) @@ -474,6 +632,99 @@ async def _fake_run_async(cmd, **_kwargs): assert marker_payload["oss_version"] == "v2026.5.10" +@pytest.mark.asyncio +async def test_perform_pro_bundle_install_keeps_newer_local_core_when_bundle_oss_is_older( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + bundle_root = tmp_path / "bundle-root" + core_root = bundle_root / "flocks" + core_root.mkdir(parents=True) + (core_root / "pyproject.toml").write_text('[project]\nname = "flocks"\n', encoding="utf-8") + (core_root / "older_core.py").write_text("OLDER = True\n", encoding="utf-8") + wheels = bundle_root / "wheels" + wheels.mkdir() + wheel = wheels / "flockspro-0.2.0-py3-none-any.whl" + wheel.write_bytes(b"fake-wheel") + (bundle_root / "manifest.json").write_text( + """{ + "display_version": "v2026.6.13", + "oss_version": "v2026.6.13", + "flockspro_component_version": "v2026.6.2", + "flockspro_wheel": "wheels/flockspro-0.2.0-py3-none-any.whl", + "build_id": "job_new_pro_old_core" +}""", + encoding="utf-8", + ) + bundle = tmp_path / "flockspro-bundle.zip" + with zipfile.ZipFile(bundle, "w") as archive: + for path in bundle_root.rglob("*"): + if path.is_file(): + archive.write(path, path.relative_to(bundle_root).as_posix()) + + install_root = tmp_path / "install" + install_root.mkdir() + (install_root / "current_core.py").write_text("CURRENT = True\n", encoding="utf-8") + venv_bin = install_root / ".venv" / "bin" + venv_bin.mkdir(parents=True) + (venv_bin / "python").write_text("#!/usr/bin/env python\n", encoding="utf-8") + + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path / "flocks-root")) + monkeypatch.setattr(updater, "_get_repo_root", lambda: install_root) + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.6.18") + + async def _fake_manifest_info(): + return updater.ConsoleManifestRelease( + version="v2026.6.13", + release_notes="new Pro on older core", + release_url=str(bundle), + bundle_url=str(bundle), + bundle_sha256=None, + bundle_format="zip", + manifest={ + "release_id": "rel_new_pro_old_core", + "display_version": "v2026.6.13", + "oss_version": "v2026.6.13", + "flockspro_component_version": "v2026.6.2", + "build_id": "job_new_pro_old_core", + }, + ) + + monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", _fake_manifest_info) + monkeypatch.setattr(updater, "_download_console_bundle", lambda *_args, **_kwargs: _async_path(bundle)) + monkeypatch.setattr(updater, "_verify_download_sha256", lambda *_args, **_kwargs: None) + monkeypatch.setattr(updater, "_find_executable", lambda name: "/usr/bin/uv" if name == "uv" else None) + monkeypatch.setattr(updater, "_backup_current_version", lambda *_args, **_kwargs: tmp_path / "backup.tar.gz") + monkeypatch.setattr(updater, "_write_version_marker", lambda *_args, **_kwargs: None) + monkeypatch.setattr(updater, "_refresh_global_cli_entry", lambda *_args, **_kwargs: None) + + captured: list[list[str]] = [] + + async def _fake_run_async(cmd, **_kwargs): + captured.append(cmd) + return 0, "", "" + + monkeypatch.setattr(updater, "_run_async", _fake_run_async) + + progresses = [step async for step in updater.perform_pro_bundle_install(restart=False)] + + assert progresses[-1].stage == "done" + assert any("Keeping local Flocks v2026.6.18" in step.message for step in progresses) + assert (install_root / "current_core.py").is_file() + assert not (install_root / "older_core.py").exists() + pip_installs = [cmd for cmd in captured if cmd[:3] == ["/usr/bin/uv", "pip", "install"]] + assert pip_installs + assert str(wheel.name) in pip_installs[-1][-1] + marker = tmp_path / "flocks-root" / "run" / "pro-bundle-installed.json" + marker_payload = __import__("json").loads(marker.read_text(encoding="utf-8")) + assert marker_payload["release_id"] == "rel_new_pro_old_core" + assert marker_payload["display_version"] == "v2026.6.13" + assert marker_payload["installed_version"] == "v2026.6.13" + assert marker_payload["core_version"] == "v2026.6.18" + assert marker_payload["oss_version"] == "v2026.6.18" + assert marker_payload["flockspro_component_version"] == "v2026.6.2" + + @pytest.mark.asyncio async def test_perform_pro_bundle_install_schedules_restart_before_stream_can_close( monkeypatch: pytest.MonkeyPatch, @@ -509,6 +760,7 @@ async def test_perform_pro_bundle_install_schedules_restart_before_stream_can_cl (venv_bin / "python").write_text("#!/usr/bin/env python\n", encoding="utf-8") monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path / "flocks-root")) monkeypatch.setattr(updater, "_get_repo_root", lambda: install_root) + monkeypatch.setattr(updater, "get_current_version", lambda: "2026.5.10") monkeypatch.setattr(updater, "_fetch_console_manifest_release_info", lambda: _async_manifest_info(bundle)) monkeypatch.setattr(updater, "_download_console_bundle", lambda *_args, **_kwargs: _async_path(bundle)) monkeypatch.setattr(updater, "_verify_download_sha256", lambda *_args, **_kwargs: None) @@ -550,4 +802,3 @@ async def _async_manifest_info(bundle): async def _async_path(path): return path - diff --git a/tests/updater/test_updater_edition_sources.py b/tests/updater/test_updater_edition_sources.py index b93706b72..2798ed541 100644 --- a/tests/updater/test_updater_edition_sources.py +++ b/tests/updater/test_updater_edition_sources.py @@ -51,6 +51,24 @@ async def test_console_session_does_not_change_oss_sources(monkeypatch, tmp_path assert sources == ["github", "gitee"] +@pytest.mark.asyncio +async def test_console_manifest_only_config_is_filtered_from_oss_sources(monkeypatch, tmp_path): + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) + + sources = await _resolve_sources_for_edition(["console-manifest"]) + + assert sources == [] + + +@pytest.mark.asyncio +async def test_console_manifest_mixed_config_is_filtered_from_oss_sources(monkeypatch, tmp_path): + monkeypatch.setenv("FLOCKS_ROOT", str(tmp_path)) + + sources = await _resolve_sources_for_edition(["github", "console-manifest", "gitee"]) + + assert sources == ["github", "gitee"] + + def test_flockspro_license_active_uses_runtime_capability(monkeypatch): monkeypatch.setattr(updater.importlib.util, "find_spec", lambda name: object() if name == "flockspro" else None) diff --git a/tests/user_defined_pages/test_store.py b/tests/user_defined_pages/test_store.py deleted file mode 100644 index 4952598d9..000000000 --- a/tests/user_defined_pages/test_store.py +++ /dev/null @@ -1,74 +0,0 @@ -import json - -import pytest - -from flocks.user_defined_pages.store import UserDefinedPagesStore - - -@pytest.fixture -def store(tmp_path, monkeypatch): - root = tmp_path / "user_defined_pages" - monkeypatch.setenv("FLOCKS_USER_DEFINED_PAGES_ROOT", str(root)) - return UserDefinedPagesStore() - - -def test_create_page_scaffold(store: UserDefinedPagesStore): - detail = store.create_page(page_id="my-dashboard", title="我的大屏") - assert detail.manifest.id == "my-dashboard" - assert detail.manifest.route == "/user-defined-pages/my-dashboard" - assert (store.page_dir("my-dashboard") / "src" / "Page.tsx").is_file() - assert (store.page_dir("my-dashboard") / "manifest.json").is_file() - - -def test_list_pages_enabled_only(store: UserDefinedPagesStore): - store.create_page(page_id="enabled-page", title="启用页") - disabled = store.create_page(page_id="disabled-page", title="禁用页") - store.save_manifest("disabled-page", {**disabled.manifest.model_dump(), "enabled": False}) - - all_pages = store.list_pages(enabled_only=False) - enabled_pages = store.list_pages(enabled_only=True) - - assert {page.id for page in all_pages} == {"enabled-page", "disabled-page"} - assert [page.id for page in enabled_pages] == ["enabled-page"] - - -def test_reject_path_traversal_on_write(store: UserDefinedPagesStore): - store.create_page(page_id="safe-page", title="安全页") - with pytest.raises(ValueError, match="writes are not allowed"): - store.save_source_file("safe-page", "../escape.tsx", "bad") - - -def test_allow_page_api_source_files(store: UserDefinedPagesStore): - store.create_page(page_id="api-page", title="API 页") - store.save_source_file("api-page", "api/routes.yaml", "routes: []\n") - store.save_source_file("api-page", "api/handlers.py", "def ping(ctx, request):\n return {'ok': True}\n") - assert store.read_source_file("api-page", "api/routes.yaml").startswith("routes:") - detail = store.get_page("api-page") - assert "api/routes.yaml" in detail.sourceFiles - assert "api/handlers.py" in detail.sourceFiles - - -def test_reject_unsupported_api_extension(store: UserDefinedPagesStore): - store.create_page(page_id="api-ext-page", title="API 后缀页") - with pytest.raises(ValueError, match="unsupported source file type"): - store.save_source_file("api-ext-page", "api/secret.txt", "nope") - - -def test_reject_invalid_page_id(store: UserDefinedPagesStore): - with pytest.raises(ValueError, match="invalid page id"): - store.validate_page_id("../bad") - - -def test_asset_path_stays_inside_assets_dir(store: UserDefinedPagesStore): - store.create_page(page_id="asset-page", title="资源页") - with pytest.raises(ValueError, match="path traversal is not allowed"): - store.asset_path("asset-page", "../manifest.json") - - -def test_manifest_roundtrip(store: UserDefinedPagesStore): - store.create_page(page_id="roundtrip", title="原始标题") - manifest = store.save_manifest("roundtrip", {"title": "新标题", "order": 10}) - assert manifest.title == "新标题" - assert manifest.order == 10 - raw = json.loads((store.page_dir("roundtrip") / "manifest.json").read_text(encoding="utf-8")) - assert raw["route"] == "/user-defined-pages/roundtrip" diff --git a/tests/user_defined_pages/test_watcher.py b/tests/user_defined_pages/test_watcher.py deleted file mode 100644 index 535ef23b2..000000000 --- a/tests/user_defined_pages/test_watcher.py +++ /dev/null @@ -1,36 +0,0 @@ -from flocks.user_defined_pages import watcher as watcher_module -from flocks.user_defined_pages.watcher import UserDefinedPagesWatcher, _PendingAction - - -class _RuntimeStub: - async def reload_page(self, _page_id: str): - return [{"method": "GET", "path": "/stats", "handler": "handlers.stats"}] - - -class _BuilderStub: - def build(self, _page_id: str): - raise AssertionError("build should not be called for api-only change") - - -def test_watcher_api_change_uses_main_loop_bridge(monkeypatch): - emitted: list[tuple[str, dict]] = [] - bridge_calls: list[str] = [] - - def _bridge(coro, *, timeout_seconds=5.0): - bridge_calls.append("called") - coro.close() - return [{"method": "GET", "path": "/stats", "handler": "handlers.stats"}] - - def _emit(event_type: str, properties: dict): - emitted.append((event_type, properties)) - - monkeypatch.setattr(watcher_module, "_run_on_main_loop_sync", _bridge) - monkeypatch.setattr(watcher_module, "_publish_event_sync", _emit) - - watcher = UserDefinedPagesWatcher(builder=_BuilderStub(), api_runtime=_RuntimeStub()) - watcher._pending_pages["demo-page"] = _PendingAction(api_changed=True) - watcher._run_pending_builds() - - assert bridge_calls == ["called"] - assert emitted[0][0] == "user_defined_pages.api_changed" - assert emitted[0][1]["id"] == "demo-page" diff --git a/tests/workflow/test_execution_store_compact.py b/tests/workflow/test_execution_store_compact.py index 9e95f2813..21a268e4a 100644 --- a/tests/workflow/test_execution_store_compact.py +++ b/tests/workflow/test_execution_store_compact.py @@ -22,6 +22,7 @@ import pytest from flocks.workflow.execution_store import ( DEFAULT_COMPACT_SIZE_THRESHOLD, + DEFAULT_GENERIC_SEQUENCE_THRESHOLD, DEFAULT_LARGE_LIST_KEYS, _trim_execution_history, compact_history_for_storage, @@ -29,11 +30,9 @@ compact_outputs_for_storage, compact_step_for_storage, record_execution_result, - workflow_execution_index_key, - workflow_execution_step_prefix, workflow_execution_step_key, ) -from flocks.storage.storage import Storage +from flocks.workflow.store import WorkflowStore def _make_alerts(n: int) -> List[Dict[str, Any]]: @@ -77,15 +76,26 @@ def test_compact_outputs_keeps_small_lists_verbatim() -> None: assert "_enriched_alerts_count" not in compacted -def test_compact_outputs_ignores_unknown_keys() -> None: - big_unknown = _make_alerts(5_000) +def test_compact_outputs_summarizes_unknown_large_sequences() -> None: + big_unknown = _make_alerts(DEFAULT_GENERIC_SEQUENCE_THRESHOLD + 1) outputs = {"some_other_alerts": big_unknown} compacted = compact_outputs_for_storage(outputs) - # Unknown keys are not in the default large-list set; they must pass - # through even if huge, so callers don't get surprising drops. - assert compacted["some_other_alerts"] is big_unknown + assert compacted["some_other_alerts"]["_type"] == "list" + assert compacted["some_other_alerts"]["count"] == DEFAULT_GENERIC_SEQUENCE_THRESHOLD + 1 + assert len(compacted["some_other_alerts"]["preview"]) == 3 + assert compacted["some_other_alerts"] is not big_unknown + + +def test_compact_outputs_summarizes_large_strings() -> None: + outputs = {"huge_text": "x" * 25_000} + + compacted = compact_outputs_for_storage(outputs) + + assert compacted["huge_text"]["_type"] == "string" + assert compacted["huge_text"]["chars"] == 25_000 + assert len(compacted["huge_text"]["preview"]) < 25_000 def test_compact_outputs_accepts_custom_keys_and_threshold() -> None: @@ -157,7 +167,7 @@ def test_compact_outputs_drastically_reduces_serialised_size() -> None: after = len(json.dumps(compact_outputs_for_storage(outputs)).encode()) assert before > 1_000_000 # ≥ 1 MB before - assert after < 1_000 # < 1 KB after + assert after < 1_000 # < 1 KB after assert before / after > 1_000 @@ -267,15 +277,13 @@ def test_compact_execution_summary_drops_execution_log() -> None: def test_workflow_execution_step_key_is_append_only_namespaced() -> None: - assert ( - workflow_execution_step_key("exec-1", 12) - == "workflow_execution_step/exec-1/00000012" - ) + assert workflow_execution_step_key("exec-1", 12) == "workflow_execution_step/exec-1/00000012" @pytest.mark.asyncio async def test_record_execution_result_backfills_execution_log_steps() -> None: - storage_write = AsyncMock(return_value=None) + record_step = AsyncMock(return_value=None) + upsert_execution = AsyncMock(return_value=None) update_stats = AsyncMock(return_value=None) exec_data = { "id": "exec-1", @@ -292,22 +300,25 @@ def raise_create_task(coro, *args, **kwargs): # noqa: ANN001, ARG001 coro.close() raise RuntimeError - with patch.object(Storage, "write", storage_write), \ - patch("flocks.workflow.execution_store._update_workflow_stats", update_stats), \ - patch("flocks.session.recorder.Recorder.record_workflow_execution", AsyncMock(return_value=None)), \ - patch("flocks.workflow.execution_store.asyncio.create_task", side_effect=raise_create_task), \ - patch("flocks.workflow.execution_store._trim_execution_history", AsyncMock(return_value=None)): + with ( + patch.object(WorkflowStore, "record_step", record_step), + patch.object(WorkflowStore, "upsert_execution", upsert_execution), + patch("flocks.workflow.execution_store._update_workflow_stats", update_stats), + patch("flocks.session.recorder.Recorder.record_workflow_execution", AsyncMock(return_value=None)), + patch("flocks.workflow.execution_store.asyncio.create_task", side_effect=raise_create_task), + patch("flocks.workflow.execution_store._trim_execution_history", AsyncMock(return_value=None)), + ): await record_execution_result("wf", "exec-1", exec_data) - write_calls = storage_write.await_args_list - assert write_calls[0].args[0] == "workflow_execution_step/exec-1/00000001" - assert write_calls[0].args[1]["outputs"] == {"_raw_alerts_count": 150} - assert write_calls[1].args[0] == "workflow_execution_step/exec-1/00000002" - assert write_calls[1].args[1]["inputs"] == {"_filtered_alerts_count": 150} - assert write_calls[2].args[0] == "workflow_execution/exec-1" - assert write_calls[2].args[1]["executionLog"] == [] - assert write_calls[2].args[1]["stepCount"] == 2 - assert write_calls[3].args[0].startswith("workflow_execution_index/wf/") + step_calls = record_step.await_args_list + assert step_calls[0].args[:2] == ("exec-1", 1) + assert step_calls[0].args[2]["outputs"] == {"_raw_alerts_count": 150} + assert step_calls[1].args[:2] == ("exec-1", 2) + assert step_calls[1].args[2]["inputs"] == {"_filtered_alerts_count": 150} + upsert_execution.assert_awaited_once() + summary = upsert_execution.await_args.args[0] + assert summary["executionLog"] == [] + assert summary["stepCount"] == 2 def test_compact_history_compacts_each_step_inputs() -> None: @@ -398,19 +409,8 @@ async def test_trim_execution_history_keeps_only_30_and_deletes_matching_jsonl( tmp_path, ) -> None: workflow_id = "wf-trim" - indexed_rows = [] for idx in range(32): exec_id = f"exec-{idx:02d}" - index_key = workflow_execution_index_key(workflow_id, idx, exec_id) - indexed_rows.append(( - index_key, - { - "workflowId": workflow_id, - "execId": exec_id, - "executionKey": f"workflow_execution/{exec_id}", - "startedAt": idx, - }, - )) workflow_record = tmp_path / "workflow" / f"{exec_id}.jsonl" workflow_record.parent.mkdir(parents=True, exist_ok=True) workflow_record.write_text('{"type":"workflow.summary"}\n', encoding="utf-8") @@ -421,28 +421,15 @@ async def test_trim_execution_history_keeps_only_30_and_deletes_matching_jsonl( other_record.parent.mkdir(parents=True, exist_ok=True) other_record.write_text('{"type":"workflow.summary"}\n', encoding="utf-8") - remove_mock = AsyncMock(return_value=None) - clear_mock = AsyncMock(return_value=2) - list_raw_mock = AsyncMock(return_value=[]) + trim_mock = AsyncMock(return_value=["exec-00", "exec-01"]) - with patch.object(Storage, "list_entries", AsyncMock(return_value=indexed_rows)), \ - patch.object(Storage, "list_raw", list_raw_mock), \ - patch.object(Storage, "clear", clear_mock), \ - patch.object(Storage, "remove", remove_mock), \ - patch("flocks.session.recorder._record_dir", return_value=tmp_path): + with ( + patch.object(WorkflowStore, "trim_executions", trim_mock), + patch("flocks.session.recorder._record_dir", return_value=tmp_path), + ): await _trim_execution_history(workflow_id) - list_raw_mock.assert_not_awaited() - removed_keys = [call.args[0] for call in remove_mock.await_args_list] - assert "workflow_execution/exec-00" in removed_keys - assert "workflow_execution/exec-01" in removed_keys - assert workflow_execution_index_key(workflow_id, 0, "exec-00") in removed_keys - assert workflow_execution_index_key(workflow_id, 1, "exec-01") in removed_keys - cleared_prefixes = [call.args[0] for call in clear_mock.await_args_list] - assert cleared_prefixes == [ - workflow_execution_step_prefix("exec-00"), - workflow_execution_step_prefix("exec-01"), - ] + trim_mock.assert_awaited_once_with(workflow_id, keep=30) assert not (tmp_path / "workflow" / "exec-00.jsonl").exists() assert not (tmp_path / "workflow" / "exec-01.jsonl").exists() assert (tmp_path / "workflow" / "exec-02.jsonl").exists() @@ -452,64 +439,28 @@ async def test_trim_execution_history_keeps_only_30_and_deletes_matching_jsonl( @pytest.mark.asyncio async def test_trim_execution_history_uses_index_without_full_scan(tmp_path) -> None: workflow_id = "wf-indexed" - indexed_rows = [] for idx in range(32): exec_id = f"exec-{idx:02d}" - index_key = workflow_execution_index_key(workflow_id, idx, exec_id) - indexed_rows.append(( - index_key, - { - "workflowId": workflow_id, - "execId": exec_id, - "executionKey": f"workflow_execution/{exec_id}", - "startedAt": idx, - }, - )) workflow_record = tmp_path / "workflow" / f"{exec_id}.jsonl" workflow_record.parent.mkdir(parents=True, exist_ok=True) workflow_record.write_text('{"type":"workflow.summary"}\n', encoding="utf-8") - list_raw_mock = AsyncMock(return_value=[]) - clear_mock = AsyncMock(return_value=1) - remove_mock = AsyncMock(return_value=True) + trim_mock = AsyncMock(return_value=["exec-00", "exec-01"]) - with patch.object(Storage, "list_entries", AsyncMock(return_value=indexed_rows)), \ - patch.object(Storage, "list_raw", list_raw_mock), \ - patch.object(Storage, "clear", clear_mock), \ - patch.object(Storage, "remove", remove_mock), \ - patch("flocks.session.recorder._record_dir", return_value=tmp_path): + with ( + patch.object(WorkflowStore, "trim_executions", trim_mock), + patch("flocks.session.recorder._record_dir", return_value=tmp_path), + ): await _trim_execution_history(workflow_id) - list_raw_mock.assert_not_awaited() - assert [call.args[0] for call in clear_mock.await_args_list] == [ - workflow_execution_step_prefix("exec-00"), - workflow_execution_step_prefix("exec-01"), - ] - removed = [call.args[0] for call in remove_mock.await_args_list] - assert "workflow_execution/exec-00" in removed - assert workflow_execution_index_key(workflow_id, 0, "exec-00") in removed + trim_mock.assert_awaited_once_with(workflow_id, keep=30) + assert not (tmp_path / "workflow" / "exec-00.jsonl").exists() + assert not (tmp_path / "workflow" / "exec-01.jsonl").exists() @pytest.mark.asyncio async def test_trim_execution_history_surfaces_delete_failures() -> None: workflow_id = "wf-trim-fail" - indexed_rows = [ - ( - workflow_execution_index_key(workflow_id, idx, f"exec-{idx:02d}"), - { - "workflowId": workflow_id, - "execId": f"exec-{idx:02d}", - "executionKey": f"workflow_execution/exec-{idx:02d}", - "startedAt": idx, - }, - ) - for idx in range(31) - ] - list_raw_mock = AsyncMock(return_value=[]) - - with patch.object(Storage, "list_entries", AsyncMock(return_value=indexed_rows)), \ - patch.object(Storage, "list_raw", list_raw_mock), \ - patch.object(Storage, "clear", AsyncMock(side_effect=RuntimeError("locked"))): - with pytest.raises(RuntimeError, match="Failed to trim workflow execution history"): + with patch.object(WorkflowStore, "trim_executions", AsyncMock(side_effect=RuntimeError("locked"))): + with pytest.raises(RuntimeError, match="locked"): await _trim_execution_history(workflow_id) - list_raw_mock.assert_not_awaited() diff --git a/tests/workflow/test_poller_manager.py b/tests/workflow/test_poller_manager.py index 63534e535..4b7f915df 100644 --- a/tests/workflow/test_poller_manager.py +++ b/tests/workflow/test_poller_manager.py @@ -16,10 +16,10 @@ async def test_restart_disabled_config_reports_stopped(monkeypatch: pytest.MonkeyPatch) -> None: manager = poller_manager.WorkflowPollerManager() - async def _fake_read(_key: str) -> dict[str, Any]: + async def _fake_get_config(_workflow_id: str, *, kind: str) -> dict[str, Any]: return {"enabled": False} - monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager.WorkflowStore, "get_config", _fake_get_config) status = await manager.restart_workflow("wf-disabled") assert status["state"] == "stopped" @@ -30,10 +30,10 @@ async def _fake_read(_key: str) -> dict[str, Any]: async def test_restart_missing_workflow_reports_failed(monkeypatch: pytest.MonkeyPatch) -> None: manager = poller_manager.WorkflowPollerManager() - async def _fake_read(_key: str) -> dict[str, Any]: + async def _fake_get_config(_workflow_id: str, *, kind: str) -> dict[str, Any]: return {"enabled": True, "intervalSeconds": 30} - monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager.WorkflowStore, "get_config", _fake_get_config) monkeypatch.setattr(poller_manager, "read_workflow_from_fs", lambda _workflow_id: None) status = await manager.restart_workflow("wf-missing") @@ -46,7 +46,7 @@ async def test_run_once_injects_dynamic_inputs_and_summary(monkeypatch: pytest.M manager = poller_manager.WorkflowPollerManager() captured_inputs: dict[str, Any] = {} - async def _fake_read(_key: str) -> dict[str, Any]: + async def _fake_get_config(_workflow_id: str, *, kind: str) -> dict[str, Any]: return { "enabled": False, "timeoutSeconds": 9, @@ -61,11 +61,16 @@ def _fake_run_workflow( # noqa: ANN001 trace: bool, cancel, on_step_complete, + run_id: str, + execution_profile: str, ): captured_inputs.update(inputs) - assert workflow == {"start": "n1", "nodes": [], "edges": []} + assert workflow["start"] == "n1" + assert workflow["nodes"][0]["id"] == "n1" assert timeout_s == 9 assert trace is False + assert run_id == "exec-wf-run-once" + assert execution_profile == "high_frequency" assert cancel() is False return RunWorkflowResult( status="success", @@ -78,25 +83,34 @@ def _fake_run_workflow( # noqa: ANN001 }, ) - monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager.WorkflowStore, "get_config", _fake_get_config) monkeypatch.setattr( poller_manager, "read_workflow_from_fs", - lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + lambda _workflow_id: { + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + }, ) monkeypatch.setattr( poller_manager, "create_execution_record", - lambda workflow_id, *, input_params=None, exec_id=None: asyncio.sleep(0, result={ - "id": exec_id or f"exec-{workflow_id}", - "workflowId": workflow_id, - "inputParams": input_params or {}, - "status": "running", - "startedAt": 111, - "executionLog": [], - "currentPhase": "queued", - "currentStepIndex": 0, - }), + lambda workflow_id, *, input_params=None, exec_id=None: asyncio.sleep( + 0, + result={ + "id": exec_id or f"exec-{workflow_id}", + "workflowId": workflow_id, + "inputParams": input_params or {}, + "status": "running", + "startedAt": 111, + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + }, + ), ) monkeypatch.setattr( poller_manager, @@ -127,7 +141,7 @@ async def test_run_once_records_execution_and_normalizes_business_failure( recorded_results: list[dict[str, Any]] = [] recorded_steps: list[tuple[str, int, dict[str, Any]]] = [] - async def _fake_read(_key: str) -> dict[str, Any]: + async def _fake_get_config(_workflow_id: str, *, kind: str) -> dict[str, Any]: return { "enabled": False, "timeoutSeconds": 9, @@ -177,10 +191,15 @@ def _fake_run_workflow( # noqa: ANN001 trace: bool, cancel, on_step_complete, + run_id: str, + execution_profile: str, ): - assert workflow == {"start": "n1", "nodes": [], "edges": []} + assert workflow["start"] == "n1" + assert workflow["nodes"][0]["id"] == "n1" assert timeout_s == 9 assert trace is False + assert run_id == "exec-1" + assert execution_profile == "high_frequency" assert cancel() is False assert inputs["dedup_source_workflow_name"] == "stream_alert_denoise_gt_fast" on_step_complete( @@ -191,7 +210,7 @@ def _fake_run_workflow( # noqa: ANN001 "inputs": {"iteration": 1, "total_iterations": 2}, "outputs": {"load_stats": {"record_count": 9}}, } - ) + ), ) return RunWorkflowResult( status="SUCCEEDED", @@ -205,11 +224,17 @@ def _fake_run_workflow( # noqa: ANN001 }, ) - monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager.WorkflowStore, "get_config", _fake_get_config) monkeypatch.setattr( poller_manager, "read_workflow_from_fs", - lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + lambda _workflow_id: { + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + }, ) monkeypatch.setattr(poller_manager, "create_execution_record", _fake_create_execution_record) monkeypatch.setattr(poller_manager, "record_execution_result", _fake_record_execution_result) @@ -239,7 +264,7 @@ async def test_no_overlap_skips_when_previous_run_is_still_active( monkeypatch: pytest.MonkeyPatch, ) -> None: manager = poller_manager.WorkflowPollerManager() - threading_event = asyncio.Event() + threading_event = threading.Event() config = { "enabled": True, @@ -257,27 +282,33 @@ def _fake_run_workflow( # noqa: ANN001 trace: bool, cancel, on_step_complete, + run_id: str, + execution_profile: str, ): - _ = workflow, inputs, timeout_s, trace, cancel + _ = workflow, inputs, timeout_s, trace, cancel, run_id _ = on_step_complete + assert execution_profile == "high_frequency" # Keep the run active until the test releases it so a second tick skips. - asyncio.run(asyncio.wait_for(threading_event.wait(), timeout=2.0)) + assert threading_event.wait(timeout=2.0) return RunWorkflowResult(status="success", outputs={"load_stats": {"record_count": 1}}) monkeypatch.setattr(poller_manager, "run_workflow", _fake_run_workflow) monkeypatch.setattr( poller_manager, "create_execution_record", - lambda workflow_id, *, input_params=None, exec_id=None: asyncio.sleep(0, result={ - "id": exec_id or f"exec-{workflow_id}", - "workflowId": workflow_id, - "inputParams": input_params or {}, - "status": "running", - "startedAt": 111, - "executionLog": [], - "currentPhase": "queued", - "currentStepIndex": 0, - }), + lambda workflow_id, *, input_params=None, exec_id=None: asyncio.sleep( + 0, + result={ + "id": exec_id or f"exec-{workflow_id}", + "workflowId": workflow_id, + "inputParams": input_params or {}, + "status": "running", + "startedAt": 111, + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + }, + ), ) monkeypatch.setattr( poller_manager, @@ -287,7 +318,13 @@ def _fake_run_workflow( # noqa: ANN001 monkeypatch.setattr( poller_manager, "read_workflow_from_fs", - lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + lambda _workflow_id: { + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + }, ) await manager._schedule_run("wf-overlap", {"start": "n1", "nodes": [], "edges": []}, config) @@ -341,10 +378,13 @@ def _fake_run_workflow( # noqa: ANN001 trace: bool, cancel, on_step_complete, + run_id: str, + execution_profile: str, ): - _ = workflow, inputs, timeout_s, trace, cancel + _ = workflow, inputs, timeout_s, trace, cancel, run_id _ = on_step_complete - release_run.wait(timeout=0.2) + assert execution_profile == "high_frequency" + release_run.wait(0.2) return RunWorkflowResult(status="SUCCEEDED", run_id="run-stop") monkeypatch.setattr(poller_manager, "RUN_SHUTDOWN_GRACE_SECONDS", 0.01) @@ -372,21 +412,17 @@ async def test_start_all_only_restarts_enabled_configs(monkeypatch: pytest.Monke manager = poller_manager.WorkflowPollerManager() restarted: list[str] = [] - async def _fake_list_keys(_prefix: str) -> list[str]: + async def _fake_list_configs(*, kind: str) -> list[tuple[str, dict[str, Any]]]: return [ - "workflow_poller_config/wf-enabled", - "workflow_poller_config/wf-disabled", + ("wf-enabled", {"enabled": True}), + ("wf-disabled", {"enabled": False}), ] - async def _fake_read(key: str) -> dict[str, Any]: - return {"enabled": key.endswith("wf-enabled")} - async def _fake_restart(workflow_id: str) -> dict[str, Any]: restarted.append(workflow_id) return {"workflowId": workflow_id, "state": "running"} - monkeypatch.setattr(poller_manager.Storage, "list_keys", _fake_list_keys) - monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager.WorkflowStore, "list_configs", _fake_list_configs) monkeypatch.setattr(manager, "restart_workflow", _fake_restart) await manager.start_all() @@ -398,17 +434,23 @@ async def test_restart_workflow_replaces_existing_task(monkeypatch: pytest.Monke manager = poller_manager.WorkflowPollerManager() config = {"enabled": True, "intervalSeconds": 30, "timeoutSeconds": 10, "noOverlap": True, "inputs": {}} - async def _fake_read(_key: str) -> dict[str, Any]: + async def _fake_get_config(_workflow_id: str, *, kind: str) -> dict[str, Any]: return config async def _fake_loop(*args, **kwargs) -> None: # noqa: ANN002, ANN003 await asyncio.sleep(60) - monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager.WorkflowStore, "get_config", _fake_get_config) monkeypatch.setattr( poller_manager, "read_workflow_from_fs", - lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + lambda _workflow_id: { + "workflowJson": { + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + }, ) monkeypatch.setattr(manager, "_poller_loop", _fake_loop) diff --git a/tests/workflow/test_run_workflow_history.py b/tests/workflow/test_run_workflow_history.py index 9bd6d3f40..91bce7c2e 100644 --- a/tests/workflow/test_run_workflow_history.py +++ b/tests/workflow/test_run_workflow_history.py @@ -65,14 +65,14 @@ async def test_workflow_history_in_output(): assert result.success is True assert result.output is not None - # Final tool metadata should not retain full per-step history in memory. - assert "history" in result.metadata - history = result.metadata["history"] - assert history == [] + # Final tool metadata should not retain full per-step history or outputs in memory. + assert "history" not in result.metadata + assert "outputs" not in result.metadata - # Verify final outputs in metadata - assert "outputs" in result.metadata - assert result.metadata["outputs"]["final"] == 35 + # Verify final outputs remain available in the agent-facing output text. + assert result.metadata["has_output"] is True + assert result.metadata["output_keys"] == ["final"] + assert '"final": 35' in result.output # Verify output text no longer expands the full execution history assert "Status: SUCCEEDED" in result.output @@ -123,9 +123,8 @@ async def test_workflow_history_with_error(): # Per-step details are written through execution step rows, not retained # in the final ToolResult metadata. - assert "history" in result.metadata - history = result.metadata["history"] - assert history == [] + assert "history" not in result.metadata + assert "outputs" not in result.metadata # Output should contain only the top-level failure summary assert "Error:" in result.output @@ -163,8 +162,8 @@ async def test_workflow_history_with_stdout(): assert result.success is True - history = result.metadata["history"] - assert history == [] + assert "history" not in result.metadata + assert "outputs" not in result.metadata # Output should stay concise and omit per-step stdout details assert "Stdout:" not in result.output diff --git a/tests/workflow/test_tool_run_workflow.py b/tests/workflow/test_tool_run_workflow.py index 7edfd9437..95e9d038a 100644 --- a/tests/workflow/test_tool_run_workflow.py +++ b/tests/workflow/test_tool_run_workflow.py @@ -47,10 +47,30 @@ def _make_large_alerts(count: int) -> list[dict[str, Any]]: return [{"id": idx, "payload": "x" * 20} for idx in range(count)] +def test_format_workflow_result_keeps_raw_outputs() -> None: + large_alerts = _make_large_alerts(5_000) + + text = run_workflow_module._format_workflow_result( + { + "status": "SUCCEEDED", + "steps": 1, + "outputs": { + "enriched_alerts": large_alerts, + "message": "done", + }, + } + ) + + assert '"enriched_alerts"' in text + assert '"payload"' in text + assert "_enriched_alerts_count" not in text + + # ============================================================================= # Fixtures # ============================================================================= + @pytest.fixture(autouse=True) def init_tool_registry(): """Ensure ToolRegistry is initialized before each test""" @@ -73,10 +93,10 @@ def tool_context(): def tool_context_with_permission(): """Create a tool context with permission tracking""" permissions_requested = [] - + async def track_permission(request): permissions_requested.append(request) - + ctx = ToolContext( session_id="test-session-workflow-perm", message_id="test-message-workflow-perm", @@ -97,13 +117,9 @@ def simple_workflow(): "metadata": {}, "start": "node-1", "nodes": [ - { - "id": "node-1", - "type": "python", - "code": "result = {'message': 'Hello from workflow!', 'value': 42}" - } + {"id": "node-1", "type": "python", "code": "result = {'message': 'Hello from workflow!', 'value': 42}"} ], - "edges": [] + "edges": [], } @@ -114,18 +130,10 @@ def workflow_with_requirements(): "id": "test-workflow-002", "name": "Test Workflow with Requirements", "start": "node-1", - "metadata": { - "requirements": ["requests>=2.31,<3"] - }, + "metadata": {"requirements": ["requests>=2.31,<3"]}, "start": "node-1", - "nodes": [ - { - "id": "node-1", - "type": "python", - "code": "result = {'status': 'ok'}" - } - ], - "edges": [] + "nodes": [{"id": "node-1", "type": "python", "code": "result = {'status': 'ok'}"}], + "edges": [], } @@ -142,10 +150,10 @@ def workflow_with_inputs(): { "id": "node-1", "type": "python", - "code": "greeting = f'Hello, {inputs.get(\"name\", \"World\")}!'; result = {'greeting': greeting}" + "code": "greeting = f'Hello, {inputs.get(\"name\", \"World\")}!'; result = {'greeting': greeting}", } ], - "edges": [] + "edges": [], } @@ -153,9 +161,10 @@ def workflow_with_inputs(): # Test Tool Registration # ============================================================================= + class TestRunWorkflowToolRegistration: """Test run_workflow tool registration""" - + def test_run_workflow_tool_exists(self): """Test that run_workflow tool is registered""" ToolRegistry.init() @@ -164,7 +173,7 @@ def test_run_workflow_tool_exists(self): assert tool.info.name == "run_workflow" assert tool.info.category.value == "system" assert tool.info.requires_confirmation is True - + def test_run_workflow_tool_schema(self): """Test run_workflow tool schema""" ToolRegistry.init() @@ -182,9 +191,10 @@ def test_run_workflow_tool_schema(self): # Test Error Handling (flocks_workflow not available) # ============================================================================= + class TestRunWorkflowToolWithoutDependency: """Test run_workflow tool when flocks_workflow is not available""" - + @pytest.mark.anyio async def test_run_workflow_without_flocks_workflow(self, tool_context, simple_workflow): """Test that tool returns error when flocks_workflow is not available""" @@ -202,9 +212,10 @@ async def test_run_workflow_without_flocks_workflow(self, tool_context, simple_w # Test Parameter Validation # ============================================================================= + class TestRunWorkflowToolValidation: """Test run_workflow tool parameter validation""" - + @pytest.mark.anyio async def test_run_workflow_missing_workflow(self, tool_context): """Test that missing workflow parameter returns error""" @@ -217,7 +228,7 @@ async def test_run_workflow_missing_workflow(self, tool_context): ) assert result.success is False assert "workflow parameter is required" in result.error - + @pytest.mark.anyio async def test_run_workflow_invalid_workflow_type(self, tool_context): """Test that invalid workflow type returns error""" @@ -230,7 +241,7 @@ async def test_run_workflow_invalid_workflow_type(self, tool_context): ) assert result.success is False assert "workflow must be a dictionary or string" in result.error - + @pytest.mark.anyio async def test_run_workflow_empty_workflow(self, tool_context): """Test that empty workflow returns error""" @@ -249,40 +260,44 @@ async def test_run_workflow_empty_workflow(self, tool_context): # Test Workflow Execution (with mocked flocks_workflow) # ============================================================================= + class TestRunWorkflowToolExecution: """Test run_workflow tool execution with mocked dependencies""" - + @pytest.mark.anyio async def test_run_workflow_success(self, tool_context_with_permission, simple_workflow): """Test successful workflow execution""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-123", - "steps": 1, - "last_node_id": "node-1", - "outputs": {"message": "Hello from workflow!", "value": 42}, - "history": [ - {"node_id": "node-1", "status": "SUCCEEDED", "outputs": {"message": "Hello from workflow!", "value": 42}} - ], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-123", + "steps": 1, + "last_node_id": "node-1", + "outputs": {"message": "Hello from workflow!", "value": 42}, + "history": [ + { + "node_id": "node-1", + "status": "SUCCEEDED", + "outputs": {"message": "Hello from workflow!", "value": 42}, + } + ], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={} + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={} ) - + assert result.success is True assert "SUCCEEDED" in result.output - assert "run-123" in result.output + assert "run-123" not in result.output assert "Steps executed: 1" in result.output assert result.metadata["status"] == "success" assert result.metadata["steps"] == 1 - assert result.metadata["run_id"] == "run-123" - + assert "run_id" not in result.metadata + # Check that permission was requested assert len(tool_context_with_permission._permissions_requested) > 0 @@ -296,15 +311,18 @@ async def test_run_workflow_registered_id_updates_execution_history( tool_context_with_permission._metadata_callback = metadata_updates.append def run_side_effect(**kwargs): - kwargs["on_step_start"]("run-registered", 1, MagicMock(id="node-1", type="python"), {}) - kwargs["on_step_complete"]({ - "node_id": "node-1", - "node_type": "python", - "outputs": {"message": "ok"}, - }) + assert kwargs["run_id"] == "exec-registered" + kwargs["on_step_start"](kwargs["run_id"], 1, MagicMock(id="node-1", type="python"), {}) + kwargs["on_step_complete"]( + { + "node_id": "node-1", + "node_type": "python", + "outputs": {"message": "ok"}, + } + ) return FakeRunWorkflowResult( status="SUCCEEDED", - run_id="run-registered", + run_id=kwargs["run_id"], steps=1, last_node_id="node-1", outputs={"message": "ok"}, @@ -313,32 +331,31 @@ def run_side_effect(**kwargs): ) mock_run = Mock(name="run_workflow", side_effect=run_side_effect) - create_execution = AsyncMock(return_value={ - "id": "exec-registered", - "workflowId": "test-workflow-001", - "inputParams": {"name": "Flocks"}, - "status": "running", - "startedAt": 1, - "executionLog": [], - }) - storage_read = AsyncMock(return_value={ - "id": "exec-registered", - "workflowId": "test-workflow-001", - "inputParams": {"name": "Flocks"}, - "status": "running", - "startedAt": 1, - "executionLog": [], - }) - storage_write = AsyncMock(return_value=None) + create_execution = AsyncMock( + return_value={ + "id": "exec-registered", + "workflowId": "test-workflow-001", + "inputParams": {"name": "Flocks"}, + "status": "running", + "startedAt": 1, + "executionLog": [], + } + ) + upsert_execution = AsyncMock(return_value=None) record_result = AsyncMock(return_value=None) - with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)), \ - patch.object(run_workflow_module, "read_workflow_from_fs", return_value={"id": "test-workflow-001", "workflowJson": simple_workflow}), \ - patch.object(run_workflow_module, "resolve_workflow_id_from_source", return_value="test-workflow-001"), \ - patch.object(run_workflow_module, "create_execution_record", create_execution), \ - patch.object(run_workflow_module.Storage, "read", storage_read), \ - patch.object(run_workflow_module.Storage, "write", storage_write), \ - patch.object(run_workflow_module, "record_execution_result", record_result): + with ( + patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)), + patch.object( + run_workflow_module, + "read_workflow_from_fs", + return_value={"id": "test-workflow-001", "workflowJson": simple_workflow}, + ), + patch.object(run_workflow_module, "resolve_workflow_id_from_source", return_value="test-workflow-001"), + patch.object(run_workflow_module, "create_execution_record", create_execution), + patch.object(run_workflow_module.WorkflowStore, "upsert_execution", upsert_execution), + patch.object(run_workflow_module, "record_execution_result", record_result), + ): result = await ToolRegistry.execute( "run_workflow", ctx=tool_context_with_permission, @@ -348,11 +365,74 @@ def run_side_effect(**kwargs): assert result.success is True assert result.metadata["workflow_execution_id"] == "exec-registered" + assert "run_id" not in result.metadata create_execution.assert_awaited_once() record_result.assert_awaited_once() - assert storage_write.await_count >= 1 + assert upsert_execution.await_count >= 1 assert any(update.get("workflow_execution_id") == "exec-registered" for update in metadata_updates) + @pytest.mark.anyio + async def test_run_workflow_registered_id_overrides_missing_workflow_json_id( + self, + tool_context_with_permission, + ): + workflow_without_id = { + "name": "Display Name Only", + "start": "node-1", + "nodes": [{"id": "node-1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + captured_kwargs: dict[str, Any] = {} + + def run_side_effect(**kwargs): + captured_kwargs.update(kwargs) + return FakeRunWorkflowResult( + status="SUCCEEDED", + run_id="run-directory-id", + steps=1, + last_node_id="node-1", + outputs={"workflow_id": kwargs.get("workflow_id")}, + history=[], + error=None, + ) + + mock_run = Mock(name="run_workflow", side_effect=run_side_effect) + create_execution = AsyncMock( + return_value={ + "id": "exec-directory-id", + "workflowId": "wf-directory-id", + "inputParams": {}, + "status": "running", + "startedAt": 1, + "executionLog": [], + } + ) + + with ( + patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)), + patch.object( + run_workflow_module, + "read_workflow_from_fs", + return_value={"id": "wf-directory-id", "workflowJson": workflow_without_id}, + ), + patch.object(run_workflow_module, "create_execution_record", create_execution), + patch.object(run_workflow_module.WorkflowStore, "upsert_execution", AsyncMock(return_value=None)), + patch.object(run_workflow_module, "record_execution_result", AsyncMock(return_value=None)), + ): + result = await ToolRegistry.execute( + "run_workflow", + ctx=tool_context_with_permission, + workflow="wf-directory-id", + inputs={}, + ) + + assert result.success is True + assert result.metadata["workflow_id"] == "wf-directory-id" + assert captured_kwargs["workflow_id"] == "wf-directory-id" + assert captured_kwargs["workflow"]["id"] == "wf-directory-id" + create_execution.assert_awaited_once() + assert create_execution.await_args.args[0] == "wf-directory-id" + @pytest.mark.anyio async def test_run_workflow_compacts_large_outputs_for_progress_and_final_record( self, @@ -364,16 +444,19 @@ async def test_run_workflow_compacts_large_outputs_for_progress_and_final_record tool_context_with_permission._metadata_callback = metadata_updates.append def run_side_effect(**kwargs): - kwargs["on_step_start"]("run-compacted", 1, MagicMock(id="node-1", type="python"), {}) - kwargs["on_step_complete"]({ - "node_id": "node-1", - "node_type": "python", - "inputs": {"raw_alerts": large_alerts, "source": "syslog"}, - "outputs": {"raw_alerts": large_alerts, "message": "ok"}, - }) + assert kwargs["run_id"] == "exec-compacted" + kwargs["on_step_start"](kwargs["run_id"], 1, MagicMock(id="node-1", type="python"), {}) + kwargs["on_step_complete"]( + { + "node_id": "node-1", + "node_type": "python", + "inputs": {"raw_alerts": large_alerts, "source": "syslog"}, + "outputs": {"raw_alerts": large_alerts, "message": "ok"}, + } + ) return FakeRunWorkflowResult( status="SUCCEEDED", - run_id="run-compacted", + run_id=kwargs["run_id"], steps=1, last_node_id="node-1", outputs={"enriched_alerts": large_alerts, "message": "done"}, @@ -382,32 +465,29 @@ def run_side_effect(**kwargs): ) mock_run = Mock(name="run_workflow", side_effect=run_side_effect) - create_execution = AsyncMock(return_value={ - "id": "exec-compacted", - "workflowId": "test-workflow-001", - "inputParams": {}, - "status": "running", - "startedAt": 1, - "executionLog": [], - }) - storage_read = AsyncMock(return_value={ - "id": "exec-compacted", - "workflowId": "test-workflow-001", - "inputParams": {}, - "status": "running", - "startedAt": 1, - "executionLog": [], - }) - storage_write = AsyncMock(return_value=None) + create_execution = AsyncMock( + return_value={ + "id": "exec-compacted", + "workflowId": "test-workflow-001", + "inputParams": {}, + "status": "running", + "startedAt": 1, + "executionLog": [], + } + ) + upsert_execution = AsyncMock(return_value=None) + record_step = AsyncMock(return_value=None) record_result = AsyncMock(return_value=None) - with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)), \ - patch.object(run_workflow_module, "resolve_workflow_id_from_source", return_value="test-workflow-001"), \ - patch.object(run_workflow_module, "create_execution_record", create_execution), \ - patch.object(run_workflow_module.Storage, "read", storage_read), \ - patch.object(run_workflow_module.Storage, "write", storage_write), \ - patch.object(run_workflow_module, "record_execution_result", record_result), \ - patch.object(run_workflow_module, "_record_workflow_tool_result", AsyncMock(return_value=None)): + with ( + patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)), + patch.object(run_workflow_module, "resolve_workflow_id_from_source", return_value="test-workflow-001"), + patch.object(run_workflow_module, "create_execution_record", create_execution), + patch.object(run_workflow_module.WorkflowStore, "upsert_execution", upsert_execution), + patch.object(run_workflow_module, "record_execution_step", record_step), + patch.object(run_workflow_module, "record_execution_result", record_result), + patch.object(run_workflow_module, "_record_workflow_tool_result", AsyncMock(return_value=None)), + ): result = await ToolRegistry.execute( "run_workflow", ctx=tool_context_with_permission, @@ -416,9 +496,8 @@ def run_side_effect(**kwargs): ) assert result.success is True - step_write = storage_write.await_args_list[-1] - assert step_write.args[0] == "workflow_execution_step/exec-compacted/00000001" - step_payload = step_write.args[1] + record_step.assert_awaited() + step_payload = record_step.await_args.args[2] assert step_payload["inputs"] == { "_raw_alerts_count": 150, "source": "syslog", @@ -427,11 +506,10 @@ def run_side_effect(**kwargs): "_raw_alerts_count": 150, "message": "ok", } - assert result.metadata["outputs"] == { - "_enriched_alerts_count": 150, - "message": "done", - } - assert result.metadata["history"] == [] + assert result.metadata["has_output"] is True + assert result.metadata["output_keys"] == ["enriched_alerts", "message"] + assert "outputs" not in result.metadata + assert "history" not in result.metadata assert result.metadata["history_count"] == 0 final_exec_data = record_result.await_args.args[2] @@ -485,112 +563,112 @@ async def test_run_workflow_uses_isolated_child_tool_context( assert nested_ctx.event_publish_callback == tool_context_with_permission.event_publish_callback assert nested_ctx._permission_callback == tool_context_with_permission._permission_callback assert nested_ctx._metadata_callback is None - + @pytest.mark.anyio async def test_run_workflow_with_inputs(self, tool_context_with_permission, workflow_with_inputs): """Test workflow execution with input parameters""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-456", - "steps": 1, - "last_node_id": "node-1", - "outputs": {"greeting": "Hello, Flocks!"}, - "history": [ - {"node_id": "node-1", "status": "SUCCEEDED", "outputs": {"greeting": "Hello, Flocks!"}} - ], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-456", + "steps": 1, + "last_node_id": "node-1", + "outputs": {"greeting": "Hello, Flocks!"}, + "history": [{"node_id": "node-1", "status": "SUCCEEDED", "outputs": {"greeting": "Hello, Flocks!"}}], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( "run_workflow", ctx=tool_context_with_permission, workflow=workflow_with_inputs, - inputs={"name": "Flocks"} + inputs={"name": "Flocks"}, ) - + assert result.success is True assert "SUCCEEDED" in result.output - + @pytest.mark.anyio async def test_run_workflow_with_requirements(self, tool_context_with_permission, workflow_with_requirements): """Test workflow execution with requirements installation""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-789", - "steps": 1, - "last_node_id": "node-1", - "outputs": {"status": "ok"}, - "history": [ - {"node_id": "node-1", "status": "SUCCEEDED", "outputs": {"status": "ok"}} - ], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-789", + "steps": 1, + "last_node_id": "node-1", + "outputs": {"status": "ok"}, + "history": [{"node_id": "node-1", "status": "SUCCEEDED", "outputs": {"status": "ok"}}], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) installer_cls = Mock(name="RequirementsInstaller") - with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run, installer_cls=installer_cls)): + with patch.object( + run_workflow_module, + "_get_workflow_runtime", + return_value=_runtime_tuple(run_fn=mock_run, installer_cls=installer_cls), + ): result = await ToolRegistry.execute( "run_workflow", ctx=tool_context_with_permission, workflow=workflow_with_requirements, inputs={}, - ensure_requirements=True + ensure_requirements=True, ) - + assert result.success is True assert installer_cls.called is True - + @pytest.mark.anyio async def test_run_workflow_with_timeout(self, tool_context_with_permission, simple_workflow): """Test workflow execution with timeout""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-timeout", - "steps": 1, - "last_node_id": "node-1", - "outputs": {}, - "history": [], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-timeout", + "steps": 1, + "last_node_id": "node-1", + "outputs": {}, + "history": [], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={}, - timeout_s=300.0 + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={}, timeout_s=300.0 ) - + assert result.success is True # Verify timeout was passed to run_workflow mock_run.assert_called_once() call_kwargs = mock_run.call_args[1] assert call_kwargs.get("timeout_s") == 300.0 assert call_kwargs.get("use_llm") is True - + @pytest.mark.anyio async def test_run_workflow_with_trace(self, tool_context_with_permission, simple_workflow): """Test workflow execution with tracing enabled""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-trace", - "steps": 1, - "last_node_id": "node-1", - "outputs": {}, - "history": [], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-trace", + "steps": 1, + "last_node_id": "node-1", + "outputs": {}, + "history": [], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={}, - trace=True + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={}, trace=True ) - + assert result.success is True # Verify trace was passed to run_workflow call_kwargs = mock_run.call_args[1] @@ -600,15 +678,17 @@ async def test_run_workflow_with_trace(self, tool_context_with_permission, simpl @pytest.mark.anyio async def test_run_workflow_passes_cancel_callback(self, tool_context_with_permission, simple_workflow): """Session abort should be forwarded to workflow runtime cancellation.""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-cancel", - "steps": 1, - "last_node_id": "node-1", - "outputs": {}, - "history": [], - "error": None, - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-cancel", + "steps": 1, + "last_node_id": "node-1", + "outputs": {}, + "history": [], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( @@ -629,15 +709,17 @@ async def test_run_workflow_passes_cancel_callback(self, tool_context_with_permi @pytest.mark.anyio async def test_run_workflow_disable_llm(self, tool_context_with_permission, simple_workflow): """Test workflow execution with use_llm disabled""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-no-llm", - "steps": 1, - "last_node_id": "node-1", - "outputs": {}, - "history": [], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-no-llm", + "steps": 1, + "last_node_id": "node-1", + "outputs": {}, + "history": [], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( @@ -651,7 +733,7 @@ async def test_run_workflow_disable_llm(self, tool_context_with_permission, simp assert result.success is True call_kwargs = mock_run.call_args[1] assert call_kwargs.get("use_llm") is False - + @pytest.mark.anyio async def test_run_workflow_execution_failure(self, tool_context_with_permission, simple_workflow): """Test workflow execution failure handling""" @@ -659,37 +741,33 @@ async def test_run_workflow_execution_failure(self, tool_context_with_permission mock_run = Mock(name="run_workflow", side_effect=Exception("Workflow execution failed")) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={} + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={} ) - + assert result.success is False assert "Workflow execution failed" in result.error assert result.metadata["status"] == "FAILED" - + @pytest.mark.anyio async def test_run_workflow_failed_status(self, tool_context_with_permission, simple_workflow): """Test workflow execution with FAILED status""" - fake = FakeRunWorkflowResult(**{ - "status": "FAILED", - "run_id": "run-failed", - "steps": 0, - "last_node_id": None, - "outputs": {}, - "history": [], - "error": "NodeExecutionError: Error in node 'node-1'" - }) + fake = FakeRunWorkflowResult( + **{ + "status": "FAILED", + "run_id": "run-failed", + "steps": 0, + "last_node_id": None, + "outputs": {}, + "history": [], + "error": "NodeExecutionError: Error in node 'node-1'", + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={} + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={} ) - + assert result.success is False assert "FAILED" in result.output assert result.metadata["status"] == "error" @@ -699,71 +777,68 @@ async def test_run_workflow_failed_status(self, tool_context_with_permission, si # Test Result Formatting # ============================================================================= + class TestRunWorkflowToolResultFormatting: """Test run_workflow tool result formatting""" - + @pytest.mark.anyio async def test_run_workflow_result_formatting(self, tool_context_with_permission, simple_workflow): """Test that workflow results are properly formatted""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-format", - "steps": 3, - "last_node_id": "node-3", - "outputs": { - "result": "processed", - "count": 42 - }, - "history": [ - {"node_id": "node-1", "status": "SUCCEEDED"}, - {"node_id": "node-2", "status": "SUCCEEDED"}, - {"node_id": "node-3", "status": "SUCCEEDED"}, - ], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-format", + "steps": 3, + "last_node_id": "node-3", + "outputs": {"result": "processed", "count": 42}, + "history": [ + {"node_id": "node-1", "status": "SUCCEEDED"}, + {"node_id": "node-2", "status": "SUCCEEDED"}, + {"node_id": "node-3", "status": "SUCCEEDED"}, + ], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={} + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={} ) - + assert result.success is True output = result.output - + # Check that all key information is present assert "Status: SUCCEEDED" in output - assert "Run ID: run-format" in output + assert "Run ID:" not in output assert "Steps executed: 3" in output assert "Last node: node-3" in output assert "Final Outputs:" in output assert "Execution History" not in output - assert result.metadata["history"] == [] + assert "history" not in result.metadata + assert result.metadata["output_keys"] == ["result", "count"] assert result.metadata["history_count"] == len(fake.history) - + @pytest.mark.anyio async def test_run_workflow_result_with_error(self, tool_context_with_permission, simple_workflow): """Test result formatting when workflow has error""" - fake = FakeRunWorkflowResult(**{ - "status": "FAILED", - "run_id": "run-error", - "steps": 1, - "last_node_id": "node-1", - "outputs": {}, - "history": [], - "error": "NodeExecutionError: Invalid code" - }) + fake = FakeRunWorkflowResult( + **{ + "status": "FAILED", + "run_id": "run-error", + "steps": 1, + "last_node_id": "node-1", + "outputs": {}, + "history": [], + "error": "NodeExecutionError: Invalid code", + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={} + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={} ) - + assert result.success is False assert "Error:" in result.output assert "NodeExecutionError" in result.output @@ -773,30 +848,30 @@ async def test_run_workflow_result_with_error(self, tool_context_with_permission # Test Permission Handling # ============================================================================= + class TestRunWorkflowToolPermissions: """Test run_workflow tool permission handling""" - + @pytest.mark.anyio async def test_run_workflow_requests_permission(self, tool_context_with_permission, simple_workflow): """Test that workflow execution requests permission""" - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-perm", - "steps": 1, - "last_node_id": "node-1", - "outputs": {}, - "history": [], - "error": None - }) + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-perm", + "steps": 1, + "last_node_id": "node-1", + "outputs": {}, + "history": [], + "error": None, + } + ) mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): await ToolRegistry.execute( - "run_workflow", - ctx=tool_context_with_permission, - workflow=simple_workflow, - inputs={} + "run_workflow", ctx=tool_context_with_permission, workflow=simple_workflow, inputs={} ) - + # Verify permission was requested assert len(tool_context_with_permission._permissions_requested) == 1 perm_request = tool_context_with_permission._permissions_requested[0] @@ -810,74 +885,73 @@ async def test_run_workflow_requests_permission(self, tool_context_with_permissi # JSON Parsing Tests (simulating LLM-generated tool calls) # ============================================================================= + class TestRunWorkflowToolJSONParsing: """Test JSON parsing scenarios that occur when LLM generates tool calls.""" - + @pytest.mark.anyio async def test_workflow_path_with_quotes_valid_json(self, tool_context_with_permission, tmp_path): """Test that workflow path with quotes is valid JSON and can be parsed.""" import json - + # Create a workflow file workflow_path = str(tmp_path / "test_workflow.json") workflow_content = { "id": "test-json-parsing", "name": "Test JSON Parsing", "start": "node-1", - "nodes": [ - { - "id": "node-1", - "type": "python", - "code": "outputs['result'] = 'success'" - } - ], - "edges": [] + "nodes": [{"id": "node-1", "type": "python", "code": "outputs['result'] = 'success'"}], + "edges": [], } - with open(workflow_path, 'w') as f: + with open(workflow_path, "w") as f: json.dump(workflow_content, f) - + # Simulate LLM generating JSON with QUOTED path (correct) - arguments_json_string = json.dumps({ - "workflow": workflow_path, # This will be properly quoted in JSON - "inputs": {} - }) - + arguments_json_string = json.dumps( + { + "workflow": workflow_path, # This will be properly quoted in JSON + "inputs": {}, + } + ) + # Verify it's valid JSON parsed_args = json.loads(arguments_json_string) assert parsed_args["workflow"] == workflow_path - + # Now execute with the parsed arguments - fake = FakeRunWorkflowResult(**{ - "status": "SUCCEEDED", - "run_id": "run-json-test", - "steps": 1, - "last_node_id": "node-1", - "outputs": {"result": "success"}, - "history": [] - }) - + fake = FakeRunWorkflowResult( + **{ + "status": "SUCCEEDED", + "run_id": "run-json-test", + "steps": 1, + "last_node_id": "node-1", + "outputs": {"result": "success"}, + "history": [], + } + ) + mock_run = Mock(name="run_workflow", return_value=fake) with patch.object(run_workflow_module, "_get_workflow_runtime", return_value=_runtime_tuple(run_fn=mock_run)): result = await ToolRegistry.execute( "run_workflow", ctx=tool_context_with_permission, - **parsed_args # Unpack parsed arguments + **parsed_args, # Unpack parsed arguments ) - + assert result.success is True - + @pytest.mark.anyio async def test_workflow_path_without_quotes_invalid_json(self): """Test that workflow path without quotes is INVALID JSON and cannot be parsed.""" import json - + # Simulate LLM generating JSON with UNQUOTED path (incorrect - this is the bug) invalid_json_string = '{"workflow": workflow/alert_triage/workflow.json, "inputs": {}}' - + # Verify it's INVALID JSON with pytest.raises(json.JSONDecodeError): json.loads(invalid_json_string) - + # This is exactly what causes "Failed to parse tool arguments" error in production @@ -885,9 +959,10 @@ async def test_workflow_path_without_quotes_invalid_json(self): # Integration Test (if flocks_workflow is available) # ============================================================================= + class TestRunWorkflowToolIntegration: """Integration tests with the real in-repo workflow runtime.""" - + @pytest.mark.anyio async def test_run_workflow_integration(self, tool_context_with_permission): """Integration test that exercises the tool end-to-end.""" @@ -897,23 +972,19 @@ async def test_run_workflow_integration(self, tool_context_with_permission): "metadata": {}, "start": "node-1", "nodes": [ - { - "id": "node-1", - "type": "python", - "code": "outputs['result'] = {'test': 'integration', 'value': 100}" - } + {"id": "node-1", "type": "python", "code": "outputs['result'] = {'test': 'integration', 'value': 100}"} ], - "edges": [] + "edges": [], } - + result = await ToolRegistry.execute( "run_workflow", ctx=tool_context_with_permission, workflow=workflow, inputs={}, - ensure_requirements=False # Skip requirements for test + ensure_requirements=False, # Skip requirements for test ) - + assert result is not None assert result.success is True assert "Status: SUCCEEDED" in (result.output or "") diff --git a/tests/workflow/test_tool_run_workflow_simple.py b/tests/workflow/test_tool_run_workflow_simple.py index 15c4a677b..806cdb250 100644 --- a/tests/workflow/test_tool_run_workflow_simple.py +++ b/tests/workflow/test_tool_run_workflow_simple.py @@ -24,7 +24,7 @@ if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) -from flocks.tool import ToolRegistry, ToolContext +from flocks.tool import ToolContext, ToolRegistry # noqa: E402 def _dump_tool_result(result) -> str: @@ -71,7 +71,7 @@ async def run_simple_workflow_full_result(): print(_dump_tool_result(result)) assert result.success is True - assert result.metadata.get("status") == "SUCCEEDED" + assert result.metadata.get("status") == "success" assert "Status: SUCCEEDED" in (result.output or "") @@ -148,7 +148,7 @@ async def run_specified_workflow_file_full_result(workflow_path=None, query=None # 基础断言:工作流应该成功执行 assert result.success is True, f"Workflow execution failed: {result.error}" - assert result.metadata.get("status") == "SUCCEEDED", f"Expected SUCCEEDED status, got: {result.metadata.get('status')}" + assert result.metadata.get("status") == "success", f"Expected success status, got: {result.metadata.get('status')}" assert "Status: SUCCEEDED" in (result.output or ""), "Output should contain success status" # 验证最终节点:应该是 finalize_output diff --git a/tests/workflow/test_trigger_runtime.py b/tests/workflow/test_trigger_runtime.py index e230a29b1..023ce4f32 100644 --- a/tests/workflow/test_trigger_runtime.py +++ b/tests/workflow/test_trigger_runtime.py @@ -14,10 +14,10 @@ async def test_sync_legacy_configs_disables_explicit_empty_trigger_list( ) -> None: writes: list[tuple[str, dict]] = [] - async def _fake_write(key: str, value: dict) -> None: - writes.append((key, value)) + async def _fake_put_config(workflow_id: str, config: dict, *, kind: str) -> None: + writes.append((f"{kind}/{workflow_id}", config)) - monkeypatch.setattr(runtime_module.Storage, "write", _fake_write) + monkeypatch.setattr(runtime_module.WorkflowStore, "put_config", _fake_put_config) runtime = runtime_module.TriggerRuntime() triggers = await runtime._sync_legacy_configs_from_workflow( # noqa: SLF001 @@ -26,9 +26,7 @@ async def _fake_write(key: str, value: dict) -> None: ) assert triggers == [] - assert { - key for key, _value in writes - } == { + assert {key for key, _value in writes} == { "workflow_poller_config/wf-empty", "workflow_syslog_config/wf-empty", "workflow_kafka_config/wf-empty", @@ -62,9 +60,7 @@ def stop(self) -> None: monkeypatch.setattr( runtime_module, "load_trigger_plugin_module", - lambda _plugin_spec: SimpleNamespace( - create_trigger_adapter=lambda definition: _FakeAdapter(definition) - ), + lambda _plugin_spec: SimpleNamespace(create_trigger_adapter=lambda definition: _FakeAdapter(definition)), ) runtime = runtime_module.TriggerRuntime() diff --git a/tests/workflow/test_workflow_center_lifecycle.py b/tests/workflow/test_workflow_center_lifecycle.py index cf1046578..43d11ab83 100644 --- a/tests/workflow/test_workflow_center_lifecycle.py +++ b/tests/workflow/test_workflow_center_lifecycle.py @@ -54,9 +54,9 @@ async def fake_stop_container(container_name: str) -> bool: return True monkeypatch.setenv("FLOCKS_WORKFLOW_SERVICE_DRIVER", "local") - monkeypatch.setattr(center.Storage, "read", fake_read) - monkeypatch.setattr(center.Storage, "write", fake_write) - monkeypatch.setattr(center.Storage, "remove", fake_remove) + monkeypatch.setattr(center.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(center.WorkflowStore, "kv_put", fake_write) + monkeypatch.setattr(center.WorkflowStore, "kv_remove", fake_remove) monkeypatch.setattr(center, "_stop_and_remove_container", fake_stop_container) result = await center.stop_workflow_service("wf-1") @@ -108,9 +108,9 @@ async def fake_stop_container(container_name: str) -> bool: stopped_containers.append(container_name) return True - monkeypatch.setattr(center.Storage, "read", fake_read) - monkeypatch.setattr(center.Storage, "write", fake_write) - monkeypatch.setattr(center.Storage, "remove", fake_remove) + monkeypatch.setattr(center.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(center.WorkflowStore, "kv_put", fake_write) + monkeypatch.setattr(center.WorkflowStore, "kv_remove", fake_remove) monkeypatch.setattr(center, "_stop_and_remove_container", fake_stop_container) await center._stop_existing_runtime_for_publish("wf-1") @@ -152,8 +152,8 @@ async def fake_read(key): monkeypatch.setenv("FLOCKS_WORKFLOW_SERVICE_PORT_START", "19000") monkeypatch.setenv("FLOCKS_WORKFLOW_SERVICE_PORT_END", "19003") - monkeypatch.setattr(center.Storage, "list_keys", fake_list_keys) - monkeypatch.setattr(center.Storage, "read", fake_read) + monkeypatch.setattr(center.WorkflowStore, "kv_list_keys", fake_list_keys) + monkeypatch.setattr(center.WorkflowStore, "kv_get", fake_read) monkeypatch.setattr(center, "_is_port_available", lambda _port: True) assert await center._allocate_port() == 19003 @@ -169,7 +169,7 @@ async def fake_list_keys(_prefix): monkeypatch.setenv("FLOCKS_WORKFLOW_SERVICE_PORT_START", "19000") monkeypatch.setenv("FLOCKS_WORKFLOW_SERVICE_PORT_END", "19001") - monkeypatch.setattr(center.Storage, "list_keys", fake_list_keys) + monkeypatch.setattr(center.WorkflowStore, "kv_list_keys", fake_list_keys) monkeypatch.setattr(center, "_is_port_available", lambda _port: True) try: @@ -187,12 +187,14 @@ async def test_publish_workflow_local_releases_reserved_port_on_spawn_failure( workflow_id = "wf-local-spawn-fail" workflow_path = tmp_path / "workflow.json" workflow_path.write_text( - json.dumps({ - "id": workflow_id, - "start": "n1", - "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], - "edges": [], - }), + json.dumps( + { + "id": workflow_id, + "start": "n1", + "nodes": [{"id": "n1", "type": "python", "code": "outputs['ok'] = True"}], + "edges": [], + } + ), encoding="utf-8", ) store: dict[str, Any] = { @@ -223,8 +225,8 @@ async def fake_create_subprocess_exec(*_args, **_kwargs): raise OSError("spawn failed") center._IN_FLIGHT_PORT_RESERVATIONS.clear() - monkeypatch.setattr(center.Storage, "read", fake_read) - monkeypatch.setattr(center.Storage, "write", fake_write) + monkeypatch.setattr(center.WorkflowStore, "kv_get", fake_read) + monkeypatch.setattr(center.WorkflowStore, "kv_put", fake_write) monkeypatch.setattr(center, "_stop_existing_runtime_for_publish", fake_stop_existing_runtime_for_publish) monkeypatch.setattr(center, "_write_release_snapshot", fake_write_release_snapshot) monkeypatch.setattr(center, "_allocate_port", fake_allocate_port) diff --git a/tests/workflow/test_workflow_execution_plan.py b/tests/workflow/test_workflow_execution_plan.py new file mode 100644 index 000000000..264f1a630 --- /dev/null +++ b/tests/workflow/test_workflow_execution_plan.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +from flocks.workflow import runner as runner_module +from flocks.workflow.execution_plan import WorkflowExecutionPlan, build_workflow_execution_plan +from flocks.workflow.models import Workflow +from flocks.workflow.runner import run_workflow + + +def _workflow() -> Workflow: + return Workflow.from_dict( + { + "start": "start", + "nodes": [ + { + "id": "start", + "type": "python", + "code": "outputs['ok'] = inputs.get('value', 1)", + } + ], + "edges": [], + } + ) + + +def test_run_workflow_accepts_execution_plan_without_rebuilding(monkeypatch) -> None: + plan = build_workflow_execution_plan(_workflow()) + + def _fail_build(*args, **kwargs): # noqa: ANN002, ANN003 + raise AssertionError("plan should not be rebuilt") + + monkeypatch.setattr(runner_module, "build_workflow_execution_plan", _fail_build) + monkeypatch.setattr(runner_module, "_resolve_workflow_runtime_preference", lambda _ctx: "host") + monkeypatch.setattr(runner_module, "get_tool_registry", lambda tool_context=None: None) + + result = run_workflow( + workflow=plan, + inputs={"value": 7}, + ensure_requirements=False, + ) + + assert result.status == "SUCCEEDED" + assert result.outputs == {"ok": 7} + + +def test_high_frequency_profile_uses_lightweight_runtime_options(monkeypatch) -> None: + captured_init: dict[str, Any] = {} + captured_run: dict[str, Any] = {} + + class FakeEngine: + def __init__(self, *args, **kwargs): # noqa: ANN002, ANN003 + captured_init.update(kwargs) + + def run(self, *args, **kwargs): # noqa: ANN002, ANN003 + captured_run.update(kwargs) + return SimpleNamespace( + run_id=kwargs.get("run_id"), + steps=1, + last_node_id="start", + outputs={"ok": True}, + history=[], + ) + + monkeypatch.setattr(runner_module, "WorkflowEngine", FakeEngine) + monkeypatch.setattr(runner_module, "_resolve_workflow_runtime_preference", lambda _ctx: "host") + monkeypatch.setattr(runner_module, "get_tool_registry", lambda tool_context=None: None) + + result = run_workflow( + workflow=_workflow(), + trace=True, + node_timeout_s=300, + history_mode="full", + retain_history=True, + execution_profile="high_frequency", + run_id="exec-1", + ensure_requirements=False, + ) + + assert result.status == "SUCCEEDED" + assert captured_init["trace"] is False + assert captured_init["node_timeout_s"] == 300 + assert captured_init["history_mode"] == "summary" + assert isinstance(captured_init["execution_plan"], WorkflowExecutionPlan) + assert captured_run["retain_history"] is False + assert captured_run["run_id"] == "exec-1" diff --git a/tests/workflow/test_workflow_fixes.py b/tests/workflow/test_workflow_fixes.py index 5b0cddc98..c516a8e2b 100644 --- a/tests/workflow/test_workflow_fixes.py +++ b/tests/workflow/test_workflow_fixes.py @@ -7,17 +7,24 @@ from __future__ import annotations import json +from pathlib import Path from typing import Any, Dict from unittest.mock import MagicMock, patch import pytest +from flocks.tool import Tool, ToolCategory, ToolContext, ToolInfo, ToolRegistry, ToolResult +from flocks.workflow.errors import WorkflowValidationError from flocks.workflow.models import Workflow from flocks.workflow.engine import WorkflowEngine +from flocks.workflow import fs_store from flocks.workflow.repl_runtime import PythonExecRuntime +from flocks.workflow.runner import run_workflow +from flocks.workflow import tools_adapter as tools_adapter_module from flocks.workflow.tools import ToolFacade from flocks.workflow.tools_adapter import FlocksToolAdapter from flocks.workflow.workflow_lint import ( + lint_implicit_full_payload_edges, lint_expensive_node_multi_trigger, lint_join_requirements, lint_workflow, @@ -28,6 +35,7 @@ # Helper: mock adapter that returns controllable outputs # --------------------------------------------------------------------------- + class _MockToolAdapter(FlocksToolAdapter): """FlocksToolAdapter subclass that bypasses real tool registry.""" @@ -42,11 +50,60 @@ def run(self, name: str, /, **kwargs: Any) -> Any: val = self._outputs[name] if isinstance(val, Exception): from flocks.workflow.errors import NodeExecutionError + raise NodeExecutionError(node_id="", message=str(val)) return val return f"mock_output_for_{name}" +def test_subworkflow_loader_reads_filesystem_workflow_without_legacy_kv( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +): + workflow_id = "child-filesystem-workflow" + workflow_dir = tmp_path / ".flocks" / "plugins" / "workflows" / workflow_id + workflow_dir.mkdir(parents=True) + (workflow_dir / "workflow.json").write_text( + json.dumps( + { + "name": "Child Filesystem Workflow", + "start": "child_node", + "nodes": [ + { + "id": "child_node", + "type": "python", + "code": "outputs['child_result'] = inputs.get('value', 'missing')", + } + ], + "edges": [], + } + ), + encoding="utf-8", + ) + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(fs_store, "_workspace_root", None) + + result = run_workflow( + workflow={ + "id": "parent-filesystem-workflow", + "start": "call_child", + "nodes": [ + { + "id": "call_child", + "type": "subworkflow", + "workflow_id": workflow_id, + } + ], + "edges": [], + }, + inputs={"value": "ok"}, + ensure_requirements=False, + ) + + assert result.status == "SUCCEEDED" + assert result.outputs == {"output": {"child_result": "ok"}} + + # =================================================================== # 方案 1: run_safe() tests # =================================================================== @@ -64,9 +121,7 @@ def test_run_safe_str_output(self): assert result["error"] is None def test_run_safe_dict_output(self): - adapter = _MockToolAdapter(outputs={ - "memory_search": {"results": [{"id": 1}], "count": 1} - }) + adapter = _MockToolAdapter(outputs={"memory_search": {"results": [{"id": 1}], "count": 1}}) result = adapter.run_safe("memory_search", query="test") assert result["success"] is True assert isinstance(result["text"], str) @@ -82,9 +137,7 @@ def test_run_safe_none_output(self): assert result["obj"] is None def test_run_safe_error(self): - adapter = _MockToolAdapter(outputs={ - "bad_tool": Exception("connection timeout") - }) + adapter = _MockToolAdapter(outputs={"bad_tool": Exception("connection timeout")}) result = adapter.run_safe("bad_tool") assert result["success"] is False assert result["text"] == "" @@ -103,9 +156,7 @@ def test_run_safe_unknown_tool(self): def test_run_safe_explicit_error(self): """Explicit error entry in outputs -> run_safe catches and wraps.""" - adapter = _MockToolAdapter(outputs={ - "failing": Exception("service unavailable") - }) + adapter = _MockToolAdapter(outputs={"failing": Exception("service unavailable")}) result = adapter.run_safe("failing") assert result["success"] is False assert result["obj"] is None @@ -134,92 +185,125 @@ def test_run_safe_list_output(self): class TestLintJoinRequirements: """Tests for lint_join_requirements().""" - def test_multi_incoming_no_join_error(self): - """Node with 2+ non-exclusive incoming edges and no join -> error.""" - wf = Workflow.from_dict({ - "name": "bad_join", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, - {"id": "c", "type": "python", "code": "outputs['z'] = inputs.get('x', 0)"}, - ], - "edges": [ - {"from": "a", "to": "c"}, - {"from": "b", "to": "c"}, - ], - }) + def test_multi_incoming_no_join_warning(self): + """Node with 2+ non-exclusive incoming edges and no join -> warning.""" + wf = Workflow.from_dict( + { + "name": "bad_join", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, + {"id": "c", "type": "python", "code": "outputs['z'] = inputs.get('x', 0)"}, + ], + "edges": [ + {"from": "a", "to": "c"}, + {"from": "b", "to": "c"}, + ], + } + ) results = lint_join_requirements(wf) assert len(results) == 1 assert results[0]["kind"] == "multi_incoming_no_join" - assert results[0]["severity"] == "error" + assert results[0]["severity"] == "warning" assert results[0]["node_id"] == "c" + def test_loop_node_multi_incoming_no_join_ok(self): + """Loop nodes naturally have an initial edge and a back edge.""" + wf = Workflow.from_dict( + { + "name": "loop_back_edge", + "start": "init", + "nodes": [ + {"id": "init", "type": "python", "code": "outputs['should_continue'] = True"}, + {"id": "loop_check", "type": "loop", "select_key": "should_continue"}, + {"id": "body", "type": "python", "code": "outputs['should_continue'] = False"}, + {"id": "done", "type": "python", "code": "pass"}, + ], + "edges": [ + {"from": "init", "to": "loop_check"}, + {"from": "loop_check", "to": "body", "label": "continue"}, + {"from": "body", "to": "loop_check"}, + {"from": "loop_check", "to": "done", "label": "exit"}, + ], + } + ) + + results = lint_join_requirements(wf) + assert results == [] + def test_exclusive_branch_no_error(self): """Edges from same branch with different labels are exclusive -> no error.""" - wf = Workflow.from_dict({ - "name": "ok_branch", - "start": "start", - "nodes": [ - {"id": "start", "type": "python", "code": "outputs['flag'] = True"}, - {"id": "br", "type": "branch", "select_key": "flag"}, - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['x'] = 2"}, - {"id": "merge", "type": "python", "code": "outputs['r'] = inputs.get('x')"}, - ], - "edges": [ - {"from": "start", "to": "br"}, - {"from": "br", "to": "a", "label": "true"}, - {"from": "br", "to": "b", "label": "false"}, - {"from": "a", "to": "merge"}, - {"from": "b", "to": "merge"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "ok_branch", + "start": "start", + "nodes": [ + {"id": "start", "type": "python", "code": "outputs['flag'] = True"}, + {"id": "br", "type": "branch", "select_key": "flag"}, + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['x'] = 2"}, + {"id": "merge", "type": "python", "code": "outputs['r'] = inputs.get('x')"}, + ], + "edges": [ + {"from": "start", "to": "br"}, + {"from": "br", "to": "a", "label": "true"}, + {"from": "br", "to": "b", "label": "false"}, + {"from": "a", "to": "merge"}, + {"from": "b", "to": "merge"}, + ], + } + ) results = lint_join_requirements(wf) assert len(results) == 0 def test_join_true_no_error(self): """Node with join=true should not trigger the lint.""" - wf = Workflow.from_dict({ - "name": "ok_join", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, - {"id": "c", "type": "python", "code": "pass", "join": True}, - ], - "edges": [ - {"from": "a", "to": "c"}, - {"from": "b", "to": "c"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "ok_join", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, + {"id": "c", "type": "python", "code": "pass", "join": True}, + ], + "edges": [ + {"from": "a", "to": "c"}, + {"from": "b", "to": "c"}, + ], + } + ) results = lint_join_requirements(wf) assert len(results) == 0 def test_mixed_exclusive_and_non_exclusive(self): """Branch targets + extra direct edge -> error (not fully exclusive).""" - wf = Workflow.from_dict({ - "name": "mixed", - "start": "start", - "nodes": [ - {"id": "start", "type": "python", "code": "outputs['flag'] = True"}, - {"id": "br", "type": "branch", "select_key": "flag"}, - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['x'] = 2"}, - {"id": "merge", "type": "python", "code": "pass"}, - ], - "edges": [ - {"from": "start", "to": "br"}, - {"from": "br", "to": "a", "label": "true"}, - {"from": "br", "to": "b", "label": "false"}, - {"from": "a", "to": "merge"}, - {"from": "b", "to": "merge"}, - {"from": "start", "to": "merge"}, # extra non-exclusive edge - ], - }) + wf = Workflow.from_dict( + { + "name": "mixed", + "start": "start", + "nodes": [ + {"id": "start", "type": "python", "code": "outputs['flag'] = True"}, + {"id": "br", "type": "branch", "select_key": "flag"}, + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['x'] = 2"}, + {"id": "merge", "type": "python", "code": "pass"}, + ], + "edges": [ + {"from": "start", "to": "br"}, + {"from": "br", "to": "a", "label": "true"}, + {"from": "br", "to": "b", "label": "false"}, + {"from": "a", "to": "merge"}, + {"from": "b", "to": "merge"}, + {"from": "start", "to": "merge"}, # extra non-exclusive edge + ], + } + ) results = lint_join_requirements(wf) assert len(results) == 1 assert results[0]["node_id"] == "merge" + assert results[0]["severity"] == "warning" class TestLintExpensiveNodeMultiTrigger: @@ -227,23 +311,25 @@ class TestLintExpensiveNodeMultiTrigger: def test_expensive_node_multi_incoming_error(self): """Expensive node (LLM call) with multiple non-exclusive edges -> error.""" - wf = Workflow.from_dict({ - "name": "expensive_bad", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, - { - "id": "expensive", - "type": "python", - "code": "result = llm.ask('summarize')\noutputs['summary'] = result", - }, - ], - "edges": [ - {"from": "a", "to": "expensive"}, - {"from": "b", "to": "expensive"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "expensive_bad", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, + { + "id": "expensive", + "type": "python", + "code": "result = llm.ask('summarize')\noutputs['summary'] = result", + }, + ], + "edges": [ + {"from": "a", "to": "expensive"}, + {"from": "b", "to": "expensive"}, + ], + } + ) results = lint_expensive_node_multi_trigger(wf) assert len(results) == 1 assert results[0]["kind"] == "expensive_node_multi_trigger" @@ -251,41 +337,45 @@ def test_expensive_node_multi_incoming_error(self): def test_non_expensive_node_no_error(self): """Non-expensive node with multiple edges -> no error from this check.""" - wf = Workflow.from_dict({ - "name": "cheap_ok", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, - {"id": "c", "type": "python", "code": "outputs['z'] = 3"}, - ], - "edges": [ - {"from": "a", "to": "c"}, - {"from": "b", "to": "c"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "cheap_ok", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, + {"id": "c", "type": "python", "code": "outputs['z'] = 3"}, + ], + "edges": [ + {"from": "a", "to": "c"}, + {"from": "b", "to": "c"}, + ], + } + ) results = lint_expensive_node_multi_trigger(wf) assert len(results) == 0 def test_write_tool_detected_as_expensive(self): """Node calling tool.run('write', ...) is detected as expensive.""" - wf = Workflow.from_dict({ - "name": "write_bad", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, - { - "id": "writer", - "type": "python", - "code": "tool.run('write', filePath='out.md', content='hi')", - }, - ], - "edges": [ - {"from": "a", "to": "writer"}, - {"from": "b", "to": "writer"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "write_bad", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, + { + "id": "writer", + "type": "python", + "code": "tool.run('write', filePath='out.md', content='hi')", + }, + ], + "edges": [ + {"from": "a", "to": "writer"}, + {"from": "b", "to": "writer"}, + ], + } + ) results = lint_expensive_node_multi_trigger(wf) assert len(results) == 1 assert results[0]["node_id"] == "writer" @@ -296,36 +386,482 @@ class TestLintWorkflowUnified: def test_lint_workflow_combines_all_checks(self): """lint_workflow() should return results from all check functions.""" - wf = Workflow.from_dict({ - "name": "combined", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, - {"id": "c", "type": "python", "code": "pass"}, - ], - "edges": [ - {"from": "a", "to": "c"}, - {"from": "b", "to": "c"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "combined", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = 2"}, + {"id": "c", "type": "python", "code": "pass"}, + ], + "edges": [ + {"from": "a", "to": "c"}, + {"from": "b", "to": "c"}, + ], + } + ) results = lint_workflow(wf) kinds = {r["kind"] for r in results} assert "multi_incoming_no_join" in kinds def test_lint_workflow_clean(self): """A well-formed workflow should produce no lint results.""" - wf = Workflow.from_dict({ - "name": "clean", + wf = Workflow.from_dict( + { + "name": "clean", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('mapped_x')"}, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"mapped_x": "x"}}], + } + ) + results = lint_workflow(wf) + assert len(results) == 0 + + def test_schema_lint_accepts_declared_mapping(self): + wf = Workflow.from_dict( + { + "name": "schema_clean", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['summary'] = 'ok'", + "outputSchema": {"summary": {"type": "str"}}, + }, + { + "id": "b", + "type": "python", + "code": "outputs['y'] = inputs['text']", + "inputSchema": {"text": {"type": "str", "required": True}}, + }, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"text": "summary"}}], + } + ) + + results = lint_workflow(wf) + + assert not [item for item in results if str(item.get("kind", "")).startswith("schema_")] + + def test_schema_lint_rejects_unknown_source_key(self): + wf = Workflow.from_dict( + { + "name": "schema_bad_src", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['summary'] = 'ok'", + "outputSchema": {"summary": {"type": "str"}}, + }, + {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('text')"}, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"text": "missing"}}], + } + ) + + results = lint_workflow(wf) + + assert any(item["kind"] == "schema_mapping_src_not_declared" for item in results) + + def test_schema_lint_rejects_type_mismatch(self): + wf = Workflow.from_dict( + { + "name": "schema_type_mismatch", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['count'] = 1", + "outputSchema": {"count": {"type": "int"}}, + }, + { + "id": "b", + "type": "python", + "code": "outputs['y'] = inputs.get('text')", + "inputSchema": {"text": {"type": "str", "required": True}}, + }, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"text": "count"}}], + } + ) + + results = lint_workflow(wf) + + assert any(item["kind"] == "schema_mapping_type_mismatch" for item in results) + + def test_schema_lint_rejects_large_output_to_regular_input(self): + wf = Workflow.from_dict( + { + "name": "schema_large_payload", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['raw_events'] = []", + "outputSchema": {"raw_events": {"type": "list", "large": True}}, + }, + { + "id": "b", + "type": "python", + "code": "outputs['y'] = len(inputs.get('events', []))", + "inputSchema": {"events": {"type": "list", "required": True}}, + }, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"events": "raw_events"}}], + } + ) + + results = lint_workflow(wf) + + assert any(item["kind"] == "schema_mapping_large_payload" for item in results) + + def test_run_workflow_rejects_schema_lint_errors(self): + workflow = { + "name": "schema_runtime_error", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['count'] = 1", + "outputSchema": {"count": {"type": "int"}}, + }, + { + "id": "b", + "type": "python", + "code": "outputs['y'] = inputs.get('text')", + "inputSchema": {"text": {"type": "str", "required": True}}, + }, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"text": "count"}}], + } + + with pytest.raises(WorkflowValidationError, match="Workflow schema lint failed"): + run_workflow(workflow=workflow, ensure_requirements=False) + + def test_missing_edge_mapping_warns_by_default(self): + wf = Workflow.from_dict( + { + "name": "implicit_mapping", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('x')"}, + ], + "edges": [{"from": "a", "to": "b"}], + } + ) + + results = lint_implicit_full_payload_edges(wf) + + assert len(results) == 1 + assert results[0]["kind"] == "implicit_full_payload_edge" + assert results[0]["severity"] == "warning" + + def test_missing_edge_mapping_is_error_when_strict(self): + wf = Workflow.from_dict( + { + "name": "strict_implicit_mapping", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('x')"}, + ], + "edges": [{"from": "a", "to": "b"}], + "metadata": {"runtime": {"strict_edge_mapping": True}}, + } + ) + + results = lint_implicit_full_payload_edges(wf) + + assert len(results) == 1 + assert results[0]["kind"] == "implicit_full_payload_edge" + assert results[0]["severity"] == "error" + + def test_run_workflow_rejects_strict_implicit_mapping(self): + workflow = { + "name": "strict_implicit_mapping", "start": "a", "nodes": [ {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('x')"}, ], "edges": [{"from": "a", "to": "b"}], - }) + "metadata": {"runtime": {"strict_edge_mapping": True}}, + } + + with pytest.raises(WorkflowValidationError): + run_workflow(workflow=workflow, ensure_requirements=False) + + def test_identity_mapping_does_not_suggest_omitting_mapping(self): + wf = Workflow.from_dict( + { + "name": "identity_mapping_ok", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('x')"}, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"x": "x"}}], + } + ) + results = lint_workflow(wf) - assert len(results) == 0 + + assert all(item.get("kind") != "scheme_a_suggest_omit_identity_mapping" for item in results) + + def test_trigger_workflow_recommends_strict_edge_mapping(self): + wf = Workflow.from_dict( + { + "name": "kafka_trigger_mapping_recommendation", + "start": "a", + "nodes": [{"id": "a", "type": "python", "code": "outputs['x'] = 1"}], + "edges": [], + "triggers": [{"type": "kafka"}], + } + ) + + results = lint_workflow(wf) + + recommendation = next(item for item in results if item["kind"] == "recommend_strict_edge_mapping") + assert recommendation["severity"] == "warning" + assert recommendation["trigger_types"] == ["kafka"] + + def test_trigger_workflow_with_strict_mapping_does_not_recommend_again(self): + wf = Workflow.from_dict( + { + "name": "strict_kafka_trigger_mapping", + "start": "a", + "nodes": [{"id": "a", "type": "python", "code": "outputs['x'] = 1"}], + "edges": [], + "triggers": [{"type": "kafka"}], + "metadata": {"runtime": {"strict_edge_mapping": True}}, + } + ) + + results = lint_workflow(wf) + + assert all(item.get("kind") != "recommend_strict_edge_mapping" for item in results) + + +class TestPayloadRiskRemoved: + def test_large_payload_no_mapping_preserves_inputs_without_payload_risk_result(self): + wf = Workflow.from_dict( + { + "name": "large_payload_no_mapping", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['raw_alerts'] = list(range(1500))"}, + { + "id": "b", + "type": "python", + "code": "outputs['count'] = len(inputs['raw_alerts'])\noutputs['is_list'] = isinstance(inputs['raw_alerts'], list)", + }, + ], + "edges": [{"from": "a", "to": "b"}], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime()).run() + + assert result.outputs == {"count": 1500, "is_list": True} + assert not hasattr(result, "payload_risk_summary") + + def test_large_payload_with_mapping_preserves_outputs_without_payload_risk_result(self): + wf = Workflow.from_dict( + { + "name": "large_payload_with_mapping", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['raw_alerts'] = list(range(1500))"}, + {"id": "b", "type": "python", "code": "outputs['count'] = len(inputs['alerts'])"}, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"alerts": "raw_alerts"}}], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime()).run() + + assert result.outputs == {"count": 1500} + assert not hasattr(result, "payload_risk_summary") + + def test_large_payload_join_buffer_preserves_outputs_without_payload_risk_result(self): + wf = Workflow.from_dict( + { + "name": "large_payload_join", + "start": "start", + "nodes": [ + {"id": "start", "type": "python", "code": "outputs['raw_alerts'] = list(range(1500))"}, + {"id": "b", "type": "python", "code": "outputs['x'] = 1"}, + { + "id": "join", + "type": "python", + "join": True, + "code": "outputs['count'] = len(inputs.get('raw_alerts', []))", + }, + ], + "edges": [ + {"from": "start", "to": "join"}, + {"from": "start", "to": "b"}, + {"from": "b", "to": "join"}, + ], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime()).run() + + assert result.outputs == {"count": 1500} + assert not hasattr(result, "payload_risk_summary") + + +class TestVertexCacheDataflow: + def test_legacy_mode_preserves_mapped_fanout_outputs(self): + wf = Workflow.from_dict( + { + "name": "legacy_large_source_small_mapped_fanout", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['events'] = list(range(1500))\noutputs['count'] = len(outputs['events'])", + }, + {"id": "b", "type": "python", "code": "outputs['b_count'] = inputs['count']"}, + {"id": "c", "type": "python", "code": "outputs['c_count'] = inputs['count']"}, + ], + "edges": [ + {"from": "a", "to": "b", "mapping": {"count": "count"}}, + {"from": "a", "to": "c", "mapping": {"count": "count"}}, + ], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime(), dataflow_mode="legacy").run() + + assert result.outputs == {"c_count": 1500} + assert not hasattr(result, "payload_risk_summary") + + def test_vertex_cache_mode_uses_resolved_edge_payload(self): + wf = Workflow.from_dict( + { + "name": "vertex_cache_small_mapped_fanout", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['events'] = list(range(1500))\noutputs['count'] = len(outputs['events'])", + }, + {"id": "b", "type": "python", "code": "outputs['b_count'] = inputs['count']"}, + {"id": "c", "type": "python", "code": "outputs['c_count'] = inputs['count']"}, + ], + "edges": [ + {"from": "a", "to": "b", "mapping": {"count": "count"}}, + {"from": "a", "to": "c", "mapping": {"count": "count"}}, + ], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime(), dataflow_mode="vertex_cache").run() + + assert result.outputs == {"c_count": 1500} + assert not hasattr(result, "payload_risk_summary") + + def test_vertex_cache_no_mapping_preserves_legacy_shape(self): + wf = Workflow.from_dict( + { + "name": "vertex_cache_no_mapping_fallback", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['events'] = list(range(1500))"}, + {"id": "b", "type": "python", "code": "outputs['count'] = len(inputs['events'])"}, + ], + "edges": [{"from": "a", "to": "b"}], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime(), dataflow_mode="vertex_cache").run() + + assert result.outputs == {"count": 1500} + assert not hasattr(result, "payload_risk_summary") + + def test_vertex_cache_root_mapping_preserves_legacy_merged_payload(self): + wf = Workflow.from_dict( + { + "name": "vertex_cache_root_mapping", + "start": "a", + "nodes": [ + { + "id": "a", + "type": "python", + "code": ( + "outputs['count'] = len(inputs['events'])\n" + "outputs['status'] = 'processed'\n" + "outputs['shared'] = 'output'" + ), + }, + { + "id": "b", + "type": "python", + "code": ( + "full = inputs['full_data']\n" + "outputs['has_events'] = 'events' in full\n" + "outputs['count'] = full['count']\n" + "outputs['status'] = full['status']\n" + "outputs['shared'] = full['shared']" + ), + }, + ], + "edges": [{"from": "a", "to": "b", "mapping": {"full_data": "$"}}], + } + ) + + result = WorkflowEngine(wf, runtime=PythonExecRuntime(), dataflow_mode="vertex_cache").run( + initial_inputs={"events": [1, 2, 3], "status": "raw", "shared": "input"} + ) + + assert result.outputs == { + "has_events": True, + "count": 3, + "status": "processed", + "shared": "output", + } + assert not hasattr(result, "payload_risk_summary") + + def test_run_workflow_uses_vertex_cache_dataflow_metadata(self): + workflow = { + "name": "metadata_vertex_cache_small_fanout", + "start": "a", + "metadata": {"runtime": {"strict_edge_mapping": True, "dataflow_mode": "vertex_cache"}}, + "nodes": [ + { + "id": "a", + "type": "python", + "code": "outputs['events'] = list(range(1500))\noutputs['count'] = len(outputs['events'])", + }, + {"id": "b", "type": "python", "code": "outputs['b_count'] = inputs['count']"}, + {"id": "c", "type": "python", "code": "outputs['c_count'] = inputs['count']"}, + ], + "edges": [ + {"from": "a", "to": "b", "mapping": {"count": "count"}}, + {"from": "a", "to": "c", "mapping": {"count": "count"}}, + ], + } + + result = run_workflow(workflow=workflow, ensure_requirements=False) + + assert result.status == "SUCCEEDED" + assert result.outputs == {"c_count": 1500} + assert not hasattr(result, "payload_risk_summary") # =================================================================== @@ -338,21 +874,23 @@ class TestEngineDedup: def test_dedup_different_inputs_both_execute(self): """When a node receives different inputs from two sources, both execute (no dedup).""" - wf = Workflow.from_dict({ - "name": "dedup_test", - "start": "a", - "nodes": [ - # a fans out to b and c (python node sends to all outgoing edges) - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - {"id": "b", "type": "python", "code": "outputs['x'] = 2"}, - {"id": "d", "type": "python", "code": "outputs['result'] = inputs.get('x', 0)"}, - ], - "edges": [ - {"from": "a", "to": "b"}, - {"from": "a", "to": "d"}, - {"from": "b", "to": "d"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "dedup_test", + "start": "a", + "nodes": [ + # a fans out to b and c (python node sends to all outgoing edges) + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + {"id": "b", "type": "python", "code": "outputs['x'] = 2"}, + {"id": "d", "type": "python", "code": "outputs['result'] = inputs.get('x', 0)"}, + ], + "edges": [ + {"from": "a", "to": "b"}, + {"from": "a", "to": "d"}, + {"from": "b", "to": "d"}, + ], + } + ) engine = WorkflowEngine( wf, runtime=PythonExecRuntime(), @@ -370,19 +908,21 @@ def test_dedup_different_inputs_both_execute(self): def test_dedup_skips_truly_identical_inputs(self): """Identical inputs to the same node -> second execution is skipped.""" - wf = Workflow.from_dict({ - "name": "dedup_identical", - "start": "a", - "nodes": [ - {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, - # Two edges from a to b (rare but possible) - {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('x')"}, - ], - "edges": [ - {"from": "a", "to": "b"}, - {"from": "a", "to": "b"}, - ], - }) + wf = Workflow.from_dict( + { + "name": "dedup_identical", + "start": "a", + "nodes": [ + {"id": "a", "type": "python", "code": "outputs['x'] = 1"}, + # Two edges from a to b (rare but possible) + {"id": "b", "type": "python", "code": "outputs['y'] = inputs.get('x')"}, + ], + "edges": [ + {"from": "a", "to": "b"}, + {"from": "a", "to": "b"}, + ], + } + ) engine = WorkflowEngine( wf, runtime=PythonExecRuntime(), @@ -409,6 +949,7 @@ class TestExampleWorkflow: def test_example_workflow_loads(self): """Example workflow.json should be valid and loadable.""" import os + wf_path = os.path.join( os.path.dirname(__file__), "..", @@ -429,6 +970,7 @@ def test_example_workflow_loads(self): def test_example_workflow_no_lint_errors(self): """Example workflow should pass all lint checks (no errors).""" import os + wf_path = os.path.join( os.path.dirname(__file__), "..", @@ -450,6 +992,7 @@ def test_example_workflow_no_lint_errors(self): def test_no_exec_json_files(self): """workflow-exec.json and workflow-exec-b.json should not exist.""" import os + base = os.path.join( os.path.dirname(__file__), "..", diff --git a/tests/workflow/test_workflow_history_mode.py b/tests/workflow/test_workflow_history_mode.py index 87441c5e5..af4412ea7 100644 --- a/tests/workflow/test_workflow_history_mode.py +++ b/tests/workflow/test_workflow_history_mode.py @@ -137,4 +137,5 @@ def test_python_runtime_can_cleanup_node_globals_after_execute() -> None: assert outputs == {"ok": True} assert "temporary_payload" not in runtime.globals - assert runtime.globals["outputs"] == {"ok": True} + assert "inputs" not in runtime.globals + assert "outputs" not in runtime.globals diff --git a/tests/workflow/test_workflow_store.py b/tests/workflow/test_workflow_store.py new file mode 100644 index 000000000..bca4f3ada --- /dev/null +++ b/tests/workflow/test_workflow_store.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from flocks.config.config import Config +from flocks.storage.storage import Storage +from flocks.workflow.store import WorkflowStore + + +def _reset_state() -> None: + Config._global_config = None + Config._cached_config = None + Storage._db_path = None + Storage._initialized = False + Storage._init_pid = None + WorkflowStore._initialized = False + WorkflowStore._conn = None + WorkflowStore._init_pid = None + WorkflowStore._db_path = None + + +@pytest.fixture(autouse=True) +async def isolated_workflow_store(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + data_dir = tmp_path / "flocks_data" + data_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("FLOCKS_DATA_DIR", str(data_dir)) + _reset_state() + yield + await WorkflowStore.close() + _reset_state() + + +@pytest.mark.asyncio +async def test_workflow_store_records_execution_steps_config_and_kv() -> None: + await WorkflowStore.init() + + await WorkflowStore.upsert_execution( + { + "id": "exec-1", + "workflowId": "wf-1", + "status": "running", + "startedAt": 100, + "triggerId": "trigger-1", + "triggerType": "schedule", + } + ) + await WorkflowStore.upsert_execution( + { + "id": "exec-2", + "workflowId": "wf-1", + "status": "success", + "startedAt": 200, + } + ) + await WorkflowStore.upsert_execution( + { + "id": "exec-other", + "workflowId": "wf-other", + "status": "success", + "startedAt": 300, + } + ) + + rows = await WorkflowStore.list_executions("wf-1", limit=10) + assert [row["id"] for row in rows] == ["exec-2", "exec-1"] + filtered = await WorkflowStore.list_executions( + "wf-1", + limit=10, + trigger_id="trigger-1", + trigger_type="schedule", + ) + assert [row["id"] for row in filtered] == ["exec-1"] + + await WorkflowStore.record_step("exec-1", 1, {"node_id": "n1", "outputs": {"ok": 1}}) + await WorkflowStore.record_step("exec-1", 2, {"node_id": "n2", "outputs": {"ok": 2}}) + steps, total = await WorkflowStore.list_steps("exec-1", offset=1, limit=1) + assert total == 2 + assert steps == [{"node_id": "n2", "outputs": {"ok": 2}}] + + await WorkflowStore.put_config("wf-1", {"enabled": True}, kind="workflow_poller_config") + assert await WorkflowStore.get_config("wf-1", kind="workflow_poller_config") == {"enabled": True} + assert await WorkflowStore.list_configs(kind="workflow_poller_config") == [("wf-1", {"enabled": True})] + + await WorkflowStore.kv_put("workflow_runtime/wf-1", {"status": "active"}) + assert await WorkflowStore.kv_get("workflow_runtime/wf-1") == {"status": "active"} + assert await WorkflowStore.kv_list_keys("workflow_runtime/") == ["workflow_runtime/wf-1"] + + +@pytest.mark.asyncio +async def test_workflow_store_increment_stats_is_atomic_for_concurrent_updates() -> None: + await WorkflowStore.init() + updates = [(idx % 3 != 0, 1.0) for idx in range(60)] + + await asyncio.gather( + *( + WorkflowStore.increment_stats("wf-concurrent", success=success, duration=duration) + for success, duration in updates + ) + ) + + stats = await WorkflowStore.get_stats("wf-concurrent") + assert stats is not None + assert stats["callCount"] == 60 + assert stats["successCount"] == sum(1 for success, _ in updates if success) + assert stats["errorCount"] == sum(1 for success, _ in updates if not success) + assert stats["totalRuntime"] == pytest.approx(60.0) + assert stats["avgRuntime"] == pytest.approx(1.0) diff --git a/tests/workspace/test_workspace_routes.py b/tests/workspace/test_workspace_routes.py index 5192ec606..03f09eea3 100644 --- a/tests/workspace/test_workspace_routes.py +++ b/tests/workspace/test_workspace_routes.py @@ -8,8 +8,9 @@ ----------------- Directory: GET /tree, GET /list, POST /dir, DELETE /dir File: POST /upload, GET /file, PUT /file, DELETE /file, - GET /download, POST /download/zip, POST /move -Memory: GET /memory/list, GET /memory/file + GET /preview, GET /download, POST /download/zip, POST /move +Memory: GET /memory/list, GET /memory/file, GET /memory/preview, + GET /memory/download Stats: GET /stats """ @@ -488,6 +489,52 @@ def test_download_zip_skips_invalid_paths(self, workspace_client): assert zf.namelist() == [] +# ─── File preview ───────────────────────────────────────────────────────────── + +class TestPreview: + def test_preview_pdf_inline(self, workspace_client): + ws = _ws(workspace_client) + (ws / "outputs" / "doc.pdf").write_bytes(b"%PDF-1.4") + r = _client(workspace_client).get("/api/workspace/preview?path=outputs/doc.pdf") + assert r.status_code == 200 + assert r.content == b"%PDF-1.4" + assert r.headers["content-type"].startswith("application/pdf") + assert "inline" in r.headers.get("content-disposition", "") + + def test_preview_png_inline(self, workspace_client): + ws = _ws(workspace_client) + (ws / "outputs" / "image.png").write_bytes(b"\x89PNG\r\n\x1a\n") + r = _client(workspace_client).get("/api/workspace/preview?path=outputs/image.png") + assert r.status_code == 200 + assert r.content == b"\x89PNG\r\n\x1a\n" + assert r.headers["content-type"].startswith("image/png") + assert "inline" in r.headers.get("content-disposition", "") + + def test_preview_html_rejected(self, workspace_client): + ws = _ws(workspace_client) + (ws / "outputs" / "demo.html").write_text("") + r = _client(workspace_client).get("/api/workspace/preview?path=outputs/demo.html") + assert r.status_code == 415 + + def test_preview_svg_inline_with_security_headers(self, workspace_client): + ws = _ws(workspace_client) + (ws / "outputs" / "image.svg").write_text("") + r = _client(workspace_client).get("/api/workspace/preview?path=outputs/image.svg") + assert r.status_code == 200 + assert r.headers["content-type"].startswith("image/svg+xml") + assert "inline" in r.headers.get("content-disposition", "") + assert r.headers["x-content-type-options"] == "nosniff" + assert "script-src 'none'" in r.headers["content-security-policy"] + + def test_preview_nonexistent_returns_404(self, workspace_client): + r = _client(workspace_client).get("/api/workspace/preview?path=missing.pdf") + assert r.status_code == 404 + + def test_preview_traversal_rejected(self, workspace_client): + r = _client(workspace_client).get("/api/workspace/preview?path=../../etc/passwd") + assert r.status_code == 400 + + # ─── Move / rename ──────────────────────────────────────────────────────────── class TestMove: @@ -540,6 +587,82 @@ def test_move_traversal_rejected(self, workspace_client): assert r.status_code == 400 +# ─── Reveal in file manager ────────────────────────────────────────────────── + +class TestReveal: + def test_reveal_file_on_macos_selects_file(self, workspace_client, monkeypatch): + from flocks.server.routes import workspace as workspace_routes + + ws = _ws(workspace_client) + target = ws / "outputs" / "report.pdf" + target.write_bytes(b"%PDF") + calls = [] + monkeypatch.setattr(workspace_routes.sys, "platform", "darwin") + monkeypatch.setattr(workspace_routes.subprocess, "Popen", lambda args: calls.append(args)) + + r = _client(workspace_client).post( + "/api/workspace/reveal", + json={"path": "outputs/report.pdf"}, + ) + + assert r.status_code == 200 + assert r.json()["opened"] is True + assert r.json()["target"] == "file" + assert calls == [["open", "-R", str(target)]] + + def test_reveal_directory_on_windows_opens_directory(self, workspace_client, monkeypatch): + from flocks.server.routes import workspace as workspace_routes + + ws = _ws(workspace_client) + target = ws / "outputs" + calls = [] + monkeypatch.setattr(workspace_routes.sys, "platform", "win32") + monkeypatch.setattr(workspace_routes.subprocess, "Popen", lambda args: calls.append(args)) + + r = _client(workspace_client).post( + "/api/workspace/reveal", + json={"path": "outputs"}, + ) + + assert r.status_code == 200 + assert r.json()["target"] == "directory" + assert calls == [["explorer", str(target)]] + + def test_reveal_file_on_linux_opens_parent_directory(self, workspace_client, monkeypatch): + from flocks.server.routes import workspace as workspace_routes + + ws = _ws(workspace_client) + target = ws / "outputs" / "report.txt" + target.write_text("hello") + calls = [] + monkeypatch.setattr(workspace_routes.sys, "platform", "linux") + monkeypatch.setattr(workspace_routes.shutil, "which", lambda name: "/usr/bin/xdg-open" if name == "xdg-open" else None) + monkeypatch.setattr(workspace_routes.subprocess, "Popen", lambda args: calls.append(args)) + + r = _client(workspace_client).post( + "/api/workspace/reveal", + json={"path": "outputs/report.txt"}, + ) + + assert r.status_code == 200 + assert r.json()["mode"] == "open" + assert calls == [["/usr/bin/xdg-open", str(target.parent)]] + + def test_reveal_nonexistent_returns_404(self, workspace_client): + r = _client(workspace_client).post( + "/api/workspace/reveal", + json={"path": "missing.txt"}, + ) + assert r.status_code == 404 + + def test_reveal_traversal_rejected(self, workspace_client): + r = _client(workspace_client).post( + "/api/workspace/reveal", + json={"path": "../../etc/passwd"}, + ) + assert r.status_code == 400 + + # ─── Memory view (read-only) ───────────────────────────────────────────────── class TestMemoryView: @@ -570,6 +693,13 @@ def test_read_memory_file(self, workspace_client): assert "Key facts" in data["content"] assert data["truncated"] is False + def test_read_memory_binary_returns_400(self, workspace_client): + mem = _mem(workspace_client) + (mem / "image.png").write_bytes(b"\x89PNG\r\n\x1a\n") + r = _client(workspace_client).get("/api/workspace/memory/file?path=image.png") + assert r.status_code == 400 + assert "Binary file" in r.json()["detail"] + def test_read_large_memory_file_returns_truncated_preview(self, workspace_client, monkeypatch): monkeypatch.setenv("FLOCKS_WORKSPACE_MAX_READ_BYTES", "8") mem = _mem(workspace_client) @@ -598,6 +728,40 @@ def test_memory_nested_file(self, workspace_client): assert r.status_code == 200 assert r.json()["content"] == "daily note" + def test_preview_memory_pdf_inline(self, workspace_client): + mem = _mem(workspace_client) + (mem / "report.pdf").write_bytes(b"%PDF-1.4\n") + r = _client(workspace_client).get("/api/workspace/memory/preview?path=report.pdf") + assert r.status_code == 200 + assert r.headers["content-type"].startswith("application/pdf") + assert "inline" in r.headers.get("content-disposition", "") + + def test_preview_memory_svg_inline_with_security_headers(self, workspace_client): + mem = _mem(workspace_client) + (mem / "logo.svg").write_text("") + r = _client(workspace_client).get("/api/workspace/memory/preview?path=logo.svg") + assert r.status_code == 200 + assert r.headers["content-type"].startswith("image/svg+xml") + assert "script-src 'none'" in r.headers["content-security-policy"] + + def test_preview_memory_unsupported_returns_415(self, workspace_client): + mem = _mem(workspace_client) + (mem / "archive.zip").write_bytes(b"PK\x03\x04") + r = _client(workspace_client).get("/api/workspace/memory/preview?path=archive.zip") + assert r.status_code == 415 + + def test_download_memory_file(self, workspace_client): + mem = _mem(workspace_client) + (mem / "image.png").write_bytes(b"\x89PNG\r\n\x1a\n") + r = _client(workspace_client).get("/api/workspace/memory/download?path=image.png") + assert r.status_code == 200 + assert r.content == b"\x89PNG\r\n\x1a\n" + assert r.headers["content-type"].startswith("application/octet-stream") + + def test_memory_preview_traversal_rejected(self, workspace_client): + r = _client(workspace_client).get("/api/workspace/memory/preview?path=../../etc/passwd") + assert r.status_code == 400 + def test_memory_write_not_allowed(self, workspace_client): """Memory directory has no write endpoint — PUT /file with memory path is confined to workspace.""" # Trying to write to memory via workspace file endpoint should be rejected diff --git a/uv.lock b/uv.lock index 3be641a1f..e13b4c58e 100644 --- a/uv.lock +++ b/uv.lock @@ -537,7 +537,7 @@ wheels = [ [[package]] name = "flocks" -version = "2026.6.24" +version = "2026.7.1" source = { editable = "." } dependencies = [ { name = "aiofiles" }, diff --git a/webui/package-lock.json b/webui/package-lock.json index cb87d933b..2f6078c9a 100644 --- a/webui/package-lock.json +++ b/webui/package-lock.json @@ -16,6 +16,7 @@ "i18next": "^25.8.14", "i18next-browser-languagedetector": "^8.2.1", "lucide-react": "^0.562.0", + "pdfjs-dist": "^6.1.200", "qrcode.react": "^4.2.0", "react": "^19.2.0", "react-dom": "^19.2.0", @@ -1248,6 +1249,256 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@napi-rs/canvas": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-1.0.1.tgz", + "integrity": "sha512-mPD43G7pXbQhIGa7z4IpT/vXm1jbF8cBM1oY5UqjL8LSaTCNGhNi2Lidc/0+LwKbNiqbv/Tq0JlBRwKu+LW3iw==", + "license": "MIT", + "optional": true, + "workspaces": [ + "e2e/*" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + }, + "optionalDependencies": { + "@napi-rs/canvas-android-arm64": "1.0.1", + "@napi-rs/canvas-darwin-arm64": "1.0.1", + "@napi-rs/canvas-darwin-x64": "1.0.1", + "@napi-rs/canvas-linux-arm-gnueabihf": "1.0.1", + "@napi-rs/canvas-linux-arm64-gnu": "1.0.1", + "@napi-rs/canvas-linux-arm64-musl": "1.0.1", + "@napi-rs/canvas-linux-riscv64-gnu": "1.0.1", + "@napi-rs/canvas-linux-x64-gnu": "1.0.1", + "@napi-rs/canvas-linux-x64-musl": "1.0.1", + "@napi-rs/canvas-win32-arm64-msvc": "1.0.1", + "@napi-rs/canvas-win32-x64-msvc": "1.0.1" + } + }, + "node_modules/@napi-rs/canvas-android-arm64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-1.0.1.tgz", + "integrity": "sha512-d7ZCwJsgH4QNG50C7HQeVRsRG1gRDa1UeDUb1jEcqgLuiEJp6GVbGiZkFXPlmt0dEs2QHRQCPJoOv+bOkSQR/w==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-darwin-arm64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-1.0.1.tgz", + "integrity": "sha512-ppyVSzIHsVldc3B++mdh03ed0Q0hoVR2QDG9O/wEUR0PurJKwDEEYV87uBQDpbSumJBfLEINDndsOPzQj71qEQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-darwin-x64": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-1.0.1.tgz", + "integrity": "sha512-/BlXif9VOzf/WP32g9zxl612dO0KLvwqplBFqfRcyr3PyR5fhPrilTuJxSBq3zkwCKGKy82JsoPd2JeQI/HBlA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-linux-arm-gnueabihf": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-1.0.1.tgz", + "integrity": "sha512-JTGq93/Pje+iSNVjL+ggB0+pqEfu7nXvQGrHTvugz+Lp08wCDa5rjov4JeEljGDk16/inVBU9sp4N9f0/+o+9A==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-gnu": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-1.0.1.tgz", + "integrity": "sha512-i74zqEh5yFmYwHkszFo+4EH4l5ATD4bSlJG21iW2j5kpqiN2b0WN9SG/xdq2O60MjZK0ZLSu3a/Z3aQAsmDQ5A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-musl": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-1.0.1.tgz", + "integrity": "sha512-COqBxybXcKb6gNgEhjh04rPHrpsJB+n/5+p4ySPgQWl0i+xVNYHn4rvzCtUBIFqOgY6HEJ9UaP6c4W9EMwzfpQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-linux-riscv64-gnu": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-1.0.1.tgz", + "integrity": "sha512-1vdAZGpD85lMUo7K3qtEdoIWeMc0xcpUD5PagK3fVcMSdf8dkSL5bg/KE4Rwv5NF+PYx4plrgfn0KRMOqdKtwA==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-gnu": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-1.0.1.tgz", + "integrity": "sha512-W/iC2qJZGqKKQJ0JrNo3QkhoAy/PvzlmdYLW8Yz5/L6XgT5d7t26dnqgP2rCbL58P3CbPw7ES0Rz8OG0gn7JeQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-musl": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-1.0.1.tgz", + "integrity": "sha512-a5mmIVwxF92UGUe+1c7Ap32ZRnApbRMnQC/KgYyFB0AXZShBCHVGaURq+BDkiV7jvHhVwvvAP0Q/3aWNhqgVZg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-win32-arm64-msvc": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-arm64-msvc/-/canvas-win32-arm64-msvc-1.0.1.tgz", + "integrity": "sha512-rVnDhVvcXlqcHMsgnxxhZgdRkRIqVBlx+FJwSAHi4VbWWwsowvV5ldFEecEHD2+Ac/IL3fNWF/LB5CZTghNwRA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, + "node_modules/@napi-rs/canvas-win32-x64-msvc": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-1.0.1.tgz", + "integrity": "sha512-UMstkP/nZHbithgdSJv1EvYVrYhdao5B5N3szMVU2i5/b6ijMcVPXOEyrk0QXl0iPjv8Hkoow+Tap+MiOxppOQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Brooooooklyn" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -2616,13 +2867,16 @@ "license": "MIT" }, "node_modules/baseline-browser-mapping": { - "version": "2.9.19", - "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.19.tgz", - "integrity": "sha512-ipDqC8FrAl/76p2SSWKSI+H9tFwm7vYqXQrItCuiVPt26Km0jS+NzSsBWAaBusvSbQcfJG+JitdMm+wZAgTYqg==", + "version": "2.10.40", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.40.tgz", + "integrity": "sha512-BSSLZ9/Cjjv7Gtj5B68ZzXcXUg8iOf3fme+FCuh8rC/Go+Kmh8cox7M3A8dolou16s64QjLPOSdngh7GxXvkSw==", "dev": true, "license": "Apache-2.0", "bin": { - "baseline-browser-mapping": "dist/cli.js" + "baseline-browser-mapping": "dist/cli.cjs" + }, + "engines": { + "node": ">=6.0.0" } }, "node_modules/binary-extensions": { @@ -2726,9 +2980,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001768", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001768.tgz", - "integrity": "sha512-qY3aDRZC5nWPgHUgIB84WL+nySuo19wk0VJpp/XI9T34lrvkyhRvNVOFJOp2kxClQhiFBu+TaUSudf6oa3vkSA==", + "version": "1.0.30001800", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001800.tgz", + "integrity": "sha512-MMHtuAz9Ys840zAY5F4k6fV5GaivZ9sPk+nz0mY+GYVzRBnYkN0mpqkSR92oWRQ19yQWo4HvBV/FnC16AJX8MA==", "dev": true, "funding": [ { @@ -6037,6 +6291,18 @@ "dev": true, "license": "MIT" }, + "node_modules/pdfjs-dist": { + "version": "6.1.200", + "resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-6.1.200.tgz", + "integrity": "sha512-o8MolyzirkkLrcdsae/HEOiIcXWI7DS5zGpvqW8xTC2YUsW30rltFw2bDGvw/fskUdEMrQm2br68jzDS5BH2vw==", + "license": "Apache-2.0", + "engines": { + "node": ">=22.13.0 || >=24" + }, + "optionalDependencies": { + "@napi-rs/canvas": "^1.0.0" + } + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", diff --git a/webui/package.json b/webui/package.json index ee2f6e884..c84357bdf 100644 --- a/webui/package.json +++ b/webui/package.json @@ -22,6 +22,7 @@ "i18next": "^25.8.14", "i18next-browser-languagedetector": "^8.2.1", "lucide-react": "^0.562.0", + "pdfjs-dist": "^6.1.200", "qrcode.react": "^4.2.0", "react": "^19.2.0", "react-dom": "^19.2.0", diff --git a/webui/public/vendor-logos/360.png b/webui/public/vendor-logos/360.png new file mode 100644 index 000000000..bda5c9aa8 Binary files /dev/null and b/webui/public/vendor-logos/360.png differ diff --git a/webui/public/vendor-logos/huaweicloud.png b/webui/public/vendor-logos/huaweicloud.png new file mode 100644 index 000000000..2e1c3568b Binary files /dev/null and b/webui/public/vendor-logos/huaweicloud.png differ diff --git a/webui/public/vendor-logos/huorong.png b/webui/public/vendor-logos/huorong.png new file mode 100644 index 000000000..6ff09f179 Binary files /dev/null and b/webui/public/vendor-logos/huorong.png differ diff --git a/webui/public/vendor-logos/nsfocus.png b/webui/public/vendor-logos/nsfocus.png new file mode 100644 index 000000000..2feaf04da Binary files /dev/null and b/webui/public/vendor-logos/nsfocus.png differ diff --git a/webui/public/vendor-logos/qianxin.png b/webui/public/vendor-logos/qianxin.png new file mode 100644 index 000000000..dee91c775 Binary files /dev/null and b/webui/public/vendor-logos/qianxin.png differ diff --git a/webui/public/vendor-logos/qingteng.png b/webui/public/vendor-logos/qingteng.png new file mode 100644 index 000000000..6e7eba2e8 Binary files /dev/null and b/webui/public/vendor-logos/qingteng.png differ diff --git a/webui/public/vendor-logos/sangfor.png b/webui/public/vendor-logos/sangfor.png new file mode 100644 index 000000000..7fd8c2bf3 Binary files /dev/null and b/webui/public/vendor-logos/sangfor.png differ diff --git a/webui/public/vendor-logos/threatbook.png b/webui/public/vendor-logos/threatbook.png new file mode 100644 index 000000000..10ecdfe97 Binary files /dev/null and b/webui/public/vendor-logos/threatbook.png differ diff --git a/webui/src/api/consoleUpgrade.ts b/webui/src/api/consoleUpgrade.ts index d58d12b3e..e211efcb1 100644 --- a/webui/src/api/consoleUpgrade.ts +++ b/webui/src/api/consoleUpgrade.ts @@ -64,6 +64,8 @@ export interface UpgradeRequestStatus { export interface ProPackageStatus { installed: boolean; + runtime_importable?: boolean | null; + install_marker_present?: boolean | null; installed_version?: string | null; flockspro_component_version?: string | null; build_id?: string | null; @@ -89,17 +91,6 @@ export const consoleUpgradeApi = { return response.data; }, - syncRevocations: async (): Promise<{ - revoked_license_ids: string[]; - imported: boolean; - synced_license_ids?: string[]; - activated_license_id?: string | null; - refreshed_license_id?: string | null; - }> => { - const response = await client.post('/api/console/licenses/sync-revocations'); - return response.data; - }, - getRequest: async (requestId: string): Promise => { const response = await client.get(`/api/console/upgrade-requests/${requestId}`); return response.data; @@ -163,4 +154,3 @@ export const consoleUpgradeApi = { }); }, }; - diff --git a/webui/src/api/device.ts b/webui/src/api/device.ts index 1944c58ce..e6b34dd4f 100644 --- a/webui/src/api/device.ts +++ b/webui/src/api/device.ts @@ -109,6 +109,7 @@ export interface DeviceTemplate { vendor?: string | null; description?: string | null; description_cn?: string | null; + docs_url?: string | null; credential_schema: APIServiceCredentialField[]; tool_count: number; installed: boolean; diff --git a/webui/src/api/index.ts b/webui/src/api/index.ts index 9e807ec2e..bbf46d8c0 100644 --- a/webui/src/api/index.ts +++ b/webui/src/api/index.ts @@ -9,3 +9,4 @@ export * from './tool'; export * from './provider'; export * from './mcp'; export * from './hub'; +export * from './webuiContractPages'; diff --git a/webui/src/api/update.ts b/webui/src/api/update.ts index 737b7254b..dd2f84ea2 100644 --- a/webui/src/api/update.ts +++ b/webui/src/api/update.ts @@ -12,6 +12,12 @@ export type UpdateEdition = 'flocks' | 'flockspro'; export interface VersionInfo { current_version: string; latest_version: string | null; + current_core_version?: string | null; + latest_core_version?: string | null; + current_bundle_version?: string | null; + latest_bundle_version?: string | null; + current_pro_component_version?: string | null; + latest_pro_component_version?: string | null; edition?: 'flocks' | 'flockspro'; has_update: boolean; release_notes: string | null; diff --git a/webui/src/api/userDefinedPages.ts b/webui/src/api/userDefinedPages.ts deleted file mode 100644 index 621aec313..000000000 --- a/webui/src/api/userDefinedPages.ts +++ /dev/null @@ -1,73 +0,0 @@ -import client from './client'; - -export interface UserDefinedPageListItem { - id: string; - title: string; - route: string; - icon: string; - order: number; - enabled: boolean; - placement: string; - buildHash: string; - buildStatus: 'idle' | 'building' | 'ready' | 'failed'; -} - -export interface UserDefinedPageManifest { - id: string; - title: string; - route: string; - icon: string; - order: number; - enabled: boolean; - placement: string; - entry: string; - updatedAt: number; -} - -export interface UserDefinedPageBuildMeta { - hash: string; - builtAt: number; - status: 'idle' | 'building' | 'ready' | 'failed'; - error?: string | null; -} - -export interface UserDefinedPageDetail { - manifest: UserDefinedPageManifest; - build: UserDefinedPageBuildMeta; - sourceFiles: string[]; -} - -export interface UserDefinedPageCreateRequest { - id: string; - title: string; - icon?: string; - order?: number; -} - -export interface UserDefinedPageSaveRequest { - manifest?: Partial; - sourcePath?: string; - sourceContent?: string; -} - -export const userDefinedPagesAPI = { - list: (enabledOnly = false) => - client.get('/api/user-defined-pages', { - params: enabledOnly ? { enabledOnly: true } : undefined, - }), - - create: (payload: UserDefinedPageCreateRequest) => - client.post('/api/user-defined-pages', payload), - - get: (pageId: string) => - client.get(`/api/user-defined-pages/${pageId}`), - - save: (pageId: string, payload: UserDefinedPageSaveRequest) => - client.put<{ manifest: UserDefinedPageManifest; build: UserDefinedPageBuildMeta }>( - `/api/user-defined-pages/${pageId}`, - payload, - ), - - build: (pageId: string) => - client.post(`/api/user-defined-pages/${pageId}/build`), -}; diff --git a/webui/src/api/webuiContractPages.ts b/webui/src/api/webuiContractPages.ts new file mode 100644 index 000000000..3f53e8737 --- /dev/null +++ b/webui/src/api/webuiContractPages.ts @@ -0,0 +1,103 @@ +import client from './client'; + +export interface WebUIContractPageListItem { + id: string; + title: string; + route: string; + icon: string; + order: number; + enabled: boolean; + placement: string; + buildHash: string; + buildStatus: 'idle' | 'building' | 'ready' | 'failed'; + workspaceId?: string | null; + workspaceTitle?: string | null; + workspaceRoute?: string | null; +} + +export interface WebUIContractWorkspaceSection { + id: string; + label: string; + pageIds: string[]; + defaultPageId?: string | null; + contentPadding?: 'comfortable' | 'none'; + themeOverride?: 'light' | 'dark' | null; +} + +export interface WebUIContractWorkspaceListItem { + id: string; + title: string; + route: string; + icon: string; + order: number; + enabled: boolean; + placement: 'sceneWorkspace' | 'aiWorkbench'; + defaultPageId?: string | null; + sections?: WebUIContractWorkspaceSection[]; + pages: WebUIContractPageListItem[]; +} + +export interface WebUIContractPageManifest { + id: string; + title: string; + route: string; + icon: string; + order: number; + enabled: boolean; + placement: string; + entry: string; + updatedAt: number; +} + +export interface WebUIContractPageBuildMeta { + hash: string; + builtAt: number; + status: 'idle' | 'building' | 'ready' | 'failed'; + error?: string | null; +} + +export interface WebUIContractPageDetail { + manifest: WebUIContractPageManifest; + build: WebUIContractPageBuildMeta; + sourceFiles: string[]; +} + +export interface WebUIContractPageCreateRequest { + id: string; + title: string; + icon?: string; + order?: number; +} + +export interface WebUIContractPageSaveRequest { + manifest?: Partial; + sourcePath?: string; + sourceContent?: string; +} + +export const webuiContractPagesAPI = { + list: (enabledOnly = false) => + client.get('/api/contracts/webui/pages', { + params: enabledOnly ? { enabledOnly: true } : undefined, + }), + + listWorkspaces: (enabledOnly = false) => + client.get('/api/contracts/webui/workspaces', { + params: enabledOnly ? { enabledOnly: true } : undefined, + }), + + create: (payload: WebUIContractPageCreateRequest) => + client.post('/api/contracts/webui/pages', payload), + + get: (pageId: string) => + client.get(`/api/contracts/webui/pages/${pageId}`), + + save: (pageId: string, payload: WebUIContractPageSaveRequest) => + client.put<{ manifest: WebUIContractPageManifest; build: WebUIContractPageBuildMeta }>( + `/api/contracts/webui/pages/${pageId}`, + payload, + ), + + build: (pageId: string) => + client.post(`/api/contracts/webui/pages/${pageId}/build`), +}; diff --git a/webui/src/api/workspace.ts b/webui/src/api/workspace.ts index d8f3d712d..f346cca9b 100644 --- a/webui/src/api/workspace.ts +++ b/webui/src/api/workspace.ts @@ -85,6 +85,9 @@ export const workspaceAPI = { downloadUrl: (path: string) => `${client.defaults.baseURL ?? ''}/api/workspace/download?path=${encodeURIComponent(path)}`, + previewUrl: (path: string) => + `${client.defaults.baseURL ?? ''}/api/workspace/preview?path=${encodeURIComponent(path)}`, + downloadZip: (paths: string[], archiveName = 'workspace_files.zip') => client.post( '/api/workspace/download/zip', @@ -95,6 +98,12 @@ export const workspaceAPI = { move: (src: string, dst: string) => client.post<{ src: string; dst: string; moved: boolean }>('/api/workspace/move', { src, dst }), + reveal: (path: string) => + client.post<{ path: string; opened: boolean; target: 'file' | 'directory'; mode: string }>( + '/api/workspace/reveal', + { path }, + ), + // Memory (read-only) listMemory: () => client.get('/api/workspace/memory/list'), @@ -102,6 +111,12 @@ export const workspaceAPI = { readMemoryFile: (path: string) => client.get('/api/workspace/memory/file', { params: { path } }), + memoryDownloadUrl: (path: string) => + `${client.defaults.baseURL ?? ''}/api/workspace/memory/download?path=${encodeURIComponent(path)}`, + + memoryPreviewUrl: (path: string) => + `${client.defaults.baseURL ?? ''}/api/workspace/memory/preview?path=${encodeURIComponent(path)}`, + // Stats stats: () => client.get('/api/workspace/stats'), diff --git a/webui/src/components/common/ConfirmDialog.tsx b/webui/src/components/common/ConfirmDialog.tsx index 75c4d38c6..4290b1f75 100644 --- a/webui/src/components/common/ConfirmDialog.tsx +++ b/webui/src/components/common/ConfirmDialog.tsx @@ -49,19 +49,21 @@ export function ConfirmProvider({ children }: { children: ReactNode }) { {children} {state && ( -
-
+
+
{state.options.variant === 'danger' && (
)} -
+
{state.options.title && (

{state.options.title}

)} -

{state.options.description}

+

+ {state.options.description} +

diff --git a/webui/src/components/common/GuideInfoIcon.tsx b/webui/src/components/common/GuideInfoIcon.tsx index 5393af78d..1e0230936 100644 --- a/webui/src/components/common/GuideInfoIcon.tsx +++ b/webui/src/components/common/GuideInfoIcon.tsx @@ -13,12 +13,14 @@ interface GuideInfoIconProps { label: string; description: string; className?: string; + interactive?: boolean; } export default function GuideInfoIcon({ label, description, className = '', + interactive = true, }: GuideInfoIconProps) { const [tooltip, setTooltip] = useState(null); const tooltipId = useId(); @@ -39,11 +41,12 @@ export default function GuideInfoIcon({ return ( <> { event.preventDefault(); event.stopPropagation(); @@ -53,11 +56,11 @@ export default function GuideInfoIcon({ event.stopPropagation(); }} onPointerEnter={(event) => showTooltip(event.currentTarget)} - onFocus={(event) => showTooltip(event.currentTarget)} + onFocus={interactive ? (event) => showTooltip(event.currentTarget) : undefined} onMouseEnter={(event) => showTooltip(event.currentTarget)} onMouseOver={(event) => showTooltip(event.currentTarget)} onPointerLeave={hideTooltip} - onBlur={hideTooltip} + onBlur={interactive ? hideTooltip : undefined} onMouseLeave={hideTooltip} >
@@ -301,7 +302,7 @@ export default function UpdateModal({ initialInfo, edition = 'flocks', canUpgrad <>
- {t('confirmUpgrade', { version: formatUpdateVersion(info.latest_version) })} + {t('confirmUpgrade', { version: formatUpdateVersion(latestDisplayVersion) })}
{t('newVersionDesc')}
@@ -328,17 +329,17 @@ export default function UpdateModal({ initialInfo, edition = 'flocks', canUpgrad
{t('currentVersion')} - {formatUpdateVersion(info?.current_version)} + {formatUpdateVersion(currentDisplayVersion)}
{t('latestVersion')}
{checking ? ( - ) : info?.latest_version ? ( + ) : latestDisplayVersion ? ( <> - {formatUpdateVersion(info.latest_version)} - {info.has_update ? ( + {formatUpdateVersion(latestDisplayVersion)} + {info?.has_update ? ( {t('hasUpdate')} ) : ( {t('upToDate')} @@ -351,43 +352,6 @@ export default function UpdateModal({ initialInfo, edition = 'flocks', canUpgrad
- {info?.has_update && localizedReleaseNotes && ( -
- - - {showReleaseNotes && ( -
-
- {t('releaseNotes')} - {info.release_url && ( - - {t('details')} - - )} -
-
-                    {localizedReleaseNotes}
-                  
-
- )} -
- )} -
) : ( <> - {productName} +
+ {productName} + {hasVisibleUpdate && ( + + )} +
- ) : ( - + + {t('flocksproUpgrade')} + )} - + + { + setAccountMenuOpen(false); + setSidebarOpen(false); + }} + className="flex items-center gap-2 px-3 py-2 text-sm font-medium text-zinc-700 transition-colors hover:bg-zinc-50 hover:text-zinc-950 dark:text-zinc-200 dark:hover:bg-zinc-800 dark:hover:text-zinc-50" + > + + {t('settings')} + + +
)} - {collapsed && ( + {collapsed ? ( + ) : ( + )}
@@ -672,6 +853,118 @@ export default function Layout() { + {activeWorkspaceMenu && ( + + )} + {/* Mobile top menu button */}
+ +
); } @@ -93,8 +100,44 @@ describe('ThemeProvider', () => { }); expect(screen.getByTestId('theme-value')).toHaveTextContent('dark'); + expect(screen.getByTestId('effective-theme-value')).toHaveTextContent('dark'); expect(document.documentElement).toHaveClass('dark'); expect(document.documentElement.style.colorScheme).toBe('dark'); await waitFor(() => expect(localStorage.getItem('flocks_theme')).toBe('dark')); }); + + it('temporarily overrides the displayed theme without changing the stored preference', async () => { + const user = userEvent.setup(); + localStorage.setItem('flocks_theme', 'light'); + + render( + + + , + ); + + expect(screen.getByTestId('theme-value')).toHaveTextContent('light'); + expect(screen.getByTestId('effective-theme-value')).toHaveTextContent('light'); + expect(document.documentElement).not.toHaveClass('dark'); + + await act(async () => { + await user.click(screen.getByRole('button', { name: 'temp dark' })); + }); + + expect(screen.getByTestId('theme-value')).toHaveTextContent('light'); + expect(screen.getByTestId('effective-theme-value')).toHaveTextContent('dark'); + expect(document.documentElement).toHaveClass('dark'); + expect(document.documentElement.style.colorScheme).toBe('dark'); + expect(localStorage.getItem('flocks_theme')).toBe('light'); + + await act(async () => { + await user.click(screen.getByRole('button', { name: 'clear temp' })); + }); + + expect(screen.getByTestId('theme-value')).toHaveTextContent('light'); + expect(screen.getByTestId('effective-theme-value')).toHaveTextContent('light'); + expect(document.documentElement).not.toHaveClass('dark'); + expect(document.documentElement.style.colorScheme).toBe('light'); + expect(localStorage.getItem('flocks_theme')).toBe('light'); + }); }); diff --git a/webui/src/contexts/ThemeContext.tsx b/webui/src/contexts/ThemeContext.tsx index 2f692071e..3aaf1089f 100644 --- a/webui/src/contexts/ThemeContext.tsx +++ b/webui/src/contexts/ThemeContext.tsx @@ -1,19 +1,23 @@ import { createContext, useCallback, useEffect, useLayoutEffect, useMemo, useState, type ReactNode } from 'react'; -type Theme = 'light' | 'dark'; +export type Theme = 'light' | 'dark'; interface ThemeContextValue { theme: Theme; + effectiveTheme: Theme; toggleTheme: () => void; setTheme: (theme: Theme) => void; + setTemporaryThemeOverride: (theme: Theme | null) => void; } const THEME_STORAGE_KEY = 'flocks_theme'; const ThemeContext = createContext({ theme: 'light', + effectiveTheme: 'light', toggleTheme: () => undefined, setTheme: () => undefined, + setTemporaryThemeOverride: () => undefined, }); function getInitialTheme(): Theme { @@ -34,10 +38,12 @@ function applyTheme(theme: Theme) { export function ThemeProvider({ children }: { children: ReactNode }) { const [theme, setThemeState] = useState(getInitialTheme); + const [temporaryThemeOverride, setTemporaryThemeOverride] = useState(null); + const effectiveTheme = temporaryThemeOverride ?? theme; useLayoutEffect(() => { - applyTheme(theme); - }, [theme]); + applyTheme(effectiveTheme); + }, [effectiveTheme]); useEffect(() => { if (typeof window.localStorage?.setItem === 'function') { @@ -53,7 +59,16 @@ export function ThemeProvider({ children }: { children: ReactNode }) { setThemeState((current) => (current === 'dark' ? 'light' : 'dark')); }, []); - const value = useMemo(() => ({ theme, toggleTheme, setTheme }), [setTheme, theme, toggleTheme]); + const value = useMemo( + () => ({ + theme, + effectiveTheme, + toggleTheme, + setTheme, + setTemporaryThemeOverride, + }), + [effectiveTheme, setTheme, theme, toggleTheme], + ); return {children}; } diff --git a/webui/src/hooks/useSessions.test.ts b/webui/src/hooks/useSessions.test.ts index 03a744a3b..f5a6c503d 100644 --- a/webui/src/hooks/useSessions.test.ts +++ b/webui/src/hooks/useSessions.test.ts @@ -324,6 +324,46 @@ describe('updateMessagePart scheduling', () => { expect((result.current.messages[2].parts as any[])[0].text).toBe('new reply'); }); + it('moves a replaced temp user before an already streamed assistant child', async () => { + const { result } = renderHook(() => useSessionMessages('sess-1')); + await act(async () => {}); + + await act(async () => { + result.current.updateMessagePart({ + id: 'new-text', + messageID: 'new-assistant', + sessionID: 'sess-1', + type: 'text', + text: 'new reply', + }); + result.current.updateMessage({ + id: 'new-assistant', + sessionID: 'sess-1', + role: 'assistant', + parentID: 'new-user', + time: { created: 200 }, + }); + result.current.addMessage(makeMsg({ + id: 'temp-user', + role: 'user', + parts: [{ id: 'temp-user-text', type: 'text', text: 'hello' } as any], + })); + result.current.updateMessage({ + id: 'new-user', + sessionID: 'sess-1', + role: 'user', + time: { created: 100 }, + }); + }); + + expect(result.current.messages.map((msg) => msg.id)).toEqual([ + 'new-user', + 'new-assistant', + ]); + expect((result.current.messages[0].parts as any[])[0].text).toBe('hello'); + expect((result.current.messages[1].parts as any[])[0].text).toBe('new reply'); + }); + it('truncateAfterMessage keeps the target by default', async () => { const { result } = renderHook(() => useSessionMessages('sess-1')); await act(async () => {}); diff --git a/webui/src/hooks/useSessions.ts b/webui/src/hooks/useSessions.ts index 4b93a1699..4f556780d 100644 --- a/webui/src/hooks/useSessions.ts +++ b/webui/src/hooks/useSessions.ts @@ -33,10 +33,46 @@ function finalizeStoppedMessageParts(parts: Message['parts'], stoppedAt = Date.n }); } +function normalizeMessageOrder(messages: Message[]): Message[] { + const messageIds = new Set(messages.map((message) => message.id)); + const assistantChildrenByParent = new Map(); + const childIds = new Set(); + + messages.forEach((message) => { + if (message.role !== 'assistant' || !message.parentID || !messageIds.has(message.parentID)) { + return; + } + childIds.add(message.id); + const siblings = assistantChildrenByParent.get(message.parentID) ?? []; + siblings.push(message); + assistantChildrenByParent.set(message.parentID, siblings); + }); + + const ordered: Message[] = []; + const pushed = new Set(); + + messages.forEach((message) => { + if (childIds.has(message.id)) return; + if (!pushed.has(message.id)) { + ordered.push(message); + pushed.add(message.id); + } + + const children = assistantChildrenByParent.get(message.id) ?? []; + children.forEach((child) => { + if (pushed.has(child.id)) return; + ordered.push(child); + pushed.add(child.id); + }); + }); + + return ordered; +} + function mergeFetchedMessages(prev: Message[], fetched: Message[]): Message[] { const previousById = new Map(prev.map((message) => [message.id, message])); - return fetched.map((message) => { + return normalizeMessageOrder(fetched.map((message) => { const existing = previousById.get(message.id); if (!existing) return message; @@ -53,11 +89,11 @@ function mergeFetchedMessages(prev: Message[], fetched: Message[]): Message[] { } return message; - }); + })); } function mergeLatestFetchedMessages(prev: Message[], fetched: Message[]): Message[] { - if (prev.length === 0) return fetched; + if (prev.length === 0) return normalizeMessageOrder(fetched); const fetchedIds = new Set(fetched.map((message) => message.id)); const mergedFetched = mergeFetchedMessages(prev, fetched); const firstFetchedTimestamp = mergedFetched[0]?.timestamp ?? Number.POSITIVE_INFINITY; @@ -67,12 +103,12 @@ function mergeLatestFetchedMessages(prev: Message[], fetched: Message[]): Messag const retainedNewer = prev.filter( (message) => !fetchedIds.has(message.id) && message.timestamp > firstFetchedTimestamp, ); - return [...retainedOlder, ...mergedFetched, ...retainedNewer]; + return normalizeMessageOrder([...retainedOlder, ...mergedFetched, ...retainedNewer]); } function prependOlderMessages(prev: Message[], older: Message[]): Message[] { const existingIds = new Set(prev.map((message) => message.id)); - return [...older.filter((message) => !existingIds.has(message.id)), ...prev]; + return normalizeMessageOrder([...older.filter((message) => !existingIds.has(message.id)), ...prev]); } function transformMessageResponse(data: any): { @@ -441,7 +477,7 @@ export function useSessionMessages(sessionId?: string) { ); if (tempIndex >= 0) { const updated = [...prev]; - updated[tempIndex] = { + const nextUser = { id: messageInfo.id, sessionID: messageInfo.sessionID, role: 'user' as const, @@ -454,7 +490,8 @@ export function useSessionMessages(sessionId?: string) { tokens: messageInfo.tokens, timestamp: messageInfo.time?.created || updated[tempIndex].timestamp, }; - return updated; + updated[tempIndex] = nextUser; + return normalizeMessageOrder(updated); } } @@ -474,20 +511,9 @@ export function useSessionMessages(sessionId?: string) { timestamp: messageInfo.time?.created || Date.now(), }; - if (messageInfo.role === 'user') { - const childIndex = prev.findIndex( - (m) => m.role === 'assistant' && m.parentID === messageInfo.id, - ); - if (childIndex >= 0) { - const updated = [...prev]; - updated.splice(childIndex, 0, nextMessage); - return updated; - } - } - - return [...prev, { + return normalizeMessageOrder([...prev, { ...nextMessage, - }]; + }]); }); }, /** diff --git a/webui/src/hooks/useUserDefinedPages.test.tsx b/webui/src/hooks/useWebUIContractPages.test.tsx similarity index 60% rename from webui/src/hooks/useUserDefinedPages.test.tsx rename to webui/src/hooks/useWebUIContractPages.test.tsx index f10bb4280..afbebdfb1 100644 --- a/webui/src/hooks/useUserDefinedPages.test.tsx +++ b/webui/src/hooks/useWebUIContractPages.test.tsx @@ -1,33 +1,36 @@ import { renderHook, waitFor } from '@testing-library/react'; import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { useUserDefinedPages } from './useUserDefinedPages'; +import { useWebUIContractPages } from './useWebUIContractPages'; import { setupSSEMock } from '@/test/mocks/sse'; -const { listMock } = vi.hoisted(() => ({ +const { listMock, listWorkspacesMock } = vi.hoisted(() => ({ listMock: vi.fn(), + listWorkspacesMock: vi.fn(), })); -vi.mock('@/api/userDefinedPages', () => ({ - userDefinedPagesAPI: { +vi.mock('@/api/webuiContractPages', () => ({ + webuiContractPagesAPI: { list: listMock, + listWorkspaces: listWorkspacesMock, }, })); -describe('useUserDefinedPages', () => { +describe('useWebUIContractPages', () => { const sse = setupSSEMock(); beforeEach(() => { vi.clearAllMocks(); + listWorkspacesMock.mockResolvedValue({ data: [] }); }); - it('loads enabled user defined pages for navigation', async () => { + it('loads enabled WebUI contract pages for navigation', async () => { listMock.mockResolvedValueOnce({ data: [ { id: 'dash-1', title: '仪表盘', - route: '/user-defined-pages/dash-1', + route: '/contracts/webui/dash-1', icon: 'LayoutDashboard', order: 10, enabled: true, @@ -38,7 +41,7 @@ describe('useUserDefinedPages', () => { ], }); - const { result } = renderHook(() => useUserDefinedPages()); + const { result } = renderHook(() => useWebUIContractPages()); await waitFor(() => { expect(result.current.loading).toBe(false); @@ -46,10 +49,12 @@ describe('useUserDefinedPages', () => { expect(result.current.pages).toHaveLength(1); expect(result.current.pages[0].title).toBe('仪表盘'); + expect(result.current.workspaces).toHaveLength(0); expect(listMock).toHaveBeenCalledWith(true); + expect(listWorkspacesMock).toHaveBeenCalledWith(true); }); - it('refetches when user_defined_pages.nav_changed SSE event arrives', async () => { + it('refetches when contracts.webui.pages.nav_changed SSE event arrives', async () => { listMock .mockResolvedValueOnce({ data: [] }) .mockResolvedValueOnce({ @@ -57,7 +62,7 @@ describe('useUserDefinedPages', () => { { id: 'dash-2', title: '新页面', - route: '/user-defined-pages/dash-2', + route: '/contracts/webui/dash-2', icon: 'LayoutDashboard', order: 20, enabled: true, @@ -67,8 +72,9 @@ describe('useUserDefinedPages', () => { }, ], }); + listWorkspacesMock.mockResolvedValue({ data: [] }); - const { result } = renderHook(() => useUserDefinedPages()); + const { result } = renderHook(() => useWebUIContractPages()); await waitFor(() => { expect(result.current.loading).toBe(false); @@ -76,7 +82,7 @@ describe('useUserDefinedPages', () => { sse.open(); sse.send({ - type: 'user_defined_pages.nav_changed', + type: 'contracts.webui.pages.nav_changed', properties: { id: 'dash-2' }, }); @@ -84,5 +90,6 @@ describe('useUserDefinedPages', () => { expect(result.current.pages).toHaveLength(1); }); expect(listMock).toHaveBeenCalledTimes(2); + expect(listWorkspacesMock).toHaveBeenCalledTimes(2); }); }); diff --git a/webui/src/hooks/useUserDefinedPages.ts b/webui/src/hooks/useWebUIContractPages.ts similarity index 68% rename from webui/src/hooks/useUserDefinedPages.ts rename to webui/src/hooks/useWebUIContractPages.ts index c472b8485..4457bd9c8 100644 --- a/webui/src/hooks/useUserDefinedPages.ts +++ b/webui/src/hooks/useWebUIContractPages.ts @@ -1,10 +1,15 @@ import { useCallback, useEffect, useRef, useState } from 'react'; import i18n from '@/i18n'; -import { userDefinedPagesAPI, type UserDefinedPageListItem } from '@/api/userDefinedPages'; +import { + webuiContractPagesAPI, + type WebUIContractPageListItem, + type WebUIContractWorkspaceListItem, +} from '@/api/webuiContractPages'; import { useSSE } from '@/hooks/useSSE'; -export function useUserDefinedPages() { - const [pages, setPages] = useState([]); +export function useWebUIContractPages() { + const [pages, setPages] = useState([]); + const [workspaces, setWorkspaces] = useState([]); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const loadingRef = useRef(false); @@ -16,11 +21,16 @@ export function useUserDefinedPages() { if (!silent) setLoading(true); setError(null); try { - const response = await userDefinedPagesAPI.list(true); - setPages(Array.isArray(response.data) ? response.data : []); + const [pagesResponse, workspacesResponse] = await Promise.all([ + webuiContractPagesAPI.list(true), + webuiContractPagesAPI.listWorkspaces(true), + ]); + setPages(Array.isArray(pagesResponse.data) ? pagesResponse.data : []); + setWorkspaces(Array.isArray(workspacesResponse.data) ? workspacesResponse.data : []); } catch (err: unknown) { setPages([]); - setError(err instanceof Error ? err.message : i18n.t('nav.fetchFailed', { ns: 'userDefinedPage' })); + setWorkspaces([]); + setError(err instanceof Error ? err.message : i18n.t('nav.fetchFailed', { ns: 'webuiContractPage' })); } finally { loadingRef.current = false; if (!silent) setLoading(false); @@ -58,7 +68,7 @@ export function useUserDefinedPages() { useSSE({ url: '/api/event', onEvent: useCallback((evt) => { - if (evt.type === 'user_defined_pages.nav_changed') { + if (evt.type === 'contracts.webui.pages.nav_changed') { void fetchPages(true); } }, [fetchPages]), @@ -67,6 +77,7 @@ export function useUserDefinedPages() { return { pages, + workspaces, loading, error, refetch: () => fetchPages(), diff --git a/webui/src/i18n.ts b/webui/src/i18n.ts index ef5e0c139..f362b884d 100644 --- a/webui/src/i18n.ts +++ b/webui/src/i18n.ts @@ -22,7 +22,7 @@ import enWorkspace from './locales/en-US/workspace.json'; import enAuth from './locales/en-US/auth.json'; import enNotification from './locales/en-US/notification.json'; import enFlocksPro from './locales/en-US/flockspro.json'; -import enUserDefinedPage from './locales/en-US/userDefinedPage.json'; +import enWebUIContractPage from './locales/en-US/webuiContractPage.json'; import enDevice from './locales/en-US/device.json'; import zhCommon from './locales/zh-CN/common.json'; @@ -45,7 +45,7 @@ import zhWorkspace from './locales/zh-CN/workspace.json'; import zhAuth from './locales/zh-CN/auth.json'; import zhNotification from './locales/zh-CN/notification.json'; import zhFlocksPro from './locales/zh-CN/flockspro.json'; -import zhUserDefinedPage from './locales/zh-CN/userDefinedPage.json'; +import zhWebUIContractPage from './locales/zh-CN/webuiContractPage.json'; import zhDevice from './locales/zh-CN/device.json'; i18n @@ -74,7 +74,7 @@ i18n auth: enAuth, notification: enNotification, flockspro: enFlocksPro, - userDefinedPage: enUserDefinedPage, + webuiContractPage: enWebUIContractPage, device: enDevice, }, 'zh-CN': { @@ -98,13 +98,13 @@ i18n auth: zhAuth, notification: zhNotification, flockspro: zhFlocksPro, - userDefinedPage: zhUserDefinedPage, + webuiContractPage: zhWebUIContractPage, device: zhDevice, }, }, fallbackLng: 'en-US', defaultNS: 'common', - ns: ['common', 'nav', 'home', 'session', 'agent', 'task', 'workflow', 'tool', 'skill', 'model', 'mcp', 'config', 'channel', 'permission', 'monitoring', 'update', 'workspace', 'auth', 'notification', 'flockspro', 'device', 'userDefinedPage'], + ns: ['common', 'nav', 'home', 'session', 'agent', 'task', 'workflow', 'tool', 'skill', 'model', 'mcp', 'config', 'channel', 'permission', 'monitoring', 'update', 'workspace', 'auth', 'notification', 'flockspro', 'device', 'webuiContractPage'], detection: { order: ['localStorage', 'navigator'], lookupLocalStorage: 'flocks-language', diff --git a/webui/src/locales/en-US/device.json b/webui/src/locales/en-US/device.json index 1ec2ac5b3..3cd11af43 100644 --- a/webui/src/locales/en-US/device.json +++ b/webui/src/locales/en-US/device.json @@ -68,11 +68,54 @@ "rex": "Rex", "manual": "Manual" }, + "guide": { + "workbenchTab": "Workbench", + "title": "Rex-assisted integration", + "subtitle": "Choose a guide or example. Rex will confirm the integration path, credential fields, and smoke-test steps to help you connect the device.", + "customTitle": "Custom device integration", + "caseTitle": "Create device", + "examples": { + "supported": "I want to connect a supported security device", + "addressOnly": "I only have the device URL and login method", + "noApi": "This device does not expose an API" + }, + "prompts": { + "api": "I have selected API integration. Continue with API integration and guide me through confirming the vendor, product name, API docs, Base URL, authentication method, device capabilities to implement, and smoke-test steps.", + "browser": "I have selected browser integration. Continue with browser integration and guide me through confirming the product URL, login method, target page behavior, data to collect, and automation validation steps. When generating the device plugin, use auth_state_path for the browser login state; username/password are only for auto re-login after the login state expires and must use storage: secret; inline auth_state, if needed, must use storage: secret and internal: true.", + "addressOnly": "I only have the device URL and login method. First help me determine whether an existing template can be used. If information is missing, ask only the most important question at a time, and output a device configuration JSON draft once enough details are available.", + "tdp": "I have selected the TDP integration example. Guide me through connecting ThreatBook TDP and confirming the connection URL, API key, secret, SSL verification, and connection test steps.", + "onesec": "I have selected the OneSEC integration example. Guide me through connecting OneSEC and confirming the Base URL, API key, secret, SSL verification, and connection test steps." + }, + "actions": { + "api": "API integration", + "browser": "Browser integration" + }, + "cases": { + "tdp": "TDP integration", + "onesec": "OneSEC integration", + "more": "View more" + }, + "descriptions": { + "api": "Use this when the device exposes APIs. Rex will guide API docs, authentication, credential fields, and smoke tests.", + "browser": "Use this when the device has no open API. Rex will guide the console URL, login method, page actions, and automation validation.", + "tdp": "Continue with the known ThreatBook TDP case, focusing on API key, secret, base URL, and connection testing.", + "onesec": "Continue with the known OneSEC case, focusing on the connection URL, credential fields, and connection testing.", + "more": "Browse supported device templates. After a template is selected, Rex will continue guiding the configuration." + } + }, + "supportedList": { + "back": "Back", + "title": "Supported devices", + "subtitle": "Choose a vendor first, then choose the device to connect. Rex will continue guiding the integration after a device is selected.", + "deviceCount": "{{count}} device(s)", + "integratedCount": "{{count}} integrated", + "templateTooltip": "Rex will continue configuration from this device template. If it is not installed, Rex will guide you to install it from FlockHub first." + }, "rex": { "title": "Rex-guided device add", "heading": "Rex-guided integration", "subtitle": "Describe the model, integration path, and materials you have. Rex will help confirm the template, credential fields, and test steps.", - "welcome": "Tell me the device vendor, model, version, and any API docs, console URL, or authentication details you already have.\n\nIf the device is already supported, I will organize the values needed for the manual form. Once there is enough information, I will output a ```json configuration draft that the page can detect and use to fill the form. If there is no matching template yet, I will guide you through creating a custom device integration.", + "welcome": "Tell me the device vendor, model, version, and any API docs, console URL, or authentication details you already have.\n\nIf the device is already supported, I will organize the values needed for the device configuration form. Once there is enough information, I will output a ```json configuration draft that the page can detect and use to fill the form. If there is no matching template yet, I will guide you through creating a custom device integration.", "placeholder": "Describe the device, URL, authentication method, or upload related materials", "pending": "Rex is getting ready...", "manualAction": "Switch to manual", @@ -91,25 +134,25 @@ "guides": { "existing": { "title": "Existing template", - "desc": "Organize manual-form fields for supported devices", + "desc": "Organize configuration fields for supported devices", "prompt": "I want to integrate a security device that is already supported by a template. Guide me through confirming the vendor, model, connection URL, credential fields, and connection test steps." }, "api": { "title": "Custom API", "desc": "Create a custom device plugin from API docs", - "prompt": "I want to integrate an unsupported API device. Help me confirm the API docs, authentication method, device details, and steps for generating a device plugin." + "prompt": "I have selected API integration. Help me confirm the API docs, authentication method, device details, and steps for generating a device plugin." }, "webcli": { - "title": "WebCLI", + "title": "Browser integration", "desc": "Create integration capability from a web console", - "prompt": "I want to integrate a web-console device without an open API. Guide me through preparing the product URL, login method, target pages/actions, and WebCLI integration steps." + "prompt": "I have selected browser integration. Guide me through preparing the product URL, login method, target pages/actions, and browser integration steps. When generating the device plugin, use auth_state_path for the browser login state; username/password are only for auto re-login after the login state expires and must use storage: secret; inline auth_state, if needed, must use storage: secret and internal: true." } } }, "vendorHint": "Choose a vendor — {{count}} available", "chooseVendorOrCustom": "Choose a vendor or create a custom device integration", "customCardTitle": "Custom Device", - "customCardSubtitle": "API / WebCLI / Workflow", + "customCardSubtitle": "API / Browser / Workflow", "customCardCta": "Create a new integration path", "chooseCustomMode": "Choose how this custom device should connect", "customModes": { @@ -118,7 +161,7 @@ "desc": "Use this when the device exposes APIs. Provide API docs and generate a reusable device plugin." }, "webcli": { - "title": "WebCLI Integration", + "title": "Browser Integration", "desc": "Use this when the device has no open API. First generate and integrate the skill or CLI asset; for security devices, also generate a device plugin for the device page." }, "workflow": { @@ -135,7 +178,15 @@ "installed": "Installed", "available": "Available", "updateAvailable": "Update", - "brokenShort": "Broken" + "brokenShort": "Broken", + "installing": "Installing", + "updating": "Updating", + "installingTemplate": "Installing device template \"{{name}}\"", + "updatingTemplate": "Updating device template \"{{name}}\"", + "installDone": "Device template \"{{name}}\" installed", + "updateDone": "Device template \"{{name}}\" updated", + "installFailed": "Failed to install device template \"{{name}}\"", + "updateFailed": "Failed to update device template \"{{name}}\"" }, "productHint": "{{count}} product(s) — the same type can be added multiple times" }, @@ -151,7 +202,6 @@ "namePlaceholder": "e.g. HQ AF Firewall", "roomLabel": "Room", "connectionParams": "Connection parameters", - "secretConfigured": "Already set · Leave unchanged to keep it, clear to remove", "secretRevealFailed": "Failed to reveal credential", "showSecretAction": "Show", "hideSecretAction": "Hide", @@ -161,6 +211,11 @@ "sslHint": "Disable to reach internal devices with self-signed certificates", "enabledLabel": "Enable device", "enabledHint": "When disabled, agents will not call tools from this device", + "aiAssistTitle": "Rex-assisted config", + "aiAssistHint": "Use the current form and setup guide so Rex can help you continue configuring this device.", + "aiAssistTest": "Test device", + "aiAssistTroubleshoot": "Troubleshoot", + "aiAssistSaveFirst": "Save the device before Rex can call its tools", "testBtn": "Test connection", "saveBtn": "Save config", "addBtn": "Confirm integration", @@ -183,7 +238,7 @@ "toolCount": "Tools", "vendor": "Vendor", "serviceDesc": "Service description", - "viewDocs": "View API docs" + "viewDocs": "View setup guide" }, "toast": { @@ -205,12 +260,12 @@ "custom": { "title": { "api": "Custom Device API Integration", - "webcli": "Custom Device WebCLI Integration", + "webcli": "Custom Device Browser Integration", "workflow": "Custom Device Workflow Integration" }, "subtitle": { "api": "Provide API documentation and generate a reusable device plugin", - "webcli": "Provide the product URL and target interfaces to generate a WebCLI device plugin for the device page", + "webcli": "Provide the product URL and target interfaces to generate a browser-integration device plugin for the device page", "workflow": "Configure Syslog, Kafka, and Webhook entries on the workflow Publish page" }, "welcome": { @@ -233,7 +288,7 @@ "prepareTitle": "Prepare the integration details before sending them to Rex", "prepareIntro": "After submission, you will enter the Rex conversation directly. ", "apiNext": "You can continue to add API doc links, upload documentation files, or explain interface details. Once the plugin is generated, return to the device page to finish the rest of the setup.", - "webcliNext": "You can continue to add page actions, capture targets, and authentication details. The final result should be a WebCLI device plugin that the device page can recognize; a CLI can remain as an optional debugging entry." + "webcliNext": "You can continue to add page actions, capture targets, and authentication details. The final result should be a browser-integration device plugin that the device page can recognize; a CLI can remain as an optional debugging entry." }, "rex": { "apiHint": "Upload API docs or describe the endpoints. Rex will clarify key gaps before generating the device plugin.", diff --git a/webui/src/locales/en-US/flockspro.json b/webui/src/locales/en-US/flockspro.json index 1e17cddc2..7a7b9d8f6 100644 --- a/webui/src/locales/en-US/flockspro.json +++ b/webui/src/locales/en-US/flockspro.json @@ -58,6 +58,7 @@ "refreshing": "Refreshing...", "syncLicenseAction": "Sync License Info", "syncingLicense": "Syncing...", + "consoleSyncWarning": "Console license sync is temporarily unavailable. Local license status was refreshed. {{message}}", "applyAction": "Apply for Upgrade", "applyNewLicenseAction": "Apply for New License", "revokedOrExpiredHint": "The current license is no longer valid. You can submit a new Flocks Pro application.", diff --git a/webui/src/locales/en-US/home.json b/webui/src/locales/en-US/home.json index 49248b103..25874ca3f 100644 --- a/webui/src/locales/en-US/home.json +++ b/webui/src/locales/en-US/home.json @@ -3,10 +3,10 @@ "subtitle": "AI-Native SecOps Automation Platform", "description": "Let AI handle your security operations — from alert triage to threat response, fully automated with 10x efficiency", "getStarted": "Setup Guide", - "createUserDefinedPage": "Create Custom Page", - "createUserDefinedPageSessionTitle": "Create Custom Page", - "createUserDefinedPageInitialMessage": "I want to create a new user-defined page in the Flocks left navigation. Please introduce what this feature can do, where pages are stored, where they appear in the navigation, and the full creation and development workflow. Also explain how to hide a page from the navigation or permanently delete it when I no longer need it. If I already have an idea, guide me step by step: choose a page ID and title, create the page scaffold, write the React page code, and explain how live preview works after saving. Finally, tell me what information you still need from me (such as page name, content to display, and data sources).", - "createUserDefinedPageError": "Failed to create session. Please try again later.", + "createWebUIContractPage": "Custom Page", + "createWebUIContractPageSessionTitle": "Custom Page", + "createWebUIContractPageInitialMessage": "I want to create a new custom page in the Flocks left navigation. Please introduce what this feature can do, where page files are stored, where they appear in the navigation, and the full creation and development workflow. Also explain how to hide a page from the navigation or permanently delete it when I no longer need it. If I already have an idea, guide me step by step: choose a page ID and title, create the page scaffold, write the React page code, and explain how live preview works after saving. Finally, tell me what information you still need from me, such as page name, content to display, and data contracts.", + "createWebUIContractPageError": "Failed to create a custom page session. Please try again later.", "openSource": "Open Source", "systemCard": { "title": "Flocks System", diff --git a/webui/src/locales/en-US/nav.json b/webui/src/locales/en-US/nav.json index 4eb050c04..3466706a6 100644 --- a/webui/src/locales/en-US/nav.json +++ b/webui/src/locales/en-US/nav.json @@ -2,6 +2,7 @@ "home": "Home", "flocksHome": "Home", "aiWorkbench": "AI Workbench", + "sceneWorkspaces": "Scene Workspaces", "sessions": "Sessions", "tasks": "Task Center", "workspace": "Workspace", @@ -19,7 +20,24 @@ "accountManagement": "Account", "systemLog": "System Logs", "flocksproUpgrade": "Flocks Pro", + "checkUpdate": "Check for updates", "auditLogs": "Audit Logs", + "settings": "Settings", + "settingsBack": "Back", + "settingsTitle": "Settings", + "settingsDescription": "Manage preferences, account, and system settings", + "settingsPreferences": "Preferences", + "settingsPreferencesDescription": "Adjust the console language and interface theme.", + "settingsGroupPreferences": "Preferences", + "settingsGroupSystem": "Account & System", + "settingsGroupIntegrations": "Models & Channels", + "language": "Language", + "languageDescription": "Choose the language used by the console.", + "theme": "Theme", + "themeDescription": "Choose the default light or dark mode for the console.", + "lightTheme": "Light", + "darkTheme": "Dark", + "logout": "Log out", "expandNav": "Expand navigation", "collapseNav": "Collapse navigation", "switchLanguage": "Switch Language", diff --git a/webui/src/locales/en-US/socPage.json b/webui/src/locales/en-US/socPage.json deleted file mode 100644 index d5424c691..000000000 --- a/webui/src/locales/en-US/socPage.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "host": { - "missingPageId": "Page ID is missing", - "loading": "Loading custom page...", - "unavailableTitle": "Page unavailable", - "notBuilt": "This page has not been built yet", - "loadFailed": "Failed to load page", - "buildFailed": "Page build failed", - "apiFailed": "Page API runtime failed", - "retry": "Retry", - "emptyComponent": "Page component is empty", - "renderFailedTitle": "Custom page failed to run", - "renderFailed": "Page render failed", - "bundleMissingExport": "Page bundle does not export a default component" - }, - "nav": { - "fetchFailed": "Failed to load custom page navigation" - } -} diff --git a/webui/src/locales/en-US/update.json b/webui/src/locales/en-US/update.json index a8f773023..be271bbe1 100644 --- a/webui/src/locales/en-US/update.json +++ b/webui/src/locales/en-US/update.json @@ -3,6 +3,9 @@ "proTitle": "Flocks Pro Version Info", "currentVersion": "Current Version", "latestVersion": "Latest Version", + "coreVersion": "Core Version", + "proComponent": "Pro Component", + "notInstalled": "Not installed", "hasUpdate": "Update Available", "upToDate": "Up to Date", "releaseNotes": "Release Notes", diff --git a/webui/src/locales/en-US/userDefinedPage.json b/webui/src/locales/en-US/webuiContractPage.json similarity index 53% rename from webui/src/locales/en-US/userDefinedPage.json rename to webui/src/locales/en-US/webuiContractPage.json index efbee33ab..93d0e2fc0 100644 --- a/webui/src/locales/en-US/userDefinedPage.json +++ b/webui/src/locales/en-US/webuiContractPage.json @@ -14,5 +14,18 @@ }, "nav": { "fetchFailed": "Failed to load custom page navigation" + }, + "workspace": { + "missingWorkspaceId": "Workspace ID is missing", + "loading": "Loading workspace...", + "unavailableTitle": "Workspace unavailable", + "loadFailed": "Failed to load workspace", + "notFound": "Workspace not found", + "empty": "No pages in this workspace", + "selectPage": "Select a page from the workspace menu", + "pageNotFound": "Page not found", + "sectionNavigation": "Workspace page sections", + "collapseSidebar": "Collapse sidebar", + "expandSidebar": "Expand sidebar" } } diff --git a/webui/src/locales/en-US/workspace.json b/webui/src/locales/en-US/workspace.json index afaf03fb8..963113209 100644 --- a/webui/src/locales/en-US/workspace.json +++ b/webui/src/locales/en-US/workspace.json @@ -19,6 +19,7 @@ "modified": "Modified" }, "download": "Download", + "reveal": "Open containing folder", "delete": "Delete", "edit": "Edit", "save": "Save", @@ -28,6 +29,26 @@ "binaryPreview": "Binary files cannot be previewed", "truncatedPreview": "This file is large, so only the first {{limit}} is previewed. Inline editing is disabled to avoid saving truncated content; download the file for the full contents.", "downloadFile": "Download File", + "preview": { + "previewMode": "Preview", + "sourceMode": "Source", + "fullscreen": "Fullscreen preview", + "resize": "Drag to resize preview", + "htmlSandbox": "HTML is rendered in a restricted sandbox. Scripts will not run.", + "jsonParseFailed": "JSON parsing failed, so the original text is shown.", + "jsonlParseFailed": "{{count}} JSONL line(s) could not be parsed and were kept as original text.", + "pdfLoading": "Loading PDF...", + "pdfRendering": "Rendering page...", + "pdfLoadFailed": "Failed to load PDF preview", + "pdfCanvasUnavailable": "Canvas rendering is not available in this environment.", + "pageIndicator": "{{page}} / {{total}}", + "previousPage": "Previous page", + "nextPage": "Next page", + "zoomIn": "Zoom in", + "zoomOut": "Zoom out", + "unsupportedTitle": "This file cannot be previewed", + "unsupportedDesc": "This file type is not supported for inline preview. You can download it or open its containing folder." + }, "toast": { "loadDirFailed": "Failed to load directory", "readFileFailed": "Failed to read file", @@ -38,7 +59,9 @@ "uploadSuccess": "Uploaded {{count}} file(s) successfully", "uploadPartialFail": "{{count}} file(s) failed to upload", "uploadFailed": "Upload failed", - "createDirFailed": "Failed to create folder" + "createDirFailed": "Failed to create folder", + "revealSuccess": "Opened containing folder", + "revealFailed": "Failed to open folder" }, "confirm": { "deleteTitle": "Confirm Delete", diff --git a/webui/src/locales/zh-CN/device.json b/webui/src/locales/zh-CN/device.json index f6140adcf..8cd4e686b 100644 --- a/webui/src/locales/zh-CN/device.json +++ b/webui/src/locales/zh-CN/device.json @@ -68,11 +68,54 @@ "rex": "Rex 接入", "manual": "手动接入" }, + "guide": { + "workbenchTab": "工作台", + "title": "Rex 辅助接入", + "subtitle": "选择一个引导或案例,Rex 会确认接入方式、凭据字段和冒烟测试步骤,帮助你完成设备接入。", + "customTitle": "自定义设备接入", + "caseTitle": "创建设备", + "examples": { + "supported": "我想接入一台已支持的安全设备", + "addressOnly": "我只有设备地址和登录方式", + "noApi": "这台设备没有开放 API" + }, + "prompts": { + "api": "我已选择 API 接入,请按 API 接入继续,引导我确认厂商、产品名、API 文档、Base URL、认证方式、需要实现的设备能力和冒烟测试步骤。", + "browser": "我已选择浏览器接入,请按浏览器接入继续,引导我确认产品 URL、登录方式、目标页面行为、需要获取的数据和自动化验证步骤。生成 device 插件时,auth_state_path 用于保存浏览器登录态;username/password 仅用于登录态失效后的自动重登,必须使用 storage: secret;如需内联 auth_state,必须 storage: secret 且 internal: true。", + "addressOnly": "我只有待接入设备的地址和登录方式,请先帮我判断它是否已有模板可用;如果信息不足,请一次只问一个最关键问题,并在信息足够后输出设备配置 JSON 草稿。", + "tdp": "我已选择 TDP 接入案例,请引导我接入微步 TDP 设备,确认连接地址、API Key、Secret、SSL 验证和连通测试步骤。", + "onesec": "我已选择 OneSEC 接入案例,请引导我接入 OneSEC 设备,确认 Base URL、API Key、Secret、SSL 验证和连通测试步骤。" + }, + "actions": { + "api": "API 接入", + "browser": "浏览器接入" + }, + "cases": { + "tdp": "TDP 接入", + "onesec": "OneSEC 接入", + "more": "查看更多" + }, + "descriptions": { + "api": "设备提供 API 能力时使用,Rex 会引导确认文档、认证方式、凭据字段和冒烟测试。", + "browser": "设备没有开放 API 时使用,Rex 会引导确认控制台地址、登录方式、页面操作和自动化验证。", + "tdp": "按微步 TDP 的已知接入案例继续,重点确认 API Key、Secret、Base URL 和连通测试。", + "onesec": "按 OneSEC 的已知接入案例继续,重点确认连接地址、认证字段和连通测试。", + "more": "查看当前已支持的设备模板,选择模板后由 Rex 继续引导填写配置。" + } + }, + "supportedList": { + "back": "返回", + "title": "已支持设备列表", + "subtitle": "先选择厂商,再选择要接入的设备。点击设备后 Rex 会继续引导接入。", + "deviceCount": "{{count}} 款设备", + "integratedCount": "已接入 {{count}} 台", + "templateTooltip": "点击后 Rex 会基于该设备模板继续引导配置;如模板未安装,会先引导到 FlockHub 安装。" + }, "rex": { "title": "Rex 引导添加设备", "heading": "Rex 引导接入", "subtitle": "描述设备型号、接入方式和已有资料,Rex 会帮你确认模板、凭据字段和测试步骤。", - "welcome": "请告诉我你要接入的设备厂商、型号、版本,以及当前已有的 API 文档、控制台地址或认证方式。\n\n如果是已支持设备,我会帮你整理手动接入表单需要填写的内容;信息足够后,我会输出一段 ```json 配置草稿,页面检测到后可一键填充表单。如果还没有接入模板,我会引导你创建自定义设备接入。", + "welcome": "请告诉我你要接入的设备厂商、型号、版本,以及当前已有的 API 文档、控制台地址或认证方式。\n\n如果是已支持设备,我会帮你整理设备配置表单需要填写的内容;信息足够后,我会输出一段 ```json 配置草稿,页面检测到后可一键填充表单。如果还没有接入模板,我会引导你创建自定义设备接入。", "placeholder": "描述要接入的设备、地址、认证方式或上传相关资料", "pending": "Rex 准备中...", "manualAction": "切换到手动接入", @@ -91,25 +134,25 @@ "guides": { "existing": { "title": "已有模板", - "desc": "整理已支持设备的手动接入字段", + "desc": "整理已支持设备的配置字段", "prompt": "我要接入一个已有模板支持的安全设备,请引导我确认厂商、设备型号、连接地址、认证字段和连通测试步骤。" }, "api": { "title": "自定义 API", "desc": "通过 API 文档创建自定义 device 插件", - "prompt": "我要接入一个暂未支持的 API 设备,请先帮我确认需要的 API 文档、认证方式、设备信息和生成 device 插件的步骤。" + "prompt": "我已选择 API 接入,请帮我确认需要的 API 文档、认证方式、设备信息和生成 device 插件的步骤。" }, "webcli": { - "title": "WebCLI", + "title": "浏览器接入", "desc": "通过 Web 控制台页面创建接入能力", - "prompt": "我要接入一个没有开放 API 的 Web 控制台设备,请引导我准备产品 URL、登录方式、目标页面/操作和 WebCLI 接入步骤。" + "prompt": "我已选择浏览器接入,请引导我准备产品 URL、登录方式、目标页面/操作和浏览器接入步骤。生成 device 插件时,auth_state_path 用于保存浏览器登录态;username/password 仅用于登录态失效后的自动重登,必须使用 storage: secret;如需内联 auth_state,必须 storage: secret 且 internal: true。" } } }, "vendorHint": "选择设备所属厂商,共 {{count}} 家", "chooseVendorOrCustom": "选择设备所属厂商,或创建自定义设备接入", "customCardTitle": "自定义设备", - "customCardSubtitle": "API / WebCLI / Workflow", + "customCardSubtitle": "API / 浏览器 / Workflow", "customCardCta": "创建新的接入方式", "chooseCustomMode": "请选择自定义设备的接入方式", "customModes": { @@ -118,7 +161,7 @@ "desc": "设备提供 API 能力时使用。需要提供 API 文档,最终生成可复用的 device 插件。" }, "webcli": { - "title": "WebCLI 接入", + "title": "浏览器接入", "desc": "设备没有开放 API 时使用。先生成并集成 skill/CLI 资产;如果是安全设备场景,再额外生成可在设备页使用的 device 插件。" }, "workflow": { @@ -135,7 +178,15 @@ "installed": "已安装", "available": "可安装", "updateAvailable": "可更新", - "brokenShort": "不可用" + "brokenShort": "不可用", + "installing": "安装中", + "updating": "更新中", + "installingTemplate": "正在安装设备模板「{{name}}」", + "updatingTemplate": "正在更新设备模板「{{name}}」", + "installDone": "设备模板「{{name}}」已安装", + "updateDone": "设备模板「{{name}}」已更新", + "installFailed": "设备模板「{{name}}」安装失败", + "updateFailed": "设备模板「{{name}}」更新失败" }, "productHint": "共 {{count}} 款设备,同款设备可多次接入" }, @@ -151,7 +202,6 @@ "namePlaceholder": "例如:总部 AF 防火墙", "roomLabel": "所属机房", "connectionParams": "连接参数", - "secretConfigured": "已配置 · 保持不变请勿修改,清空则删除", "secretRevealFailed": "读取密钥失败", "showSecretAction": "显示", "hideSecretAction": "隐藏", @@ -161,6 +211,11 @@ "sslHint": "关闭可访问自签名证书的内网设备", "enabledLabel": "启用设备", "enabledHint": "关闭后 Agent 不会调用此设备的工具", + "aiAssistTitle": "Rex 辅助配置", + "aiAssistHint": "结合当前表单和配置指引,让 Rex 帮你继续完成配置。", + "aiAssistTest": "测试设备", + "aiAssistTroubleshoot": "排查问题", + "aiAssistSaveFirst": "保存设备后可让 Rex 调用设备工具测试", "testBtn": "连通测试", "saveBtn": "保存配置", "addBtn": "确认接入", @@ -183,7 +238,7 @@ "toolCount": "工具数量", "vendor": "厂商", "serviceDesc": "服务简介", - "viewDocs": "查看 API 文档" + "viewDocs": "查看配置指引" }, "toast": { @@ -205,12 +260,12 @@ "custom": { "title": { "api": "自定义设备 API 接入", - "webcli": "自定义设备 WebCLI 接入", + "webcli": "自定义设备浏览器接入", "workflow": "自定义设备 Workflow 接入" }, "subtitle": { "api": "提供 API 文档,生成可复用的 device 插件", - "webcli": "提供产品 URL 和目标接口,生成可在设备页使用的 WebCLI device 插件", + "webcli": "提供产品 URL 和目标接口,生成可在设备页使用的浏览器接入 device 插件", "workflow": "Syslog、Kafka、Webhook 统一在工作流发布页面配置" }, "welcome": { @@ -233,7 +288,7 @@ "prepareTitle": "提交给 Rex 前请准备好接入资料", "prepareIntro": "提交后会直接进入 Rex 对话。", "apiNext": "你可以继续补充 API 文档链接、上传文档文件或说明接口细节。插件生成完成后,可返回设备页继续后续配置。", - "webcliNext": "你可以继续补充页面操作、抓包目标和认证方式。最终结果应当是可在设备页识别的 WebCLI device 插件;如果需要,也可以额外保留 CLI 作为调试入口。" + "webcliNext": "你可以继续补充页面操作、抓包目标和认证方式。最终结果应当是可在设备页识别的浏览器接入 device 插件;如果需要,也可以额外保留 CLI 作为调试入口。" }, "rex": { "apiHint": "上传 API 文档或描述接口;Rex 会先澄清关键缺口,确认后再生成 device 插件。", diff --git a/webui/src/locales/zh-CN/flockspro.json b/webui/src/locales/zh-CN/flockspro.json index e8656c64a..9c0e30fde 100644 --- a/webui/src/locales/zh-CN/flockspro.json +++ b/webui/src/locales/zh-CN/flockspro.json @@ -58,6 +58,7 @@ "refreshing": "刷新中...", "syncLicenseAction": "同步授权信息", "syncingLicense": "同步中...", + "consoleSyncWarning": "云端授权同步暂不可用,已继续刷新本地授权状态。{{message}}", "applyAction": "申请升级", "applyNewLicenseAction": "申请新授权", "revokedOrExpiredHint": "当前 License 已失效,可重新提交 Flocks Pro 申请。", diff --git a/webui/src/locales/zh-CN/home.json b/webui/src/locales/zh-CN/home.json index c5693e857..3c096cd6f 100644 --- a/webui/src/locales/zh-CN/home.json +++ b/webui/src/locales/zh-CN/home.json @@ -3,10 +3,10 @@ "subtitle": "AI 原生的安全运营自动化平台", "description": "让 AI 替你做安全运营,从告警研判到威胁处置,全流程智能化,效率提升 10 倍", "getStarted": "新手引导", - "createUserDefinedPage": "创建自定义页面", - "createUserDefinedPageSessionTitle": "创建自定义页面", - "createUserDefinedPageInitialMessage": "我想在 Flocks 左侧导航中创建一个新的用户自定义页面。请先帮我介绍这个功能能做什么、页面会保存在哪里、创建后会出现在导航的什么位置,以及完整的创建与开发流程。也请说明如果不再需要某个页面,如何从导航隐藏或彻底删除。如果我已经有具体想法,也请引导我一步步完成:确定页面 ID 和标题、创建页面骨架、编写 React 页面代码,并说明保存后如何实时预览。最后请告诉我,接下来你只需要我提供哪些信息(例如页面名称、展示内容、数据来源)。", - "createUserDefinedPageError": "无法创建会话,请稍后重试", + "createWebUIContractPage": "自定义页面", + "createWebUIContractPageSessionTitle": "自定义页面", + "createWebUIContractPageInitialMessage": "我想在 Flocks 左侧导航中创建一个新的自定义页面。请先帮我介绍这个功能能做什么、页面文件会保存在哪里、创建后会出现在导航的什么位置,以及完整的创建与开发流程。也请说明如果不再需要某个页面,如何从导航隐藏或彻底删除。如果我已经有具体想法,也请引导我一步步完成:确定页面 ID 和标题、创建页面骨架、编写 React 页面代码,并说明保存后如何实时预览。最后请告诉我,接下来你只需要我提供哪些信息(例如页面名称、展示内容、数据契约)。", + "createWebUIContractPageError": "无法创建自定义页面会话,请稍后重试", "openSource": "开源项目", "systemCard": { "title": "Flocks 系统", diff --git a/webui/src/locales/zh-CN/nav.json b/webui/src/locales/zh-CN/nav.json index d5759ed70..9b672ccd8 100644 --- a/webui/src/locales/zh-CN/nav.json +++ b/webui/src/locales/zh-CN/nav.json @@ -2,6 +2,7 @@ "home": "首页", "flocksHome": "首页", "aiWorkbench": "AI 工作台", + "sceneWorkspaces": "场景工作区", "sessions": "会话管理", "tasks": "任务中心", "workspace": "工作空间", @@ -19,7 +20,24 @@ "accountManagement": "账号管理", "systemLog": "系统日志", "flocksproUpgrade": "Flocks Pro", + "checkUpdate": "检查更新", "auditLogs": "审计日志", + "settings": "设置", + "settingsBack": "返回", + "settingsTitle": "设置", + "settingsDescription": "管理偏好、账号与系统配置", + "settingsPreferences": "偏好设置", + "settingsPreferencesDescription": "调整当前工作台的显示语言和界面主题。", + "settingsGroupPreferences": "偏好", + "settingsGroupSystem": "账号与系统", + "settingsGroupIntegrations": "模型与通道", + "language": "语言", + "languageDescription": "选择控制台界面使用的语言。", + "theme": "主题", + "themeDescription": "选择控制台默认使用的浅色或深色模式。", + "lightTheme": "浅色", + "darkTheme": "深色", + "logout": "退出登录", "expandNav": "展开导航", "collapseNav": "收起导航", "switchLanguage": "切换语言", diff --git a/webui/src/locales/zh-CN/socPage.json b/webui/src/locales/zh-CN/socPage.json deleted file mode 100644 index b41ddf059..000000000 --- a/webui/src/locales/zh-CN/socPage.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "host": { - "missingPageId": "未指定页面 ID", - "loading": "正在加载自定义页面...", - "unavailableTitle": "页面暂不可用", - "notBuilt": "页面尚未构建完成", - "loadFailed": "加载页面失败", - "buildFailed": "页面构建失败", - "apiFailed": "页面 API 运行失败", - "retry": "重试", - "emptyComponent": "页面组件为空", - "renderFailedTitle": "自定义页面运行失败", - "renderFailed": "页面渲染失败", - "bundleMissingExport": "页面 bundle 未导出 default 组件" - }, - "nav": { - "fetchFailed": "加载自定义页面导航失败" - } -} diff --git a/webui/src/locales/zh-CN/update.json b/webui/src/locales/zh-CN/update.json index 8ebca06ca..c26ebbe60 100644 --- a/webui/src/locales/zh-CN/update.json +++ b/webui/src/locales/zh-CN/update.json @@ -3,6 +3,9 @@ "proTitle": "Flocks Pro 版本信息", "currentVersion": "当前版本", "latestVersion": "最新版本", + "coreVersion": "Core 版本", + "proComponent": "Pro 组件", + "notInstalled": "未安装", "hasUpdate": "有新版本", "upToDate": "已是最新", "releaseNotes": "发布说明", diff --git a/webui/src/locales/zh-CN/userDefinedPage.json b/webui/src/locales/zh-CN/webuiContractPage.json similarity index 53% rename from webui/src/locales/zh-CN/userDefinedPage.json rename to webui/src/locales/zh-CN/webuiContractPage.json index cfde1a4a7..138dc53a4 100644 --- a/webui/src/locales/zh-CN/userDefinedPage.json +++ b/webui/src/locales/zh-CN/webuiContractPage.json @@ -14,5 +14,18 @@ }, "nav": { "fetchFailed": "加载自定义页面导航失败" + }, + "workspace": { + "missingWorkspaceId": "未指定工作区 ID", + "loading": "正在加载工作区...", + "unavailableTitle": "工作区暂不可用", + "loadFailed": "加载工作区失败", + "notFound": "工作区不存在", + "empty": "工作区暂无页面", + "selectPage": "请从工作区菜单选择页面", + "pageNotFound": "页面不存在", + "sectionNavigation": "工作区页面分类", + "collapseSidebar": "折叠侧边栏", + "expandSidebar": "展开侧边栏" } } diff --git a/webui/src/locales/zh-CN/workspace.json b/webui/src/locales/zh-CN/workspace.json index dbdc9fa52..ecffb9da8 100644 --- a/webui/src/locales/zh-CN/workspace.json +++ b/webui/src/locales/zh-CN/workspace.json @@ -19,6 +19,7 @@ "modified": "修改时间" }, "download": "下载", + "reveal": "打开所在目录", "delete": "删除", "edit": "编辑", "save": "保存", @@ -28,6 +29,26 @@ "binaryPreview": "二进制文件无法预览", "truncatedPreview": "文件较大,当前仅预览前 {{limit}} 内容。为避免误保存截断内容,已禁用在线编辑;如需完整内容请下载文件。", "downloadFile": "下载文件", + "preview": { + "previewMode": "预览", + "sourceMode": "源码", + "fullscreen": "全屏预览", + "resize": "拖拽调整预览宽度", + "htmlSandbox": "HTML 已在受限沙箱中渲染,脚本不会执行。", + "jsonParseFailed": "JSON 解析失败,已按原始文本显示。", + "jsonlParseFailed": "{{count}} 行 JSONL 解析失败,已保留原始文本。", + "pdfLoading": "正在加载 PDF...", + "pdfRendering": "正在渲染页面...", + "pdfLoadFailed": "PDF 预览加载失败", + "pdfCanvasUnavailable": "当前环境不支持 Canvas 渲染。", + "pageIndicator": "{{page}} / {{total}}", + "previousPage": "上一页", + "nextPage": "下一页", + "zoomIn": "放大", + "zoomOut": "缩小", + "unsupportedTitle": "无法预览此文件", + "unsupportedDesc": "该文件类型暂不支持在线预览。你可以下载文件,或在本机文件管理器中打开所在目录。" + }, "toast": { "loadDirFailed": "加载目录失败", "readFileFailed": "读取文件失败", @@ -38,7 +59,9 @@ "uploadSuccess": "上传成功 {{count}} 个文件", "uploadPartialFail": "{{count}} 个文件上传失败", "uploadFailed": "上传失败", - "createDirFailed": "创建目录失败" + "createDirFailed": "创建目录失败", + "revealSuccess": "已打开所在目录", + "revealFailed": "打开目录失败" }, "confirm": { "deleteTitle": "确认删除", diff --git a/webui/src/pages/DeviceIntegration/CustomDeviceAccessPanel.tsx b/webui/src/pages/DeviceIntegration/CustomDeviceAccessPanel.tsx deleted file mode 100644 index b6d9bc2d3..000000000 --- a/webui/src/pages/DeviceIntegration/CustomDeviceAccessPanel.tsx +++ /dev/null @@ -1,169 +0,0 @@ -import { useEffect, useMemo } from 'react'; -import { ChevronLeft, MessageSquare, Route, Workflow, X } from 'lucide-react'; -import { useTranslation } from 'react-i18next'; -import { useNavigate } from 'react-router-dom'; -import SessionChat from '@/components/common/SessionChat'; -import { useSessionChat } from '@/hooks/useSessionChat'; -import type { CustomDeviceAccessMode } from '@/types'; -import { - buildCustomDeviceSessionContext, - buildCustomDeviceWelcomeMessage, -} from './customDevice'; - -export default function CustomDeviceAccessPanel({ - mode, - onClose, - onBack, -}: { - mode: CustomDeviceAccessMode; - onClose: () => void; - onBack: () => void; -}) { - const navigate = useNavigate(); - const { t } = useTranslation('device'); - const isWorkflow = mode === 'workflow'; - const title = useMemo(() => t(`custom.title.${mode}`), [mode, t]); - const subtitle = useMemo(() => t(`custom.subtitle.${mode}`), [mode, t]); - const welcomeMessage = useMemo( - () => (isWorkflow ? t(`custom.welcome.${mode}`) : buildCustomDeviceWelcomeMessage(mode)), - [isWorkflow, mode, t], - ); - - const { sessionId, createAndSend, reset } = useSessionChat({ - title, - category: 'entity-config', - contextMessage: buildCustomDeviceSessionContext(mode), - welcomeMessage, - }); - - useEffect(() => reset, [reset]); - - const handleOpenSession = () => { - if (!sessionId) return; - const params = new URLSearchParams({ session: sessionId }); - navigate(`/sessions?${params.toString()}`); - }; - - return ( -
-
-
-
- -
- {mode === 'api' ? : null} - {mode === 'webcli' ? : null} - {mode === 'workflow' ? : null} -
-
-

{title}

-

{subtitle}

-
-
- -
- -
- {isWorkflow ? ( -
-
-

{t('custom.workflow.heading')}

-

- {t('custom.workflow.body')} -

-
- -
-

{t('custom.workflow.requirementsTitle')}

-
    -
  • {t('custom.workflow.requirement1')}
  • -
  • {t('custom.workflow.requirement2')}
  • -
  • {t('custom.workflow.requirement3')}
  • -
-
-
- ) : ( -
- -
- R -
-
-
Rex
-
- {welcomeMessage} -
-
-
- } - onCreateAndSend={!sessionId ? (text, imageParts) => createAndSend({ text, imageParts }) : undefined} - /> -
- )} -
- -
- {isWorkflow ? ( -
- - -
- ) : ( -
-
- - {sessionId && ( - - )} -
-
- )} -
-
-
- ); -} diff --git a/webui/src/pages/DeviceIntegration/customDevice.test.ts b/webui/src/pages/DeviceIntegration/customDevice.test.ts index 72df920a9..f6291ca98 100644 --- a/webui/src/pages/DeviceIntegration/customDevice.test.ts +++ b/webui/src/pages/DeviceIntegration/customDevice.test.ts @@ -13,16 +13,22 @@ describe('customDevice helpers', () => { expect(buildCustomDeviceServiceId('Acme Guard')).toBe('acme_guard_device'); }); - it('routes custom device onboarding through an explicit access-mode question', () => { + it('routes custom device onboarding without re-asking when the access mode is already selected', () => { const prompt = buildCustomDeviceModeRoutingPrompt(); - expect(prompt).toContain('必须先使用 `question` 工具询问用户选择接入方式'); - expect(prompt).toContain('选项固定为「API 接入」「WebCLI 接入」「Workflow 接入」'); + expect(prompt).toContain('只有用户没有明确选择接入方式时,才使用 `question` 工具询问用户选择接入方式'); + expect(prompt).toContain('选项固定为「API 接入」「浏览器接入」「Workflow 接入」'); + expect(prompt).toContain('如果用户当前消息已经明确写了「API 接入」或「浏览器接入」,不要再询问接入方式'); expect(prompt).toContain('Syslog、Kafka 或 Webhook'); expect(prompt).toContain('【API 接入规则】'); expect(prompt).toContain('tool-builder skill'); - expect(prompt).toContain('【WebCLI 接入规则】'); + expect(prompt).toContain('【浏览器接入规则】'); expect(prompt).toContain('web2cli skill'); + expect(prompt).toContain('username` / `password` 仅用于 cookie 失效后的浏览器认证恢复'); + expect(prompt).toContain('二者都必须声明为 `storage: secret`'); + expect(prompt).toContain('`auth_state`,并声明 `storage: secret` 与 `internal: true`'); + expect(prompt).toContain('handler 只读取 `auth_state_path` 指向的 auth-state 文件'); + expect(prompt).toContain('提示用户重新登录后保存 state'); expect(prompt).toContain('【Workflow 接入规则】'); expect(prompt).toContain('不需要创建 device 插件'); }); diff --git a/webui/src/pages/DeviceIntegration/customDevice.ts b/webui/src/pages/DeviceIntegration/customDevice.ts index 8832dea1a..015825eed 100644 --- a/webui/src/pages/DeviceIntegration/customDevice.ts +++ b/webui/src/pages/DeviceIntegration/customDevice.ts @@ -19,31 +19,22 @@ export function buildCustomDeviceServiceId(deviceName: string): string { return base.endsWith('_device') ? base : `${base}_device`; } -function buildBaseDeviceSessionContext(): string[] { - return [ - '你是 Flocks 的自定义设备接入助手。', - '在正式开始构建设备插件之前,必须先做需求澄清:盘点已知信息、列出缺失/不确定信息,并向用户提出必要问题。', - '当需要用户补充关键信息或澄清不确定项时,使用 `question` 工具明确。', - '除非用户已经提供了足够的信息,否则不要直接写文件或生成插件;优先通过简短问题确认产品名、厂商、版本、认证方式、目标能力、API/页面文档。', - '澄清问题应聚焦关键阻塞项,一次提出 3 到 6 个最重要的问题(支持多选)。', - '不要把用户在表单里填写的账号、密码、Token、Cookie 直接写入插件;这些都应该通过 `credential_fields` 暴露为设备实例配置项。', - ]; -} - export function buildCustomDeviceModeRoutingPrompt(): string { return [ - '如果没有合适的已安装设备模板,按以下规则引导用户进入与「手动接入 > 自定义设备」一致的自定义设备接入路径:', + '如果没有合适的已安装设备模板,按以下规则引导用户进入自定义设备接入路径:', '- 设备提供 API 能力、API 文档或开放接口时,选择「API 接入」。', - '- 设备没有开放 API、主要通过 Web 控制台操作时,选择「WebCLI 接入」。', + '- 设备没有开放 API、主要通过 Web 控制台操作时,选择「浏览器接入」。', '- 数据通过 Syslog、Kafka 或 Webhook 上报时,选择「Workflow 接入」;不要创建 device 插件,请提示用户前往 Workflow 接入完成配置。', - '识别到自定义设备时,不要继续输出设备配置 JSON;必须先使用 `question` 工具询问用户选择接入方式,选项固定为「API 接入」「WebCLI 接入」「Workflow 接入」。', + '只有用户没有明确选择接入方式时,才使用 `question` 工具询问用户选择接入方式,选项固定为「API 接入」「浏览器接入」「Workflow 接入」。', + '如果用户当前消息已经明确写了「API 接入」或「浏览器接入」,不要再询问接入方式,直接按对应规则继续澄清和推进。', + '识别到自定义设备时,不要继续输出设备配置 JSON;先完成对应自定义接入资产创建流程。', '如果能根据用户描述判断推荐路径,先说明推荐原因,再让用户确认或改选接入方式。', '用户确认接入方式后,必须使用下方对应规则继续澄清和推进:', '', '【API 接入规则】', buildCustomDeviceModeInstruction('api'), '', - '【WebCLI 接入规则】', + '【浏览器接入规则】', buildCustomDeviceModeInstruction('webcli'), '', '【Workflow 接入规则】', @@ -64,10 +55,14 @@ export function buildCustomDeviceModeInstruction(mode: CustomDeviceAccessMode): } if (mode === 'webcli') { return [ - '本次接入方式是 WebCLI 接入。', + '本次接入方式是浏览器接入。', '你必须先读取并使用 web2cli skill,再开始捕获与转换流程。', '用户会提供产品 URL 和需要获取的接口/页面行为。目标是安全设备接入,需要生成 device 插件。', - '自定义 CLI 默认复用 `cookie/auth-state`;可选暴露 `username` / `password` 仅用于 cookie 失效后的浏览器认证恢复。只有在站点确实需要补充 header、cookie 或 token 时,才额外暴露对应字段。', + '自定义 CLI 默认复用 `cookie/auth-state`;优先使用 `auth_state_path` 指向 `~/.flocks/browser//auth-state.json`。', + '`username` / `password` 仅用于 cookie 失效后的浏览器认证恢复,二者都必须声明为 `storage: secret`,不要把账号或密码明文写入数据库字段。', + '如果需要保存内联登录态,只能使用 `auth_state`,并声明 `storage: secret` 与 `internal: true`;不要在表单中展示 Cookie、localStorage、token 明文。', + 'handler 只读取 `auth_state_path` 指向的 auth-state 文件;如果文件缺失、过期或无法匹配当前站点,应返回明确错误并提示用户重新登录后保存 state。', + '只有在站点确实需要补充 header、cookie 或 token 时,才额外暴露对应字段,并且必须使用 `storage: secret`。', '最终输出目录和插件结构必须符合 device 插件规范。', ].join('\n'); } @@ -77,43 +72,6 @@ export function buildCustomDeviceModeInstruction(mode: CustomDeviceAccessMode): ].join('\n'); } -export function buildCustomDeviceSessionContext(mode: CustomDeviceAccessMode): string { - return [ - ...buildBaseDeviceSessionContext(), - buildCustomDeviceModeInstruction(mode), - ].join('\n'); -} - -export function buildCustomDeviceWelcomeMessage(mode: CustomDeviceAccessMode): string { - if (mode === 'api') { - return [ - '请提供待接入设备的 API 资料。', - '', - '建议包含以下内容:', - '1. 产品、厂商与版本信息', - '2. API 文档链接或文档附件', - '3. Base URL 或典型部署地址', - '4. 认证方式与凭据类型', - '', - '资料确认后,Rex 将生成可在设备接入页识别和配置的 device 插件。', - ].join('\n'); - } - if (mode === 'webcli') { - return [ - '请提供待接入设备的 Web 控制台资料。', - '', - '建议包含以下内容:', - '1. 产品、厂商与版本信息', - '2. 登录 URL 或目标页面 URL', - '3. 需要沉淀的页面行为或接口', - '4. 认证限制、权限要求与可用登录态', - '', - '资料确认后,Rex 将沉淀 WebCLI 资产,并按需生成可在设备接入页识别和配置的 device 插件。', - ].join('\n'); - } - return 'Workflow 接入不在这里创建插件,请前往工作流发布页面,根据需要配置 Syslog、Kafka 或 Webhook。'; -} - export function findTemplateForCustomDevice( templates: DeviceTemplate[], deviceName: string, diff --git a/webui/src/pages/DeviceIntegration/index.test.tsx b/webui/src/pages/DeviceIntegration/index.test.tsx index 439eb237b..089cab010 100644 --- a/webui/src/pages/DeviceIntegration/index.test.tsx +++ b/webui/src/pages/DeviceIntegration/index.test.tsx @@ -1,6 +1,6 @@ import React from 'react'; import { beforeEach, describe, expect, it, vi } from 'vitest'; -import { render, screen, waitFor } from '@testing-library/react'; +import { fireEvent, render, screen, waitFor } from '@testing-library/react'; import userEvent from '@testing-library/user-event'; import DeviceIntegrationPage from './index'; @@ -29,6 +29,8 @@ const mocks = vi.hoisted(() => ({ listDeviceTools: vi.fn(), updateDeviceTool: vi.fn(), listTemplates: vi.fn(), + hubInstall: vi.fn(), + hubUpdate: vi.fn(), getServiceMetadata: vi.fn(), listTools: vi.fn(), setToolEnabled: vi.fn(), @@ -46,10 +48,15 @@ vi.mock('react-i18next', () => ({ const translations: Record = { pageTitle: '设备接入', pageDescription: '配置安全设备 API 连接,使 Flocks 能够直接调用和控制这些设备', + 'status.connected': '已连接', + 'status.disabled': '已禁用', + 'status.error': '连接失败', + 'status.unknown': '未检测', 'toolbar.refresh': '刷新', 'toolbar.addDevice': '立即添加设备', 'empty.addNow': '立即添加设备', 'config.closeAriaLabel': '关闭设备配置面板', + 'config.newDeviceTitle': '填写配置', 'config.nameLabel': '设备名称', 'config.roomLabel': '所属机房', 'config.saveBtn': '保存配置', @@ -57,9 +64,56 @@ vi.mock('react-i18next', () => ({ 'config.testBtn': '连通测试', 'config.showSecretAction': '显示', 'config.hideSecretAction': '隐藏', + 'config.aiAssistTitle': 'Rex 辅助配置', + 'config.aiAssistHint': '结合当前表单和配置指引,让 Rex 帮你继续完成配置。', + 'config.aiAssistTest': '测试设备', + 'config.aiAssistTroubleshoot': '排查问题', + 'config.aiAssistSaveFirst': '保存设备后可让 Rex 调用设备工具测试', + 'overview.viewDocs': '查看配置指引', 'wizard.selectVendorTitle': `选择 ${String(params?.vendor ?? '')} 设备`, 'wizard.tabs.rex': 'Rex 接入', 'wizard.tabs.manual': '手动接入', + 'wizard.guide.workbenchTab': '工作台', + 'wizard.guide.title': 'Rex 辅助接入', + 'wizard.guide.subtitle': '选择一个引导或案例', + 'wizard.guide.customTitle': '自定义设备接入', + 'wizard.guide.caseTitle': '创建设备', + 'wizard.guide.examples.supported': '我想接入一台已支持的安全设备', + 'wizard.guide.examples.addressOnly': '我只有设备地址和登录方式', + 'wizard.guide.examples.noApi': '这台设备没有开放 API', + 'wizard.guide.prompts.api': '我已选择 API 接入,请按 API 接入继续', + 'wizard.guide.prompts.browser': '我已选择浏览器接入,请按浏览器接入继续', + 'wizard.guide.prompts.addressOnly': '我只有待接入设备的地址和登录方式,请先帮我判断它是否已有模板可用', + 'wizard.guide.prompts.tdp': '我已选择 TDP 接入案例,请引导我接入微步 TDP 设备', + 'wizard.guide.prompts.onesec': '我已选择 OneSEC 接入案例,请引导我接入 OneSEC 设备', + 'wizard.guide.actions.api': 'API 接入', + 'wizard.guide.actions.browser': '浏览器接入', + 'wizard.guide.cases.tdp': 'TDP 接入', + 'wizard.guide.cases.onesec': 'OneSEC 接入', + 'wizard.guide.cases.more': '查看更多', + 'wizard.guide.descriptions.api': '设备提供 API 能力时使用', + 'wizard.guide.descriptions.browser': '设备没有开放 API 时使用', + 'wizard.guide.descriptions.tdp': '按微步 TDP 的已知接入案例继续', + 'wizard.guide.descriptions.onesec': '按 OneSEC 的已知接入案例继续', + 'wizard.guide.descriptions.more': '查看当前已支持的设备模板', + 'wizard.supportedList.back': '返回', + 'wizard.supportedList.title': '已支持设备列表', + 'wizard.supportedList.subtitle': '先选择厂商,再选择要接入的设备', + 'wizard.supportedList.deviceCount': `${String(params?.count ?? '')} 款设备`, + 'wizard.supportedList.integratedCount': `已接入 ${String(params?.count ?? '')} 台`, + 'wizard.supportedList.templateTooltip': '点击后 Rex 会基于该设备模板继续引导配置', + 'wizard.installState.installed': '已安装', + 'wizard.installState.available': '可安装', + 'wizard.installState.updateAvailable': '可更新', + 'wizard.installState.brokenShort': '不可用', + 'wizard.installState.installing': '安装中', + 'wizard.installState.updating': '更新中', + 'wizard.installState.installingTemplate': `正在安装设备模板「${String(params?.name ?? '')}」`, + 'wizard.installState.updatingTemplate': `正在更新设备模板「${String(params?.name ?? '')}」`, + 'wizard.installState.installDone': `设备模板「${String(params?.name ?? '')}」已安装`, + 'wizard.installState.updateDone': `设备模板「${String(params?.name ?? '')}」已更新`, + 'wizard.installState.installFailed': `设备模板「${String(params?.name ?? '')}」安装失败`, + 'wizard.installState.updateFailed': `设备模板「${String(params?.name ?? '')}」更新失败`, 'wizard.rex.title': 'Rex 引导添加设备', 'wizard.rex.heading': 'Rex 引导接入', 'wizard.rex.subtitle': '描述设备型号、接入方式和已有资料', @@ -85,12 +139,12 @@ vi.mock('react-i18next', () => ({ 'wizard.rex.guides.api.title': '自定义 API', 'wizard.rex.guides.api.desc': '通过 API 文档创建自定义 device 插件', 'wizard.rex.guides.api.prompt': '我要接入一个暂未支持的 API 设备', - 'wizard.rex.guides.webcli.title': 'WebCLI', + 'wizard.rex.guides.webcli.title': '浏览器接入', 'wizard.rex.guides.webcli.desc': '通过 Web 控制台页面创建接入能力', 'wizard.rex.guides.webcli.prompt': '我要接入一个没有开放 API 的 Web 控制台设备', 'wizard.customCardTitle': '自定义设备', 'wizard.customModes.api.title': 'API 接入', - 'wizard.customModes.webcli.title': 'WebCLI 接入', + 'wizard.customModes.webcli.title': '浏览器接入', 'wizard.customModes.workflow.title': 'Workflow 接入', 'custom.actions.submit': '提交给 Rex', 'custom.actions.openSessionList': '前往会话列表查看', @@ -231,6 +285,13 @@ vi.mock('@/api/device', () => ({ }, })); +vi.mock('@/api/hub', () => ({ + hubAPI: { + install: (...args: unknown[]) => mocks.hubInstall(...args), + update: (...args: unknown[]) => mocks.hubUpdate(...args), + }, +})); + vi.mock('@/api/provider', () => ({ providerAPI: { getServiceMetadata: (...args: unknown[]) => mocks.getServiceMetadata(...args), @@ -269,7 +330,21 @@ function buildTemplate(overrides: Record = {}) { async function openManualAddWizard(user: ReturnType) { await user.click(await screen.findByRole('button', { name: /立即添加设备/ })); - await user.click(screen.getByRole('button', { name: /^手动接入$/ })); +} + +async function openSupportedDeviceList(user: ReturnType) { + await openManualAddWizard(user); + await user.click(screen.getByRole('button', { name: /查看更多/ })); +} + +async function openApiDeviceGuidance(user: ReturnType) { + await openManualAddWizard(user); + await user.click(screen.getByRole('button', { name: /^API 接入$/ })); +} + +async function openBrowserDeviceGuidance(user: ReturnType) { + await openManualAddWizard(user); + await user.click(screen.getByRole('button', { name: /^浏览器接入$/ })); } describe('DeviceIntegrationPage', () => { @@ -301,6 +376,8 @@ describe('DeviceIntegrationPage', () => { data: [{ id: 'default', name: '默认机房', sort_order: 0, created_at: 0, updated_at: 0 }], }); mocks.listTemplates.mockResolvedValue({ data: [buildTemplate()] }); + mocks.hubInstall.mockResolvedValue({ data: {} }); + mocks.hubUpdate.mockResolvedValue({ data: {} }); mocks.getServiceMetadata.mockResolvedValue({ data: { credential_schema: [] } }); mocks.revealDeviceCredentials.mockResolvedValue({ data: { fields: {} } }); mocks.listTools.mockResolvedValue({ data: [] }); @@ -332,16 +409,47 @@ describe('DeviceIntegrationPage', () => { expect(mocks.syncDevices).not.toHaveBeenCalled(); }); - it('shows custom device option and access modes', async () => { + it('shows custom guidance and example entries on the add-device workbench', async () => { const user = userEvent.setup(); render(); await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); expect(screen.getByText('API 接入')).toBeInTheDocument(); - expect(screen.getByText('WebCLI 接入')).toBeInTheDocument(); - expect(screen.getByText('Workflow 接入')).toBeInTheDocument(); + expect(screen.getByText('浏览器接入')).toBeInTheDocument(); + expect(screen.getByText('TDP 接入')).toBeInTheDocument(); + expect(screen.getByText('OneSEC 接入')).toBeInTheDocument(); + expect(screen.getByText('查看更多')).toBeInTheDocument(); + }); + + it('shows tooltip text when hovering a workbench info icon', async () => { + const user = userEvent.setup(); + render(); + + await openManualAddWizard(user); + + const apiButton = screen.getByRole('button', { name: /^API 接入$/ }); + const infoIcon = apiButton.querySelector('span[aria-hidden="true"]'); + expect(infoIcon).toBeTruthy(); + + await user.hover(infoIcon as HTMLElement); + + expect(await screen.findByRole('tooltip')).toHaveTextContent('设备提供 API 能力时使用'); + }); + + it('returns to the add-device workbench when the workbench tab is clicked', async () => { + const user = userEvent.setup(); + render(); + + await openSupportedDeviceList(user); + expect(screen.getByText('已支持设备列表')).toBeInTheDocument(); + + await user.click(screen.getByRole('button', { name: /^工作台$/ })); + + expect(screen.getByText('Rex 辅助接入')).toBeInTheDocument(); + expect(screen.getByRole('button', { name: /^API 接入$/ })).toBeInTheDocument(); + expect(screen.queryByText('已支持设备列表')).not.toBeInTheDocument(); + expect(mocks.resetSession).toHaveBeenCalledTimes(1); }); it('opens the add-device panel on the Rex-guided tab by default', async () => { @@ -352,6 +460,10 @@ describe('DeviceIntegrationPage', () => { expect(await screen.findByText('SessionChat:pending')).toBeInTheDocument(); expect(screen.getByText('SessionChat:pending')).toBeInTheDocument(); + expect(screen.getByText('Rex 辅助接入')).toBeInTheDocument(); + expect(screen.getByRole('button', { name: /^API 接入$/ })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: /^浏览器接入$/ })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: /^查看更多$/ })).toBeInTheDocument(); expect(screen.getByText('Placeholder:描述要接入的设备、地址、认证方式或上传相关资料')).toBeInTheDocument(); expect(screen.queryByRole('button', { name: /^填充表单$/ })).toBeNull(); expect(screen.queryByRole('button', { name: /自定义设备/ })).toBeNull(); @@ -366,14 +478,44 @@ describe('DeviceIntegrationPage', () => { const contextMessage = mocks.useSessionChatOptions.mock.calls.at(-1)?.[0].contextMessage; expect(contextMessage).toContain('```json'); expect(contextMessage).toContain('API 接入'); - expect(contextMessage).toContain('WebCLI 接入'); + expect(contextMessage).toContain('浏览器接入'); expect(contextMessage).toContain('Workflow 接入'); expect(contextMessage).toContain('Syslog、Kafka 或 Webhook'); expect(contextMessage).toContain('不要继续输出设备配置 JSON'); - expect(contextMessage).toContain('必须先使用 `question` 工具询问用户选择接入方式'); + expect(contextMessage).toContain('只有用户没有明确选择接入方式时,才使用 `question` 工具询问用户选择接入方式'); + expect(contextMessage).toContain('如果用户当前消息已经明确写了「API 接入」或「浏览器接入」,不要再询问接入方式'); expect(contextMessage).toContain('用户确认接入方式后,必须使用下方对应规则继续澄清和推进'); }); + it('starts a guided Rex prompt from the add-device welcome card', async () => { + const user = userEvent.setup(); + render(); + + await user.click(await screen.findByRole('button', { name: /立即添加设备/ })); + await user.click(screen.getByRole('button', { name: /TDP 接入/ })); + + expect(mocks.createAndSend).toHaveBeenCalledWith({ + text: expect.stringContaining('我已选择 TDP 接入案例'), + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }); + }); + + it('sends the custom API guidance prompt from the add-device welcome card', async () => { + const user = userEvent.setup(); + render(); + + await user.click(await screen.findByRole('button', { name: /立即添加设备/ })); + await user.click(screen.getByRole('button', { name: /^API 接入$/ })); + + expect(mocks.createAndSend).toHaveBeenCalledWith({ + text: expect.stringContaining('我已选择 API 接入'), + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }); + expect(screen.getByText('Placeholder:描述要接入的设备、地址、认证方式或上传相关资料')).toBeInTheDocument(); + }); + it('applies Rex device draft to the add-device config form', async () => { const user = userEvent.setup(); mocks.sessionId = 'session-1'; @@ -390,6 +532,7 @@ describe('DeviceIntegrationPage', () => { service_id: 'qingteng', name: '青藤云安全', vendor: 'qingteng', + docs_url: 'https://docs.example.com/qingteng', credential_schema: [ { key: 'base_url', @@ -438,9 +581,107 @@ describe('DeviceIntegrationPage', () => { expect(screen.getByDisplayValue('admin')).toBeInTheDocument(); expect(screen.getByRole('combobox')).toHaveValue('group-1'); expect(screen.getAllByText('北京机房').length).toBeGreaterThan(0); + expect(screen.getByRole('link', { name: /查看配置指引/ })).toHaveAttribute('href', 'https://docs.example.com/qingteng'); + expect(screen.queryByRole('button', { name: /^工作台$/ })).not.toBeInTheDocument(); expect(mocks.toastSuccess).toHaveBeenCalledWith('已填充设备配置表单'); }); + it('returns to the Rex session and asks for testing guidance after confirming integration', async () => { + const user = userEvent.setup(); + mocks.sessionId = 'session-1'; + const template = buildTemplate({ + storage_key: 'qingteng_v3_4_1_66', + service_id: 'qingteng', + name: '青藤云安全', + vendor: 'qingteng', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + { + key: 'username', + label: 'Username', + storage: 'config', + sensitive: false, + required: true, + input_type: 'text', + config_key: 'username', + }, + ], + }); + mocks.listGroups.mockResolvedValue({ + data: [ + { id: 'group-1', name: '默认机房', sort_order: 0, created_at: 0, updated_at: 0 }, + ], + }); + mocks.listTemplates.mockResolvedValue({ data: [template] }); + mocks.getSessionMessagesPage.mockResolvedValue({ + items: [ + { + info: { role: 'assistant' }, + parts: [ + { + type: 'text', + text: '```json\n{"storage_key":"qingteng_v3_4_1_66","device_name":"青藤万相","fields":{"base_url":"https://example.com","username":"admin"},"verify_ssl":false}\n```', + }, + ], + }, + ], + }); + mocks.createDevice.mockResolvedValue({ + data: { + id: 'device-new', + group_id: 'group-1', + name: '青藤万相', + storage_key: 'qingteng_v3_4_1_66', + service_id: 'qingteng', + enabled: true, + verify_ssl: false, + fields: { base_url: 'https://example.com', username: 'admin' }, + fields_set: { base_url: true, username: true }, + status: 'unknown', + created_at: 0, + updated_at: 0, + }, + }); + render(); + + await user.click(await screen.findByRole('button', { name: /立即添加设备/ })); + await user.click(await screen.findByRole('button', { name: /mock stream done/ })); + await user.click(await screen.findByRole('button', { name: /^填充表单$/ })); + expect(await screen.findByDisplayValue('青藤万相')).toBeInTheDocument(); + await user.click(await screen.findByRole('button', { name: /^(确认接入|添加设备|config.addBtn)$/ })); + + await waitFor(() => { + expect(mocks.createDevice).toHaveBeenCalledWith(expect.objectContaining({ + name: '青藤万相', + storage_key: 'qingteng_v3_4_1_66', + service_id: 'qingteng', + group_id: 'group-1', + })); + }); + expect(await screen.findByText('SessionChat:session-1')).toBeInTheDocument(); + await waitFor(() => expect(mocks.createAndSend).toHaveBeenCalledWith(expect.objectContaining({ + text: expect.stringContaining('设备「青藤万相」已确认接入并保存'), + displayText: '设备「青藤万相」已确认接入,请帮我测试。', + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }))); + await waitFor(() => { + expect(mocks.getDevice).toHaveBeenCalledWith('device-new'); + }); + expect(mocks.createAndSend.mock.calls.at(-1)?.[0].text).toContain('不要再询问接入方式'); + expect(mocks.createAndSend.mock.calls.at(-1)?.[0].text).toContain('直接调用这台设备的可用工具完成连通测试'); + expect(mocks.createAndSend.mock.calls.at(-1)?.[0].text).toContain('明确使用上面的 device_id 作为目标设备'); + expect(mocks.createAndSend.mock.calls.at(-1)?.[0].text).not.toContain('页面上执行的连通测试动作'); + }); + it('does not detect Rex prose as a fillable device draft', async () => { const user = userEvent.setup(); mocks.sessionId = 'session-1'; @@ -481,98 +722,84 @@ describe('DeviceIntegrationPage', () => { expect(mocks.toastError).not.toHaveBeenCalled(); }); - it('navigates unavailable templates to FlockHub', async () => { + it('installs unavailable supported templates and then sends them to Rex', async () => { const user = userEvent.setup(); + const availableTemplate = buildTemplate({ + plugin_id: 'onesig_v2_5_3_D20250710', + storage_key: 'onesig_v2_5_3_D20250710_api_v2_5_3_D20250710', + service_id: 'onesig_v2_5_3_D20250710_api', + name: 'onesig', + version: '2.5.3 D20250710', + installed: false, + state: 'available', + }); mocks.listTemplates.mockResolvedValueOnce({ - data: [ - buildTemplate({ - plugin_id: 'onesig_v2_5_3_D20250710', - storage_key: 'onesig_v2_5_3_D20250710_api_v2_5_3_D20250710', - service_id: 'onesig_v2_5_3_D20250710_api', - name: 'onesig', - version: '2.5.3 D20250710', - installed: false, - state: 'available', - }), - ], + data: [availableTemplate], + }); + mocks.listTemplates.mockResolvedValueOnce({ + data: [{ ...availableTemplate, installed: true, state: 'installed' }], }); render(); - await openManualAddWizard(user); + await openSupportedDeviceList(user); await user.click(screen.getByText('微步')); await user.click(screen.getByText('onesig')); - expect(mocks.navigate).toHaveBeenCalledWith( - '/hub?type=device&plugin=onesig_v2_5_3_D20250710&q=onesig_v2_5_3_D20250710', - ); + await waitFor(() => { + expect(mocks.hubInstall).toHaveBeenCalledWith('device', 'onesig_v2_5_3_D20250710'); + }); + expect(mocks.syncDevices).toHaveBeenCalledWith({ refresh: true }); + expect(mocks.listTemplates).toHaveBeenLastCalledWith({ refresh: true }); + expect(mocks.navigate).not.toHaveBeenCalled(); + await waitFor(() => expect(mocks.createAndSend).toHaveBeenCalledWith(expect.objectContaining({ + text: expect.stringContaining('我要接入设备「onesig」'), + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }))); + expect(mocks.createAndSend.mock.calls[0][0].text).toContain('storage_key=onesig_v2_5_3_D20250710_api_v2_5_3_D20250710'); }); - it('opens api mode directly in Rex chat with built-in guidance', async () => { + it('sends api guidance directly to Rex without opening a custom form', async () => { const user = userEvent.setup(); render(); - await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); - await user.click(screen.getByRole('button', { name: /API 接入/ })); + await openApiDeviceGuidance(user); expect(screen.queryByLabelText('设备产品名')).toBeNull(); expect(screen.queryByLabelText('Base URL')).toBeNull(); expect(screen.queryByRole('button', { name: /提交给 Rex/ })).toBeNull(); expect(await screen.findByText('SessionChat:pending')).toBeInTheDocument(); - expect(screen.getByText('Placeholder:请提供产品 API 文档')).toBeInTheDocument(); - expect(screen.getByText(/请提供待接入设备的 API 资料。/)).toBeInTheDocument(); - expect(mocks.createAndSend).not.toHaveBeenCalled(); - const options = mocks.useSessionChatOptions.mock.calls.at(-1)?.[0]; - expect(options).toEqual( - expect.objectContaining({ - category: 'entity-config', - welcomeMessage: expect.stringContaining('API 文档链接'), - }), - ); - expect(options.contextMessage).toContain('本次接入方式是 API 接入'); - expect(options.contextMessage).toContain('在正式开始构建设备插件之前'); - expect(options.contextMessage).toContain('使用 `question` 工具明确'); - expect(options.welcomeMessage).toContain('请提供待接入设备的 API 资料。'); - expect(options.welcomeMessage).toContain('资料确认后,Rex 将生成'); + expect(mocks.createAndSend).toHaveBeenCalledWith({ + text: expect.stringContaining('我已选择 API 接入'), + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }); }); - it('opens webcli mode directly in Rex chat with skill-first guidance', async () => { + it('sends browser guidance directly to Rex without opening a custom form', async () => { const user = userEvent.setup(); render(); - await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); - await user.click(screen.getByRole('button', { name: /WebCLI 接入/ })); + await openBrowserDeviceGuidance(user); expect(screen.queryByLabelText('登录说明')).toBeNull(); expect(screen.queryByLabelText('产品 URL')).toBeNull(); expect(screen.queryByLabelText('需要获取的接口或页面行为')).toBeNull(); expect(screen.queryByRole('button', { name: /提交给 Rex/ })).toBeNull(); expect(await screen.findByText('SessionChat:pending')).toBeInTheDocument(); - expect(screen.getByText('Placeholder:请提供网站地址')).toBeInTheDocument(); - expect(screen.getByText(/请提供待接入设备的 Web 控制台资料。/)).toBeInTheDocument(); - expect(mocks.createAndSend).not.toHaveBeenCalled(); - const options = mocks.useSessionChatOptions.mock.calls.at(-1)?.[0]; - expect(options).toEqual( - expect.objectContaining({ - welcomeMessage: expect.stringContaining('登录 URL'), - }), - ); - expect(options.contextMessage).toContain('本次接入方式是 WebCLI 接入'); - expect(options.contextMessage).toContain('向用户提出必要问题'); - expect(options.contextMessage).toContain('使用 `question` 工具明确'); - expect(options.welcomeMessage).toContain('请提供待接入设备的 Web 控制台资料。'); - expect(options.welcomeMessage).toContain('资料确认后,Rex 将沉淀 WebCLI 资产'); + expect(mocks.createAndSend).toHaveBeenCalledWith({ + text: expect.stringContaining('我已选择浏览器接入'), + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }); }); - it('creates custom device session only after the user sends a message', async () => { + it('creates a Rex device-add session when the user sends a message', async () => { const user = userEvent.setup(); render(); await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); - await user.click(screen.getByRole('button', { name: /API 接入/ })); expect(await screen.findByText('SessionChat:pending')).toBeInTheDocument(); expect(mocks.createAndSend).not.toHaveBeenCalled(); @@ -583,6 +810,8 @@ describe('DeviceIntegrationPage', () => { expect(mocks.createAndSend).toHaveBeenCalledWith({ text: '用户补充资料', imageParts: [], + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, }); }); @@ -590,41 +819,137 @@ describe('DeviceIntegrationPage', () => { const user = userEvent.setup(); render(); - await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); - await user.click(screen.getByRole('button', { name: /API 接入/ })); + await openApiDeviceGuidance(user); await screen.findByText('SessionChat:pending'); expect(screen.queryByRole('button', { name: /刷新设备模板/ })).toBeNull(); expect(screen.queryByText(/已进入 Rex 对话/)).toBeNull(); }); - it('navigates to the matching session from rex chat view', async () => { + it('localizes known vendor keys in the supported-device list', async () => { const user = userEvent.setup(); - mocks.sessionId = 'session-1'; - render(); + mocks.listTemplates.mockResolvedValueOnce({ + data: [ + buildTemplate({ + plugin_id: 'huorong_v1', + storage_key: 'huorong_v1', + service_id: 'huorong', + name: '火绒终端安全', + vendor: 'huorong', + }), + buildTemplate({ + plugin_id: 'huawei_cloud_v1', + storage_key: 'huawei_cloud_v1', + service_id: 'huaweicloud', + name: '华为云', + vendor: 'huaweicloud', + }), + buildTemplate({ + plugin_id: '360_waf_v5_5', + storage_key: '360_waf_v5_5', + service_id: '360_waf', + name: '360 WAF', + vendor: '360', + }), + ], + }); + const { container } = render(); - await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); - await user.click(screen.getByRole('button', { name: /API 接入/ })); + await openSupportedDeviceList(user); - await screen.findByText('SessionChat:session-1'); - await user.click(screen.getByRole('button', { name: /前往会话列表查看/ })); + expect(screen.getByText('火绒')).toBeInTheDocument(); + expect(screen.getByText('华为云')).toBeInTheDocument(); + expect(screen.getAllByText('360').length).toBeGreaterThan(0); + expect(screen.queryByText('huorong')).toBeNull(); + expect(screen.queryByText('huaweicloud')).toBeNull(); - expect(mocks.navigate).toHaveBeenCalledWith('/sessions?session=session-1'); + const huorongLogo = container.querySelector('img[src="/vendor-logos/huorong.png"]'); + expect(huorongLogo).not.toBeNull(); + expect(container.querySelector('img[src="/vendor-logos/huaweicloud.png"]')).not.toBeNull(); + expect(container.querySelector('img[src="/vendor-logos/360.png"]')).not.toBeNull(); + + fireEvent.error(huorongLogo as Element); + await waitFor(() => expect(screen.getByText('火')).toBeInTheDocument()); }); - it('redirects workflow integration flow to workflows page', async () => { + it('shows template versions when supported devices share the same name', async () => { const user = userEvent.setup(); + mocks.listTemplates.mockResolvedValueOnce({ + data: [ + buildTemplate({ + plugin_id: 'onesig_v2_5_3_D20260321', + storage_key: 'onesig_v2_5_3_D20260321', + service_id: 'onesig_api', + name: 'onesig', + vendor: 'threatbook', + version: '2.5.3 D20260321', + installed: true, + state: 'installed', + }), + buildTemplate({ + plugin_id: 'onesig_v2_5_3_D20250710', + storage_key: 'onesig_v2_5_3_D20250710', + service_id: 'onesig_v2_5_3_D20250710_api', + name: 'onesig', + vendor: 'threatbook', + version: '2.5.3 D20250710', + installed: true, + state: 'installed', + }), + ], + }); render(); - await openManualAddWizard(user); - await user.click(screen.getByRole('button', { name: /自定义设备/ })); - await user.click(screen.getByRole('button', { name: /Workflow 接入/ })); - expect(screen.queryByRole('button', { name: /新建工作流/ })).toBeNull(); - await user.click(screen.getByRole('button', { name: /前往工作流列表/ })); + await openSupportedDeviceList(user); + await user.click(screen.getByRole('button', { name: /微步/ })); + + expect(screen.getAllByText('onesig')).toHaveLength(2); + expect(screen.getByText('v2.5.3 D20260321')).toBeInTheDocument(); + expect(screen.getByText('v2.5.3 D20250710')).toBeInTheDocument(); + }); + + it('sends the selected supported device template to Rex from the vendor accordion', async () => { + const user = userEvent.setup(); + mocks.listTemplates.mockResolvedValueOnce({ + data: [ + buildTemplate({ + plugin_id: 'tdp_v3_3_10', + storage_key: 'tdp_api_v3_3_10', + service_id: 'tdp_api', + name: 'TDP', + vendor: 'threatbook', + docs_url: 'https://docs.example.com/tdp', + installed: true, + state: 'installed', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + ], + }), + ], + }); + render(); - expect(mocks.navigate).toHaveBeenCalledWith('/workflows'); + await openSupportedDeviceList(user); + await user.click(screen.getByRole('button', { name: /微步/ })); + await user.click(screen.getByRole('button', { name: /TDP/ })); + + expect(screen.queryByText('填写配置')).toBeNull(); + expect(mocks.createAndSend).toHaveBeenCalledWith(expect.objectContaining({ + text: expect.stringContaining('我要接入设备「TDP」'), + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + })); + expect(mocks.createAndSend.mock.calls[0][0].text).toContain('storage_key=tdp_api_v3_3_10'); + expect(mocks.createAndSend.mock.calls[0][0].text).toContain('base_url* (Base URL)'); + expect(mocks.createAndSend.mock.calls[0][0].text).toContain('https://docs.example.com/tdp'); }); it('clicking the blank backdrop closes the config panel', async () => { @@ -656,6 +981,18 @@ describe('DeviceIntegrationPage', () => { name: 'TDP', tool_count: 21, vendor: 'threatbook', + docs_url: 'https://docs.example.com/tdp', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + ], }), ], }); @@ -719,6 +1056,68 @@ describe('DeviceIntegrationPage', () => { }); }); + it('hides internal credential fields from the device config form', async () => { + const user = userEvent.setup(); + mocks.listDevices.mockResolvedValueOnce({ + data: [ + { + id: 'device-1', + group_id: 'group-1', + name: 'webcli-device', + storage_key: 'webcli_device_v1', + service_id: 'webcli_device', + enabled: true, + verify_ssl: false, + fields: { base_url: 'https://device.example.com', auth_state: 'a***xyz' }, + fields_set: { base_url: true, auth_state: true }, + status: 'connected', + created_at: 0, + updated_at: 0, + }, + ], + }); + mocks.listTemplates.mockResolvedValueOnce({ + data: [ + buildTemplate({ + plugin_id: 'webcli_device_v1', + storage_key: 'webcli_device_v1', + service_id: 'webcli_device', + name: 'WebCLI Device', + vendor: 'threatbook', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + { + key: 'auth_state', + label: 'Auth State', + storage: 'secret', + sensitive: true, + required: false, + input_type: 'password', + config_key: 'auth_state', + internal: true, + }, + ], + }), + ], + }); + + render(); + + await user.click(await screen.findByText('webcli-device')); + + expect(await screen.findByText('Base URL')).toBeInTheDocument(); + expect(screen.queryByText('Auth State')).toBeNull(); + expect(screen.queryByDisplayValue('a***xyz')).toBeNull(); + }); + it('allows editing an existing device room from a selected room view', async () => { const user = userEvent.setup(); const initialDevice = { @@ -745,6 +1144,18 @@ describe('DeviceIntegrationPage', () => { name: 'TDP', tool_count: 21, vendor: 'threatbook', + docs_url: 'https://docs.example.com/tdp', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + ], }), ], }); @@ -761,6 +1172,7 @@ describe('DeviceIntegrationPage', () => { render(); await user.click(await screen.findByText('TDP-test-02')); + expect(await screen.findByRole('link', { name: /查看配置指引/ })).toHaveAttribute('href', 'https://docs.example.com/tdp'); const roomSelect = await screen.findByRole('combobox'); await user.selectOptions(roomSelect, 'group-2'); await user.click(screen.getByRole('button', { name: /保存配置/ })); @@ -773,7 +1185,142 @@ describe('DeviceIntegrationPage', () => { }); }); - it('tests connectivity with draft fields without replacing the form', async () => { + it('omits untouched masked secrets when saving existing device credentials', async () => { + const user = userEvent.setup(); + const initialDevice = { + id: 'device-1', + group_id: 'group-1', + name: 'TDP-test-02', + storage_key: 'tdp_api_v3_3_10', + service_id: 'tdp_api', + enabled: true, + verify_ssl: false, + fields: { + api_key: '07***0af9', + base_url: 'https://tdp.example.com', + }, + fields_set: { api_key: true, base_url: true }, + status: 'connected', + created_at: 0, + updated_at: 0, + }; + mocks.listDevices.mockResolvedValue({ data: [initialDevice] }); + mocks.listTemplates.mockResolvedValue({ + data: [ + buildTemplate({ + plugin_id: 'tdp_v3_3_10', + storage_key: 'tdp_api_v3_3_10', + service_id: 'tdp_api', + name: 'TDP', + credential_schema: [ + { + key: 'api_key', + label: 'API Key', + storage: 'secret', + sensitive: true, + required: true, + input_type: 'password', + config_key: 'api_key', + }, + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + ], + }), + ], + }); + + render(); + + await user.click(await screen.findByText('TDP-test-02')); + await screen.findByDisplayValue('07***0af9'); + await user.click(screen.getByRole('button', { name: /保存配置/ })); + + await waitFor(() => { + expect(mocks.updateDevice).toHaveBeenCalledWith( + 'device-1', + expect.objectContaining({ + fields: expect.not.objectContaining({ api_key: expect.anything() }), + }), + ); + }); + }); + + it('submits empty strings when existing device credentials are cleared', async () => { + const user = userEvent.setup(); + const initialDevice = { + id: 'device-1', + group_id: 'group-1', + name: 'TDP-test-02', + storage_key: 'tdp_api_v3_3_10', + service_id: 'tdp_api', + enabled: true, + verify_ssl: false, + fields: { + api_key: '07***0af9', + base_url: 'https://tdp.example.com', + }, + fields_set: { api_key: true, base_url: true }, + status: 'connected', + created_at: 0, + updated_at: 0, + }; + mocks.listDevices.mockResolvedValue({ data: [initialDevice] }); + mocks.listTemplates.mockResolvedValue({ + data: [ + buildTemplate({ + plugin_id: 'tdp_v3_3_10', + storage_key: 'tdp_api_v3_3_10', + service_id: 'tdp_api', + name: 'TDP', + credential_schema: [ + { + key: 'api_key', + label: 'API Key', + storage: 'secret', + sensitive: true, + required: true, + input_type: 'password', + config_key: 'api_key', + }, + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + ], + }), + ], + }); + + render(); + + await user.click(await screen.findByText('TDP-test-02')); + const apiKeyInput = await screen.findByDisplayValue('07***0af9'); + await user.clear(apiKeyInput); + await user.click(screen.getByRole('button', { name: /保存配置/ })); + + await waitFor(() => { + expect(mocks.updateDevice).toHaveBeenCalledWith( + 'device-1', + expect.objectContaining({ + fields: expect.objectContaining({ api_key: '' }), + }), + ); + }); + }); + + it('does not show the legacy page connectivity test button', async () => { const user = userEvent.setup(); const initialDevice = { id: 'device-1', @@ -856,24 +1403,230 @@ describe('DeviceIntegrationPage', () => { const baseUrl = await screen.findByDisplayValue('https://persisted.example.com'); await user.clear(baseUrl); await user.type(baseUrl, 'https://draft.example.com'); - await user.click(screen.getByRole('button', { name: /连通测试/ })); - await waitFor(() => { - expect(mocks.testDevice).toHaveBeenCalledWith('device-1', { - fields: expect.objectContaining({ - base_url: 'https://draft.example.com', - api_prefix: '/api', - username: 'admin', - password: 'p***word', - }), - verify_ssl: false, - base_url: 'https://draft.example.com', - }); - }); + expect(screen.queryByRole('button', { name: /连通测试/ })).not.toBeInTheDocument(); + expect(mocks.testDevice).not.toHaveBeenCalled(); expect(mocks.getDevice).not.toHaveBeenCalled(); expect(mocks.listDevices).toHaveBeenCalledTimes(1); expect(screen.getByDisplayValue('https://draft.example.com')).toBeInTheDocument(); - expect(await screen.findByText('HTTP 200, 163ms')).toBeInTheDocument(); + }); + + it('sends current config context to Rex for assisted device testing', async () => { + const user = userEvent.setup(); + const initialDevice = { + id: 'device-1', + group_id: 'group-1', + name: 'onesig-02', + storage_key: 'onesig_api_v2_5_3', + service_id: 'onesig_api', + enabled: true, + verify_ssl: false, + fields: { + base_url: 'https://persisted.example.com', + username: 'admin', + password: 'p***word', + }, + fields_set: { base_url: true, username: true, password: true }, + status: 'connected', + message: 'last ok', + created_at: 0, + updated_at: 0, + }; + mocks.listDevices.mockResolvedValue({ data: [initialDevice] }); + mocks.listTemplates.mockResolvedValue({ + data: [ + buildTemplate({ + plugin_id: 'onesig_v2_5_3', + storage_key: 'onesig_api_v2_5_3', + service_id: 'onesig_api', + name: 'OneSIG', + vendor: 'threatbook', + docs_url: 'https://docs.example.com/onesig', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + { + key: 'username', + label: 'Username', + storage: 'config', + sensitive: false, + required: true, + input_type: 'text', + config_key: 'username', + }, + { + key: 'password', + label: 'Password', + storage: 'secret', + sensitive: true, + required: true, + input_type: 'password', + config_key: 'password', + }, + ], + }), + ], + }); + + render(); + + await user.click(await screen.findByText('onesig-02')); + expect(await screen.findByText('Rex 辅助配置')).toBeInTheDocument(); + expect(screen.queryByRole('button', { name: /帮我补全/ })).not.toBeInTheDocument(); + await user.click(screen.getByRole('button', { name: /测试设备/ })); + + await waitFor(() => expect(mocks.createAndSend).toHaveBeenCalledWith(expect.objectContaining({ + displayText: '设备「onesig-02」请帮我测试。', + agent: 'rex', + model: { providerID: 'openai', modelID: 'gpt-4.1' }, + }))); + const prompt = mocks.createAndSend.mock.calls.at(-1)?.[0].text; + expect(prompt).toContain('device_id=device-1'); + expect(prompt).toContain('配置指引文档=https://docs.example.com/onesig'); + expect(prompt).toContain('任务:请测试这台设备的连通性并完成基础冒烟验证'); + expect(prompt).toContain('第一步必须调用 `device_manage`'); + expect(prompt).toContain('action="connectivity_test"'); + expect(prompt).toContain('完成标准连通性检测并更新设备卡片状态'); + expect(prompt).toContain('卡片状态只以 `device_manage(action="connectivity_test")` 写入的 status 为准'); + expect(prompt).toContain('必须使用上面的 device_id 作为目标设备'); + expect(prompt).toContain('password (Password): 已填写(敏感值未发送明文)'); + expect(prompt).not.toContain('p***word'); + }); + + it('uses credential defaults as real values for existing devices when saved fields are empty', async () => { + const user = userEvent.setup(); + const initialDevice = { + id: 'device-1', + group_id: 'group-1', + name: 'Eagle Sensor', + storage_key: 'eagle_sensor_v1_0', + service_id: 'eagle_sensor_device', + enabled: true, + verify_ssl: false, + fields: { + base_url: 'https://eagle-sensor.threatbook-inc.cn', + }, + fields_set: { base_url: true }, + status: 'unknown', + created_at: 0, + updated_at: 0, + }; + mocks.listDevices.mockResolvedValue({ data: [initialDevice] }); + mocks.listTemplates.mockResolvedValue({ + data: [ + buildTemplate({ + plugin_id: 'eagle_sensor_v1_0', + storage_key: 'eagle_sensor_v1_0', + service_id: 'eagle_sensor_device', + name: 'Eagle Sensor', + vendor: 'threatbook', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + { + key: 'auth_state_path', + label: 'Auth State Path', + storage: 'config', + sensitive: false, + required: false, + input_type: 'text', + config_key: 'auth_state_path', + default_value: '~/.flocks/browser/eagle-sensor/auth-state.json', + }, + ], + }), + ], + }); + + render(); + + await user.click(await screen.findByText('Eagle Sensor')); + expect(await screen.findByDisplayValue('~/.flocks/browser/eagle-sensor/auth-state.json')).toBeInTheDocument(); + await user.click(screen.getByRole('button', { name: /测试设备/ })); + + await waitFor(() => expect(mocks.createAndSend).toHaveBeenCalled()); + const prompt = mocks.createAndSend.mock.calls.at(-1)?.[0].text; + expect(prompt).toContain('auth_state_path (Auth State Path): ~/.flocks/browser/eagle-sensor/auth-state.json'); + expect(prompt).not.toContain('auth_state_path (Auth State Path): 未填写'); + }); + + it('refreshes the device card status after Rex assisted testing writes a result', async () => { + const user = userEvent.setup(); + const initialDevice = { + id: 'device-1', + group_id: 'group-1', + name: 'onesig-02', + storage_key: 'onesig_api_v2_5_3', + service_id: 'onesig_api', + enabled: true, + verify_ssl: false, + fields: { + base_url: 'https://persisted.example.com', + }, + fields_set: { base_url: true }, + status: 'unknown', + message: null, + latency_ms: null, + checked_at: null, + created_at: 0, + updated_at: 0, + }; + const updatedDevice = { + ...initialDevice, + status: 'ok', + message: 'HTTP 200,延迟 10ms', + latency_ms: 10, + checked_at: 1782800230695, + }; + mocks.listDevices.mockResolvedValue({ data: [initialDevice] }); + mocks.getDevice.mockResolvedValue({ data: updatedDevice }); + mocks.listTemplates.mockResolvedValue({ + data: [ + buildTemplate({ + plugin_id: 'onesig_v2_5_3', + storage_key: 'onesig_api_v2_5_3', + service_id: 'onesig_api', + name: 'OneSIG', + vendor: 'threatbook', + credential_schema: [ + { + key: 'base_url', + label: 'Base URL', + storage: 'config', + sensitive: false, + required: true, + input_type: 'url', + config_key: 'base_url', + }, + ], + }), + ], + }); + + render(); + + await user.click(await screen.findByText('onesig-02')); + expect(screen.getByText('未检测')).toBeInTheDocument(); + await user.click(screen.getByRole('button', { name: /测试设备/ })); + + await waitFor(() => { + expect(mocks.getDevice).toHaveBeenCalledWith('device-1'); + expect(screen.getByText('已连接')).toBeInTheDocument(); + }); }); it('reveals the full persisted secret when clicking show', async () => { @@ -909,6 +1662,26 @@ describe('DeviceIntegrationPage', () => { name: 'OneSEC', tool_count: 5, vendor: 'threatbook', + credential_schema: [ + { + key: 'api_key', + label: 'API Key', + storage: 'secret', + sensitive: true, + required: true, + input_type: 'password', + config_key: 'api_key', + }, + { + key: 'secret', + label: 'Secret', + storage: 'secret', + sensitive: true, + required: true, + input_type: 'password', + config_key: 'secret', + }, + ], }), ], }); diff --git a/webui/src/pages/DeviceIntegration/index.tsx b/webui/src/pages/DeviceIntegration/index.tsx index d8bca4200..fdf1b3c05 100644 --- a/webui/src/pages/DeviceIntegration/index.tsx +++ b/webui/src/pages/DeviceIntegration/index.tsx @@ -1,25 +1,25 @@ import { useState, useEffect, useCallback, useMemo, useRef } from 'react'; import { useTranslation } from 'react-i18next'; -import { useNavigate } from 'react-router-dom'; import { Shield, CheckCircle, XCircle, AlertTriangle, RefreshCw, Plug, PlugZap, WifiOff, Plus, Settings, Loader2, Eye, EyeOff, Save, Trash2, Activity, X, Server, Pencil, Check, - Wrench, ChevronRight, ChevronLeft, ChevronDown, Building2, ServerCog, + Wrench, ChevronRight, ChevronLeft, ChevronDown, Building2, ServerCog, Sparkles, } from 'lucide-react'; import PageHeader from '@/components/common/PageHeader'; import LoadingSpinner from '@/components/common/LoadingSpinner'; import { useToast } from '@/components/common/Toast'; import SessionChat from '@/components/common/SessionChat'; +import GuideInfoIcon from '@/components/common/GuideInfoIcon'; import { useRexComposerControls } from '@/components/common/useRexComposerControls'; -import { useSessionChat } from '@/hooks/useSessionChat'; +import { useSessionChat, type CreateAndSendOptions } from '@/hooks/useSessionChat'; import { sessionApi } from '@/api/session'; import { providerAPI } from '@/api/provider'; import { deviceAPI, type DeviceIntegration, type DeviceGroup, type DeviceTemplate, type DeviceToolInfo } from '@/api/device'; -import type { APIServiceCredentialField, CustomDeviceAccessMode, Tool } from '@/types'; +import { hubAPI } from '@/api/hub'; +import type { APIServiceCredentialField, Tool } from '@/types'; import { toolAPI } from '@/api/tool'; import ToolDetailModal from '../Tool/components/ToolDetailModal'; -import CustomDeviceAccessPanel from './CustomDeviceAccessPanel'; import { buildCustomDeviceModeRoutingPrompt } from './customDevice'; // ============================================================================ @@ -54,14 +54,19 @@ interface DeviceVendor { nameCn: string; nameEn: string; color: string; + mark?: string; + logoSrc?: string; } const VENDOR_PRESENTATION: Record> = { - sangfor: { nameCn: '深信服', nameEn: 'Sangfor', color: 'bg-blue-100 text-blue-800' }, - qianxin: { nameCn: '奇安信', nameEn: 'Qi-AnXin', color: 'bg-purple-100 text-purple-800' }, - threatbook: { nameCn: '微步', nameEn: 'ThreatBook', color: 'bg-orange-100 text-orange-800' }, - qingteng: { nameCn: '青藤', nameEn: 'Qingteng', color: 'bg-teal-100 text-teal-800' }, - nsfocus: { nameCn: '绿盟', nameEn: 'NSFOCUS', color: 'bg-green-100 text-green-800' }, + '360': { nameCn: '360', nameEn: '360', color: 'bg-zinc-100 text-zinc-700', mark: '360', logoSrc: '/vendor-logos/360.png' }, + huaweicloud: { nameCn: '华为云', nameEn: 'Huawei Cloud', color: 'bg-red-100 text-red-700', mark: '华', logoSrc: '/vendor-logos/huaweicloud.png' }, + huorong: { nameCn: '火绒', nameEn: 'Huorong', color: 'bg-amber-100 text-amber-700', mark: '火', logoSrc: '/vendor-logos/huorong.png' }, + sangfor: { nameCn: '深信服', nameEn: 'Sangfor', color: 'bg-blue-100 text-blue-800', mark: '深', logoSrc: '/vendor-logos/sangfor.png' }, + qianxin: { nameCn: '奇安信', nameEn: 'Qi-AnXin', color: 'bg-purple-100 text-purple-800', mark: '奇', logoSrc: '/vendor-logos/qianxin.png' }, + threatbook: { nameCn: '微步', nameEn: 'ThreatBook', color: 'bg-orange-100 text-orange-800', mark: '微', logoSrc: '/vendor-logos/threatbook.png' }, + qingteng: { nameCn: '青藤', nameEn: 'Qingteng', color: 'bg-teal-100 text-teal-800', mark: '青', logoSrc: '/vendor-logos/qingteng.png' }, + nsfocus: { nameCn: '绿盟', nameEn: 'NSFOCUS', color: 'bg-green-100 text-green-800', mark: '绿', logoSrc: '/vendor-logos/nsfocus.png' }, }; function vendorPresentation(vendorKey: string): DeviceVendor { @@ -72,11 +77,50 @@ function vendorPresentation(vendorKey: string): DeviceVendor { nameCn: vendorKey, nameEn: vendorKey, color: 'bg-zinc-100 text-zinc-700', + mark: vendorKey[0]?.toUpperCase() || '?', }; } -function deviceTemplateHubUrl(template: DeviceTemplate): string { - return `/hub?type=device&plugin=${encodeURIComponent(template.plugin_id)}&q=${encodeURIComponent(template.plugin_id)}`; +function VendorMark({ vendor, label, className = 'h-6 w-6 rounded-md text-[11px]' }: { + vendor: DeviceVendor; + label: string; + className?: string; +}) { + const [logoFailed, setLogoFailed] = useState(false); + const showLogo = !!vendor.logoSrc && !logoFailed; + return ( + + ); +} + +function templateAction(template: DeviceTemplate): 'install' | 'update' | null { + if (template.installed) return null; + if (template.state === 'available') return 'install'; + if (template.state === 'updateAvailable') return 'update'; + return null; +} + +function formatTemplateVersion(version: string): string { + return /^v/i.test(version) ? version : `v${version}`; } // ============================================================================ @@ -155,8 +199,6 @@ function ActiveCard({ device, vendorKey, selected, onClick }: { // Add device wizard panel (step 1: vendor, step 2: product) // ============================================================================ -type DeviceAddTab = 'rex' | 'manual'; - interface DeviceAddDraft { template: DeviceTemplate; name?: string; @@ -185,7 +227,8 @@ function buildDeviceAddSessionContext(templates: DeviceTemplate[]): string { `vendor=${template.vendor || 'unspecified'}`, `state=${template.installed ? 'installed' : template.state}`, fields ? `fields=${fields}` : 'fields=none', - ].join(' | '); + template.docs_url ? `docs_url=${template.docs_url}` : null, + ].filter(Boolean).join(' | '); }); return [ @@ -329,30 +372,274 @@ function buildDeviceDraftAction( }; } +function buildTemplateGuidePrompt(template: DeviceTemplate): string { + const fields = template.credential_schema + .map((field) => `${field.key}${field.required ? '*' : ''}${field.label ? ` (${field.label})` : ''}`) + .join(', '); + const installed = template.installed; + return [ + `我要接入设备「${template.name}」。`, + '我已从已支持设备列表选择了这个设备模板,请按该模板继续引导接入。', + `模板信息:storage_key=${template.storage_key},service_id=${template.service_id},plugin_id=${template.plugin_id},状态=${installed ? 'installed' : template.state}。`, + template.docs_url ? `配置指引文档:${template.docs_url}。请优先结合该文档引导用户完成设备侧准备和 Flocks 侧配置。` : null, + fields ? `该设备表单字段包括:${fields}。` : '该设备模板没有声明额外表单字段。', + installed + ? '请直接引导我确认设备名称、所属机房、连接地址、认证字段、SSL 验证和连通测试步骤。' + : '该模板尚未安装,请先引导我前往 FlockHub 安装或更新该设备模板,安装完成后再继续配置。', + installed + ? '信息足够后,请输出设备配置 JSON 草稿,页面会用它填充表单;不要在 JSON 中写入真实密钥。' + : '模板安装完成前不要输出设备配置 JSON 草稿。', + ].filter(Boolean).join('\n'); +} + +function buildDeviceTestGuidePrompt(device: DeviceIntegration, template: DeviceTemplate): CreateAndSendOptions { + const fieldKeys = Object.keys(device.fields || {}); + const fieldStatus = fieldKeys.length > 0 + ? fieldKeys.map((key) => `${key}${device.fields_set?.[key] ? '(已填写)' : ''}`).join(', ') + : '无额外字段'; + const text = [ + `设备「${device.name}」已确认接入并保存。`, + `device_id=${device.id},storage_key=${device.storage_key},service_id=${device.service_id},模板名称=${template.name}。`, + `设备当前状态=${device.status},enabled=${device.enabled},verify_ssl=${device.verify_ssl},group_id=${device.group_id}。`, + `已填写字段:${fieldStatus}。`, + '请继续留在当前会话,直接调用这台设备的可用工具完成连通测试和基础冒烟验证。', + '不要再询问接入方式,也不要让我在 API 接入、浏览器接入、Workflow 接入之间选择;不要重新输出设备配置 JSON 草稿。', + '测试时请优先选择只读、低风险工具,并明确使用上面的 device_id 作为目标设备;如果需要执行写操作或高风险动作,必须先说明风险并请求确认。', + '如果工具调用失败,请根据返回错误给出优先排查项,例如地址、认证字段、SSL 验证、网络连通性或设备侧权限。', + ].join('\n'); + return { + text, + displayText: `设备「${device.name}」已确认接入,请帮我测试。`, + }; +} + +type DeviceRexAssistAction = 'test' | 'troubleshoot'; + +interface DeviceConfigRexAssistInput { + action: DeviceRexAssistAction; + device?: DeviceIntegration; + template?: DeviceTemplate; + metadata?: { name?: string; version?: string; description?: string; description_cn?: string; docs_url?: string } | null; + name: string; + groupId: string; + fields: Record; + fieldsSet?: Record; + credentialSchema: APIServiceCredentialField[]; + verifySsl: boolean; + enabled: boolean; +} + +function summarizeDeviceFormFields(input: DeviceConfigRexAssistInput): string { + const schema = input.credentialSchema || []; + const knownKeys = new Set(schema.map((field) => field.key)); + const lines = schema.map((field) => { + const isSecret = field.storage === 'secret' || field.input_type === 'password'; + const value = (input.fields[field.key] ?? '').trim(); + const hasPersisted = !!input.fieldsSet?.[field.key]; + const hasValue = Boolean(value) || hasPersisted; + const label = field.label ? ` (${field.label})` : ''; + if (isSecret) { + return `- ${field.key}${label}: ${hasValue ? '已填写(敏感值未发送明文)' : '未填写'}${field.required ? ',必填' : ''}`; + } + return `- ${field.key}${label}: ${value || '未填写'}${field.required ? ',必填' : ''}`; + }); + Object.entries(input.fields).forEach(([key, value]) => { + if (knownKeys.has(key)) return; + lines.push(`- ${key}: ${value || '未填写'}`); + }); + return lines.length ? lines.join('\n') : '- 无额外字段'; +} + +function buildDeviceConfigRexAssistPrompt(input: DeviceConfigRexAssistInput): CreateAndSendOptions { + const template = input.template; + const device = input.device; + const docsUrl = input.metadata?.docs_url ?? template?.docs_url; + const identityLines = [ + `设备名称=${input.name || device?.name || '未命名设备'}`, + device ? `device_id=${device.id}` : 'device_id=尚未保存', + `storage_key=${device?.storage_key ?? template?.storage_key ?? 'unknown'}`, + `service_id=${device?.service_id ?? template?.service_id ?? 'unknown'}`, + template?.name ? `模板名称=${template.name}` : null, + input.metadata?.version || template?.version ? `版本=${input.metadata?.version ?? template?.version}` : null, + `group_id=${input.groupId}`, + `enabled=${input.enabled}`, + `verify_ssl=${input.verifySsl}`, + docsUrl ? `配置指引文档=${docsUrl}` : null, + ].filter(Boolean).join('\n'); + const fieldSummary = summarizeDeviceFormFields(input); + const statusLine = device + ? `当前状态=${device.status},message=${device.message || '无'},latency_ms=${device.latency_ms ?? '无'}` + : '当前状态=尚未保存到设备列表'; + const common = [ + '你是 Flocks 的设备接入助手,请基于当前设备配置上下文继续工作。', + '不要要求我在 API 接入、浏览器接入、Workflow 接入之间重新选择;不要索要或复述真实密钥。', + '敏感字段只根据“已填写/未填写”判断,不要让用户在对话里粘贴真实密码、Token、API Key 或 Secret。', + '', + '设备上下文:', + identityLines, + statusLine, + '', + '当前表单字段:', + fieldSummary, + ]; + + if (input.action === 'test') { + return { + text: [ + ...common, + '', + '任务:请测试这台设备的连通性并完成基础冒烟验证。', + '第一步必须调用 `device_manage`,参数为 action="connectivity_test" 且传入上面的 device_id,完成标准连通性检测并更新设备卡片状态。', + '连通性检测成功后,再调用这台设备的少量可用只读工具完成基础冒烟验证。', + '卡片状态只以 `device_manage(action="connectivity_test")` 写入的 status 为准;其他工具调用结果用于功能验证总结。', + '必须使用上面的 device_id 作为目标设备;优先选择只读、低风险工具。', + '如果需要执行写操作或高风险动作,必须先说明风险并请求确认。', + '完成后总结成功/失败结果;失败时给出地址、认证字段、SSL 验证、网络连通性或设备侧权限等优先排查项。', + ].join('\n'), + displayText: `设备「${input.name || device?.name || '未命名设备'}」请帮我测试。`, + }; + } + + return { + text: [ + ...common, + '', + '任务:请帮我排查这台设备的连接或配置问题。', + '请先根据当前状态、最近测试结果和表单字段判断最可能原因,再给出按优先级排序的排查步骤。', + '如果需要验证,请优先调用只读工具;如果信息不足,只问一个最关键问题。', + ].join('\n'), + displayText: `设备「${input.name || device?.name || '未命名设备'}」请帮我排查连接问题。`, + }; +} + function DeviceAddRexPanel({ templates, + sessionId, + showBuiltInTemplates, + setShowBuiltInTemplates, + workbenchResetToken, + createAndSend, + rexComposerControls, onApplyDraft, onInstallTemplate, + instanceCounts, }: { templates: DeviceTemplate[]; + sessionId: string | null; + showBuiltInTemplates: boolean; + setShowBuiltInTemplates: (show: boolean) => void; + workbenchResetToken: number; + createAndSend: (options: CreateAndSendOptions) => Promise; + rexComposerControls: ReturnType; onApplyDraft: (draft: DeviceAddDraft) => void; - onInstallTemplate: (template: DeviceTemplate) => void; + onInstallTemplate: (template: DeviceTemplate) => Promise; + instanceCounts: Record; }) { - const { t } = useTranslation('device'); + const { t, i18n } = useTranslation('device'); const toast = useToast(); const [extracting, setExtracting] = useState(false); const [detectedAction, setDetectedAction] = useState(null); - const rexComposerControls = useRexComposerControls(); - const contextMessage = useMemo(() => buildDeviceAddSessionContext(templates), [templates]); - const welcomeMessage = t('wizard.rex.welcome'); - const { sessionId, createAndSend, reset } = useSessionChat({ - title: t('wizard.rex.title'), - category: 'entity-config', - contextMessage, - welcomeMessage, - }); + const [expandedVendors, setExpandedVendors] = useState>(new Set()); + const [installingTemplateKey, setInstallingTemplateKey] = useState(null); - useEffect(() => reset, [reset]); + useEffect(() => { + if (workbenchResetToken === 0) return; + setDetectedAction(null); + setExpandedVendors(new Set()); + setInstallingTemplateKey(null); + }, [workbenchResetToken]); + + const startGuidedPrompt = useCallback((prompt: string) => { + createAndSend({ + text: prompt, + agent: rexComposerControls.rexAgentName, + model: rexComposerControls.rexModel, + }).catch(() => {}); + }, [createAndSend, rexComposerControls.rexAgentName, rexComposerControls.rexModel]); + + const vendorGroups = useMemo(() => { + const groups = new Map(); + for (const template of templates) { + const vendorKey = template.vendor || '__unspecified__'; + const vendor = vendorKey === '__unspecified__' + ? { id: vendorKey, nameCn: t('vendor.unspecified'), nameEn: 'Unspecified', color: 'bg-zinc-100 text-zinc-600' } + : vendorPresentation(vendorKey); + if (!groups.has(vendorKey)) { + groups.set(vendorKey, { vendor, templates: [] }); + } + groups.get(vendorKey)!.templates.push(template); + } + return Array.from(groups.values()) + .map((group) => ({ + ...group, + templates: [...group.templates].sort((a, b) => { + if (a.installed !== b.installed) return a.installed ? -1 : 1; + return a.name.localeCompare(b.name); + }), + })) + .sort((a, b) => { + const rank = (vendor: DeviceVendor) => { + if (vendor.id === 'threatbook') return 0; + if (vendor.id === '__unspecified__') return 99; + return 1; + }; + const ra = rank(a.vendor); + const rb = rank(b.vendor); + if (ra !== rb) return ra - rb; + return a.vendor.id.localeCompare(b.vendor.id); + }); + }, [templates, t]); + + const findCaseTemplate = useCallback((keywords: string[]) => { + const normalizedKeywords = keywords.map((keyword) => keyword.toLowerCase()); + const matches = templates.filter((template) => { + const haystack = [ + template.name, + template.plugin_id, + template.storage_key, + template.service_id, + ].join(' ').toLowerCase(); + return normalizedKeywords.some((keyword) => haystack.includes(keyword)); + }); + return matches.find((template) => template.installed) ?? matches[0]; + }, [templates]); + + const handleTemplatePrompt = useCallback(async (template: DeviceTemplate) => { + const action = templateAction(template); + if (!action) { + setShowBuiltInTemplates(false); + startGuidedPrompt(buildTemplateGuidePrompt(template)); + return; + } + setInstallingTemplateKey(template.storage_key); + try { + const installedTemplate = await onInstallTemplate(template); + if (installedTemplate) { + setShowBuiltInTemplates(false); + startGuidedPrompt(buildTemplateGuidePrompt(installedTemplate)); + } + } finally { + setInstallingTemplateKey(null); + } + }, [onInstallTemplate, startGuidedPrompt]); + + const handleCaseTemplate = useCallback((keywords: string[], fallbackPrompt: string) => { + const template = findCaseTemplate(keywords); + if (!template) { + startGuidedPrompt(fallbackPrompt); + return; + } + void handleTemplatePrompt(template); + }, [findCaseTemplate, handleTemplatePrompt, startGuidedPrompt]); + + const toggleVendor = (vendorId: string) => { + setExpandedVendors((current) => { + const next = new Set(current); + if (next.has(vendorId)) next.delete(vendorId); + else next.add(vendorId); + return next; + }); + }; const detectLatestDraft = useCallback(async (silent: boolean) => { if (!sessionId || extracting) return; @@ -379,11 +666,14 @@ function DeviceAddRexPanel({ } }, [extracting, sessionId, t, templates, toast]); - const handleConfirmDetectedDraft = () => { + const handleConfirmDetectedDraft = async () => { if (!detectedAction) return; if (detectedAction.kind === 'install') { - toast.info(t('wizard.rex.installFirst')); - onInstallTemplate(detectedAction.template); + const installedTemplate = await onInstallTemplate(detectedAction.template); + setDetectedAction(null); + if (installedTemplate) { + startGuidedPrompt(buildTemplateGuidePrompt(installedTemplate)); + } return; } onApplyDraft(detectedAction.draft); @@ -414,15 +704,148 @@ function DeviceAddRexPanel({ emptyText={t('wizard.rex.pending')} onStreamingDone={() => void detectLatestDraft(true)} welcomeContent={ -
-
- R -
-
-
Rex
-
- {welcomeMessage} -
+
+
+ {!showBuiltInTemplates ? ( + <> +
+
+ +
+

{t('wizard.guide.title')}

+

+ {t('wizard.guide.subtitle')} +

+
+ +
+ + startGuidedPrompt(t('wizard.guide.prompts.api'))} + /> + startGuidedPrompt(t('wizard.guide.prompts.browser'))} + /> + + + + handleCaseTemplate(['tdp'], t('wizard.guide.prompts.tdp'))} + /> + handleCaseTemplate(['onesec', 'one sec'], t('wizard.guide.prompts.onesec'))} + /> + setShowBuiltInTemplates(true)} + /> + +
+ + ) : ( + <> +
+ +

{t('wizard.supportedList.title')}

+

{t('wizard.supportedList.subtitle')}

+
+ +
+ {vendorGroups.map(({ vendor, templates: vendorTemplates }) => { + const expanded = expandedVendors.has(vendor.id); + const vendorName = i18n.language.startsWith('zh') ? vendor.nameCn : vendor.nameEn; + const integratedCount = vendorTemplates.reduce( + (sum, template) => sum + (instanceCounts[template.storage_key] ?? 0), + 0, + ); + return ( +
+ + {expanded && ( +
+ {vendorTemplates.map((tpl) => { + const count = instanceCounts[tpl.storage_key] ?? 0; + const action = templateAction(tpl); + const installing = installingTemplateKey === tpl.storage_key; + const templateMeta = tpl.version ? formatTemplateVersion(tpl.version) : tpl.storage_key; + const stateBadge = tpl.installed + ? t('wizard.installState.installed') + : tpl.state === 'updateAvailable' + ? t('wizard.installState.updateAvailable') + : tpl.state === 'broken' + ? t('wizard.installState.brokenShort') + : t('wizard.installState.available'); + return ( + + ); + })} +
+ )} +
+ ); + })} +
+ + )}
} @@ -464,63 +887,72 @@ function DeviceAddRexPanel({ ); } -function AddDeviceWizardPanel({ templates, instanceCounts, initialVendor, onSelect, onSelectCustom, onApplyRexDraft, onInstallTemplate, onClose }: { +function WorkbenchSection({ + title, + children, +}: { + title: string; + children: React.ReactNode; +}) { + return ( +
+

{title}

+
+ {children} +
+
+ ); +} + +function WorkbenchAction({ label, description, onClick }: { label: string; description: string; onClick: () => void }) { + return ( + + ); +} + +function AddDeviceWizardPanel({ + templates, + instanceCounts, + sessionId, + createAndSend, + rexComposerControls, + onApplyRexDraft, + onInstallTemplate, + onResetWorkbench, + onClose, +}: { templates: DeviceTemplate[]; instanceCounts: Record; - initialVendor?: DeviceVendor; - onSelect: (template: DeviceTemplate) => void; - onSelectCustom: (mode: CustomDeviceAccessMode) => void; + sessionId: string | null; + createAndSend: (options: CreateAndSendOptions) => Promise; + rexComposerControls: ReturnType; onApplyRexDraft: (draft: DeviceAddDraft) => void; - onInstallTemplate: (template: DeviceTemplate) => void; + onInstallTemplate: (template: DeviceTemplate) => Promise; + onResetWorkbench: () => void; onClose: () => void; }) { - const { t, i18n } = useTranslation('device'); - const navigate = useNavigate(); - const [activeTab, setActiveTab] = useState('rex'); - const [selectedVendor, setSelectedVendor] = useState(initialVendor ?? null); - const [showCustomModes, setShowCustomModes] = useState(false); - - const availableVendors = useMemo(() => { - const seen: string[] = []; - for (const t of templates) { - const key = t.vendor || '__unspecified__'; - if (!seen.includes(key)) seen.push(key); - } - seen.sort((a, b) => { - const rank = (k: string) => { - if (k === 'threatbook') return 0; - if (k === '__unspecified__') return 99; - return 1; - }; - const ra = rank(a); - const rb = rank(b); - if (ra !== rb) return ra - rb; - return a.localeCompare(b); - }); - return seen.map((key) => - key === '__unspecified__' - ? { id: '__unspecified__', nameCn: t('vendor.unspecified'), nameEn: 'Unspecified', color: 'bg-zinc-100 text-zinc-600' } - : vendorPresentation(key), - ); - }, [templates]); - - const vendorTotalCounts = useMemo(() => { - const counts: Record = {}; - for (const t of templates) { - const key = t.vendor || '__unspecified__'; - counts[key] = (counts[key] ?? 0) + (instanceCounts[t.storage_key] ?? 0); - } - return counts; - }, [templates, instanceCounts]); - - const vendorTemplates = useMemo(() => { - if (!selectedVendor) return []; - return templates.filter((t) => (t.vendor || '__unspecified__') === selectedVendor.id); - }, [templates, selectedVendor]); + const { t } = useTranslation('device'); + const [showBuiltInTemplates, setShowBuiltInTemplates] = useState(false); + const [workbenchResetToken, setWorkbenchResetToken] = useState(0); - const inModeSelection = showCustomModes && !selectedVendor; - const shouldShowVendorSecondary = (vendor: DeviceVendor) => - vendor.nameCn.trim().toLocaleLowerCase() !== vendor.nameEn.trim().toLocaleLowerCase(); + const handleWorkbenchClick = () => { + setShowBuiltInTemplates(false); + setWorkbenchResetToken((current) => current + 1); + onResetWorkbench(); + }; return (
@@ -538,244 +970,39 @@ function AddDeviceWizardPanel({ templates, instanceCounts, initialVendor, onSele {/* Header */}
-

{t('wizard.title')}

+
+

{t('wizard.title')}

+
-
- {[ - { key: 'rex' as const, label: t('wizard.tabs.rex') }, - { key: 'manual' as const, label: t('wizard.tabs.manual') }, - ].map((tab) => ( - - ))} +
+
- {activeTab === 'manual' && ( -
- {(selectedVendor || inModeSelection) && ( - - )} -
- {(selectedVendor || inModeSelection) && ( -

- {selectedVendor - ? t('wizard.selectVendorTitle', { vendor: i18n.language.startsWith('zh') ? selectedVendor.nameCn : selectedVendor.nameEn }) - : t('wizard.modeTitle')} -

- )} -
- - {t('wizard.step1Custom')} - - - - {t('wizard.step2Custom')} - - - - {t('wizard.step3Custom')} - -
-
-
- )}
{/* Content */} -
- {activeTab === 'rex' ? ( - - ) : !selectedVendor && !inModeSelection ? ( - <> -

{t('wizard.chooseVendorOrCustom')}

-
- - {availableVendors.map((vendor) => { - const count = vendorTotalCounts[vendor.id] ?? 0; - const productCount = templates.filter( - (t) => (t.vendor || '__unspecified__') === vendor.id, - ).length; - const primaryName = i18n.language.startsWith('zh') ? vendor.nameCn : vendor.nameEn; - const secondaryName = i18n.language.startsWith('zh') ? vendor.nameEn : vendor.nameCn; - const showSecondary = shouldShowVendorSecondary(vendor); - return ( - - ); - })} -
- - ) : inModeSelection ? ( - <> -

{t('wizard.chooseCustomMode')}

-
- {[ - { - key: 'api' as const, - title: t('wizard.customModes.api.title'), - desc: t('wizard.customModes.api.desc'), - }, - { - key: 'webcli' as const, - title: t('wizard.customModes.webcli.title'), - desc: t('wizard.customModes.webcli.desc'), - }, - { - key: 'workflow' as const, - title: t('wizard.customModes.workflow.title'), - desc: t('wizard.customModes.workflow.desc'), - }, - ].map((mode) => ( - - ))} -
- - ) : ( - <> -

- {t('wizard.productHint', { count: vendorTemplates.length })} -

-
- {vendorTemplates.map((tpl) => { - const count = instanceCounts[tpl.storage_key] ?? 0; - const disabled = !tpl.installed; - const stateHint = tpl.state === 'updateAvailable' - ? t('wizard.installState.update') - : tpl.state === 'broken' - ? t('wizard.installState.broken') - : t('wizard.installState.install'); - const stateBadge = tpl.state === 'updateAvailable' - ? t('wizard.installState.updateAvailable') - : tpl.state === 'broken' - ? t('wizard.installState.brokenShort') - : tpl.installed - ? t('wizard.installState.installed') - : t('wizard.installState.available'); - const hubUrl = deviceTemplateHubUrl(tpl); - return ( - - ); - })} -
- - )} +
+
@@ -788,6 +1015,22 @@ function AddDeviceWizardPanel({ templates, instanceCounts, initialVendor, onSele type PanelTab = 'config' | 'tools' | 'overview'; +function applyCredentialDefaults( + schema: APIServiceCredentialField[], + fields: Record, +): Record { + const next = { ...fields }; + schema.forEach((field) => { + const defaultValue = field.default_value; + if (!defaultValue) return; + const currentValue = next[field.key] ?? ''; + if (!currentValue.trim()) { + next[field.key] = defaultValue; + } + }); + return next; +} + function Toggle({ on, onToggle }: { on: boolean; onToggle: () => void }) { return ( )}
- {isSecret && device && hasExisting && ( -

{t('config.secretConfigured')}

- )} {f.description &&

{f.description}

}
); @@ -1206,33 +1480,48 @@ function DeviceConfigPanel({
- {testResult && ( -
- {testResult.success - ? - : } - {testResult.message} + {onRexAssist && ( +
+
+
+
+ +
+

{t('config.aiAssistTitle')}

+
+
+ {([ + ['test', t('config.aiAssistTest'), Activity], + ['troubleshoot', t('config.aiAssistTroubleshoot'), AlertTriangle], + ] as const).map(([action, label, Icon]) => { + const disabled = action === 'test' && !device; + const loading = rexAssistAction === action; + return ( + + ); + })} +
+
+

{t('config.aiAssistHint')}

)}
- {device && onTest && ( - - )}
)} + {licenseSyncWarning && ( +
+ {licenseSyncWarning} +
+ )} {requestError && (
{requestError} @@ -1360,15 +1409,16 @@ export default function FlocksproUpgradePage() {

{t('upgrade.startUpgrade')}

- + {!upgradeInProgress && ( + + )}
{proRestarting ? t('upgrade.waitingRestart') : t('upgrade.installingHint')} @@ -1445,7 +1495,7 @@ export default function FlocksproUpgradePage() {
)}
- {!proUpgrading && !proRestarting && ( + {!upgradeInProgress && (
); } - diff --git a/webui/src/pages/Home/index.test.tsx b/webui/src/pages/Home/index.test.tsx index 26d06a640..f5141e841 100644 --- a/webui/src/pages/Home/index.test.tsx +++ b/webui/src/pages/Home/index.test.tsx @@ -50,10 +50,10 @@ vi.mock('react-i18next', () => ({ }), })); -describe('Home create user defined page entry', () => { +describe('Home create WebUI contract page entry', () => { beforeEach(() => { vi.clearAllMocks(); - createMock.mockResolvedValue({ id: 'session-user-defined-1' }); + createMock.mockResolvedValue({ id: 'session-webui-contract-1' }); useAuthMock.mockReturnValue({ user: { id: 'user-1', @@ -73,18 +73,18 @@ describe('Home create user defined page entry', () => { , ); - await user.click(screen.getByRole('button', { name: 'createUserDefinedPage' })); + await user.click(screen.getByRole('button', { name: 'createWebUIContractPage' })); await waitFor(() => { - expect(createMock).toHaveBeenCalledWith({ title: 'createUserDefinedPageSessionTitle' }); + expect(createMock).toHaveBeenCalledWith({ title: 'createWebUIContractPageSessionTitle' }); }); expect(navigateMock).toHaveBeenCalledWith( - `/sessions?session=session-user-defined-1&message=${encodeURIComponent('createUserDefinedPageInitialMessage')}`, + `/sessions?session=session-webui-contract-1&message=${encodeURIComponent('createWebUIContractPageInitialMessage')}`, ); }); - it('hides the create user defined page entry for non-admin users', () => { + it('hides the create WebUI contract page entry for non-admin users', () => { useAuthMock.mockReturnValue({ user: { id: 'user-2', @@ -101,7 +101,7 @@ describe('Home create user defined page entry', () => { , ); - expect(screen.queryByRole('button', { name: 'createUserDefinedPage' })).not.toBeInTheDocument(); + expect(screen.queryByRole('button', { name: 'createWebUIContractPage' })).not.toBeInTheDocument(); expect(createMock).not.toHaveBeenCalled(); }); }); diff --git a/webui/src/pages/Home/index.tsx b/webui/src/pages/Home/index.tsx index 2dda6f562..717b51e74 100644 --- a/webui/src/pages/Home/index.tsx +++ b/webui/src/pages/Home/index.tsx @@ -35,24 +35,24 @@ export default function Home() { const navigate = useNavigate(); const toast = useToast(); const { user } = useAuth(); - const canCreateUserDefinedPage = user?.role === 'admin'; + const canCreateWebUIContractPage = user?.role === 'admin'; const [isRepoMenuOpen, setIsRepoMenuOpen] = useState(false); - const [creatingUserDefinedPageSession, setCreatingUserDefinedPageSession] = useState(false); + const [creatingWebUIContractPageSession, setCreatingWebUIContractPageSession] = useState(false); - const handleCreateUserDefinedPage = useCallback(async () => { - if (creatingUserDefinedPageSession) return; - setCreatingUserDefinedPageSession(true); + const handleCreateWebUIContractPage = useCallback(async () => { + if (creatingWebUIContractPageSession) return; + setCreatingWebUIContractPageSession(true); try { - const session = await sessionApi.create({ title: t('createUserDefinedPageSessionTitle') }); - const message = t('createUserDefinedPageInitialMessage'); + const session = await sessionApi.create({ title: t('createWebUIContractPageSessionTitle') }); + const message = t('createWebUIContractPageInitialMessage'); navigate(`/sessions?session=${session.id}&message=${encodeURIComponent(message)}`); } catch (err: unknown) { - const detail = err instanceof Error ? err.message : t('createUserDefinedPageError'); - toast.error(t('createUserDefinedPageError'), detail); + const detail = err instanceof Error ? err.message : t('createWebUIContractPageError'); + toast.error(t('createWebUIContractPageError'), detail); } finally { - setCreatingUserDefinedPageSession(false); + setCreatingWebUIContractPageSession(false); } - }, [creatingUserDefinedPageSession, navigate, t, toast]); + }, [creatingWebUIContractPageSession, navigate, t, toast]); return (
@@ -91,19 +91,19 @@ export default function Home() { - {canCreateUserDefinedPage ? ( + {canCreateWebUIContractPage ? ( ) : null} diff --git a/webui/src/pages/Settings/index.test.tsx b/webui/src/pages/Settings/index.test.tsx new file mode 100644 index 000000000..a82c3e46a --- /dev/null +++ b/webui/src/pages/Settings/index.test.tsx @@ -0,0 +1,199 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { render, screen, within } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { MemoryRouter, Route, Routes, useLocation } from 'react-router-dom'; +import SettingsPage from './index'; +import { ThemeContext, type Theme } from '@/contexts/ThemeContext'; + +const { changeLanguage, flocksproUsersApi, setTheme, useAuth } = vi.hoisted(() => ({ + changeLanguage: vi.fn(), + flocksproUsersApi: { + hasCapability: vi.fn(), + }, + setTheme: vi.fn(), + useAuth: vi.fn(), +})); + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + i18n: { + language: 'zh-CN', + changeLanguage, + }, + }), +})); + +vi.mock('@/contexts/AuthContext', () => ({ + useAuth, +})); + +vi.mock('@/api/flocksproUsers', () => ({ + flocksproUsersApi, +})); + +vi.mock('@/pages/Config', () => ({ + default: () =>
account page
, +})); + +vi.mock('@/pages/SystemLog', () => ({ + default: () =>
system logs page
, +})); + +vi.mock('@/pages/AuditLogs', () => ({ + default: () =>
audit logs page
, +})); + +vi.mock('@/pages/FlocksproUpgrade', () => ({ + default: () =>
flocks pro page
, +})); + +vi.mock('@/pages/Model', () => ({ + default: () =>
models page
, +})); + +vi.mock('@/pages/Channel', () => ({ + default: () =>
channels page
, +})); + +function LocationProbe() { + const location = useLocation(); + return
{`${location.pathname}${location.search}${location.hash}`}
; +} + +function renderSettings(path: string, theme: Theme = 'light', state?: Record) { + return render( + + + + } /> + models page
} /> + channels page
} /> + } /> + } /> + + + , + ); +} + +describe('SettingsPage', () => { + beforeEach(() => { + vi.clearAllMocks(); + flocksproUsersApi.hasCapability.mockResolvedValue(true); + useAuth.mockReturnValue({ + user: { + id: 'user-1', + username: 'admin', + role: 'admin', + status: 'active', + must_reset_password: false, + }, + }); + }); + + it('renders preference controls for language and theme', async () => { + const user = userEvent.setup(); + + renderSettings('/settings/preferences', 'light'); + + expect(screen.getByRole('heading', { name: 'settingsPreferences' })).toBeInTheDocument(); + + await user.click(screen.getByRole('button', { name: 'EN' })); + expect(changeLanguage).toHaveBeenCalledWith('en-US'); + + await user.click(screen.getByRole('button', { name: 'darkTheme' })); + expect(setTheme).toHaveBeenCalledWith('dark'); + }); + + it('redirects legacy model and channel settings URLs to workspace pages', async () => { + const { unmount } = renderSettings('/settings/models'); + + expect(await screen.findByText('models page')).toBeInTheDocument(); + expect(screen.queryByRole('link', { name: 'models' })).not.toBeInTheDocument(); + expect(screen.queryByRole('link', { name: 'channels' })).not.toBeInTheDocument(); + + unmount(); + renderSettings('/settings/channels'); + + expect(await screen.findByText('channels page')).toBeInTheDocument(); + }); + + it('returns to the page captured before opening settings', async () => { + const user = userEvent.setup(); + + renderSettings('/settings/system-logs', 'light', { + from: { + pathname: '/contracts/webui/workspaces/scene_workspace', + search: '?view=posture', + hash: '#top', + }, + }); + + expect(await screen.findByText('system logs page')).toBeInTheDocument(); + await user.click(screen.getAllByRole('link', { name: 'accountManagement' })[0]); + expect(await screen.findByText('account page')).toBeInTheDocument(); + + await user.click(screen.getAllByRole('button', { name: 'settingsBack' })[0]); + + expect(await screen.findByTestId('location')).toHaveTextContent( + '/contracts/webui/workspaces/scene_workspace?view=posture#top', + ); + }); + + it('keeps return and section navigation available outside the desktop sidebar', async () => { + renderSettings('/settings/system-logs'); + + expect(await screen.findByText('system logs page')).toBeInTheDocument(); + expect(screen.getAllByRole('button', { name: 'settingsBack' })).toHaveLength(2); + + const mobileNav = screen.getByRole('navigation', { name: 'settingsTitle' }); + expect(within(mobileNav).getByRole('link', { name: 'accountManagement' })).toHaveAttribute('href', '/settings/account'); + expect(within(mobileNav).getByRole('link', { name: 'auditLogs' })).toHaveAttribute('href', '/settings/audit-logs'); + expect(within(mobileNav).queryByRole('link', { name: 'models' })).not.toBeInTheDocument(); + expect(within(mobileNav).queryByRole('link', { name: 'channels' })).not.toBeInTheDocument(); + }); + + it('renders audit logs in settings for Flocks Pro admins', async () => { + renderSettings('/settings/audit-logs'); + + expect(await screen.findByText('audit logs page')).toBeInTheDocument(); + expect(screen.getAllByRole('link', { name: 'auditLogs' })[0]).toHaveAttribute('href', '/settings/audit-logs'); + expect(flocksproUsersApi.hasCapability).toHaveBeenCalled(); + }); + + it('hides audit logs when Flocks Pro capability is unavailable', async () => { + flocksproUsersApi.hasCapability.mockResolvedValue(false); + + renderSettings('/settings/audit-logs'); + + expect(await screen.findByRole('heading', { name: 'settingsPreferences' })).toBeInTheDocument(); + expect(screen.queryByRole('link', { name: 'auditLogs' })).not.toBeInTheDocument(); + }); + + it('hides Flocks Pro settings for non-admin users', async () => { + useAuth.mockReturnValue({ + user: { + id: 'user-2', + username: 'member', + role: 'member', + status: 'active', + must_reset_password: false, + }, + }); + + renderSettings('/settings/flockspro'); + + expect(await screen.findByRole('heading', { name: 'settingsPreferences' })).toBeInTheDocument(); + expect(screen.queryByRole('link', { name: 'flocksproUpgrade' })).not.toBeInTheDocument(); + expect(screen.queryByRole('link', { name: 'auditLogs' })).not.toBeInTheDocument(); + }); +}); diff --git a/webui/src/pages/Settings/index.tsx b/webui/src/pages/Settings/index.tsx new file mode 100644 index 000000000..64d741f3f --- /dev/null +++ b/webui/src/pages/Settings/index.tsx @@ -0,0 +1,410 @@ +import { Suspense, lazy, useContext, useEffect, useMemo, useState } from 'react'; +import type { ReactNode } from 'react'; +import { Link, Navigate, useLocation, useNavigate, useParams } from 'react-router-dom'; +import { useTranslation } from 'react-i18next'; +import { + ArrowLeft, + ArrowUpCircle, + Check, + Languages, + Moon, + ScrollText, + Settings as SettingsIcon, + ShieldCheck, + Sun, + UserCog, + type LucideIcon, +} from 'lucide-react'; +import RoutePageSkeleton from '@/components/common/RoutePageSkeleton'; +import { ThemeContext } from '@/contexts/ThemeContext'; +import { useAuth } from '@/contexts/AuthContext'; +import { flocksproUsersApi } from '@/api/flocksproUsers'; + +const ConfigPage = lazy(() => import('@/pages/Config')); +const SystemLogPage = lazy(() => import('@/pages/SystemLog')); +const FlocksproUpgradePage = lazy(() => import('@/pages/FlocksproUpgrade')); +const AuditLogsPage = lazy(() => import('@/pages/AuditLogs')); + +type SettingsSectionId = 'preferences' | 'account' | 'system-logs' | 'audit-logs' | 'flockspro'; + +interface ReturnLocation { + pathname: string; + search: string; + hash: string; +} + +interface SettingsLocationState { + from?: Partial; +} + +interface SettingsSection { + id: SettingsSectionId; + name: string; + icon: LucideIcon; + adminOnly?: boolean; + requiresFlockspro?: boolean; +} + +interface SettingsGroup { + name: string; + items: SettingsSection[]; +} + +function isSettingsSectionId(value: string | undefined): value is SettingsSectionId { + return ( + value === 'preferences' || + value === 'account' || + value === 'system-logs' || + value === 'audit-logs' || + value === 'flockspro' + ); +} + +function sanitizeReturnLocation(state: unknown): ReturnLocation { + const from = (state as SettingsLocationState | null)?.from; + const pathname = typeof from?.pathname === 'string' ? from.pathname : ''; + if (!pathname.startsWith('/') || pathname.startsWith('/settings')) { + return { pathname: '/', search: '', hash: '' }; + } + + return { + pathname, + search: typeof from?.search === 'string' && from.search.startsWith('?') ? from.search : '', + hash: typeof from?.hash === 'string' && from.hash.startsWith('#') ? from.hash : '', + }; +} + +function buildReturnPath(location: ReturnLocation): string { + return `${location.pathname}${location.search}${location.hash}`; +} + +function PreferenceRow({ + icon: Icon, + title, + description, + children, +}: { + icon: LucideIcon; + title: string; + description: string; + children: ReactNode; +}) { + return ( +
+
+ + + +
+

{title}

+

{description}

+
+
+
+ {children} +
+
+ ); +} + +function SegmentedOption({ + active, + children, + icon: Icon, + onClick, +}: { + active: boolean; + children: ReactNode; + icon: LucideIcon; + onClick: () => void; +}) { + return ( + + ); +} + +function PreferencesPanel() { + const { t, i18n } = useTranslation('nav'); + const { theme, setTheme } = useContext(ThemeContext); + const language = i18n.language?.toLowerCase().startsWith('zh') ? 'zh-CN' : 'en-US'; + + return ( +
+
+

{t('settingsPreferences')}

+

{t('settingsPreferencesDescription')}

+
+ +
+ +
+ void i18n.changeLanguage('en-US')} + > + EN + + void i18n.changeLanguage('zh-CN')} + > + 中 + +
+
+ + +
+ setTheme('light')} + > + {t('lightTheme')} + + setTheme('dark')} + > + {t('darkTheme')} + +
+
+
+
+ ); +} + +function SettingsContent({ sectionId }: { sectionId: SettingsSectionId }) { + if (sectionId === 'preferences') return ; + + return ( + }> + {sectionId === 'account' && } + {sectionId === 'system-logs' && } + {sectionId === 'audit-logs' && } + {sectionId === 'flockspro' && } + + ); +} + +export default function SettingsPage() { + const params = useParams(); + const location = useLocation(); + const navigate = useNavigate(); + const { t } = useTranslation('nav'); + const { user } = useAuth(); + const isAdmin = user?.role === 'admin'; + const sectionId = params.sectionId; + const [flocksproCapabilityReady, setFlocksproCapabilityReady] = useState(false); + const [hasFlocksproCapability, setHasFlocksproCapability] = useState(false); + const returnLocation = useMemo(() => sanitizeReturnLocation(location.state), [location.state]); + const settingsRouteState = useMemo(() => ({ from: returnLocation }), [returnLocation]); + + useEffect(() => { + let cancelled = false; + if (!isAdmin) { + setHasFlocksproCapability(false); + setFlocksproCapabilityReady(true); + return () => { + cancelled = true; + }; + } + + setFlocksproCapabilityReady(false); + const refreshCapability = () => { + void flocksproUsersApi.hasCapability() + .then((ok) => { + if (!cancelled) { + setHasFlocksproCapability(ok); + } + }) + .catch(() => { + if (!cancelled) { + setHasFlocksproCapability(false); + } + }) + .finally(() => { + if (!cancelled) { + setFlocksproCapabilityReady(true); + } + }); + }; + + refreshCapability(); + window.addEventListener('flockspro-license-status-changed', refreshCapability); + return () => { + cancelled = true; + window.removeEventListener('flockspro-license-status-changed', refreshCapability); + }; + }, [isAdmin]); + + const groups = useMemo( + () => [ + { + name: t('settingsGroupPreferences'), + items: [ + { id: 'preferences', name: t('settingsPreferences'), icon: SettingsIcon }, + ], + }, + { + name: t('settingsGroupSystem'), + items: [ + { id: 'account', name: t('accountManagement'), icon: UserCog }, + { id: 'system-logs', name: t('systemLog'), icon: ScrollText }, + { id: 'audit-logs', name: t('auditLogs'), icon: ShieldCheck, adminOnly: true, requiresFlockspro: true }, + { id: 'flockspro', name: t('flocksproUpgrade'), icon: ArrowUpCircle, adminOnly: true }, + ], + }, + ], + [t], + ); + + const visibleGroups = groups + .map((group) => ({ + ...group, + items: group.items.filter((item) => { + if (item.adminOnly && !isAdmin) return false; + if (item.requiresFlockspro && flocksproCapabilityReady && !hasFlocksproCapability) return false; + return true; + }), + })) + .filter((group) => group.items.length > 0); + + if (!sectionId) { + return ; + } + + if (sectionId === 'models') { + return ; + } + + if (sectionId === 'channels') { + return ; + } + + if (!isSettingsSectionId(sectionId)) { + return ; + } + + const currentSection = visibleGroups.flatMap((group) => group.items).find((item) => item.id === sectionId); + + if (!currentSection) { + return ; + } + + return ( +
+ + +
+
+ +
+

{t('settingsTitle')}

+

{currentSection.name}

+
+ +
+ +
+ +
+
+
+ ); +} diff --git a/webui/src/pages/UserDefinedPageHost/index.tsx b/webui/src/pages/WebUIContractPageHost/PageRuntimeHost.tsx similarity index 72% rename from webui/src/pages/UserDefinedPageHost/index.tsx rename to webui/src/pages/WebUIContractPageHost/PageRuntimeHost.tsx index a30cfce6f..e1cccee74 100644 --- a/webui/src/pages/UserDefinedPageHost/index.tsx +++ b/webui/src/pages/WebUIContractPageHost/PageRuntimeHost.tsx @@ -7,40 +7,39 @@ import { useEffect, useState, } from 'react'; -import { useParams } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import i18n from '@/i18n'; import { AlertCircle, Loader2 } from 'lucide-react'; import { getApiBase } from '@/api/client'; -import { userDefinedPagesAPI } from '@/api/userDefinedPages'; +import { webuiContractPagesAPI } from '@/api/webuiContractPages'; import { useSSE } from '@/hooks/useSSE'; -import { installUserDefinedPageRuntime, loadUserDefinedPageBundle } from './runtime'; +import { installWebUIContractPageRuntime, loadWebUIContractPageBundle } from './runtime'; -interface UserDefinedPageErrorBoundaryProps { +interface WebUIContractPageErrorBoundaryProps { children: ReactNode; errorTitle: string; fallbackMessage: string; onError?: (message: string) => void; } -interface UserDefinedPageErrorBoundaryState { +interface WebUIContractPageErrorBoundaryState { hasError: boolean; message: string; } -class UserDefinedPageErrorBoundary extends Component< - UserDefinedPageErrorBoundaryProps, - UserDefinedPageErrorBoundaryState +class WebUIContractPageErrorBoundary extends Component< + WebUIContractPageErrorBoundaryProps, + WebUIContractPageErrorBoundaryState > { - state: UserDefinedPageErrorBoundaryState = { hasError: false, message: '' }; + state: WebUIContractPageErrorBoundaryState = { hasError: false, message: '' }; - static getDerivedStateFromError(error: Error): UserDefinedPageErrorBoundaryState { + static getDerivedStateFromError(error: Error): WebUIContractPageErrorBoundaryState { return { hasError: true, message: error.message || '' }; } componentDidCatch(error: Error, info: ErrorInfo) { this.props.onError?.(error.message || this.props.fallbackMessage); - console.error('[UserDefinedPageHost] render error:', error, info); + console.error('[WebUIContractPageHost] render error:', error, info); } render() { @@ -60,11 +59,14 @@ class UserDefinedPageErrorBoundary extends Component< } } -export default function UserDefinedPageHost() { - const { pageId } = useParams<{ pageId: string }>(); - const { t } = useTranslation('userDefinedPage'); +interface PageRuntimeHostProps { + pageId?: string; +} + +export default function PageRuntimeHost({ pageId }: PageRuntimeHostProps) { + const { t } = useTranslation('webuiContractPage'); const tr = useCallback( - (key: string) => i18n.t(key, { ns: 'userDefinedPage' }), + (key: string) => i18n.t(key, { ns: 'webuiContractPage' }), [], ); const [PageComponent, setPageComponent] = useState(null); @@ -74,10 +76,10 @@ export default function UserDefinedPageHost() { const loadBundle = useCallback(async (hash: string) => { if (!pageId || !hash) return; - installUserDefinedPageRuntime(pageId); + installWebUIContractPageRuntime(pageId); const base = getApiBase(); - const url = `${base}/api/user-defined-pages/${encodeURIComponent(pageId)}/bundle.js?v=${encodeURIComponent(hash)}`; - const component = await loadUserDefinedPageBundle(url, tr('host.bundleMissingExport')); + const url = `${base}/api/contracts/webui/pages/${encodeURIComponent(pageId)}/bundle.js?v=${encodeURIComponent(hash)}`; + const component = await loadWebUIContractPageBundle(url, tr('host.bundleMissingExport')); setPageComponent(() => component); setError(null); }, [pageId, tr]); @@ -86,7 +88,7 @@ export default function UserDefinedPageHost() { if (!pageId) return; setLoading(true); try { - const response = await userDefinedPagesAPI.get(pageId); + const response = await webuiContractPagesAPI.get(pageId); const nextHash = hash || response.data.build.hash; setBuildHash(nextHash); if (response.data.build.status !== 'ready' || !nextHash) { @@ -111,21 +113,21 @@ export default function UserDefinedPageHost() { url: '/api/event', onEvent: useCallback((evt) => { if (!pageId) return; - if (evt.type === 'user_defined_pages.updated' && evt.properties?.id === pageId) { + if (evt.type === 'contracts.webui.pages.updated' && evt.properties?.id === pageId) { const hash = evt.properties?.hash as string | undefined; void refreshPage(hash); return; } - if (evt.type === 'user_defined_pages.build_failed' && evt.properties?.id === pageId) { + if (evt.type === 'contracts.webui.pages.build_failed' && evt.properties?.id === pageId) { setError((evt.properties?.error as string | undefined) || tr('host.buildFailed')); setLoading(false); return; } - if (evt.type === 'user_defined_pages.api_changed' && evt.properties?.id === pageId) { + if (evt.type === 'contracts.webui.pages.api_changed' && evt.properties?.id === pageId) { setError(null); return; } - if (evt.type === 'user_defined_pages.api_failed' && evt.properties?.id === pageId) { + if (evt.type === 'contracts.webui.pages.api_failed' && evt.properties?.id === pageId) { setError((evt.properties?.error as string | undefined) || tr('host.apiFailed')); setLoading(false); } @@ -170,12 +172,12 @@ export default function UserDefinedPageHost() { } return ( - - + ); } diff --git a/webui/src/pages/UserDefinedPageHost/index.test.tsx b/webui/src/pages/WebUIContractPageHost/index.test.tsx similarity index 74% rename from webui/src/pages/UserDefinedPageHost/index.test.tsx rename to webui/src/pages/WebUIContractPageHost/index.test.tsx index 3f95164ce..ad265474e 100644 --- a/webui/src/pages/UserDefinedPageHost/index.test.tsx +++ b/webui/src/pages/WebUIContractPageHost/index.test.tsx @@ -1,7 +1,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; import { render, screen, waitFor } from '@testing-library/react'; import { MemoryRouter, Route, Routes } from 'react-router-dom'; -import UserDefinedPageHost from './index'; +import WebUIContractPageHost from './index'; import { setupSSEMock } from '@/test/mocks/sse'; const { getMock, loadBundleMock, installMock } = vi.hoisted(() => ({ @@ -10,8 +10,8 @@ const { getMock, loadBundleMock, installMock } = vi.hoisted(() => ({ installMock: vi.fn(), })); -vi.mock('@/api/userDefinedPages', () => ({ - userDefinedPagesAPI: { +vi.mock('@/api/webuiContractPages', () => ({ + webuiContractPagesAPI: { get: getMock, }, })); @@ -21,8 +21,8 @@ vi.mock('@/api/client', () => ({ })); vi.mock('./runtime', () => ({ - installUserDefinedPageRuntime: installMock, - loadUserDefinedPageBundle: loadBundleMock, + installWebUIContractPageRuntime: installMock, + loadWebUIContractPageBundle: loadBundleMock, })); vi.mock('@/i18n', () => ({ @@ -40,10 +40,10 @@ vi.mock('react-i18next', () => ({ })); function MockPage() { - return
自定义页面内容
; + return
契约页面内容
; } -describe('UserDefinedPageHost', () => { +describe('WebUIContractPageHost', () => { setupSSEMock(); beforeEach(() => { @@ -57,7 +57,7 @@ describe('UserDefinedPageHost', () => { manifest: { id: 'dash-1', title: '仪表盘', - route: '/user-defined-pages/dash-1', + route: '/contracts/webui/dash-1', icon: 'LayoutDashboard', order: 10, enabled: true, @@ -76,19 +76,19 @@ describe('UserDefinedPageHost', () => { }); render( - + - } /> + } /> , ); await waitFor(() => { - expect(screen.getByText('自定义页面内容')).toBeInTheDocument(); + expect(screen.getByText('契约页面内容')).toBeInTheDocument(); }); expect(installMock).toHaveBeenCalledWith('dash-1'); expect(loadBundleMock).toHaveBeenCalledWith( - 'https://api.example.test/api/user-defined-pages/dash-1/bundle.js?v=abc123', + 'https://api.example.test/api/contracts/webui/pages/dash-1/bundle.js?v=abc123', 'host.bundleMissingExport', ); }); @@ -99,7 +99,7 @@ describe('UserDefinedPageHost', () => { manifest: { id: 'dash-2', title: '失败页', - route: '/user-defined-pages/dash-2', + route: '/contracts/webui/dash-2', icon: 'LayoutDashboard', order: 20, enabled: true, @@ -118,9 +118,9 @@ describe('UserDefinedPageHost', () => { }); render( - + - } /> + } /> , ); diff --git a/webui/src/pages/WebUIContractPageHost/index.tsx b/webui/src/pages/WebUIContractPageHost/index.tsx new file mode 100644 index 000000000..f7900f995 --- /dev/null +++ b/webui/src/pages/WebUIContractPageHost/index.tsx @@ -0,0 +1,7 @@ +import { useParams } from 'react-router-dom'; +import PageRuntimeHost from './PageRuntimeHost'; + +export default function WebUIContractPageHost() { + const { pageId } = useParams<{ pageId: string }>(); + return ; +} diff --git a/webui/src/pages/UserDefinedPageHost/runtime.test.tsx b/webui/src/pages/WebUIContractPageHost/runtime.test.tsx similarity index 52% rename from webui/src/pages/UserDefinedPageHost/runtime.test.tsx rename to webui/src/pages/WebUIContractPageHost/runtime.test.tsx index cce4b5c57..7390e79ce 100644 --- a/webui/src/pages/UserDefinedPageHost/runtime.test.tsx +++ b/webui/src/pages/WebUIContractPageHost/runtime.test.tsx @@ -1,18 +1,33 @@ import { describe, expect, it, vi } from 'vitest'; import apiClient from '@/api/client'; -import { installUserDefinedPageRuntime, loadUserDefinedPageBundle } from './runtime'; +import { installWebUIContractPageRuntime, loadWebUIContractPageBundle } from './runtime'; -describe('UserDefinedPage runtime', () => { +describe('WebUIContractPage runtime', () => { it('exposes page-scoped api helper', async () => { const getSpy = vi.spyOn(apiClient, 'get').mockResolvedValue({ data: {} } as never); - installUserDefinedPageRuntime('dash-1'); - const sdk = window.__FLOCKS_USER_DEFINED_PAGE_SDK__; + installWebUIContractPageRuntime('dash-1'); + const sdk = window.__FLOCKS_WEBUI_CONTRACT_SDK__; expect(sdk).toBeTruthy(); await sdk!.api.page.get('/stats'); - expect(getSpy).toHaveBeenCalledWith('/api/user-defined-pages/dash-1/api/stats', undefined); + expect(getSpy).toHaveBeenCalledWith('/api/contracts/webui/pages/dash-1/api/stats', undefined); getSpy.mockRestore(); }); + it('exposes contract operation helper', async () => { + const postSpy = vi.spyOn(apiClient, 'post').mockResolvedValue({ data: {} } as never); + installWebUIContractPageRuntime('dash-1'); + const sdk = window.__FLOCKS_WEBUI_CONTRACT_SDK__; + await sdk!.api + .contract('records/list', 'records.operations') + .operation('list', { params: { limit: 10 } }); + expect(postSpy).toHaveBeenCalledWith( + '/api/contracts/webui/pages/records/list/access/records.operations/operations/list', + { params: { limit: 10 } }, + undefined, + ); + postSpy.mockRestore(); + }); + it('loads page bundles through the credentialed api client', async () => { const source = 'export default function Page(){return null;}'; const getSpy = vi.spyOn(apiClient, 'get').mockResolvedValue({ data: source } as never); @@ -21,14 +36,14 @@ describe('UserDefinedPage runtime', () => { .mockReturnValue(`data:text/javascript,${encodeURIComponent(source)}`); const revokeObjectURLSpy = vi.spyOn(URL, 'revokeObjectURL').mockImplementation(() => {}); - const component = await loadUserDefinedPageBundle( - 'https://api.example.test/api/user-defined-pages/dash-1/bundle.js?v=abc123', + const component = await loadWebUIContractPageBundle( + 'https://api.example.test/api/contracts/webui/pages/dash-1/bundle.js?v=abc123', 'missing default', ); expect(component).toEqual(expect.any(Function)); expect(getSpy).toHaveBeenCalledWith( - 'https://api.example.test/api/user-defined-pages/dash-1/bundle.js?v=abc123', + 'https://api.example.test/api/contracts/webui/pages/dash-1/bundle.js?v=abc123', { responseType: 'text' }, ); expect(createObjectURLSpy).toHaveBeenCalledWith(expect.any(Blob)); diff --git a/webui/src/pages/UserDefinedPageHost/runtime.tsx b/webui/src/pages/WebUIContractPageHost/runtime.tsx similarity index 66% rename from webui/src/pages/UserDefinedPageHost/runtime.tsx rename to webui/src/pages/WebUIContractPageHost/runtime.tsx index 6a0610f78..e2d53fcfa 100644 --- a/webui/src/pages/UserDefinedPageHost/runtime.tsx +++ b/webui/src/pages/WebUIContractPageHost/runtime.tsx @@ -4,7 +4,7 @@ import type { AxiosRequestConfig, AxiosResponse } from 'axios'; import apiClient from '@/api/client'; import { useAuth } from '@/contexts/AuthContext'; -interface UserDefinedPageScopedApi { +interface WebUIContractPageScopedApi { get(path: string, config?: AxiosRequestConfig): Promise>; post(path: string, data?: unknown, config?: AxiosRequestConfig): Promise>; put(path: string, data?: unknown, config?: AxiosRequestConfig): Promise>; @@ -12,22 +12,31 @@ interface UserDefinedPageScopedApi { delete(path: string, config?: AxiosRequestConfig): Promise>; } -type UserDefinedPageApiClient = typeof apiClient & { - page: UserDefinedPageScopedApi; +interface WebUIContractOperationApi { + operation( + operationName: string, + data?: unknown, + config?: AxiosRequestConfig, + ): Promise>; +} + +type WebUIContractPageApiClient = typeof apiClient & { + page: WebUIContractPageScopedApi; + contract(pagePath: string, contractId: string): WebUIContractOperationApi; }; -export interface UserDefinedPageSdk { +export interface WebUIContractPageSdk { React: typeof React; jsx: typeof jsx; jsxs: typeof jsxs; - api: UserDefinedPageApiClient; + api: WebUIContractPageApiClient; Card: typeof Card; useCurrentUser: typeof useCurrentUser; } declare global { interface Window { - __FLOCKS_USER_DEFINED_PAGE_SDK__?: UserDefinedPageSdk; + __FLOCKS_WEBUI_CONTRACT_SDK__?: WebUIContractPageSdk; } } @@ -50,8 +59,16 @@ function normalizePageApiPath(path: string): string { return path.startsWith('/') ? path : `/${path}`; } -function createScopedApi(pageId: string): UserDefinedPageScopedApi { - const base = `/api/user-defined-pages/${encodeURIComponent(pageId)}/api`; +function encodePagePath(pagePath: string): string { + return pagePath + .split('/') + .filter(Boolean) + .map((part) => encodeURIComponent(part)) + .join('/'); +} + +function createScopedApi(pageId: string): WebUIContractPageScopedApi { + const base = `/api/contracts/webui/pages/${encodeURIComponent(pageId)}/api`; return { get(path, config) { return apiClient.get(`${base}${normalizePageApiPath(path)}`, config); @@ -71,11 +88,21 @@ function createScopedApi(pageId: string): UserDefinedPageScopedApi { }; } -export function installUserDefinedPageRuntime(pageId: string): void { +function createContractApi(pagePath: string, contractId: string): WebUIContractOperationApi { + const base = `/api/contracts/webui/pages/${encodePagePath(pagePath)}/access/${encodeURIComponent(contractId)}/operations`; + return { + operation(operationName, data, config) { + return apiClient.post(`${base}/${encodeURIComponent(operationName)}`, data, config); + }, + }; +} + +export function installWebUIContractPageRuntime(pageId: string): void { if (typeof window === 'undefined') return; - const api = apiClient as UserDefinedPageApiClient; + const api = apiClient as WebUIContractPageApiClient; api.page = createScopedApi(pageId); - window.__FLOCKS_USER_DEFINED_PAGE_SDK__ = { + api.contract = createContractApi; + window.__FLOCKS_WEBUI_CONTRACT_SDK__ = { React, jsx, jsxs, @@ -85,7 +112,7 @@ export function installUserDefinedPageRuntime(pageId: string): void { }; } -export async function loadUserDefinedPageBundle( +export async function loadWebUIContractPageBundle( url: string, missingExportMessage = 'Page bundle does not export a default component', ): Promise { diff --git a/webui/src/pages/WebUIContractWorkspaceHost/index.test.tsx b/webui/src/pages/WebUIContractWorkspaceHost/index.test.tsx new file mode 100644 index 000000000..926803175 --- /dev/null +++ b/webui/src/pages/WebUIContractWorkspaceHost/index.test.tsx @@ -0,0 +1,175 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { render, screen, waitFor } from '@testing-library/react'; +import { MemoryRouter, Route, Routes } from 'react-router-dom'; +import WebUIContractWorkspaceHost from './index'; +import { setupSSEMock } from '@/test/mocks/sse'; +import { ThemeContext } from '@/contexts/ThemeContext'; + +const { listWorkspacesMock } = vi.hoisted(() => ({ + listWorkspacesMock: vi.fn(), +})); + +vi.mock('@/api/webuiContractPages', () => ({ + webuiContractPagesAPI: { + listWorkspaces: listWorkspacesMock, + }, +})); + +vi.mock('@/pages/WebUIContractPageHost/PageRuntimeHost', () => ({ + default: ({ pageId }: { pageId?: string }) =>
page:{pageId}
, +})); + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + i18n: { language: 'zh-CN' }, + }), +})); + +describe('WebUIContractWorkspaceHost', () => { + setupSSEMock(); + + beforeEach(() => { + vi.clearAllMocks(); + listWorkspacesMock.mockResolvedValue({ + data: [ + { + id: 'scene_workspace', + title: '场景工作区', + route: '/contracts/webui/workspaces/scene_workspace', + icon: 'ShieldCheck', + order: 10, + enabled: true, + placement: 'sceneWorkspace', + defaultPageId: 'ops-overview', + sections: [ + { + id: 'posture', + label: '态势', + pageIds: ['risk-dashboard'], + defaultPageId: 'risk-dashboard', + contentPadding: 'none', + themeOverride: 'dark', + }, + { + id: 'operations', + label: '调查列表', + pageIds: ['ops-overview', 'investigation-list'], + defaultPageId: 'ops-overview', + contentPadding: 'comfortable', + }, + ], + pages: [ + { + id: 'risk-dashboard', + title: '态势看板', + route: '/contracts/webui/risk-dashboard', + icon: 'ShieldCheck', + order: 30, + enabled: true, + placement: 'home.after', + buildHash: 'posture', + buildStatus: 'ready', + workspaceId: 'scene_workspace', + workspaceTitle: '场景工作区', + workspaceRoute: '/contracts/webui/workspaces/scene_workspace', + }, + { + id: 'ops-overview', + title: '运营总览', + route: '/contracts/webui/ops-overview', + icon: 'Shield', + order: 10, + enabled: true, + placement: 'home.after', + buildHash: 'abc', + buildStatus: 'ready', + workspaceId: 'scene_workspace', + workspaceTitle: '场景工作区', + workspaceRoute: '/contracts/webui/workspaces/scene_workspace', + }, + { + id: 'investigation-list', + title: '调查列表', + route: '/contracts/webui/investigation-list', + icon: 'AlertTriangle', + order: 20, + enabled: true, + placement: 'home.after', + buildHash: '', + buildStatus: 'failed', + workspaceId: 'scene_workspace', + workspaceTitle: '场景工作区', + workspaceRoute: '/contracts/webui/workspaces/scene_workspace', + }, + ], + }, + ], + }); + }); + + it('waits for an explicit page selection on the workspace root', async () => { + render( + + + } /> + + , + ); + + await waitFor(() => { + expect(screen.getByText('workspace.selectPage')).toBeInTheDocument(); + }); + expect(screen.queryByText('page:risk-dashboard')).not.toBeInTheDocument(); + expect(screen.queryByRole('navigation', { name: 'workspace.sectionNavigation' })).not.toBeInTheDocument(); + }); + + it('renders a selected operation page without a fixed workspace sidebar', async () => { + render( + + + } /> + + , + ); + + await waitFor(() => { + expect(screen.getByText('page:investigation-list')).toBeInTheDocument(); + }); + expect(screen.getByText('page:investigation-list').parentElement).toHaveClass('p-6'); + expect(screen.queryByRole('navigation', { name: 'workspace.sectionNavigation' })).not.toBeInTheDocument(); + }); + + it('temporarily uses dark theme for the posture dashboard when the user preference is light', async () => { + const setTemporaryThemeOverride = vi.fn(); + const { unmount } = render( + + + + } /> + + + , + ); + + await waitFor(() => { + expect(screen.getByText('page:risk-dashboard')).toBeInTheDocument(); + }); + expect(screen.getByText('page:risk-dashboard').parentElement).not.toHaveClass('p-6'); + await waitFor(() => { + expect(setTemporaryThemeOverride).toHaveBeenCalledWith('dark'); + }); + + unmount(); + + expect(setTemporaryThemeOverride).toHaveBeenLastCalledWith(null); + }); +}); diff --git a/webui/src/pages/WebUIContractWorkspaceHost/index.tsx b/webui/src/pages/WebUIContractWorkspaceHost/index.tsx new file mode 100644 index 000000000..251584c57 --- /dev/null +++ b/webui/src/pages/WebUIContractWorkspaceHost/index.tsx @@ -0,0 +1,143 @@ +import { useCallback, useContext, useEffect, useMemo, useState } from 'react'; +import { useParams } from 'react-router-dom'; +import { useTranslation } from 'react-i18next'; +import { AlertCircle, Loader2 } from 'lucide-react'; +import { + webuiContractPagesAPI, + type WebUIContractWorkspaceListItem, +} from '@/api/webuiContractPages'; +import { useSSE } from '@/hooks/useSSE'; +import { ThemeContext } from '@/contexts/ThemeContext'; +import PageRuntimeHost from '@/pages/WebUIContractPageHost/PageRuntimeHost'; +import { buildWebUIContractWorkspaceSections } from '@/utils/webuiContractWorkspaceSections'; + +export default function WebUIContractWorkspaceHost() { + const { workspaceId, pageId } = useParams<{ workspaceId: string; pageId?: string }>(); + const { t } = useTranslation('webuiContractPage'); + const [workspaces, setWorkspaces] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const { theme, setTemporaryThemeOverride } = useContext(ThemeContext); + + const fetchWorkspaces = useCallback(async (silent = false) => { + if (!silent) setLoading(true); + setError(null); + try { + const response = await webuiContractPagesAPI.listWorkspaces(true); + setWorkspaces(Array.isArray(response.data) ? response.data : []); + } catch (err: unknown) { + setWorkspaces([]); + setError(err instanceof Error ? err.message : t('workspace.loadFailed')); + } finally { + if (!silent) setLoading(false); + } + }, [t]); + + useEffect(() => { + void fetchWorkspaces(); + }, [fetchWorkspaces]); + + useSSE({ + url: '/api/event', + onEvent: useCallback((evt) => { + if (evt.type === 'contracts.webui.pages.nav_changed') { + void fetchWorkspaces(true); + } + }, [fetchWorkspaces]), + reconnect: { maxRetries: 5, initialDelay: 2000 }, + }); + + const workspace = useMemo( + () => workspaces.find((item) => item.id === workspaceId), + [workspaceId, workspaces], + ); + const pages = useMemo( + () => [...(workspace?.pages ?? [])].sort((a, b) => a.order - b.order || a.title.localeCompare(b.title)), + [workspace?.pages], + ); + const sections = useMemo( + () => (workspace ? buildWebUIContractWorkspaceSections(workspace) : []), + [workspace], + ); + const currentPage = pages.find((page) => page.id === pageId); + const currentSection = currentPage + ? sections.find((section) => section.pages.some((page) => page.id === currentPage.id)) + : undefined; + const temporaryThemeOverride = currentSection?.themeOverride && theme !== currentSection.themeOverride + ? currentSection.themeOverride + : null; + + useEffect(() => { + if (!temporaryThemeOverride) { + setTemporaryThemeOverride(null); + return undefined; + } + + setTemporaryThemeOverride(temporaryThemeOverride); + return () => setTemporaryThemeOverride(null); + }, [setTemporaryThemeOverride, temporaryThemeOverride]); + + if (!workspaceId) { + return
{t('workspace.missingWorkspaceId')}
; + } + + if (loading) { + return ( +
+ + {t('workspace.loading')} +
+ ); + } + + if (error) { + return ( +
+ +
+
{t('workspace.unavailableTitle')}
+
{error}
+ +
+
+ ); + } + + if (!workspace) { + return
{t('workspace.notFound')}
; + } + + if (pages.length === 0) { + return
{t('workspace.empty')}
; + } + + if (!pageId) { + return ( +
+ {t('workspace.selectPage')} +
+ ); + } + + if (!currentPage) { + return
{t('workspace.pageNotFound')}
; + } + + const pageContentClassName = currentSection?.contentPadding === 'none' + ? 'h-full min-w-0 overflow-x-auto' + : 'h-full min-w-0 overflow-x-auto p-6'; + + return ( +
+
+ +
+
+ ); +} diff --git a/webui/src/pages/Workspace/index.test.tsx b/webui/src/pages/Workspace/index.test.tsx index 678bc0ab7..f13ad052d 100644 --- a/webui/src/pages/Workspace/index.test.tsx +++ b/webui/src/pages/Workspace/index.test.tsx @@ -13,6 +13,7 @@ const mocks = vi.hoisted(() => ({ deleteDir: vi.fn(), upload: vi.fn(), createDir: vi.fn(), + reveal: vi.fn(), listMemory: vi.fn(), readMemoryFile: vi.fn(), confirm: vi.fn(), @@ -20,6 +21,31 @@ const mocks = vi.hoisted(() => ({ toastError: vi.fn(), })); +const pdfMocks = vi.hoisted(() => { + const renderPage = vi.fn(() => ({ promise: Promise.resolve(), cancel: vi.fn() })); + const getPage = vi.fn(() => Promise.resolve({ + getViewport: () => ({ width: 600, height: 800 }), + render: renderPage, + })); + const destroyDocument = vi.fn(); + const destroyTask = vi.fn(); + const getDocument = vi.fn(() => ({ + promise: Promise.resolve({ + numPages: 3, + getPage, + destroy: destroyDocument, + }), + destroy: destroyTask, + })); + return { + getDocument, + getPage, + renderPage, + destroyDocument, + destroyTask, + }; +}); + const translations: Record = { description: 'Workspace files', 'tabs.files': 'Files', @@ -33,9 +59,28 @@ const translations: Record = { 'files.back': 'Back', 'files.delete': 'Delete', 'files.download': 'Download', + 'files.reveal': 'Open containing folder', 'files.downloadFile': 'Download file', 'files.binaryPreview': 'Binary file cannot be previewed', 'files.truncatedPreview': 'Preview truncated to first {{limit}}', + 'files.preview.previewMode': 'Preview', + 'files.preview.sourceMode': 'Source', + 'files.preview.fullscreen': 'Fullscreen preview', + 'files.preview.resize': 'Drag to resize preview', + 'files.preview.htmlSandbox': 'HTML sandboxed', + 'files.preview.jsonParseFailed': 'JSON parse failed', + 'files.preview.jsonlParseFailed': '{{count}} JSONL lines failed', + 'files.preview.pdfLoading': 'Loading PDF', + 'files.preview.pdfRendering': 'Rendering page', + 'files.preview.pdfLoadFailed': 'Failed to load PDF preview', + 'files.preview.pdfCanvasUnavailable': 'Canvas unavailable', + 'files.preview.pageIndicator': '{{page}} / {{total}}', + 'files.preview.previousPage': 'Previous page', + 'files.preview.nextPage': 'Next page', + 'files.preview.zoomIn': 'Zoom in', + 'files.preview.zoomOut': 'Zoom out', + 'files.preview.unsupportedTitle': 'This file cannot be previewed', + 'files.preview.unsupportedDesc': 'Download it or open containing folder', 'files.emptyDir': 'Empty directory', 'files.dropHere': 'Drop files here', 'files.uploading': 'Uploading', @@ -62,12 +107,27 @@ vi.mock('react-i18next', () => ({ if (key === 'files.truncatedPreview') { return `Preview truncated to first ${params?.limit ?? ''}`; } + if (key === 'files.preview.jsonlParseFailed') { + return `${params?.count ?? ''} JSONL lines failed`; + } + if (key === 'files.preview.pageIndicator') { + return `${params?.page ?? ''} / ${params?.total ?? ''}`; + } return translations[key] ?? key; }, i18n: { language: 'en-US' }, }), })); +vi.mock('pdfjs-dist', () => ({ + GlobalWorkerOptions: {}, + getDocument: pdfMocks.getDocument, +})); + +vi.mock('pdfjs-dist/build/pdf.worker.min.mjs?url', () => ({ + default: '/pdf.worker.min.mjs', +})); + vi.mock('@/components/common/Toast', () => ({ useToast: () => ({ success: mocks.toastSuccess, @@ -75,6 +135,11 @@ vi.mock('@/components/common/Toast', () => ({ }), })); +Object.defineProperty(HTMLCanvasElement.prototype, 'getContext', { + configurable: true, + value: vi.fn(() => ({})), +}); + vi.mock('@/components/common/ConfirmDialog', () => ({ useConfirm: () => mocks.confirm, })); @@ -105,9 +170,13 @@ vi.mock('@/api/workspace', async () => { deleteDir: mocks.deleteDir, upload: mocks.upload, createDir: mocks.createDir, + reveal: mocks.reveal, listMemory: mocks.listMemory, readMemoryFile: mocks.readMemoryFile, downloadUrl: (path: string) => `/api/workspace/download?path=${encodeURIComponent(path)}`, + previewUrl: (path: string) => `/api/workspace/preview?path=${encodeURIComponent(path)}`, + memoryDownloadUrl: (path: string) => `/api/workspace/memory/download?path=${encodeURIComponent(path)}`, + memoryPreviewUrl: (path: string) => `/api/workspace/memory/preview?path=${encodeURIComponent(path)}`, }, }; }); @@ -135,12 +204,14 @@ function file(name: string, path: string, isTextFile = true) { describe('WorkspacePage', () => { beforeEach(() => { vi.clearAllMocks(); + mocks.list.mockResolvedValue({ data: [] }); mocks.readFile.mockResolvedValue({ data: { content: '' } }); mocks.writeFile.mockResolvedValue({ data: { written: true } }); mocks.deleteFile.mockResolvedValue({ data: { deleted: true } }); mocks.deleteDir.mockResolvedValue({ data: { deleted: true } }); mocks.upload.mockResolvedValue({ data: { uploaded: [] } }); mocks.createDir.mockResolvedValue({ data: { created: true } }); + mocks.reveal.mockResolvedValue({ data: { opened: true } }); mocks.listMemory.mockResolvedValue({ data: [] }); mocks.readMemoryFile.mockResolvedValue({ data: { content: '' } }); mocks.confirm.mockResolvedValue(true); @@ -204,7 +275,280 @@ describe('WorkspacePage', () => { await user.click(await screen.findByText('events.jsonl')); expect(await screen.findByText('Preview truncated to first 16 B')).toBeInTheDocument(); - expect(screen.getByText('{"id":1}')).toBeInTheDocument(); + expect(screen.getByText(/"id": 1/)).toBeInTheDocument(); expect(screen.queryByTitle('Edit')).not.toBeInTheDocument(); }); + + it('Markdown 文件默认渲染预览,并可打开全屏预览', async () => { + mocks.list.mockResolvedValue({ + data: [file('README.md', 'README.md')], + }); + mocks.readFile.mockResolvedValue({ + data: { + path: 'README.md', + content: '# Hello\n\n**World**', + truncated: false, + }, + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(await screen.findByText('README.md')); + + expect(await screen.findByRole('heading', { name: 'Hello' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Preview' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Source' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Drag to resize preview' })).toBeInTheDocument(); + + await user.click(screen.getByTitle('Fullscreen preview')); + expect(screen.getAllByRole('heading', { name: 'Hello' })).toHaveLength(2); + }); + + it('JSON 文件默认格式化显示,并可切换源码', async () => { + mocks.list.mockResolvedValue({ + data: [file('payload.json', 'payload.json')], + }); + mocks.readFile.mockResolvedValue({ + data: { + path: 'payload.json', + content: '{"message":"ok","count":2}', + truncated: false, + }, + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(await screen.findByText('payload.json')); + + expect(await screen.findByText(/"message": "ok"/)).toBeInTheDocument(); + + await user.click(screen.getByRole('button', { name: 'Source' })); + expect(screen.getByText('{"message":"ok","count":2}')).toBeInTheDocument(); + }); + + it('CSV 文件默认展示表格,并可切换源码', async () => { + mocks.list.mockResolvedValue({ + data: [file('table.csv', 'table.csv')], + }); + mocks.readFile.mockResolvedValue({ + data: { + path: 'table.csv', + content: 'name,count\nalpha,2\n"beta, inc",5', + truncated: false, + }, + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(await screen.findByText('table.csv')); + + expect(await screen.findByRole('columnheader', { name: 'name' })).toBeInTheDocument(); + expect(screen.getByRole('columnheader', { name: 'count' })).toBeInTheDocument(); + expect(screen.getByText('alpha')).toBeInTheDocument(); + expect(screen.getByText('beta, inc')).toBeInTheDocument(); + + await user.click(screen.getByRole('button', { name: 'Source' })); + expect(screen.getByText(/name,count/)).toBeInTheDocument(); + }); + + it('PDF 文件使用 inline preview 地址展示', async () => { + mocks.list.mockResolvedValue({ + data: [file('report.pdf', 'report.pdf', false)], + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(await screen.findByText('report.pdf')); + + await waitFor(() => { + expect(pdfMocks.getDocument).toHaveBeenCalledWith({ + url: '/api/workspace/preview?path=report.pdf', + withCredentials: true, + }); + }); + expect(await screen.findByText('1 / 3')).toBeInTheDocument(); + expect(pdfMocks.getPage).toHaveBeenCalledWith(1); + expect(pdfMocks.renderPage).toHaveBeenCalled(); + expect(screen.getByTitle('Previous page')).toBeDisabled(); + expect(screen.getByTitle('Next page')).toBeEnabled(); + }); + + it('Memory Markdown 文件复用预览渲染和全屏预览', async () => { + mocks.listMemory.mockResolvedValue({ + data: [file('MEMORY.md', 'MEMORY.md')], + }); + mocks.readMemoryFile.mockResolvedValue({ + data: { + path: 'MEMORY.md', + content: '# Memory\n\n**Fact**', + truncated: false, + }, + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(screen.getByRole('button', { name: 'Memory' })); + await user.click(await screen.findByText('MEMORY.md')); + + expect(await screen.findByRole('heading', { name: 'Memory' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Preview' })).toBeInTheDocument(); + expect(screen.getByRole('button', { name: 'Source' })).toBeInTheDocument(); + + await user.click(screen.getByTitle('Fullscreen preview')); + expect(screen.getAllByRole('heading', { name: 'Memory' })).toHaveLength(2); + }); + + it('Memory PDF 文件使用 memory inline preview 地址展示', async () => { + mocks.listMemory.mockResolvedValue({ + data: [file('profile.pdf', 'nested/profile.pdf', false)], + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(screen.getByRole('button', { name: 'Memory' })); + await user.click(await screen.findByText('profile.pdf')); + + await waitFor(() => { + expect(pdfMocks.getDocument).toHaveBeenCalledWith({ + url: '/api/workspace/memory/preview?path=nested%2Fprofile.pdf', + withCredentials: true, + }); + }); + expect(mocks.readMemoryFile).not.toHaveBeenCalled(); + }); + + it('Memory SVG 文件使用图片预览展示', async () => { + mocks.listMemory.mockResolvedValue({ + data: [file('logo.svg', 'icons/logo.svg', false)], + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(screen.getByRole('button', { name: 'Memory' })); + await user.click(await screen.findByText('logo.svg')); + + const image = await screen.findByRole('img', { name: 'logo.svg' }); + expect(image).toHaveAttribute('src', '/api/workspace/memory/preview?path=icons%2Flogo.svg'); + expect(mocks.readMemoryFile).not.toHaveBeenCalled(); + }); + + it('Memory 文本文件快速切换时忽略过期读取结果', async () => { + let resolveFirst!: (value: { data: { path: string; content: string; truncated: boolean } }) => void; + const firstRead = new Promise<{ data: { path: string; content: string; truncated: boolean } }>((resolve) => { + resolveFirst = resolve; + }); + + mocks.listMemory.mockResolvedValue({ + data: [ + file('first.md', 'first.md'), + file('second.md', 'second.md'), + ], + }); + mocks.readMemoryFile.mockImplementation((path: string) => { + if (path === 'first.md') { + return firstRead; + } + return Promise.resolve({ + data: { + path: 'second.md', + content: '# Second', + truncated: false, + }, + }); + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(screen.getByRole('button', { name: 'Memory' })); + await user.click(await screen.findByText('first.md')); + await user.click(await screen.findByText('second.md')); + + expect(await screen.findByRole('heading', { name: 'Second' })).toBeInTheDocument(); + + resolveFirst({ + data: { + path: 'first.md', + content: '# First', + truncated: false, + }, + }); + + await waitFor(() => { + expect(screen.queryByRole('heading', { name: 'First' })).not.toBeInTheDocument(); + }); + expect(screen.getByRole('heading', { name: 'Second' })).toBeInTheDocument(); + }); + + it('不支持预览的文件显示下载和打开目录入口', async () => { + mocks.list.mockResolvedValue({ + data: [file('archive.zip', 'archive.zip', false)], + }); + + const user = userEvent.setup(); + renderWithRouter(); + + await user.click(await screen.findByText('archive.zip')); + + expect(screen.getByText('This file cannot be previewed')).toBeInTheDocument(); + expect(screen.getByText('Download file')).toBeInTheDocument(); + const revealButtons = screen.getAllByRole('button', { name: 'Open containing folder' }); + await user.click(revealButtons[revealButtons.length - 1]); + expect(mocks.reveal).toHaveBeenCalledWith('archive.zip'); + }); + + it('目录内容默认按名称升序,并支持按名称、大小和修改时间切换排序', async () => { + mocks.list.mockResolvedValue({ + data: [ + { ...directory('beta', 'beta'), modified_at: 300 }, + { ...directory('alpha', 'alpha'), modified_at: 100 }, + { ...file('gamma.txt', 'gamma.txt'), size: 200, modified_at: 200 }, + { ...file('delta.txt', 'delta.txt'), size: 40, modified_at: 400 }, + ], + }); + + const user = userEvent.setup(); + renderWithRouter(); + + const alpha = await screen.findByText('alpha'); + const beta = screen.getByText('beta'); + const delta = screen.getByText('delta.txt'); + const gamma = screen.getByText('gamma.txt'); + + expect(alpha.compareDocumentPosition(beta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(beta.compareDocumentPosition(delta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(delta.compareDocumentPosition(gamma) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + + await user.click(screen.getByRole('button', { name: 'Name' })); + expect(gamma.compareDocumentPosition(delta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(delta.compareDocumentPosition(beta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(beta.compareDocumentPosition(alpha) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + + await user.click(screen.getByRole('button', { name: 'Size' })); + expect(alpha.compareDocumentPosition(beta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(beta.compareDocumentPosition(delta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(delta.compareDocumentPosition(gamma) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + + await user.click(screen.getByRole('button', { name: 'Size' })); + expect(gamma.compareDocumentPosition(delta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(delta.compareDocumentPosition(beta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(beta.compareDocumentPosition(alpha) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + + await user.click(screen.getByRole('button', { name: 'Modified' })); + expect(alpha.compareDocumentPosition(gamma) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(gamma.compareDocumentPosition(beta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(beta.compareDocumentPosition(delta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + + await user.click(screen.getByRole('button', { name: 'Modified' })); + expect(delta.compareDocumentPosition(beta) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(beta.compareDocumentPosition(gamma) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + expect(gamma.compareDocumentPosition(alpha) & Node.DOCUMENT_POSITION_FOLLOWING).toBeTruthy(); + }); }); diff --git a/webui/src/pages/Workspace/index.tsx b/webui/src/pages/Workspace/index.tsx index 9d3621e80..2e8632ead 100644 --- a/webui/src/pages/Workspace/index.tsx +++ b/webui/src/pages/Workspace/index.tsx @@ -1,14 +1,18 @@ -import { useState, useEffect, useCallback, useRef, useReducer } from 'react'; +import { useState, useEffect, useCallback, useMemo, useRef, useReducer } from 'react'; import { FolderOpen, Upload, Download, Trash2, Edit3, Save, - X, ChevronRight, RefreshCw, FolderPlus, - Brain, FileText, AlertTriangle, Search, ArrowLeft, + X, ChevronRight, ChevronLeft, ChevronDown, ChevronUp, RefreshCw, FolderPlus, + Brain, AlertTriangle, Search, ArrowLeft, Maximize2, + Code2, Eye, ZoomIn, ZoomOut, } from 'lucide-react'; +import * as pdfjsLib from 'pdfjs-dist'; +import pdfWorkerUrl from 'pdfjs-dist/build/pdf.worker.min.mjs?url'; import { useTranslation } from 'react-i18next'; import PageHeader from '@/components/common/PageHeader'; import LoadingSpinner from '@/components/common/LoadingSpinner'; import { useToast } from '@/components/common/Toast'; import { useConfirm } from '@/components/common/ConfirmDialog'; +import { StreamingMarkdown } from '@/components/common/StreamingMarkdown'; import { workspaceAPI, WorkspaceNode, formatBytes, formatDate, fileIcon, } from '@/api/workspace'; @@ -16,6 +20,60 @@ import { // ─── Types ──────────────────────────────────────────────────────────────── type Tab = 'files' | 'memory'; +type SortField = 'name' | 'size' | 'modified'; +type SortDirection = 'asc' | 'desc'; +type PreviewKind = 'markdown' | 'html' | 'json' | 'jsonl' | 'csv' | 'text' | 'image' | 'pdf' | 'unsupported'; +type PreviewMode = 'preview' | 'source'; + +interface PreviewFileAccess { + previewUrl: (path: string) => string; + downloadUrl: (path: string) => string; +} + +const WORKSPACE_PREVIEW_FILE_ACCESS: PreviewFileAccess = { + previewUrl: (path) => workspaceAPI.previewUrl(path), + downloadUrl: (path) => workspaceAPI.downloadUrl(path), +}; + +const MEMORY_PREVIEW_FILE_ACCESS: PreviewFileAccess = { + previewUrl: (path) => workspaceAPI.memoryPreviewUrl(path), + downloadUrl: (path) => workspaceAPI.memoryDownloadUrl(path), +}; + +const PREVIEW_PANEL_DEFAULT_RATIO = 0.5; +const PREVIEW_PANEL_MIN_WIDTH = 420; +const PREVIEW_PANEL_MIN_LIST_WIDTH = 360; +const PDF_MIN_SCALE = 0.6; +const PDF_MAX_SCALE = 2.2; +const PDF_SCALE_STEP = 0.2; +const PDF_MAX_OUTPUT_SCALE = 3; +const PDF_RENDER_WINDOW = 2; +const IMAGE_MIN_SCALE = 0.5; +const IMAGE_MAX_SCALE = 3; +const IMAGE_SCALE_STEP = 0.25; + +pdfjsLib.GlobalWorkerOptions.workerSrc = pdfWorkerUrl; + +function getViewportWidth(): number { + return typeof window === 'undefined' ? PREVIEW_PANEL_MIN_WIDTH * 2 : window.innerWidth; +} + +function getPreviewPanelMaxWidth(containerWidth = getViewportWidth()): number { + return Math.max(PREVIEW_PANEL_MIN_WIDTH, containerWidth - PREVIEW_PANEL_MIN_LIST_WIDTH); +} + +function getDefaultPreviewPanelWidth(containerWidth = getViewportWidth()): number { + const targetWidth = Math.floor(containerWidth * PREVIEW_PANEL_DEFAULT_RATIO); + return Math.min( + getPreviewPanelMaxWidth(containerWidth), + Math.max(PREVIEW_PANEL_MIN_WIDTH, targetWidth), + ); +} + +interface SortState { + field: SortField; + direction: SortDirection; +} // Preview/edit panel state consolidated into a single object interface PanelState { @@ -77,7 +135,7 @@ export default function WorkspacePage() { const { t } = useTranslation('workspace'); return ( -
+
setActiveTab('memory')} icon={} label={t('tabs.memory')} />
-
+
{activeTab === 'files' ? : }
@@ -112,6 +170,857 @@ function TabButton({ active, onClick, icon, label }: { ); } +function SortHeaderButton({ + label, + field, + sort, + onClick, + align = 'left', +}: { + label: string; + field: SortField; + sort: SortState; + onClick: (field: SortField) => void; + align?: 'left' | 'right'; +}) { + const active = sort.field === field; + const Icon = sort.direction === 'asc' ? ChevronUp : ChevronDown; + + return ( + + ); +} + +function fileExtension(name: string): string { + const index = name.lastIndexOf('.'); + return index >= 0 ? name.slice(index + 1).toLowerCase() : ''; +} + +function getPreviewKind(node: WorkspaceNode): PreviewKind { + const ext = fileExtension(node.name); + if (['png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'].includes(ext)) return 'image'; + if (ext === 'pdf') return 'pdf'; + if (node.is_text_file) { + if (['md', 'markdown'].includes(ext)) return 'markdown'; + if (['html', 'htm'].includes(ext)) return 'html'; + if (ext === 'json') return 'json'; + if (ext === 'jsonl') return 'jsonl'; + if (ext === 'csv') return 'csv'; + return 'text'; + } + return 'unsupported'; +} + +function prettyJson(content: string): { value: string; error: string | null } { + try { + return { value: JSON.stringify(JSON.parse(content), null, 2), error: null }; + } catch (e: any) { + return { value: content, error: e?.message ?? 'Invalid JSON' }; + } +} + +function prettyJsonLines(content: string): { value: string; errorCount: number } { + let errorCount = 0; + const value = content.split(/\r?\n/).map((line) => { + if (!line.trim()) return line; + try { + return JSON.stringify(JSON.parse(line), null, 2); + } catch { + errorCount += 1; + return line; + } + }).join('\n'); + return { value, errorCount }; +} + +function parseCsv(content: string): string[][] { + const rows: string[][] = []; + let row: string[] = []; + let field = ''; + let inQuotes = false; + + for (let i = 0; i < content.length; i += 1) { + const char = content[i]; + const next = content[i + 1]; + + if (char === '"') { + if (inQuotes && next === '"') { + field += '"'; + i += 1; + } else { + inQuotes = !inQuotes; + } + continue; + } + + if (char === ',' && !inQuotes) { + row.push(field); + field = ''; + continue; + } + + if ((char === '\n' || char === '\r') && !inQuotes) { + if (char === '\r' && next === '\n') { + i += 1; + } + row.push(field); + rows.push(row); + row = []; + field = ''; + continue; + } + + field += char; + } + + row.push(field); + if (row.length > 1 || row[0] !== '' || content.endsWith(',')) { + rows.push(row); + } + return rows; +} + +function SourcePreview({ content }: { content: string }) { + return ( +
+      {content}
+    
+ ); +} + +function CsvPreview({ content }: { content: string }) { + const rows = parseCsv(content); + if (rows.length === 0) { + return ; + } + + const [header, ...body] = rows; + const columnCount = Math.max(...rows.map((row) => row.length)); + + return ( +
+ + + + {Array.from({ length: columnCount }).map((_, index) => ( + + ))} + + + + {body.map((row, rowIndex) => ( + + {Array.from({ length: columnCount }).map((_, columnIndex) => ( + + ))} + + ))} + +
+ {header[index] || `Column ${index + 1}`} +
+ {row[columnIndex] ?? ''} +
+
+ ); +} + +function PdfPreview({ + node, + fileAccess, + onReveal, +}: { + node: WorkspaceNode; + fileAccess: PreviewFileAccess; + onReveal?: (node: WorkspaceNode) => void; +}) { + const { t } = useTranslation('workspace'); + const previewAreaRef = useRef(null); + const pageCanvasRefs = useRef(new Map()); + const pageShellRefs = useRef(new Map()); + const renderedPageKeysRef = useRef(new Map()); + const [pdfDoc, setPdfDoc] = useState(null); + const [pageNumber, setPageNumber] = useState(1); + const [pageCount, setPageCount] = useState(0); + const [scale, setScale] = useState(1); + const [previewAreaWidth, setPreviewAreaWidth] = useState(0); + const [pagesToRender, setPagesToRender] = useState>(() => new Set()); + const [loading, setLoading] = useState(true); + const [rendering, setRendering] = useState(false); + const [error, setError] = useState(null); + const previewUrl = fileAccess.previewUrl(node.path); + + useEffect(() => { + let cancelled = false; + setPdfDoc(null); + setPageNumber(1); + setPageCount(0); + setScale(1); + setPagesToRender(new Set()); + renderedPageKeysRef.current.clear(); + setLoading(true); + setError(null); + + const loadingTask = pdfjsLib.getDocument({ url: previewUrl, withCredentials: true }); + loadingTask.promise + .then((doc) => { + if (cancelled) { + return; + } + setPdfDoc(doc); + setPageCount(doc.numPages); + setPagesToRender(new Set(Array.from({ length: Math.min(doc.numPages, PDF_RENDER_WINDOW + 1) }, (_, index) => index + 1))); + setLoading(false); + }) + .catch((e: any) => { + if (cancelled) return; + setError(e?.message ?? 'PDF preview failed'); + setLoading(false); + }); + + return () => { + cancelled = true; + loadingTask.destroy(); + }; + }, [previewUrl]); + + const setPageCanvasRef = useCallback((page: number, element: HTMLCanvasElement | null) => { + if (element) { + pageCanvasRefs.current.set(page, element); + } else { + pageCanvasRefs.current.delete(page); + } + }, []); + + const setPageShellRef = useCallback((page: number, element: HTMLDivElement | null) => { + if (element) { + pageShellRefs.current.set(page, element); + } else { + pageShellRefs.current.delete(page); + } + }, []); + + const scrollToPage = useCallback((page: number) => { + const targetPage = Math.min(pageCount, Math.max(1, page)); + setPageNumber(targetPage); + setPagesToRender((previous) => { + const next = new Set(previous); + for (let candidate = Math.max(1, targetPage - PDF_RENDER_WINDOW); candidate <= Math.min(pageCount, targetPage + PDF_RENDER_WINDOW); candidate += 1) { + next.add(candidate); + } + return next; + }); + pageShellRefs.current.get(targetPage)?.scrollIntoView({ block: 'start' }); + }, [pageCount]); + + const handlePreviewScroll = useCallback(() => { + const area = previewAreaRef.current; + if (!area || pageCount === 0) return; + + const scrollTop = area.scrollTop + 16; + let nearestPage = pageNumber; + let nearestDistance = Number.POSITIVE_INFINITY; + pageShellRefs.current.forEach((element, page) => { + const distance = Math.abs(element.offsetTop - scrollTop); + if (distance < nearestDistance) { + nearestDistance = distance; + nearestPage = page; + } + }); + if (nearestPage !== pageNumber) { + setPageNumber(nearestPage); + } + setPagesToRender((previous) => { + let changed = false; + const next = new Set(previous); + for (let candidate = Math.max(1, nearestPage - PDF_RENDER_WINDOW); candidate <= Math.min(pageCount, nearestPage + PDF_RENDER_WINDOW); candidate += 1) { + if (!next.has(candidate)) { + next.add(candidate); + changed = true; + } + } + return changed ? next : previous; + }); + }, [pageCount, pageNumber]); + + useEffect(() => { + const area = previewAreaRef.current; + if (!area) return; + + if (area.clientWidth > 0) { + setPreviewAreaWidth(area.clientWidth); + } + if (typeof ResizeObserver === 'undefined') return; + + const observer = new ResizeObserver(([entry]) => { + if (entry.contentRect.width > 0) { + setPreviewAreaWidth(entry.contentRect.width); + } + }); + observer.observe(area); + return () => observer.disconnect(); + }, []); + + useEffect(() => { + if (!pdfDoc || pageCount === 0) return; + renderedPageKeysRef.current.clear(); + }, [pageCount, pdfDoc, previewAreaWidth, scale]); + + useEffect(() => { + if (!pdfDoc || pageCount === 0 || pagesToRender.size === 0) return; + let cancelled = false; + const renderTasks: any[] = []; + + async function renderPages() { + setRendering(true); + try { + const orderedPages = [...pagesToRender].sort((a, b) => Math.abs(a - pageNumber) - Math.abs(b - pageNumber)); + for (const pageIndex of orderedPages) { + if (pageIndex < 1 || pageIndex > pageCount) continue; + const canvas = pageCanvasRefs.current.get(pageIndex); + if (!canvas) continue; + + const page = await pdfDoc.getPage(pageIndex); + if (cancelled) return; + const baseViewport = page.getViewport({ scale: 1 }); + const availableWidth = Math.max(0, previewAreaWidth - 32); + const fitScale = availableWidth > 0 ? availableWidth / baseViewport.width : 1; + const viewport = page.getViewport({ scale: fitScale * scale }); + const context = canvas.getContext('2d'); + if (!context) { + throw new Error('Canvas unavailable'); + } + const outputScale = Math.min(window.devicePixelRatio || 1, PDF_MAX_OUTPUT_SCALE); + const cssWidth = Math.ceil(viewport.width); + const cssHeight = Math.ceil(viewport.height); + const renderKey = `${cssWidth}x${cssHeight}@${outputScale}`; + if (renderedPageKeysRef.current.get(pageIndex) === renderKey) { + continue; + } + canvas.width = Math.ceil(viewport.width * outputScale); + canvas.height = Math.ceil(viewport.height * outputScale); + canvas.style.width = `${cssWidth}px`; + canvas.style.height = `${cssHeight}px`; + const renderTask = page.render({ + canvasContext: context, + viewport, + transform: outputScale === 1 ? undefined : [outputScale, 0, 0, outputScale, 0, 0], + }); + renderTasks.push(renderTask); + await renderTask.promise; + renderedPageKeysRef.current.set(pageIndex, renderKey); + } + } catch (e: any) { + if (!cancelled && e?.name !== 'RenderingCancelledException') { + setError(e?.message ?? 'PDF preview failed'); + } + } finally { + if (!cancelled) setRendering(false); + } + } + + renderPages(); + return () => { + cancelled = true; + renderTasks.forEach((task) => task?.cancel?.()); + }; + }, [pageCount, pageNumber, pagesToRender, pdfDoc, previewAreaWidth, scale]); + + if (error) { + return ( +
+
+ +
+
+

{t('files.preview.pdfLoadFailed')}

+

{error}

+
+
+ + + {t('files.downloadFile')} + + {onReveal && ( + + )} +
+
+
+
+ ); + } + + return ( +
+
+
+ + + {loading ? t('files.preview.pdfLoading') : t('files.preview.pageIndicator', { page: pageNumber, total: pageCount })} + + +
+
+ + {Math.round(scale * 100)}% + +
+
+
+ {(loading || rendering) && ( +
+
+ {loading ? t('files.preview.pdfLoading') : t('files.preview.pdfRendering')} +
+
+ )} +
+ {Array.from({ length: pageCount }).map((_, index) => { + const page = index + 1; + return ( +
setPageShellRef(page, element)} + className="flex w-full flex-col items-center gap-1" + > + setPageCanvasRef(page, element)} + className="h-fit max-w-none bg-white shadow" + /> + {page} +
+ ); + })} +
+
+
+ ); +} + +function ImagePreview({ node, fileAccess }: { node: WorkspaceNode; fileAccess: PreviewFileAccess }) { + const { t } = useTranslation('workspace'); + const previewAreaRef = useRef(null); + const [scale, setScale] = useState(1); + const [previewAreaWidth, setPreviewAreaWidth] = useState(0); + const [naturalSize, setNaturalSize] = useState<{ width: number; height: number } | null>(null); + const previewUrl = fileAccess.previewUrl(node.path); + + useEffect(() => { + setScale(1); + setNaturalSize(null); + }, [node.path]); + + useEffect(() => { + const area = previewAreaRef.current; + if (!area) return; + + if (area.clientWidth > 0) { + setPreviewAreaWidth(area.clientWidth); + } + if (typeof ResizeObserver === 'undefined') return; + + const observer = new ResizeObserver(([entry]) => { + if (entry.contentRect.width > 0) { + setPreviewAreaWidth(entry.contentRect.width); + } + }); + observer.observe(area); + return () => observer.disconnect(); + }, []); + + const availableWidth = Math.max(0, previewAreaWidth - 32); + const fitScale = naturalSize && availableWidth > 0 ? Math.min(1, availableWidth / naturalSize.width) : 1; + const displayWidth = naturalSize ? Math.max(1, Math.round(naturalSize.width * fitScale * scale)) : undefined; + + return ( +
+
+
+ + {Math.round(scale * 100)}% + +
+
+
+
+ {node.name} { + setNaturalSize({ + width: event.currentTarget.naturalWidth, + height: event.currentTarget.naturalHeight, + }); + }} + style={displayWidth ? { width: displayWidth } : undefined} + className="h-fit max-w-none self-start bg-white object-contain shadow" + /> +
+
+
+ ); +} + +function RenderedPreview({ + node, + content, + kind, + fileAccess, + onReveal, +}: { + node: WorkspaceNode; + content: string | null; + kind: PreviewKind; + fileAccess: PreviewFileAccess; + onReveal?: (node: WorkspaceNode) => void; +}) { + const { t } = useTranslation('workspace'); + + if (kind === 'image') { + return ; + } + + if (kind === 'pdf') { + return ; + } + + if (kind === 'unsupported') { + return ; + } + + if (content === null) { + return
; + } + + if (kind === 'markdown') { + return ( +
+ +
+ ); + } + + if (kind === 'html') { + return ( +
+
+ {t('files.preview.htmlSandbox')} +
+