diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 3dde11d4d4bc..b1746c4a4cd8 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,10 @@ ## Release History +### 4.14.5 (2026-01-15) + +#### Bugs Fixed +* Fixed bug where sdk was encountering a timeout issue caused by infinite recursion during the 410 (Gone) error.See [PR 44659](https://github.com/Azure/azure-sdk-for-python/pull/44649) + ### 4.14.4 (2026-01-12) #### Bugs Fixed diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py index 9f96fe7dc026..1a866c2df873 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/base_execution_context.py @@ -25,10 +25,12 @@ from collections import deque import copy +import logging from ...aio import _retry_utility_async from ... import http_constants, exceptions +_LOGGER = logging.getLogger(__name__) # pylint: disable=protected-access @@ -145,8 +147,19 @@ async def callback(**kwargs): # pylint: disable=unused-argument self._client, self._client._global_endpoint_manager, callback, **self._options ) + # Check if this is an internal partition key range fetch - skip 410 retry logic to avoid recursion + # When we call refresh_routing_map_provider(), it triggers _ReadPartitionKeyRanges which would + # come through this same code path. If that also gets a 410 and tries to refresh, we get infinite recursion. + is_pk_range_fetch = self._options.get("_internal_pk_range_fetch", False) + if is_pk_range_fetch: + # For partition key range queries, just execute without 410 partition split retry + # The underlying retry utility will still handle other transient errors + _LOGGER.debug("Partition split retry (async): Skipping 410 retry for internal PK range fetch") + return await execute_fetch() + max_retries = 3 attempt = 0 + while attempt <= max_retries: try: return await execute_fetch() @@ -154,14 +167,33 @@ async def callback(**kwargs): # pylint: disable=unused-argument if exceptions._partition_range_is_gone(e): attempt += 1 if attempt > max_retries: + _LOGGER.error( + "Partition split retry (async): Exhausted all %d retries. " + "state: _has_started=%s, _continuation=%s", + max_retries, self._has_started, self._continuation + ) raise # Exhausted retries, propagate error + _LOGGER.warning( + "Partition split retry (async): 410 error (sub_status=%s). Attempt %d of %d. " + "Refreshing routing map and resetting state.", + getattr(e, 'sub_status', 'N/A'), + attempt, + max_retries + ) + # Refresh routing map to get new partition key ranges self._client.refresh_routing_map_provider() + # Reset execution context state to allow retry from the beginning + self._has_started = False + self._continuation = None # Retry immediately (no backoff needed for partition splits) continue raise # Not a partition split error, propagate immediately + # This should never be reached, but added for safety + return [] + class _DefaultQueryExecutionContext(_QueryExecutionContextBase): """ diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py index 6ef8c7f5a9c6..528ca87f2586 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/base_execution_context.py @@ -25,8 +25,10 @@ from collections import deque import copy +import logging from .. import _retry_utility, http_constants, exceptions +_LOGGER = logging.getLogger(__name__) # pylint: disable=protected-access @@ -143,6 +145,16 @@ def callback(**kwargs): # pylint: disable=unused-argument self._client, self._client._global_endpoint_manager, callback, **self._options ) + # Check if this is an internal partition key range fetch - skip 410 retry logic to avoid recursion + # When we call refresh_routing_map_provider(), it triggers _ReadPartitionKeyRanges which would + # come through this same code path. If that also gets a 410 and tries to refresh, we get infinite recursion. + is_pk_range_fetch = self._options.get("_internal_pk_range_fetch", False) + if is_pk_range_fetch: + # For partition key range queries, just execute without 410 partition split retry + # The underlying retry utility will still handle other transient errors + _LOGGER.debug("Partition split retry: Skipping 410 retry for internal PK range fetch") + return execute_fetch() + max_retries = 3 attempt = 0 @@ -153,13 +165,32 @@ def callback(**kwargs): # pylint: disable=unused-argument if exceptions._partition_range_is_gone(e): attempt += 1 if attempt > max_retries: + _LOGGER.error( + "Partition split retry: Exhausted all %d retries. " + "state: _has_started=%s, _continuation=%s", + max_retries, self._has_started, self._continuation + ) raise # Exhausted retries, propagate error + _LOGGER.warning( + "Partition split retry: 410 error (sub_status=%s). Attempt %d of %d. " + "Refreshing routing map and resetting state.", + getattr(e, 'sub_status', 'N/A'), + attempt, + max_retries + ) + # Refresh routing map to get new partition key ranges self._client.refresh_routing_map_provider() + # Reset execution context state to allow retry from the beginning + self._has_started = False + self._continuation = None # Retry immediately (no backoff needed for partition splits) continue raise # Not a partition split error, propagate immediately + + # This should never be reached, but added for safety + return [] next = __next__ # Python 2 compatibility. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index a725a828024a..34e4a3436c27 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -22,12 +22,15 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +import logging from typing import Any, Optional from ... import _base from ..collection_routing_map import CollectionRoutingMap from .. import routing_range +_LOGGER = logging.getLogger(__name__) + # pylint: disable=protected-access @@ -75,18 +78,33 @@ async def init_collection_routing_map_if_needed( ): collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: + # Pass _internal_pk_range_fetch flag to prevent recursive 410 retry logic + # When a 410 partition split error occurs, the SDK calls refresh_routing_map_provider() + # which clears the cache and retries. The retry needs partition key ranges, which calls + # this method, which triggers _ReadPartitionKeyRanges. If that query also goes through + # the 410 retry logic and calls refresh again, we get infinite recursion. + _LOGGER.debug( + "PK range cache (async): Initializing routing map for collection_id=%s with " + "_internal_pk_range_fetch=True to prevent recursive 410 retry.", + collection_id + ) + pk_range_kwargs = {**kwargs, "_internal_pk_range_fetch": True} collection_pk_ranges = [pk async for pk in self._documentClient._ReadPartitionKeyRanges(collection_link, feed_options, - **kwargs)] + **pk_range_kwargs)] # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. - collection_pk_ranges = PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges) + collection_pk_ranges = list(PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges)) collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( [(r, True) for r in collection_pk_ranges], collection_id ) self._collection_routing_map_by_item[collection_id] = collection_routing_map + _LOGGER.debug( + "PK range cache (async): Cached routing map for collection_id=%s with %d ranges", + collection_id, len(collection_pk_ranges) + ) async def get_range_by_partition_key_range_id( self, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 4a26984d9d99..263d31916c49 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -22,6 +22,7 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +import logging from typing import Any, Optional from .. import _base @@ -29,6 +30,8 @@ from . import routing_range from .routing_range import PartitionKeyRange +_LOGGER = logging.getLogger(__name__) + # pylint: disable=protected-access @@ -61,17 +64,32 @@ def init_collection_routing_map_if_needed( ): collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if not collection_routing_map: + # Pass _internal_pk_range_fetch flag to prevent recursive 410 retry logic + # When a 410 partition split error occurs, the SDK calls refresh_routing_map_provider() + # which clears the cache and retries. The retry needs partition key ranges, which calls + # this method, which triggers _ReadPartitionKeyRanges. If that query also goes through + # the 410 retry logic and calls refresh again, we get infinite recursion. + _LOGGER.debug( + "PK range cache: Initializing routing map for collection_id=%s with " + "_internal_pk_range_fetch=True to prevent recursive 410 retry.", + collection_id + ) + pk_range_kwargs = {**kwargs, "_internal_pk_range_fetch": True} collection_pk_ranges = list(self._documentClient._ReadPartitionKeyRanges(collection_link, feed_options, - **kwargs)) + **pk_range_kwargs)) # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. - collection_pk_ranges = PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges) + collection_pk_ranges = list(PartitionKeyRangeCache._discard_parent_ranges(collection_pk_ranges)) collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( [(r, True) for r in collection_pk_ranges], collection_id ) self._collection_routing_map_by_item[collection_id] = collection_routing_map + _LOGGER.debug( + "PK range cache: Cached routing map for collection_id=%s with %d ranges", + collection_id, len(collection_pk_ranges) + ) def get_overlapping_ranges(self, collection_link, partition_key_ranges, feed_options, **kwargs): """Given a partition key range and a collection, return the list of diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index d4a56e1545a3..b50a97f8d223 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -80,6 +80,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin """ # pylint: disable=protected-access, too-many-branches kwargs.pop(_Constants.OperationStartTime, None) + # Pop internal flags that should not be passed to the HTTP layer + kwargs.pop("_internal_pk_range_fetch", None) connection_timeout = connection_policy.RequestTimeout connection_timeout = kwargs.pop("connection_timeout", connection_timeout) read_timeout = connection_policy.ReadTimeout diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index e6b2758537f7..d56668da1d50 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.14.4" +VERSION = "4.14.5" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index f9477b0d3da5..50ee31666826 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -52,6 +52,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p """ # pylint: disable=protected-access, too-many-branches kwargs.pop(_Constants.OperationStartTime, None) + # Pop internal flags that should not be passed to the HTTP layer + kwargs.pop("_internal_pk_range_fetch", None) connection_timeout = connection_policy.RequestTimeout read_timeout = connection_policy.ReadTimeout connection_timeout = kwargs.pop("connection_timeout", connection_timeout) diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index 9a96d024d3e9..2d30dee5dfbc 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -18,7 +18,7 @@ class MockedCosmosClientConnection(object): def __init__(self, partition_key_ranges): self.partition_key_ranges = partition_key_ranges - def _ReadPartitionKeyRanges(self, collection_link: str, feed_options: Optional[Mapping[str, Any]] = None): + def _ReadPartitionKeyRanges(self, collection_link: str, feed_options: Optional[Mapping[str, Any]] = None, **kwargs): return self.partition_key_ranges def setUp(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py new file mode 100644 index 000000000000..5564136002b7 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -0,0 +1,328 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +""" +Sync unit tests for partition split (410) retry logic. +""" + +import gc +import time +import unittest +from unittest.mock import patch + +import pytest + +from azure.cosmos import exceptions +from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes + +# tracemalloc is not available in PyPy, so we import conditionally +try: + import tracemalloc + HAS_TRACEMALLOC = True +except ImportError: + HAS_TRACEMALLOC = False + + +# ================================= +# Shared Test Helpers +# ================================= + +class MockGlobalEndpointManager: + """Mock global endpoint manager for testing.""" + def is_circuit_breaker_applicable(self, request): + return False + + +class MockClient: + """Mock Cosmos client for testing partition split retry logic.""" + def __init__(self): + self._global_endpoint_manager = MockGlobalEndpointManager() + self.refresh_routing_map_provider_call_count = 0 + + def refresh_routing_map_provider(self): + self.refresh_routing_map_provider_call_count += 1 + + def reset_counts(self): + """Reset call counts for reuse in tests.""" + self.refresh_routing_map_provider_call_count = 0 + + +def create_410_partition_split_error(): + """Create a 410 partition split error for testing.""" + error = exceptions.CosmosHttpResponseError( + status_code=StatusCodes.GONE, + message="Partition key range is gone" + ) + error.sub_status = SubStatusCodes.PARTITION_KEY_RANGE_GONE + return error + + +def raise_410_partition_split_error(*args, **kwargs): + """Raise a 410 partition split error - for use as mock side_effect.""" + raise create_410_partition_split_error() + + +# ========================== +# Test Class +# ========================== + +@pytest.mark.cosmosEmulator +class TestPartitionSplitRetryUnit(unittest.TestCase): + """ + Sync unit tests for 410 partition split retry logic. + """ + + def test_execution_context_state_reset_on_partition_split(self): + """ + Test that execution context state is properly reset on 410 partition split retry. + Verifies the fix where the while loop in _fetch_items_helper_no_retries + would not execute after a retry because _has_started was still True. + """ + mock_client = MockClient() + + def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + # simulate state after first successful fetch but before 410 error + context._has_started = True + context._continuation = None + + # Verify the loop condition without state reset - this is false + loop_condition_without_reset = context._continuation or not context._has_started + assert not loop_condition_without_reset, \ + "Without state reset, loop condition should be False" + + # Verify _fetch_items_helper_no_retries returns empty when state is not reset + fetch_was_called = [False] + + def tracking_fetch(options): + fetch_was_called[0] = True + return ([{"id": "1"}], {}) + + result = context._fetch_items_helper_no_retries(tracking_fetch) + assert not fetch_was_called[0], \ + "Fetch should NOT be called when _has_started=True and _continuation=None" + assert result == [], \ + "Should return empty list when while loop doesn't execute" + + # Now reset state + context._has_started = False + context._continuation = None + + # verify the loop condition with state reset + loop_condition_with_reset = context._continuation or not context._has_started + assert loop_condition_with_reset, \ + "With state reset, loop condition should be True" + + # verify _fetch_items_helper_no_retries works after state reset + result = context._fetch_items_helper_no_retries(tracking_fetch) + assert fetch_was_called[0], \ + "Fetch SHOULD be called after state reset" + assert result == [{"id": "1"}], \ + "Should return documents after state reset" + + @patch('azure.cosmos._retry_utility.Execute') + def test_retry_with_410_resets_state_and_succeeds(self, mock_execute): + """ + Test the full retry flow: 410 partition split error triggers state reset and retry succeeds. + """ + mock_client = MockClient() + expected_docs = [{"id": "success"}] + call_count = [0] + + def execute_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise create_410_partition_split_error() + return expected_docs + + mock_execute.side_effect = execute_side_effect + + def mock_fetch_function(options): + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2, "Should have retried once after 410" + assert mock_client.refresh_routing_map_provider_call_count == 1, \ + "refresh_routing_map_provider should be called once on 410" + assert result == expected_docs, "Should return expected documents after retry" + + @patch('azure.cosmos._retry_utility.Execute') + def test_pk_range_query_skips_410_retry_to_prevent_recursion(self, mock_execute): + """ + Test that partition key range queries (marked with _internal_pk_range_fetch=True) + skip the 410 partition split retry logic to prevent infinite recursion. + + When a 410 partition split error occurs: + 1. SDK calls refresh_routing_map_provider() which clears the routing map cache + 2. SDK retries the query + 3. Retry needs partition key ranges, which triggers _ReadPartitionKeyRanges + 4. If _ReadPartitionKeyRanges also uses 410 retry logic and gets a 410, + it would call refresh_routing_map_provider() again, creating infinite recursion + + This test verifies that queries with _internal_pk_range_fetch=True do not + trigger the 410 retry with refresh logic. + """ + mock_client = MockClient() + options = {"_internal_pk_range_fetch": True} + + mock_execute.side_effect = raise_410_partition_split_error + + def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + context = _DefaultQueryExecutionContext(mock_client, options, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError) as exc_info: + context._fetch_items_helper_with_retries(mock_fetch_function) + + assert exc_info.value.status_code == StatusCodes.GONE + assert mock_client.refresh_routing_map_provider_call_count == 0, \ + "refresh_routing_map_provider should NOT be called for PK range queries" + assert mock_execute.call_count == 1, \ + "Execute should only be called once - no retry for PK range queries" + + @patch('azure.cosmos._retry_utility.Execute') + def test_410_retry_behavior_with_and_without_pk_range_flag(self, mock_execute): + """ + Test that verifies the fix for the partition split recursion problem. + + The fix ensures: + - Regular queries retry up to 3 times on 410, calling refresh each time + - PK range queries (with _internal_pk_range_fetch flag) skip retry entirely, + preventing infinite recursion when refresh_routing_map_provider triggers + another PK range query that also gets a 410 + """ + mock_client = MockClient() + + mock_execute.side_effect = raise_410_partition_split_error + + def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + # Test 1: Regular query (no flag) - should retry 3 times + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + context._fetch_items_helper_with_retries(mock_fetch_function) + + assert mock_client.refresh_routing_map_provider_call_count == 3, \ + f"Expected 3 refresh calls, got {mock_client.refresh_routing_map_provider_call_count}" + assert mock_execute.call_count == 4, \ + f"Expected 4 Execute calls, got {mock_execute.call_count}" + + # Test 2: PK range query (with flag) - should NOT retry + mock_client.reset_counts() + mock_execute.reset_mock() + mock_execute.side_effect = raise_410_partition_split_error + + options_with_flag = {"_internal_pk_range_fetch": True} + context_pk_range = _DefaultQueryExecutionContext(mock_client, options_with_flag, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + context_pk_range._fetch_items_helper_with_retries(mock_fetch_function) + + assert mock_client.refresh_routing_map_provider_call_count == 0, \ + f"With flag, expected 0 refresh calls, got {mock_client.refresh_routing_map_provider_call_count}" + assert mock_execute.call_count == 1, \ + f"With flag, expected 1 Execute call, got {mock_execute.call_count}" + + @pytest.mark.skipif(not HAS_TRACEMALLOC, reason="tracemalloc not available in PyPy") + @patch('azure.cosmos._retry_utility.Execute') + def test_memory_bounded_no_leak_on_410_retries(self, mock_execute): + """ + Test that memory usage is bounded during 410 partition split retries. + - Execute calls are bounded (max 4: 1 initial + 3 retries) + - Refresh calls are bounded (max 3) + - Memory growth is minimal (no recursive accumulation) + - No infinite recursion (max depth = 0 for PK range queries) + """ + # tracemalloc.start() begins tracing memory allocations to detect leaks + tracemalloc.start() + # gc.collect() forces garbage collection to get accurate baseline memory measurement + gc.collect() + # take_snapshot() captures current memory state for comparison after test + snapshot_before = tracemalloc.take_snapshot() + start_time = time.time() + + mock_client = MockClient() + + mock_execute.side_effect = raise_410_partition_split_error + + def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + # Test regular query - should have bounded retries + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + context._fetch_items_helper_with_retries(mock_fetch_function) + + elapsed_time = time.time() - start_time + # gc.collect() before snapshot ensures we measure actual leaks, not pending garbage + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + # compare_to() shows memory difference between snapshots to identify growth + top_stats = snapshot_after.compare_to(snapshot_before, 'lineno') + memory_growth = sum(stat.size_diff for stat in top_stats if stat.size_diff > 0) + peak_memory = tracemalloc.get_traced_memory()[1] + # tracemalloc.stop() ends memory tracing and frees tracing overhead + tracemalloc.stop() + + # Collect metrics + execute_calls = mock_execute.call_count + refresh_calls = mock_client.refresh_routing_map_provider_call_count + + # Print metrics + print(f"\n{'=' * 60}") + print("MEMORY METRICS - Partition Split Memory Verification") + print(f"{'=' * 60}") + print(f"Metrics:") + print(f" - Execute calls: {execute_calls} (bounded)") + print(f" - Refresh calls: {refresh_calls}") + print(f" - Elapsed time: {elapsed_time:.2f}s") + print(f" - Memory growth: {memory_growth / 1024:.2f} KB") + print(f" - Peak memory: {peak_memory / 1024:.2f} KB") + print(f"{'=' * 60}") + + assert execute_calls == 4, \ + f"Execute calls should be bounded to 4, got {execute_calls}" + assert refresh_calls == 3, \ + f"Refresh calls should be bounded to 3, got {refresh_calls}" + assert elapsed_time < 1.0, \ + f"Should complete quickly (< 1s), took {elapsed_time:.2f}s - indicates no infinite loop" + assert memory_growth < 500 * 1024, \ + f"Memory growth should be < 500KB, got {memory_growth / 1024:.2f} KB - indicates no memory leak" + + # Test PK range query - should have NO retries (prevents recursion) + mock_client.reset_counts() + mock_execute.reset_mock() + mock_execute.side_effect = raise_410_partition_split_error + + options_with_flag = {"_internal_pk_range_fetch": True} + context_pk = _DefaultQueryExecutionContext(mock_client, options_with_flag, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + context_pk._fetch_items_helper_with_retries(mock_fetch_function) + + pk_execute_calls = mock_execute.call_count + pk_refresh_calls = mock_client.refresh_routing_map_provider_call_count + + print(f"\nPK Range Query:") + print(f" - Execute calls: {pk_execute_calls} (no retry)") + print(f" - Refresh calls: {pk_refresh_calls} (no recursion)") + print(f"{'=' * 60}\n") + + assert pk_execute_calls == 1, \ + f"PK range query should have 1 execute call, got {pk_execute_calls}" + assert pk_refresh_calls == 0, \ + f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py new file mode 100644 index 000000000000..a487fc6a85eb --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -0,0 +1,317 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +""" +Async unit tests for partition split (410) retry logic. +""" + +import gc +import time +import unittest +from unittest.mock import patch + +import pytest + +from azure.cosmos import exceptions +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes +from azure.cosmos.aio import CosmosClient # noqa: F401 - needed to resolve circular imports +from azure.cosmos._execution_context.aio.base_execution_context import _DefaultQueryExecutionContext + +# tracemalloc is not available in PyPy, so we import conditionally +try: + import tracemalloc + HAS_TRACEMALLOC = True +except ImportError: + HAS_TRACEMALLOC = False + + +# ==================================== +# Shared Test Helpers +# ==================================== + +class MockGlobalEndpointManager: + """Mock global endpoint manager for testing.""" + def is_circuit_breaker_applicable(self, request): + return False + + +class MockClient: + """Mock Cosmos client for testing partition split retry logic.""" + def __init__(self): + self._global_endpoint_manager = MockGlobalEndpointManager() + self.refresh_routing_map_provider_call_count = 0 + + def refresh_routing_map_provider(self): + self.refresh_routing_map_provider_call_count += 1 + + def reset_counts(self): + """Reset call counts for reuse in tests.""" + self.refresh_routing_map_provider_call_count = 0 + + +def create_410_partition_split_error(): + """Create a 410 partition split error for testing.""" + error = exceptions.CosmosHttpResponseError( + status_code=StatusCodes.GONE, + message="Partition key range is gone" + ) + error.sub_status = SubStatusCodes.PARTITION_KEY_RANGE_GONE + return error + + +def raise_410_partition_split_error(*args, **kwargs): + """Raise a 410 partition split error - for use as mock side_effect.""" + raise create_410_partition_split_error() + + +# =============================== +# Test Class +# =============================== + + + +@pytest.mark.cosmosEmulator +class TestPartitionSplitRetryUnitAsync(unittest.IsolatedAsyncioTestCase): + """ + Async unit tests for 410 partition split retry logic. + """ + + async def test_execution_context_state_reset_on_partition_split_async(self): + """ + Test that execution context state is properly reset on 410 partition split retry (async). + Verifies the fix for a bug where the while loop in _fetch_items_helper_no_retries + would not execute after a retry because _has_started was still True. + + """ + mock_client = MockClient() + + async def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + # Simulate state AFTER first successful fetch but BEFORE 410 error + context._has_started = True + context._continuation = None + + # Verify the loop condition WITHOUT state reset - this is FALSE + loop_condition_without_reset = context._continuation or not context._has_started + assert not loop_condition_without_reset, \ + "Without state reset, loop condition should be False" + + # Verify _fetch_items_helper_no_retries returns empty when state is not reset + fetch_was_called = [False] + + async def tracking_fetch(options): + fetch_was_called[0] = True + return ([{"id": "1"}], {}) + + result = await context._fetch_items_helper_no_retries(tracking_fetch) + assert not fetch_was_called[0], \ + "Fetch should NOT be called when _has_started=True and _continuation=None" + assert result == [], \ + "Should return empty list when while loop doesn't execute" + + # reset state + context._has_started = False + context._continuation = None + + # Verify _fetch_items_helper_no_retries works after state reset + result = await context._fetch_items_helper_no_retries(tracking_fetch) + assert fetch_was_called[0], \ + "Fetch SHOULD be called after state reset" + assert result == [{"id": "1"}], \ + "Should return documents after state reset" + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_retry_with_410_resets_state_and_succeeds_async(self, mock_execute): + """ + Test the full retry flow: 410 partition split error triggers state reset and retry succeeds (async). + """ + mock_client = MockClient() + expected_docs = [{"id": "success"}] + call_count = [0] + + async def execute_side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise create_410_partition_split_error() + return expected_docs + + mock_execute.side_effect = execute_side_effect + + async def mock_fetch_function(options): + return (expected_docs, {}) + + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + result = await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert call_count[0] == 2, "Should have retried once after 410" + assert mock_client.refresh_routing_map_provider_call_count == 1, \ + "refresh_routing_map_provider should be called once on 410" + assert result == expected_docs, "Should return expected documents after retry" + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_pk_range_query_skips_410_retry_to_prevent_recursion_async(self, mock_execute): + """ + Test that partition key range queries skip 410 retry to prevent recursion (async). + """ + mock_client = MockClient() + options = {"_internal_pk_range_fetch": True} + + mock_execute.side_effect = raise_410_partition_split_error + + async def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + context = _DefaultQueryExecutionContext(mock_client, options, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError) as exc_info: + await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert exc_info.value.status_code == StatusCodes.GONE + assert mock_client.refresh_routing_map_provider_call_count == 0, \ + "refresh_routing_map_provider should NOT be called for PK range queries" + assert mock_execute.call_count == 1, \ + "ExecuteAsync should only be called once - no retry for PK range queries" + + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_410_retry_behavior_with_and_without_pk_range_flag_async(self, mock_execute): + """ + Test that verifies the fix for the partition split recursion problem (async). + + The fix ensures: + - Regular queries retry up to 3 times on 410, calling refresh each time + - PK range queries (with _internal_pk_range_fetch flag) skip retry entirely, + preventing infinite recursion when refresh_routing_map_provider triggers + another PK range query that also gets a 410 + """ + mock_client = MockClient() + + mock_execute.side_effect = raise_410_partition_split_error + + async def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + # Test 1: Regular query (no flag) - should retry 3 times + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + await context._fetch_items_helper_with_retries(mock_fetch_function) + + assert mock_client.refresh_routing_map_provider_call_count == 3, \ + f"Expected 3 refresh calls, got {mock_client.refresh_routing_map_provider_call_count}" + assert mock_execute.call_count == 4, \ + f"Expected 4 ExecuteAsync calls, got {mock_execute.call_count}" + + # Test 2: PK range query (with flag) - should NOT retry + mock_client.reset_counts() + mock_execute.reset_mock() + mock_execute.side_effect = raise_410_partition_split_error + + options_with_flag = {"_internal_pk_range_fetch": True} + context_pk_range = _DefaultQueryExecutionContext(mock_client, options_with_flag, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + await context_pk_range._fetch_items_helper_with_retries(mock_fetch_function) + + assert mock_client.refresh_routing_map_provider_call_count == 0, \ + f"With flag, expected 0 refresh calls, got {mock_client.refresh_routing_map_provider_call_count}" + assert mock_execute.call_count == 1, \ + f"With flag, expected 1 ExecuteAsync call, got {mock_execute.call_count}" + + @pytest.mark.skipif(not HAS_TRACEMALLOC, reason="tracemalloc not available in PyPy") + @patch('azure.cosmos.aio._retry_utility_async.ExecuteAsync') + async def test_memory_bounded_no_leak_on_410_retries_async(self, mock_execute): + """ + Test that memory usage is bounded during 410 partition split retries. + - Execute calls are bounded (max 4: 1 initial + 3 retries) + - Refresh calls are bounded (max 3) + - Memory growth is minimal (no recursive accumulation) + - No infinite recursion (max depth = 0 for PK range queries) + """ + # tracemalloc.start() begins tracing memory allocations to detect leaks + tracemalloc.start() + # gc.collect() forces garbage collection to get accurate baseline memory measurement + gc.collect() + # take_snapshot() captures current memory state for comparison after test + snapshot_before = tracemalloc.take_snapshot() + start_time = time.time() + + mock_client = MockClient() + + mock_execute.side_effect = raise_410_partition_split_error + + async def mock_fetch_function(options): + return ([{"id": "1"}], {}) + + # Test regular query - should have bounded retries + context = _DefaultQueryExecutionContext(mock_client, {}, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + await context._fetch_items_helper_with_retries(mock_fetch_function) + + elapsed_time = time.time() - start_time + # gc.collect() before snapshot ensures we measure actual leaks, not pending garbage + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + # compare_to() shows memory difference between snapshots to identify growth + top_stats = snapshot_after.compare_to(snapshot_before, 'lineno') + memory_growth = sum(stat.size_diff for stat in top_stats if stat.size_diff > 0) + peak_memory = tracemalloc.get_traced_memory()[1] + # tracemalloc.stop() ends memory tracing and frees tracing overhead + tracemalloc.stop() + + # Collect metrics + execute_calls = mock_execute.call_count + refresh_calls = mock_client.refresh_routing_map_provider_call_count + + # Print metrics + print(f"\n{'=' * 60}") + print("MEMORY METRICS (Async) - Partition Split Memory Verification") + print(f"{'=' * 60}") + print(f"Metrics:") + print(f" - Execute calls: {execute_calls} (bounded)") + print(f" - Refresh calls: {refresh_calls}") + print(f" - Elapsed time: {elapsed_time:.2f}s") + print(f" - Memory growth: {memory_growth / 1024:.2f} KB") + print(f" - Peak memory: {peak_memory / 1024:.2f} KB") + print(f"{'=' * 60}") + + assert execute_calls == 4, \ + f"Execute calls should be bounded to 4, got {execute_calls}" + assert refresh_calls == 3, \ + f"Refresh calls should be bounded to 3, got {refresh_calls}" + assert elapsed_time < 1.0, \ + f"Should complete quickly (< 1s), took {elapsed_time:.2f}s - indicates no infinite loop" + assert memory_growth < 500 * 1024, \ + f"Memory growth should be < 500KB, got {memory_growth / 1024:.2f} KB - indicates no memory leak" + + # Test PK range query - should have NO retries (prevents recursion) + mock_client.reset_counts() + mock_execute.reset_mock() + mock_execute.side_effect = raise_410_partition_split_error + + options_with_flag = {"_internal_pk_range_fetch": True} + context_pk = _DefaultQueryExecutionContext(mock_client, options_with_flag, mock_fetch_function) + + with pytest.raises(exceptions.CosmosHttpResponseError): + await context_pk._fetch_items_helper_with_retries(mock_fetch_function) + + pk_execute_calls = mock_execute.call_count + pk_refresh_calls = mock_client.refresh_routing_map_provider_call_count + + print(f"\nPK Range Query:") + print(f" - Execute calls: {pk_execute_calls} (no retry)") + print(f" - Refresh calls: {pk_refresh_calls} (no recursion)") + print(f"{'=' * 60}\n") + + assert pk_execute_calls == 1, \ + f"PK range query should have 1 execute call, got {pk_execute_calls}" + assert pk_refresh_calls == 0, \ + f"PK range query should have 0 refresh calls, got {pk_refresh_calls}" + + +if __name__ == "__main__": + unittest.main() +