Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions test/asynchronous/test_async_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Async-only unit tests for network_layer.py."""

from __future__ import annotations

import asyncio
import struct
import sys
from unittest.mock import AsyncMock, MagicMock, patch

sys.path[0:0] = [""]

from test.asynchronous import AsyncUnitTest, unittest

from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive


async def _make_protocol(timeout=None):
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol


def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)


class TestPyMongoProtocol(AsyncUnitTest):
async def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
protocol = await _make_protocol()
protocol._max_message_size = max_size
protocol._header = memoryview(bytearray(header_bytes))
return protocol

async def test_initial_timeout_from_constructor(self):
protocol = await _make_protocol(timeout=3.0)
self.assertEqual(protocol.gettimeout, 3.0)

async def test_settimeout_updates_value(self):
protocol = await _make_protocol()
protocol.settimeout(7.5)
self.assertEqual(protocol.gettimeout, 7.5)

async def test_default_timeout_is_none(self):
protocol = await _make_protocol()
self.assertIsNone(protocol.gettimeout)

async def test_normal_op_msg(self):
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
protocol = await self._make_proto_with_header(header)
body_len, op_code, response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)

async def test_op_compressed(self):
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
# length=35 → after compression sub-header: 26 → body: 10
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)

async def test_op_compressed_length_too_small_raises(self):
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

async def test_non_compressed_length_too_small_raises(self):
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

async def test_length_exceeds_max_raises(self):
header = _make_header(
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
)
protocol = await self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

async def test_op_reply_op_code(self):
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
protocol = await self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 4)
self.assertEqual(op_code, 1)
self.assertFalse(expecting_compression)

async def test_compression_header_snappy_compressor_id(self):
protocol = await _make_protocol()
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
data = struct.pack("<iiB", 2013, 0, 1)
protocol._compression_header = memoryview(bytearray(data))
op_code, compressor_id = protocol.process_compression_header()
self.assertEqual(op_code, 2013)
self.assertEqual(compressor_id, 1)

async def test_compression_header_zlib_compressor_id(self):
protocol = await _make_protocol()
data = struct.pack("<iiB", 2013, 0, 2)
protocol._compression_header = memoryview(bytearray(data))
_, compressor_id = protocol.process_compression_header()
self.assertEqual(compressor_id, 2)

async def test_message_complete_resolves_pending_future(self):
protocol = await _make_protocol()
protocol._expecting_header = False
protocol._expecting_compression = False
protocol._message_size = 10
protocol._message = memoryview(bytearray(10))
protocol._message_index = 0
protocol._op_code = 2013
protocol._compressor_id = None
protocol._response_to = 42

future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)

protocol.buffer_updated(10)
self.assertTrue(future.done())
op_code, compressor_id, response_to, _ = future.result()
self.assertEqual(op_code, 2013)
self.assertIsNone(compressor_id)
self.assertEqual(response_to, 42)

async def test_close_aborts_transport(self):
protocol = await _make_protocol()
protocol.close()
self.assertTrue(protocol.transport.abort.called)

async def test_connection_lost_twice_does_not_raise(self):
protocol = await _make_protocol()
protocol.connection_lost(None)
protocol.connection_lost(None)

async def test_close_with_exception_propagates_to_pending(self):
protocol = await _make_protocol()
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
exc = OSError("connection reset")
protocol.close(exc)
with self.assertRaisesRegex(OSError, "connection reset"):
await future


class TestAsyncSocketReceive(AsyncUnitTest):
async def test_reads_data_in_multiple_chunks(self):
# Covers the loop in _async_socket_receive that accumulates short reads
# until the requested length has been received.
data = b"abcdefgh"
length = len(data)
chunk1, chunk2 = data[:4], data[4:]
mock_socket = MagicMock()
loop = asyncio.get_running_loop()
calls = 0

async def fake_recv_into(sock, buf):
nonlocal calls
if calls == 0:
buf[: len(chunk1)] = chunk1
calls += 1
return len(chunk1)
buf[: len(chunk2)] = chunk2
calls += 1
return len(chunk2)

with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
result = await _async_socket_receive(mock_socket, length, loop)
self.assertEqual(bytes(result), data)
self.assertEqual(calls, 2)

async def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# sock_recv_into returns 0.
mock_socket = MagicMock()
loop = asyncio.get_running_loop()

async def fake_recv_into(sock, buf):
return 0

with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
with self.assertRaisesRegex(OSError, "connection closed"):
await _async_socket_receive(mock_socket, 10, loop)


if __name__ == "__main__":
unittest.main()
64 changes: 64 additions & 0 deletions test/asynchronous/test_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for code in network_layer.py shared between sync and async APIs.

Async-only tests live in ``test_async_network_layer.py``.
Comment thread
aclark4life marked this conversation as resolved.
"""

