Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,11 @@ def _apply_to(
read_preference: _ServerMode,
conn: AsyncConnection,
) -> None:
if not conn.supports_sessions:
# getMores must be sent with a session if the cursor was opened with one
operation = next(iter(command))
if not conn.supports_sessions and (
isinstance(self._server_session, _EmptyServerSession) or operation != "getMore"
):
if not self._implicit:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return
Expand Down
6 changes: 5 additions & 1 deletion pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,11 @@ def _apply_to(
read_preference: _ServerMode,
conn: Connection,
) -> None:
if not conn.supports_sessions:
# getMores must be sent with a session if the cursor was opened with one
operation = next(iter(command))
if not conn.supports_sessions and (
isinstance(self._server_session, _EmptyServerSession) or operation != "getMore"
):
if not self._implicit:
raise ConfigurationError("Sessions are not supported by this MongoDB deployment")
return
Expand Down
38 changes: 34 additions & 4 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Test the client_session module."""
from __future__ import annotations

import asyncio
import copy
import sys
import time
Expand All @@ -24,8 +23,6 @@
from test.asynchronous.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple

from pymongo.synchronous.mongo_client import MongoClient

sys.path[0:0] = [""]

from test.asynchronous import (
Expand All @@ -45,7 +42,7 @@

from bson import DBRef
from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket
from pymongo import ASCENDING, AsyncMongoClient, _csot, monitoring
from pymongo import ASCENDING, AsyncMongoClient, monitoring
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.helpers import anext
Expand Down Expand Up @@ -938,6 +935,39 @@ async def test_session_binding_end_session(self):

await s2.end_session()

async def test_getmore_preserves_lsid_after_session_support_lost(self):
listener = OvertCommandListener()
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
coll = client.pymongo_test.test
await coll.drop()
await coll.insert_many([{"x": i} for i in range(10)])
self.addAsyncCleanup(coll.drop)

async with client.start_session() as s:
cursor = coll.find({}, batch_size=2, session=s)
await anext(cursor)

find_event = next(e for e in listener.started_events if e.command_name == "find")
lsid = find_event.command["lsid"]

# Simulate a node stepping down: mark idle connections as not supporting sessions.
for server in client._topology._servers.values():
for conn in server.pool.conns:
conn.supports_sessions = False

listener.reset()
await cursor.to_list()

getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
for event in getmore_events:
self.assertIn(
"lsid", event.command, "getMore must include lsid when session is materialized"
)
self.assertEqual(
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
)


class TestCausalConsistency(AsyncUnitTest):
listener: SessionTestListener
Expand Down
38 changes: 34 additions & 4 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Test the client_session module."""
from __future__ import annotations

import asyncio
import copy
import sys
import time
Expand All @@ -24,8 +23,6 @@
from test.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple

from pymongo.synchronous.mongo_client import MongoClient

sys.path[0:0] = [""]

from test import (
Expand All @@ -45,7 +42,7 @@

from bson import DBRef
from gridfs.synchronous.grid_file import GridFS, GridFSBucket
from pymongo import ASCENDING, MongoClient, _csot, monitoring
from pymongo import ASCENDING, MongoClient, monitoring
from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
Expand Down Expand Up @@ -938,6 +935,39 @@ def test_session_binding_end_session(self):

s2.end_session()

def test_getmore_preserves_lsid_after_session_support_lost(self):
listener = OvertCommandListener()
client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
coll = client.pymongo_test.test
coll.drop()
coll.insert_many([{"x": i} for i in range(10)])
self.addCleanup(coll.drop)

with client.start_session() as s:
cursor = coll.find({}, batch_size=2, session=s)
next(cursor)

find_event = next(e for e in listener.started_events if e.command_name == "find")
lsid = find_event.command["lsid"]

# Simulate a node stepping down: mark idle connections as not supporting sessions.
for server in client._topology._servers.values():
for conn in server.pool.conns:
conn.supports_sessions = False

listener.reset()
cursor.to_list()

getmore_events = [e for e in listener.started_events if e.command_name == "getMore"]
self.assertGreater(len(getmore_events), 0, "expected at least one getMore command")
for event in getmore_events:
self.assertIn(
"lsid", event.command, "getMore must include lsid when session is materialized"
)
self.assertEqual(
lsid, event.command["lsid"], "getMore lsid must match the session lsid from find"
)


class TestCausalConsistency(UnitTest):
listener: SessionTestListener
Expand Down
Loading