Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,142 @@ def _build_tool_call_from_json_dict(
return tool_call


# DeepSeek models may emit tool calls as inline text using proprietary
# special tokens. See https://api-docs.deepseek.com/guides/function_calling
# for the full specification. LiteLLM usually translates these into
# structured `tool_calls` but when it doesn't (intermittent), the raw
# tokens land in the `content` field and must be parsed here.
_DS_TCALLS_BEGIN = "\u003c\uff5ctool\u2581calls\u2581begin\uff5c\u003e"
_DS_TCALLS_END = "\u003c\uff5ctool\u2581calls\u2581end\uff5c\u003e"
_DS_TCALL_BEGIN = "\u003c\uff5ctool\u2581call\u2581begin\uff5c\u003e"
_DS_TCALL_END = "\u003c\uff5ctool\u2581call\u2581end\uff5c\u003e"
_DS_TSEP = "\u003c\uff5ctool\u2581sep\uff5c\u003e"

# Pattern: <|tool▁call▁begin|>function<|tool▁sep|>NAME \n ARGS <|tool▁call▁end|>
_DS_TOOL_CALL_RE = re.compile(
re.escape(_DS_TCALL_BEGIN)
+ r"function"
+ re.escape(_DS_TSEP)
+ r"([^\n\r]+?)\s*?\n(.*?)"
+ re.escape(_DS_TCALL_END),
re.DOTALL,
)


def _extract_json_from_deepseek_args(args_text: str) -> Optional[str]:
"""Extracts a JSON string from DeepSeek arguments text.

Args:
args_text: Raw text containing the function arguments, possibly
wrapped in Markdown-style code fences.

Returns:
The JSON string, or None if no valid JSON object could be found.
"""
if not args_text:
return None
# Strip optional Markdown code fences (```json ... ``` or ``` ... ```).
fence_match = re.search(r"```(?:json)?\s*(\{[\s\S]*?\})\s*```", args_text)
if fence_match:
return fence_match.group(1).strip()
# Fall back to the first balanced { … } block.
open_brace = args_text.find("{")
if open_brace == -1:
return None
try:
candidate, _ = _JSON_DECODER.raw_decode(args_text, open_brace)
return json.dumps(candidate, ensure_ascii=False)
except json.JSONDecodeError:
return None


def _parse_deepseek_tool_calls_from_text(
text_block: str,
) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]:
"""Parses DeepSeek proprietary inline tool-call tokens from text.

When LiteLLM does not translate DeepSeek's special tokens into
structured ``tool_calls``, the raw tokens appear inside the ``content``
field. This function extracts them and returns standard
``ChatCompletionMessageToolCall`` objects.

Token reference
``<|tool▁calls▁begin|>`` … ``<|tool▁calls▁end|>`` → outer wrapper
``<|tool▁call▁begin|>function<|tool▁sep|>NAME`` → single call start
``<|tool▁call▁end|>`` → single call end

Args:
text_block: The raw text that may contain DeepSeek tokens.

Returns:
A tuple of ``(tool_calls, remainder)`` where ``remainder`` is the
original text with all DeepSeek token regions removed.
"""
_ensure_litellm_imported()

tool_calls: list[ChatCompletionMessageToolCall] = []
if not text_block:
return tool_calls, None

# Quick guard: only invoke the regex if the outer tokens are present.
if _DS_TCALLS_BEGIN not in text_block and _DS_TCALL_BEGIN not in text_block:
return tool_calls, None

remainder_parts: list[str] = []
cursor = 0

# Outer loop — there may be multiple <|tool▁calls▁begin|> blocks.
while True:
begin_idx = text_block.find(_DS_TCALLS_BEGIN, cursor)
if begin_idx == -1:
# No more wrapped blocks; also look for unwrapped top-level call tokens.
begin_idx = text_block.find(_DS_TCALL_BEGIN, cursor)
if begin_idx == -1:
remainder_parts.append(text_block[cursor:])
break

