Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ The SDK sends events to SessionBat ingestion by default. Pass `api_key` directly
or set `SESSIONBAT_API_KEY`. For tests or local debugging, pass an explicit
transport such as `MemoryTransport` or `StdoutTransport`.

HTTP ingestion runs in a background thread so recording observations does not
block your application on network I/O. Transient failures are retried with
bounded backoff, and queued events are flushed automatically during interpreter
shutdown. Call `client.flush()` or `client.close()` when you need to wait for
delivery before exiting a short-lived process.

### `Session`

A `Session` records completed observations:
Expand Down
10 changes: 10 additions & 0 deletions src/sessionbat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ def _send(self, payload: dict[str, Any]) -> None:
assert self.transport is not None
self.transport.send(payload)

def flush(self, timeout: float | None = None) -> bool:
if self.transport is None or not hasattr(self.transport, "flush"):
return True
return bool(self.transport.flush(timeout=timeout))

def close(self, timeout: float | None = None) -> bool:
if self.transport is None or not hasattr(self.transport, "close"):
return True
return bool(self.transport.close(timeout=timeout))

def langchain_callback(
self,
*,
Expand Down
114 changes: 111 additions & 3 deletions src/sessionbat/transports.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import atexit
import json
import random
import sys
from dataclasses import dataclass
import threading
import time
from dataclasses import dataclass, field
from queue import Empty, Full, Queue
from typing import Protocol
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
Expand All @@ -23,8 +28,90 @@ class IngestionTransport:
api_key: str
endpoint: str = DEFAULT_INGESTION_ENDPOINT
timeout: float = 10.0
max_retries: int = 3
base_backoff: float = 0.25
max_backoff: float = 2.0
queue_size: int = 1000
shutdown_timeout: float = 2.0
_queue: Queue[dict] = field(init=False, repr=False)
_worker: threading.Thread | None = field(default=None, init=False, repr=False)
_closed: bool = field(default=False, init=False, repr=False)
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)

def __post_init__(self) -> None:
self._queue = Queue(maxsize=self.queue_size)
atexit.register(self.close, timeout=self.shutdown_timeout)

def send(self, payload: dict) -> None:
with self._lock:
if self._closed:
return
self._ensure_worker_started()

try:
self._queue.put_nowait(payload)
except Full:
return

def flush(self, timeout: float | None = None) -> bool:
deadline = None if timeout is None else time.monotonic() + timeout
while self._queue.unfinished_tasks:
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(0.01)
return True

def close(self, timeout: float | None = None) -> bool:
deadline = None if timeout is None else time.monotonic() + timeout
with self._lock:
worker = self._worker
if not self._closed:
self._closed = True

flushed = self.flush(timeout=timeout)
if worker is None:
return flushed

remaining = None if deadline is None else max(0.0, deadline - time.monotonic())
worker.join(timeout=remaining)
return flushed and not worker.is_alive()

def _ensure_worker_started(self) -> None:
if self._worker is not None:
return
self._worker = threading.Thread(target=self._run, name="sessionbat-ingestion")
self._worker.daemon = True
self._worker.start()

def _run(self) -> None:
while True:
try:
payload = self._queue.get(timeout=0.1)
except Empty:
if self._closed:
return
continue

try:
self._send_with_retries(payload)
finally:
self._queue.task_done()

def _send_with_retries(self, payload: dict) -> None:
attempt = 0
while True:
try:
self._send_once(payload)
return
except TransportError as error:
if attempt >= self.max_retries or not _is_retryable(error):
Comment thread
daugaard marked this conversation as resolved.
return
time.sleep(_backoff(attempt, self.base_backoff, self.max_backoff))
attempt += 1
except Exception:
return

