Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def __init__(
raise ValueError("Cannot specify both `region_name` and `boto_session`.")

session = boto_session or boto3.Session()
self._boto_session = session
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
self._resolved_region = resolved_region
self.config = BedrockModel.BedrockConfig(
model_id=BedrockModel._get_default_model_with_warning(resolved_region, model_config),
include_tool_result_status="auto",
Expand Down Expand Up @@ -177,14 +179,58 @@ def __init__(

logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)

self._caching_supported: bool | None = None

@property
def _supports_caching(self) -> bool:
"""Whether this model supports prompt caching.

Returns True for Claude models on Bedrock.
Returns True for Claude/Anthropic models on Bedrock, including when the model_id
is an application inference profile ARN derived from such a model.
"""
if self._caching_supported is not None:
return self._caching_supported

model_id = self.config.get("model_id", "").lower()
return "claude" in model_id or "anthropic" in model_id

if "claude" in model_id or "anthropic" in model_id:
self._caching_supported = True
return True

if model_id.startswith("arn:") and "application-inference-profile" in model_id:
original_model_id = self.config.get("model_id", "")
self._caching_supported = self._resolve_arn_supports_caching(original_model_id)
return self._caching_supported

self._caching_supported = False
return False

def _resolve_arn_supports_caching(self, arn: str) -> bool:
"""Resolve whether an application inference profile ARN supports caching.

Calls the Bedrock control-plane API to retrieve the underlying model
and checks if it is a Claude/Anthropic model.
"""
try:
bedrock_client = self._boto_session.client(
service_name="bedrock",
region_name=self._resolved_region,
)
response = bedrock_client.get_inference_profile(inferenceProfileIdentifier=arn)
except Exception:
logger.debug("model_id=<%s> | failed to resolve inference profile ARN", arn, exc_info=True)
return False

for model in response.get("models", []):
underlying = model.get("modelArn", "")
if not isinstance(underlying, str):
continue
underlying = underlying.lower()
if "claude" in underlying or "anthropic" in underlying:
logger.debug("model_id=<%s> | resolved ARN to caching-capable model", arn)
return True

return False

@override
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore
Expand All @@ -194,6 +240,8 @@ def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.BedrockConfig)
if "model_id" in model_config:
self._caching_supported = None
self.config.update(model_config)

@override
Expand Down
124 changes: 124 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,130 @@ def test_supports_caching_false_for_non_claude(bedrock_client):
assert model._supports_caching is False


def test_supports_caching_true_for_application_inference_profile_arn(session_cls):
"""Test that _supports_caching resolves application inference profile ARNs to the underlying model."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

# The bedrock control-plane client (same mock due to session_cls fixture)
mock_bedrock_client = mock_runtime_client
mock_bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}

arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/abc123"
model = BedrockModel(model_id=arn)
assert model._supports_caching is True


def test_supports_caching_false_for_non_claude_application_inference_profile_arn(session_cls):
"""Test that _supports_caching returns False for ARNs that resolve to non-Claude models."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

mock_bedrock_client = mock_runtime_client
mock_bedrock_client.get_inference_profile.return_value = {
"models": [{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-pro-v1:0"}]
}

arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/def456"
model = BedrockModel(model_id=arn)
assert model._supports_caching is False


def test_supports_caching_false_when_arn_resolution_fails(session_cls):
"""Test that _supports_caching returns False when the inference profile API call fails."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

mock_bedrock_client = mock_runtime_client
mock_bedrock_client.get_inference_profile.side_effect = ClientError(
{"Error": {"Code": "ResourceNotFoundException", "Message": "Not found"}},
"GetInferenceProfile",
)

arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/missing"
model = BedrockModel(model_id=arn)
assert model._supports_caching is False


def test_supports_caching_result_is_cached(session_cls):
"""Test that the caching support result is cached after first resolution."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

mock_bedrock_client = mock_runtime_client
mock_bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}

arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/cached"
model = BedrockModel(model_id=arn)

# First call resolves via API
assert model._supports_caching is True
# Second call should use cached value — reset mock to verify no extra calls
mock_bedrock_client.get_inference_profile.reset_mock()
assert model._supports_caching is True
mock_bedrock_client.get_inference_profile.assert_not_called()


def test_supports_caching_cache_invalidated_on_model_id_change(session_cls):
"""Test that _caching_supported cache is invalidated when model_id changes via update_config."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

model = BedrockModel(model_id="amazon.nova-pro-v1:0")
assert model._supports_caching is False

model.update_config(model_id="anthropic.claude-3-5-sonnet-20241022-v2:0")
assert model._supports_caching is True


def test_supports_caching_arn_uses_original_case(session_cls):
"""Test that the original ARN casing is preserved when calling the API."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

mock_bedrock_client = mock_runtime_client
mock_bedrock_client.get_inference_profile.return_value = {
"models": [
{"modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0"}
]
}

# ARN with mixed case in the profile ID
arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/AbCdEf123"
model = BedrockModel(model_id=arn)
assert model._supports_caching is True
# Verify the original ARN (not lowercased) was passed to the API
mock_bedrock_client.get_inference_profile.assert_called_once_with(inferenceProfileIdentifier=arn)


def test_supports_caching_empty_models_list(session_cls):
"""Test that _supports_caching returns False when the API returns an empty models list."""
mock_runtime_client = session_cls.return_value.client.return_value
mock_runtime_client.meta = unittest.mock.MagicMock()
mock_runtime_client.meta.region_name = "us-east-1"

mock_bedrock_client = mock_runtime_client
mock_bedrock_client.get_inference_profile.return_value = {"models": []}

arn = "arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/empty"
model = BedrockModel(model_id=arn)
assert model._supports_caching is False


def test_inject_cache_point_adds_to_last_assistant(bedrock_client):
"""Test that _inject_cache_point adds cache point to last assistant message."""
model = BedrockModel(
Expand Down
Loading