From e1dc33dce077bd8c28a086e7fe63dcebb5b4a9db Mon Sep 17 00:00:00 2001 From: Clay Gifford Date: Mon, 18 May 2026 11:07:07 -0700 Subject: [PATCH] feat(mcp): Add McpServerInterceptor with scoped hooks Add hook-based server interceptor for McpService, following the established ClientInterceptor pattern. Interceptors observe or modify requests at specific pipeline stages without exposing async dispatch internals. Hooks: readBefore/modifyBefore Execution and ToolCall, readAfter/modifyAfter Execution and ToolCall. Also makes McpServerProxy's rpc(), start(), and shutdown() protected to allow custom proxy transport implementations. --- .../java/mcp/server/McpExecutionHook.java | 60 ++ .../java/mcp/server/McpServerBuilder.java | 16 + .../java/mcp/server/McpServerInterceptor.java | 217 +++++ .../mcp/server/McpServerInterceptorChain.java | 136 +++ .../java/mcp/server/McpServerProxy.java | 6 +- .../smithy/java/mcp/server/McpService.java | 284 +++++- .../java/mcp/server/McpToolCallHook.java | 67 ++ .../smithy/java/mcp/server/McpServerTest.java | 842 +++++++++++++++++- 8 files changed, 1603 insertions(+), 25 deletions(-) create mode 100644 mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpExecutionHook.java create mode 100644 mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptor.java create mode 100644 mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptorChain.java create mode 100644 mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpToolCallHook.java diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpExecutionHook.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpExecutionHook.java new file mode 100644 index 000000000..7a417b4d0 --- /dev/null +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpExecutionHook.java @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Hook data available at the execution level. Passed to execution-scoped hooks in + * {@link McpServerInterceptor}. + * + *

