From 1830dae8a36679b9e8e526761939e8fb5b48e3c4 Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 20:18:41 +0100 Subject: [PATCH] fix: support Bedrock application inference profile ARNs in cache detection _supports_caching only checked for 'claude'/'anthropic' in the model_id string, which fails for application inference profile ARNs since they contain opaque profile IDs without model names. Changes: - Resolve application inference profile ARNs via the Bedrock control-plane GetInferenceProfile API to discover the underlying foundation model - Cache the resolution result to avoid repeated API calls - Invalidate cache when model_id changes via update_config() - Preserve original ARN casing for API calls (ARNs are case-sensitive) - Handle API errors gracefully with broad exception catching Fixes strands-agents/sdk-python#1705 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/strands/models/bedrock.py | 52 ++++++++++- tests/strands/models/test_bedrock.py | 124 +++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a48d7229..2790f029f 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -145,7 +145,9 @@ def __init__( raise ValueError("Cannot specify both `region_name` and `boto_session`.") session = boto_session or boto3.Session() + self._boto_session = session resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION + self._resolved_region = resolved_region self.config = BedrockModel.BedrockConfig( model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config), include_tool_result_status="auto", @@ -177,14 +179,58 @@ def __init__( logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name) + self._caching_supported: bool | None = None + @property def _supports_caching(self) -> bool: """Whether this model supports prompt caching. - Returns True for Claude models on Bedrock. + Returns True for Claude/Anthropic models on Bedrock, including when the model_id + is an application inference profile ARN derived from such a model. """ + if self._caching_supported is not None: + return self._caching_supported + model_id = self.config.get("model_id", "").lower() - return "claude" in model_id or "anthropic" in model_id + + if "claude" in model_id or "anthropic" in model_id: + self._caching_supported = True + return True + + if model_id.startswith("arn:") and "application-inference-profile" in model_id: + original_model_id = self.config.get("model_id", "") + self._caching_supported = self._resolve_arn_supports_caching(original_model_id) + return self._caching_supported + + self._caching_supported = False + return False + + def _resolve_arn_supports_caching(self, arn: str) -> bool: + """Resolve whether an application inference profile ARN supports caching. + + Calls the Bedrock control-plane API to retrieve the underlying model + and checks if it is a Claude/Anthropic model. + """ + try: + bedrock_client = self._boto_session.client( + service_name="bedrock", + region_name=self._resolved_region, + ) + response = bedrock_client.get_inference_profile(inferenceProfileIdentifier=arn) + except Exception: + logger.debug("model_id=<%s> | failed to resolve inference profile ARN", arn, exc_info=True) + return False + + for model in response.get("models", []): + underlying = model.get("modelArn", "") + if not isinstance(underlying, str): + continue + underlying = underlying.lower() + if "claude" in underlying or "anthropic" in underlying: + logger.debug("model_id=<%s> | resolved ARN to caching-capable model", arn) + return True + + return False @override def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore @@ -194,6 +240,8 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: **model_config: Configuration overrides. """ validate_config_keys(model_config, self.BedrockConfig) + if "model_id" in model_config: + self._caching_supported = None self.config.update(model_config) @override diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 9dae16be7..f61b52b91 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2597,6 +2597,130 @@ def test_supports_caching_false_for_non_claude(bedrock_client): assert model._supports_caching is False +def test_supports_caching_true_for_application_inference_profile_arn(session_cls): + """Test that _supports_caching resolves application inference profile ARNs to the underlying model.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + # The bedrock control-plane client (same mock due to session_cls fixture) + mock_bedrock_client = mock_runtime_client + mock_bedrock_client.get_inference_profile.return_value = { + "models": [ + {"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"} + ] + } + + arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123" + model = BedrockModel(model_id=arn) + assert model._supports_caching is True + + +def test_supports_caching_false_for_non_claude_application_inference_profile_arn(session_cls): + """Test that _supports_caching returns False for ARNs that resolve to non-Claude models.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + mock_bedrock_client = mock_runtime_client + mock_bedrock_client.get_inference_profile.return_value = { + "models": [{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-pro-v1:0"}] + } + + arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/def456" + model = BedrockModel(model_id=arn) + assert model._supports_caching is False + + +def test_supports_caching_false_when_arn_resolution_fails(session_cls): + """Test that _supports_caching returns False when the inference profile API call fails.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + mock_bedrock_client = mock_runtime_client + mock_bedrock_client.get_inference_profile.side_effect = ClientError( + {"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}}, + "GetInferenceProfile", + ) + + arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/missing" + model = BedrockModel(model_id=arn) + assert model._supports_caching is False + + +def test_supports_caching_result_is_cached(session_cls): + """Test that the caching support result is cached after first resolution.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + mock_bedrock_client = mock_runtime_client + mock_bedrock_client.get_inference_profile.return_value = { + "models": [ + {"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"} + ] + } + + arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/cached" + model = BedrockModel(model_id=arn) + + # First call resolves via API + assert model._supports_caching is True + # Second call should use cached value — reset mock to verify no extra calls + mock_bedrock_client.get_inference_profile.reset_mock() + assert model._supports_caching is True + mock_bedrock_client.get_inference_profile.assert_not_called() + + +def test_supports_caching_cache_invalidated_on_model_id_change(session_cls): + """Test that _caching_supported cache is invalidated when model_id changes via update_config.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + model = BedrockModel(model_id="amazon.nova-pro-v1:0") + assert model._supports_caching is False + + model.update_config(model_id="anthropic.claude-3-5-sonnet-20241022-v2:0") + assert model._supports_caching is True + + +def test_supports_caching_arn_uses_original_case(session_cls): + """Test that the original ARN casing is preserved when calling the API.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + mock_bedrock_client = mock_runtime_client + mock_bedrock_client.get_inference_profile.return_value = { + "models": [ + {"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"} + ] + } + + # ARN with mixed case in the profile ID + arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/AbCdEf123" + model = BedrockModel(model_id=arn) + assert model._supports_caching is True + # Verify the original ARN (not lowercased) was passed to the API + mock_bedrock_client.get_inference_profile.assert_called_once_with(inferenceProfileIdentifier=arn) + + +def test_supports_caching_empty_models_list(session_cls): + """Test that _supports_caching returns False when the API returns an empty models list.""" + mock_runtime_client = session_cls.return_value.client.return_value + mock_runtime_client.meta = unittest.mock.MagicMock() + mock_runtime_client.meta.region_name = "us-east-1" + + mock_bedrock_client = mock_runtime_client + mock_bedrock_client.get_inference_profile.return_value = {"models": []} + + arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/empty" + model = BedrockModel(model_id=arn) + assert model._supports_caching is False + + def test_inject_cache_point_adds_to_last_assistant(bedrock_client): """Test that _inject_cache_point adds cache point to last assistant message.""" model = BedrockModel(