From 54c1934c6373bc4bad03e07393265cc47c21d4f9 Mon Sep 17 00:00:00 2001 From: hemanthsavasere Date: Sat, 16 May 2026 09:49:13 +0000 Subject: [PATCH] feat: allow plain dict yields in single-step streaming actions Intermediate yields in streaming actions can now be plain dicts instead of requiring (dict, None) tuples. The (dict, None) form still works for backward compatibility. --- burr/core/action.py | 6 +- burr/core/application.py | 4 + tests/core/test_application.py | 148 +++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 3 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index 58b5ecd17..08771d411 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1545,8 +1545,8 @@ def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str] See the following example for how to use this decorator -- this reads ``prompt`` from the state and writes ``response`` back out, yielding all intermediate chunks. - Note that this *must* return a value. If it does not, we will not know how to update the state, and - we will error out. + Note that this *must* return a final value with a state update. If it does not, we will not know how to update the state, and + we will error out. Intermediate yields can be plain dicts (without a state update). .. code-block:: python @@ -1565,7 +1565,7 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State] delta = chunk.choices[0].delta.content buffer.append(delta) # yield partial results - yield {'response': delta}, None + yield {'response': delta} full_response = ''.join(buffer) # return the final result return {'response': full_response}, state.update(response=full_response) diff --git a/burr/core/application.py b/burr/core/application.py index 7cbe96149..25bce4a10 100644 --- a/burr/core/application.py +++ b/burr/core/application.py @@ -340,6 +340,8 @@ def _run_single_step_streaming_action( count = 0 try: for item in generator: + if isinstance(item, dict): + item = (item, None) if not isinstance(item, tuple): # TODO -- consider adding support for just returning a result. raise ValueError( @@ -406,6 +408,8 @@ async def _arun_single_step_streaming_action( count = 0 try: async for item in generator: + if isinstance(item, dict): + item = (item, None) if not isinstance(item, tuple): # TODO -- consider adding support for just returning a result. raise ValueError( diff --git a/tests/core/test_application.py b/tests/core/test_application.py index 7383583f0..f8dd56048 100644 --- a/tests/core/test_application.py +++ b/tests/core/test_application.py @@ -827,6 +827,25 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterYieldsDict(SingleStepStreamingAction): + def stream_run_and_update( + self, state: State, **run_kwargs + ) -> Generator[Tuple[dict, Optional[State]], None, None]: + steps_per_count = run_kwargs.get("granularity", 10) + count = state["count"] + for i in range(steps_per_count): + yield {"count": count + ((i + 1) / 10)} + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + class SingleStepStreamingCounterAsync(SingleStepStreamingAction): async def stream_run_and_update( self, state: State, **run_kwargs @@ -848,6 +867,27 @@ def writes(self) -> list[str]: return ["count", "tracker"] +class SingleStepStreamingCounterYieldsDictAsync(SingleStepStreamingAction): + async def stream_run_and_update( + self, state: State, **run_kwargs + ) -> AsyncGenerator[Tuple[dict, Optional[State]], None]: + steps_per_count = run_kwargs.get("granularity", 10) + count = state["count"] + for i in range(steps_per_count): + await asyncio.sleep(0.01) + yield {"count": count + ((i + 1) / 10)} + await asyncio.sleep(0.01) + yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1) + + @property + def reads(self) -> list[str]: + return ["count"] + + @property + def writes(self) -> list[str]: + return ["count", "tracker"] + + class StreamingActionIncorrectResultType(StreamingAction): def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]: yield {} @@ -1207,6 +1247,60 @@ def test__run_single_step_streaming_action(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +def test__run_single_step_streaming_action_yields_dict(): + action = SingleStepStreamingCounterYieldsDict().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", + ) + last_result = -1 + result, state = None, None + for result, state in generator: + if last_result < 1: + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +def test__run_single_step_streaming_action_yields_dict_calls_callbacks(): + action = SingleStepStreamingCounterYieldsDict().with_name("counter") + + class TrackingCallback(PostStreamItemHook): + def __init__(self): + self.items = [] + + def post_stream_item(self, item: Any, **future_kwargs: Any): + self.items.append(item) + + hook = TrackingCallback() + + state = State({"count": 0, "tracker": []}) + generator = _run_single_step_streaming_action( + action, + state, + inputs={}, + sequence_id=0, + partition_key="partition_key", + app_id="app_id", + lifecycle_adapters=LifecycleAdapterSet(hook), + ) + last_result = -1 + result, state = None, None + for result, state in generator: + if last_result < 1: + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + assert len(hook.items) == 10 + + def test__run_single_step_streaming_action_calls_callbacks(): action = base_streaming_single_step_counter.with_name("counter") @@ -1266,6 +1360,60 @@ async def test__run_single_step_streaming_action_async(): assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} +async def test__run_single_step_streaming_action_yields_dict_async(): + async_action = SingleStepStreamingCounterYieldsDictAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=async_action, + state=state, + inputs={}, + sequence_id=0, + app_id="app_id", + partition_key="partition_key", + lifecycle_adapters=LifecycleAdapterSet(), + ) + last_result = -1 + result, state = None, None + async for result, state in generator: + if last_result < 1: + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + + +async def test__run_single_step_streaming_action_yields_dict_async_callbacks(): + class TrackingCallback(PostStreamItemHookAsync): + def __init__(self): + self.items = [] + + async def post_stream_item(self, item: Any, **future_kwargs: Any): + self.items.append(item) + + hook = TrackingCallback() + + async_action = SingleStepStreamingCounterYieldsDictAsync().with_name("counter") + state = State({"count": 0, "tracker": []}) + generator = _arun_single_step_streaming_action( + action=async_action, + state=state, + inputs={}, + sequence_id=0, + app_id="app_id", + partition_key="partition_key", + lifecycle_adapters=LifecycleAdapterSet(hook), + ) + last_result = -1 + result, state = None, None + async for result, state in generator: + if last_result < 1: + assert result["count"] > last_result + last_result = result["count"] + assert result == {"count": 1} + assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]} + assert len(hook.items) == 10 + + async def test__run_single_step_streaming_action_async_callbacks(): class TrackingCallback(PostStreamItemHookAsync): def __init__(self):