The {@link #context()} provides a per-request key-value store for passing state + * between hooks. For example, a telemetry interceptor can stash a start timestamp in + * {@code readBeforeExecution} and retrieve it in {@code readAfterExecution}. + */ +@SmithyUnstableApi +public class McpExecutionHook { + + private final JsonRpcRequest request; + private final ProtocolVersion protocolVersion; + private final Context context; + + McpExecutionHook(JsonRpcRequest request, ProtocolVersion protocolVersion, Context context) { + this.request = request; + this.protocolVersion = protocolVersion; + this.context = context; + } + + /** + * The JSON-RPC request being handled. + */ + public JsonRpcRequest request() { + return request; + } + + /** + * Returns a new hook with the given request, or the same hook if unchanged. + */ + public McpExecutionHook withRequest(JsonRpcRequest request) { + return this.request == request ? this : new McpExecutionHook(request, protocolVersion, context); + } + + /** + * The MCP protocol version for this request. + */ + public ProtocolVersion protocolVersion() { + return protocolVersion; + } + + /** + * Per-request context for passing state between hooks. + */ + public Context context() { + return context; + } +} diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java index fa40be0db..5777859c0 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerBuilder.java @@ -24,6 +24,7 @@ public final class McpServerBuilder { OutputStream os; Map services = new HashMap<>(); List proxyList = new ArrayList<>(); + McpServerInterceptor interceptor; String name; String version; ToolFilter toolFilter = (server, tool) -> true; @@ -72,6 +73,10 @@ public Server build() { builder.version(version); } + if (interceptor != null) { + builder.interceptor(interceptor); + } + this.mcpService = builder.build(); return new McpServer(this); } @@ -101,6 +106,17 @@ public McpServerBuilder metricsObserver(McpMetricsObserver observer) { return this; } + /** + * Sets the server interceptor. Use {@link McpServerInterceptor#chain(List)} to compose + * multiple interceptors into one. + * + * @see McpServerInterceptor for hook descriptions and the execution lifecycle + */ + public McpServerBuilder interceptor(McpServerInterceptor interceptor) { + this.interceptor = Objects.requireNonNull(interceptor, "interceptor"); + return this; + } + private void validate() { Objects.requireNonNull(is, "MCP server input stream is required"); Objects.requireNonNull(os, "MCP server output stream is required"); diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptor.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptor.java new file mode 100644 index 000000000..e2c056441 --- /dev/null +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptor.java @@ -0,0 +1,217 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import java.util.List; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.java.mcp.model.JsonRpcResponse; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Interceptor for MCP server request processing. Interceptors inject code into the + * {@link McpService} request execution pipeline via hooks at specific stages. + * + *

Hooks are either "read" hooks (observe in-flight data) or "modify" hooks (transform + * in-flight data). All hooks have default no-op implementations; override only the hooks + * you need. + * + *

Execution lifecycle

+ * + *

For every request: + *

    + *
  1. {@link #readBeforeExecution} — observe the incoming request
  2. + *
  3. {@link #modifyBeforeExecution} — optionally transform the request
  4. + *
  5. For {@code tools/call} requests only: + *
      + *
    1. {@link #readBeforeToolCall} — observe before tool dispatch
    2. + *
    3. {@link #modifyBeforeToolCall} — optionally transform the request
    4. + *
    5. Tool dispatch (local or proxy)
    6. + *
    7. {@link #readAfterToolCall} — observe the tool result
    8. + *
    9. {@link #modifyAfterToolCall} — optionally transform the response
    10. + *
    + *
  6. + *
  7. {@link #readAfterExecution} — observe the final result (ALWAYS fires)
  8. + *
  9. {@link #modifyAfterExecution} — optionally transform the final response
  10. + *
+ * + *

Error handling

+ * + *

Any hook may throw a {@link RuntimeException}. When a hook throws, remaining hooks + * in that stage are skipped, and execution jumps to the after-execution hooks with the + * error. The {@code readAfterExecution} and {@code modifyAfterExecution} hooks ALWAYS fire, + * ensuring cleanup and telemetry logic runs regardless of errors. + * + *

Async tool calls

+ * + *

For proxy tool calls, the after-tool-call and after-execution hooks fire on the + * thread that receives the proxy response, not the original request thread. Hook + * implementations must be thread-safe. + * + *

Example: telemetry

+ *
{@code
+ * public class TelemetryInterceptor implements McpServerInterceptor {
+ *     private static final Context.Key START = Context.key("start");
+ *
+ *     @Override
+ *     public void readBeforeExecution(McpExecutionHook hook) {
+ *         hook.context().put(START, System.nanoTime());
+ *     }
+ *
+ *     @Override
+ *     public void readAfterExecution(McpExecutionHook hook,
+ *             JsonRpcResponse response, RuntimeException error) {
+ *         long duration = System.nanoTime() - hook.context().get(START);
+ *         emitMetrics(hook.request().getMethod(), duration, error == null);
+ *     }
+ * }
+ * }
+ * + *

Example: access control

+ *
{@code
+ * public class AccessControlInterceptor implements McpServerInterceptor {
+ *     @Override
+ *     public void readBeforeToolCall(McpToolCallHook hook) {
+ *         if (isBlocked(hook.toolName(), hook.serverId())) {
+ *             throw new RuntimeException("Access denied: " + hook.toolName());
+ *         }
+ *     }
+ * }
+ * }
+ */ +@SmithyUnstableApi +public interface McpServerInterceptor { + + /** + * An interceptor that does nothing. + */ + McpServerInterceptor NOOP = new McpServerInterceptor() {}; + + /** + * Combines multiple interceptors into a single interceptor that invokes each one + * in order. Hooks are called sequentially on each interceptor in list order. + * + * @param interceptors The interceptors to compose. + * @return A single interceptor that delegates to all provided interceptors. + */ + static McpServerInterceptor chain(List interceptors) { + return switch (interceptors.size()) { + case 0 -> NOOP; + case 1 -> interceptors.get(0); + default -> new McpServerInterceptorChain(List.copyOf(interceptors)); + }; + } + + /** + * Combines multiple interceptors into a single interceptor that invokes each one + * in order. Convenience overload of {@link #chain(List)}. + * + * @param interceptors The interceptors to compose. + * @return A single interceptor that delegates to all provided interceptors. + */ + static McpServerInterceptor chain(McpServerInterceptor... interceptors) { + return chain(List.of(interceptors)); + } + + // --- Execution-level hooks (fire for all requests) --- + + /** + * Called when a request is received, before any dispatch logic. + * + * @param hook Execution hook data containing the request, protocol version, and context. + */ + default void readBeforeExecution(McpExecutionHook hook) {} + + /** + * Called before dispatch. Can return a modified request. + * + * @param hook Execution hook data. + * @return The request to dispatch, or {@code hook.request()} to pass through unmodified. + */ + default JsonRpcRequest modifyBeforeExecution(McpExecutionHook hook) { + return hook.request(); + } + + /** + * Called when execution completes. ALWAYS fires, even if an earlier hook threw. + * + * @param hook Execution hook data. + * @param response The response, or {@code null} for notifications and async proxy calls + * still in flight. + * @param error The error if one occurred, or {@code null} on success. + */ + default void readAfterExecution(McpExecutionHook hook, JsonRpcResponse response, RuntimeException error) {} + + /** + * Called when execution completes. Can modify the response or handle errors. + * ALWAYS fires, even if an earlier hook threw. + * + * @param hook Execution hook data. + * @param response The response, or {@code null} for notifications. + * @param error The error if one occurred, or {@code null} on success. + * @return The final response. + * @throws RuntimeException to propagate or replace the error. + */ + default JsonRpcResponse modifyAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + if (error != null) { + throw error; + } + return response; + } + + // --- Tool-level hooks (fire only for tools/call) --- + + /** + * Called before a tool is invoked. + * + * @param hook Tool call hook data containing tool name, server ID, and proxy status. + */ + default void readBeforeToolCall(McpToolCallHook hook) {} + + /** + * Called before a tool is invoked. Can return a modified request. + * + * @param hook Tool call hook data. + * @return The request to use for tool invocation, or {@code hook.request()} to pass + * through unmodified. + */ + default JsonRpcRequest modifyBeforeToolCall(McpToolCallHook hook) { + return hook.request(); + } + + /** + * Called after a tool completes. For proxy tools, this fires on the callback thread. + * + * @param hook Tool call hook data. + * @param response The tool call response. + * @param error The error if one occurred, or {@code null} on success. + */ + default void readAfterToolCall(McpToolCallHook hook, JsonRpcResponse response, RuntimeException error) {} + + /** + * Called after a tool completes. Can modify the response or handle errors. + * For proxy tools, this fires on the callback thread. + * + * @param hook Tool call hook data. + * @param response The tool call response. + * @param error The error if one occurred, or {@code null} on success. + * @return The response to return. + * @throws RuntimeException to propagate or replace the error. + */ + default JsonRpcResponse modifyAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + if (error != null) { + throw error; + } + return response; + } +} diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptorChain.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptorChain.java new file mode 100644 index 000000000..0af320392 --- /dev/null +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerInterceptorChain.java @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import java.util.List; +import software.amazon.smithy.java.logging.InternalLogger; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.java.mcp.model.JsonRpcResponse; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Composes multiple {@link McpServerInterceptor} instances into a single interceptor + * that delegates to each one in order. + */ +@SmithyUnstableApi +final class McpServerInterceptorChain implements McpServerInterceptor { + + private static final InternalLogger LOGGER = InternalLogger.getLogger(McpServerInterceptorChain.class); + private final McpServerInterceptor[] interceptors; + + McpServerInterceptorChain(List interceptors) { + this.interceptors = interceptors.toArray(McpServerInterceptor[]::new); + } + + @Override + public void readBeforeExecution(McpExecutionHook hook) { + RuntimeException error = null; + for (var interceptor : interceptors) { + try { + interceptor.readBeforeExecution(hook); + } catch (RuntimeException e) { + error = swapError("readBeforeExecution", error, e); + } + } + if (error != null) { + throw error; + } + } + + @Override + public JsonRpcRequest modifyBeforeExecution(McpExecutionHook hook) { + var current = hook; + for (var interceptor : interceptors) { + var req = interceptor.modifyBeforeExecution(current); + current = current.withRequest(req); + } + return current.request(); + } + + @Override + public void readAfterExecution(McpExecutionHook hook, JsonRpcResponse response, RuntimeException error) { + for (var interceptor : interceptors) { + try { + interceptor.readAfterExecution(hook, response, error); + } catch (RuntimeException e) { + error = swapError("readAfterExecution", error, e); + } + } + // Always throw the error even if it's the original error. + if (error != null) { + throw error; + } + } + + @Override + public JsonRpcResponse modifyAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + for (var interceptor : interceptors) { + response = interceptor.modifyAfterExecution(hook, response, error); + error = null; + } + return response; + } + + @Override + public void readBeforeToolCall(McpToolCallHook hook) { + RuntimeException error = null; + for (var interceptor : interceptors) { + try { + interceptor.readBeforeToolCall(hook); + } catch (RuntimeException e) { + error = swapError("readBeforeToolCall", error, e); + } + } + if (error != null) { + throw error; + } + } + + @Override + public JsonRpcRequest modifyBeforeToolCall(McpToolCallHook hook) { + var current = hook; + for (var interceptor : interceptors) { + var req = interceptor.modifyBeforeToolCall(current); + current = current.withRequest(req); + } + return current.request(); + } + + @Override + public void readAfterToolCall(McpToolCallHook hook, JsonRpcResponse response, RuntimeException error) { + for (var interceptor : interceptors) { + try { + interceptor.readAfterToolCall(hook, response, error); + } catch (RuntimeException e) { + error = swapError("readAfterToolCall", error, e); + } + } + // Always throw the error even if it's the original error. + if (error != null) { + throw error; + } + } + + @Override + public JsonRpcResponse modifyAfterToolCall(McpToolCallHook hook, JsonRpcResponse response, RuntimeException error) { + for (var interceptor : interceptors) { + response = interceptor.modifyAfterToolCall(hook, response, error); + error = null; + } + return response; + } + + private static RuntimeException swapError(String hook, RuntimeException oldE, RuntimeException newE) { + if (oldE != null && oldE != newE) { + LOGGER.trace("Replacing error after {}: {}", hook, newE.getClass().getName(), newE.getMessage()); + } + return newE; + } +} diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java index 12722c1e6..22e61867f 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java @@ -101,11 +101,11 @@ protected final ProtocolVersion getProtocolVersion() { return protocolVersion.get(); } - abstract CompletableFuture rpc(JsonRpcRequest request); + protected abstract CompletableFuture rpc(JsonRpcRequest request); - abstract void start(); + protected abstract void start(); - abstract CompletableFuture shutdown(); + protected abstract CompletableFuture shutdown(); protected CompletableFuture rpc(String method, ShapeBuilder builder) { JsonRpcRequest request = JsonRpcRequest.builder() diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 662020f8a..d22343b2d 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; @@ -27,6 +28,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Collectors; +import software.amazon.smithy.java.context.Context; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SchemaIndex; import software.amazon.smithy.java.core.schema.SerializableShape; @@ -74,6 +76,7 @@ public final class McpService { private static final InternalLogger LOG = InternalLogger.getLogger(McpService.class); + private static final Context.Key ASYNC_DISPATCH = Context.key("mcp.asyncDispatch"); private static final JsonCodec CODEC = JsonCodec.builder() .settings(JsonSettings.builder() @@ -95,6 +98,7 @@ public final class McpService { private final AtomicReference proxiesInitialized = new AtomicReference<>(false); private final McpMetricsObserver metricsObserver; private final SchemaIndex schemaIndex; + private final McpServerInterceptor interceptor; private Consumer notificationWriter; McpService( @@ -103,7 +107,8 @@ public final class McpService { String name, String version, ToolFilter toolFilter, - McpMetricsObserver metricsObserver + McpMetricsObserver metricsObserver, + McpServerInterceptor interceptor ) { this.services = services; this.schemaIndex = @@ -115,22 +120,82 @@ public final class McpService { this.proxies = proxyList.stream().collect(Collectors.toMap(McpServerProxy::name, p -> p)); this.toolFilter = toolFilter; this.metricsObserver = metricsObserver; + this.interceptor = interceptor; } /** - * Handles a JSON-RPC request synchronously and returns a response. - * For proxy tool calls, the response callback is invoked asynchronously and this method returns null. - * For local operations, the response is returned immediately. + * Handles a JSON-RPC request, invoking interceptor hooks at each stage of the pipeline. + * + *

Responses are delivered through one of two channels: + *

    + *
  • Synchronous (return value): For most requests, the response is returned directly.
  • + *
  • Asynchronous (callback): For proxy tool calls, returns {@code null} and the callback + * is invoked when the proxy responds.
  • + *
  • Neither: For notifications and unknown methods, returns {@code null} and the callback + * is never invoked.
  • + *
* * @param req The JSON-RPC request to handle * @param asyncResponseCallback Callback for async responses (used for proxy calls) * @param protocolVersion The protocol version for this request (may be null) - * @return The response for synchronous operations, or null for async operations + * @return The response for synchronous operations, or null for async/notification operations */ public JsonRpcResponse handleRequest( JsonRpcRequest req, Consumer asyncResponseCallback, ProtocolVersion protocolVersion + ) { + // Zero-interceptor fast path: skip Context creation, hook allocation, and all hook invocations. + if (interceptor == McpServerInterceptor.NOOP) { + return handleRequestDirect(req, asyncResponseCallback, protocolVersion); + } + + var hook = new McpExecutionHook(req, protocolVersion, Context.create()); + JsonRpcResponse response = null; + RuntimeException caughtError = null; + + try { + var currentReq = fireBeforeExecution(hook); + hook = hook.withRequest(currentReq); + + // Dispatch + validate(currentReq); + var method = currentReq.getMethod(); + response = switch (method) { + case "initialize" -> handleInitialize(currentReq); + case "ping" -> handlePing(currentReq); + default -> { + initializeProxies(rpcResponse -> {}); + yield switch (method) { + case "prompts/list" -> handlePromptsList(currentReq); + case "prompts/get" -> handlePromptsGet(currentReq); + case "tools/list" -> handleToolsList(currentReq, protocolVersion); + case "tools/call" -> + handleToolsCall(currentReq, asyncResponseCallback, protocolVersion, hook); + default -> null; + }; + } + }; + if (Boolean.TRUE.equals(hook.context().get(ASYNC_DISPATCH))) { + return null; + } + } catch (RuntimeException e) { + caughtError = e; + } catch (Exception e) { + caughtError = new RuntimeException(e); + } + + return fireAfterExecution(hook, response, caughtError); + } + + /** + * Direct dispatch path used when no interceptor is configured. Avoids Context creation, + * hook allocation, and all hook invocations. + */ + private JsonRpcResponse handleRequestDirect( + JsonRpcRequest req, + Consumer asyncResponseCallback, + ProtocolVersion protocolVersion ) { try { validate(req); @@ -144,7 +209,8 @@ yield switch (method) { case "prompts/list" -> handlePromptsList(req); case "prompts/get" -> handlePromptsGet(req); case "tools/list" -> handleToolsList(req, protocolVersion); - case "tools/call" -> handleToolsCall(req, asyncResponseCallback, protocolVersion); + case "tools/call" -> + handleToolsCallDirect(req, asyncResponseCallback, protocolVersion); default -> null; // Notifications or unknown methods }; } @@ -154,6 +220,38 @@ yield switch (method) { } } + private JsonRpcRequest fireBeforeExecution(McpExecutionHook hook) { + interceptor.readBeforeExecution(hook); + return interceptor.modifyBeforeExecution(hook); + } + + private JsonRpcRequest fireBeforeToolCall(McpToolCallHook hook) { + interceptor.readBeforeToolCall(hook); + return interceptor.modifyBeforeToolCall(hook); + } + + private JsonRpcResponse fireAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + try { + interceptor.readAfterExecution(hook, response, error); + } catch (RuntimeException e) { + error = swapError("readAfterExecution", error, e); + } + try { + response = interceptor.modifyAfterExecution(hook, response, error); + error = null; + } catch (RuntimeException e) { + error = e; + } + if (error != null) { + return createErrorResponse(hook.request(), error); + } + return response; + } + private JsonRpcResponse handleInitialize(JsonRpcRequest req) { if (metricsObserver != null) { var params = req.getParams(); @@ -262,6 +360,103 @@ private JsonRpcResponse handleToolsList(JsonRpcRequest req, ProtocolVersion prot } private JsonRpcResponse handleToolsCall( + JsonRpcRequest req, + Consumer asyncResponseCallback, + ProtocolVersion protocolVersion, + McpExecutionHook executionHook + ) { + if (metricsObserver != null) { + String toolName = req.getParams().getMember("name") != null + ? req.getParams().getMember("name").asString() + : null; + metricsObserver.onToolCall("tools/call", toolName); + } + + var operationName = req.getParams().getMember("name").asString(); + var tool = tools.get(operationName); + + if (tool == null) { + return createErrorResponse(req, "No such tool: " + operationName); + } + + var toolHook = new McpToolCallHook( + req, + protocolVersion, + executionHook.context(), + operationName, + tool.serverId(), + tool.proxy() != null); + + ToolResult result; + try { + var currentReq = fireBeforeToolCall(toolHook); + toolHook = toolHook.withRequest(currentReq); + + if (tool.proxy() != null) { + return dispatchProxy(tool, currentReq, toolHook, executionHook, asyncResponseCallback); + } + + result = dispatchLocal(tool, currentReq, protocolVersion); + } catch (RuntimeException e) { + result = ToolResult.failure(e); + } + + return fireAfterToolCall(toolHook, result.response(), result.error()); + } + + private JsonRpcResponse dispatchProxy( + Tool tool, + JsonRpcRequest currentReq, + McpToolCallHook toolHook, + McpExecutionHook executionHook, + Consumer asyncResponseCallback + ) { + JsonRpcRequest proxyRequest = JsonRpcRequest.builder() + .id(currentReq.getId()) + .method(currentReq.getMethod()) + .params(currentReq.getParams()) + .jsonrpc(currentReq.getJsonrpc()) + .build(); + + executionHook.context().put(ASYNC_DISPATCH, true); + + var finalToolHook = toolHook; + tool.proxy().rpc(proxyRequest).thenAccept(response -> { + var finalResponse = fireAfterToolCall(finalToolHook, response, null); + finalResponse = fireAfterExecution(executionHook, finalResponse, null); + asyncResponseCallback.accept(finalResponse); + }).exceptionally(ex -> { + var proxyError = new RuntimeException("Proxy error: " + ex.getMessage(), ex); + var errorResponse = fireAfterToolCall(finalToolHook, null, proxyError); + if (errorResponse == null) { + errorResponse = createErrorResponse(finalToolHook.request(), proxyError); + } + errorResponse = fireAfterExecution(executionHook, errorResponse, null); + asyncResponseCallback.accept(errorResponse); + return null; + }); + + return null; + } + + private ToolResult dispatchLocal(Tool tool, JsonRpcRequest req, ProtocolVersion protocolVersion) { + try { + var operation = tool.operation(); + var argumentsDoc = req.getParams().getMember("arguments"); + var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); + var input = adaptedDoc.asShape(operation.getApiOperation().inputBuilder()); + var output = operation.function().apply(input, null); + var result = formatStructuredContent(tool, (SerializableShape) output, protocolVersion); + return ToolResult.success(createSuccessResponse(req.getId(), result)); + } catch (RuntimeException e) { + return ToolResult.failure(e); + } + } + + /** + * Direct tool dispatch used when no interceptor is configured. No hooks are invoked. + */ + private JsonRpcResponse handleToolsCallDirect( JsonRpcRequest req, Consumer asyncResponseCallback, ProtocolVersion protocolVersion @@ -280,9 +475,7 @@ private JsonRpcResponse handleToolsCall( return createErrorResponse(req, "No such tool: " + operationName); } - // Check if this tool should be dispatched to a proxy if (tool.proxy() != null) { - // Forward the request to the proxy JsonRpcRequest proxyRequest = JsonRpcRequest.builder() .id(req.getId()) .method(req.getMethod()) @@ -290,18 +483,17 @@ private JsonRpcResponse handleToolsCall( .jsonrpc(req.getJsonrpc()) .build(); - // Get response asynchronously and invoke callback - tool.proxy().rpc(proxyRequest).thenAccept(asyncResponseCallback).exceptionally(ex -> { - LOG.error("Error from proxy RPC", ex); - asyncResponseCallback - .accept(createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex))); - return null; - }); - - // Return null to indicate async handling + tool.proxy() + .rpc(proxyRequest) + .thenAccept(asyncResponseCallback) + .exceptionally(ex -> { + LOG.error("Error from proxy RPC", ex); + asyncResponseCallback.accept( + createErrorResponse(req, new RuntimeException("Proxy error: " + ex.getMessage(), ex))); + return null; + }); return null; } else { - // Handle locally var operation = tool.operation(); var argumentsDoc = req.getParams().getMember("arguments"); var adaptedDoc = adaptDocument(argumentsDoc, operation.getApiOperation().inputSchema()); @@ -312,6 +504,38 @@ private JsonRpcResponse handleToolsCall( } } + private JsonRpcResponse fireAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + try { + interceptor.readAfterToolCall(hook, response, error); + } catch (RuntimeException e) { + error = swapError("readAfterToolCall", error, e); + } + try { + response = interceptor.modifyAfterToolCall(hook, response, error); + error = null; + } catch (RuntimeException e) { + error = e; + } + if (error != null) { + return createErrorResponse(hook.request(), error); + } + return response; + } + + private static RuntimeException swapError(String hook, RuntimeException oldE, RuntimeException newE) { + if (oldE != null && oldE != newE) { + LOG.trace("Replacing error after {}: {} -> {}", + hook, + oldE.getClass().getName(), + newE.getClass().getName()); + } + return newE; + } + /** * Sets the notification writer for forwarding notifications from proxies. */ @@ -927,6 +1151,16 @@ private record Tool( } } + private record ToolResult(JsonRpcResponse response, RuntimeException error) { + static ToolResult success(JsonRpcResponse response) { + return new ToolResult(response, null); + } + + static ToolResult failure(RuntimeException error) { + return new ToolResult(null, error); + } + } + private static String appendSentences(String first, String second) { first = first.trim(); if (!first.endsWith(".")) { @@ -1165,6 +1399,7 @@ public static Builder builder() { public static class Builder { private Map services = new HashMap<>(); private List proxyList = new ArrayList<>(); + private McpServerInterceptor interceptor = McpServerInterceptor.NOOP; private String name = "mcp-server"; private String version = "1.0.0"; private ToolFilter toolFilter = (serverId, toolName) -> true; @@ -1200,8 +1435,19 @@ public Builder metricsObserver(McpMetricsObserver metricsObserver) { return this; } + /** + * Sets the server interceptor. Use {@link McpServerInterceptor#chain(List)} to compose + * multiple interceptors into one. + * + * @see McpServerInterceptor for hook descriptions and the execution lifecycle + */ + public Builder interceptor(McpServerInterceptor interceptor) { + this.interceptor = Objects.requireNonNull(interceptor, "interceptor"); + return this; + } + public McpService build() { - return new McpService(services, proxyList, name, version, toolFilter, metricsObserver); + return new McpService(services, proxyList, name, version, toolFilter, metricsObserver, interceptor); } } } diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpToolCallHook.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpToolCallHook.java new file mode 100644 index 000000000..cff5813af --- /dev/null +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpToolCallHook.java @@ -0,0 +1,67 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.mcp.server; + +import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.mcp.model.JsonRpcRequest; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Hook data available during tool call processing. Extends {@link McpExecutionHook} with + * tool-specific information. Passed to tool-scoped hooks in {@link McpServerInterceptor}. + */ +@SmithyUnstableApi +public class McpToolCallHook extends McpExecutionHook { + + private final String toolName; + private final String serverId; + private final boolean isProxy; + + McpToolCallHook( + JsonRpcRequest request, + ProtocolVersion protocolVersion, + Context context, + String toolName, + String serverId, + boolean isProxy + ) { + super(request, protocolVersion, context); + this.toolName = toolName; + this.serverId = serverId; + this.isProxy = isProxy; + } + + /** + * The name of the tool being invoked. + */ + public String toolName() { + return toolName; + } + + /** + * The server ID that owns this tool. + */ + public String serverId() { + return serverId; + } + + /** + * Whether this tool is dispatched to a remote proxy rather than handled locally. + */ + public boolean isProxy() { + return isProxy; + } + + /** + * Returns a new hook with the given request, or the same hook if unchanged. + */ + @Override + public McpToolCallHook withRequest(JsonRpcRequest request) { + return this.request() == request + ? this + : new McpToolCallHook(request, protocolVersion(), context(), toolName, serverId, isProxy); + } +} diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java index aef941dec..3241b344e 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java @@ -1695,7 +1695,7 @@ public List listPrompts() { } @Override - CompletableFuture rpc(JsonRpcRequest request) { + protected CompletableFuture rpc(JsonRpcRequest request) { // Notifications have no ID if (request.getId() == null) { sentNotifications.add(request.getMethod()); @@ -1714,10 +1714,10 @@ List getSentNotifications() { } @Override - void start() {} + protected void start() {} @Override - CompletableFuture shutdown() { + protected CompletableFuture shutdown() { return CompletableFuture.completedFuture(null); } @@ -1730,4 +1730,840 @@ void sendNotification(JsonRpcRequest notification) { notify(notification); } } + + // ==================== Read-only hooks ==================== + + @Test + void testReadBeforeAndAfterExecution() { + var capturedMethod = new AtomicReference(); + var capturedResponse = new AtomicReference(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + capturedMethod.set(hook.request().getMethod()); + } + + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + capturedResponse.set(response); + } + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + read(); + + assertEquals("ping", capturedMethod.get()); + assertNotNull(capturedResponse.get()); + } + + @Test + void testReadBeforeAndAfterToolCallLocal() { + var capturedToolName = new AtomicReference(); + var capturedServerId = new AtomicReference(); + var capturedIsProxy = new AtomicReference(); + var afterToolCallFired = new AtomicReference<>(false); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeToolCall(McpToolCallHook hook) { + capturedToolName.set(hook.toolName()); + capturedServerId.set(hook.serverId()); + capturedIsProxy.set(hook.isProxy()); + } + + @Override + public void readAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + afterToolCallFired.set(true); + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + assertEquals("NoIOOperation", capturedToolName.get()); + assertEquals("test-mcp", capturedServerId.get()); + assertFalse(capturedIsProxy.get()); + assertTrue(afterToolCallFired.get()); + } + + @Test + void testReadBeforeAndAfterToolCallProxy() { + var capturedToolName = new AtomicReference(); + var capturedIsProxy = new AtomicReference(); + var afterToolCallFired = new AtomicReference<>(false); + var mockProxy = new CacheTestProxy(new AtomicInteger(0)); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addService(mockProxy) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeToolCall(McpToolCallHook hook) { + capturedToolName.set(hook.toolName()); + capturedIsProxy.set(hook.isProxy()); + } + + @Override + public void readAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + afterToolCallFired.set(true); + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("test-tool"), + "arguments", + Document.of(Map.of())))); + read(); + + assertEquals("test-tool", capturedToolName.get()); + assertTrue(capturedIsProxy.get()); + assertTrue(afterToolCallFired.get()); + } + + @Test + void testReadAfterExecutionAlwaysFires() { + var afterCount = new AtomicInteger(0); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + afterCount.incrementAndGet(); + } + }) + .build(); + + server.start(); + + write("ping", Document.of(Map.of())); + read(); + assertEquals(1, afterCount.get()); + + writeNotification("notifications/initialized", Document.of(Map.of())); + output.assertNoOutput(); + assertEquals(2, afterCount.get()); + + write("ping", Document.of(Map.of())); + read(); + assertEquals(3, afterCount.get()); + } + + @Test + void testReadAfterToolCallFiresWhenBeforeToolCallThrows() { + var afterToolCallFired = new AtomicReference<>(false); + var capturedErrorMessage = new AtomicReference(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeToolCall(McpToolCallHook hook) { + throw new RuntimeException("blocked"); + } + + @Override + public void readAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + afterToolCallFired.set(true); + if (error != null) { + capturedErrorMessage.set(error.getMessage()); + } + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + assertTrue(afterToolCallFired.get()); + assertEquals("blocked", capturedErrorMessage.get()); + } + + @Test + void testReadAfterExecutionFiresForProxyToolCall() { + var afterExecutionFired = new AtomicReference<>(false); + var mockProxy = new CacheTestProxy(new AtomicInteger(0)); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .addService(mockProxy) + .interceptor(new McpServerInterceptor() { + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + afterExecutionFired.set(true); + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("test-tool"), + "arguments", + Document.of(Map.of())))); + read(); + + assertTrue(afterExecutionFired.get()); + } + + @Test + void testReadBeforeExecutionThrowSkipsToolHooks() { + var beforeToolCallFired = new AtomicReference<>(false); + var afterExecutionError = new AtomicReference(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + if ("tools/call".equals(hook.request().getMethod())) { + throw new RuntimeException("execution-blocked"); + } + } + + @Override + public void readBeforeToolCall(McpToolCallHook hook) { + beforeToolCallFired.set(true); + } + + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + if ("tools/call".equals(hook.request().getMethod())) { + afterExecutionError.set(error); + } + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + assertFalse(beforeToolCallFired.get()); + assertNotNull(afterExecutionError.get()); + assertEquals("execution-blocked", afterExecutionError.get().getMessage()); + } + + @Test + void testContextPassesBetweenReadHooks() { + var duration = new AtomicReference(); + software.amazon.smithy.java.context.Context.Key START_KEY = + software.amazon.smithy.java.context.Context.key("start"); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + hook.context().put(START_KEY, System.nanoTime()); + } + + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + long start = hook.context().get(START_KEY); + duration.set(System.nanoTime() - start); + } + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + read(); + + assertNotNull(duration.get()); + assertTrue(duration.get() > 0); + } + + // ==================== Modify-only hooks ==================== + + @Test + void testModifyBeforeExecutionRewritesRequest() { + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public JsonRpcRequest modifyBeforeExecution(McpExecutionHook hook) { + return JsonRpcRequest.builder() + .id(hook.request().getId()) + .method("ping") + .params(Document.of(Map.of())) + .jsonrpc("2.0") + .build(); + } + }) + .build(); + + server.start(); + write("tools/list", Document.of(Map.of())); + var response = read(); + assertTrue(response.getResult().asStringMap().isEmpty()); + } + + @Test + void testModifyBeforeToolCallModifiesRequest() { + var modifyHookCalled = new AtomicReference<>(false); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public JsonRpcRequest modifyBeforeToolCall(McpToolCallHook hook) { + modifyHookCalled.set(true); + return hook.request(); + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + var response = read(); + + assertTrue(modifyHookCalled.get()); + assertNotNull(response); + } + + @Test + void testModifyAfterExecutionTransformsResponse() { + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public JsonRpcResponse modifyAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + if (error != null) { + throw error; + } + return JsonRpcResponse.builder() + .id(hook.request().getId()) + .result(Document.of(Map.of("modified", Document.of("true")))) + .jsonrpc("2.0") + .build(); + } + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + var response = read(); + + assertEquals("true", response.getResult().getMember("modified").asString()); + } + + @Test + void testModifyAfterToolCallTransformsResponse() { + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public JsonRpcResponse modifyAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + // Always return custom response, ignoring any tool error + return JsonRpcResponse.builder() + .id(hook.request().getId()) + .result(Document.of(Map.of("tool-modified", Document.of("true")))) + .jsonrpc("2.0") + .build(); + } + }) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + var response = read(); + + assertEquals("true", response.getResult().getMember("tool-modified").asString()); + } + + // ==================== Error handling ==================== + + @Test + void testReadBeforeExecutionThrowShortCircuits() { + var afterExecutionError = new AtomicReference(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + throw new RuntimeException("blocked"); + } + + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + afterExecutionError.set(error); + } + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + var response = read(); + + assertNotNull(response.getError()); + assertTrue(response.getError().getMessage().contains("blocked")); + assertNotNull(afterExecutionError.get()); + assertEquals("blocked", afterExecutionError.get().getMessage()); + } + + @Test + void testModifyAfterExecutionCanRecoverFromError() { + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + throw new RuntimeException("original-error"); + } + + @Override + public JsonRpcResponse modifyAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + // Recover from the error by returning a success response + return JsonRpcResponse.builder() + .id(hook.request().getId()) + .result(Document.of(Map.of())) + .jsonrpc("2.0") + .build(); + } + }) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + var response = read(); + + assertNull(response.getError()); + assertNotNull(response.getResult()); + } + + // ==================== Chain composition ==================== + + @Test + void testChainReadHooksInvokedInOrder() { + var order = new ArrayList(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(McpServerInterceptor.chain(List.of( + new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + order.add("A-before"); + } + + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + order.add("A-after"); + } + }, + new McpServerInterceptor() { + @Override + public void readBeforeExecution(McpExecutionHook hook) { + order.add("B-before"); + } + + @Override + public void readAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + order.add("B-after"); + } + }))) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + read(); + + assertEquals(List.of("A-before", "B-before", "A-after", "B-after"), order); + } + + @Test + void testChainModifyBeforeToolCallPropagatesRequest() { + var capturedRequest = new AtomicReference(); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(McpServerInterceptor.chain(List.of( + new McpServerInterceptor() { + @Override + public JsonRpcRequest modifyBeforeToolCall(McpToolCallHook hook) { + var params = hook.request().getParams().asStringMap(); + var newParams = new java.util.HashMap<>(params); + newParams.put("injected", Document.of("from-first")); + return JsonRpcRequest.builder() + .id(hook.request().getId()) + .method(hook.request().getMethod()) + .params(Document.of(newParams)) + .jsonrpc(hook.request().getJsonrpc()) + .build(); + } + }, + new McpServerInterceptor() { + @Override + public JsonRpcRequest modifyBeforeToolCall(McpToolCallHook hook) { + capturedRequest.set(hook.request()); + return hook.request(); + } + }))) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + read(); + + assertNotNull(capturedRequest.get()); + var injected = capturedRequest.get().getParams().getMember("injected"); + assertNotNull(injected); + assertEquals("from-first", injected.asString()); + } + + @Test + void testChainModifyAfterExecutionErrorPropagates() { + // When a modify hook throws, the exception propagates immediately (no try/catch in + // chain modify hooks, matching ClientInterceptorChain). The caller converts it to + // an error response. + var secondInterceptorCalled = new AtomicReference<>(false); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(McpServerInterceptor.chain(List.of( + new McpServerInterceptor() { + @Override + public JsonRpcResponse modifyAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + throw new RuntimeException("first-error"); + } + }, + new McpServerInterceptor() { + @Override + public JsonRpcResponse modifyAfterExecution( + McpExecutionHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + secondInterceptorCalled.set(true); + return response; + } + }))) + .build(); + + server.start(); + write("ping", Document.of(Map.of())); + var response = read(); + + // Second interceptor never runs — exception propagates immediately + assertFalse(secondInterceptorCalled.get()); + assertNotNull(response.getError()); + assertTrue(response.getError().getMessage().contains("first-error")); + } + + @Test + void testChainModifyAfterToolCallErrorPropagates() { + // When a modify hook throws, the exception propagates immediately (no try/catch in + // chain modify hooks, matching ClientInterceptorChain). The caller converts it to + // an error response. + var secondInterceptorCalled = new AtomicReference<>(false); + + server = McpServer.builder() + .name("smithy-mcp-server") + .input(input) + .output(output) + .addService("test-mcp", + ProxyService.builder() + .service(ShapeId.from("smithy.test#TestService")) + .proxyEndpoint("http://localhost") + .model(MODEL) + .build()) + .interceptor(McpServerInterceptor.chain(List.of( + new McpServerInterceptor() { + @Override + public JsonRpcResponse modifyAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + throw new RuntimeException("tool-error"); + } + }, + new McpServerInterceptor() { + @Override + public JsonRpcResponse modifyAfterToolCall( + McpToolCallHook hook, + JsonRpcResponse response, + RuntimeException error + ) { + secondInterceptorCalled.set(true); + return response; + } + }))) + .build(); + + server.start(); + initializeWithProtocolVersion(null); + + write("tools/call", + Document.of(Map.of( + "name", + Document.of("NoIOOperation"), + "arguments", + Document.of(Map.of())))); + var response = read(); + + // Second interceptor never runs — exception propagates immediately + assertFalse(secondInterceptorCalled.get()); + assertNotNull(response.getError()); + assertTrue(response.getError().getMessage().contains("tool-error")); + } }