Skip to content

Commit 76c09ae

Browse files
committed
(improvement)Optimize column_encryption_policy checks: tests
Add tests, respond to review feedback on added tests. Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent ad8b6f9 commit 76c09ae

2 files changed

Lines changed: 133 additions & 132 deletions

File tree

tests/unit/test_protocol.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import io
1516
import unittest
1617

1718
from unittest.mock import Mock
1819

1920
from cassandra import ProtocolVersion, UnsupportedOperation
21+
from cassandra.cqltypes import Int32Type, UTF8Type
2022
from cassandra.protocol import (
2123
PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation,
2224
_PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG,
2325
_PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG,
24-
BatchMessage
26+
BatchMessage,
27+
ResultMessage, RESULT_KIND_ROWS
2528
)
2629
from cassandra.query import BatchType
27-
from cassandra.marshal import uint32_unpack
30+
from cassandra.marshal import uint32_unpack, int32_pack
2831
from cassandra.cluster import ContinuousPagingOptions
2932
import pytest
3033

34+
from cassandra.policies import ColDesc
3135

3236
class MessageTest(unittest.TestCase):
3337

@@ -189,3 +193,130 @@ def test_batch_message_with_keyspace(self):
189193
(b'\x00\x03',),
190194
(b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',))
191195
)
196+
197+
class ResultTest(unittest.TestCase):
198+
"""
199+
Tests to verify the optimization of column_encryption_policy checks
200+
in recv_results_rows. The optimization checks if the policy exists once
201+
per result message, avoiding the redundant 'column_encryption_policy and ...'
202+
check for every value.
203+
"""
204+
205+
def _create_mock_result_metadata(self):
206+
"""Create mock result metadata for testing"""
207+
return [
208+
('keyspace1', 'table1', 'col1', Int32Type),
209+
('keyspace1', 'table1', 'col2', UTF8Type),
210+
]
211+
212+
def _create_mock_result_message(self):
213+
"""Create a mock result message with data"""
214+
msg = ResultMessage(kind=RESULT_KIND_ROWS)
215+
msg.column_metadata = self._create_mock_result_metadata()
216+
msg.recv_results_metadata = Mock()
217+
msg.recv_row = Mock(side_effect=[
218+
[int32_pack(42), b'hello'],
219+
[int32_pack(100), b'world'],
220+
])
221+
return msg
222+
223+
def _create_mock_stream(self):
224+
"""Create a mock stream for reading rows"""
225+
# Pack rowcount (2 rows)
226+
data = int32_pack(2)
227+
return io.BytesIO(data)
228+
229+
def test_decode_without_encryption_policy(self):
230+
"""
231+
Test that decoding works correctly without column encryption policy.
232+
This should use the optimized simple path.
233+
"""
234+
msg = self._create_mock_result_message()
235+
f = self._create_mock_stream()
236+
237+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)
238+
239+
# Verify results
240+
self.assertEqual(len(msg.parsed_rows), 2)
241+
self.assertEqual(msg.parsed_rows[0][0], 42)
242+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
243+
self.assertEqual(msg.parsed_rows[1][0], 100)
244+
self.assertEqual(msg.parsed_rows[1][1], 'world')
245+
246+
def test_decode_with_encryption_policy_no_encrypted_columns(self):
247+
"""
248+
Test that decoding works with encryption policy when no columns are encrypted.
249+
"""
250+
msg = self._create_mock_result_message()
251+
f = self._create_mock_stream()
252+
253+
# Create mock encryption policy that has no encrypted columns
254+
mock_policy = Mock()
255+
mock_policy.contains_column = Mock(return_value=False)
256+
257+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
258+
259+
# Verify results
260+
self.assertEqual(len(msg.parsed_rows), 2)
261+
self.assertEqual(msg.parsed_rows[0][0], 42)
262+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
263+
264+
# Verify contains_column was called for each value (but policy existence check happens once)
265+
# Should be called 4 times (2 rows × 2 columns)
266+
self.assertEqual(mock_policy.contains_column.call_count, 4)
267+
268+
def test_decode_with_encryption_policy_with_encrypted_column(self):
269+
"""
270+
Test that decoding works with encryption policy when one column is encrypted.
271+
"""
272+
msg = self._create_mock_result_message()
273+
f = self._create_mock_stream()
274+
275+
# Create mock encryption policy where first column is encrypted
276+
mock_policy = Mock()
277+
def contains_column_side_effect(col_desc):
278+
return col_desc.col == 'col1'
279+
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
280+
mock_policy.column_type = Mock(return_value=Int32Type)
281+
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)
282+
283+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
284+
285+
# Verify results
286+
self.assertEqual(len(msg.parsed_rows), 2)
287+
self.assertEqual(msg.parsed_rows[0][0], 42)
288+
self.assertEqual(msg.parsed_rows[0][1], 'hello')
289+
290+
# Verify contains_column was called for each value (but policy existence check happens once)
291+
# Should be called 4 times (2 rows × 2 columns)
292+
self.assertEqual(mock_policy.contains_column.call_count, 4)
293+
294+
# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
295+
self.assertEqual(mock_policy.decrypt.call_count, 2)
296+
297+
def test_optimization_efficiency(self):
298+
"""
299+
Verify that the optimization checks policy existence once per result message.
300+
The key optimization is checking 'if column_encryption_policy:' once,
301+
rather than 'column_encryption_policy and ...' for every value.
302+
"""
303+
msg = self._create_mock_result_message()
304+
305+
# Create more rows to make the check pattern clear
306+
msg.recv_row = Mock(side_effect=[
307+
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
308+
])
309+
310+
# Create mock stream with 100 rows
311+
f = io.BytesIO(int32_pack(100))
312+
313+
mock_policy = Mock()
314+
mock_policy.contains_column = Mock(return_value=False)
315+
316+
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
317+
318+
# With optimization: policy existence checked once, contains_column called per value
319+
# = 100 rows * 2 columns = 200 calls to contains_column
320+
# The key is we avoid checking 'column_encryption_policy and ...' 200 times
321+
self.assertEqual(mock_policy.contains_column.call_count, 200,
322+
"contains_column should be called for each value when policy exists")

