diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index be5337f0d..895c2a04d 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,6 +5,7 @@ import json import logging +import uuid from collections.abc import AsyncGenerator from typing import Any, TypedDict, TypeVar, cast @@ -445,7 +446,13 @@ async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGe Formatted tool call chunks. """ for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + # Some LiteLLM proxy backends return null tool_call IDs. + # Ensure the first delta (used for content_start) has a valid ID. + first = tool_deltas[0] + if getattr(first, "id", None) is None: + first.id = f"tooluse_{uuid.uuid4().hex[:24]}" + + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": first}) for tool_delta in tool_deltas: yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 9bb0e09ca..ac790a583 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -848,3 +848,64 @@ def test_format_request_messages_with_tool_calls_no_content(): }, ] assert tru_result == exp_result + + +@pytest.mark.asyncio +async def test_stream_non_streaming_null_tool_call_id(litellm_acompletion, api_key, model_id, alist): + """Verify that null tool_call IDs from proxy backends get a generated fallback. + + Some LiteLLM proxy backends (e.g. running behind a custom proxy server) + return tool_calls with id=None, causing downstream failures when + the event loop tries to match tool results back to tool_call IDs. + + See: https://github.com/strands-agents/sdk-python/issues/1259 + """ + mock_function = unittest.mock.Mock() + mock_function.name = "search" + mock_function.arguments = '{"query": "test"}' + + mock_tool_call = unittest.mock.Mock(index=0, function=mock_function, id=None) + + mock_message = unittest.mock.Mock() + mock_message.content = None + mock_message.reasoning_content = None + mock_message.tool_calls = [mock_tool_call] + + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_choice.finish_reason = "tool_calls" + + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 5 + mock_usage.completion_tokens = 10 + mock_usage.total_tokens = 15 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_response.usage = mock_usage + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) + + model = LiteLLMModel( + client_args={"api_key": api_key}, + model_id=model_id, + params={"stream": False}, + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "search test"}]}] + response = model.stream(messages) + events = await alist(response) + + # Find the content_start event for the tool + tool_start = [ + e for e in events if "contentBlockStart" in e and "toolUse" in e["contentBlockStart"].get("start", {}) + ] + assert len(tool_start) == 1 + + tool_use_id = tool_start[0]["contentBlockStart"]["start"]["toolUse"]["toolUseId"] + assert tool_use_id is not None + assert isinstance(tool_use_id, str) + assert len(tool_use_id) > 0 + assert tool_use_id.startswith("tooluse_")