# Everything before the token becomes remainder.
if begin_idx > cursor:
remainder_parts.append(text_block[cursor:begin_idx])

# Determine whether we are inside a wrapped block.
in_wrapped_block = text_block[begin_idx : begin_idx + len(_DS_TCALLS_BEGIN)] == _DS_TCALLS_BEGIN # pytype: disable=attribute-error # pylint: disable=line-too-long
if in_wrapped_block:
end_idx = text_block.find(
_DS_TCALLS_END, begin_idx + len(_DS_TCALLS_BEGIN)
)
if end_idx == -1:
remainder_parts.append(text_block[begin_idx:])
break
block = text_block[begin_idx + len(_DS_TCALLS_BEGIN) : end_idx]
cursor = end_idx + len(_DS_TCALLS_END)
else:
# Unwrapped call token — scan for a matching end token.
end_idx = text_block.find(_DS_TCALL_END, begin_idx + len(_DS_TCALL_BEGIN))
if end_idx == -1:
remainder_parts.append(text_block[begin_idx:])
break
block = text_block[begin_idx : end_idx + len(_DS_TCALL_END)]
cursor = end_idx + len(_DS_TCALL_END)

# Parse individual tool calls inside the block.
for match in _DS_TOOL_CALL_RE.finditer(block):
func_name = match.group(1).strip()
args_raw = match.group(2).strip()
args_json = _extract_json_from_deepseek_args(args_raw)
if not func_name or not args_json:
continue
tool_call = _build_tool_call_from_json_dict(
{"name": func_name, "arguments": args_json},
index=len(tool_calls),
)
if tool_call:
tool_calls.append(tool_call)

remainder = "".join(p for p in remainder_parts if p).strip()
return tool_calls, remainder or None


def _parse_tool_calls_from_text(
text_block: str,
) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]:
Expand All @@ -1318,6 +1454,17 @@ def _parse_tool_calls_from_text(

_ensure_litellm_imported()

# Try DeepSeek proprietary format first, then fall back to generic JSON.
ds_tool_calls, ds_remainder = _parse_deepseek_tool_calls_from_text(text_block)
if ds_tool_calls:
# If the remainder still contains content, re-parse it for
# additional generic inline JSON tool calls (mixed formats).
if ds_remainder:
extra_calls, extra_remainder = _parse_tool_calls_from_text(ds_remainder)
tool_calls = ds_tool_calls + (extra_calls or [])
return tool_calls, extra_remainder
return ds_tool_calls, None

remainder_segments = []
cursor = 0
text_length = len(text_block)
Expand Down
124 changes: 124 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from google.adk.models.lite_llm import _MISSING_TOOL_RESULT_MESSAGE
from google.adk.models.lite_llm import _model_response_to_chunk
from google.adk.models.lite_llm import _model_response_to_generate_content_response
from google.adk.models.lite_llm import _parse_deepseek_tool_calls_from_text
from google.adk.models.lite_llm import _parse_tool_calls_from_text
from google.adk.models.lite_llm import _redirect_litellm_loggers_to_stdout
from google.adk.models.lite_llm import _safe_json_serialize
Expand Down Expand Up @@ -2692,6 +2693,129 @@ def test_parse_tool_calls_from_text_invalid_json_returns_remainder():
assert remainder == 'Leading {"unused": "payload"} trailing text'


# ---------------------------------------------------------------------------
# DeepSeek proprietary inline tool-call format tests
# ---------------------------------------------------------------------------

_DS_BEGIN_CALLS = "\u003c\uff5ctool\u2581calls\u2581begin\uff5c\u003e"
_DS_END_CALLS = "\u003c\uff5ctool\u2581calls\u2581end\uff5c\u003e"
_DS_BEGIN_CALL = "\u003c\uff5ctool\u2581call\u2581begin\uff5c\u003e"
_DS_END_CALL = "\u003c\uff5ctool\u2581call\u2581end\uff5c\u003e"
_DS_SEP = "\u003c\uff5ctool\u2581sep\uff5c\u003e"


def _ds_tool_call(name: str, args_json: str) -> str:
"""Build a single DeepSeek-style tool-call block."""
return (
f"{_DS_BEGIN_CALL}function{_DS_SEP}{name}\n"
f"```json\n{args_json}\n```"
f"{_DS_END_CALL}"
)


def _ds_wrapped(inner: str) -> str:
"""Wrap content in <|tool▁calls▁begin|>...<|tool▁calls▁end|>."""
return f"{_DS_BEGIN_CALLS}{inner}{_DS_END_CALLS}"


def test_parse_deepseek_single_tool_call():
"""Single DeepSeek tool call with code-fenced JSON args."""
text = _ds_wrapped(
_ds_tool_call("get_weather", '{"city": "Beijing", "unit": "celsius"}')
)
tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text)
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
assert json.loads(tool_calls[0].function.arguments) == {
"city": "Beijing",
"unit": "celsius",
}
assert remainder is None


