diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 31e6ceb386..9fbd10c8e9 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -1101,7 +1101,11 @@ def _apply_to( read_preference: _ServerMode, conn: AsyncConnection, ) -> None: - if not conn.supports_sessions: + # getMores must be sent with a session if the cursor was opened with one + operation = next(iter(command)) + if not conn.supports_sessions and ( + isinstance(self._server_session, _EmptyServerSession) or operation != "getMore" + ): if not self._implicit: raise ConfigurationError("Sessions are not supported by this MongoDB deployment") return diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 3165dd52b7..7563850843 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -1097,7 +1097,11 @@ def _apply_to( read_preference: _ServerMode, conn: Connection, ) -> None: - if not conn.supports_sessions: + # getMores must be sent with a session if the cursor was opened with one + operation = next(iter(command)) + if not conn.supports_sessions and ( + isinstance(self._server_session, _EmptyServerSession) or operation != "getMore" + ): if not self._implicit: raise ConfigurationError("Sessions are not supported by this MongoDB deployment") return diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 404a69fdee..13ce578671 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -15,7 +15,6 @@ """Test the client_session module.""" from __future__ import annotations -import asyncio import copy import sys import time @@ -24,8 +23,6 @@ from test.asynchronous.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple -from pymongo.synchronous.mongo_client import MongoClient - sys.path[0:0] = [""] from test.asynchronous import ( @@ -45,7 +42,7 @@ from bson import DBRef from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket -from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring +from pymongo import ASCENDING, AsyncMongoClient, monitoring from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.helpers import anext @@ -938,6 +935,39 @@ async def test_session_binding_end_session(self): await s2.end_session() + async def test_getmore_preserves_lsid_after_session_support_lost(self): + listener = OvertCommandListener() + client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + coll = client.pymongo_test.test + await coll.drop() + await coll.insert_many([{"x": i} for i in range(10)]) + self.addAsyncCleanup(coll.drop) + + async with client.start_session() as s: + cursor = coll.find({}, batch_size=2, session=s) + await anext(cursor) + + find_event = next(e for e in listener.started_events if e.command_name == "find") + lsid = find_event.command["lsid"] + + # Simulate a node stepping down: mark idle connections as not supporting sessions. + for server in client._topology._servers.values(): + for conn in server.pool.conns: + conn.supports_sessions = False + + listener.reset() + await cursor.to_list() + + getmore_events = [e for e in listener.started_events if e.command_name == "getMore"] + self.assertGreater(len(getmore_events), 0, "expected at least one getMore command") + for event in getmore_events: + self.assertIn( + "lsid", event.command, "getMore must include lsid when session is materialized" + ) + self.assertEqual( + lsid, event.command["lsid"], "getMore lsid must match the session lsid from find" + ) + class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener diff --git a/test/test_session.py b/test/test_session.py index 3963f88da0..cf071df49a 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -15,7 +15,6 @@ """Test the client_session module.""" from __future__ import annotations -import asyncio import copy import sys import time @@ -24,8 +23,6 @@ from test.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple -from pymongo.synchronous.mongo_client import MongoClient - sys.path[0:0] = [""] from test import ( @@ -45,7 +42,7 @@ from bson import DBRef from gridfs.synchronous.grid_file import GridFS, GridFSBucket -from pymongo import ASCENDING, MongoClient, _csot, monitoring +from pymongo import ASCENDING, MongoClient, monitoring from pymongo.common import _MAX_END_SESSIONS from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import IndexModel, InsertOne, UpdateOne @@ -938,6 +935,39 @@ def test_session_binding_end_session(self): s2.end_session() + def test_getmore_preserves_lsid_after_session_support_lost(self): + listener = OvertCommandListener() + client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + coll = client.pymongo_test.test + coll.drop() + coll.insert_many([{"x": i} for i in range(10)]) + self.addCleanup(coll.drop) + + with client.start_session() as s: + cursor = coll.find({}, batch_size=2, session=s) + next(cursor) + + find_event = next(e for e in listener.started_events if e.command_name == "find") + lsid = find_event.command["lsid"] + + # Simulate a node stepping down: mark idle connections as not supporting sessions. + for server in client._topology._servers.values(): + for conn in server.pool.conns: + conn.supports_sessions = False + + listener.reset() + cursor.to_list() + + getmore_events = [e for e in listener.started_events if e.command_name == "getMore"] + self.assertGreater(len(getmore_events), 0, "expected at least one getMore command") + for event in getmore_events: + self.assertIn( + "lsid", event.command, "getMore must include lsid when session is materialized" + ) + self.assertEqual( + lsid, event.command["lsid"], "getMore lsid must match the session lsid from find" + ) + class TestCausalConsistency(UnitTest): listener: SessionTestListener