From 4fbf7a84ddd8f23b69f7c8d3d3ef4c9f935ccefb Mon Sep 17 00:00:00 2001 From: simonwe97 Date: Wed, 11 Jun 2025 16:45:18 +0800 Subject: [PATCH 1/2] fix: agent generate config err --- src/google/adk/models/lite_llm.py | 32 +++++++++++-- tests/unittests/models/test_litellm.py | 66 +++++++++++++------------- 2 files changed, 60 insertions(+), 38 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faecf05..bef18b2d248 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -482,13 +482,13 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, ) -> tuple[Iterable[Message], Iterable[dict]]: - """Converts an LlmRequest to litellm inputs. + """Converts an LlmRequest to litellm inputs and extracts generation params. Args: llm_request: The LlmRequest to convert. Returns: - The litellm inputs (message list, tool dictionary and response format). + The litellm inputs (message list, tool dictionary, response format, and generation params). """ messages = [] for content in llm_request.contents or []: @@ -523,7 +523,28 @@ def _get_completion_inputs( if llm_request.config.response_schema: response_format = llm_request.config.response_schema - return messages, tools, response_format + # Extract generation params + generation_params = {} + if llm_request.config is not None: + config_dict = llm_request.config.model_dump(exclude_none=True) + for key in ( + "temperature", + "max_output_tokens", + "top_p", + "top_k", + "stop_sequences", + "presence_penalty", + "frequency_penalty", + ): + if key in config_dict: + if key == "max_output_tokens": + generation_params["max_tokens"] = config_dict[key] + elif key == "stop_sequences": + generation_params["stop"] = config_dict[key] + else: + generation_params[key] = config_dict[key] + + return messages, tools, response_format, generation_params def _build_function_declaration_log( @@ -660,7 +681,9 @@ async def generate_content_async( self._maybe_append_user_content(llm_request) logger.debug(_build_request_log(llm_request)) - messages, tools, response_format = _get_completion_inputs(llm_request) + messages, tools, response_format, generation_params = ( + _get_completion_inputs(llm_request) + ) completion_args = { "model": self.model, @@ -669,6 +692,7 @@ async def generate_content_async( "response_format": response_format, } completion_args.update(self._additional_args) + completion_args.update(generation_params) if stream: text = "" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index f316e83ae92..73ddfbc33e5 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,7 +13,6 @@ # limitations under the License. -import json from unittest.mock import AsyncMock from unittest.mock import Mock @@ -763,39 +762,6 @@ async def test_generate_content_async_with_tool_response( assert kwargs["messages"][2]["content"] == '{"result": "test_result"}' -@pytest.mark.asyncio -async def test_generate_content_async(mock_acompletion, lite_llm_instance): - - async for response in lite_llm_instance.generate_content_async( - LLM_REQUEST_WITH_FUNCTION_DECLARATION - ): - assert response.content.role == "model" - assert response.content.parts[0].text == "Test response" - assert response.content.parts[1].function_call.name == "test_function" - assert response.content.parts[1].function_call.args == { - "test_arg": "test_value" - } - assert response.content.parts[1].function_call.id == "test_tool_call_id" - - mock_acompletion.assert_called_once() - - _, kwargs = mock_acompletion.call_args - assert kwargs["model"] == "test_model" - assert kwargs["messages"][0]["role"] == "user" - assert kwargs["messages"][0]["content"] == "Test prompt" - assert kwargs["tools"][0]["function"]["name"] == "test_function" - assert ( - kwargs["tools"][0]["function"]["description"] - == "Test function description" - ) - assert ( - kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][ - "type" - ] - == "string" - ) - - @pytest.mark.asyncio async def test_generate_content_async_with_usage_metadata( lite_llm_instance, mock_acompletion @@ -1430,3 +1396,35 @@ async def test_generate_content_async_non_compliant_multiple_function_calls( assert final_response.content.parts[1].function_call.name == "function_2" assert final_response.content.parts[1].function_call.id == "1" assert final_response.content.parts[1].function_call.args == {"arg": "value2"} + + +@pytest.mark.asyncio +def test_get_completion_inputs_generation_params(): + # Test that generation_params are extracted and mapped correctly + req = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]), + ], + config=types.GenerateContentConfig( + temperature=0.33, + max_output_tokens=123, + top_p=0.88, + top_k=7, + stop_sequences=["foo", "bar"], + presence_penalty=0.1, + frequency_penalty=0.2, + ), + ) + from google.adk.models.lite_llm import _get_completion_inputs + + _, _, _, generation_params = _get_completion_inputs(req) + assert generation_params["temperature"] == 0.33 + assert generation_params["max_tokens"] == 123 + assert generation_params["top_p"] == 0.88 + assert generation_params["top_k"] == 7 + assert generation_params["stop"] == ["foo", "bar"] + assert generation_params["presence_penalty"] == 0.1 + assert generation_params["frequency_penalty"] == 0.2 + # Should not include max_output_tokens + assert "max_output_tokens" not in generation_params + assert "stop_sequences" not in generation_params From 842c09b9ffa7eb8f72e846fee80e2708b04653a1 Mon Sep 17 00:00:00 2001 From: simonwei97 Date: Sun, 15 Jun 2025 18:45:15 +0800 Subject: [PATCH 2/2] fix: resovle comment --- src/google/adk/models/lite_llm.py | 55 +++++++++++++++++--------- tests/unittests/models/test_litellm.py | 35 +++++++++++++++- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index bef18b2d248..c954711adbb 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -23,6 +23,7 @@ from typing import Dict from typing import Generator from typing import Iterable +from typing import List from typing import Literal from typing import Optional from typing import Tuple @@ -481,7 +482,12 @@ def _message_to_generate_content_response( def _get_completion_inputs( llm_request: LlmRequest, -) -> tuple[Iterable[Message], Iterable[dict]]: +) -> Tuple[ + List[Message], + Optional[List[Dict]], + Optional[types.SchemaUnion], + Optional[Dict], +]: """Converts an LlmRequest to litellm inputs and extracts generation params. Args: @@ -490,7 +496,8 @@ def _get_completion_inputs( Returns: The litellm inputs (message list, tool dictionary, response format, and generation params). """ - messages = [] + # 1. Construct messages + messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = _content_to_message_param(content) if isinstance(message_param_or_list, list): @@ -507,7 +514,8 @@ def _get_completion_inputs( ), ) - tools = None + # 2. Convert tool declarations + tools: Optional[List[Dict]] = None if ( llm_request.config and llm_request.config.tools @@ -518,15 +526,22 @@ def _get_completion_inputs( for tool in llm_request.config.tools[0].function_declarations ] - response_format = None - - if llm_request.config.response_schema: - response_format = llm_request.config.response_schema + # 3. Handle response format + response_format: Optional[types.SchemaUnion] = ( + llm_request.config.response_schema if llm_request.config else None + ) - # Extract generation params - generation_params = {} - if llm_request.config is not None: + # 4. Extract generation parameters + generation_params: Optional[Dict] = None + if llm_request.config: config_dict = llm_request.config.model_dump(exclude_none=True) + # Generate LiteLlm parameters here, + # Following https://docs.litellm.ai/docs/completion/input. + generation_params = {} + param_mapping = { + "max_output_tokens": "max_completion_tokens", + "stop_sequences": "stop", + } for key in ( "temperature", "max_output_tokens", @@ -537,12 +552,11 @@ def _get_completion_inputs( "frequency_penalty", ): if key in config_dict: - if key == "max_output_tokens": - generation_params["max_tokens"] = config_dict[key] - elif key == "stop_sequences": - generation_params["stop"] = config_dict[key] - else: - generation_params[key] = config_dict[key] + mapped_key = param_mapping.get(key, key) + generation_params[mapped_key] = config_dict[key] + + if not generation_params: + generation_params = None return messages, tools, response_format, generation_params @@ -691,8 +705,13 @@ async def generate_content_async( "tools": tools, "response_format": response_format, } - completion_args.update(self._additional_args) - completion_args.update(generation_params) + + # Merge additional arguments and generation parameters safely + if hasattr(self, "_additional_args") and self._additional_args: + completion_args.update(self._additional_args) + + if generation_params: + completion_args.update(generation_params) if stream: text = "" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 73ddfbc33e5..e600ee7f0c5 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -762,6 +762,39 @@ async def test_generate_content_async_with_tool_response( assert kwargs["messages"][2]["content"] == '{"result": "test_result"}' +@pytest.mark.asyncio +async def test_generate_content_async(mock_acompletion, lite_llm_instance): + + async for response in lite_llm_instance.generate_content_async( + LLM_REQUEST_WITH_FUNCTION_DECLARATION + ): + assert response.content.role == "model" + assert response.content.parts[0].text == "Test response" + assert response.content.parts[1].function_call.name == "test_function" + assert response.content.parts[1].function_call.args == { + "test_arg": "test_value" + } + assert response.content.parts[1].function_call.id == "test_tool_call_id" + + mock_acompletion.assert_called_once() + + _, kwargs = mock_acompletion.call_args + assert kwargs["model"] == "test_model" + assert kwargs["messages"][0]["role"] == "user" + assert kwargs["messages"][0]["content"] == "Test prompt" + assert kwargs["tools"][0]["function"]["name"] == "test_function" + assert ( + kwargs["tools"][0]["function"]["description"] + == "Test function description" + ) + assert ( + kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][ + "type" + ] + == "string" + ) + + @pytest.mark.asyncio async def test_generate_content_async_with_usage_metadata( lite_llm_instance, mock_acompletion @@ -1419,7 +1452,7 @@ def test_get_completion_inputs_generation_params(): _, _, _, generation_params = _get_completion_inputs(req) assert generation_params["temperature"] == 0.33 - assert generation_params["max_tokens"] == 123 + assert generation_params["max_completion_tokens"] == 123 assert generation_params["top_p"] == 0.88 assert generation_params["top_k"] == 7 assert generation_params["stop"] == ["foo", "bar"]