from __future__ import annotations

import sys
from unittest.mock import MagicMock

sys.path[0:0] = [""]

from test.asynchronous import AsyncUnitTest, unittest

from pymongo.network_layer import NetworkingInterfaceBase

_IS_SYNC = False


class TestNetworkingInterfaceBase(AsyncUnitTest):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class doesn't need to be synchro'd since it's testing code common to both APIs.

async def asyncSetUp(self):
self.base = NetworkingInterfaceBase(MagicMock())

def test_gettimeout_raises(self):
with self.assertRaises(NotImplementedError):
_ = self.base.gettimeout

def test_settimeout_raises(self):
with self.assertRaises(NotImplementedError):
self.base.settimeout(1.0)

def test_close_raises(self):
with self.assertRaises(NotImplementedError):
self.base.close()

def test_is_closing_raises(self):
with self.assertRaises(NotImplementedError):
self.base.is_closing()

def test_get_conn_raises(self):
with self.assertRaises(NotImplementedError):
_ = self.base.get_conn

def test_sock_raises(self):
with self.assertRaises(NotImplementedError):
_ = self.base.sock


if __name__ == "__main__":
unittest.main()
64 changes: 64 additions & 0 deletions test/test_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for code in network_layer.py shared between sync and async APIs.

Async-only tests live in ``test_async_network_layer.py``.
Comment thread
aclark4life marked this conversation as resolved.
"""

from __future__ import annotations

import sys
from unittest.mock import MagicMock

sys.path[0:0] = [""]

from test import UnitTest, unittest

from pymongo.network_layer import NetworkingInterfaceBase

_IS_SYNC = True


class TestNetworkingInterfaceBase(UnitTest):
def setUp(self):
self.base = NetworkingInterfaceBase(MagicMock())

def test_gettimeout_raises(self):
with self.assertRaises(NotImplementedError):
_ = self.base.gettimeout

def test_settimeout_raises(self):
with self.assertRaises(NotImplementedError):
self.base.settimeout(1.0)

def test_close_raises(self):
with self.assertRaises(NotImplementedError):
self.base.close()

def test_is_closing_raises(self):
with self.assertRaises(NotImplementedError):
self.base.is_closing()

def test_get_conn_raises(self):
with self.assertRaises(NotImplementedError):
_ = self.base.get_conn

def test_sock_raises(self):
with self.assertRaises(NotImplementedError):
_ = self.base.sock


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
"AsyncMock": "MagicMock",
Comment thread
aclark4life marked this conversation as resolved.
"StopAsyncIteration": "StopIteration",
"create_async_event": "create_event",
"async_create_barrier": "create_barrier",
Expand Down Expand Up @@ -190,6 +191,7 @@ def async_only_test(f: str) -> bool:
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
"test_async_loop_unblocked.py",
"test_async_network_layer.py",
]


Expand Down Expand Up @@ -251,6 +253,7 @@ def async_only_test(f: str) -> bool:
"test_monitor.py",
"test_monitoring.py",
"test_mongos_load_balancing.py",
"test_network_layer.py",
"test_on_demand_csfle.py",
"test_pooling.py",
"test_raw_bson.py",
Expand Down
Loading