diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 3a6c36624d..1aef5396f9 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -807,9 +807,10 @@ async def _content_to_message_param( if isinstance(response, str) else _safe_json_serialize(response) ) + tool_role = "tool_responses" if "gemma4" in model.lower() else "tool" tool_messages.append( ChatCompletionToolMessage( - role="tool", + role=tool_role, tool_call_id=part.function_response.id, content=response_content, ) @@ -824,6 +825,7 @@ async def _content_to_message_param( follow_up = await _content_to_message_param( types.Content(role=content.role, parts=non_tool_parts), provider=provider, + model=model, ) follow_up_messages = ( follow_up if isinstance(follow_up, list) else [follow_up] @@ -934,13 +936,7 @@ async def _content_to_message_param( ) -def _ensure_tool_results(messages: List[Message]) -> List[Message]: - """Insert placeholder tool messages for missing tool results. - - LiteLLM-backed providers like OpenAI and Anthropic reject histories where an - assistant tool call is not followed by tool responses before the next - non-tool message. This helps recover from interrupted tool execution. - """ +def _ensure_tool_results(messages: List[Message], model: str) -> List[Message]: if not messages: return messages @@ -948,17 +944,19 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: healed_messages: List[Message] = [] pending_tool_call_ids: List[str] = [] + expected_tool_role = "tool_responses" if "gemma4" in model.lower() else "tool" for message in messages: role = message.get("role") - if pending_tool_call_ids and role != "tool": + + if pending_tool_call_ids and role != expected_tool_role: logger.warning( "Missing tool results for tool_call_id(s): %s", pending_tool_call_ids, ) healed_messages.extend( ChatCompletionToolMessage( - role="tool", + role=expected_tool_role, tool_call_id=tool_call_id, content=_MISSING_TOOL_RESULT_MESSAGE, ) @@ -971,13 +969,14 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: pending_tool_call_ids = [ tool_call.get("id") for tool_call in tool_calls if tool_call.get("id") ] - elif role == "tool": + elif role == expected_tool_role: tool_call_id = message.get("tool_call_id") if tool_call_id in pending_tool_call_ids: pending_tool_call_ids.remove(tool_call_id) healed_messages.append(message) + # Bloque final también usa expected_tool_role if pending_tool_call_ids: logger.warning( "Missing tool results for tool_call_id(s): %s", @@ -985,7 +984,7 @@ def _ensure_tool_results(messages: List[Message]) -> List[Message]: ) healed_messages.extend( ChatCompletionToolMessage( - role="tool", + role=expected_tool_role, tool_call_id=tool_call_id, content=_MISSING_TOOL_RESULT_MESSAGE, ) @@ -1905,7 +1904,7 @@ async def _get_completion_inputs( content=llm_request.config.system_instruction, ), ) - messages = _ensure_tool_results(messages) + messages = _ensure_tool_results(messages, model) # 2. Convert tool declarations tools: Optional[List[Dict]] = None diff --git a/tests/unittests/models/test_lite_llm_gemma_tool_role.py b/tests/unittests/models/test_lite_llm_gemma_tool_role.py new file mode 100644 index 0000000000..31afd01fcd --- /dev/null +++ b/tests/unittests/models/test_lite_llm_gemma_tool_role.py @@ -0,0 +1,193 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Gemma-specific tool role handling in _content_to_message_param. + +Gemma's chat template expects role='tool_responses' for tool result messages, +while the OpenAI-compatible default is role='tool'. This module verifies that +_content_to_message_param sets the correct role based on the model name. +""" + +from google.adk.models.lite_llm import _content_to_message_param +from google.genai import types +import pytest + +# Auxiliar + + +def _make_function_response_content( + function_name: str = "get_weather", + response_data: dict | None = None, + call_id: str = "call_001", +) -> types.Content: + """Builds a types.Content with a single function_response part.""" + if response_data is None: + response_data = {"city": "Santiago de Cuba", "condition": "sunny"} + return types.Content( + role="user", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=function_name, + response=response_data, + id=call_id, + ) + ) + ], + ) + + +def _make_multi_function_response_content( + call_ids: list[str] | None = None, +) -> types.Content: + """Builds a types.Content with multiple function_response parts.""" + if call_ids is None: + call_ids = ["call_001", "call_002"] + return types.Content( + role="user", + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=f"tool_{i}", + response={"result": f"value_{i}"}, + id=call_id, + ) + ) + for i, call_id in enumerate(call_ids) + ], + ) + + +def _extract_role(msg) -> str: + """Extracts role from a litellm message, whether dict or object.""" + if isinstance(msg, dict): + return msg["role"] + return msg.role + + +# Tests: single + + +class TestToolRoleSingleResponse: + """_content_to_message_param with a single function_response part.""" + + @pytest.mark.asyncio + async def test_non_gemma_model_uses_tool_role(self): + """Non-Gemma models should get role='tool' (OpenAI-compatible default).""" + content = _make_function_response_content() + + result = await _content_to_message_param( + content, model="ollama/qwen2.5-coder:3b" + ) + + assert _extract_role(result) == "tool" + + @pytest.mark.asyncio + async def test_gemma4_model_uses_tool_responses_role(self): + """Models containing 'gemma4' should get role='tool_responses'.""" + content = _make_function_response_content() + + result = await _content_to_message_param(content, model="ollama/gemma4:e2b") + + assert _extract_role(result) == "tool_responses", ( + "Gemma models require role='tool_responses' to match their chat " + "template; role='tool' causes infinite tool-calling loops." + ) + + @pytest.mark.asyncio + async def test_gemma4_uppercase_model_name(self): + """Model name matching should be case-insensitive.""" + content = _make_function_response_content() + + result = await _content_to_message_param(content, model="ollama/Gemma4:31b") + + assert _extract_role(result) == "tool_responses" + + @pytest.mark.asyncio + async def test_tool_call_id_and_content_preserved(self): + """Fix must not alter tool_call_id or content — only role changes.""" + content = _make_function_response_content( + response_data={"status": "ok"}, call_id="my_call_123" + ) + + result = await _content_to_message_param(content, model="ollama/gemma4:e2b") + + if isinstance(result, dict): + assert result["tool_call_id"] == "my_call_123" + assert "ok" in result["content"] + else: + assert result.tool_call_id == "my_call_123" + assert "ok" in result.content + + @pytest.mark.asyncio + async def test_empty_model_string_uses_tool_role(self): + """Empty model string should fall back to default role='tool'.""" + content = _make_function_response_content() + + result = await _content_to_message_param(content, model="") + + assert _extract_role(result) == "tool" + + @pytest.mark.asyncio + async def test_unrelated_models_use_tool_role(self): + """Models that do not contain 'gemma4' must not be affected.""" + unaffected_models = [ + "ollama/llama3:8b", + "anthropic/claude-3-opus", + "openai/gpt-4o", + "ollama/gemma3:4b", # gemma3 != gemma4 + ] + for model in unaffected_models: + content = _make_function_response_content() + result = await _content_to_message_param(content, model=model) + assert ( + _extract_role(result) == "tool" + ), f"Model '{model}' should not be affected by the Gemma4 fix." + + +# Tests: multiple + + +class TestToolRoleMultipleResponses: + """_content_to_message_param with multiple function_response parts.""" + + @pytest.mark.asyncio + async def test_gemma4_all_messages_use_tool_responses_role(self): + """All messages in a multi-response must have role='tool_responses'.""" + content = _make_multi_function_response_content( + call_ids=["call_a", "call_b", "call_c"] + ) + + result = await _content_to_message_param(content, model="ollama/gemma4:4b") + + assert isinstance(result, list) + assert len(result) == 3 + for msg in result: + assert _extract_role(msg) == "tool_responses", ( + "Every tool message in a multi-response must use 'tool_responses' " + "for Gemma4 models." + ) + + @pytest.mark.asyncio + async def test_non_gemma_multi_response_uses_tool_role(self): + """Non-Gemma multi-response messages should all have role='tool'.""" + content = _make_multi_function_response_content( + call_ids=["call_a", "call_b"] + ) + + result = await _content_to_message_param(content, model="openai/gpt-4o") + + assert isinstance(result, list) + for msg in result: + assert _extract_role(msg) == "tool"