diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 4a54f9eb3f..f429daf8b0 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -20,7 +20,6 @@ import copy import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -45,6 +44,11 @@ _raise_bulk_write_error, _Run, ) +from pymongo.command_helpers import ( + _log_command_failed, + _log_command_started, + _log_command_succeeded, +) from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, @@ -57,7 +61,6 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, @@ -252,44 +255,15 @@ async def write_command( ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: bwc._start(cmd, request_id, docs) try: reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] await client._process_response(reply, bwc.session) # type: ignore[arg-type] @@ -299,24 +273,17 @@ async def write_command( failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: bwc._fail(request_id, failure, duration) @@ -337,22 +304,7 @@ async def unack_write( client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: cmd = bwc._start(cmd, request_id, docs) try: @@ -363,23 +315,9 @@ async def unack_write( else: # Comply with APM spec. reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) except Exception as exc: @@ -390,24 +328,17 @@ async def unack_write( failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: assert bwc.start_time is not None bwc._fail(request_id, failure, duration) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 015947d7ef..c29bb64c66 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -20,7 +20,6 @@ import copy import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -48,6 +47,11 @@ _merge_command, _throw_client_bulk_write_exception, ) +from pymongo.command_helpers import ( + _log_command_failed, + _log_command_started, + _log_command_succeeded, +) from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, @@ -63,7 +67,6 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, @@ -239,44 +242,15 @@ async def write_command( """A proxy for AsyncConnection.write_command that handles event publishing.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] # Process the response from the server. @@ -287,24 +261,17 @@ async def write_command( failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: bwc._fail(request_id, failure, duration) @@ -328,22 +295,7 @@ async def unack_write( client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: cmd = bwc._start(cmd, request_id, op_docs, ns_docs) try: @@ -354,23 +306,9 @@ async def unack_write( else: # Comply with APM spec. reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) except Exception as exc: @@ -381,24 +319,17 @@ async def unack_write( failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: assert bwc.start_time is not None bwc._fail(request_id, failure, duration) diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 5a59c67a15..34194899e1 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -189,15 +189,10 @@ async def _send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self._sock_mgr: self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] - if response.from_command: - cursor = response.docs[0]["cursor"] - documents = cursor["nextBatch"] - self._postbatchresumetoken = cursor.get("postBatchResumeToken") - self._id = cursor["id"] - else: - documents = response.docs - assert isinstance(response.data, _OpReply) - self._id = response.data.cursor_id + cursor = response.docs[0]["cursor"] + documents = cursor["nextBatch"] + self._postbatchresumetoken = cursor.get("postBatchResumeToken") + self._id = cursor["id"] if self._id == 0: await self.close() diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index a60c082ade..f7c1671777 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1020,29 +1020,23 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: cmd_name = operation.name docs = response.docs - if response.from_command: - if cmd_name != "explain": - cursor = docs[0]["cursor"] - self._id = cursor["id"] - if cmd_name == "find": - documents = cursor["firstBatch"] - # Update the namespace used for future getMore commands. - ns = cursor.get("ns") - if ns: - self._dbname, self._collname = ns.split(".", 1) - else: - documents = cursor["nextBatch"] - self._data = deque(documents) - self._retrieved += len(documents) + if cmd_name != "explain": + cursor = docs[0]["cursor"] + self._id = cursor["id"] + if cmd_name == "find": + documents = cursor["firstBatch"] + # Update the namespace used for future getMore commands. + ns = cursor.get("ns") + if ns: + self._dbname, self._collname = ns.split(".", 1) else: - self._id = 0 - self._data = deque(docs) - self._retrieved += len(docs) + documents = cursor["nextBatch"] + self._data = deque(documents) + self._retrieved += len(documents) else: - assert isinstance(response.data, _OpReply) - self._id = response.data.cursor_id + self._id = 0 self._data = deque(docs) - self._retrieved += response.data.number_returned + self._retrieved += len(docs) if self._id == 0: # Don't wait for garbage collection to call __del__, return the diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 5a5dc7fa2c..bd51397ebe 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -16,10 +16,10 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -30,18 +30,18 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message +from pymongo.command_helpers import ( + _log_command_failed, + _log_command_started, + _log_command_succeeded, +) from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import ( NotPrimaryError, OperationFailure, ) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _OpMsg +from pymongo.message import _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - async_receive_message, - async_sendall, -) if TYPE_CHECKING: from bson import CodecOptions @@ -57,6 +57,166 @@ _IS_SYNC = False +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + + +async def _network_command_core( + conn: AsyncConnection, + dbname: str, + spec: MutableMapping[str, Any], + request_id: int, + msg: Optional[bytes], + max_doc_size: int, + codec_options: CodecOptions[_DocumentType], + session: Optional[AsyncClientSession], + client: Optional[AsyncMongoClient[Any]], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + unacknowledged: bool = False, + more_to_come: bool = False, + unpack_res: Optional[Callable[..., list[_DocumentOut]]] = None, + cursor_id: Optional[int] = None, + orig: Optional[MutableMapping[str, Any]] = None, + speculative_hello: bool = False, +) -> tuple[list[_DocumentOut], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send/receive a command and return (docs, raw_reply, duration). + + Handles APM logging, send/receive, unpacking, response processing, + and decryption. Both the standard command path and the cursor + (find/getMore) path go through this function. + """ + publish = listeners is not None and listeners.enabled_for_commands + name = next(iter(spec)) + reply: Optional[Union[_OpReply, _OpMsg]] = None + docs: list[_DocumentOut] = [] + + if client is not None: + _log_command_started(client, conn, spec, dbname, request_id, request_id) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig if orig is not None else spec, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = await conn.receive_message(None) + else: + assert msg is not None + await conn.send_message(msg, max_doc_size) + if unacknowledged: + # Unacknowledged write: fake a successful command response. + docs = [{"ok": 1}] # type: ignore[list-item] + else: + reply = await conn.receive_message(request_id) + + if reply is not None: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=False, + user_fields=_CURSOR_DOC_FIELDS, + ) + else: + docs = list( + reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + ) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + _log_command_failed( + client, + conn, + spec, + dbname, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + _log_command_succeeded( + client, + conn, + spec, + dbname, + request_id, + request_id, + docs[0], + duration, + speculative_hello, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + docs[0], + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + # Decrypt response. + if client and client._encrypter and reply is not None: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + decrypt_fields = _CURSOR_DOC_FIELDS if unpack_res is not None else user_fields + docs = list(_decode_all_selective(decrypted, codec_options, decrypt_fields)) # type: ignore[arg-type] + + return docs, reply, duration + async def command( conn: AsyncConnection, @@ -156,143 +316,30 @@ async def command( request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) + max_doc_size = 0 if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - await async_sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - return response_doc # type: ignore[return-value] + docs, _reply, _duration = await _network_command_core( + conn=conn, + dbname=dbname, + spec=spec, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=codec_options, + session=session, + client=client, + listeners=listeners, + address=address, + start=start, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + user_fields=user_fields, + unacknowledged=unacknowledged, + orig=orig, + speculative_hello=speculative_hello, + ) + return cast("_DocumentType", docs[0]) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index f212306174..855cc04a81 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,18 +26,14 @@ Union, ) -from bson import _decode_all_selective from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response +from pymongo.asynchronous.network import _network_command_core from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -55,8 +51,6 @@ _IS_SYNC = False -_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} - class Server: def __init__( @@ -158,171 +152,46 @@ async def run_operation( :param client: An AsyncMongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() - use_cmd = operation.use_command(conn) - more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come - cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) + operation.use_command(conn) + more_to_come = bool(operation.conn_mgr and operation.conn_mgr.more_to_come) + cmd, dbn = await self.operation_to_command(operation, conn, True) if more_to_come: request_id = 0 + msg = None + max_doc_size = 0 else: - message = operation.get_message(read_preference, conn, use_cmd) - request_id, data, max_doc_size = self._split_message(message) - - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = await conn.receive_message(None) - else: - await conn.send_message(data, max_doc_size) - reply = await conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) - if use_cmd: - first = docs[0] - await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - if use_cmd: - res = docs[0] - elif operation.name == "explain": - res = docs[0] if docs else {} - else: - res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] - if operation.name == "find": - res["cursor"]["firstBatch"] = docs - else: - res["cursor"]["nextBatch"] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) - + op_message = operation.get_message(read_preference, conn, True) + request_id, msg, max_doc_size = self._split_message(op_message) + + if listeners.enabled_for_commands and "$db" not in cmd: + cmd["$db"] = dbn + + docs, reply, duration = await _network_command_core( + conn=conn, + dbname=dbn, + spec=cmd, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=operation.codec_options, + session=operation.session, # type: ignore[arg-type] + client=client, + listeners=listeners, + address=conn.address, + start=start, + more_to_come=more_to_come, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + ) + + assert reply is not None response: Response - + client = operation.client # type: ignore[assignment] if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() - if isinstance(reply, _OpMsg): - # In OP_MSG, the server keeps sending only if the - # more_to_come flag is set. - more_to_come = reply.more_to_come - else: - # In OP_REPLY, the server keeps sending until cursor_id is 0. - more_to_come = bool(operation.exhaust and reply.cursor_id) + more_to_come = reply.more_to_come # type: ignore[union-attr] if operation.conn_mgr: operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse( @@ -331,7 +200,7 @@ async def run_operation( conn=conn, duration=duration, request_id=request_id, - from_command=use_cmd, + from_command=True, docs=docs, more_to_come=more_to_come, ) @@ -341,7 +210,7 @@ async def run_operation( address=self._description.address, duration=duration, request_id=request_id, - from_command=use_cmd, + from_command=True, docs=docs, ) diff --git a/pymongo/command_helpers.py b/pymongo/command_helpers.py new file mode 100644 index 0000000000..b09bae7975 --- /dev/null +++ b/pymongo/command_helpers.py @@ -0,0 +1,110 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for command monitoring and logging.""" +from __future__ import annotations + +import datetime +import logging +from typing import Any, Mapping + +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log + + +def _log_command_started( + client: Any, + conn: Any, + cmd: Mapping[str, Any], + dbname: str, + request_id: int, + operation_id: int, +) -> None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=dbname, + requestId=request_id, + operationId=operation_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + + +def _log_command_succeeded( + client: Any, + conn: Any, + cmd: Mapping[str, Any], + dbname: str, + request_id: int, + operation_id: int, + reply: Any, + duration: datetime.timedelta, + speculative_authenticate: bool = False, +) -> None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=dbname, + requestId=request_id, + operationId=operation_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate=speculative_authenticate, + ) + + +def _log_command_failed( + client: Any, + conn: Any, + cmd: Mapping[str, Any], + dbname: str, + request_id: int, + operation_id: int, + failure: Any, + duration: datetime.timedelta, + is_server_side_error: bool, +) -> None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=dbname, + requestId=request_id, + operationId=operation_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=is_server_side_error, + ) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 22d6a7a76a..73929e0642 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -20,7 +20,6 @@ import copy import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -43,6 +42,11 @@ _raise_bulk_write_error, _Run, ) +from pymongo.command_helpers import ( + _log_command_failed, + _log_command_started, + _log_command_succeeded, +) from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, @@ -55,7 +59,6 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, @@ -252,44 +255,15 @@ def write_command( ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[bwc.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: bwc._start(cmd, request_id, docs) try: reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] client._process_response(reply, bwc.session) # type: ignore[arg-type] @@ -299,24 +273,17 @@ def write_command( failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: bwc._fail(request_id, failure, duration) @@ -337,22 +304,7 @@ def unack_write( client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: cmd = bwc._start(cmd, request_id, docs) try: @@ -363,23 +315,9 @@ def unack_write( else: # Comply with APM spec. reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) except Exception as exc: @@ -390,24 +328,17 @@ def unack_write( failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: assert bwc.start_time is not None bwc._fail(request_id, failure, duration) diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 1134594ae9..285701e8fa 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -20,7 +20,6 @@ import copy import datetime -import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -48,6 +47,11 @@ _merge_command, _throw_client_bulk_write_exception, ) +from pymongo.command_helpers import ( + _log_command_failed, + _log_command_started, + _log_command_succeeded, +) from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, @@ -63,7 +67,6 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, @@ -239,44 +242,15 @@ def write_command( """A proxy for Connection.write_command that handles event publishing.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: bwc._start(cmd, request_id, op_docs, ns_docs) try: reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] duration = datetime.datetime.now() - bwc.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] # Process the response from the server. @@ -287,24 +261,17 @@ def write_command( failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: bwc._fail(request_id, failure, duration) @@ -328,22 +295,7 @@ def unack_write( client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_started(client, bwc.conn, cmd, bwc.db_name, request_id, request_id) if bwc.publish: cmd = bwc._start(cmd, request_id, op_docs, ns_docs) try: @@ -354,23 +306,9 @@ def unack_write( else: # Comply with APM spec. reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - ) + _log_command_succeeded( + client, bwc.conn, cmd, bwc.db_name, request_id, request_id, reply, duration + ) if bwc.publish: bwc._succeed(request_id, reply, duration) except Exception as exc: @@ -381,24 +319,17 @@ def unack_write( failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=bwc.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=bwc.conn.id, - serverConnectionId=bwc.conn.server_connection_id, - serverHost=bwc.conn.address[0], - serverPort=bwc.conn.address[1], - serviceId=bwc.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) + _log_command_failed( + client, + bwc.conn, + cmd, + bwc.db_name, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) if bwc.publish: assert bwc.start_time is not None bwc._fail(request_id, failure, duration) diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 34f60c6540..b2023ad5de 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -189,15 +189,10 @@ def _send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self._sock_mgr: self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] - if response.from_command: - cursor = response.docs[0]["cursor"] - documents = cursor["nextBatch"] - self._postbatchresumetoken = cursor.get("postBatchResumeToken") - self._id = cursor["id"] - else: - documents = response.docs - assert isinstance(response.data, _OpReply) - self._id = response.data.cursor_id + cursor = response.docs[0]["cursor"] + documents = cursor["nextBatch"] + self._postbatchresumetoken = cursor.get("postBatchResumeToken") + self._id = cursor["id"] if self._id == 0: self.close() diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 5a721d8e06..909671dac5 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1018,29 +1018,23 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: cmd_name = operation.name docs = response.docs - if response.from_command: - if cmd_name != "explain": - cursor = docs[0]["cursor"] - self._id = cursor["id"] - if cmd_name == "find": - documents = cursor["firstBatch"] - # Update the namespace used for future getMore commands. - ns = cursor.get("ns") - if ns: - self._dbname, self._collname = ns.split(".", 1) - else: - documents = cursor["nextBatch"] - self._data = deque(documents) - self._retrieved += len(documents) + if cmd_name != "explain": + cursor = docs[0]["cursor"] + self._id = cursor["id"] + if cmd_name == "find": + documents = cursor["firstBatch"] + # Update the namespace used for future getMore commands. + ns = cursor.get("ns") + if ns: + self._dbname, self._collname = ns.split(".", 1) else: - self._id = 0 - self._data = deque(docs) - self._retrieved += len(docs) + documents = cursor["nextBatch"] + self._data = deque(documents) + self._retrieved += len(documents) else: - assert isinstance(response.data, _OpReply) - self._id = response.data.cursor_id + self._id = 0 self._data = deque(docs) - self._retrieved += response.data.number_returned + self._retrieved += len(docs) if self._id == 0: # Don't wait for garbage collection to call __del__, return the diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 7d9bca4d58..1099af24d7 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,10 +16,10 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -30,18 +30,18 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message +from pymongo.command_helpers import ( + _log_command_failed, + _log_command_started, + _log_command_succeeded, +) from pymongo.compression_support import _NO_COMPRESSION from pymongo.errors import ( NotPrimaryError, OperationFailure, ) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _OpMsg +from pymongo.message import _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - receive_message, - sendall, -) if TYPE_CHECKING: from bson import CodecOptions @@ -57,6 +57,166 @@ _IS_SYNC = True +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + + +def _network_command_core( + conn: Connection, + dbname: str, + spec: MutableMapping[str, Any], + request_id: int, + msg: Optional[bytes], + max_doc_size: int, + codec_options: CodecOptions[_DocumentType], + session: Optional[ClientSession], + client: Optional[MongoClient[Any]], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + unacknowledged: bool = False, + more_to_come: bool = False, + unpack_res: Optional[Callable[..., list[_DocumentOut]]] = None, + cursor_id: Optional[int] = None, + orig: Optional[MutableMapping[str, Any]] = None, + speculative_hello: bool = False, +) -> tuple[list[_DocumentOut], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: + """Send/receive a command and return (docs, raw_reply, duration). + + Handles APM logging, send/receive, unpacking, response processing, + and decryption. Both the standard command path and the cursor + (find/getMore) path go through this function. + """ + publish = listeners is not None and listeners.enabled_for_commands + name = next(iter(spec)) + reply: Optional[Union[_OpReply, _OpMsg]] = None + docs: list[_DocumentOut] = [] + + if client is not None: + _log_command_started(client, conn, spec, dbname, request_id, request_id) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig if orig is not None else spec, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = conn.receive_message(None) + else: + assert msg is not None + conn.send_message(msg, max_doc_size) + if unacknowledged: + # Unacknowledged write: fake a successful command response. + docs = [{"ok": 1}] # type: ignore[list-item] + else: + reply = conn.receive_message(request_id) + + if reply is not None: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=False, + user_fields=_CURSOR_DOC_FIELDS, + ) + else: + docs = list( + reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + ) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + _log_command_failed( + client, + conn, + spec, + dbname, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + _log_command_succeeded( + client, + conn, + spec, + dbname, + request_id, + request_id, + docs[0], + duration, + speculative_hello, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + docs[0], + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + # Decrypt response. + if client and client._encrypter and reply is not None: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + decrypt_fields = _CURSOR_DOC_FIELDS if unpack_res is not None else user_fields + docs = list(_decode_all_selective(decrypted, codec_options, decrypt_fields)) # type: ignore[arg-type] + + return docs, reply, duration + def command( conn: Connection, @@ -156,143 +316,30 @@ def command( request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) + max_doc_size = 0 if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - return response_doc # type: ignore[return-value] + docs, _reply, _duration = _network_command_core( + conn=conn, + dbname=dbname, + spec=spec, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=codec_options, + session=session, + client=client, + listeners=listeners, + address=address, + start=start, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + user_fields=user_fields, + unacknowledged=unacknowledged, + orig=orig, + speculative_hello=speculative_hello, + ) + return cast("_DocumentType", docs[0]) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index f57420918b..a6964518f4 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -26,19 +26,15 @@ Union, ) -from bson import _decode_all_selective -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( - _COMMAND_LOGGER, _SDAM_LOGGER, - _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query +from pymongo.message import _GetMore, _Query from pymongo.response import PinnedResponse, Response from pymongo.synchronous.helpers import _handle_reauth +from pymongo.synchronous.network import _network_command_core if TYPE_CHECKING: from queue import Queue @@ -55,8 +51,6 @@ _IS_SYNC = True -_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} - class Server: def __init__( @@ -158,171 +152,46 @@ def run_operation( :param client: A MongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() - use_cmd = operation.use_command(conn) - more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come - cmd, dbn = self.operation_to_command(operation, conn, use_cmd) + operation.use_command(conn) + more_to_come = bool(operation.conn_mgr and operation.conn_mgr.more_to_come) + cmd, dbn = self.operation_to_command(operation, conn, True) if more_to_come: request_id = 0 + msg = None + max_doc_size = 0 else: - message = operation.get_message(read_preference, conn, use_cmd) - request_id, data, max_doc_size = self._split_message(message) - - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=cmd, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = conn.receive_message(None) - else: - conn.send_message(data, max_doc_size) - reply = conn.receive_message(request_id) - - # Unpack and check for command errors. - if use_cmd: - user_fields = _CURSOR_DOC_FIELDS - legacy_response = False - else: - user_fields = None - legacy_response = True - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields, - ) - if use_cmd: - first = docs[0] - operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response - # format. - if use_cmd: - res = docs[0] - elif operation.name == "explain": - res = docs[0] if docs else {} - else: - res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] - if operation.name == "find": - res["cursor"]["firstBatch"] = docs - else: - res["cursor"]["nextBatch"] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=res, - commandName=next(iter(cmd)), - databaseName=dbn, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - if use_cmd: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) - + op_message = operation.get_message(read_preference, conn, True) + request_id, msg, max_doc_size = self._split_message(op_message) + + if listeners.enabled_for_commands and "$db" not in cmd: + cmd["$db"] = dbn + + docs, reply, duration = _network_command_core( + conn=conn, + dbname=dbn, + spec=cmd, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=operation.codec_options, + session=operation.session, # type: ignore[arg-type] + client=client, + listeners=listeners, + address=conn.address, + start=start, + more_to_come=more_to_come, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + ) + + assert reply is not None response: Response - + client = operation.client # type: ignore[assignment] if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() - if isinstance(reply, _OpMsg): - # In OP_MSG, the server keeps sending only if the - # more_to_come flag is set. - more_to_come = reply.more_to_come - else: - # In OP_REPLY, the server keeps sending until cursor_id is 0. - more_to_come = bool(operation.exhaust and reply.cursor_id) + more_to_come = reply.more_to_come # type: ignore[union-attr] if operation.conn_mgr: operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse( @@ -331,7 +200,7 @@ def run_operation( conn=conn, duration=duration, request_id=request_id, - from_command=use_cmd, + from_command=True, docs=docs, more_to_come=more_to_come, ) @@ -341,7 +210,7 @@ def run_operation( address=self._description.address, duration=duration, request_id=request_id, - from_command=use_cmd, + from_command=True, docs=docs, )