diff --git a/src/main/java/com/google/cloud/mcp/HttpMcpToolboxClient.java b/src/main/java/com/google/cloud/mcp/HttpMcpToolboxClient.java index fd699fe..40b824d 100644 --- a/src/main/java/com/google/cloud/mcp/HttpMcpToolboxClient.java +++ b/src/main/java/com/google/cloud/mcp/HttpMcpToolboxClient.java @@ -346,13 +346,20 @@ private Map handleListToolsResponse(HttpResponse } } + Object defaultValue = null; + if (propNode.has("default")) { + JsonNode defNode = propNode.get("default"); + defaultValue = objectMapper.treeToValue(defNode, Object.class); + } + params.add( new ToolDefinition.Parameter( paramName, paramType, requiredSet.contains(paramName), paramDesc, - authSources)); + authSources, + defaultValue)); } } diff --git a/src/main/java/com/google/cloud/mcp/Tool.java b/src/main/java/com/google/cloud/mcp/Tool.java index e32d87e..dc0bebd 100644 --- a/src/main/java/com/google/cloud/mcp/Tool.java +++ b/src/main/java/com/google/cloud/mcp/Tool.java @@ -180,6 +180,11 @@ private void validateAndSanitizeArgs(Map args) { for (ToolDefinition.Parameter param : definition.parameters()) { Object value = args.get(param.name()); + if (value == null && param.defaultValue() != null) { + value = param.defaultValue(); + args.put(param.name(), value); + } + // A. Check Required Parameters if (param.required() && value == null) { throw new IllegalArgumentException( diff --git a/src/main/java/com/google/cloud/mcp/ToolDefinition.java b/src/main/java/com/google/cloud/mcp/ToolDefinition.java index 586bc1d..07ff42c 100644 --- a/src/main/java/com/google/cloud/mcp/ToolDefinition.java +++ b/src/main/java/com/google/cloud/mcp/ToolDefinition.java @@ -17,6 +17,7 @@ package com.google.cloud.mcp; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; /** @@ -36,6 +37,7 @@ public record ToolDefinition( * @param required Whether the parameter is required. * @param description A description of the parameter. * @param authSources A list of authentication sources for this parameter. + * @param defaultValue The default value for the parameter. */ @JsonIgnoreProperties(ignoreUnknown = true) public record Parameter( @@ -43,6 +45,6 @@ public record Parameter( String type, boolean required, String description, - List authSources // Maps services to parameters - ) {} + List authSources, // Maps services to parameters + @JsonProperty("default") Object defaultValue) {} } diff --git a/src/test/java/com/google/cloud/mcp/ToolTest.java b/src/test/java/com/google/cloud/mcp/ToolTest.java new file mode 100644 index 0000000..15ec7cd --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/ToolTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +class ToolTest { + + @Test + void testDefaultValueInjection() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + ToolDefinition.Parameter paramNoDefault = + new ToolDefinition.Parameter("param2", "string", false, "Another parameter", null, null); + + ToolDefinition def = + new ToolDefinition("A test tool", List.of(paramWithDefault, paramNoDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param2", "provided_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "default_value", + capturedArgs.get("param1"), + "Default value should be injected when not provided"); + assertEquals("provided_value", capturedArgs.get("param2"), "Provided value should be kept"); + } + + @Test + void testDefaultValueNotOverwritten() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param1", "custom_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "custom_value", + capturedArgs.get("param1"), + "Provided value should not be overwritten by default value"); + } +}