Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .sampo/changesets/chivalrous-baron-tapio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pypi/posthog: patch
---

Fix scoped context support for async functions
17 changes: 17 additions & 0 deletions posthog/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,12 +391,29 @@ def process_payment(payment_id):
# and then re-raised
some_risky_function()
# When stacking decorators, the posthog.scoped decorator must be
# closest to the function. For example, with FastAPI middleware:
@app.middleware("http")
@posthog.scoped()
async def middleware(request, call_next):
return await call_next(request)
Category:
Contexts
"""

def decorator(func: F) -> F:
from functools import wraps
from inspect import iscoroutinefunction

if iscoroutinefunction(func):

@wraps(func)
async def async_wrapper(*args, **kwargs):
with new_context(fresh=fresh, capture_exceptions=capture_exceptions):
return await func(*args, **kwargs)

return cast(F, async_wrapper)

@wraps(func)
def wrapper(*args, **kwargs):
Expand Down
99 changes: 79 additions & 20 deletions posthog/test/test_contexts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import unittest
from unittest.mock import patch

Expand Down Expand Up @@ -87,29 +88,47 @@ def successful_function(x, y):

@patch("posthog.capture_exception")
def test_scoped_decorator_exception(self, mock_capture):
test_exception = ValueError("Test exception")

def check_context_on_capture(exception, **kwargs):
# Assert tags are available when capture_exception is called
current_tags = get_tags()
assert current_tags.get("important_context") == "value"

mock_capture.side_effect = check_context_on_capture

@scoped()
def failing_function():
tag("important_context", "value")
raise test_exception
for name, is_async in [("sync", False), ("async", True)]:
with self.subTest(name=name):
test_exception = ValueError(f"Test {name} exception")

# Function should raise the exception
with self.assertRaises(ValueError):
failing_function()
if is_async:

# Verify capture_exception was called
mock_capture.assert_called_once_with(test_exception)
@scoped()
async def failing_function():
tag("important_context", "value")
raise test_exception

# Context should be cleared after function execution
assert get_tags() == {}
def run():
return asyncio.run(failing_function())

else:

@scoped()
def failing_function():
tag("important_context", "value")
raise test_exception

run = failing_function

# Function should raise the exception
with self.assertRaises(ValueError):
run()

# Verify capture_exception was called
mock_capture.assert_called_once_with(test_exception)

# Context should be cleared after function execution
assert get_tags() == {}

mock_capture.reset_mock()

@patch("posthog.capture_exception")
def test_new_context_exception_handling(self, mock_capture):
Expand Down Expand Up @@ -219,15 +238,55 @@ def test_child_tags_override_parent_tags_in_non_fresh_context(self):

def test_scoped_decorator_with_context_ids(self):
@scoped()
def function_with_context():
def sync_function_with_context():
identify_context("user456")
set_context_session("session789")
return get_context_distinct_id(), get_context_session_id()

@scoped()
async def async_function_with_context():
identify_context("user456")
set_context_session("session789")
return get_context_distinct_id(), get_context_session_id()

distinct_id, session_id = function_with_context()
assert distinct_id == "user456"
assert session_id == "session789"
cases = [
("sync", sync_function_with_context, lambda func: func()),
("async", async_function_with_context, lambda func: asyncio.run(func())),
]

# Context should be cleared after function execution
assert get_context_distinct_id() is None
assert get_context_session_id() is None
for name, func, run in cases:
with self.subTest(name=name):
distinct_id, session_id = run(func)
assert distinct_id == "user456"
assert session_id == "session789"

# Context should be cleared after function execution
assert get_context_distinct_id() is None
assert get_context_session_id() is None

def test_scoped_decorator_async_concurrent_context_isolation(self):
first_ready = asyncio.Event()
second_ready = asyncio.Event()
first_checked = asyncio.Event()

@scoped()
async def first():
identify_context("user_1")
first_ready.set()
await second_ready.wait()
distinct_id = get_context_distinct_id()
first_checked.set()
return distinct_id

@scoped()
async def second():
await first_ready.wait()
identify_context("user_2")
second_ready.set()
await first_checked.wait()
return get_context_distinct_id()

async def run():
return await asyncio.wait_for(asyncio.gather(first(), second()), timeout=1)

assert asyncio.run(run()) == ["user_1", "user_2"]
Loading