Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any, TypedDict, TypeVar, cast

Expand Down Expand Up @@ -445,7 +446,13 @@ async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGe
Formatted tool call chunks.
"""
for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
# Some LiteLLM proxy backends return null tool_call IDs.
# Ensure the first delta (used for content_start) has a valid ID.
first = tool_deltas[0]
if getattr(first, "id", None) is None:
first.id = f"tooluse_{uuid.uuid4().hex[:24]}"

yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": first})

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})
Expand Down
61 changes: 61 additions & 0 deletions tests/strands/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,3 +848,64 @@ def test_format_request_messages_with_tool_calls_no_content():
},
]
assert tru_result == exp_result


@pytest.mark.asyncio
async def test_stream_non_streaming_null_tool_call_id(litellm_acompletion, api_key, model_id, alist):
"""Verify that null tool_call IDs from proxy backends get a generated fallback.

Some LiteLLM proxy backends (e.g. running behind a custom proxy server)
return tool_calls with id=None, causing downstream failures when
the event loop tries to match tool results back to tool_call IDs.

See: https://github.com/strands-agents/sdk-python/issues/1259
"""
mock_function = unittest.mock.Mock()
mock_function.name = "search"
mock_function.arguments = '{"query": "test"}'

mock_tool_call = unittest.mock.Mock(index=0, function=mock_function, id=None)

mock_message = unittest.mock.Mock()
mock_message.content = None
mock_message.reasoning_content = None
mock_message.tool_calls = [mock_tool_call]

mock_choice = unittest.mock.Mock()
mock_choice.message = mock_message
mock_choice.finish_reason = "tool_calls"

mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]

mock_usage = unittest.mock.Mock()
mock_usage.prompt_tokens = 5
mock_usage.completion_tokens = 10
mock_usage.total_tokens = 15
mock_usage.prompt_tokens_details = None
mock_usage.cache_creation_input_tokens = None
mock_response.usage = mock_usage

litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response)

model = LiteLLMModel(
client_args={"api_key": api_key},
model_id=model_id,
params={"stream": False},
)

messages = [{"role": "user", "content": [{"type": "text", "text": "search test"}]}]
response = model.stream(messages)
events = await alist(response)

# Find the content_start event for the tool
tool_start = [
e for e in events if "contentBlockStart" in e and "toolUse" in e["contentBlockStart"].get("start", {})
]
assert len(tool_start) == 1

tool_use_id = tool_start[0]["contentBlockStart"]["start"]["toolUse"]["toolUseId"]
assert tool_use_id is not None
assert isinstance(tool_use_id, str)
assert len(tool_use_id) > 0
assert tool_use_id.startswith("tooluse_")