Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

from typing import (
Callable,
cast,
Any,
AsyncIterable,
Expand Down Expand Up @@ -115,6 +116,7 @@
if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
from google.cloud.bigtable.data._helpers import ShardedQuery
from google.rpc import status_pb2

if CrossSync.is_async:
from google.cloud.bigtable.data._async.mutations_batcher import (
Expand Down Expand Up @@ -1437,6 +1439,7 @@ def mutations_batcher(
batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
batch_retryable_errors: Sequence[type[Exception]]
| TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
_batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None,
) -> "MutationsBatcherAsync":
"""
Returns a new mutations batcher instance.
Expand Down Expand Up @@ -1472,6 +1475,7 @@ def mutations_batcher(
batch_operation_timeout=batch_operation_timeout,
batch_attempt_timeout=batch_attempt_timeout,
batch_retryable_errors=batch_retryable_errors,
_batch_completed_callback=_batch_completed_callback,
)

@CrossSync.convert
Expand Down
18 changes: 17 additions & 1 deletion google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
from __future__ import annotations

from typing import Sequence, TYPE_CHECKING, cast
from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast
import atexit
import warnings
from collections import deque
Expand All @@ -24,6 +24,10 @@
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
from google.cloud.bigtable.data._helpers import _get_retryable_errors
from google.cloud.bigtable.data._helpers import _get_timeouts
from google.cloud.bigtable.data._helpers import (
_populate_statuses_from_mutations_exception_group,
)

from google.cloud.bigtable.data._helpers import TABLE_DEFAULT

from google.cloud.bigtable.data.mutations import (
Expand All @@ -33,6 +37,9 @@

from google.cloud.bigtable.data._cross_sync import CrossSync

from google.rpc import code_pb2
from google.rpc import status_pb2

if TYPE_CHECKING:
from google.cloud.bigtable.data.mutations import RowMutationEntry

Expand Down Expand Up @@ -223,6 +230,7 @@ def __init__(
batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
batch_retryable_errors: Sequence[type[Exception]]
| TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
_batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None,
):
self._operation_timeout, self._attempt_timeout = _get_timeouts(
batch_operation_timeout, batch_attempt_timeout, target
Expand Down Expand Up @@ -269,6 +277,7 @@ def __init__(
self._newest_exceptions: deque[Exception] = deque(
maxlen=self._exception_list_limit
)
self._user_batch_completed_callback = _batch_completed_callback
# clean up on program exit
atexit.register(self._on_exit)

Expand Down Expand Up @@ -380,6 +389,7 @@ async def _execute_mutate_rows(
list of FailedMutationEntryError objects for mutations that failed.
FailedMutationEntryError objects will not contain index information
"""
statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch)
try:
operation = CrossSync._MutateRowsOperation(
self._target.client._gapic_client,
Expand All @@ -391,13 +401,19 @@ async def _execute_mutate_rows(
)
await operation.start()
except MutationsExceptionGroup as e:
_populate_statuses_from_mutations_exception_group(statuses, e)

# strip index information from exceptions, since it is not useful in a batch context
for subexc in e.exceptions:
subexc.index = None
return list(e.exceptions)
finally:
# mark batch as complete in flow control
await self._flow_control.remove_from_flow(batch)

# Call batch done callback with list of statuses.
if self._user_batch_completed_callback:
self._user_batch_completed_callback(statuses)
return []

def _add_exceptions(self, excs: list[Exception]):
Expand Down
59 changes: 59 additions & 0 deletions google/cloud/bigtable/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from google.api_core import retry as retries
from google.api_core.retry import RetryFailureReason
from google.cloud.bigtable.data.exceptions import RetryExceptionGroup
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.rpc import code_pb2
from google.rpc import status_pb2


if TYPE_CHECKING:
import grpc
Expand Down Expand Up @@ -224,6 +228,61 @@ def _align_timeouts(operation: float, attempt: float | None) -> tuple[float, flo
return operation, final_attempt


def _populate_statuses_from_mutations_exception_group(
statuses: list[status_pb2.Status], exc_group: MutationsExceptionGroup
):
"""
Helper function that populates a list of Status objects with exception information from
the exception group.

Args:
statuses: The initial list of Status objects
exc_group: The exception group from a mutate rows operation
"""
# We exception handle as follows:
#
# 1. Each exception in the error group is a FailedMutationEntryError, and its
# cause is either a singular exception or a RetryExceptionGroup consisting of
# multiple exceptions.
#
# 2. In the case of a singular exception, if the error does not have a gRPC status
# code, we return a status code of UNKNOWN.
#
# 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception
# group and process that.
for error in exc_group.exceptions:
cause = error.__cause__
if isinstance(cause, RetryExceptionGroup):
statuses[error.index] = _get_status(cause.exceptions[-1])
else:
statuses[error.index] = _get_status(cause)
Comment on lines +253 to +258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The helper function _populate_statuses_from_mutations_exception_group uses error.index to access the statuses list without validating that the index is an integer or within the bounds of the list. Since error.index is derived from server responses (via FailedMutationEntryError), a malicious or malfunctioning server could provide an out-of-bounds index, causing an IndexError. Additionally, some parts of the library (e.g., in MutationsBatcher) explicitly set error.index to None, which would cause a TypeError if the helper is called on such an exception. This can lead to a Denial of Service (crash) of the client application.

Suggested change
for error in exc_group.exceptions:
cause = error.__cause__
if isinstance(cause, RetryExceptionGroup):
statuses[error.index] = _get_status(cause.exceptions[-1])
else:
statuses[error.index] = _get_status(cause)
for error in exc_group.exceptions:
if isinstance(error.index, int) and 0 <= error.index < len(statuses):
cause = error.__cause__
if isinstance(cause, RetryExceptionGroup):
statuses[error.index] = _get_status(cause.exceptions[-1])
else:
statuses[error.index] = _get_status(cause)



def _get_status(exc: Exception) -> status_pb2.Status:
"""
Helper function that returns a Status object corresponding to the given exception.

Args:
exc: An exception to be converted into a Status.
Returns:
status_pb2.Status: A Status proto object.
"""
if (
isinstance(exc, core_exceptions.GoogleAPICallError)
and exc.grpc_status_code is not None
):
return status_pb2.Status(
code=exc.grpc_status_code.value[0],
message=exc.message,
details=exc.details,
)

return status_pb2.Status(
code=code_pb2.Code.UNKNOWN,
message=str(exc),
)


def _validate_timeouts(
operation_timeout: float, attempt_timeout: float | None, allow_none: bool = False
):
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/bigtable/data/_sync_autogen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# This file is automatically generated by CrossSync. Do not edit manually.

from __future__ import annotations
from typing import cast, Any, Optional, Set, Sequence, TYPE_CHECKING
from typing import Callable, cast, Any, Optional, Set, Sequence, TYPE_CHECKING
import abc
import time
import warnings
Expand Down Expand Up @@ -87,6 +87,7 @@
if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
from google.cloud.bigtable.data._helpers import ShardedQuery
from google.rpc import status_pb2
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import (
MutationsBatcher,
)
Expand Down Expand Up @@ -1190,6 +1191,7 @@ def mutations_batcher(
batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
batch_retryable_errors: Sequence[type[Exception]]
| TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
_batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None,
) -> "MutationsBatcher":
"""Returns a new mutations batcher instance.

Expand Down Expand Up @@ -1224,6 +1226,7 @@ def mutations_batcher(
batch_operation_timeout=batch_operation_timeout,
batch_attempt_timeout=batch_attempt_timeout,
batch_retryable_errors=batch_retryable_errors,
_batch_completed_callback=_batch_completed_callback,
)

def mutate_row(
Expand Down
13 changes: 12 additions & 1 deletion google/cloud/bigtable/data/_sync_autogen/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# This file is automatically generated by CrossSync. Do not edit manually.

from __future__ import annotations
from typing import Sequence, TYPE_CHECKING, cast
from typing import Callable, Optional, Sequence, TYPE_CHECKING, cast
import atexit
import warnings
from collections import deque
Expand All @@ -25,10 +25,15 @@
from google.cloud.bigtable.data.exceptions import FailedMutationEntryError
from google.cloud.bigtable.data._helpers import _get_retryable_errors
from google.cloud.bigtable.data._helpers import _get_timeouts
from google.cloud.bigtable.data._helpers import (
_populate_statuses_from_mutations_exception_group,
)
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT
from google.cloud.bigtable.data.mutations import Mutation
from google.cloud.bigtable.data._cross_sync import CrossSync
from google.rpc import code_pb2
from google.rpc import status_pb2

if TYPE_CHECKING:
from google.cloud.bigtable.data.mutations import RowMutationEntry
Expand Down Expand Up @@ -192,6 +197,7 @@ def __init__(
batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
batch_retryable_errors: Sequence[type[Exception]]
| TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS,
_batch_completed_callback: Optional[Callable[list[status_pb2.Status]]] = None,
):
(self._operation_timeout, self._attempt_timeout) = _get_timeouts(
batch_operation_timeout, batch_attempt_timeout, target
Expand Down Expand Up @@ -233,6 +239,7 @@ def __init__(
self._newest_exceptions: deque[Exception] = deque(
maxlen=self._exception_list_limit
)
self._user_batch_completed_callback = _batch_completed_callback
atexit.register(self._on_exit)

def _timer_routine(self, interval: float | None) -> None:
Expand Down Expand Up @@ -324,6 +331,7 @@ def _execute_mutate_rows(
list[FailedMutationEntryError]:
list of FailedMutationEntryError objects for mutations that failed.
FailedMutationEntryError objects will not contain index information"""
statuses = [status_pb2.Status(code=code_pb2.Code.OK)] * len(batch)
try:
operation = CrossSync._Sync_Impl._MutateRowsOperation(
self._target.client._gapic_client,
Expand All @@ -335,11 +343,14 @@ def _execute_mutate_rows(
)
operation.start()
except MutationsExceptionGroup as e:
_populate_statuses_from_mutations_exception_group(statuses, e)
for subexc in e.exceptions:
subexc.index = None
return list(e.exceptions)
finally:
self._flow_control.remove_from_flow(batch)
if self._user_batch_completed_callback:
self._user_batch_completed_callback(statuses)
return []

def _add_exceptions(self, excs: list[Exception]):
Expand Down
44 changes: 7 additions & 37 deletions google/cloud/bigtable/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Set
import warnings

from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import Aborted
from google.api_core.exceptions import DeadlineExceeded
from google.api_core.exceptions import NotFound
Expand All @@ -31,10 +30,10 @@
from google.cloud.bigtable.column_family import _gc_rule_from_pb
from google.cloud.bigtable.column_family import ColumnFamily
from google.cloud.bigtable.data._helpers import TABLE_DEFAULT
from google.cloud.bigtable.data.exceptions import (
RetryExceptionGroup,
MutationsExceptionGroup,
from google.cloud.bigtable.data._helpers import (
_populate_statuses_from_mutations_exception_group,
)
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.cloud.bigtable.batcher import MutationsBatcher
from google.cloud.bigtable.batcher import FLUSH_COUNT, MAX_MUTATION_SIZE
Expand Down Expand Up @@ -774,41 +773,12 @@ def mutate_rows(self, rows, retry=DEFAULT_RETRY, timeout=DEFAULT):
retryable_errors=retryable_errors,
)
except MutationsExceptionGroup as mut_exc_group:
# We exception handle as follows:
#
# 1. Each exception in the error group is a FailedMutationEntryError, and its
# cause is either a singular exception or a RetryExceptionGroup consisting of
# multiple exceptions.
#
# 2. In the case of a singular exception, if the error does not have a gRPC status
# code, we return a status code of UNKNOWN.
#
# 3. In the case of a RetryExceptionGroup, we use terminal exception in the exception
# group and process that.
for error in mut_exc_group.exceptions:
cause = error.__cause__
if isinstance(cause, RetryExceptionGroup):
return_statuses[error.index] = self._get_status(
cause.exceptions[-1]
)
else:
return_statuses[error.index] = self._get_status(cause)

return return_statuses

@staticmethod
def _get_status(error):
if isinstance(error, GoogleAPICallError) and error.grpc_status_code is not None:
return status_pb2.Status(
code=error.grpc_status_code.value[0],
message=error.message,
details=error.details,
_populate_statuses_from_mutations_exception_group(
return_statuses,
mut_exc_group,
)

return status_pb2.Status(
code=code_pb2.Code.UNKNOWN,
message=str(error),
)
return return_statuses

def sample_row_keys(self):
"""Read a sample of row keys in the table.
Expand Down
37 changes: 37 additions & 0 deletions tests/system/data/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,43 @@ async def test_mutations_batcher_timer_flush(self, client, target, temp_rows):
# ensure cell is updated
assert (await self._retrieve_cell_value(target, row_key)) == new_value

@pytest.mark.usefixtures("client")
@pytest.mark.usefixtures("target")
@CrossSync.Retry(
predicate=retry.if_exception_type(ClientError), initial=1, maximum=5
)
@CrossSync.pytest
async def test_mutations_batcher_completed_callback(
self, client, target, temp_rows
):
"""
test batcher with batch completed callback. It should be called when the batcher flushes.
"""
from google.cloud.bigtable.data.mutations import RowMutationEntry
from google.rpc import code_pb2, status_pb2

import mock

callback = mock.Mock()

new_value = uuid.uuid4().hex.encode()
row_key, mutation = await self._create_row_and_mutation(
target, temp_rows, new_value=new_value
)
bulk_mutation = RowMutationEntry(row_key, [mutation])
flush_interval = 0.1
async with target.mutations_batcher(
flush_interval=flush_interval, _batch_completed_callback=callback
) as batcher:
await batcher.append(bulk_mutation)
await CrossSync.yield_to_event_loop()
assert len(batcher._staged_entries) == 1
await CrossSync.sleep(flush_interval + 0.1)
assert len(batcher._staged_entries) == 0
callback.assert_called_once_with([status_pb2.Status(code=code_pb2.OK)])
# ensure cell is updated
assert (await self._retrieve_cell_value(target, row_key)) == new_value

@pytest.mark.usefixtures("client")
@pytest.mark.usefixtures("target")
@CrossSync.Retry(
Expand Down
Loading
Loading