diff --git a/.sampo/changesets/chivalrous-baron-tapio.md b/.sampo/changesets/chivalrous-baron-tapio.md new file mode 100644 index 00000000..9524b28a --- /dev/null +++ b/.sampo/changesets/chivalrous-baron-tapio.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: patch +--- + +Fix scoped context support for async functions diff --git a/posthog/contexts.py b/posthog/contexts.py index 326916da..56bed4a9 100644 --- a/posthog/contexts.py +++ b/posthog/contexts.py @@ -391,12 +391,29 @@ def process_payment(payment_id): # and then re-raised some_risky_function() + # When stacking decorators, the posthog.scoped decorator must be + # closest to the function. For example, with FastAPI middleware: + @app.middleware("http") + @posthog.scoped() + async def middleware(request, call_next): + return await call_next(request) + Category: Contexts """ def decorator(func: F) -> F: from functools import wraps + from inspect import iscoroutinefunction + + if iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + with new_context(fresh=fresh, capture_exceptions=capture_exceptions): + return await func(*args, **kwargs) + + return cast(F, async_wrapper) @wraps(func) def wrapper(*args, **kwargs): diff --git a/posthog/test/test_contexts.py b/posthog/test/test_contexts.py index 3ff6fddd..a2f81fba 100644 --- a/posthog/test/test_contexts.py +++ b/posthog/test/test_contexts.py @@ -1,3 +1,4 @@ +import asyncio import unittest from unittest.mock import patch @@ -87,8 +88,6 @@ def successful_function(x, y): @patch("posthog.capture_exception") def test_scoped_decorator_exception(self, mock_capture): - test_exception = ValueError("Test exception") - def check_context_on_capture(exception, **kwargs): # Assert tags are available when capture_exception is called current_tags = get_tags() @@ -96,20 +95,40 @@ def check_context_on_capture(exception, **kwargs): mock_capture.side_effect = check_context_on_capture - @scoped() - def failing_function(): - tag("important_context", "value") - raise test_exception + for name, is_async in [("sync", False), ("async", True)]: + with self.subTest(name=name): + test_exception = ValueError(f"Test {name} exception") - # Function should raise the exception - with self.assertRaises(ValueError): - failing_function() + if is_async: - # Verify capture_exception was called - mock_capture.assert_called_once_with(test_exception) + @scoped() + async def failing_function(): + tag("important_context", "value") + raise test_exception - # Context should be cleared after function execution - assert get_tags() == {} + def run(): + return asyncio.run(failing_function()) + + else: + + @scoped() + def failing_function(): + tag("important_context", "value") + raise test_exception + + run = failing_function + + # Function should raise the exception + with self.assertRaises(ValueError): + run() + + # Verify capture_exception was called + mock_capture.assert_called_once_with(test_exception) + + # Context should be cleared after function execution + assert get_tags() == {} + + mock_capture.reset_mock() @patch("posthog.capture_exception") def test_new_context_exception_handling(self, mock_capture): @@ -219,15 +238,55 @@ def test_child_tags_override_parent_tags_in_non_fresh_context(self): def test_scoped_decorator_with_context_ids(self): @scoped() - def function_with_context(): + def sync_function_with_context(): + identify_context("user456") + set_context_session("session789") + return get_context_distinct_id(), get_context_session_id() + + @scoped() + async def async_function_with_context(): identify_context("user456") set_context_session("session789") return get_context_distinct_id(), get_context_session_id() - distinct_id, session_id = function_with_context() - assert distinct_id == "user456" - assert session_id == "session789" + cases = [ + ("sync", sync_function_with_context, lambda func: func()), + ("async", async_function_with_context, lambda func: asyncio.run(func())), + ] - # Context should be cleared after function execution - assert get_context_distinct_id() is None - assert get_context_session_id() is None + for name, func, run in cases: + with self.subTest(name=name): + distinct_id, session_id = run(func) + assert distinct_id == "user456" + assert session_id == "session789" + + # Context should be cleared after function execution + assert get_context_distinct_id() is None + assert get_context_session_id() is None + + def test_scoped_decorator_async_concurrent_context_isolation(self): + first_ready = asyncio.Event() + second_ready = asyncio.Event() + first_checked = asyncio.Event() + + @scoped() + async def first(): + identify_context("user_1") + first_ready.set() + await second_ready.wait() + distinct_id = get_context_distinct_id() + first_checked.set() + return distinct_id + + @scoped() + async def second(): + await first_ready.wait() + identify_context("user_2") + second_ready.set() + await first_checked.wait() + return get_context_distinct_id() + + async def run(): + return await asyncio.wait_for(asyncio.gather(first(), second()), timeout=1) + + assert asyncio.run(run()) == ["user_1", "user_2"]