From 872e75c4fd176bbe096de1189a1f47447dd023c1 Mon Sep 17 00:00:00 2001 From: Sergio Herrera Date: Fri, 22 May 2026 17:32:48 +0200 Subject: [PATCH] Add support for async activities Signed-off-by: Sergio Herrera <627709+seherv@users.noreply.github.com> --- examples/workflow/async_activities.py | 104 ++ ext/dapr-ext-workflow/AGENTS.md | 21 +- ext/dapr-ext-workflow/benchmarks/RESULTS.md | 132 ++ .../benchmarks/bench_async_activities.py | 1457 +++++++++++++++++ .../dapr/ext/workflow/_durabletask/worker.py | 323 +++- .../dapr/ext/workflow/workflow_runtime.py | 117 +- ext/dapr-ext-workflow/docs/concurrency.md | 83 + .../test_activity_dispatch_routing.py | 90 + .../durabletask/test_activity_executor.py | 21 + .../test_activity_executor_async.py | 101 ++ .../tests/test_async_activity_registration.py | 260 +++ 11 files changed, 2586 insertions(+), 123 deletions(-) create mode 100644 examples/workflow/async_activities.py create mode 100644 ext/dapr-ext-workflow/benchmarks/RESULTS.md create mode 100644 ext/dapr-ext-workflow/benchmarks/bench_async_activities.py create mode 100644 ext/dapr-ext-workflow/docs/concurrency.md create mode 100644 ext/dapr-ext-workflow/tests/durabletask/test_activity_dispatch_routing.py create mode 100644 ext/dapr-ext-workflow/tests/durabletask/test_activity_executor_async.py create mode 100644 ext/dapr-ext-workflow/tests/test_async_activity_registration.py diff --git a/examples/workflow/async_activities.py b/examples/workflow/async_activities.py new file mode 100644 index 000000000..cd741bb3a --- /dev/null +++ b/examples/workflow/async_activities.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async activities running alongside sync ones in a single workflow. + +Starts three async activities that do an HTTP request, then a sync activity that +sums up the results. Shows that sync and async activities work side by side. + +Run with: + + dapr run --app-id async-activities --app-protocol grpc --dapr-grpc-port 50001 \\ + -- python async_activities.py +""" + +from __future__ import annotations + +from time import sleep + +import dapr.ext.workflow as wf +import httpx +from pydantic import BaseModel + +wfr = wf.WorkflowRuntime() + + +class FetchRequest(BaseModel): + url: str + timeout_seconds: float = 5.0 + + +class FetchResult(BaseModel): + url: str + status_code: int + body_length: int + + +@wfr.workflow(name='parallel_fetch_workflow') +def parallel_fetch_workflow(ctx: wf.DaprWorkflowContext, urls: list[str]): + fetch_tasks = [ + ctx.call_activity(fetch_url, input=FetchRequest(url=url).model_dump()) for url in urls + ] + results = yield wf.when_all(fetch_tasks) + summary = yield ctx.call_activity(summarize_fetches, input=results) + return summary + + +@wfr.activity(name='fetch_url') +async def fetch_url(ctx: wf.WorkflowActivityContext, request: FetchRequest) -> dict: + """Async activity: fetch a URL with httpx. Multiple instances run concurrently.""" + async with httpx.AsyncClient(timeout=request.timeout_seconds) as client: + response = await client.get(request.url) + result = FetchResult( + url=request.url, + status_code=response.status_code, + body_length=len(response.content), + ) + print( + f'[async] fetched {result.url} -> {result.status_code} ({result.body_length}B)', flush=True + ) + return result.model_dump() + + +@wfr.activity(name='summarize_fetches') +def summarize_fetches(ctx: wf.WorkflowActivityContext, results: list[dict]) -> str: + """Sync activity: runs in the sync-fallback thread pool. Unchanged from before.""" + total_bytes = sum(r['body_length'] for r in results) + summary = f'fetched {len(results)} URLs, total {total_bytes} bytes' + print(f'[sync] {summary}', flush=True) + return summary + + +def main() -> None: + urls = [ + 'https://httpbin.org/uuid', + 'https://httpbin.org/get', + 'https://httpbin.org/headers', + ] + + wfr.start() + sleep(5) # wait for workflow runtime to start + + wf_client = wf.DaprWorkflowClient() + instance_id = wf_client.schedule_new_workflow(workflow=parallel_fetch_workflow, input=urls) + print(f'Workflow started. Instance ID: {instance_id}') + + state = wf_client.wait_for_workflow_completion(instance_id, timeout_in_seconds=60) + assert state is not None + print(f'Workflow completed! Status: {state.runtime_status.name}') + print(f'Workflow result: {state.serialized_output.strip(chr(34))}') + + wfr.shutdown() + + +if __name__ == '__main__': + main() diff --git a/ext/dapr-ext-workflow/AGENTS.md b/ext/dapr-ext-workflow/AGENTS.md index 635cb8705..9aee2fffd 100644 --- a/ext/dapr-ext-workflow/AGENTS.md +++ b/ext/dapr-ext-workflow/AGENTS.md @@ -105,6 +105,24 @@ The entry point for registration and lifecycle: Internally wraps user functions: workflow functions get a `DaprWorkflowContext`, activity functions get a `WorkflowActivityContext`. Tracks registration state via `_workflow_registered` / `_activity_registered` attributes on functions to prevent double registration. +#### Sync and async activities + +Activities can be either `def my_activity(ctx, inp)` or `async def my_activity(ctx, inp)`. At registration, `_make_activity_wrapper` calls `_is_async_callable(fn)` to detect async-ness. That helper unwraps `functools.partial`, `@functools.wraps` chains, and callable-class `__call__` so common decorator patterns route correctly. The wrapper is built `async def` or `def` to match, then stored in the registry. + +At dispatch time (the gRPC stream loop in `_durabletask/worker.py`), `inspect.iscoroutinefunction(activity_fn)` on the wrapper selects between two handlers. + +- **Async activities** go through `_execute_activity_async`, then `_ActivityExecutor.execute_async`, which awaits `fn(...)` directly on the event loop. No thread pool involvement. The gRPC response is delivered via `loop.run_in_executor(None, stub.CompleteActivityTask, ...)` (asyncio's default executor). +- **Sync activities** go through `_execute_activity`, dispatched to the thread pool by `_AsyncWorkerManager._run_func`. The activity runs on a worker thread, and the response is delivered from the same thread. The thread pool size is controlled by `maximum_thread_pool_workers`. + +Workflow (orchestrator) functions must remain generators (`def` with `yield`). They cannot be `async def` because durabletask's deterministic replay depends on synchronous generator semantics. Only activities support async. + +**Decorator ordering gotcha.** Stacking `@wfr.activity` over `@alternate_name(...)` over `async def` works because `@alternate_name` now emits an `async def innerfn` when the wrapped function is async. A user-written decorator that wraps an async function in a sync `def` (without `@functools.wraps` exposing `__wrapped__`) defeats `_is_async_callable`, routes the activity to the sync path, and produces an un-awaited coroutine. Such decorators should use `@functools.wraps(fn)` so the unwrap walks through them. + +**`maximum_thread_pool_workers` gotcha.** This knob sizes the sync-activity thread pool only. Async-activity response delivery uses asyncio's default executor (process-wide, lazily sized to `min(32, cpu_count + 4)`), which is not capped by this knob. Strict thread-count bounds for async response delivery require calling `asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=N))` before `wfr.start()`. A future PR may migrate the worker to `grpc.aio` and remove this caveat by sending responses without any thread pool. + +**Concurrency sizing and load characterization.** See `docs/concurrency.md` for sizing recommendations (`maximum_concurrent_activity_work_items`, `maximum_thread_pool_workers`), an async-vs-sync decision tree, and the default-executor caveat with a worked example. The `benchmarks/` directory ships `bench_async_activities.py` and the generated `RESULTS.md`; re-run it locally before claiming a perf regression — the report captures the run environment so a reader can tell whether a number applies to their hardware. + + ### DaprWorkflowClient (`dapr_workflow_client.py`) Client for workflow lifecycle management: @@ -163,7 +181,7 @@ Retry configuration for activities and child workflows: 1. **Registration**: User decorates functions with `@wfr.workflow` / `@wfr.activity`. The runtime wraps them and stores them in the durabletask worker's registry. 2. **Startup**: `wfr.start()` opens a gRPC stream to the Dapr sidecar. The worker polls for work items. 3. **Scheduling**: Client calls `schedule_new_workflow(fn, input=...)`. The function's name (or `_dapr_alternate_name`) is sent to the backend. -4. **Execution**: The durabletask engine dispatches work items. Workflow functions are Python **generators** that `yield` tasks (activity calls, timers, child workflows). The engine records history; on replay, yielded tasks return cached results without re-executing. +4. **Execution**: The durabletask engine dispatches work items. Workflow functions are Python **generators** that `yield` tasks (activity calls, timers, child workflows). Activity functions are either sync (dispatched to the worker's thread pool) or `async def` (awaited directly on the worker's event loop). The engine records history; on replay, yielded tasks return cached results without re-executing. 5. **Determinism**: Workflows must be deterministic — no random, no wall-clock time, no I/O. Use `ctx.current_utc_datetime` instead of `datetime.now()`. Use `ctx.is_replaying` to guard side effects like logging. 6. **Completion**: Client polls via `wait_for_workflow_completion()` or `get_workflow_state()`. @@ -191,6 +209,7 @@ Two example directories exercise workflows: - `cross-app1.py`, `cross-app2.py`, `cross-app3.py` — cross-app calls - `versioning.py` — workflow versioning with `is_patched()` - `simple_aio_client.py` — async client variant + - `async_activities.py` — `async def` activities (HTTP fan-out with `httpx.AsyncClient`) ## Testing diff --git a/ext/dapr-ext-workflow/benchmarks/RESULTS.md b/ext/dapr-ext-workflow/benchmarks/RESULTS.md new file mode 100644 index 000000000..7d391e244 --- /dev/null +++ b/ext/dapr-ext-workflow/benchmarks/RESULTS.md @@ -0,0 +1,132 @@ +# Async-activity load benchmark results + +Generated by `bench_async_activities.py`. Re-run with: + +```bash +uv run python ext/dapr-ext-workflow/benchmarks/bench_async_activities.py +``` + +## Run environment + +- **Timestamp**: 2026-05-25 20:40:09 UTC +- **Git commit**: `8f13da0-dirty` +- **Python**: CPython 3.13.12 +- **OS**: Darwin 25.5.0 (arm64) on Apple M3 Pro (12 logical cores), 36.0 GB +- **asyncio default executor**: `max_workers=16` (`min(32, cpu_count + 4)`) +- **CI environment**: no + +Numbers are specific to this hardware. Re-run locally to compare. The shape of +the curves (throughput plateau, p99 inflection, drift) is what to compare +across machines. + +Each scenario drives `TaskHubGrpcWorker._execute_activity_async` through +`_AsyncWorkerManager` against a mock `CompleteActivityTask` stub. End-to-end +latency is measured from `submit_activity` to the mock stub seeing the response. + +## 1. Concurrency win (issue #897 repro) + +100 × 1 s HTTP fetches. Async runs them concurrently on the loop, sync gates +them through the thread pool. + +| Scenario | Wallclock (s) | Tput/s | Peak tasks | Peak RSS Δ (MB) | +| --- | ---: | ---: | ---: | ---: | +| Async fan-out | 1.47 | 68.1 | 305 | 86.4 | +| Sync baseline | 13.34 | 7.5 | 121 | 2.4 | + +## 2. Throughput scaling + +Async fan-out, 50 ms activity, sem=5000, pool=16. Throughput plateaus around +N=2500. + +| N | Wallclock (s) | Tput/s | p50 ms | p95 ms | p99 ms | Peak tasks | Peak RSS Δ (MB) | +| ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| 100 | 0.06 | 1542.3 | 62.0 | 64.1 | 64.1 | 105 | 0.0 | +| 500 | 0.08 | 5931.1 | 78.6 | 79.6 | 79.6 | 505 | 0.4 | +| 1000 | 0.11 | 8956.5 | 102.9 | 106.2 | 106.3 | 1005 | 2.9 | +| 2500 | 0.24 | 10532.0 | 218.8 | 225.3 | 225.9 | 2505 | 10.0 | +| 5000 | 0.57 | 8696.7 | 543.8 | 557.2 | 558.7 | 5005 | 25.2 | + +## 3. Semaphore-cap sensitivity + +N=2500, 50 ms activity, pool=16. Caps below ~500 starve the loop. Gains +compress above ~1000. + +| Sem | Wallclock (s) | Tput/s | p50 ms | p95 ms | p99 ms | +| ---: | ---: | ---: | ---: | ---: | ---: | +| 50 | 2.69 | 928.6 | 1422.7 | 2583.5 | 2687.0 | +| 100 | 1.42 | 1758.2 | 794.9 | 1360.7 | 1412.0 | +| 500 | 0.40 | 6229.5 | 279.2 | 387.9 | 392.3 | +| 1000 | 0.30 | 8322.3 | 235.6 | 286.9 | 290.2 | +| 5000 | 0.23 | 10720.7 | 215.0 | 222.3 | 222.8 | + +## 4. Failure threshold (queue-wait inflection) + +Cap=1000, ramp N, 50 ms activity. p99 first exceeds 2× server latency at +**N=1000** (p99 = 104.7 ms). + +| N | Wallclock (s) | Tput/s | p50 ms | p95 ms | p99 ms | +| ---: | ---: | ---: | ---: | ---: | ---: | +| 500 | 0.08 | 6264.2 | 70.6 | 77.4 | 77.5 | +| 1000 | 0.11 | 9145.3 | 94.5 | 104.2 | 104.7 | +| 2500 | 0.31 | 8086.8 | 243.6 | 294.2 | 298.0 | +| 5000 | 0.72 | 6983.2 | 584.2 | 691.1 | 700.5 | +| 10000 | 2.08 | 4813.1 | 1801.7 | 2019.3 | 2046.2 | + +## 5. Sidecar response delivery overhead + +N=1000, sem=1000, pool=16, 50 ms activity. Mock `CompleteActivityTask` given +an artificial delay. Async responses go through `loop.run_in_executor(None, ...)`, +sharing asyncio's default executor (`max_workers=16` here). Delays past ~5 ms +saturate that pool. + +| Delivery | Wallclock (s) | Tput/s | p50 ms | p95 ms | p99 ms | +| ---: | ---: | ---: | ---: | ---: | ---: | +| 0 ms | 0.11 | 9497.2 | 98.2 | 101.3 | 101.5 | +| 1 ms | 0.18 | 5699.8 | 133.0 | 167.7 | 171.0 | +| 5 ms | 0.48 | 2077.9 | 287.7 | 458.6 | 473.4 | +| 10 ms | 0.86 | 1162.5 | 494.1 | 820.4 | 843.5 | + +## 6. Sustained load + +200/s for 120 s, 50 ms activity. Submitted/completed: 24 000 / 24 000. +Wallclock 120.05 s (effective 199.9/s). + +- p50 50.2 ms, p95 50.6 ms, p99 50.8 ms, max 62.8 ms. +- First-25% p99 50.8 ms, last-25% p99 50.7 ms. No drift. +- Peak tasks 19, peak queue depth 3, peak RSS Δ 5.8 MB. + +## 7. Real HTTP workload + +Each activity opens a fresh `httpx.AsyncClient` and GETs an aiohttp endpoint +sleeping 50 ms. Mirrors `examples/workflow/async_activities.py`. Pool=16 for +all rows. + +| Scenario | N | Sem | Wallclock (s) | Tput/s | p50 ms | p95 ms | p99 ms | Peak tasks | Peak RSS Δ (MB) | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| Async | 100 | 1000 | 0.49 | 205.3 | 485.1 | 485.4 | 485.5 | 305 | 0.0 | +| Async | 500 | 1000 | 2.06 | 243.2 | 1990.2 | 2052.6 | 2053.0 | 1376 | 308.1 | +| Async | 1000 | 1000 | 4.28 | 233.4 | 4200.5 | 4274.9 | 4280.5 | 2555 | 398.5 | +| Async | 2500 | 5000 | 15.16 | 165.0 | 10240.9 | 13260.9 | 15111.6 | 5776 | 1219.1 | +| Sync | 100 | 1000 | 0.51 | 194.2 | 324.6 | 458.5 | 514.4 | 137 | 0.7 | + +## 8. Real HTTP sustained load + +Open-loop 100/s for 60 s with real `httpx.AsyncClient`. Submitted/completed: +6000 / 6000. Wallclock 60.05 s (effective 99.9/s). + +- p50 56.1 ms, p95 68.9 ms, p99 76.0 ms, max 145.2 ms. +- First-25% p99 75.7 ms, last-25% p99 76.2 ms. No drift. +- Peak tasks 45, peak queue depth 6, peak RSS Δ 0.0 MB. + +## 9. OOM safety + +10 000 in-flight async activities, 50 ms, sem=1000, pool=8. ~9 000 Tasks +parked on the semaphore. Peak RSS Δ stays well under the 500 MB budget. + +| N | Sem | Wallclock (s) | Tput/s | Peak tasks | Peak RSS Δ (MB) | +| ---: | ---: | ---: | ---: | ---: | ---: | +| 10000 | 1000 | 2.03 | 4918.2 | 10005 | 0.0 | + +## Operational guidance + +See `ext/dapr-ext-workflow/docs/concurrency.md` for sizing recommendations. \ No newline at end of file diff --git a/ext/dapr-ext-workflow/benchmarks/bench_async_activities.py b/ext/dapr-ext-workflow/benchmarks/bench_async_activities.py new file mode 100644 index 000000000..6a5b8a99d --- /dev/null +++ b/ext/dapr-ext-workflow/benchmarks/bench_async_activities.py @@ -0,0 +1,1457 @@ +# -*- coding: utf-8 -*- +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async-activity load benchmarks for ``dapr-ext-workflow``. + +Drives the production dispatch path (``TaskHubGrpcWorker._execute_activity_async`` +and ``_execute_activity``) through ``_AsyncWorkerManager`` against a mock sidecar +stub. Captures end-to-end latency (submit -> response delivery), peak in-flight +Tasks, peak RSS, and steady-state behavior so the sidecar response path is part +of the measurement instead of being skipped. + +Run: + + uv run python ext/dapr-ext-workflow/benchmarks/bench_async_activities.py + +Set ``DAPR_BENCH_SUSTAINED_SECONDS`` to override the 120 s sustained run. +Set ``DAPR_BENCH_WITH_SIDECAR=1`` to run the opt-in end-to-end scenario against +a real Dapr sidecar (requires ``dapr run`` wrapping the script). + +Writes ``benchmarks/RESULTS.md`` and asserts pass-criteria budgets so regressions +fail loudly. +""" + +from __future__ import annotations + +import asyncio +import logging +import math +import os +import platform +import resource +import shutil +import socket +import statistics +import subprocess +import sys +import time +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import AsyncIterator, Awaitable, Callable + +import dapr.ext.workflow._durabletask.internal.protos as pb +import httpx +from aiohttp import web +from dapr.ext.workflow._durabletask import task +from dapr.ext.workflow._durabletask.worker import ( + ConcurrencyOptions, + TaskHubGrpcWorker, + _AsyncWorkerManager, +) + +LOGGER = logging.getLogger('bench') +RESULTS_PATH = Path(__file__).parent / 'RESULTS.md' +IS_DARWIN = sys.platform == 'darwin' + +SUSTAINED_DURATION_S = float(os.environ.get('DAPR_BENCH_SUSTAINED_SECONDS', '120')) + + +# ============================================================================ +# Data classes +# ============================================================================ + + +@dataclass(slots=True) +class LatencyStats: + """Summary statistics for a population of end-to-end latency samples.""" + + count: int + mean_ms: float + p50_ms: float + p95_ms: float + p99_ms: float + max_ms: float + + @classmethod + def from_samples(cls, samples_s: list[float]) -> 'LatencyStats': + if not samples_s: + return cls(count=0, mean_ms=0.0, p50_ms=0.0, p95_ms=0.0, p99_ms=0.0, max_ms=0.0) + samples_ms = sorted(s * 1000.0 for s in samples_s) + return cls( + count=len(samples_ms), + mean_ms=statistics.fmean(samples_ms), + p50_ms=_percentile(samples_ms, 0.50), + p95_ms=_percentile(samples_ms, 0.95), + p99_ms=_percentile(samples_ms, 0.99), + max_ms=samples_ms[-1], + ) + + +@dataclass(slots=True) +class ScenarioMetrics: + """Per-scenario summary written to the results table.""" + + name: str + n_items: int + semaphore_cap: int + thread_pool_workers: int + server_latency_s: float + wallclock_s: float + throughput_per_s: float + latency: LatencyStats + peak_tasks: int + peak_queue_depth: int + peak_rss_delta_mb: float + notes: str = '' + + +@dataclass +class _Sampler: + """Background sampler for in-flight task count, queue depth, and RSS.""" + + interval_s: float = 0.05 + peak_tasks: int = 0 + peak_rss_kb: int = 0 + peak_queue_depth: int = 0 + _queues: list[asyncio.Queue] = field(default_factory=list) + _stop_event: asyncio.Event = field(default_factory=asyncio.Event) + + def watch_queue(self, q: asyncio.Queue | None) -> None: + if q is not None: + self._queues.append(q) + + async def run(self) -> None: + while not self._stop_event.is_set(): + self.peak_tasks = max(self.peak_tasks, len(asyncio.all_tasks())) + self.peak_rss_kb = max(self.peak_rss_kb, _current_rss_kb()) + for q in self._queues: + self.peak_queue_depth = max(self.peak_queue_depth, q.qsize()) + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=self.interval_s) + except asyncio.TimeoutError: + continue + + def stop(self) -> None: + self._stop_event.set() + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _percentile(sorted_samples_ms: list[float], q: float) -> float: + if not sorted_samples_ms: + return 0.0 + if len(sorted_samples_ms) == 1: + return sorted_samples_ms[0] + pos = q * (len(sorted_samples_ms) - 1) + lo = math.floor(pos) + hi = math.ceil(pos) + if lo == hi: + return sorted_samples_ms[lo] + frac = pos - lo + return sorted_samples_ms[lo] + frac * (sorted_samples_ms[hi] - sorted_samples_ms[lo]) + + +def _current_rss_kb() -> int: + """Process RSS in KB. macOS returns bytes from getrusage; Linux returns KB.""" + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if IS_DARWIN: + return rss // 1024 + return rss + + +def _read_text(path: str) -> str: + try: + return Path(path).read_text(encoding='utf-8', errors='ignore') + except OSError: + return '' + + +def _cpu_model() -> str: + """Best-effort CPU model name. Cross-platform; returns a placeholder on failure.""" + if IS_DARWIN: + sysctl = shutil.which('sysctl') + if sysctl is not None: + try: + out = subprocess.run( + [sysctl, '-n', 'machdep.cpu.brand_string'], + capture_output=True, + text=True, + timeout=2, + ) + if out.returncode == 0 and out.stdout.strip(): + return out.stdout.strip() + except (subprocess.SubprocessError, OSError): + pass + cpuinfo = _read_text('/proc/cpuinfo') + for line in cpuinfo.splitlines(): + if line.startswith('model name'): + return line.split(':', 1)[1].strip() + return platform.processor() or platform.machine() or 'unknown' + + +def _total_memory_gb() -> float: + """Best-effort total physical memory in GB. Returns 0 on failure.""" + if IS_DARWIN: + sysctl = shutil.which('sysctl') + if sysctl is not None: + try: + out = subprocess.run( + [sysctl, '-n', 'hw.memsize'], + capture_output=True, + text=True, + timeout=2, + ) + if out.returncode == 0 and out.stdout.strip().isdigit(): + return int(out.stdout.strip()) / (1024**3) + except (subprocess.SubprocessError, OSError): + pass + meminfo = _read_text('/proc/meminfo') + for line in meminfo.splitlines(): + if line.startswith('MemTotal:'): + parts = line.split() + if len(parts) >= 2 and parts[1].isdigit(): + return int(parts[1]) / (1024**2) + return 0.0 + + +def _git_commit() -> str: + """Short git commit hash, or 'unknown' if not in a git repo.""" + git = shutil.which('git') + if git is None: + return 'unknown' + try: + out = subprocess.run( + [git, 'rev-parse', '--short', 'HEAD'], + capture_output=True, + text=True, + timeout=2, + cwd=Path(__file__).parent, + ) + if out.returncode == 0: + commit = out.stdout.strip() + # Mark dirty if there are uncommitted changes. + status = subprocess.run( + [git, 'status', '--porcelain'], + capture_output=True, + text=True, + timeout=2, + cwd=Path(__file__).parent, + ) + if status.returncode == 0 and status.stdout.strip(): + return f'{commit}-dirty' + return commit + except (subprocess.SubprocessError, OSError): + pass + return 'unknown' + + +def _default_executor_size() -> int: + """Effective max_workers for asyncio's default executor.""" + return min(32, (os.cpu_count() or 1) + 4) + + +@dataclass(slots=True) +class RunEnvironment: + """Snapshot of the machine the benchmark ran on.""" + + timestamp_utc: str + git_commit: str + python_version: str + python_implementation: str + platform: str + os_release: str + cpu_model: str + cpu_logical_cores: int + cpu_physical_cores_hint: int + total_memory_gb: float + asyncio_default_executor_workers: int + is_ci: bool + + @classmethod + def capture(cls) -> 'RunEnvironment': + return cls( + timestamp_utc=datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC'), + git_commit=_git_commit(), + python_version=platform.python_version(), + python_implementation=platform.python_implementation(), + platform=platform.platform(), + os_release=f'{platform.system()} {platform.release()} ({platform.machine()})', + cpu_model=_cpu_model(), + cpu_logical_cores=os.cpu_count() or 0, + cpu_physical_cores_hint=os.cpu_count() or 0, + total_memory_gb=_total_memory_gb(), + asyncio_default_executor_workers=_default_executor_size(), + is_ci=any(os.environ.get(k) for k in ('CI', 'GITHUB_ACTIONS', 'TRAVIS', 'BUILDKITE')), + ) + + +# ============================================================================ +# Mock sidecar stub (production response path goes through here) +# ============================================================================ + + +class _MockSidecarStub: + """In-process stand-in for ``TaskHubSidecarServiceStub``. + + ``_execute_activity_async`` and ``_execute_activity`` deliver responses via + ``stub.CompleteActivityTask``. The mock records completion timestamps so the + harness can compute end-to-end latency (submit -> delivery). ``send_latency_s`` + simulates a slow sidecar — useful for the response-delivery-overhead scenario. + """ + + def __init__(self, send_latency_s: float = 0.0): + self.send_latency_s = send_latency_s + self.completions: dict[int, float] = {} + self.calls = 0 + + def Hello(self, *_args, **_kwargs) -> None: # noqa: N802 + return None + + def CompleteActivityTask(self, response: pb.ActivityResponse) -> None: # noqa: N802 + if self.send_latency_s > 0: + time.sleep(self.send_latency_s) + self.completions[response.taskId] = time.perf_counter() + self.calls += 1 + + def CompleteOrchestratorTask(self, *_args, **_kwargs) -> None: # noqa: N802 + return None + + +def _build_activity_request(name: str, task_id: int, instance_id: str) -> pb.ActivityRequest: + return pb.ActivityRequest( + name=name, + taskId=task_id, + workflowInstance=pb.WorkflowInstance(instanceId=instance_id), + parentTraceContext=pb.TraceContext(traceParent=''), + taskExecutionId='', + ) + + +# ============================================================================ +# Activity factories — record per-invocation timestamps so the harness can +# decompose end-to-end latency into queue-wait / work / delivery. +# ============================================================================ + + +def _async_sleep_factory( + latency_s: float, start_ts: dict[int, float], end_ts: dict[int, float] +) -> Callable[[task.ActivityContext, object], Awaitable[None]]: + """Build an async activity that sleeps. Records per-task start/end timestamps.""" + + async def sleep(ctx: task.ActivityContext, _inp: object) -> None: + start_ts[ctx.task_id] = time.perf_counter() + await asyncio.sleep(latency_s) + end_ts[ctx.task_id] = time.perf_counter() + + return sleep + + +def _sync_sleep_factory( + latency_s: float, start_ts: dict[int, float], end_ts: dict[int, float] +) -> Callable[[task.ActivityContext, object], None]: + """Build a sync activity that sleeps. Records per-task start/end timestamps.""" + + def sleep(ctx: task.ActivityContext, _inp: object) -> None: + start_ts[ctx.task_id] = time.perf_counter() + time.sleep(latency_s) + end_ts[ctx.task_id] = time.perf_counter() + + return sleep + + +def _async_fetch_factory( + url: str, start_ts: dict[int, float], end_ts: dict[int, float] +) -> Callable[[task.ActivityContext, object], Awaitable[int]]: + """Build an async HTTP-fetch activity that mirrors a real user pattern.""" + + async def fetch(ctx: task.ActivityContext, _inp: object) -> int: + start_ts[ctx.task_id] = time.perf_counter() + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + end_ts[ctx.task_id] = time.perf_counter() + return response.status_code + + return fetch + + +def _sync_fetch_factory( + url: str, start_ts: dict[int, float], end_ts: dict[int, float] +) -> Callable[[task.ActivityContext, object], int]: + def fetch(ctx: task.ActivityContext, _inp: object) -> int: + start_ts[ctx.task_id] = time.perf_counter() + with httpx.Client(timeout=30.0) as client: + response = client.get(url) + end_ts[ctx.task_id] = time.perf_counter() + return response.status_code + + return fetch + + +@asynccontextmanager +async def _slow_aiohttp_server(latency_s: float) -> AsyncIterator[str]: + """Local aiohttp server that returns JSON after ``latency_s`` seconds.""" + + async def handler(_request: web.Request) -> web.Response: + await asyncio.sleep(latency_s) + return web.json_response({'ok': True, 'latency_s': latency_s}) + + app = web.Application() + app.router.add_get('/', handler) + runner = web.AppRunner(app) + await runner.setup() + + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + listener.bind(('127.0.0.1', 0)) + port = listener.getsockname()[1] + site = web.SockSite(runner, listener) + await site.start() + except BaseException: + listener.close() + raise + base_url = f'http://127.0.0.1:{port}/' + try: + yield base_url + finally: + await runner.cleanup() + + +# ============================================================================ +# Full-path harness — exercises _execute_activity_async / _execute_activity +# through _AsyncWorkerManager with a mock CompleteActivityTask stub. +# ============================================================================ + + +def _build_worker(options: ConcurrencyOptions) -> TaskHubGrpcWorker: + """Build a TaskHubGrpcWorker without calling start(). We only need its dispatch + code and registry; the gRPC stream is replaced by the mock stub. + """ + return TaskHubGrpcWorker( + host_address='in-process-mock', + concurrency_options=options, + ) + + +ActivityFactory = Callable[[dict[int, float], dict[int, float]], Callable[..., object]] + + +async def _run_full( + *, + name: str, + n_items: int, + semaphore_cap: int, + thread_pool_workers: int, + server_latency_s: float, + activity_kind: str, + activity_factory: ActivityFactory | None = None, + send_latency_s: float = 0.0, + notes: str = '', +) -> ScenarioMetrics: + """Submit ``n_items`` activities through the production dispatch path. + + Registers an async or sync activity on the worker's registry, builds real + ``pb.ActivityRequest`` protos, and submits ``_execute_activity_async`` / + ``_execute_activity`` to ``_AsyncWorkerManager``. The mock stub captures the + completion timestamp per task so we can compute end-to-end latency. + + ``activity_factory`` defaults to ``asyncio.sleep`` / ``time.sleep`` (synthetic + work). Pass a custom factory (e.g. ``_async_fetch_factory(url, ...)``) to + exercise real I/O instead. + """ + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=semaphore_cap, + maximum_concurrent_orchestration_work_items=semaphore_cap, + maximum_thread_pool_workers=thread_pool_workers, + ) + worker = _build_worker(options) + manager = worker._async_worker_manager + stub = _MockSidecarStub(send_latency_s=send_latency_s) + + start_ts: dict[int, float] = {} + end_ts: dict[int, float] = {} + activity_fn: Callable[..., object] + if activity_kind == 'async': + activity_fn = ( + activity_factory(start_ts, end_ts) + if activity_factory is not None + else _async_sleep_factory(server_latency_s, start_ts, end_ts) + ) + handler = worker._execute_activity_async + elif activity_kind == 'sync': + activity_fn = ( + activity_factory(start_ts, end_ts) + if activity_factory is not None + else _sync_sleep_factory(server_latency_s, start_ts, end_ts) + ) + handler = worker._execute_activity + else: + raise ValueError(f'unknown activity_kind: {activity_kind}') + + activity_name = f'bench_{activity_kind}' + worker._registry.add_named_activity(activity_name, activity_fn) + + baseline_rss_kb = _current_rss_kb() + sampler = _Sampler() + sampler_task = asyncio.create_task(sampler.run()) + worker_task = asyncio.create_task(manager.run()) + + # Wait for the manager to set up its activity queue, then attach the sampler. + while manager.activity_queue is None: + await asyncio.sleep(0) + sampler.watch_queue(manager.activity_queue) + + submit_ts: dict[int, float] = {} + submit_start = time.perf_counter() + for i in range(n_items): + req = _build_activity_request(activity_name, task_id=i, instance_id='bench') + submit_ts[i] = time.perf_counter() + manager.submit_activity(handler, req, stub, '') + + await manager.activity_queue.join() + wallclock_s = time.perf_counter() - submit_start + + manager._shutdown = True + sampler.stop() + await asyncio.gather(worker_task, sampler_task, return_exceptions=True) + manager.shutdown() + + e2e_samples: list[float] = [] + for task_id, t_submit in submit_ts.items(): + t_complete = stub.completions.get(task_id) + if t_complete is not None: + e2e_samples.append(t_complete - t_submit) + + throughput = len(e2e_samples) / wallclock_s if wallclock_s > 0 else 0.0 + return ScenarioMetrics( + name=name, + n_items=n_items, + semaphore_cap=semaphore_cap, + thread_pool_workers=thread_pool_workers, + server_latency_s=server_latency_s, + wallclock_s=wallclock_s, + throughput_per_s=throughput, + latency=LatencyStats.from_samples(e2e_samples), + peak_tasks=sampler.peak_tasks, + peak_queue_depth=sampler.peak_queue_depth, + peak_rss_delta_mb=max(0.0, (sampler.peak_rss_kb - baseline_rss_kb) / 1024.0), + notes=notes, + ) + + +# ============================================================================ +# Lite harness — used by the OOM safety test where we just need raw Task +# bookkeeping with no proto/stub overhead. +# ============================================================================ + + +def _make_activity_context(orchestration_id: str, task_id: int) -> task.ActivityContext: + return task.ActivityContext(orchestration_id, task_id, '', propagated_history=None) + + +async def _run_lite( + *, + name: str, + activity: Callable, + n_items: int, + semaphore_cap: int, + thread_pool_workers: int, + server_latency_s: float, + notes: str = '', +) -> ScenarioMetrics: + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=semaphore_cap, + maximum_concurrent_orchestration_work_items=semaphore_cap, + maximum_thread_pool_workers=thread_pool_workers, + ) + manager = _AsyncWorkerManager(options, logger=LOGGER) + + baseline_rss_kb = _current_rss_kb() + sampler = _Sampler() + sampler_task = asyncio.create_task(sampler.run()) + worker_task = asyncio.create_task(manager.run()) + + while manager.activity_queue is None: + await asyncio.sleep(0) + sampler.watch_queue(manager.activity_queue) + + for i in range(n_items): + ctx = _make_activity_context('bench', i) + manager.submit_activity(activity, ctx, None) + + start = time.perf_counter() + await manager.activity_queue.join() + wallclock_s = time.perf_counter() - start + + manager._shutdown = True + sampler.stop() + await asyncio.gather(worker_task, sampler_task, return_exceptions=True) + manager.shutdown() + + throughput = n_items / wallclock_s if wallclock_s > 0 else 0.0 + return ScenarioMetrics( + name=name, + n_items=n_items, + semaphore_cap=semaphore_cap, + thread_pool_workers=thread_pool_workers, + server_latency_s=server_latency_s, + wallclock_s=wallclock_s, + throughput_per_s=throughput, + latency=LatencyStats.from_samples([]), + peak_tasks=sampler.peak_tasks, + peak_queue_depth=sampler.peak_queue_depth, + peak_rss_delta_mb=max(0.0, (sampler.peak_rss_kb - baseline_rss_kb) / 1024.0), + notes=notes, + ) + + +# ============================================================================ +# Sustained-load harness — open-loop submission at a target rate for D seconds. +# ============================================================================ + + +@dataclass(slots=True) +class SustainedMetrics: + """Steady-state metrics for the sustained-load scenario.""" + + target_rate_per_s: float + duration_s: float + submitted: int + completed: int + wallclock_s: float + throughput_per_s: float + latency_overall: LatencyStats + latency_first_quarter: LatencyStats + latency_last_quarter: LatencyStats + peak_tasks: int + peak_queue_depth: int + peak_rss_delta_mb: float + + +async def _run_sustained( + *, + duration_s: float, + target_rate_per_s: float, + semaphore_cap: int, + thread_pool_workers: int, + server_latency_s: float, + activity_factory: ActivityFactory | None = None, +) -> SustainedMetrics: + """Continuously submit async activities for ``duration_s`` at a target rate. + + Records per-task submit/end timestamps so the harness can split tail latency + by quarter of the run, exposing drift. ``activity_factory`` defaults to + ``asyncio.sleep``; pass an HTTP fetch factory to exercise real I/O. + """ + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=semaphore_cap, + maximum_concurrent_orchestration_work_items=semaphore_cap, + maximum_thread_pool_workers=thread_pool_workers, + ) + worker = _build_worker(options) + manager = worker._async_worker_manager + stub = _MockSidecarStub() + start_ts: dict[int, float] = {} + end_ts: dict[int, float] = {} + activity_fn = ( + activity_factory(start_ts, end_ts) + if activity_factory is not None + else _async_sleep_factory(server_latency_s, start_ts, end_ts) + ) + activity_name = 'bench_sustained' + worker._registry.add_named_activity(activity_name, activity_fn) + + baseline_rss_kb = _current_rss_kb() + sampler = _Sampler() + sampler_task = asyncio.create_task(sampler.run()) + worker_task = asyncio.create_task(manager.run()) + + while manager.activity_queue is None: + await asyncio.sleep(0) + sampler.watch_queue(manager.activity_queue) + + submit_ts: dict[int, float] = {} + submit_interval = 1.0 / target_rate_per_s + + submitter_done = asyncio.Event() + submitted = 0 + bench_start = time.perf_counter() + + async def submitter() -> None: + nonlocal submitted + try: + next_submit = bench_start + while True: + now = time.perf_counter() + if now - bench_start >= duration_s: + return + if now >= next_submit: + req = _build_activity_request(activity_name, submitted, 'bench-sus') + submit_ts[submitted] = now + manager.submit_activity(worker._execute_activity_async, req, stub, '') + submitted += 1 + next_submit += submit_interval + continue + wait_s = max(0.0, next_submit - time.perf_counter()) + await asyncio.sleep(wait_s) + finally: + submitter_done.set() + + sub_task = asyncio.create_task(submitter()) + await sub_task + await manager.activity_queue.join() + wallclock_s = time.perf_counter() - bench_start + + manager._shutdown = True + sampler.stop() + await asyncio.gather(worker_task, sampler_task, return_exceptions=True) + manager.shutdown() + + e2e_samples_with_submit: list[tuple[float, float]] = [] + for task_id, t_submit in submit_ts.items(): + t_complete = stub.completions.get(task_id) + if t_complete is not None: + e2e_samples_with_submit.append((t_submit, t_complete - t_submit)) + + e2e_samples_with_submit.sort(key=lambda x: x[0]) + overall = [s for _, s in e2e_samples_with_submit] + quarter_size = max(1, len(overall) // 4) + first_quarter = overall[:quarter_size] + last_quarter = overall[-quarter_size:] + + return SustainedMetrics( + target_rate_per_s=target_rate_per_s, + duration_s=duration_s, + submitted=submitted, + completed=len(overall), + wallclock_s=wallclock_s, + throughput_per_s=len(overall) / wallclock_s if wallclock_s > 0 else 0.0, + latency_overall=LatencyStats.from_samples(overall), + latency_first_quarter=LatencyStats.from_samples(first_quarter), + latency_last_quarter=LatencyStats.from_samples(last_quarter), + peak_tasks=sampler.peak_tasks, + peak_queue_depth=sampler.peak_queue_depth, + peak_rss_delta_mb=max(0.0, (sampler.peak_rss_kb - baseline_rss_kb) / 1024.0), + ) + + +# ============================================================================ +# Scenario runners +# ============================================================================ + + +async def run_concurrency_win() -> list[ScenarioMetrics]: + """Issue #897 repro: async fan-out vs sync baseline at 100 x 1 s activities.""" + server_latency = 1.0 + n_items = 100 + async with _slow_aiohttp_server(server_latency) as url: + start_ts_async: dict[int, float] = {} + end_ts_async: dict[int, float] = {} + async_metrics = await _run_lite( + name='Async fan-out (issue #897 repro)', + activity=_async_fetch_factory(url, start_ts_async, end_ts_async), + n_items=n_items, + semaphore_cap=1000, + thread_pool_workers=8, + server_latency_s=server_latency, + notes='100 awaits run concurrently on the loop', + ) + start_ts_sync: dict[int, float] = {} + end_ts_sync: dict[int, float] = {} + sync_metrics = await _run_lite( + name='Sync baseline (pre-#897 behavior)', + activity=_sync_fetch_factory(url, start_ts_sync, end_ts_sync), + n_items=n_items, + semaphore_cap=1000, + thread_pool_workers=8, + server_latency_s=server_latency, + notes='gated by thread pool size, demonstrates the bug from #897', + ) + return [async_metrics, sync_metrics] + + +async def run_throughput_scaling() -> list[ScenarioMetrics]: + """Vary N at fixed 50 ms server latency. Capture throughput plateau.""" + server_latency = 0.05 + semaphore_cap = 5000 + thread_pool_workers = 16 + grid = [100, 500, 1000, 2500, 5000] + metrics: list[ScenarioMetrics] = [] + for n in grid: + m = await _run_full( + name=f'Throughput N={n}', + n_items=n, + semaphore_cap=semaphore_cap, + thread_pool_workers=thread_pool_workers, + server_latency_s=server_latency, + activity_kind='async', + notes='full _execute_activity_async path + mock CompleteActivityTask', + ) + metrics.append(m) + return metrics + + +async def run_semaphore_sensitivity() -> list[ScenarioMetrics]: + """Vary semaphore cap at fixed N=2500 / 50 ms. Shows cap-side trade-off.""" + server_latency = 0.05 + n_items = 2500 + thread_pool_workers = 16 + grid = [50, 100, 500, 1000, 5000] + metrics: list[ScenarioMetrics] = [] + for cap in grid: + m = await _run_full( + name=f'Sem cap={cap}', + n_items=n_items, + semaphore_cap=cap, + thread_pool_workers=thread_pool_workers, + server_latency_s=server_latency, + activity_kind='async', + notes=( + 'lower caps serialize the batch through fewer parallel slots' + if cap <= 100 + else 'caps above N x latency yield no further gain' + ), + ) + metrics.append(m) + return metrics + + +async def run_failure_threshold() -> list[ScenarioMetrics]: + """Hold cap=1000 / 50 ms and ramp N. The threshold is the first row where + p99 exceeds 2 x server_latency, marking the regime where queue wait + dominates work.""" + server_latency = 0.05 + semaphore_cap = 1000 + thread_pool_workers = 16 + grid = [500, 1000, 2500, 5000, 10000] + metrics: list[ScenarioMetrics] = [] + for n in grid: + m = await _run_full( + name=f'Threshold N={n} (cap={semaphore_cap})', + n_items=n, + semaphore_cap=semaphore_cap, + thread_pool_workers=thread_pool_workers, + server_latency_s=server_latency, + activity_kind='async', + notes='N > cap forces queue wait; p99 grows linearly', + ) + metrics.append(m) + return metrics + + +async def run_sustained_load(duration_s: float = SUSTAINED_DURATION_S) -> SustainedMetrics: + """Open-loop steady-state run at a target rate slightly below peak.""" + return await _run_sustained( + duration_s=duration_s, + target_rate_per_s=200.0, + semaphore_cap=1000, + thread_pool_workers=16, + server_latency_s=0.05, + ) + + +async def run_delivery_overhead() -> list[ScenarioMetrics]: + """Hold workload fixed and vary the simulated sidecar CompleteActivityTask + latency. Quantifies the response-delivery cost added by ``run_in_executor``. + """ + server_latency = 0.05 + n_items = 1000 + semaphore_cap = 1000 + thread_pool_workers = 16 + grid = [0.000, 0.001, 0.005, 0.010] + metrics: list[ScenarioMetrics] = [] + for send_latency in grid: + m = await _run_full( + name=f'Delivery latency={int(send_latency * 1000)}ms', + n_items=n_items, + semaphore_cap=semaphore_cap, + thread_pool_workers=thread_pool_workers, + server_latency_s=server_latency, + activity_kind='async', + send_latency_s=send_latency, + notes='asyncio default executor caps response delivery at min(32, cpu+4) workers', + ) + metrics.append(m) + return metrics + + +async def run_oom_safety() -> ScenarioMetrics: + """10 000 in-flight activities with a 1 000-cap semaphore. Validates that the + pile of Tasks parked on the semaphore does not blow up RSS. + """ + server_latency = 0.05 + start_ts: dict[int, float] = {} + end_ts: dict[int, float] = {} + return await _run_lite( + name='OOM safety (10k tasks, 1k semaphore)', + activity=_async_sleep_factory(server_latency, start_ts, end_ts), + n_items=10_000, + semaphore_cap=1000, + thread_pool_workers=8, + server_latency_s=server_latency, + notes='~9k tasks blocked on the semaphore. Peak RSS delta budget is 500 MB.', + ) + + +async def run_real_http_workload() -> list[ScenarioMetrics]: + """Production-shape scenarios driving real ``httpx.AsyncClient`` fetches. + + Mirrors ``examples/workflow/async_activities.py``: each activity opens a fresh + ``AsyncClient`` and GETs a local aiohttp endpoint that sleeps for 50 ms. Uses + the production dispatch path (``_execute_activity_async`` + mock stub) so the + measured latency is submit → response delivery, including TCP, HTTP, JSON + encode/decode, and ``run_in_executor`` for the response send. + + Async fetches at the same grid as the synthetic sweep let users compare + isolated SDK overhead to end-to-end behavior under real I/O. + """ + server_latency = 0.05 + grid = [(100, 1000, 16), (500, 1000, 16), (1000, 1000, 16), (2500, 5000, 16)] + metrics: list[ScenarioMetrics] = [] + async with _slow_aiohttp_server(server_latency) as url: + for n, cap, pool in grid: + async_metrics = await _run_full( + name=f'Real HTTP async N={n}', + n_items=n, + semaphore_cap=cap, + thread_pool_workers=pool, + server_latency_s=server_latency, + activity_kind='async', + activity_factory=lambda s, e, url=url: _async_fetch_factory(url, s, e), + notes='httpx.AsyncClient → aiohttp server (50 ms)', + ) + metrics.append(async_metrics) + # One sync row at N=100 to keep the comparison honest without making the + # bench painful — sync at higher N takes a long wall-clock. + sync_metrics = await _run_full( + name='Real HTTP sync N=100', + n_items=100, + semaphore_cap=1000, + thread_pool_workers=16, + server_latency_s=server_latency, + activity_kind='sync', + activity_factory=lambda s, e, url=url: _sync_fetch_factory(url, s, e), + notes='httpx.Client → aiohttp server, throttled by thread pool', + ) + metrics.append(sync_metrics) + return metrics + + +async def run_real_http_sustained( + duration_s: float = SUSTAINED_DURATION_S, +) -> SustainedMetrics: + """Sustained run mirroring real production: continuous httpx.AsyncClient fetches. + + Same shape as ``run_sustained_load`` but each activity is a real HTTP fetch + against a local aiohttp server, so the numbers reflect a workflow-heavy + deployment doing third-party API calls. + """ + server_latency = 0.05 + async with _slow_aiohttp_server(server_latency) as url: + return await _run_sustained( + duration_s=duration_s, + target_rate_per_s=100.0, + semaphore_cap=1000, + thread_pool_workers=16, + server_latency_s=server_latency, + activity_factory=lambda s, e: _async_fetch_factory(url, s, e), + ) + + +# ============================================================================ +# Report generation +# ============================================================================ + + +def _format_concurrency_table(metrics: list[ScenarioMetrics]) -> str: + header = ( + '| Scenario | N | Sem | Pool | Latency (s) | Wallclock (s) | Tput/s | p50 ms | p95 ms |' + ' p99 ms | Peak tasks | Peak queue | Peak RSS Δ (MB) | Notes |\n' + '| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | --- |\n' + ) + rows = [] + for m in metrics: + rows.append( + f'| {m.name} | {m.n_items} | {m.semaphore_cap} | {m.thread_pool_workers} |' + f' {m.server_latency_s:.3f} | {m.wallclock_s:.2f} | {m.throughput_per_s:.1f} |' + f' {m.latency.p50_ms:.1f} | {m.latency.p95_ms:.1f} | {m.latency.p99_ms:.1f} |' + f' {m.peak_tasks} | {m.peak_queue_depth} | {m.peak_rss_delta_mb:.1f} | {m.notes} |' + ) + return header + '\n'.join(rows) + + +def _format_legacy_table(metrics: list[ScenarioMetrics]) -> str: + """Compatibility table for scenarios without per-item latency (#897 repro, OOM).""" + header = ( + '| Scenario | N | Sem | Pool | Latency (s) | Wallclock (s) | Tput/s | Peak tasks |' + ' Peak queue | Peak RSS Δ (MB) | Notes |\n' + '| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | --- |\n' + ) + rows = [] + for m in metrics: + rows.append( + f'| {m.name} | {m.n_items} | {m.semaphore_cap} | {m.thread_pool_workers} |' + f' {m.server_latency_s:.3f} | {m.wallclock_s:.2f} | {m.throughput_per_s:.1f} |' + f' {m.peak_tasks} | {m.peak_queue_depth} | {m.peak_rss_delta_mb:.1f} | {m.notes} |' + ) + return header + '\n'.join(rows) + + +def _format_sustained_block(m: SustainedMetrics) -> str: + return ( + f'- **Target rate**: {m.target_rate_per_s:.0f}/s for {m.duration_s:.0f} s\n' + f'- **Submitted / completed**: {m.submitted} / {m.completed}\n' + f'- **Wallclock**: {m.wallclock_s:.2f} s (effective throughput' + f' {m.throughput_per_s:.1f}/s)\n' + f'- **Latency (overall)**: p50 {m.latency_overall.p50_ms:.1f} ms,' + f' p95 {m.latency_overall.p95_ms:.1f} ms, p99 {m.latency_overall.p99_ms:.1f} ms,' + f' max {m.latency_overall.max_ms:.1f} ms\n' + f'- **Latency (first 25%)**: p99 {m.latency_first_quarter.p99_ms:.1f} ms\n' + f'- **Latency (last 25%)**: p99 {m.latency_last_quarter.p99_ms:.1f} ms\n' + f'- **Peak tasks**: {m.peak_tasks}, peak queue depth: {m.peak_queue_depth},' + f' peak RSS Δ: {m.peak_rss_delta_mb:.1f} MB\n' + ) + + +def _find_failure_threshold(metrics: list[ScenarioMetrics], baseline_latency_ms: float) -> str: + threshold_factor = 2.0 + threshold_ms = baseline_latency_ms * threshold_factor + for m in metrics: + if m.latency.p99_ms > threshold_ms: + return ( + f'p99 first exceeds {threshold_factor:g}x server latency' + f' ({threshold_ms:.1f} ms) at **N={m.n_items}** with cap={m.semaphore_cap}' + f' (p99 = {m.latency.p99_ms:.1f} ms).' + ) + return ( + f'p99 stayed below {threshold_factor:g}x server latency across the full grid' + f' (max N={metrics[-1].n_items}); the SDK did not degrade in this run.' + ) + + +def _format_environment_block(env: RunEnvironment) -> str: + mem_str = f'{env.total_memory_gb:.1f} GB' if env.total_memory_gb > 0 else 'unknown' + return ( + '## Run environment\n' + '\n' + f'- **Timestamp**: {env.timestamp_utc}\n' + f'- **Git commit**: `{env.git_commit}`\n' + f'- **Python**: {env.python_implementation} {env.python_version}\n' + f'- **OS**: {env.os_release}\n' + f'- **Platform**: `{env.platform}`\n' + f'- **CPU**: {env.cpu_model} ({env.cpu_logical_cores} logical cores)\n' + f'- **Memory**: {mem_str}\n' + f'- **asyncio default executor**: max_workers =' + f' {env.asyncio_default_executor_workers}' + f' (`min(32, cpu_count + 4)`)\n' + f'- **CI environment**: {"yes" if env.is_ci else "no"}\n' + '\n' + '**Numbers from this report are specific to this machine.** Re-run the benchmark' + ' on your hardware before drawing conclusions; on a small CI runner or a busy' + ' workstation they will diverge. The shape of the curves (throughput plateau,' + ' p99 inflection, drift) is what to compare across machines.\n' + ) + + +def _write_results( + *, + env: RunEnvironment, + concurrency: list[ScenarioMetrics], + throughput: list[ScenarioMetrics], + semaphore: list[ScenarioMetrics], + threshold: list[ScenarioMetrics], + delivery: list[ScenarioMetrics], + sustained: SustainedMetrics, + oom: ScenarioMetrics, + real_http: list[ScenarioMetrics], + real_http_sustained: SustainedMetrics, +) -> None: + threshold_summary = _find_failure_threshold( + threshold, baseline_latency_ms=threshold[0].server_latency_s * 1000.0 + ) + body = [ + '# Async-activity load benchmark results', + '', + 'Generated by `bench_async_activities.py`. Re-run with:', + '', + '```bash', + 'uv run python ext/dapr-ext-workflow/benchmarks/bench_async_activities.py', + '```', + '', + _format_environment_block(env), + '', + 'Each scenario drives the production dispatch path' + ' (`TaskHubGrpcWorker._execute_activity_async`) through `_AsyncWorkerManager` against' + ' a mock `CompleteActivityTask` stub. End-to-end latency is measured from `submit_activity`' + ' to the mock stub receiving the response, so queue wait, semaphore acquisition,' + ' activity work, response build, and `run_in_executor` delivery are all included.', + '', + '## 1. Concurrency win (issue #897 repro)', + '', + 'Proves async activities run concurrently on the loop; the sync path is gated by the' + ' thread pool. This row reuses the original repro at 100 × 1 s HTTP fetches.', + '', + _format_legacy_table(concurrency), + '', + '## 2. Throughput scaling', + '', + 'Async fan-out at 50 ms server latency, semaphore cap 5000, thread pool 16. Throughput' + ' is reported as items completed per wallclock second; the sweep shows where the curve' + ' flattens.', + '', + _format_concurrency_table(throughput), + '', + '## 3. Semaphore-cap sensitivity', + '', + 'N=2500 async activities at 50 ms server latency. Cap below ~500 starves the loop and' + ' inflates queue wait. Above that, gains compress.', + '', + _format_concurrency_table(semaphore), + '', + '## 4. Failure threshold (queue-wait inflection)', + '', + 'Cap held at 1000, ramp N. Until N approaches cap, p99 stays close to server latency.' + ' Past it, queue wait dominates and p99 grows ~linearly with `N / cap`.', + '', + _format_concurrency_table(threshold), + '', + f'**Threshold**: {threshold_summary}', + '', + '## 5. Sidecar response delivery overhead', + '', + 'Mock `CompleteActivityTask` is given an artificial delay. Async responses go through' + " `loop.run_in_executor(None, ...)`, so they share asyncio's default executor" + f' (max `min(32, cpu_count + 4)`; on this run, `cpu_count={os.cpu_count() or 1}`).' + ' Delivery latency above ~5 ms × concurrency exceeds the default pool and serializes,' + ' inflating tail latency.', + '', + _format_concurrency_table(delivery), + '', + '## 6. Sustained load', + '', + _format_sustained_block(sustained), + '', + '## 7. Real HTTP workload (production shape)', + '', + 'Each activity opens a fresh `httpx.AsyncClient` and GETs a local aiohttp endpoint' + ' that sleeps 50 ms. Mirrors `examples/workflow/async_activities.py`. The sync row' + ' at N=100 shows the same workload throttled by the thread pool — directly comparable' + ' to the rest of the table.', + '', + _format_concurrency_table(real_http), + '', + '## 8. Real HTTP sustained load', + '', + 'Open-loop submission of real `httpx.AsyncClient` fetches at 100/s. Confirms steady' + ' state under genuine I/O, not synthetic sleep.', + '', + _format_sustained_block(real_http_sustained), + '', + '## 9. OOM safety', + '', + '10 000 in-flight async activities at 50 ms with a 1 000-cap semaphore. The' + ' ~9 000 Tasks parked on the semaphore are the design-discussion concern. Peak RSS' + ' delta stays well under the 500 MB budget, so the unbounded-pending-Task pattern is' + ' fine in practice.', + '', + _format_legacy_table([oom]), + '', + '## How to read this report', + '', + '- **Tput/s** is the closed-loop throughput (items completed / wallclock).' + ' For the sustained scenario it is the steady-state value over the full run.', + '- **p99 ms** is the end-to-end latency for the 99th-percentile item: time from' + ' `submit_activity` to the mock stub seeing the response.', + "- **Peak queue** is the maximum depth of the manager's `activity_queue` during the" + ' run. Non-zero peak queue means submission temporarily outran the semaphore.', + '- **Peak tasks** is the maximum number of live `asyncio.Task` objects in the process,' + ' which doubles as a sanity check on the unbounded-pending-Task pattern.', + '', + '## Operational guidance', + '', + 'See `ext/dapr-ext-workflow/docs/concurrency.md` for the full operational write-up,' + ' including sizing recommendations for `maximum_concurrent_activity_work_items`,' + ' `maximum_thread_pool_workers`, and the asyncio default-executor caveat.', + ] + RESULTS_PATH.write_text('\n'.join(body) + '\n', encoding='utf-8') + + +# ============================================================================ +# Budget assertions +# ============================================================================ + + +def _assert_budgets( + *, + concurrency: list[ScenarioMetrics], + throughput: list[ScenarioMetrics], + semaphore: list[ScenarioMetrics], + threshold: list[ScenarioMetrics], + delivery: list[ScenarioMetrics], + sustained: SustainedMetrics, + oom: ScenarioMetrics, + real_http: list[ScenarioMetrics], + real_http_sustained: SustainedMetrics, +) -> None: + """Pass criteria. Loud failure if a regression makes any of them false. + + Budgets are intentionally generous so CI doesn't flake; they catch order-of-magnitude + regressions, not micro-fluctuations. + """ + async_repro, sync_baseline = concurrency + # Issue #897: async must finish close to a single server-latency window. + assert async_repro.wallclock_s < async_repro.server_latency_s * 5, ( + f'Async fan-out took {async_repro.wallclock_s:.2f}s for' + f' {async_repro.n_items} × {async_repro.server_latency_s}s activities;' + f' async dispatch is not actually concurrent.' + ) + # Issue #897: sync baseline must be at least one extra latency window slower. + assert sync_baseline.wallclock_s > async_repro.wallclock_s + async_repro.server_latency_s, ( + f'Sync baseline ({sync_baseline.wallclock_s:.2f}s) was not at least one' + f' latency window slower than async ({async_repro.wallclock_s:.2f}s);' + f' the comparison is meaningless.' + ) + + # Throughput scaling: each larger N must be at least 80% as fast as the smallest; + # we tolerate the inevitable plateau but reject a collapse. + base_throughput = throughput[0].throughput_per_s + for m in throughput[1:]: + assert m.throughput_per_s >= base_throughput * 0.5, ( + f'Throughput collapsed at N={m.n_items}: {m.throughput_per_s:.1f}/s' + f' vs base {base_throughput:.1f}/s. The scaling curve regressed.' + ) + + # Semaphore sensitivity: the smallest cap must be at least 3x slower than the largest. + smallest_cap = semaphore[0] + largest_cap = semaphore[-1] + assert smallest_cap.wallclock_s > largest_cap.wallclock_s * 1.5, ( + f'Wallclock at cap={smallest_cap.semaphore_cap} ({smallest_cap.wallclock_s:.2f}s)' + f' was not meaningfully slower than at cap={largest_cap.semaphore_cap}' + f' ({largest_cap.wallclock_s:.2f}s). The semaphore is not gating concurrency.' + ) + + # Failure threshold: at N ≤ cap, p99 must be within 5x of server latency. + cap = threshold[0].semaphore_cap + server_latency_ms = threshold[0].server_latency_s * 1000.0 + for m in threshold: + if m.n_items <= cap: + assert m.latency.p99_ms <= server_latency_ms * 5, ( + f'p99 at N={m.n_items} (≤ cap={cap}) was {m.latency.p99_ms:.1f} ms,' + f' >5x server latency ({server_latency_ms:.1f} ms).' + f' The dispatch path has unexpected overhead.' + ) + + # Delivery overhead: zero-delay delivery must keep p99 < 200 ms at N=1000. + zero_delay = delivery[0] + assert zero_delay.latency.p99_ms < 200.0, ( + f'p99 with zero delivery delay was {zero_delay.latency.p99_ms:.1f} ms at N={zero_delay.n_items};' + f' the SDK adds more than 200 ms of overhead on top of the {zero_delay.server_latency_s * 1000:.0f}' + f' ms activity, which is too much.' + ) + + # Sustained: last-quarter p99 must not be more than 3x the first-quarter p99. + drift = sustained.latency_last_quarter.p99_ms + first = max(sustained.latency_first_quarter.p99_ms, 1.0) + assert drift <= first * 3.0, ( + f'Sustained tail latency drifted: first-quarter p99 = {first:.1f} ms,' + f' last-quarter p99 = {drift:.1f} ms.' + f' Steady state is degrading over the run.' + ) + + # OOM safety budgets — unchanged from the original benchmark. + assert oom.peak_tasks <= int(oom.n_items * 1.5), ( + f'Peak Tasks ({oom.peak_tasks}) exceeded 1.5 × N={oom.n_items}.' + f' The per-item Task accounting is inflated.' + ) + assert oom.peak_rss_delta_mb < 500.0, ( + f'Peak RSS delta {oom.peak_rss_delta_mb:.1f} MB exceeded the 500 MB budget.' + f' The unbounded pending-Task pattern needs an asyncio.Queue cap.' + ) + + # Real-HTTP workload: async path must beat the sync path's wallclock + # decisively. At small N, per-call ``httpx.AsyncClient(...)`` setup masks the + # win, so we compare the peak async throughput across the sweep against the + # sync row. + *real_async_rows, real_sync = real_http + peak_async_throughput = max(m.throughput_per_s for m in real_async_rows) + assert peak_async_throughput > real_sync.throughput_per_s * 1.25, ( + f'Real-HTTP peak async throughput ({peak_async_throughput:.1f}/s) was not' + f' >1.25x sync N={real_sync.n_items} ({real_sync.throughput_per_s:.1f}/s).' + f' The async path lost its concurrency advantage under real I/O.' + ) + # And at the largest async N, p99 must scale with the batch — not with the + # entire history of the run. + largest_async = max(real_async_rows, key=lambda m: m.n_items) + assert largest_async.latency.p99_ms < largest_async.wallclock_s * 1000.0 * 1.2, ( + f'Real-HTTP async N={largest_async.n_items}: p99 {largest_async.latency.p99_ms:.0f} ms' + f' exceeds 1.2x the wallclock ({largest_async.wallclock_s * 1000:.0f} ms),' + f' which means some items are blocked beyond the entire batch — pathological.' + ) + + # Real-HTTP sustained: same drift guard as the synthetic sustained run, + # but with slightly more slack because httpx connection churn adds jitter. + first_http = max(real_http_sustained.latency_first_quarter.p99_ms, 1.0) + last_http = real_http_sustained.latency_last_quarter.p99_ms + assert last_http <= first_http * 4.0, ( + f'Real-HTTP sustained tail latency drifted: first-quarter p99 = {first_http:.1f} ms,' + f' last-quarter p99 = {last_http:.1f} ms. Steady state regressed during the run.' + ) + + +# ============================================================================ +# Real-sidecar opt-in scenario +# ============================================================================ + + +async def run_with_real_sidecar() -> None: + """End-to-end scenario against a real Dapr sidecar. + + Skipped unless ``DAPR_BENCH_WITH_SIDECAR=1``. Requires the script to be run under + ``dapr run`` with a workflow-enabled state store, e.g.:: + + dapr run --app-id bench --app-protocol grpc --dapr-grpc-port 50001 \\ + -- env DAPR_BENCH_WITH_SIDECAR=1 \\ + uv run python ext/dapr-ext-workflow/benchmarks/bench_async_activities.py + """ + import dapr.ext.workflow as wf + + n_items = 50 + server_latency_s = 0.5 + wfr = wf.WorkflowRuntime() + + @wfr.workflow(name='bench_real_workflow') + def bench_workflow(ctx: wf.DaprWorkflowContext, payload: list[int]): + tasks = [ctx.call_activity(bench_async_activity, input=i) for i in payload] + return (yield wf.when_all(tasks)) + + @wfr.activity(name='bench_async_activity') + async def bench_async_activity(_ctx: wf.WorkflowActivityContext, _i: int) -> int: + await asyncio.sleep(server_latency_s) + return _i + + wfr.start() + time.sleep(2) + try: + client = wf.DaprWorkflowClient() + instance_id = client.schedule_new_workflow( + workflow=bench_workflow, input=list(range(n_items)) + ) + start = time.perf_counter() + state = client.wait_for_workflow_completion(instance_id, timeout_in_seconds=120) + wallclock = time.perf_counter() - start + assert state is not None, 'workflow timed out against real sidecar' + print( + f'[real-sidecar] {n_items} async activities × {server_latency_s}s' + f' completed in {wallclock:.2f}s (status {state.runtime_status.name})' + ) + finally: + wfr.shutdown() + + +# ============================================================================ +# Entry point +# ============================================================================ + + +async def main() -> None: + logging.basicConfig(level=logging.WARNING) + + env = RunEnvironment.capture() + print( + f'[env] {env.cpu_model} | {env.cpu_logical_cores} cores |' + f' {env.total_memory_gb:.1f} GB | {env.python_implementation} {env.python_version}', + flush=True, + ) + + print('[1/9] concurrency win (issue #897 repro)...', flush=True) + concurrency = await run_concurrency_win() + + print('[2/9] throughput scaling sweep...', flush=True) + throughput = await run_throughput_scaling() + + print('[3/9] semaphore-cap sensitivity sweep...', flush=True) + semaphore = await run_semaphore_sensitivity() + + print('[4/9] failure-threshold ramp...', flush=True) + threshold = await run_failure_threshold() + + print('[5/9] sidecar-delivery overhead sweep...', flush=True) + delivery = await run_delivery_overhead() + + print(f'[6/9] sustained load ({SUSTAINED_DURATION_S:.0f}s)...', flush=True) + sustained = await run_sustained_load() + + print('[7/9] real-HTTP workload sweep...', flush=True) + real_http = await run_real_http_workload() + + real_http_duration = min(SUSTAINED_DURATION_S, 60.0) + print(f'[8/9] real-HTTP sustained load ({real_http_duration:.0f}s)...', flush=True) + real_http_sustained = await run_real_http_sustained(duration_s=real_http_duration) + + print('[9/9] OOM safety...', flush=True) + oom = await run_oom_safety() + + _write_results( + env=env, + concurrency=concurrency, + throughput=throughput, + semaphore=semaphore, + threshold=threshold, + delivery=delivery, + sustained=sustained, + oom=oom, + real_http=real_http, + real_http_sustained=real_http_sustained, + ) + print('\n=== concurrency win ===') + print(_format_legacy_table(concurrency)) + print('\n=== throughput scaling ===') + print(_format_concurrency_table(throughput)) + print('\n=== semaphore sensitivity ===') + print(_format_concurrency_table(semaphore)) + print('\n=== failure threshold ===') + print(_format_concurrency_table(threshold)) + print('\n=== sidecar delivery overhead ===') + print(_format_concurrency_table(delivery)) + print('\n=== sustained load (synthetic) ===') + print(_format_sustained_block(sustained)) + print('\n=== real HTTP workload ===') + print(_format_concurrency_table(real_http)) + print('\n=== real HTTP sustained load ===') + print(_format_sustained_block(real_http_sustained)) + print('\n=== OOM safety ===') + print(_format_legacy_table([oom])) + print(f'\nWrote {RESULTS_PATH.relative_to(Path.cwd())}') + + _assert_budgets( + concurrency=concurrency, + throughput=throughput, + semaphore=semaphore, + threshold=threshold, + delivery=delivery, + sustained=sustained, + oom=oom, + real_http=real_http, + real_http_sustained=real_http_sustained, + ) + + if os.environ.get('DAPR_BENCH_WITH_SIDECAR') == '1': + print('\n[opt-in] running real-sidecar scenario...') + await run_with_real_sidecar() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py index 84663064f..d66698a8f 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/_durabletask/worker.py @@ -65,10 +65,10 @@ def _log_all_threads(logger: logging.Logger, context: str = ''): class ConcurrencyOptions: - """Configuration options for controlling concurrency of different work item types and the thread pool size. + """Concurrency limits for the worker. - This class provides fine-grained control over concurrent processing limits for - activities, orchestrations and the thread pool size. + ``maximum_thread_pool_workers`` only matters for sync activities. + Async activities run as coroutines on the event loop. """ def __init__( @@ -80,11 +80,12 @@ def __init__( """Initialize concurrency options. Args: - maximum_concurrent_activity_work_items: Maximum number of activity work items - that can be processed concurrently. Defaults to 100 * processor_count. - maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items - that can be processed concurrently. Defaults to 100 * processor_count. - maximum_thread_pool_workers: Maximum number of thread pool workers to use. + maximum_concurrent_activity_work_items: Cap on concurrent activity work items. + Defaults to ``100 * cpu_count``. + maximum_concurrent_orchestration_work_items: Cap on concurrent orchestration work + items. Defaults to ``100 * cpu_count``. + maximum_thread_pool_workers: Size of the thread pool used to run sync activities. + Async activities do not use this pool. Defaults to ``cpu_count + 4``. """ processor_count = os.cpu_count() or 1 default_concurrency = 100 * processor_count @@ -658,8 +659,19 @@ def stream_reader(): work_item.completionToken, ) elif work_item.HasField('activityRequest'): + # Async user activities run on the event loop. Sync ones fall through + # to the thread pool via _execute_activity. + activity_fn = self._registry.get_activity( + work_item.activityRequest.name + ) + activity_handler = ( + self._execute_activity_async + if activity_fn is not None + and inspect.iscoroutinefunction(activity_fn) + else self._execute_activity + ) self._async_worker_manager.submit_activity( - self._execute_activity, + activity_handler, work_item.activityRequest, stub, work_item.completionToken, @@ -807,10 +819,15 @@ def _deferred_close(): self._channel_cleanup_threads.append(thread) def stop(self): - """Stops the worker and waits for any pending work items to complete.""" + """Stop the worker and tear down its resources. + + Idempotent and safe to call before ``start()`` because the thread pool + exists from construction. + """ # Guards on _runLoop rather than _is_running so stop() can unblock a start() # that is still waiting for the work item stream to be established. if self._runLoop is None: + self._async_worker_manager.shutdown() return self._logger.info('Stopping gRPC worker...') @@ -964,6 +981,102 @@ def _execute_orchestrator( f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" ) + def _activity_span(self, req: pb.ActivityRequest, instance_id: str): + """Return an OTel span context manager, or a nullcontext if OTel is not installed.""" + if otel_tracer is None: + return contextlib.nullcontext() + return otel_tracer.start_as_current_span( + name=f'activity: {req.name}', + context=otel_propagator.extract( + carrier={'traceparent': req.parentTraceContext.traceParent} + ), + attributes={ + 'dapr.ext.workflow._durabletask.task.instance_id': instance_id, + 'dapr.ext.workflow._durabletask.task.id': req.taskId, + 'dapr.ext.workflow._durabletask.activity.name': req.name, + }, + ) + + def _propagated_history(self, req: pb.ActivityRequest) -> PropagatedHistory | None: + if req.HasField('propagatedHistory'): + return PropagatedHistory.from_proto(req.propagatedHistory) + return None + + def _build_activity_result_response( + self, + req: pb.ActivityRequest, + instance_id: str, + result: str | None, + completion_token, + ) -> pb.ActivityResponse: + return pb.ActivityResponse( + instanceId=instance_id, + taskId=req.taskId, + result=ph.get_string_value(result), + completionToken=completion_token, + ) + + def _build_activity_failure_response( + self, + req: pb.ActivityRequest, + instance_id: str, + ex: BaseException, + completion_token, + ) -> pb.ActivityResponse: + return pb.ActivityResponse( + instanceId=instance_id, + taskId=req.taskId, + failureDetails=ph.new_failure_details(ex), + completionToken=completion_token, + ) + + def _send_activity_response( + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + res: pb.ActivityResponse, + completion_token, + instance_id: str, + ): + """Send an activity response, falling back to a failure response when the + result is too large to deliver.""" + try: + stub.CompleteActivityTask(res) + except grpc.RpcError as rpc_error: # type: ignore + if _is_message_too_large(rpc_error): + # Result is too large to deliver - fail the activity immediately. + # This can only be fixed with infrastructure changes (increasing gRPC max message size). + self._logger.error( + f"Activity '{req.name}#{req.taskId}' result is too large to deliver " + f'(RESOURCE_EXHAUSTED). Failing the activity task: {rpc_error.details()}' + ) + oversize_error = RuntimeError( + f'Activity result exceeds gRPC max message size: {rpc_error.details()}' + ) + failure_res = self._build_activity_failure_response( + req, instance_id, oversize_error, completion_token + ) + try: + stub.CompleteActivityTask(failure_res) + except Exception as ex: + self._logger.exception( + f"Failed to deliver activity failure response for '{req.name}#{req.taskId}' " + f"of orchestration ID '{instance_id}': {ex}" + ) + else: + self._handle_grpc_execution_error(rpc_error, 'activity') + except ValueError: + # gRPC raises ValueError when the underlying channel has been closed (e.g. during reconnection). + self._logger.debug( + f"Could not deliver activity response for '{req.name}#{req.taskId}' of " + f"orchestration ID '{instance_id}': channel was closed (likely due to " + f'reconnection). The sidecar will re-dispatch this work item.' + ) + except Exception as ex: + self._logger.exception( + f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) + def _execute_activity( self, req: pb.ActivityRequest, @@ -971,91 +1084,69 @@ def _execute_activity( completionToken, ): instance_id = req.workflowInstance.instanceId - - if otel_tracer is not None: - span_context = otel_tracer.start_as_current_span( - name=f'activity: {req.name}', - context=otel_propagator.extract( - carrier={'traceparent': req.parentTraceContext.traceParent} - ), - attributes={ - 'dapr.ext.workflow._durabletask.task.instance_id': instance_id, - 'dapr.ext.workflow._durabletask.task.id': req.taskId, - 'dapr.ext.workflow._durabletask.activity.name': req.name, - }, - ) - else: - span_context = contextlib.nullcontext() - - with span_context: + with self._activity_span(req, instance_id): try: executor = _ActivityExecutor(self._registry, self._logger) - propagated = ( - PropagatedHistory.from_proto(req.propagatedHistory) - if req.HasField('propagatedHistory') - else None - ) result = executor.execute( instance_id, req.name, req.taskId, req.input.value, req.taskExecutionId, - propagated_history=propagated, + propagated_history=self._propagated_history(req), ) - res = pb.ActivityResponse( - instanceId=instance_id, - taskId=req.taskId, - result=ph.get_string_value(result), - completionToken=completionToken, + res = self._build_activity_result_response( + req, instance_id, result, completionToken ) except Exception as ex: - res = pb.ActivityResponse( - instanceId=instance_id, - taskId=req.taskId, - failureDetails=ph.new_failure_details(ex), - completionToken=completionToken, - ) + res = self._build_activity_failure_response(req, instance_id, ex, completionToken) + self._send_activity_response(req, stub, res, completionToken, instance_id) + async def _execute_activity_async( + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + """Run an async activity on the event loop and send its result to the sidecar. + The gRPC send goes through ``run_in_executor`` to avoid blocking the loop. + """ + instance_id = req.workflowInstance.instanceId + with self._activity_span(req, instance_id): try: - stub.CompleteActivityTask(res) - except grpc.RpcError as rpc_error: # type: ignore - if _is_message_too_large(rpc_error): - # Result is too large to deliver - fail the activity immediately. - # This can only be fixed with infrastructure changes (increasing gRPC max message size). - self._logger.error( - f"Activity '{req.name}#{req.taskId}' result is too large to deliver " - f'(RESOURCE_EXHAUSTED). Failing the activity task: {rpc_error.details()}' - ) - failure_res = pb.ActivityResponse( - instanceId=instance_id, - taskId=req.taskId, - failureDetails=ph.new_failure_details( - RuntimeError( - f'Activity result exceeds gRPC max message size: {rpc_error.details()}' - ) - ), - completionToken=completionToken, - ) - try: - stub.CompleteActivityTask(failure_res) - except Exception as ex: - self._logger.exception( - f"Failed to deliver activity failure response for '{req.name}#{req.taskId}' " - f"of orchestration ID '{instance_id}': {ex}" - ) - else: - self._handle_grpc_execution_error(rpc_error, 'activity') - except ValueError: - # gRPC raises ValueError when the underlying channel has been closed (e.g. during reconnection). - self._logger.debug( - f"Could not deliver activity response for '{req.name}#{req.taskId}' of " - f"orchestration ID '{instance_id}': channel was closed (likely due to " - f'reconnection). The sidecar will re-dispatch this work item.' + executor = _ActivityExecutor(self._registry, self._logger) + result = await executor.execute_async( + instance_id, + req.name, + req.taskId, + req.input.value, + req.taskExecutionId, + propagated_history=self._propagated_history(req), + ) + res = self._build_activity_result_response( + req, instance_id, result, completionToken ) + except asyncio.CancelledError: + raise except Exception as ex: - self._logger.exception( - f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" + res = self._build_activity_failure_response(req, instance_id, ex, completionToken) + loop = asyncio.get_running_loop() + try: + await loop.run_in_executor( + None, + self._send_activity_response, + req, + stub, + res, + completionToken, + instance_id, + ) + except RuntimeError as exc: + # Default executor shut down. Raising would only leak a + # 'Task exception was never retrieved' since nobody awaits this task. + self._logger.warning( + f"Could not deliver activity response for '{req.name}#{req.taskId}': " + f'{exc}. The sidecar will re-dispatch this work item.' ) @@ -2002,23 +2093,22 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger - def execute( + def _resolve( self, orchestration_id: str, name: str, task_id: int, - encoded_input: Optional[str], - task_execution_id: str = '', - propagated_history: Optional[PropagatedHistory] = None, - ) -> Optional[str]: - """Executes an activity function and returns the serialized result, if any.""" + encoded_input: str | None, + task_execution_id: str, + propagated_history: PropagatedHistory | None, + ) -> tuple[task.Activity, task.ActivityContext, Any]: + """Look up the registered activity and build its ``(fn, ctx, input)`` call args.""" self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...") fn = self._registry.get_activity(name) if not fn: raise ActivityNotRegisteredError( f"Activity function named '{name}' was not registered!" ) - activity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.ActivityContext( orchestration_id, @@ -2026,10 +2116,11 @@ def execute( task_execution_id, propagated_history=propagated_history, ) + return fn, ctx, activity_input - # Execute the activity function - activity_output = fn(ctx, activity_input) - + def _encode_output( + self, orchestration_id: str, name: str, task_id: int, activity_output: Any + ) -> str | None: encoded_output = shared.to_json(activity_output) if activity_output is not None else None chars = len(encoded_output) if encoded_output else 0 self._logger.debug( @@ -2037,6 +2128,58 @@ def execute( ) return encoded_output + def execute( + self, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: Optional[str], + task_execution_id: str = '', + propagated_history: Optional[PropagatedHistory] = None, + ) -> Optional[str]: + """Run a sync activity function and return the serialized result, if any. + + Raises ``RuntimeError`` if the activity returns a coroutine, which happens when + ``_is_async_callable`` fails to detect an async callable at registration. + """ + fn, ctx, activity_input = self._resolve( + orchestration_id, + name, + task_id, + encoded_input, + task_execution_id, + propagated_history, + ) + activity_output = fn(ctx, activity_input) + if inspect.iscoroutine(activity_output): + activity_output.close() + raise RuntimeError( + f"Activity '{name}' returned a coroutine on the sync path. " + f'Declare it with ``async def`` so the worker dispatches it on the event loop.' + ) + return self._encode_output(orchestration_id, name, task_id, activity_output) + + async def execute_async( + self, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: str | None, + task_execution_id: str = '', + propagated_history: PropagatedHistory | None = None, + ) -> str | None: + """Await a coroutine activity function and return the serialized result, if any.""" + fn, ctx, activity_input = self._resolve( + orchestration_id, + name, + task_id, + encoded_input, + task_execution_id, + propagated_history, + ) + activity_output = await fn(ctx, activity_input) + return self._encode_output(orchestration_id, name, task_id, activity_output) + def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError: return task.NonDeterminismError( diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index f33622a15..b829c6fe7 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -13,10 +13,11 @@ limitations under the License. """ +import functools import inspect import time from functools import wraps -from typing import Optional, Sequence, TypeVar, Union +from typing import Any, Awaitable, Callable, Optional, Sequence, TypeVar, Union import grpc from dapr.ext.workflow._durabletask import task, worker @@ -45,6 +46,73 @@ grpc.StreamStreamClientInterceptor, ] +# Durabletask returns decoded JSON, so we type the input as ``object | None`` and let the +# wrapper narrow it via the activity's declared model. +SyncActivityWrapper = Callable[[task.ActivityContext, object | None], object] +AsyncActivityWrapper = Callable[[task.ActivityContext, object | None], Awaitable[object]] +ActivityWrapper = SyncActivityWrapper | AsyncActivityWrapper + + +def _is_async_callable(fn: Any) -> bool: + """Return True if ``fn`` is async. Catches ``functools.partial`` of coroutines, + sync decorators that wrap async functions, and callable instances with ``async __call__``. + """ + candidate = fn + while isinstance(candidate, functools.partial): + candidate = candidate.func + candidate = inspect.unwrap(candidate) if callable(candidate) else candidate + if inspect.iscoroutinefunction(candidate): + return True + if not inspect.isfunction(candidate) and hasattr(candidate, '__call__'): + return inspect.iscoroutinefunction(candidate.__call__) + return False + + +def _coerce_activity_input(inp: object | None, input_model: type | None) -> object | None: + """Coerce the raw input to the activity's declared model, if it has one.""" + if inp is None or input_model is None or isinstance(inp, input_model): + return inp + return _model_protocol.coerce_to_model(inp, input_model) + + +def _make_activity_wrapper(fn: Activity, logger: Logger) -> ActivityWrapper: + """Wrap a user activity for the durabletask worker. + + Returns: + An ``async def`` wrapper for async activities, a plain ``def`` for sync. + """ + accepts_input, input_model = _model_protocol.resolve_input(fn) + + if _is_async_callable(fn): + + async def async_activity_wrapper( + ctx: task.ActivityContext, inp: object | None = None + ) -> object: + activity_id = getattr(ctx, 'task_id', 'unknown') + try: + wf_ctx = WorkflowActivityContext(ctx) + if not accepts_input: + return await fn(wf_ctx) + return await fn(wf_ctx, _coerce_activity_input(inp, input_model)) + except Exception as e: + logger.warning(f'Activity execution failed - task_id: {activity_id}, error: {e}') + raise + + return async_activity_wrapper + + def sync_activity_wrapper(ctx: task.ActivityContext, inp: object | None = None) -> object: + activity_id = getattr(ctx, 'task_id', 'unknown') + try: + wf_ctx = WorkflowActivityContext(ctx) + if not accepts_input: + return fn(wf_ctx) + return fn(wf_ctx, _coerce_activity_input(inp, input_model)) + except Exception as e: + logger.warning(f'Activity execution failed - task_id: {activity_id}, error: {e}') + raise + + return sync_activity_wrapper + class WorkflowRuntime: """WorkflowRuntime is the entry point for registering workflows and activities.""" @@ -180,36 +248,14 @@ def orchestrationWrapper(ctx: task.OrchestrationContext, inp: Optional[TInput] = fn.__dict__['_workflow_registered'] = True def register_activity(self, fn: Activity, *, name: Optional[str] = None): - """Registers a workflow activity as a function that takes - a specified input type and returns a specified output type. + """Register a workflow activity. ``def`` and ``async def`` are both supported. + Async activities run on the worker's event loop. Sync activities run in the + thread pool sized by ``maximum_thread_pool_workers``. """ effective_name = name or fn.__name__ self._logger.info(f"Registering activity '{effective_name}' with runtime") - accepts_input, input_model = _model_protocol.resolve_input(fn) - - def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): - """Responsible to call Activity function in activityWrapper""" - activity_id = getattr(ctx, 'task_id', 'unknown') - - try: - wfActivityContext = WorkflowActivityContext(ctx) - if not accepts_input: - result = fn(wfActivityContext) - else: - if ( - (inp is not None) - and (input_model is not None) - and not isinstance(inp, input_model) - ): - inp = _model_protocol.coerce_to_model(inp, input_model) - result = fn(wfActivityContext, inp) - return result - except Exception as e: - self._logger.warning( - f'Activity execution failed - task_id: {activity_id}, error: {e}' - ) - raise + activity_wrapper = _make_activity_wrapper(fn, self._logger) if hasattr(fn, '_activity_registered'): # whenever an activity is registered, it has a _dapr_alternate_name attribute @@ -224,7 +270,7 @@ def activityWrapper(ctx: task.ActivityContext, inp: Optional[TInput] = None): fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ self.__worker._registry.add_named_activity( - fn.__dict__['_dapr_alternate_name'], activityWrapper + fn.__dict__['_dapr_alternate_name'], activity_wrapper ) fn.__dict__['_activity_registered'] = True @@ -446,16 +492,23 @@ def add(ctx, x: int, y: int) -> int: the workflow runtime. Defaults to None. """ - def wrapper(fn: any): + def wrapper(fn: Any): if hasattr(fn, '_dapr_alternate_name'): raise ValueError( f'Function {fn.__name__} already has an alternate name {fn._dapr_alternate_name}' ) fn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ - @wraps(fn) - def innerfn(*args, **kwargs): - return fn(*args, **kwargs) + if _is_async_callable(fn): + + @wraps(fn) + async def innerfn(*args, **kwargs): + return await fn(*args, **kwargs) + else: + + @wraps(fn) + def innerfn(*args, **kwargs): + return fn(*args, **kwargs) innerfn.__dict__['_dapr_alternate_name'] = name if name else fn.__name__ innerfn.__signature__ = inspect.signature(fn) diff --git a/ext/dapr-ext-workflow/docs/concurrency.md b/ext/dapr-ext-workflow/docs/concurrency.md new file mode 100644 index 000000000..6c264e6e3 --- /dev/null +++ b/ext/dapr-ext-workflow/docs/concurrency.md @@ -0,0 +1,83 @@ +# Concurrency configuration for `dapr-ext-workflow` + +Sizing notes for the worker's concurrency knobs. Numbers come from +`benchmarks/bench_async_activities.py`. Re-run it on local hardware to validate. + +## Knobs + +| Setting | Default | Effect | +| --- | --- | --- | +| `maximum_concurrent_activity_work_items` | `100 × cpu_count` | Async semaphore cap on in-flight activity work items. | +| `maximum_concurrent_orchestration_work_items` | `100 × cpu_count` | Same, for orchestrations. | +| `maximum_thread_pool_workers` | `cpu_count + 4` | Thread pool size for **sync** activities. Async activities run as coroutines on the event loop and never enter this pool. | + +A `def` activity consumes a semaphore slot **and** a thread pool worker. An +`async def` activity consumes only a semaphore slot. + +## Sizing the activity cap + +The cap is the lever for throughput and queue wait. Throughput plateaus around +`cap ≈ peak_in_flight`. Past the cap, queue wait grows linearly. The benchmark's +failure-threshold sweep shows the inflection point clearly. Rule of thumb: set +the cap to ~2x the expected steady-state in-flight count to absorb bursts. + +If activities call a downstream with a hard concurrency limit (e.g. a database +with a 100-connection pool), set the cap below that limit so it doubles as +backpressure. + +## Sizing the thread pool + +Two distinct uses of threads exist. + +**Sync activity execution.** Each `def` activity holds one thread for its +duration. Size to peak concurrent sync-activity count. + +**Async response delivery.** Each async activity, on completion, schedules +`stub.CompleteActivityTask` via `loop.run_in_executor(None, ...)`. That uses +asyncio's **default executor**, which is process-wide and sized to +`min(32, cpu_count + 4)`. It is *not* `maximum_thread_pool_workers`. + +If the sidecar takes >5 ms to acknowledge and the worker runs >30 concurrent +async activities, response delivery serializes through the default executor and +tail latency inflates. Install a larger default executor before starting: + +```python +import asyncio +from concurrent.futures import ThreadPoolExecutor + +asyncio.get_event_loop().set_default_executor(ThreadPoolExecutor(max_workers=200)) +``` + +This goes away when the worker migrates to `grpc.aio`. Until then, the default +executor is a separate knob from `maximum_thread_pool_workers`. + +## Sharing httpx clients + +The pattern in `examples/workflow/async_activities.py` opens a fresh +`httpx.AsyncClient` per activity. Correct for most workloads, but each call pays +TCP + TLS setup, and throughput plateaus around a few hundred req/s. + +For higher throughput, share a single client across activities: + +```python +_shared_client: httpx.AsyncClient | None = None + +def _get_client() -> httpx.AsyncClient: + global _shared_client + if _shared_client is None: + _shared_client = httpx.AsyncClient(timeout=30.0) + return _shared_client +``` + +The caller owns closing it during worker shutdown. For activities that hit many +hosts or need per-call timeout isolation, stick with per-call clients. + +## Re-running the benchmark + +```bash +uv run python ext/dapr-ext-workflow/benchmarks/bench_async_activities.py +``` + +Override the 120 s sustained run with `DAPR_BENCH_SUSTAINED_SECONDS=30` for a +faster local check. Set `DAPR_BENCH_WITH_SIDECAR=1` to exercise the end-to-end +path against a real sidecar. Results land in `RESULTS.md`. diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_activity_dispatch_routing.py b/ext/dapr-ext-workflow/tests/durabletask/test_activity_dispatch_routing.py new file mode 100644 index 000000000..356603b39 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/durabletask/test_activity_dispatch_routing.py @@ -0,0 +1,90 @@ +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contract tests for the activity dispatch handlers on ``TaskHubGrpcWorker``. + +The work-item dispatcher at the top of ``worker.py``'s gRPC loop selects between +``_execute_activity`` (sync, runs in the thread pool) and ``_execute_activity_async`` +(coroutine, awaited on the event loop) using ``inspect.iscoroutinefunction(handler)`` +via ``_AsyncWorkerManager._run_func``. These tests pin the async-ness of each handler so +the dispatch routing stays correct. +""" + +import asyncio +import inspect +import logging +import threading +from typing import Iterator + +import pytest +from dapr.ext.workflow._durabletask.worker import ( + ConcurrencyOptions, + TaskHubGrpcWorker, + _AsyncWorkerManager, +) + + +@pytest.fixture +def worker() -> Iterator[TaskHubGrpcWorker]: + instance = TaskHubGrpcWorker() + try: + yield instance + finally: + # Tears down the manager's ThreadPoolExecutor so each test doesn't leak threads. + instance.stop() + + +@pytest.fixture +def manager() -> Iterator[_AsyncWorkerManager]: + instance = _AsyncWorkerManager(ConcurrencyOptions(), logger=logging.getLogger()) + try: + yield instance + finally: + instance.shutdown() + + +def test_sync_activity_handler_is_not_a_coroutine_function(worker: TaskHubGrpcWorker): + assert not inspect.iscoroutinefunction(worker._execute_activity) + + +def test_async_activity_handler_is_a_coroutine_function(worker: TaskHubGrpcWorker): + assert inspect.iscoroutinefunction(worker._execute_activity_async) + + +def test_run_func_awaits_coroutines_directly(manager: _AsyncWorkerManager): + """``_AsyncWorkerManager._run_func`` is the single point that branches on async-ness. + + A coroutine handler returns its value without going through the thread pool. + """ + + async def coroutine_handler(value: int) -> int: + return value + 1 + + async def driver() -> int: + return await manager._run_func(coroutine_handler, 41) + + assert asyncio.run(driver()) == 42 + + +def test_run_func_dispatches_sync_callables_to_thread_pool(manager: _AsyncWorkerManager): + main_thread_id = threading.get_ident() + captured: dict[str, int] = {} + + def sync_handler(value: int) -> int: + captured['thread_id'] = threading.get_ident() + return value + 1 + + async def driver() -> int: + return await manager._run_func(sync_handler, 41) + + result = asyncio.run(driver()) + assert result == 42 + assert captured['thread_id'] != main_thread_id diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor.py b/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor.py index f65aaf3f6..145548eda 100644 --- a/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor.py +++ b/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor.py @@ -59,6 +59,27 @@ def test_activity(ctx: task.ActivityContext, _): assert 'Bogus' in str(caught_exception) +def test_sync_execute_rejects_async_activity(): + """Sync ``execute`` must raise a clear RuntimeError when the activity returns a + coroutine. Guards against ``_is_async_callable`` missing an async callable at + registration; without this, JSON encoding would fail with a confusing TypeError. + """ + + async def async_activity(ctx: task.ActivityContext, _): + return 'never reached' + + executor, name = _get_activity_executor(async_activity) + + caught_exception: Optional[Exception] = None + try: + executor.execute(TEST_INSTANCE_ID, name, TEST_TASK_ID, None) + except Exception as ex: + caught_exception = ex + + assert type(caught_exception) is RuntimeError + assert 'returned a coroutine' in str(caught_exception) + + def _get_activity_executor(fn: task.Activity) -> Tuple[worker._ActivityExecutor, str]: registry = worker._Registry() name = registry.add_activity(fn) diff --git a/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor_async.py b/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor_async.py new file mode 100644 index 000000000..5f2bf7171 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/durabletask/test_activity_executor_async.py @@ -0,0 +1,101 @@ +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the async branch of ``_ActivityExecutor``. + +These mirror ``test_activity_executor.py`` but exercise the ``execute_async`` path used +when a registered activity is a coroutine function. +""" + +import asyncio +import inspect +import json +import logging +from typing import Any + +import pytest +from dapr.ext.workflow._durabletask import task, worker + +logging.basicConfig( + format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.DEBUG, +) +TEST_LOGGER = logging.getLogger('tests') +TEST_INSTANCE_ID = 'abc123' +TEST_TASK_ID = 42 + + +def _get_activity_executor(fn: task.Activity) -> tuple[worker._ActivityExecutor, str]: + registry = worker._Registry() + name = registry.add_activity(fn) + executor = worker._ActivityExecutor(registry, TEST_LOGGER) + return executor, name + + +def test_async_activity_inputs(): + """Validates that execute_async awaits the activity and returns the encoded result.""" + + async def test_async_activity(ctx: task.ActivityContext, test_input: Any): + await asyncio.sleep(0) + return test_input, ctx.orchestration_id, ctx.task_id + + activity_input = 'Hello, 世界!' + executor, name = _get_activity_executor(test_async_activity) + result = asyncio.run( + executor.execute_async(TEST_INSTANCE_ID, name, TEST_TASK_ID, json.dumps(activity_input)) + ) + assert result is not None + + result_input, result_orchestration_id, result_task_id = json.loads(result) + assert activity_input == result_input + assert TEST_INSTANCE_ID == result_orchestration_id + assert TEST_TASK_ID == result_task_id + + +def test_async_activity_not_registered(): + async def test_async_activity(ctx: task.ActivityContext, _): + pass # not used + + executor, _ = _get_activity_executor(test_async_activity) + + with pytest.raises(worker.ActivityNotRegisteredError) as exc_info: + asyncio.run(executor.execute_async(TEST_INSTANCE_ID, 'Bogus', TEST_TASK_ID, None)) + assert 'Bogus' in str(exc_info.value) + + +def test_async_activity_exception_propagates(): + async def test_async_activity(ctx: task.ActivityContext, _): + raise RuntimeError('boom') + + executor, name = _get_activity_executor(test_async_activity) + + with pytest.raises(RuntimeError) as exc_info: + asyncio.run(executor.execute_async(TEST_INSTANCE_ID, name, TEST_TASK_ID, None)) + assert 'boom' in str(exc_info.value) + + +def test_async_activity_registry_preserves_coroutine_function(): + """The dispatcher relies on iscoroutinefunction(fn) at the registry lookup level. + + If the registry's add_activity ever wraps coroutine functions in a way that hides their + async-ness (e.g. functools.wraps with a sync decorator), the dispatcher would route + them to the thread pool and break I/O concurrency. This test pins that contract. + """ + + async def test_async_activity(ctx: task.ActivityContext, _): + return None + + registry = worker._Registry() + name = registry.add_activity(test_async_activity) + + retrieved = registry.get_activity(name) + assert inspect.iscoroutinefunction(retrieved) diff --git a/ext/dapr-ext-workflow/tests/test_async_activity_registration.py b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py new file mode 100644 index 000000000..d44c7fc89 --- /dev/null +++ b/ext/dapr-ext-workflow/tests/test_async_activity_registration.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- + +# Copyright 2026 The Dapr Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for sync/async activity registration and the resulting wrappers. + +These tests exercise the helpers in workflow_runtime that decide whether an activity +runs in a thread pool (sync) or as a coroutine on the event loop (async). The +WorkflowRuntime is constructed against a fake registry so we don't need a sidecar. +""" + +import asyncio +import functools +import inspect +import unittest +from unittest import mock + +from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext +from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, _is_async_callable +from pydantic import BaseModel + + +class OrderInput(BaseModel): + order_id: str + amount: float + + +class FakeRegistry: + def __init__(self): + self.activities: dict[str, object] = {} + + def add_named_activity(self, name: str, fn) -> None: + self.activities[name] = fn + + +class FakeWorker: + def __init__(self) -> None: + self._registry = FakeRegistry() + + +class _AsyncActivityRegistrationTestBase(unittest.TestCase): + def setUp(self) -> None: + self._registry_patch = mock.patch( + 'dapr.ext.workflow._durabletask.worker._Registry', return_value=FakeRegistry() + ) + self._registry_patch.start() + self.runtime = WorkflowRuntime() + # Reach into the runtime to grab its registry for assertions. + self.registry: FakeRegistry = self.runtime._WorkflowRuntime__worker._registry + + def tearDown(self) -> None: + # Tear down the worker's ThreadPoolExecutor so each test doesn't leak threads/fds. + self.runtime.shutdown() + mock.patch.stopall() + + +class AsyncActivityRegistrationTest(_AsyncActivityRegistrationTestBase): + def test_async_activity_registers_coroutine_wrapper(self) -> None: + async def my_async_activity(ctx: WorkflowActivityContext, payload: str) -> str: + return payload.upper() + + self.runtime.register_activity(my_async_activity) + + wrapper = self.registry.activities['my_async_activity'] + self.assertTrue(inspect.iscoroutinefunction(wrapper)) + + def test_sync_activity_registers_plain_wrapper(self) -> None: + def my_sync_activity(ctx: WorkflowActivityContext, payload: str) -> str: + return payload.upper() + + self.runtime.register_activity(my_sync_activity) + + wrapper = self.registry.activities['my_sync_activity'] + self.assertFalse(inspect.iscoroutinefunction(wrapper)) + self.assertTrue(callable(wrapper)) + + def test_async_wrapper_awaits_user_function(self) -> None: + recorded: list[tuple[WorkflowActivityContext, str]] = [] + + async def my_async_activity(ctx: WorkflowActivityContext, payload: str) -> str: + await asyncio.sleep(0) + recorded.append((ctx, payload)) + return payload.upper() + + self.runtime.register_activity(my_async_activity) + wrapper = self.registry.activities['my_async_activity'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 7 + result = asyncio.run(wrapper(fake_ctx, 'hello')) + + self.assertEqual(result, 'HELLO') + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][1], 'hello') + self.assertIsInstance(recorded[0][0], WorkflowActivityContext) + + def test_sync_wrapper_calls_user_function(self) -> None: + recorded: list[tuple[WorkflowActivityContext, str]] = [] + + def my_sync_activity(ctx: WorkflowActivityContext, payload: str) -> str: + recorded.append((ctx, payload)) + return payload.upper() + + self.runtime.register_activity(my_sync_activity) + wrapper = self.registry.activities['my_sync_activity'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 3 + result = wrapper(fake_ctx, 'world') + + self.assertEqual(result, 'WORLD') + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][1], 'world') + self.assertIsInstance(recorded[0][0], WorkflowActivityContext) + + def test_async_wrapper_coerces_input_to_declared_model(self) -> None: + seen: list[OrderInput] = [] + + async def place_order(ctx: WorkflowActivityContext, order: OrderInput) -> str: + seen.append(order) + return order.order_id + + self.runtime.register_activity(place_order) + wrapper = self.registry.activities['place_order'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 99 + raw_input = {'order_id': 'abc-1', 'amount': 9.5} + result = asyncio.run(wrapper(fake_ctx, raw_input)) + + self.assertEqual(result, 'abc-1') + self.assertEqual(len(seen), 1) + self.assertIsInstance(seen[0], OrderInput) + self.assertEqual(seen[0].amount, 9.5) + + def test_async_wrapper_propagates_exceptions(self) -> None: + async def failing(ctx: WorkflowActivityContext, payload: str) -> str: + raise RuntimeError('boom') + + self.runtime.register_activity(failing) + wrapper = self.registry.activities['failing'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 1 + with self.assertRaises(RuntimeError) as caught: + asyncio.run(wrapper(fake_ctx, 'x')) + self.assertEqual(str(caught.exception), 'boom') + + def test_async_wrapper_supports_no_input_parameter(self) -> None: + async def heartbeat(ctx: WorkflowActivityContext) -> str: + return 'ok' + + self.runtime.register_activity(heartbeat) + wrapper = self.registry.activities['heartbeat'] + + fake_ctx = mock.MagicMock(spec=['task_id']) + fake_ctx.task_id = 0 + result = asyncio.run(wrapper(fake_ctx, None)) + self.assertEqual(result, 'ok') + + +class IsAsyncCallableTest(unittest.TestCase): + """Pin the contract of ``_is_async_callable`` against decorator shapes that bare + ``inspect.iscoroutinefunction`` would miss. These are the patterns the fix for finding + #5 was meant to address. Without coverage, a future refactor can silently regress + async-activity routing for any of them. + """ + + def test_plain_async_function_is_async(self) -> None: + async def fn() -> None: ... + + self.assertTrue(_is_async_callable(fn)) + + def test_plain_sync_function_is_not_async(self) -> None: + def fn() -> None: ... + + self.assertFalse(_is_async_callable(fn)) + + def test_functools_partial_of_async_is_async(self) -> None: + async def fn(prefix: str, payload: str) -> str: + return prefix + payload + + partial_fn = functools.partial(fn, 'hello-') + self.assertTrue(_is_async_callable(partial_fn)) + + def test_functools_partial_of_sync_is_not_async(self) -> None: + def fn(prefix: str, payload: str) -> str: + return prefix + payload + + partial_fn = functools.partial(fn, 'hello-') + self.assertFalse(_is_async_callable(partial_fn)) + + def test_wraps_chain_over_async_is_async(self) -> None: + """A sync decorator that uses @functools.wraps exposes the inner via __wrapped__.""" + + async def inner(ctx: object, inp: object) -> None: ... + + @functools.wraps(inner) + def outer(ctx: object, inp: object) -> object: + return inner(ctx, inp) + + self.assertTrue(_is_async_callable(outer)) + + def test_nested_partial_and_wraps_chain_is_async(self) -> None: + """partial(@wraps over async). Exercises both unwrap stages in order.""" + + async def inner(prefix: str, payload: str) -> str: + return prefix + payload + + @functools.wraps(inner) + def wrapped(prefix: str, payload: str) -> str: + return inner(prefix, payload) + + partial_wrapped = functools.partial(wrapped, 'hi-') + self.assertTrue(_is_async_callable(partial_wrapped)) + + def test_callable_class_instance_with_async_call_is_async(self) -> None: + class AsyncCallable: + async def __call__(self, ctx: object, inp: object) -> str: + return 'ok' + + self.assertTrue(_is_async_callable(AsyncCallable())) + + def test_callable_class_instance_with_sync_call_is_not_async(self) -> None: + class SyncCallable: + def __call__(self, ctx: object, inp: object) -> str: + return 'ok' + + self.assertFalse(_is_async_callable(SyncCallable())) + + +class AsyncAndSyncCoexistTest(_AsyncActivityRegistrationTestBase): + def test_runtime_registers_mixed_sync_and_async_activities(self) -> None: + async def async_activity(ctx: WorkflowActivityContext, payload: int) -> int: + return payload + 1 + + def sync_activity(ctx: WorkflowActivityContext, payload: int) -> int: + return payload * 2 + + self.runtime.register_activity(async_activity) + self.runtime.register_activity(sync_activity) + + async_wrapper = self.registry.activities['async_activity'] + sync_wrapper = self.registry.activities['sync_activity'] + + self.assertTrue(inspect.iscoroutinefunction(async_wrapper)) + self.assertFalse(inspect.iscoroutinefunction(sync_wrapper)) + + +if __name__ == '__main__': + unittest.main()