Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,8 +1545,8 @@ def __init__(self, reads: List[str], writes: List[str], tags: Optional[List[str]
See the following example for how to use this decorator -- this reads ``prompt`` from the state and writes
``response`` back out, yielding all intermediate chunks.

Note that this *must* return a value. If it does not, we will not know how to update the state, and
we will error out.
Note that this *must* return a final value with a state update. If it does not, we will not know how to update the state, and
we will error out. Intermediate yields can be plain dicts (without a state update).

.. code-block:: python

Expand All @@ -1565,7 +1565,7 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]
delta = chunk.choices[0].delta.content
buffer.append(delta)
# yield partial results
yield {'response': delta}, None
yield {'response': delta}
full_response = ''.join(buffer)
# return the final result
return {'response': full_response}, state.update(response=full_response)
Expand Down
4 changes: 4 additions & 0 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def _run_single_step_streaming_action(
count = 0
try:
for item in generator:
if isinstance(item, dict):
item = (item, None)
if not isinstance(item, tuple):
# TODO -- consider adding support for just returning a result.
raise ValueError(
Expand Down Expand Up @@ -406,6 +408,8 @@ async def _arun_single_step_streaming_action(
count = 0
try:
async for item in generator:
if isinstance(item, dict):
item = (item, None)
if not isinstance(item, tuple):
# TODO -- consider adding support for just returning a result.
raise ValueError(
Expand Down
148 changes: 148 additions & 0 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,25 @@ def writes(self) -> list[str]:
return ["count", "tracker"]


class SingleStepStreamingCounterYieldsDict(SingleStepStreamingAction):
def stream_run_and_update(
self, state: State, **run_kwargs
) -> Generator[Tuple[dict, Optional[State]], None, None]:
steps_per_count = run_kwargs.get("granularity", 10)
count = state["count"]
for i in range(steps_per_count):
yield {"count": count + ((i + 1) / 10)}
yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1)

@property
def reads(self) -> list[str]:
return ["count"]

@property
def writes(self) -> list[str]:
return ["count", "tracker"]


class SingleStepStreamingCounterAsync(SingleStepStreamingAction):
async def stream_run_and_update(
self, state: State, **run_kwargs
Expand All @@ -848,6 +867,27 @@ def writes(self) -> list[str]:
return ["count", "tracker"]


class SingleStepStreamingCounterYieldsDictAsync(SingleStepStreamingAction):
async def stream_run_and_update(
self, state: State, **run_kwargs
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
steps_per_count = run_kwargs.get("granularity", 10)
count = state["count"]
for i in range(steps_per_count):
await asyncio.sleep(0.01)
yield {"count": count + ((i + 1) / 10)}
await asyncio.sleep(0.01)
yield {"count": count + 1}, state.update(count=count + 1).append(tracker=count + 1)

@property
def reads(self) -> list[str]:
return ["count"]

@property
def writes(self) -> list[str]:
return ["count", "tracker"]


class StreamingActionIncorrectResultType(StreamingAction):
def stream_run(self, state: State, **run_kwargs) -> Generator[dict, None, dict]:
yield {}
Expand Down Expand Up @@ -1207,6 +1247,60 @@ def test__run_single_step_streaming_action():
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}


def test__run_single_step_streaming_action_yields_dict():
action = SingleStepStreamingCounterYieldsDict().with_name("counter")
state = State({"count": 0, "tracker": []})
generator = _run_single_step_streaming_action(
action,
state,
inputs={},
sequence_id=0,
partition_key="partition_key",
app_id="app_id",
)
last_result = -1
result, state = None, None
for result, state in generator:
if last_result < 1:
assert result["count"] > last_result
last_result = result["count"]
assert result == {"count": 1}
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}


def test__run_single_step_streaming_action_yields_dict_calls_callbacks():
action = SingleStepStreamingCounterYieldsDict().with_name("counter")

class TrackingCallback(PostStreamItemHook):
def __init__(self):
self.items = []

def post_stream_item(self, item: Any, **future_kwargs: Any):
self.items.append(item)

hook = TrackingCallback()

state = State({"count": 0, "tracker": []})
generator = _run_single_step_streaming_action(
action,
state,
inputs={},
sequence_id=0,
partition_key="partition_key",
app_id="app_id",
lifecycle_adapters=LifecycleAdapterSet(hook),
)
last_result = -1
result, state = None, None
for result, state in generator:
if last_result < 1:
assert result["count"] > last_result
last_result = result["count"]
assert result == {"count": 1}
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}
assert len(hook.items) == 10


def test__run_single_step_streaming_action_calls_callbacks():
action = base_streaming_single_step_counter.with_name("counter")

Expand Down Expand Up @@ -1266,6 +1360,60 @@ async def test__run_single_step_streaming_action_async():
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}


async def test__run_single_step_streaming_action_yields_dict_async():
async_action = SingleStepStreamingCounterYieldsDictAsync().with_name("counter")
state = State({"count": 0, "tracker": []})
generator = _arun_single_step_streaming_action(
action=async_action,
state=state,
inputs={},
sequence_id=0,
app_id="app_id",
partition_key="partition_key",
lifecycle_adapters=LifecycleAdapterSet(),
)
last_result = -1
result, state = None, None
async for result, state in generator:
if last_result < 1:
assert result["count"] > last_result
last_result = result["count"]
assert result == {"count": 1}
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}


async def test__run_single_step_streaming_action_yields_dict_async_callbacks():
class TrackingCallback(PostStreamItemHookAsync):
def __init__(self):
self.items = []

async def post_stream_item(self, item: Any, **future_kwargs: Any):
self.items.append(item)

hook = TrackingCallback()

async_action = SingleStepStreamingCounterYieldsDictAsync().with_name("counter")
state = State({"count": 0, "tracker": []})
generator = _arun_single_step_streaming_action(
action=async_action,
state=state,
inputs={},
sequence_id=0,
app_id="app_id",
partition_key="partition_key",
lifecycle_adapters=LifecycleAdapterSet(hook),
)
last_result = -1
result, state = None, None
async for result, state in generator:
if last_result < 1:
assert result["count"] > last_result
last_result = result["count"]
assert result == {"count": 1}
assert state.subset("count", "tracker").get_all() == {"count": 1, "tracker": [1]}
assert len(hook.items) == 10


async def test__run_single_step_streaming_action_async_callbacks():
class TrackingCallback(PostStreamItemHookAsync):
def __init__(self):
Expand Down
Loading