def _send_once(self, payload: dict) -> None:
body = json.dumps(payload).encode("utf-8")
request = Request(
self.endpoint,
Expand All @@ -41,11 +128,32 @@ def send(self, payload: dict) -> None:
with urlopen(request, timeout=self.timeout) as response:
status = response.status
if status < 200 or status >= 300:
raise TransportError(f"ingestion failed with HTTP {status}")
raise TransportError(f"ingestion failed with HTTP {status}", status)
except HTTPError as error:
raise TransportError(f"ingestion failed with HTTP {error.code}") from error
raise TransportError(f"ingestion failed with HTTP {error.code}", error.code) from error
except URLError as error:
raise TransportError(f"ingestion request failed: {error.reason}") from error
except TimeoutError as error:
raise TransportError("ingestion request timed out") from error
except OSError as error:
raise TransportError(f"ingestion request failed: {error}") from error


def _is_retryable(error: TransportError) -> bool:
if len(error.args) > 1 and isinstance(error.args[1], int):
status = error.args[1]
return status == 408 or status == 429 or status >= 500
cause = error.__cause__
if isinstance(cause, HTTPError):
return cause.code == 408 or cause.code == 429 or cause.code >= 500
if isinstance(cause, URLError | TimeoutError | OSError):
return True
return False


def _backoff(attempt: int, base_backoff: float, max_backoff: float) -> float:
delay = min(max_backoff, base_backoff * (2**attempt))
return random.uniform(0, delay)


class StdoutTransport:
Expand Down
174 changes: 167 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import json
import threading
import time
from collections.abc import Iterator
from datetime import datetime
from http.server import BaseHTTPRequestHandler, HTTPServer
from queue import Queue

import pytest

from sessionbat import SessionBat
from sessionbat.transports import IngestionTransport, MemoryTransport, TransportError
from sessionbat.transports import IngestionTransport, MemoryTransport


class TestSessionBatClient:
Expand Down Expand Up @@ -131,19 +133,24 @@ def test_records_errors_on_failed_operations(self) -> None:

class _RecordingHandler(BaseHTTPRequestHandler):
requests: list[dict] = []
response_status = 202
response_status: int | list[int] = 202

def do_POST(self) -> None:
length = int(self.headers["Content-Length"])
body = self.rfile.read(length)
response_status = self.__class__.response_status
if isinstance(response_status, list):
status = response_status.pop(0) if response_status else 202
else:
status = response_status
self.__class__.requests.append(
{
"path": self.path,
"headers": self.headers,
"body": json.loads(body),
}
)
self.send_response(self.__class__.response_status)
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(b'{"ok":true}')
Expand Down Expand Up @@ -195,6 +202,7 @@ def test_uses_environment_api_key(
session = client.session(session_id="thread_123")
session.tool_call(tool_name="lookup_account", input={"account_id": "acct_123"})

assert client.flush(timeout=1.0)
request = _RecordingHandler.requests[0]
assert request["headers"]["Authorization"] == "Bearer sbat_ingest_env"

Expand All @@ -212,6 +220,7 @@ def test_posts_sdk_payload_with_bearer_auth(self, ingestion_server: str) -> None
output={"status": "locked"},
)

assert client.flush(timeout=1.0)
request = _RecordingHandler.requests[0]
payload = request["body"]
assert request["path"] == "/api/v1/ingestion/events"
Expand All @@ -228,9 +237,160 @@ def test_posts_sdk_payload_with_bearer_auth(self, ingestion_server: str) -> None
assert payload["observation"]["input"] == {"account_id": "acct_123"}
assert payload["observation"]["output"] == {"status": "locked"}

def test_raises_transport_error_for_non_2xx_response(self, ingestion_server: str) -> None:
_RecordingHandler.response_status = 500
def test_retries_transient_http_failures(self, ingestion_server: str) -> None:
_RecordingHandler.response_status = [500, 500, 202]
transport = IngestionTransport(
api_key="sbat_ingest_test",
endpoint=ingestion_server,
base_backoff=0,
max_backoff=0,
)

transport.send({"id": "evt_123"})

assert transport.flush(timeout=1.0)
assert [request["body"]["id"] for request in _RecordingHandler.requests] == [
"evt_123",
"evt_123",
"evt_123",
]

def test_does_not_retry_or_raise_for_non_retryable_http_failures(
self, ingestion_server: str
) -> None:
_RecordingHandler.response_status = 400
transport = IngestionTransport(api_key="sbat_ingest_test", endpoint=ingestion_server)

with pytest.raises(TransportError, match="HTTP 500"):
transport.send({"id": "evt_123"})
transport.send({"id": "evt_123"})

assert transport.flush(timeout=1.0)
assert len(_RecordingHandler.requests) == 1

def test_send_does_not_raise_for_network_failures(self) -> None:
transport = IngestionTransport(
api_key="sbat_ingest_test",
endpoint="http://127.0.0.1:1/api/v1/ingestion/events",
base_backoff=0,
max_backoff=0,
timeout=0.01,
)

transport.send({"id": "evt_123"})

assert transport.flush(timeout=1.0)

def test_unexpected_send_failure_does_not_stop_worker(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
sent: list[str] = []

def send_once(self: IngestionTransport, payload: dict) -> None:
if payload["id"] == "evt_bad":
raise TypeError("not JSON serializable")
sent.append(payload["id"])

monkeypatch.setattr(IngestionTransport, "_send_once", send_once)
transport = IngestionTransport(api_key="sbat_ingest_test")

transport.send({"id": "evt_bad"})
transport.send({"id": "evt_good"})

assert transport.flush(timeout=1.0)
assert sent == ["evt_good"]
assert transport.close(timeout=1.0)

def test_close_waits_for_event_accepted_during_send(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
sent: list[str] = []

class BlockingPutQueue:
def __init__(self) -> None:
self.inner: Queue[dict] = Queue()
self.put_started = threading.Event()
self.release_put = threading.Event()

@property
def unfinished_tasks(self) -> int:
return self.inner.unfinished_tasks

def put_nowait(self, payload: dict) -> None:
self.put_started.set()
self.release_put.wait(timeout=1.0)
self.inner.put_nowait(payload)

def get(self, timeout: float) -> dict:
return self.inner.get(timeout=timeout)

def task_done(self) -> None:
self.inner.task_done()

def send_once(self: IngestionTransport, payload: dict) -> None:
sent.append(payload["id"])

monkeypatch.setattr(IngestionTransport, "_send_once", send_once)
transport = IngestionTransport(api_key="sbat_ingest_test")
queue = BlockingPutQueue()
transport._queue = queue
close_result: list[bool] = []

send_thread = threading.Thread(target=transport.send, args=({"id": "evt_123"},))
send_thread.start()
assert queue.put_started.wait(timeout=1.0)

close_thread = threading.Thread(
target=lambda: close_result.append(transport.close(timeout=1.0))
)
close_thread.start()
time.sleep(0.05)

assert close_thread.is_alive()
queue.release_put.set()
send_thread.join(timeout=1.0)
close_thread.join(timeout=1.0)

assert close_result == [True]
assert sent == ["evt_123"]

def test_full_queue_drops_newest_without_blocking(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
started = threading.Event()
release = threading.Event()
sent: list[str] = []

def slow_send_once(self: IngestionTransport, payload: dict) -> None:
sent.append(payload["id"])
started.set()
release.wait(timeout=1.0)

monkeypatch.setattr(IngestionTransport, "_send_once", slow_send_once)
transport = IngestionTransport(api_key="sbat_ingest_test", queue_size=1)

transport.send({"id": "evt_1"})
assert started.wait(timeout=1.0)
transport.send({"id": "evt_2"})
start = time.monotonic()
transport.send({"id": "evt_3"})

assert time.monotonic() - start < 0.1
release.set()
assert transport.close(timeout=1.0)
assert sent == ["evt_1", "evt_2"]

def test_flush_returns_false_when_timeout_expires(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
release = threading.Event()

def slow_send_once(self: IngestionTransport, payload: dict) -> None:
release.wait(timeout=1.0)

monkeypatch.setattr(IngestionTransport, "_send_once", slow_send_once)
transport = IngestionTransport(api_key="sbat_ingest_test")

transport.send({"id": "evt_123"})

assert transport.flush(timeout=0.01) is False
release.set()
assert transport.close(timeout=1.0)
Loading