diff --git a/cassandra/obj_parser.pyx b/cassandra/obj_parser.pyx index cf43771dd7..2d366fc5bb 100644 --- a/cassandra/obj_parser.pyx +++ b/cassandra/obj_parser.pyx @@ -31,7 +31,10 @@ cdef class ListParser(ColumnParser): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] + if desc.column_encryption_policy: + return [rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)] + else: + return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] cdef class LazyParser(ColumnParser): @@ -47,7 +50,10 @@ def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() - return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) + if desc.column_encryption_policy: + return (rowparser.unpack_ce_row(reader, desc) for i in range(rowcount)) + else: + return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) cdef class TupleRowParser(RowParser): @@ -55,9 +61,11 @@ cdef class TupleRowParser(RowParser): Parse a single returned row into a tuple of objects: (obj1, ..., objN) + If CE (Column encryption) policy is enabled - use unpack_ce_row(), + otherwsise use unpack_row() """ - cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + cpdef unpack_ce_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf @@ -73,9 +81,9 @@ cdef class TupleRowParser(RowParser): # Deserialize bytes to python object deserializer = desc.deserializers[i] - coldesc = desc.coldescs[i] - uses_ce = ce_policy and ce_policy.contains_column(coldesc) try: + coldesc = desc.coldescs[i] + uses_ce = ce_policy.contains_column(coldesc) if uses_ce: col_type = ce_policy.column_type(coldesc) decrypted_bytes = ce_policy.decrypt(coldesc, to_bytes(&buf)) @@ -84,11 +92,36 @@ cdef class TupleRowParser(RowParser): val = from_binary(deserializer, &newbuf, desc.protocol_version) else: val = from_binary(deserializer, &buf, desc.protocol_version) + # Insert new object into tuple + tuple_set(res, i, val) + except Exception as e: + raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], + desc.coltypes[i].cql_parameterized_type(), + str(e))) + + return res + + cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): + assert desc.rowsize >= 0 + + cdef Buffer buf + cdef Py_ssize_t i, rowsize = desc.rowsize + cdef Deserializer deserializer + cdef tuple res = tuple_new(desc.rowsize) + + for i in range(rowsize): + # Read the next few bytes + get_buf(reader, &buf) + + # Deserialize bytes to python object + deserializer = desc.deserializers[i] + try: + val = from_binary(deserializer, &buf, desc.protocol_version) + # Insert new object into tuple + tuple_set(res, i, val) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (desc.colnames[i], desc.coltypes[i].cql_parameterized_type(), str(e))) - # Insert new object into tuple - tuple_set(res, i, val) return res diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e574965de8..5f77818c70 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -719,24 +719,37 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] self.column_names = [c[2] for c in column_metadata] self.column_types = [c[3] for c in column_metadata] - col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_val(val, col_md, col_desc): - uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) - col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] - raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) + if column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_row(row): - return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + def decode_val(val, col_md, col_desc): + uses_ce = column_encryption_policy.contains_column(col_desc) + if uses_ce: + col_type = column_encryption_policy.column_type(col_desc) + raw_bytes = column_encryption_policy.decrypt(col_desc, val) + return col_type.from_binary(raw_bytes, protocol_version) + else: + return col_md[3].from_binary(val, protocol_version) + + def decode_row(row): + return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + else: + def decode_row(row): + return tuple(col_md[3].from_binary(val, protocol_version) for val, col_md in zip(row, column_metadata)) try: self.parsed_rows = [decode_row(row) for row in rows] except Exception: + if not column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] for row in rows: for val, col_md, col_desc in zip(row, column_metadata, col_descs): try: - decode_val(val, col_md, col_desc) + if column_encryption_policy: + decode_val(val, col_md, col_desc) + else: + col_md[3].from_binary(val, protocol_version) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], col_md[3].cql_parameterized_type(), diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..1308f5b2ce 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -44,7 +44,11 @@ def make_recv_results_rows(ColumnParser colparser): reader.buf_ptr = reader.buf reader.pos = 0 rowcount = read_int(reader) - for i in range(rowcount): - rowparser.unpack_row(reader, desc) + if desc.column_encryption_policy: + for i in range(rowcount): + rowparser.unpack_ce_row(reader, desc) + else: + for i in range(rowcount): + rowparser.unpack_row(reader, desc) return recv_results_rows diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..ea12fa7b5a 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -12,22 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import unittest from unittest.mock import Mock from cassandra import ProtocolVersion, UnsupportedOperation +from cassandra.cqltypes import Int32Type, UTF8Type from cassandra.protocol import ( PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, + ResultMessage, RESULT_KIND_ROWS ) from cassandra.query import BatchType -from cassandra.marshal import uint32_unpack +from cassandra.marshal import uint32_unpack, int32_pack from cassandra.cluster import ContinuousPagingOptions import pytest +from cassandra.policies import ColDesc class MessageTest(unittest.TestCase): @@ -189,3 +193,130 @@ def test_batch_message_with_keyspace(self): (b'\x00\x03',), (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) ) + +class ResultTest(unittest.TestCase): + """ + Tests to verify the optimization of column_encryption_policy checks + in recv_results_rows. The optimization checks if the policy exists once + per result message, avoiding the redundant 'column_encryption_policy and ...' + check for every value. + """ + + def _create_mock_result_metadata(self): + """Create mock result metadata for testing""" + return [ + ('keyspace1', 'table1', 'col1', Int32Type), + ('keyspace1', 'table1', 'col2', UTF8Type), + ] + + def _create_mock_result_message(self): + """Create a mock result message with data""" + msg = ResultMessage(kind=RESULT_KIND_ROWS) + msg.column_metadata = self._create_mock_result_metadata() + msg.recv_results_metadata = Mock() + msg.recv_row = Mock(side_effect=[ + [int32_pack(42), b'hello'], + [int32_pack(100), b'world'], + ]) + return msg + + def _create_mock_stream(self): + """Create a mock stream for reading rows""" + # Pack rowcount (2 rows) + data = int32_pack(2) + return io.BytesIO(data) + + def test_decode_without_encryption_policy(self): + """ + Test that decoding works correctly without column encryption policy. + This should use the optimized simple path. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + self.assertEqual(msg.parsed_rows[1][0], 100) + self.assertEqual(msg.parsed_rows[1][1], 'world') + + def test_decode_with_encryption_policy_no_encrypted_columns(self): + """ + Test that decoding works with encryption policy when no columns are encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy that has no encrypted columns + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + def test_decode_with_encryption_policy_with_encrypted_column(self): + """ + Test that decoding works with encryption policy when one column is encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy where first column is encrypted + mock_policy = Mock() + def contains_column_side_effect(col_desc): + return col_desc.col == 'col1' + mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) + mock_policy.column_type = Mock(return_value=Int32Type) + mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) + self.assertEqual(mock_policy.decrypt.call_count, 2) + + def test_optimization_efficiency(self): + """ + Verify that the optimization checks policy existence once per result message. + The key optimization is checking 'if column_encryption_policy:' once, + rather than 'column_encryption_policy and ...' for every value. + """ + msg = self._create_mock_result_message() + + # Create more rows to make the check pattern clear + msg.recv_row = Mock(side_effect=[ + [int32_pack(i), f'text{i}'.encode()] for i in range(100) + ]) + + # Create mock stream with 100 rows + f = io.BytesIO(int32_pack(100)) + + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # With optimization: policy existence checked once, contains_column called per value + # = 100 rows * 2 columns = 200 calls to contains_column + # The key is we avoid checking 'column_encryption_policy and ...' 200 times + self.assertEqual(mock_policy.contains_column.call_count, 200, + "contains_column should be called for each value when policy exists")