From fcb0244df3af3c66c9dc3cab983154465e788b78 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:06:21 +0000 Subject: [PATCH] fix: add StatelessModeNotSupported exception and improve tests Address review feedback from PR #1827: - Create StatelessModeNotSupported exception that inherits from RuntimeError - Replace _require_stateful_mode() helper with inline checks in each method - Add match= parameter to pytest.raises() calls for more specific error matching - Create stateless_session and stateful_session fixtures to reduce test boilerplate - Add test for exception's method attribute Claude-Generated-By: Claude Code (cli/claude-opus-4-5=100%) Claude-Steers: 7 Claude-Permission-Prompts: 26 Claude-Escapes: 1 --- src/mcp/server/session.py | 39 ++-- src/mcp/shared/exceptions.py | 18 ++ tests/server/test_stateless_mode.py | 290 ++++++++++------------------ 3 files changed, 132 insertions(+), 215 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 4a7b1768b..62762dfee 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -49,6 +49,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages +from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -157,26 +158,6 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return True - def _require_stateful_mode(self, feature_name: str) -> None: - """Raise an error if trying to use a feature that requires stateful mode. - - Server-to-client requests (sampling, elicitation, list_roots) are not - supported in stateless HTTP mode because there is no persistent connection - for bidirectional communication. - - Args: - feature_name: Name of the feature being used (for error message) - - Raises: - RuntimeError: If the session is in stateless mode - """ - if self._stateless: - raise RuntimeError( - f"Cannot use {feature_name} in stateless HTTP mode. " - "Stateless mode does not support server-to-client requests. " - "Use stateful mode (stateless_http=False) to enable this feature." - ) - async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() @@ -332,9 +313,10 @@ async def create_message( Raises: McpError: If tools are provided but client doesn't support them. ValueError: If tool_use or tool_result message structure is invalid. - RuntimeError: If called in stateless HTTP mode. + StatelessModeNotSupported: If called in stateless HTTP mode. """ - self._require_stateful_mode("sampling") + if self._stateless: + raise StatelessModeNotSupported(method="sampling") client_caps = self._client_params.capabilities if self._client_params else None validate_sampling_tools(client_caps, tools, tool_choice) validate_tool_use_result_messages(messages) @@ -372,7 +354,8 @@ async def create_message( async def list_roots(self) -> types.ListRootsResult: """Send a roots/list request.""" - self._require_stateful_mode("list_roots") + if self._stateless: + raise StatelessModeNotSupported(method="list_roots") return await self.send_request( types.ServerRequest(types.ListRootsRequest()), types.ListRootsResult, @@ -417,9 +400,10 @@ async def elicit_form( The client's response with form data Raises: - RuntimeError: If called in stateless HTTP mode. + StatelessModeNotSupported: If called in stateless HTTP mode. """ - self._require_stateful_mode("elicitation") + if self._stateless: + raise StatelessModeNotSupported(method="elicitation") return await self.send_request( types.ServerRequest( types.ElicitRequest( @@ -455,9 +439,10 @@ async def elicit_url( The client's response indicating acceptance, decline, or cancellation Raises: - RuntimeError: If called in stateless HTTP mode. + StatelessModeNotSupported: If called in stateless HTTP mode. """ - self._require_stateful_mode("elicitation") + if self._stateless: + raise StatelessModeNotSupported(method="elicitation") return await self.send_request( types.ServerRequest( types.ElicitRequest( diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 494311491..80f202443 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -18,6 +18,24 @@ def __init__(self, error: ErrorData): self.error = error +class StatelessModeNotSupported(RuntimeError): + """ + Raised when attempting to use a method that is not supported in stateless mode. + + Server-to-client requests (sampling, elicitation, list_roots) are not + supported in stateless HTTP mode because there is no persistent connection + for bidirectional communication. + """ + + def __init__(self, method: str): + super().__init__( + f"Cannot use {method} in stateless HTTP mode. " + "Stateless mode does not support server-to-client requests. " + "Use stateful mode (stateless_http=False) to enable this feature." + ) + self.method = method + + class UrlElicitationRequiredError(McpError): """ Specialized error for when a tool requires URL mode elicitation(s) before proceeding. diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index dfd90d1c3..c59ea2351 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -7,47 +7,32 @@ See: https://github.com/modelcontextprotocol/python-sdk/issues/1097 """ +from collections.abc import AsyncGenerator +from typing import Any + import anyio import pytest import mcp.types as types from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.message import SessionMessage from mcp.types import ServerCapabilities -def create_test_streams(): - """Create memory streams for testing.""" +@pytest.fixture +async def stateless_session() -> AsyncGenerator[ServerSession, None]: + """Create a stateless ServerSession for testing.""" server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - return ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) - -def create_init_options(): - """Create default initialization options for testing.""" - return InitializationOptions( + init_options = InitializationOptions( server_name="test", server_version="0.1.0", capabilities=ServerCapabilities(), ) - -@pytest.mark.anyio -async def test_list_roots_fails_in_stateless_mode(): - """Test that list_roots raises RuntimeError in stateless mode.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() - async with ( client_to_server_send, client_to_server_receive, @@ -57,159 +42,100 @@ async def test_list_roots_fails_in_stateless_mode(): async with ServerSession( client_to_server_receive, server_to_client_send, - create_init_options(), + init_options, stateless=True, ) as session: - with pytest.raises(RuntimeError) as exc_info: - await session.list_roots() - - assert "stateless HTTP mode" in str(exc_info.value) - assert "list_roots" in str(exc_info.value) + yield session @pytest.mark.anyio -async def test_create_message_fails_in_stateless_mode(): - """Test that create_message raises RuntimeError in stateless mode.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() +async def test_list_roots_fails_in_stateless_mode(stateless_session: ServerSession): + """Test that list_roots raises StatelessModeNotSupported in stateless mode.""" + with pytest.raises(StatelessModeNotSupported, match="list_roots"): + await stateless_session.list_roots() - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - create_init_options(), - stateless=True, - ) as session: - with pytest.raises(RuntimeError) as exc_info: - await session.create_message( - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="hello"), - ) - ], - max_tokens=100, - ) - assert "stateless HTTP mode" in str(exc_info.value) - assert "sampling" in str(exc_info.value) +@pytest.mark.anyio +async def test_create_message_fails_in_stateless_mode(stateless_session: ServerSession): + """Test that create_message raises StatelessModeNotSupported in stateless mode.""" + with pytest.raises(StatelessModeNotSupported, match="sampling"): + await stateless_session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="hello"), + ) + ], + max_tokens=100, + ) @pytest.mark.anyio -async def test_elicit_form_fails_in_stateless_mode(): - """Test that elicit_form raises RuntimeError in stateless mode.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() +async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSession): + """Test that elicit_form raises StatelessModeNotSupported in stateless mode.""" + with pytest.raises(StatelessModeNotSupported, match="elicitation"): + await stateless_session.elicit_form( + message="Please provide input", + requestedSchema={"type": "object", "properties": {}}, + ) - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - create_init_options(), - stateless=True, - ) as session: - with pytest.raises(RuntimeError) as exc_info: - await session.elicit_form( - message="Please provide input", - requestedSchema={"type": "object", "properties": {}}, - ) - assert "stateless HTTP mode" in str(exc_info.value) - assert "elicitation" in str(exc_info.value) +@pytest.mark.anyio +async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSession): + """Test that elicit_url raises StatelessModeNotSupported in stateless mode.""" + with pytest.raises(StatelessModeNotSupported, match="elicitation"): + await stateless_session.elicit_url( + message="Please authenticate", + url="https://example.com/auth", + elicitation_id="test-123", + ) @pytest.mark.anyio -async def test_elicit_url_fails_in_stateless_mode(): - """Test that elicit_url raises RuntimeError in stateless mode.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() +async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: ServerSession): + """Test that the deprecated elicit method also fails in stateless mode.""" + with pytest.raises(StatelessModeNotSupported, match="elicitation"): + await stateless_session.elicit( + message="Please provide input", + requestedSchema={"type": "object", "properties": {}}, + ) - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - create_init_options(), - stateless=True, - ) as session: - with pytest.raises(RuntimeError) as exc_info: - await session.elicit_url( - message="Please authenticate", - url="https://example.com/auth", - elicitation_id="test-123", - ) - assert "stateless HTTP mode" in str(exc_info.value) - assert "elicitation" in str(exc_info.value) +@pytest.mark.anyio +async def test_stateless_error_message_is_actionable(stateless_session: ServerSession): + """Test that the error message provides actionable guidance.""" + with pytest.raises(StatelessModeNotSupported) as exc_info: + await stateless_session.list_roots() + + error_message = str(exc_info.value) + # Should mention it's stateless mode + assert "stateless HTTP mode" in error_message + # Should explain why it doesn't work + assert "server-to-client requests" in error_message + # Should tell user how to fix it + assert "stateless_http=False" in error_message @pytest.mark.anyio -async def test_elicit_deprecated_fails_in_stateless_mode(): - """Test that the deprecated elicit method also fails in stateless mode.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() +async def test_exception_has_method_attribute(stateless_session: ServerSession): + """Test that the exception has a method attribute for programmatic access.""" + with pytest.raises(StatelessModeNotSupported) as exc_info: + await stateless_session.list_roots() - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - create_init_options(), - stateless=True, - ) as session: - with pytest.raises(RuntimeError) as exc_info: - await session.elicit( - message="Please provide input", - requestedSchema={"type": "object", "properties": {}}, - ) + assert exc_info.value.method == "list_roots" - assert "stateless HTTP mode" in str(exc_info.value) - assert "elicitation" in str(exc_info.value) +@pytest.fixture +async def stateful_session() -> AsyncGenerator[ServerSession, None]: + """Create a stateful ServerSession for testing.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) -@pytest.mark.anyio -async def test_require_stateful_mode_does_not_raise_in_stateful_mode(): - """Test that _require_stateful_mode does not raise in stateful mode.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() + init_options = InitializationOptions( + server_name="test", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ) async with ( client_to_server_send, @@ -220,44 +146,32 @@ async def test_require_stateful_mode_does_not_raise_in_stateful_mode(): async with ServerSession( client_to_server_receive, server_to_client_send, - create_init_options(), - stateless=False, # Stateful mode + init_options, + stateless=False, ) as session: - # These should not raise - the check passes in stateful mode - session._require_stateful_mode("list_roots") - session._require_stateful_mode("sampling") - session._require_stateful_mode("elicitation") + yield session @pytest.mark.anyio -async def test_stateless_error_message_is_actionable(): - """Test that the error message provides actionable guidance.""" - ( - server_to_client_send, - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - ) = create_test_streams() +async def test_stateful_mode_does_not_raise_stateless_error( + stateful_session: ServerSession, monkeypatch: pytest.MonkeyPatch +): + """Test that StatelessModeNotSupported is not raised in stateful mode. - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - create_init_options(), - stateless=True, - ) as session: - with pytest.raises(RuntimeError) as exc_info: - await session.list_roots() - - error_message = str(exc_info.value) - # Should mention it's stateless mode - assert "stateless HTTP mode" in error_message - # Should explain why it doesn't work - assert "server-to-client requests" in error_message - # Should tell user how to fix it - assert "stateless_http=False" in error_message + We mock send_request to avoid blocking on I/O while still verifying + that the stateless check passes. + """ + send_request_called = False + + async def mock_send_request(*_: Any, **__: Any) -> types.ListRootsResult: + nonlocal send_request_called + send_request_called = True + return types.ListRootsResult(roots=[]) + + monkeypatch.setattr(stateful_session, "send_request", mock_send_request) + + # This should NOT raise StatelessModeNotSupported + result = await stateful_session.list_roots() + + assert send_request_called + assert isinstance(result, types.ListRootsResult)