tests/unit/test_protocol_decode_optimization.py

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -23,133 +23,3 @@
2323
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS
2424

2525

26-
class DecodeOptimizationTest(unittest.TestCase):
27-
"""
28-
Tests to verify the optimization of column_encryption_policy checks
29-
in recv_results_rows. The optimization checks if the policy exists once
30-
per result message, avoiding the redundant 'column_encryption_policy and ...'
31-
check for every value.
32-
"""
33-
34-
def _create_mock_result_metadata(self):
35-
"""Create mock result metadata for testing"""
36-
return [
37-
('keyspace1', 'table1', 'col1', Int32Type),
38-
('keyspace1', 'table1', 'col2', UTF8Type),
39-
]
40-
41-
def _create_mock_result_message(self):
42-
"""Create a mock result message with data"""
43-
msg = ResultMessage(kind=RESULT_KIND_ROWS)
44-
msg.column_metadata = self._create_mock_result_metadata()
45-
msg.recv_results_metadata = Mock()
46-
msg.recv_row = Mock(side_effect=[
47-
[int32_pack(42), b'hello'],
48-
[int32_pack(100), b'world'],
49-
])
50-
return msg
51-
52-
def _create_mock_stream(self):
53-
"""Create a mock stream for reading rows"""
54-
# Pack rowcount (2 rows)
55-
data = int32_pack(2)
56-
return io.BytesIO(data)
57-
58-
def test_decode_without_encryption_policy(self):
59-
"""
60-
Test that decoding works correctly without column encryption policy.
61-
This should use the optimized simple path.
62-
"""
63-
msg = self._create_mock_result_message()
64-
f = self._create_mock_stream()
65-
66-
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)
67-
68-
# Verify results
69-
self.assertEqual(len(msg.parsed_rows), 2)
70-
self.assertEqual(msg.parsed_rows[0][0], 42)
71-
self.assertEqual(msg.parsed_rows[0][1], 'hello')
72-
self.assertEqual(msg.parsed_rows[1][0], 100)
73-
self.assertEqual(msg.parsed_rows[1][1], 'world')
74-
75-
def test_decode_with_encryption_policy_no_encrypted_columns(self):
76-
"""
77-
Test that decoding works with encryption policy when no columns are encrypted.
78-
"""
79-
msg = self._create_mock_result_message()
80-
f = self._create_mock_stream()
81-
82-
# Create mock encryption policy that has no encrypted columns
83-
mock_policy = Mock()
84-
mock_policy.contains_column = Mock(return_value=False)
85-
86-
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
87-
88-
# Verify results
89-
self.assertEqual(len(msg.parsed_rows), 2)
90-
self.assertEqual(msg.parsed_rows[0][0], 42)
91-
self.assertEqual(msg.parsed_rows[0][1], 'hello')
92-
93-
# Verify contains_column was called for each value (but policy existence check happens once)
94-
# Should be called 4 times (2 rows × 2 columns)
95-
self.assertEqual(mock_policy.contains_column.call_count, 4)
96-
97-
def test_decode_with_encryption_policy_with_encrypted_column(self):
98-
"""
99-
Test that decoding works with encryption policy when one column is encrypted.
100-
"""
101-
msg = self._create_mock_result_message()
102-
f = self._create_mock_stream()
103-
104-
# Create mock encryption policy where first column is encrypted
105-
mock_policy = Mock()
106-
def contains_column_side_effect(col_desc):
107-
return col_desc.col == 'col1'
108-
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
109-
mock_policy.column_type = Mock(return_value=Int32Type)
110-
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)
111-
112-
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
113-
114-
# Verify results
115-
self.assertEqual(len(msg.parsed_rows), 2)
116-
self.assertEqual(msg.parsed_rows[0][0], 42)
117-
self.assertEqual(msg.parsed_rows[0][1], 'hello')
118-
119-
# Verify contains_column was called for each value (but policy existence check happens once)
120-
# Should be called 4 times (2 rows × 2 columns)
121-
self.assertEqual(mock_policy.contains_column.call_count, 4)
122-
123-
# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
124-
self.assertEqual(mock_policy.decrypt.call_count, 2)
125-
126-
def test_optimization_efficiency(self):
127-
"""
128-
Verify that the optimization checks policy existence once per result message.
129-
The key optimization is checking 'if column_encryption_policy:' once,
130-
rather than 'column_encryption_policy and ...' for every value.
131-
"""
132-
msg = self._create_mock_result_message()
133-
134-
# Create more rows to make the check pattern clear
135-
msg.recv_row = Mock(side_effect=[
136-
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
137-
])
138-
139-
# Create mock stream with 100 rows
140-
f = io.BytesIO(int32_pack(100))
141-
142-
mock_policy = Mock()
143-
mock_policy.contains_column = Mock(return_value=False)
144-
145-
msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)
146-
147-
# With optimization: policy existence checked once, contains_column called per value
148-
# = 100 rows * 2 columns = 200 calls to contains_column
149-
# The key is we avoid checking 'column_encryption_policy and ...' 200 times
150-
self.assertEqual(mock_policy.contains_column.call_count, 200,
151-
"contains_column should be called for each value when policy exists")
152-
153-
154-
if __name__ == '__main__':
155-
unittest.main()

0 commit comments

Comments
 (0)