def test_parse_deepseek_multi_tool_call():
"""Multiple DeepSeek tool calls in a single wrapped block."""
inner = _ds_tool_call("func_a", '{"x": 1}') + _ds_tool_call(
"func_b", '{"y": 2}'
)
text = _ds_wrapped(inner)
tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text)
assert len(tool_calls) == 2
assert tool_calls[0].function.name == "func_a"
assert json.loads(tool_calls[0].function.arguments) == {"x": 1}
assert tool_calls[1].function.name == "func_b"
assert json.loads(tool_calls[1].function.arguments) == {"y": 2}
assert remainder is None


def test_parse_deepseek_plain_json_args():
"""DeepSeek tool call without Markdown code fences around args."""
inner = (
f"{_DS_BEGIN_CALL}function{_DS_SEP}search\n"
f'{{"query": "天气"}}'
f"{_DS_END_CALL}"
)
text = _ds_wrapped(inner)
tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text)
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "search"
assert json.loads(tool_calls[0].function.arguments) == {"query": "天气"}


def test_parse_deepseek_with_surrounding_text():
"""DeepSeek tool call embedded in surrounding non-tool text."""
prefix = "Let me think about this.\n"
suffix = "\nI'll proceed now."
inner = _ds_tool_call("calculate", '{"expr": "2+2"}')
text = prefix + _ds_wrapped(inner) + suffix
tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text)
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "calculate"
assert remainder == "Let me think about this.\n\nI'll proceed now."


def test_parse_deepseek_no_tokens_returns_empty():
"""Text without DeepSeek tokens returns no tool calls and None remainder."""
text = "Just a regular response, no special tokens here."
tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text)
assert tool_calls == []
assert remainder is None


def test_parse_tool_calls_from_text_handles_deepseek_format():
"""Integration: the generic parser delegates to the DeepSeek parser."""
text = _ds_wrapped(
_ds_tool_call("fetch_page", '{"url": "https://example.com"}')
)
tool_calls, remainder = _parse_tool_calls_from_text(text)
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "fetch_page"
assert json.loads(tool_calls[0].function.arguments) == {
"url": "https://example.com"
}
assert remainder is None


def test_parse_tool_calls_from_text_mixed_formats():
"""DeepSeek tokens + standard inline JSON in the same text."""
ds_part = _ds_wrapped(_ds_tool_call("ds_func", '{"a": 1}'))
standard_part = '{"name": "std_func", "arguments": {"b": 2}}'
text = ds_part + " some text " + standard_part
tool_calls, remainder = _parse_tool_calls_from_text(text)
assert len(tool_calls) == 2
assert tool_calls[0].function.name == "ds_func"
assert tool_calls[1].function.name == "std_func"
assert remainder == "some text"


def test_parse_deepseek_empty_text():
"""Empty or whitespace-only text returns no tool calls."""
for text in ("", " ", "\n\n"):
tool_calls, remainder = _parse_deepseek_tool_calls_from_text(text)
assert tool_calls == []
assert remainder is None


def test_split_message_content_and_tool_calls_inline_text():
message = {
"role": "assistant",
Expand Down