diff --git a/google/genai/models.py b/google/genai/models.py index c004c1dfa..1be968e71 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -36,6 +36,20 @@ logger = logging.getLogger('google_genai.models') +def _merge_content_parts(contents: list[types.Content]) -> Optional[types.Content]: + """Merges streamed content chunks into one content block.""" + parts: list[types.Part] = [] + role = None + for content in contents: + if role is None: + role = content.role + if content.parts: + parts.extend(content.parts) + if not parts: + return None + return types.Content(role=role, parts=parts) + + def _PersonGeneration_to_mldev_enum_validate(enum_value: Any) -> None: if enum_value in set(['ALLOW_ALL']): raise ValueError( @@ -6567,6 +6581,7 @@ def generate_content_stream( automatic_function_calling_history: list[types.Content] = [] chunk = None func_response_parts = None + func_call_content = None i = 0 while remaining_remote_calls_afc > 0: i += 1 @@ -6575,6 +6590,8 @@ def generate_content_stream( ) function_map = _extra_utils.get_function_map(parsed_config) + func_response_parts = [] + func_call_contents = [] if i == 1: # First request gets a function call. @@ -6591,12 +6608,16 @@ def generate_content_stream( or not chunk.candidates[0].content.parts ): break - func_response_parts = _extra_utils.get_function_response_parts( + response_parts = _extra_utils.get_function_response_parts( chunk, function_map ) - if not func_response_parts: + if response_parts: + func_response_parts.extend(response_parts) + func_call_contents.append(chunk.candidates[0].content) + else: contents = _extra_utils.append_chunk_contents(contents, chunk) # type: ignore[assignment] yield chunk + func_call_content = _merge_content_parts(func_call_contents) else: # Second request and beyond, yield chunks. @@ -6617,6 +6638,7 @@ def generate_content_stream( func_response_parts = _extra_utils.get_function_response_parts( chunk, function_map ) + func_call_content = chunk.candidates[0].content if not function_map: break @@ -6629,7 +6651,6 @@ def generate_content_stream( # Append function response parts to contents for the next request. if chunk is not None and chunk.candidates is not None: - func_call_content = chunk.candidates[0].content func_response_content = types.Content( role='user', parts=func_response_parts, @@ -8667,6 +8688,7 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d ) automatic_function_calling_history: list[types.Content] = [] func_response_parts = None + func_call_content = None chunk = None i = 0 while remaining_remote_calls_afc > 0: @@ -8686,6 +8708,8 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d function_map = _extra_utils.get_function_map( config, mcp_to_genai_tool_adapters, is_caller_method_async=True ) + func_response_parts = [] + func_call_contents = [] if i == 1: # First request gets a function call. @@ -8702,14 +8726,18 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d or not chunk.candidates[0].content.parts ): break - func_response_parts = ( + response_parts = ( await _extra_utils.get_function_response_parts_async( chunk, function_map ) ) - if not func_response_parts: + if response_parts: + func_response_parts.extend(response_parts) + func_call_contents.append(chunk.candidates[0].content) + else: contents = _extra_utils.append_chunk_contents(contents, chunk) yield chunk + func_call_content = _merge_content_parts(func_call_contents) else: # Second request and beyond, yield chunks. @@ -8733,6 +8761,7 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d chunk, function_map ) ) + func_call_content = chunk.candidates[0].content if not function_map: break @@ -8742,7 +8771,6 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d if chunk is None: continue # Append function response parts to contents for the next request. - func_call_content = chunk.candidates[0].content func_response_content = types.Content( role='user', parts=func_response_parts, diff --git a/google/genai/tests/afc/test_generate_content_stream_afc.py b/google/genai/tests/afc/test_generate_content_stream_afc.py index 5cdb9b33f..c2c3f8f08 100644 --- a/google/genai/tests/afc/test_generate_content_stream_afc.py +++ b/google/genai/tests/afc/test_generate_content_stream_afc.py @@ -76,6 +76,31 @@ ] +def _function_call_response( + name: str, + args: dict[str, object], + thought_signature: bytes, +) -> types.GenerateContentResponse: + return types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name=name, + args=args, + ), + thought_signature=thought_signature, + ) + ], + role='model', + ) + ) + ] + ) + + def get_current_weather(location: str) -> str: """Returns the current weather. @@ -359,6 +384,58 @@ def test_generate_content_stream_with_thought_summaries( ) == TEST_AFC_HISTORY[i].model_dump(exclude_none=True) +def test_generate_content_stream_merges_function_call_chunks_with_signatures(): + with mock.patch.object( + models.Models, '_generate_content_stream' + ) as mock_stream_with_parallel_calls: + mock_stream_with_parallel_calls.side_effect = [ + [ + _function_call_response( + 'get_current_weather', + {'location': 'San Francisco'}, + b'weather-signature', + ), + _function_call_response( + 'get_aqi_from_city', + {'location': 'San Francisco'}, + b'aqi-signature', + ), + ], + [ + types.GenerateContentResponse( + candidates=[types.Candidate(content=TEST_AFC_TEXT_CONTENT)] + ) + ], + ] + models_instance = models.Models(api_client_=mock_api_client) + stream = models_instance.generate_content_stream( + model='test_model', + contents='what is the weather and AQI in San Francisco?', + config=types.GenerateContentConfig( + tools=[get_current_weather, get_aqi_from_city] + ), + ) + + chunks = list(stream) + + assert len(chunks) == 1 + assert mock_stream_with_parallel_calls.call_count == 2 + second_request_contents = mock_stream_with_parallel_calls.call_args_list[ + 1 + ].kwargs['contents'] + function_call_content = second_request_contents[-2] + function_response_content = second_request_contents[-1] + + assert len(function_call_content.parts) == 2 + assert [ + part.function_call.name for part in function_call_content.parts + ] == ['get_current_weather', 'get_aqi_from_city'] + assert [ + part.thought_signature for part in function_call_content.parts + ] == [b'weather-signature', b'aqi-signature'] + assert len(function_response_content.parts) == 2 + + @pytest.mark.asyncio async def test_generate_content_stream_no_function_map_async( mock_generate_content_stream_no_afc, @@ -528,3 +605,59 @@ async def test_generate_content_stream_with_thought_summaries_async( assert chunk.automatic_function_calling_history[i].model_dump( exclude_none=True ) == TEST_AFC_HISTORY[i].model_dump(exclude_none=True) + + +@pytest.mark.asyncio +async def test_generate_content_stream_merges_function_call_chunks_async(): + with mock.patch.object( + models.AsyncModels, '_generate_content_stream' + ) as mock_stream_with_parallel_calls: + + async def async_generator_1(): + yield _function_call_response( + 'get_current_weather', + {'location': 'San Francisco'}, + b'weather-signature', + ) + yield _function_call_response( + 'get_aqi_from_city', + {'location': 'San Francisco'}, + b'aqi-signature', + ) + + async def async_generator_2(): + yield types.GenerateContentResponse( + candidates=[types.Candidate(content=TEST_AFC_TEXT_CONTENT)] + ) + + mock_stream_with_parallel_calls.side_effect = [ + async_generator_1(), + async_generator_2(), + ] + models_instance = models.AsyncModels(api_client_=mock_api_client) + stream = await models_instance.generate_content_stream( + model='test_model', + contents='what is the weather and AQI in San Francisco?', + config=types.GenerateContentConfig( + tools=[get_current_weather, get_aqi_from_city] + ), + ) + + chunks = [chunk async for chunk in stream] + + assert len(chunks) == 1 + assert mock_stream_with_parallel_calls.call_count == 2 + second_request_contents = mock_stream_with_parallel_calls.call_args_list[ + 1 + ].kwargs['contents'] + function_call_content = second_request_contents[-2] + function_response_content = second_request_contents[-1] + + assert len(function_call_content.parts) == 2 + assert [ + part.function_call.name for part in function_call_content.parts + ] == ['get_current_weather', 'get_aqi_from_city'] + assert [ + part.thought_signature for part in function_call_content.parts + ] == [b'weather-signature', b'aqi-signature'] + assert len(function_response_content.parts) == 2