diff --git a/governs_ai/__init__.py b/governs_ai/__init__.py index ae55283..7189ec6 100644 --- a/governs_ai/__init__.py +++ b/governs_ai/__init__.py @@ -7,6 +7,8 @@ """ from .client import GovernsAIClient, GovernsAIConfig +from .sync import SyncClient, PrecheckDecision, GovernsAPIError, precheck +from .async_client import AsyncClient from .clients.precheck import PrecheckClient from .clients.confirmation import ConfirmationClient from .clients.budget import BudgetClient diff --git a/governs_ai/async_client.py b/governs_ai/async_client.py new file mode 100644 index 0000000..d30943f --- /dev/null +++ b/governs_ai/async_client.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: MIT +"""Async convenience wrapper — for asyncio apps (FastAPI, aiohttp, etc.). + +Mirrors `sync.SyncClient` but uses `httpx.AsyncClient` so it can be awaited +without blocking the event loop. Identical Decision dataclass + retry semantics. + + from governs_ai import AsyncClient + async with AsyncClient(api_key='GAI_...', base_url='http://localhost:8082') as c: + d = await c.precheck(tool='chat', raw_text='hi') + if d.denied: + raise HTTPException(status_code=403) +""" +from __future__ import annotations + +import asyncio +import os +import random +from typing import Any, Dict, Optional + +import aiohttp + +from .sync import ( + DEFAULT_BASE_URL, + GovernsAPIError, + PrecheckDecision, + USER_AGENT, +) + + +class AsyncClient: + """asyncio variant of SyncClient. Same retry/error model.""" + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout: float = 5.0, + retries: int = 3, + backoff_base_ms: int = 100, + ): + self.api_key = api_key or os.environ.get("GOVERNS_AI_API_KEY", "") + if not self.api_key: + raise ValueError("api_key is required (or set GOVERNS_AI_API_KEY env var)") + self.base_url = ( + base_url or os.environ.get("GOVERNS_AI_BASE_URL") or DEFAULT_BASE_URL + ).rstrip("/") + self.timeout = timeout + self.retries = max(0, retries) + self.backoff_base_ms = backoff_base_ms + self._client = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout), + headers={ + "X-Governs-Key": self.api_key, + "Content-Type": "application/json", + "User-Agent": USER_AGENT + " (async)", + }, + ) + + async def precheck( + self, + *, + tool: str, + raw_text: str, + scope: Optional[str] = None, + user_id: Optional[str] = None, + corr_id: Optional[str] = None, + policy_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[Dict[str, Any]] = None, + budget_context: Optional[Dict[str, Any]] = None, + ) -> PrecheckDecision: + payload: Dict[str, Any] = {"tool": tool, "raw_text": raw_text} + if scope is not None: + payload["scope"] = scope + if user_id is not None: + payload["user_id"] = user_id + if corr_id is not None: + payload["corr_id"] = corr_id + if policy_config is not None: + payload["policy_config"] = policy_config + if tool_config is not None: + payload["tool_config"] = tool_config + if budget_context is not None: + payload["budget_context"] = budget_context + return await self._post_with_retry("/api/v1/precheck", payload) + + async def postcheck( + self, *, tool: str, raw_text: str, scope: Optional[str] = None, corr_id: Optional[str] = None, + ) -> PrecheckDecision: + payload: Dict[str, Any] = {"tool": tool, "raw_text": raw_text} + if scope is not None: + payload["scope"] = scope + if corr_id is not None: + payload["corr_id"] = corr_id + return await self._post_with_retry("/api/v1/postcheck", payload) + + async def health(self) -> Dict[str, Any]: + url = f"{self.base_url}/api/v1/health" + async with self._client.get(url) as resp: + text = await resp.text() + if resp.status >= 400: + raise GovernsAPIError(resp.status, text) + import json as _json + return _json.loads(text) + + async def close(self) -> None: + await self._client.close() + + async def __aenter__(self) -> "AsyncClient": + return self + + async def __aexit__(self, *_exc) -> None: + await self.close() + + async def _post_with_retry(self, path: str, payload: Dict[str, Any]) -> PrecheckDecision: + import json as _json + + url = f"{self.base_url}{path}" + last_exc: Optional[Exception] = None + for attempt in range(self.retries + 1): + try: + async with self._client.post(url, json=payload) as resp: + text = await resp.text() + status = resp.status + except aiohttp.ClientError as exc: + last_exc = exc + if attempt < self.retries: + await self._backoff(attempt) + continue + raise GovernsAPIError(0, f"network error: {exc}") from exc + + if 200 <= status < 300: + try: + body = _json.loads(text) + except Exception: + raise GovernsAPIError(status, "invalid JSON", text) + return PrecheckDecision.from_dict(body) + + if status == 429 or 500 <= status < 600: + last_exc = GovernsAPIError(status, text) + if attempt < self.retries: + await self._backoff(attempt) + continue + + try: + err_body = _json.loads(text) + except Exception: + err_body = text + raise GovernsAPIError(status, str(err_body), err_body) + + assert last_exc is not None + raise last_exc # type: ignore[misc] + + async def _backoff(self, attempt: int) -> None: + delay_ms = min(self.backoff_base_ms * (2 ** attempt), 3000) + delay_ms = delay_ms * (0.5 + random.random() * 0.5) + await asyncio.sleep(delay_ms / 1000.0) diff --git a/governs_ai/exceptions/precheck.py b/governs_ai/exceptions/precheck.py index 37e4843..4b6eb34 100644 --- a/governs_ai/exceptions/precheck.py +++ b/governs_ai/exceptions/precheck.py @@ -4,6 +4,8 @@ Precheck-specific exceptions. """ +from typing import Any, Dict, Optional + from .base import GovernsAIError diff --git a/governs_ai/sync.py b/governs_ai/sync.py new file mode 100644 index 0000000..f89fa33 --- /dev/null +++ b/governs_ai/sync.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: MIT +"""Synchronous convenience wrapper around the GovernsAI HTTP API. + +The full async client (`GovernsAIClient`) is the right tool for production +workloads. This module is the "5-second integration" surface: + + from governs_ai import precheck + decision = precheck( + api_key="GAI_...", + tool="chat", + raw_text="user input here", + base_url="http://localhost:8082", + ) + if decision.decision == "deny": + ... + +Returns a `Decision` dataclass with the same shape the precheck service emits. + +This wrapper uses `requests` (already a hard dep in pyproject.toml) so it has +no asyncio. It is safe to call from any Python framework: Flask, FastAPI +handlers (run-in-threadpool), Django views, plain scripts. + +For retries and connection pooling at scale, instantiate `SyncClient` once +and reuse it instead of calling `precheck(...)` per request. +""" +from __future__ import annotations + +import os +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import requests + +DEFAULT_BASE_URL = "https://api.governsai.com" +USER_AGENT = "governs-ai-sdk-python (sync)" + + +class GovernsAPIError(RuntimeError): + """Raised for any non-2xx response from the precheck service.""" + + def __init__(self, status_code: int, message: str, body: Optional[Any] = None): + super().__init__(f"GovernsAI API error {status_code}: {message}") + self.status_code = status_code + self.body = body + + +@dataclass +class PrecheckDecision: + """Decision returned by the precheck service. + + Mirrors precheck's `DecisionResponse` (precheck/app/models.py). + """ + + decision: str # one of: allow | transform | deny | confirm + raw_text_out: str + reasons: List[str] = field(default_factory=list) + policy_id: Optional[str] = None + ts: Optional[int] = None + raw: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PrecheckDecision": + return cls( + decision=data.get("decision", "deny"), + raw_text_out=data.get("raw_text_out", ""), + reasons=list(data.get("reasons") or []), + policy_id=data.get("policy_id"), + ts=data.get("ts"), + raw=data, + ) + + @property + def allowed(self) -> bool: + return self.decision in ("allow", "transform") + + @property + def denied(self) -> bool: + return self.decision == "deny" + + +class SyncClient: + """Synchronous client. Construct once, reuse across requests. + + Honors `GOVERNS_AI_API_KEY` and `GOVERNS_AI_BASE_URL` env vars when params + are omitted, so it slots cleanly into 12-factor apps. + """ + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout: float = 5.0, + retries: int = 3, + backoff_base_ms: int = 100, + ): + self.api_key = api_key or os.environ.get("GOVERNS_AI_API_KEY", "") + if not self.api_key: + raise ValueError( + "api_key is required (or set GOVERNS_AI_API_KEY env var)" + ) + self.base_url = ( + base_url + or os.environ.get("GOVERNS_AI_BASE_URL") + or DEFAULT_BASE_URL + ).rstrip("/") + self.timeout = timeout + self.retries = max(0, retries) + self.backoff_base_ms = backoff_base_ms + self._session = requests.Session() + self._session.headers.update({ + "X-Governs-Key": self.api_key, + "Content-Type": "application/json", + "User-Agent": USER_AGENT, + }) + + def precheck( + self, + *, + tool: str, + raw_text: str, + scope: Optional[str] = None, + user_id: Optional[str] = None, + corr_id: Optional[str] = None, + policy_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[Dict[str, Any]] = None, + budget_context: Optional[Dict[str, Any]] = None, + ) -> PrecheckDecision: + """Call POST /api/v1/precheck with retries + simple backoff. + + Retries only on transient failures (network errors, 5xx, 429). + 4xx errors are surfaced immediately via GovernsAPIError. + """ + payload: Dict[str, Any] = {"tool": tool, "raw_text": raw_text} + if scope is not None: + payload["scope"] = scope + if user_id is not None: + payload["user_id"] = user_id + if corr_id is not None: + payload["corr_id"] = corr_id + if policy_config is not None: + payload["policy_config"] = policy_config + if tool_config is not None: + payload["tool_config"] = tool_config + if budget_context is not None: + payload["budget_context"] = budget_context + + return self._post_with_retry("/api/v1/precheck", payload) + + def postcheck( + self, + *, + tool: str, + raw_text: str, + scope: Optional[str] = None, + corr_id: Optional[str] = None, + ) -> PrecheckDecision: + """Call POST /api/v1/postcheck — same shape as precheck.""" + payload: Dict[str, Any] = {"tool": tool, "raw_text": raw_text} + if scope is not None: + payload["scope"] = scope + if corr_id is not None: + payload["corr_id"] = corr_id + return self._post_with_retry("/api/v1/postcheck", payload) + + def health(self) -> Dict[str, Any]: + """GET /api/v1/health — returns the raw payload (no Decision wrap).""" + url = f"{self.base_url}/api/v1/health" + resp = self._session.get(url, timeout=self.timeout) + if not resp.ok: + raise GovernsAPIError(resp.status_code, resp.text) + return resp.json() + + def close(self) -> None: + """Release the underlying TCP pool.""" + self._session.close() + + def __enter__(self) -> "SyncClient": + return self + + def __exit__(self, *_exc) -> None: + self.close() + + # ─── internals ──────────────────────────────────────────────────── + def _post_with_retry(self, path: str, payload: Dict[str, Any]) -> PrecheckDecision: + url = f"{self.base_url}{path}" + last_exc: Optional[Exception] = None + for attempt in range(self.retries + 1): + try: + resp = self._session.post(url, json=payload, timeout=self.timeout) + except requests.RequestException as exc: + last_exc = exc + if attempt < self.retries: + self._backoff(attempt) + continue + raise GovernsAPIError(0, f"network error: {exc}") from exc + + if 200 <= resp.status_code < 300: + try: + body = resp.json() + except ValueError: + raise GovernsAPIError(resp.status_code, "invalid JSON", resp.text) + return PrecheckDecision.from_dict(body) + + # 429 / 5xx are retriable + if resp.status_code == 429 or 500 <= resp.status_code < 600: + last_exc = GovernsAPIError(resp.status_code, resp.text) + if attempt < self.retries: + self._backoff(attempt) + continue + + # 4xx (non-429) — don't retry, surface immediately + try: + err_body = resp.json() + except ValueError: + err_body = resp.text + raise GovernsAPIError(resp.status_code, str(err_body), err_body) + + # exhausted retries + assert last_exc is not None + raise last_exc # type: ignore[misc] + + def _backoff(self, attempt: int) -> None: + # exponential backoff with full jitter, capped at ~3s + delay_ms = min(self.backoff_base_ms * (2 ** attempt), 3000) + # jitter 50–100% of delay_ms + import random + delay_ms = delay_ms * (0.5 + random.random() * 0.5) + time.sleep(delay_ms / 1000.0) + + +# Module-level convenience: one-shot call without instantiating a client. +def precheck( + *, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + tool: str, + raw_text: str, + **kwargs: Any, +) -> PrecheckDecision: + """One-shot precheck. For repeated calls, prefer `SyncClient(...).precheck()`.""" + with SyncClient(api_key=api_key, base_url=base_url) as client: + return client.precheck(tool=tool, raw_text=raw_text, **kwargs) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..ee71831 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +asyncio_mode = auto +testpaths = tests +filterwarnings = + ignore::DeprecationWarning:pydantic.* + ignore::DeprecationWarning:sqlalchemy.* diff --git a/tests/test_async_client.py b/tests/test_async_client.py new file mode 100644 index 0000000..dd0a72c --- /dev/null +++ b/tests/test_async_client.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: MIT +"""Unit tests for governs_ai.async_client.AsyncClient. + +Uses aiohttp's built-in test server (no external mock dep needed). +""" +from __future__ import annotations + +import asyncio +import pytest + +from aiohttp import web + +from governs_ai import AsyncClient, GovernsAPIError, PrecheckDecision + + +@pytest.fixture +async def server(aiohttp_server): + calls = {"precheck": 0, "fail_modes": []} + + async def precheck_handler(request): + calls["precheck"] += 1 + # check the API key got through + assert request.headers.get("X-Governs-Key") == "GAI_test" + body = await request.json() + # Honor fail_modes for retry tests + if calls["fail_modes"]: + mode = calls["fail_modes"].pop(0) + if mode == "5xx": + return web.json_response({"error": "down"}, status=503) + if mode == "429": + return web.json_response({"error": "slow down"}, status=429) + if mode == "4xx": + return web.json_response({"error": "bad"}, status=400) + return web.json_response({ + "decision": "transform", + "raw_text_out": f"redacted: {body['raw_text']}", + "reasons": ["pii.email"], + "policy_id": "p-1", + "ts": 12345, + }) + + async def health_handler(request): + return web.json_response({"ok": True, "service": "test"}) + + app = web.Application() + app.router.add_post("/api/v1/precheck", precheck_handler) + app.router.add_get("/api/v1/health", health_handler) + s = await aiohttp_server(app) + s.calls = calls # type: ignore[attr-defined] + return s + + +@pytest.mark.asyncio +async def test_precheck_returns_decision(server): + async with AsyncClient(api_key="GAI_test", base_url=str(server.make_url("")).rstrip("/")) as c: + d = await c.precheck(tool="chat", raw_text="hello jane@example.com") + assert isinstance(d, PrecheckDecision) + assert d.decision == "transform" + assert "redacted:" in d.raw_text_out + assert d.reasons == ["pii.email"] + + +@pytest.mark.asyncio +async def test_health(server): + async with AsyncClient(api_key="GAI_test", base_url=str(server.make_url("")).rstrip("/")) as c: + h = await c.health() + assert h == {"ok": True, "service": "test"} + + +@pytest.mark.asyncio +async def test_5xx_retries_then_succeeds(server): + server.calls["fail_modes"] = ["5xx", "5xx"] # first two attempts fail, third wins + async with AsyncClient(api_key="GAI_test", base_url=str(server.make_url("")).rstrip("/"), + retries=2, backoff_base_ms=1) as c: + d = await c.precheck(tool="chat", raw_text="x") + assert d.decision == "transform" + assert server.calls["precheck"] == 3 + + +@pytest.mark.asyncio +async def test_429_retries(server): + server.calls["fail_modes"] = ["429", "429"] + async with AsyncClient(api_key="GAI_test", base_url=str(server.make_url("")).rstrip("/"), + retries=2, backoff_base_ms=1) as c: + d = await c.precheck(tool="chat", raw_text="x") + assert d.decision == "transform" + assert server.calls["precheck"] == 3 + + +@pytest.mark.asyncio +async def test_4xx_raises_no_retry(server): + server.calls["fail_modes"] = ["4xx"] + async with AsyncClient(api_key="GAI_test", base_url=str(server.make_url("")).rstrip("/"), + retries=3, backoff_base_ms=1) as c: + with pytest.raises(GovernsAPIError) as ei: + await c.precheck(tool="chat", raw_text="x") + assert ei.value.status_code == 400 + assert server.calls["precheck"] == 1 # no retry + + +@pytest.mark.asyncio +async def test_api_key_required(): + with pytest.raises(ValueError, match="api_key is required"): + AsyncClient(api_key="") + + +@pytest.mark.asyncio +async def test_env_var_pickup(monkeypatch): + monkeypatch.setenv("GOVERNS_AI_API_KEY", "GAI_env") + monkeypatch.setenv("GOVERNS_AI_BASE_URL", "http://envtest") + c = AsyncClient() + assert c.api_key == "GAI_env" + assert c.base_url == "http://envtest" + await c.close() diff --git a/tests/test_sync_client.py b/tests/test_sync_client.py new file mode 100644 index 0000000..0f43b01 --- /dev/null +++ b/tests/test_sync_client.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: MIT +"""Unit tests for governs_ai.sync — the synchronous convenience client. + +Uses `requests-mock` (built into newer requests via responses) or a plain +monkeypatched session for portability. We use unittest.mock.patch on +the underlying session for zero external test deps. +""" +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from governs_ai.sync import ( + PrecheckDecision, + GovernsAPIError, + SyncClient, + precheck as oneshot_precheck, +) + + +def _ok_response(payload): + r = MagicMock(spec=requests.Response) + r.status_code = 200 + r.ok = True + r.json.return_value = payload + r.text = json.dumps(payload) + return r + + +def _err_response(status, body): + r = MagicMock(spec=requests.Response) + r.status_code = status + r.ok = False + if isinstance(body, dict): + r.json.return_value = body + r.text = json.dumps(body) + else: + r.json.side_effect = ValueError("not json") + r.text = body + return r + + +def test_api_key_required(): + with pytest.raises(ValueError, match="api_key is required"): + SyncClient(api_key="") + + +def test_api_key_from_env(monkeypatch): + monkeypatch.setenv("GOVERNS_AI_API_KEY", "GAI_envkey") + c = SyncClient() + assert c.api_key == "GAI_envkey" + + +def test_base_url_from_env(monkeypatch): + monkeypatch.setenv("GOVERNS_AI_API_KEY", "GAI_envkey") + monkeypatch.setenv("GOVERNS_AI_BASE_URL", "http://localhost:8082") + c = SyncClient() + assert c.base_url == "http://localhost:8082" + + +def test_precheck_returns_decision(): + c = SyncClient(api_key="GAI_test", base_url="http://t") + with patch.object(c._session, "post", return_value=_ok_response({ + "decision": "transform", + "raw_text_out": "redacted", + "reasons": ["pii.email"], + "policy_id": "p-1", + "ts": 12345, + })) as mock_post: + d = c.precheck(tool="chat", raw_text="hi jane@example.com") + assert isinstance(d, PrecheckDecision) + assert d.decision == "transform" + assert d.raw_text_out == "redacted" + assert d.allowed is True + assert d.denied is False + assert d.reasons == ["pii.email"] + assert d.policy_id == "p-1" + # verify request shape + call = mock_post.call_args + assert call.args[0].endswith("/api/v1/precheck") + body = call.kwargs["json"] + assert body == {"tool": "chat", "raw_text": "hi jane@example.com"} + + +def test_precheck_passes_optional_fields(): + c = SyncClient(api_key="GAI_test", base_url="http://t") + with patch.object(c._session, "post", return_value=_ok_response({ + "decision": "allow", "raw_text_out": "x" + })) as mock_post: + c.precheck( + tool="chat", + raw_text="x", + scope="user.write", + user_id="u1", + corr_id="c1", + policy_config={"defaults": {"pii": "redact"}}, + ) + body = mock_post.call_args.kwargs["json"] + assert body["scope"] == "user.write" + assert body["user_id"] == "u1" + assert body["corr_id"] == "c1" + assert body["policy_config"] == {"defaults": {"pii": "redact"}} + + +def test_deny_decision_helpers(): + d = PrecheckDecision.from_dict({"decision": "deny", "raw_text_out": "", "reasons": ["x"]}) + assert d.denied + assert not d.allowed + + +def test_4xx_raises_immediately_no_retry(): + c = SyncClient(api_key="GAI_test", base_url="http://t", retries=3) + with patch.object(c._session, "post", return_value=_err_response(400, {"error": "bad"})) as mock_post: + with pytest.raises(GovernsAPIError) as ei: + c.precheck(tool="chat", raw_text="x") + assert ei.value.status_code == 400 + # 4xx should NOT retry + assert mock_post.call_count == 1 + + +def test_5xx_retries_then_fails(): + c = SyncClient(api_key="GAI_test", base_url="http://t", retries=2, backoff_base_ms=1) + with patch.object(c._session, "post", return_value=_err_response(503, {"error": "down"})) as mock_post: + with pytest.raises(GovernsAPIError) as ei: + c.precheck(tool="chat", raw_text="x") + assert ei.value.status_code == 503 + # 2 retries + 1 initial = 3 calls + assert mock_post.call_count == 3 + + +def test_429_retries(): + c = SyncClient(api_key="GAI_test", base_url="http://t", retries=2, backoff_base_ms=1) + with patch.object(c._session, "post", return_value=_err_response(429, "rate limited")) as mock_post: + with pytest.raises(GovernsAPIError): + c.precheck(tool="chat", raw_text="x") + assert mock_post.call_count == 3 + + +def test_5xx_then_ok_succeeds(): + c = SyncClient(api_key="GAI_test", base_url="http://t", retries=2, backoff_base_ms=1) + responses = [ + _err_response(503, {"error": "down"}), + _ok_response({"decision": "allow", "raw_text_out": "ok"}), + ] + with patch.object(c._session, "post", side_effect=responses): + d = c.precheck(tool="chat", raw_text="x") + assert d.decision == "allow" + + +def test_network_error_retries(): + c = SyncClient(api_key="GAI_test", base_url="http://t", retries=2, backoff_base_ms=1) + with patch.object(c._session, "post", side_effect=requests.ConnectionError("boom")) as mock_post: + with pytest.raises(GovernsAPIError): + c.precheck(tool="chat", raw_text="x") + assert mock_post.call_count == 3 # retried + + +def test_oneshot_precheck_uses_context_manager(monkeypatch): + captured = {} + real_precheck = SyncClient.precheck + + def fake_precheck(self, **kwargs): + captured.update(kwargs) + return PrecheckDecision.from_dict({"decision": "allow", "raw_text_out": "ok"}) + + monkeypatch.setattr(SyncClient, "precheck", fake_precheck) + monkeypatch.setattr(SyncClient, "close", lambda self: None) + + d = oneshot_precheck(api_key="GAI_x", base_url="http://t", tool="chat", raw_text="hi") + assert d.decision == "allow" + assert captured["tool"] == "chat" + assert captured["raw_text"] == "hi" + + +def test_auth_header_set(): + c = SyncClient(api_key="GAI_secret", base_url="http://t") + assert c._session.headers["X-Governs-Key"] == "GAI_secret" + assert c._session.headers["Content-Type"] == "application/json"