diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 000000000..6d2890793 --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +8.5.0 diff --git a/policy/src/main/java/dev/cel/policy/CelPolicy.java b/policy/src/main/java/dev/cel/policy/CelPolicy.java index 9980d0cad..b73d9e0b1 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicy.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicy.java @@ -27,6 +27,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -77,8 +78,7 @@ public abstract static class Builder { public abstract Builder setPolicySource(CelPolicySource policySource); - // This should stay package-private to encourage add/set methods to be used instead. - abstract ImmutableMap.Builder metadataBuilder(); + private final HashMap metadata = new HashMap<>(); public abstract Builder setMetadata(ImmutableMap value); @@ -90,6 +90,10 @@ public List imports() { return Collections.unmodifiableList(importList); } + public Map metadata() { + return Collections.unmodifiableMap(metadata); + } + @CanIgnoreReturnValue public Builder addImport(Import value) { importList.add(value); @@ -104,13 +108,13 @@ public Builder addImports(Collection values) { @CanIgnoreReturnValue public Builder putMetadata(String key, Object value) { - metadataBuilder().put(key, value); + metadata.put(key, value); return this; } @CanIgnoreReturnValue public Builder putMetadata(Map map) { - metadataBuilder().putAll(map); + metadata.putAll(map); return this; } diff --git a/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel b/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel new file mode 100644 index 000000000..6c847e0a6 --- /dev/null +++ b/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel @@ -0,0 +1,29 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//policy/testing:__pkg__", + ], +) + +java_library( + name = "policy_test_suite_helper", + testonly = True, + srcs = [ + "PolicyTestSuiteHelper.java", + ], + deps = [ + "//bundle:cel", + "//common:cel_ast", + "//common:compiler_common", + "//common/formats:value_string", + "//policy", + "//policy:parser", + "//policy:parser_builder", + "//policy:policy_parser_context", + "//runtime:evaluation_exception", + "@maven//:com_google_guava_guava", + "@maven//:org_yaml_snakeyaml", + ], +) diff --git a/policy/src/main/java/dev/cel/policy/testing/PolicyTestSuiteHelper.java b/policy/src/main/java/dev/cel/policy/testing/PolicyTestSuiteHelper.java new file mode 100644 index 000000000..99bcab727 --- /dev/null +++ b/policy/src/main/java/dev/cel/policy/testing/PolicyTestSuiteHelper.java @@ -0,0 +1,192 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.policy.testing; + +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.runtime.CelEvaluationException; +import java.io.IOException; +import java.net.URL; +import java.util.List; +import java.util.Map; +import org.yaml.snakeyaml.LoaderOptions; +import org.yaml.snakeyaml.Yaml; +import org.yaml.snakeyaml.constructor.Constructor; + +/** + * Helper to assist with policy testing. + * + **/ +public final class PolicyTestSuiteHelper { + + /** + * TODO + */ + public static PolicyTestSuite readTestSuite(String path) throws IOException { + Yaml yaml = new Yaml(new Constructor(PolicyTestSuite.class, new LoaderOptions())); + String testContent = readFile(path); + + return yaml.load(testContent); + } + + /** + * TODO + * @param yamlPath + * @return + * @throws IOException + */ + public static String readFromYaml(String yamlPath) throws IOException { + return readFile(yamlPath); + } + + /** + * TestSuite describes a set of tests divided by section. + * + *

Visibility must be public for YAML deserialization to work. This is effectively + * package-private since the outer class is. + */ + @VisibleForTesting + public static final class PolicyTestSuite { + private String description; + private List section; + + public void setDescription(String description) { + this.description = description; + } + + public void setSection(List section) { + this.section = section; + } + + public String getDescription() { + return description; + } + + public List getSection() { + return section; + } + + @VisibleForTesting + public static final class PolicyTestSection { + private String name; + private List tests; + + public void setName(String name) { + this.name = name; + } + + public void setTests(List tests) { + this.tests = tests; + } + + public String getName() { + return name; + } + + public List getTests() { + return tests; + } + + @VisibleForTesting + public static final class PolicyTestCase { + private String name; + private Map input; + private String output; + + public void setName(String name) { + this.name = name; + } + + public void setInput(Map input) { + this.input = input; + } + + public void setOutput(String output) { + this.output = output; + } + + public String getName() { + return name; + } + + public Map getInput() { + return input; + } + + public String getOutput() { + return output; + } + + @VisibleForTesting + public static final class PolicyTestInput { + private Object value; + private String expr; + + public Object getValue() { + return value; + } + + public void setValue(Object value) { + this.value = value; + } + + public String getExpr() { + return expr; + } + + public void setExpr(String expr) { + this.expr = expr; + } + } + + public ImmutableMap toInputMap(Cel cel) + throws CelValidationException, CelEvaluationException { + ImmutableMap.Builder inputBuilder = ImmutableMap.builderWithExpectedSize( + input.size()); + for (Map.Entry entry : input.entrySet()) { + String exprInput = entry.getValue().getExpr(); + if (isNullOrEmpty(exprInput)) { + inputBuilder.put(entry.getKey(), entry.getValue().getValue()); + } else { + CelAbstractSyntaxTree exprInputAst = cel.compile(exprInput).getAst(); + inputBuilder.put(entry.getKey(), cel.createProgram(exprInputAst).eval()); + } + } + + return inputBuilder.buildOrThrow(); + } + } + } + } + + + private static URL getResource(String path) { + return Resources.getResource(Ascii.toLowerCase(path)); + } + + private static String readFile(String path) throws IOException { + return Resources.toString(getResource(path), UTF_8); + } + + private PolicyTestSuiteHelper() {} +} diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 9106caf70..d51b5dc3e 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -33,6 +33,7 @@ java_library( "//policy:policy_parser_context", "//policy:source", "//policy:validation_exception", + "//policy/testing:policy_test_suite_helper", "//runtime", "//runtime:function_binding", "//runtime:late_function_binding", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index fa0da8a9a..c38e1f8e0 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -14,9 +14,8 @@ package dev.cel.policy; -import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.truth.Truth.assertThat; -import static dev.cel.policy.PolicyTestHelper.readFromYaml; +import static dev.cel.policy.testing.PolicyTestSuiteHelper.readFromYaml; import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; @@ -38,17 +37,15 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparserFactory; import dev.cel.policy.PolicyTestHelper.K8sTagHandler; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; -import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase.PolicyTestInput; import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.testing.testdata.SingleFileProto.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; -import java.util.Map; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -215,17 +212,8 @@ public void evaluateYamlPolicy_withCanonicalTestData( // Compile then evaluate the policy CelAbstractSyntaxTree compiledPolicyAst = CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); - ImmutableMap.Builder inputBuilder = ImmutableMap.builder(); - for (Map.Entry entry : testData.testCase.getInput().entrySet()) { - String exprInput = entry.getValue().getExpr(); - if (isNullOrEmpty(exprInput)) { - inputBuilder.put(entry.getKey(), entry.getValue().getValue()); - } else { - CelAbstractSyntaxTree exprInputAst = cel.compile(exprInput).getAst(); - inputBuilder.put(entry.getKey(), cel.createProgram(exprInputAst).eval()); - } - } - Object evalResult = cel.createProgram(compiledPolicyAst).eval(inputBuilder.buildOrThrow()); + ImmutableMap inputMap = testData.testCase.toInputMap(cel); + Object evalResult = cel.createProgram(compiledPolicyAst).eval(inputMap); // Assert // Note that policies may either produce an optional or a non-optional result, diff --git a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java index 8d9e0084b..dab91afd7 100644 --- a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java +++ b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java @@ -1,42 +1,19 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package dev.cel.policy; -import static java.nio.charset.StandardCharsets.UTF_8; +import static dev.cel.policy.testing.PolicyTestSuiteHelper.readFromYaml; +import static dev.cel.policy.testing.PolicyTestSuiteHelper.readTestSuite; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Ascii; -import com.google.common.io.Resources; import dev.cel.common.formats.ValueString; import dev.cel.policy.CelPolicy.Match; import dev.cel.policy.CelPolicy.Match.Result; import dev.cel.policy.CelPolicy.Rule; import dev.cel.policy.CelPolicyParser.TagVisitor; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite; import java.io.IOException; -import java.net.URL; -import java.util.List; -import java.util.Map; -import org.yaml.snakeyaml.LoaderOptions; -import org.yaml.snakeyaml.Yaml; -import org.yaml.snakeyaml.constructor.Constructor; import org.yaml.snakeyaml.nodes.Node; import org.yaml.snakeyaml.nodes.SequenceNode; -/** Package-private class to assist with policy testing. */ final class PolicyTestHelper { - enum TestYamlPolicy { NESTED_RULE( "nested_rule", @@ -135,128 +112,11 @@ String readConfigYamlContent() throws IOException { } PolicyTestSuite readTestYamlContent() throws IOException { - Yaml yaml = new Yaml(new Constructor(PolicyTestSuite.class, new LoaderOptions())); - String testContent = readFile(String.format("policy/%s/tests.yaml", name)); - - return yaml.load(testContent); - } - } - - static String readFromYaml(String yamlPath) throws IOException { - return readFile(yamlPath); - } - - /** - * TestSuite describes a set of tests divided by section. - * - *

Visibility must be public for YAML deserialization to work. This is effectively - * package-private since the outer class is. - */ - @VisibleForTesting - public static final class PolicyTestSuite { - private String description; - private List section; - - public void setDescription(String description) { - this.description = description; - } - - public void setSection(List section) { - this.section = section; - } - - public String getDescription() { - return description; - } - - public List getSection() { - return section; - } - - @VisibleForTesting - public static final class PolicyTestSection { - private String name; - private List tests; - - public void setName(String name) { - this.name = name; - } - - public void setTests(List tests) { - this.tests = tests; - } - - public String getName() { - return name; - } - - public List getTests() { - return tests; - } - - @VisibleForTesting - public static final class PolicyTestCase { - private String name; - private Map input; - private String output; - - public void setName(String name) { - this.name = name; - } - - public void setInput(Map input) { - this.input = input; - } - - public void setOutput(String output) { - this.output = output; - } - - public String getName() { - return name; - } - - public Map getInput() { - return input; - } - - public String getOutput() { - return output; - } - - @VisibleForTesting - public static final class PolicyTestInput { - private Object value; - private String expr; - - public Object getValue() { - return value; - } - - public void setValue(Object value) { - this.value = value; - } - - public String getExpr() { - return expr; - } - - public void setExpr(String expr) { - this.expr = expr; - } - } - } + String testPath = String.format("policy/%s/tests.yaml", name); + return readTestSuite(testPath); } } - private static URL getResource(String path) { - return Resources.getResource(Ascii.toLowerCase(path)); - } - - private static String readFile(String path) throws IOException { - return Resources.toString(getResource(path), UTF_8); - } - static class K8sTagHandler implements TagVisitor { @Override @@ -360,3 +220,5 @@ public void visitMatchTag( private PolicyTestHelper() {} } + + diff --git a/policy/testing/BUILD.bazel b/policy/testing/BUILD.bazel new file mode 100644 index 000000000..834c0a978 --- /dev/null +++ b/policy/testing/BUILD.bazel @@ -0,0 +1,12 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//:internal"], +) + +java_library( + name = "policy_test_suite_helper", + testonly = True, + exports = ["//policy/src/main/java/dev/cel/policy/testing:policy_test_suite_helper"], +) diff --git a/tools/ai/BUILD.bazel b/tools/ai/BUILD.bazel new file mode 100644 index 000000000..97ee7aeef --- /dev/null +++ b/tools/ai/BUILD.bazel @@ -0,0 +1,17 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], +) + +java_library( + name = "agentic_policy_compiler", + exports = ["//tools/src/main/java/dev/cel/tools/ai:agentic_policy_compiler"], +) + +alias( + name = "test_policies", + testonly = True, + actual = "//tools/src/test/resources:test_policies", +) diff --git a/tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java b/tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java new file mode 100644 index 000000000..778837f80 --- /dev/null +++ b/tools/src/main/java/dev/cel/tools/ai/AgenticPolicyCompiler.java @@ -0,0 +1,176 @@ +package dev.cel.tools.ai; + +import static dev.cel.common.formats.YamlHelper.assertYamlType; + +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.formats.ValueString; +import dev.cel.common.formats.YamlHelper.YamlNodeType; +import dev.cel.policy.CelPolicy; +import dev.cel.policy.CelPolicy.Match; +import dev.cel.policy.CelPolicy.Match.Result; +import dev.cel.policy.CelPolicy.Rule; +import dev.cel.policy.CelPolicy.Variable; +import dev.cel.policy.CelPolicyCompiler; +import dev.cel.policy.CelPolicyCompilerFactory; +import dev.cel.policy.CelPolicyParser; +import dev.cel.policy.CelPolicyParser.TagVisitor; +import dev.cel.policy.CelPolicyParserFactory; +import dev.cel.policy.CelPolicyValidationException; +import dev.cel.policy.PolicyParserContext; +import java.util.ArrayList; +import java.util.List; +import org.yaml.snakeyaml.nodes.MappingNode; +import org.yaml.snakeyaml.nodes.Node; +import org.yaml.snakeyaml.nodes.NodeTuple; +import org.yaml.snakeyaml.nodes.ScalarNode; +import org.yaml.snakeyaml.nodes.SequenceNode; + +public final class AgenticPolicyCompiler { + + private static final CelPolicyParser POLICY_PARSER = + CelPolicyParserFactory.newYamlParserBuilder() + .addTagVisitor(new AgenticPolicyTagHandler()) + .build(); + + private final CelPolicyCompiler policyCompiler; + + public static AgenticPolicyCompiler newInstance(Cel cel) { + return new AgenticPolicyCompiler(cel); + } + + private AgenticPolicyCompiler(Cel cel) { + this.policyCompiler = CelPolicyCompilerFactory.newPolicyCompiler(cel).build(); + } + + public CelAbstractSyntaxTree compile(String policySource) throws CelPolicyValidationException { + CelPolicy policy = POLICY_PARSER.parse(policySource); + return policyCompiler.compile(policy); + } + + private static class AgenticPolicyTagHandler implements TagVisitor { + + @Override + public void visitPolicyTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder) { + + switch (tagName) { + case "default": + if (assertYamlType(ctx, id, node, YamlNodeType.STRING)) { + policyBuilder.putMetadata("default_effect", ((ScalarNode) node).getValue()); + } + break; + + case "variables": + if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) return; + List parsedVariables = new ArrayList<>(); + SequenceNode varList = (SequenceNode) node; + + for (Node varNode : varList.getValue()) { + if (assertYamlType(ctx, ctx.collectMetadata(varNode), varNode, YamlNodeType.MAP)) { + MappingNode map = (MappingNode) varNode; + for (NodeTuple tuple : map.getValue()) { + String name = ((ScalarNode) tuple.getKeyNode()).getValue(); + String expr = ((ScalarNode) tuple.getValueNode()).getValue(); + parsedVariables.add(Variable.newBuilder() + .setName(ValueString.of(ctx.collectMetadata(tuple.getKeyNode()), name)) + .setExpression(ValueString.of(ctx.collectMetadata(tuple.getValueNode()), expr)) + .build()); + } + } + } + policyBuilder.putMetadata("top_level_variables", parsedVariables); + break; + + case "rules": + if (!assertYamlType(ctx, id, node, YamlNodeType.LIST)) return; + SequenceNode rulesNode = (SequenceNode) node; + Rule.Builder subRuleBuilder = Rule.newBuilder(ctx.collectMetadata(rulesNode)); + + if (policyBuilder.metadata().containsKey("top_level_variables")) { + List variables = (List) policyBuilder.metadata().get("top_level_variables"); + subRuleBuilder.addVariables(variables); + } + + for (Node ruleNode : rulesNode.getValue()) { + policyBuilder.putMetadata("effect", "deny"); + policyBuilder.putMetadata("message", ""); + policyBuilder.putMetadata("output_expr", null); + + Match subMatch = ctx.parseMatch(ctx, policyBuilder, ruleNode); + subRuleBuilder.addMatches(subMatch); + } + + if (policyBuilder.metadata().containsKey("default_effect")) { + String defaultEffect = policyBuilder.metadata().get("default_effect").toString(); + Match defaultMatch = Match.newBuilder(ctx.nextId()) + .setCondition(ValueString.of(ctx.nextId(), "true")) + .setResult(Result.ofOutput(ValueString.of(ctx.nextId(), generateMessageOutput(defaultEffect, "")))) + .build(); + subRuleBuilder.addMatches(defaultMatch); + } + policyBuilder.setRule(subRuleBuilder.build()); + break; + + default: + TagVisitor.super.visitPolicyTag(ctx, id, tagName, node, policyBuilder); + break; + } + } + + @Override + public void visitMatchTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder, + Match.Builder matchBuilder) { + + switch (tagName) { + case "description": + if (assertYamlType(ctx, id, node, YamlNodeType.STRING)) { + matchBuilder.setExplanation(ValueString.of(ctx.nextId(), ((ScalarNode) node).getValue())); + } + break; + + case "effect": + case "message": + case "output_expr": + if (!assertYamlType(ctx, id, node, YamlNodeType.STRING)) return; + + String value = ((ScalarNode) node).getValue(); + policyBuilder.putMetadata(tagName, value); + + String currentEffect = (String) policyBuilder.metadata().get("effect"); + String currentMessage = (String) policyBuilder.metadata().get("message"); + String currentOutputExpr = (String) policyBuilder.metadata().get("output_expr"); + + String finalOutput = (currentOutputExpr != null) + ? generateDetailsOutput(currentEffect, currentOutputExpr) + : generateMessageOutput(currentEffect, currentMessage); + + matchBuilder.setResult(Result.ofOutput(ValueString.of(ctx.nextId(), finalOutput))); + break; + + default: + TagVisitor.super.visitMatchTag(ctx, id, tagName, node, policyBuilder, matchBuilder); + break; + } + } + + // The following will likely benefit from having a concrete output structure + private static String generateMessageOutput(String effect, String message) { + String safeMessage = message.replace("'", "\\'"); + return String.format("{'effect': '%s', 'message': '%s'}", effect, safeMessage); + } + + private static String generateDetailsOutput(String effect, String outputExpression) { + return String.format("{'effect': '%s', 'details': %s}", effect, outputExpression); + } + } +} diff --git a/tools/src/main/java/dev/cel/tools/ai/BUILD.bazel b/tools/src/main/java/dev/cel/tools/ai/BUILD.bazel new file mode 100644 index 000000000..6cbd4f62d --- /dev/null +++ b/tools/src/main/java/dev/cel/tools/ai/BUILD.bazel @@ -0,0 +1,48 @@ +load("@com_google_protobuf//bazel:java_proto_library.bzl", "java_proto_library") +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = [ + "//:license", + ], + default_visibility = ["//visibility:public"], + # default_visibility = [ + # "//tools/ai:__pkg__", + # ], +) + +java_library( + name = "agentic_policy_compiler", + srcs = ["AgenticPolicyCompiler.java"], + deps = [ + ":agent_context_java_proto", + "//bundle:cel", + "//common:cel_ast", + "//common/formats:value_string", + "//common/formats:yaml_helper", + "//common/types", + "//policy", + "//policy:compiler", + "//policy:compiler_factory", + "//policy:parser", + "//policy:parser_factory", + "//policy:policy_parser_context", + "//policy:validation_exception", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:org_yaml_snakeyaml", + ], +) + +proto_library( + name = "agent_context_proto", + srcs = ["agent_context.proto"], + deps = [ + "@com_google_protobuf//:struct_proto", + "@com_google_protobuf//:timestamp_proto", + ], +) + +java_proto_library( + name = "agent_context_java_proto", + deps = [":agent_context_proto"], +) diff --git a/tools/src/main/java/dev/cel/tools/ai/agent_context.proto b/tools/src/main/java/dev/cel/tools/ai/agent_context.proto new file mode 100644 index 000000000..988841004 --- /dev/null +++ b/tools/src/main/java/dev/cel/tools/ai/agent_context.proto @@ -0,0 +1,390 @@ +syntax = "proto3"; + +package cel.expr.ai; + +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; + +option java_package = "dev.cel.expr.ai"; +option java_multiple_files = true; +option java_outer_classname = "AgentContextProto"; + +// Agent represents the AI System or Service being governed. +// It encapsulates the static configuration (Manifests, Identity) and the +// dynamic runtime state (Context, Inputs, Outputs). +message Agent { + // The unique resource name of the agent. + // e.g. "agents/finance-helper" or "publishers/google/agents/gemini-pro" + string name = 1; + + // Human-readable description of the agent's purpose. + string description = 2; + + // The semantic version of the agent definition. + string version = 3; + + // The underlying model family backing this agent. + Model model = 4; + + // The provider or vendor responsible for hosting/managing this agent. + AgentProvider provider = 5; + + // Identity of the Agent itself (Service Account / Principal) + // Independent of 'request.auth.principal' which may be the end user + // credentials or the agent's identity + AgentAuth auth = 6; + + // The accumulated security context (Trust, Sensitivity, Data Sources). + AgentContext context = 7; + + // The current turn's input (Prompt + Attachments) + AgentMessage input = 8; + + // The pending response (if evaluating egress/output policies) + AgentMessage output = 9; +} + +// AgentAuth represents the identity of the Agent itself. +// Independent of 'request.auth.principal' which may be the end user +// credentials or the agent's identity +message AgentAuth { + // The principal of the agent, prefer SPIFFE format of: + // spiffe:///ns//sa/ + // See: https://spiffe.io/docs/latest/spiffe/concepts/#spiffe-identifiers + string principal = 1; + + // Map of string keys to structured claims about the agent. + // For example, with JWT-based tokens, the claims would include fields + // indicating the following: + // + // - The issuer 'iss' (e.g. url of the identity provider) + // - The audience(s) 'aud' (e.g. the intended recipient(s) of the token) + // - The token's expiration time ('exp') + // - The token's subject ('sub') + google.protobuf.Struct claims = 2; + + // The OAuth scopes granted to the agent. + // This is a list of strings, where each string is a valid OAuth scope + // (e.g. "https://www.googleapis.com/auth/cloud-platform"). + repeated string oauth_scopes = 3; +} + +// AgentContext represents the aggregate security and data governance state +// of the agent's context window. +message AgentContext { + // Aggregated view of data sensitivity in the window. + repeated Sensitivity sensitivities = 1; + + // Aggregated trust score (Min of all inputs). + Trust trust = 2; + + // Origin/Lineage tracking. + repeated DataSource data_sources = 3; + + // The flattened text content of the current prompt. + string prompt = 4; +} + +// AgentHistory represents the ordered sequence of messages representing the +// agent's conversation. +// +// AgentHistory is expected to be provided on-demand via helper methods +// associated with an Agent instance. +message AgentHistory { + // The name of the agent for whom this history is collected. + // + // This should match the `Agent.name` field. + string agent_name = 1; + + // The ordered sequence of messages representing the agent's conversation. + repeated AgentMessage messages = 2; +} + +// AgentMessage represents a single turn in the conversation. +// It acts as a container for multimodal content (Text, Files, Tool Results). +message AgentMessage { + // A discrete unit of content within the message. + message Part { + oneof type { + // User or System text input. + ContentPart prompt = 1; + + // A request to execute a specific tool. + // + // If a call has been completed, the call will have the result or + // error populated. Calls which have not yet been resolved will only have + // the intent (arguments) populated. + ToolCall tool_call = 2; + + // A file or multimodal object (Image, PDF). + ContentPart attachment = 3; + + // An error that occurred during processing. + ErrorPart error = 4; + } + } + + // The actor who constructed the message (e.g., "user", "model", "tool"). + string role = 1; + + // The ordered sequence of content parts. + // + // In the case of a tool call, the result or error will be populated within + // the `ToolCall` message rather than split into a separate `Part`. + repeated Part parts = 2; + + // Arbitrary metadata associated with the message turn. + optional google.protobuf.Struct metadata = 3; + + // Message creation time + google.protobuf.Timestamp time = 4; +} + +// ContentPart is a catch-all message type capable of encapsulating other +// messages within its `structured_content` field. +// +// For example, a series of sub-agent MCP tool calls and results may be +// encapsulated as an `AgentMessage` in JSON form within the +// `structured_content` field. +// +// The approach is unconventional, but indicates how the data representation +// provided to policy requires helper methods to help make agent policies +// sensible and with support to type-convert from json to proto perhaps being +// a necessary on-demand feature within agent policies. +message ContentPart { + // Unique identifier for this content part. + string id = 1; + + // The type of content. + // + // Common values include: "text", "file", "json" + string type = 2; + + // The MIME type of the content. + // + // Common values include: "text/plain", "application/json", "image/png" + string mime_type = 3; + + // The name of the content. + string name = 4; + + // The description of the content. + string description = 5; + + // The URI of the content. + optional string uri = 6; + + // The string seriralized representation of the content, either plain text or + // serialized JSON reflected from `structured_content`. + optional string content = 7; + + // The binary representation of the content. + // + // This field is used to represent binary data (e.g., images, PDFs) or + // serialized proto messages which come over the wire as base64-encoded string + // values that are expected to be decoded into binary data. + optional bytes data = 8; + + // The JSON object representation of the content, if applicable. + optional google.protobuf.Struct structured_content = 9; + + // Arbitrary metadata associated with the content part. + optional google.protobuf.Struct annotations = 10; + + // Timestamp associated with the content part. + google.protobuf.Timestamp time = 11; +} + +// ErrorPart represents a processing error within the agent loop. +message ErrorPart { + // The identifier of the specific ContentPart, ToolCall, or Message that + // caused this error. Used to correlate the failure back to the originating + // action (e.g., matching a failed tool call). + string id = 1; + + // Standardized error code (e.g., gRPC status code or HTTP status). + int64 code = 2; + + // Developer-facing error message describing the failure. + string error_message = 3; + + // Timestamp when the error occurred. + google.protobuf.Timestamp time = 4; +} + +// AgentProvider describes the entity responsible for the agent's operation. +message AgentProvider { + // The base URL or endpoint where the agent service is hosted. + string url = 1; + + // The name of the organization providing the agent (e.g. "Google", + // "Salesforce"). + optional string organization = 2; +} + +// Model describes the AI model backing the agent. +message Model { + // Identifier of the model family (ex: gemini-pro, gpt-4 ...) + string name = 1; +} + +// ToolManifest describes a collection of tools provided by a specific +// source. +message ToolManifest { + // Metadata about the tool provider itself, including authorization + // requirements. + ToolProvider provider = 1; + + // Collection of Tool instances specified by the provider. + repeated Tool tools = 2; +} + +// Tool describes a specific function or capability available to the agent. +message Tool { + // The unique name of the tool + string name = 1; // (e.g. "weather_lookup"). + + // Human readable description of what the tool does. + string description = 2; + + // JSON Schema defining the expected arguments. + optional google.protobuf.Struct input_schema = 3; + + // JSON Schema defining the expected output. + optional google.protobuf.Struct output_schema = 4; + + // Security and behavior hints for policy enforcement. + optional ToolAnnotations annotations = 5; + + // Arbitrary tool metadata. + optional google.protobuf.Struct metadata = 6; +} + +// Information about how the tools were provided and by whom. +message ToolProvider { + // URL where the tools were provided. + string url = 1; + + // Name of the tool provider. + string organization = 2; // e.g. "google-cloud" + + // URL for the OAuth authorization endpoint supported by this tool provider + optional string authorization_server_url = 3; + + // Repeated set of OAuth scopes for this tool provider. + repeated string supported_scopes = 4; +} + +// Additional properties describing a tool to clients. +// +// Informed by annotations common to the MCP spec and conventions common to +// other agent frameworks. +message ToolAnnotations { + // A human-readable title for the tool. + string title = 1; + + // If true, the tool does not modify its environment. + // Default: false + bool read_only = 2; + + // If true, the tool may perform destructive updates to its environment. + // If false, the tool performs only additive updates. + // NOTE: This property is meaningful only when `read_only_hint == false` + bool destructive = 3; + + // If true, calling the tool repeatedly with the same arguments will have no + // additional effect on its environment. + // NOTE: This property is meaningful only when `read_only_hint == false`. + bool idempotent = 4; + + // If true, this tool may interact with an "open world" of external entities. + // If false, the tools domain of interaction is closed. For example, the + // world of a web search tool is open, whereas that of a memory tool is not. + bool open_world = 5; + + // If true, this tool is intended to be called asynchronously. + // For example, a tool that starts a simulation process on a server and + // returns immediately. + bool async = 6; + + // Additional structured tags associated with the tool. + map tags = 7; + + // The OAuth scopes required to use this tool. If empty, the set of scopes + // required is inherited from ToolProvider.supported_scopes. + // + // This is a list of strings, where each string is a valid OAuth scope + // (e.g. "https://www.googleapis.com/auth/cloud-platform"). + repeated string required_auth_scopes = 8; + + // The OAuth scopes that are optional to use this tool. + repeated string optional_auth_scopes = 9; + + message DataAccessLevel { + Sensitivity sensitivity = 1; + + message AccessRole { + string role = 1; + google.protobuf.Struct metadata = 2; + } + } +} + +// Sensitivity describes the classification of data within the context. +message Sensitivity { + // Valid labels are 'pii', 'internal' + string label = 1; + + // The optional value associated with the label, e.g. 'credit card' + string value = 2; +} + +// Describes the integrity/veracity of the data. +message Trust { + // Valid trust labels are "untrusted" (default), "trusted", and + // "partially_trusted". + string label = 1; +} + +// Describes the provenance of a data chunk. +message DataSource { + // Unique id describing the originating data source. + string id = 1; // e.g. "bigquery:sales_table" + + // The category of origin for this data. + string provenance = 2; // e.g. "UserPrompt", "Database:Secure", "PublicWeb" +} + +// ToolCall represents a specific invocation of a tool by the agent. +// It captures the intent (arguments), the status (result/error), and +// governance metadata (confirmation). +message ToolCall { + // Unique identifier for this tool call. + // Used to correlate the call with its result or error in the history. + string id = 1; + + // The name of the tool being called (e.g., "weather_lookup"). + // This should match a tool defined in the agent's ToolManifest. + string name = 2; + + // The arguments provided to the tool call. + // Policies can inspect these values to enforce data safety (e.g. no PII). + google.protobuf.Struct arguments = 3; + + // The execution status of the tool call. + // This field is populated if the tool has already been executed (in history). + oneof status { + // The successful output of the tool. + ContentPart result = 4; + + // The error encountered during execution. + ErrorPart error = 5; + } + + // Timestamp when the tool call was initiated. + google.protobuf.Timestamp time = 6; + + // Indicates if the user explicitly confirmed this action. + // Useful for Human-in-the-Loop (HITL) policies. + bool user_confirmed = 7; +} \ No newline at end of file diff --git a/tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java b/tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java new file mode 100644 index 000000000..b9016969b --- /dev/null +++ b/tools/src/test/java/dev/cel/tools/ai/AgenticPolicyCompilerTest.java @@ -0,0 +1,293 @@ +package dev.cel.tools.ai; + +import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration; +import static dev.cel.common.CelOverloadDecl.newGlobalOverload; +import static dev.cel.common.CelOverloadDecl.newMemberOverload; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Ascii; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import com.google.common.truth.Expect; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelContainer; +import dev.cel.common.CelValidationException; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.expr.ai.Agent; +import dev.cel.expr.ai.AgentMessage; +import dev.cel.expr.ai.ContentPart; +import dev.cel.expr.ai.ToolCall; +import dev.cel.parser.CelStandardMacro; +import dev.cel.policy.testing.PolicyTestSuiteHelper; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection; +import dev.cel.policy.testing.PolicyTestSuiteHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelFunctionBinding; +import java.io.IOException; +import java.net.URL; +import java.util.List; +import java.util.Map; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class AgenticPolicyCompilerTest { + @Rule + public final Expect expect = Expect.create(); + + private static final Cel CEL = CelFactory.standardCelBuilder() + .setContainer(CelContainer.ofName("cel.expr.ai")) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addMessageTypes(Agent.getDescriptor()) + .addMessageTypes(ToolCall.getDescriptor()) + .addMessageTypes(AgentMessage.getDescriptor()) + + .addVar("agent", StructTypeReference.create("cel.expr.ai.Agent")) + .addVar("tool", StructTypeReference.create("cel.expr.ai.ToolCall")) + + .addFunctionDeclarations( + newFunctionDeclaration( + "history", + newMemberOverload( + "agent_history", + ListType.create(StructTypeReference.create("cel.expr.ai.AgentMessage")), + StructTypeReference.create("cel.expr.ai.Agent") + ) + ), + newFunctionDeclaration( + "isSensitive", + newMemberOverload( + "toolCall_isSensitive", + SimpleType.BOOL, + StructTypeReference.create("cel.expr.ai.ToolCall") + )), + newFunctionDeclaration( + "security.classifyInjection", + newGlobalOverload( + "classifyInjection_string", + SimpleType.DOUBLE, + SimpleType.STRING + )), + newFunctionDeclaration( + "security.computePrivilegedPlan", + newGlobalOverload( + "computePrivilegedPlan_agentMessage", + ListType.create(SimpleType.STRING), + ListType.create(StructTypeReference.create(AgentMessage.getDescriptor().getFullName())) + )), + newFunctionDeclaration( + "security.cascade_trust", + newGlobalOverload( + "security_cascade_trust", + SimpleType.DYN, + ListType.create(StructTypeReference.create(AgentMessage.getDescriptor().getFullName())) + )) + ) + // Mocked functions + .addFunctionBindings( + CelFunctionBinding.from( + "agent_history", + Agent.class, + (agent) -> { + String scenario = agent.getDescription(); + + if (scenario.startsWith("trust_cascading")) { + return getTrustCascadingHistory(scenario); + } + + if (scenario.startsWith("contextual_security")) { + return getContextualSecurityHistory(scenario); + } + + throw new IllegalArgumentException( + "Test requested 'agent.history()' but provided unsupported agent.description: " + scenario); + } + ), + CelFunctionBinding.from( + "toolCall_isSensitive", + ToolCall.class, + (tool) -> tool.getName().contains("PII")), + CelFunctionBinding.from( + "classifyInjection_string", + ImmutableList.of(String.class), + (args) -> { + String input = (String) args[0]; + if (input.contains("INJECTION_ATTACK")) return 0.95; + if (input.contains("SUSPICIOUS")) return 0.6; + return 0.1; + }), + CelFunctionBinding.from( + "computePrivilegedPlan_agentMessage", + ImmutableList.of(List.class), + (args) -> { + List history = (List) args[0]; + for (AgentMessage msg : history) { + // TODO: Filter by trust as well + if (msg.getPartsCount() > 0) { + String content = msg.getParts(0).getPrompt().getContent(); + // Mocked logic claiming that calculator is the only allowed tool + if (content.contains("Calculate")) { + return ImmutableList.of("calculator"); + } + } + } + return ImmutableList.of(); + }), + CelFunctionBinding.from( + "security_cascade_trust", + ImmutableList.of(List.class), + (args) -> { + List history = (List) args[0]; + String currentTrust = "LOW"; + + if (!history.isEmpty()) { + Map metadata = history.get(0).getMetadata().getFieldsMap(); + if (metadata.containsKey("trust_score")) { + currentTrust = metadata.get("trust_score").getStringValue(); + } + } + + if (currentTrust.equals("LOW")) { + return ImmutableMap.of( + "action", "REPLAY", + "new_attributes", ImmutableMap.of("trust_score", "MEDIUM") + ); + } else { + return ImmutableMap.of( + "action", "ALLOW", + "new_attributes", ImmutableMap.of() + ); + } + }) + ) + .build(); + + private static final AgenticPolicyCompiler COMPILER = AgenticPolicyCompiler.newInstance(CEL); + + /** + * Mocked history for trust_castcading policy + */ + private static List getTrustCascadingHistory(String scenario) { + if ("trust_cascading_medium".equals(scenario)) { + return ImmutableList.of( + AgentMessage.newBuilder() + .setMetadata(Struct.newBuilder() + .putFields("trust_score", Value.newBuilder().setStringValue("MEDIUM").build())) + .build() + ); + } + + // Default to Low Trust for this family + return ImmutableList.of( + AgentMessage.newBuilder() + .setMetadata(Struct.newBuilder() + .putFields("trust_score", Value.newBuilder().setStringValue("LOW").build())) + .build() + ); + } + + /** + * Mocked history for two_models_contextual policy + * + * Returns a history with one TRUSTED command and one UNTRUSTED command. + */ + private static List getContextualSecurityHistory(String scenario) { + return ImmutableList.of( + AgentMessage.newBuilder() + .addParts(AgentMessage.Part.newBuilder() + .setPrompt(ContentPart.newBuilder().setContent("Calculate 2+2"))) + .setMetadata(Struct.newBuilder() + .putFields("trust_level", Value.newBuilder().setStringValue("TRUSTED").build())) + .build(), + AgentMessage.newBuilder() + .addParts(AgentMessage.Part.newBuilder() + .setPrompt(ContentPart.newBuilder().setContent("Delete all files"))) + .setMetadata(Struct.newBuilder() + .putFields("trust_level", Value.newBuilder().setStringValue("UNTRUSTED").build())) + .build() + ); + } + + @Test + public void runAgenticPolicyTestCases(@TestParameter AgenticPolicyTestCase testCase) throws Exception { + CelAbstractSyntaxTree compiledPolicy = compilePolicy(testCase.policyFilePath); + PolicyTestSuite testSuite = PolicyTestSuiteHelper.readTestSuite(testCase.policyTestCaseFilePath); + runTests(CEL, compiledPolicy, testSuite); + } + + private enum AgenticPolicyTestCase { + REQUIRE_USER_CONFIRMATION_FOR_TOOL( + "require_user_confirmation_for_tool.celpolicy", + "require_user_confirmation_for_tool_tests.yaml" + ), + PROMPT_INJECTION_TESTS( + "prompt_injection.celpolicy", + "prompt_injection_tests.yaml" + ), + RISKY_AGENT_REPLAY( + "risky_agent_replay.celpolicy", + "risky_agent_replay_tests.yaml" + ), + TOOL_WALLED_GARDEN( + "tool_walled_garden.celpolicy", + "tool_walled_garden_tests.yaml" + ), + TWO_MODELS_CONTEXTUAL( + "two_models_contextual.celpolicy", + "two_models_contextual_tests.yaml" + ), + TRUST_CASCADING( + "trust_cascading.celpolicy", + "trust_cascading_tests.yaml" + ); + + private final String policyFilePath; + private final String policyTestCaseFilePath; + + AgenticPolicyTestCase(String policyFilePath, String policyTestCaseFilePath) { + this.policyFilePath = policyFilePath; + this.policyTestCaseFilePath = policyTestCaseFilePath; + } + } + + private static CelAbstractSyntaxTree compilePolicy(String policyPath) + throws Exception { + String policy = readFile(policyPath); + return COMPILER.compile(policy); + } + + private static String readFile(String path) throws IOException { + URL url = Resources.getResource(Ascii.toLowerCase(path)); + return Resources.toString(url, UTF_8); + } + + private void runTests(Cel cel, CelAbstractSyntaxTree ast, PolicyTestSuite testSuite) { + for (PolicyTestSection testSection : testSuite.getSection()) { + for (PolicyTestCase testCase : testSection.getTests()) { + String testName = String.format( + "%s: %s", testSection.getName(), testCase.getName()); + try { + ImmutableMap inputMap = testCase.toInputMap(cel); + Object evalResult = cel.createProgram(ast).eval(inputMap); + Object expectedOutput = cel.createProgram(cel.compile(testCase.getOutput()).getAst()).eval(); + expect.withMessage(testName).that(evalResult).isEqualTo(expectedOutput); + } catch (CelValidationException e) { + expect.withMessage("Failed to compile test case for " + testName + ". Reason:\n" + e.getMessage()).fail(); + } catch (CelEvaluationException e) { + expect.withMessage("Failed to evaluate test case for " + testName + ". Reason:\n" + e.getMessage()).fail(); + } + } + } + } +} diff --git a/tools/src/test/java/dev/cel/tools/ai/BUILD.bazel b/tools/src/test/java/dev/cel/tools/ai/BUILD.bazel new file mode 100644 index 000000000..47bd39549 --- /dev/null +++ b/tools/src/test/java/dev/cel/tools/ai/BUILD.bazel @@ -0,0 +1,41 @@ +load("@rules_java//java:defs.bzl", "java_library") +load("//:testing.bzl", "junit4_test_suites") + +package(default_applicable_licenses = ["//:license"]) + +java_library( + name = "tests", + testonly = True, + srcs = glob( + ["*.java"], + ), + resources = ["//tools/ai:test_policies"], + deps = [ + "//:java_truth", + "//bundle:cel", + "//common:cel_ast", + "//common:compiler_common", + "//common:container", + "//common/formats:value_string", + "//common/types", + "//parser:macro", + "//policy/testing:policy_test_suite_helper", + "//runtime:evaluation_exception", + "//runtime:function_binding", + "//tools/ai:agentic_policy_compiler", + "//tools/src/main/java/dev/cel/tools/ai:agent_context_java_proto", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +junit4_test_suites( + name = "test_suites", + sizes = [ + "small", + ], + src_dir = "src/test/java", + deps = [":tests"], +) diff --git a/tools/src/test/resources/BUILD.bazel b/tools/src/test/resources/BUILD.bazel new file mode 100644 index 000000000..8fbb42fce --- /dev/null +++ b/tools/src/test/resources/BUILD.bazel @@ -0,0 +1,20 @@ +package( + default_applicable_licenses = [ + "//:license", + ], + default_testonly = True, + default_visibility = [ + "//tools/ai:__pkg__", + ], +) + +filegroup( + name = "test_policies", + testonly = True, + srcs = glob( + [ + "*.celpolicy", + "*.yaml", + ], + ), +) diff --git a/tools/src/test/resources/prompt_injection.celpolicy b/tools/src/test/resources/prompt_injection.celpolicy new file mode 100644 index 000000000..f61bea38d --- /dev/null +++ b/tools/src/test/resources/prompt_injection.celpolicy @@ -0,0 +1,15 @@ +name: "policy.safety.prompt.injection" +default: allow + +variables: + - injection_score: > + security.classifyInjection(agent.context.prompt) + +rules: + - condition: variables.injection_score > 0.9 + effect: deny + message: "Prompt injection detected with high confidence." + + - condition: variables.injection_score > 0.5 + effect: confirm + message: "Potential prompt injection detected. User confirmation required." \ No newline at end of file diff --git a/tools/src/test/resources/prompt_injection_tests.yaml b/tools/src/test/resources/prompt_injection_tests.yaml new file mode 100644 index 000000000..2a7bfecb2 --- /dev/null +++ b/tools/src/test/resources/prompt_injection_tests.yaml @@ -0,0 +1,49 @@ +description: "Prompt Injection Policy Tests" + +section: +- name: "Injection Classification Scenarios" + tests: + - name: "High Confidence Injection (Deny)" + input: + agent: + expr: > + Agent{ + context: AgentContext{ + prompt: "I'm attempting an INJECTION_ATTACK!" + } + } + output: > + { + "effect": "deny", + "message": "Prompt injection detected with high confidence." + } + + - name: "Medium Confidence Injection (Confirm)" + input: + agent: + expr: > + Agent{ + context: AgentContext{ + prompt: "This might be a SUSPICIOUS message, maybe safe" + } + } + output: > + { + "effect": "confirm", + "message": "Potential prompt injection detected. User confirmation required." + } + + - name: "Safe Input (Allow)" + input: + agent: + expr: > + Agent{ + context: AgentContext{ + prompt: "Just a normal user query" + } + } + output: > + { + "effect": "allow", + "message": "" + } \ No newline at end of file diff --git a/tools/src/test/resources/require_user_confirmation_for_tool.celpolicy b/tools/src/test/resources/require_user_confirmation_for_tool.celpolicy new file mode 100644 index 000000000..4c08538aa --- /dev/null +++ b/tools/src/test/resources/require_user_confirmation_for_tool.celpolicy @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "require_user_confirmation_for_mcp_tool" + +default: deny + +rules: + - description: "Confirm tool calls with PII" + condition: > + tool.isSensitive() && !tool.user_confirmed + effect: confirm + message: "This tool call is sensitive and requires confirmation before the agent can execute. Ask for confirmation from the user" + + - description: "Allow insensitive tools or when user confirmed the tool invocation" + condition: > + !tool.isSensitive() || tool.user_confirmed + effect: allow \ No newline at end of file diff --git a/tools/src/test/resources/require_user_confirmation_for_tool_tests.yaml b/tools/src/test/resources/require_user_confirmation_for_tool_tests.yaml new file mode 100644 index 000000000..74e21f204 --- /dev/null +++ b/tools/src/test/resources/require_user_confirmation_for_tool_tests.yaml @@ -0,0 +1,31 @@ +description: "Require tool confirmation tests" + +section: +- name: "tool call test section" + tests: + - name: "reject_sensitive_tool_call" + input: + tool: + expr: > + ToolCall{ + name: "tool_with_PII", + user_confirmed: false + } + output: > + { + "effect": "confirm", + "message": "This tool call is sensitive and requires confirmation before the agent can execute. Ask for confirmation from the user", + } + - name: "allow_confirmed_tool" + input: + tool: + expr: > + ToolCall{ + name: "tool_with_PII", + user_confirmed: true + } + output: > + { + "effect": "allow", + "message": "", + } \ No newline at end of file diff --git a/tools/src/test/resources/risky_agent_replay.celpolicy b/tools/src/test/resources/risky_agent_replay.celpolicy new file mode 100644 index 000000000..86557a4e3 --- /dev/null +++ b/tools/src/test/resources/risky_agent_replay.celpolicy @@ -0,0 +1,13 @@ +name: "policy.risky.agent.replay" +default: allow + +rules: + - description: "Limit turn window for risky agents" + condition: | + tool.name in ["my_risky_agent1", "my_risky_agent2"] + effect: replay + output_expr: | + { + 'type': 'USER', + 'turn_window': 1 + } diff --git a/tools/src/test/resources/risky_agent_replay_tests.yaml b/tools/src/test/resources/risky_agent_replay_tests.yaml new file mode 100644 index 000000000..12ffa0e47 --- /dev/null +++ b/tools/src/test/resources/risky_agent_replay_tests.yaml @@ -0,0 +1,29 @@ +description: "Risky Agent Replay Policy Tests" + +section: +- name: "Risky Agent Checks" + tests: + - name: "Risky Agent 1 (Replay)" + input: + tool: + expr: > + ToolCall{ name: "my_risky_agent1" } + output: > + { + "effect": "replay", + "details": { + "type": "USER", + "turn_window": 1 + } + } + + - name: "Safe Agent (Allow)" + input: + tool: + expr: > + ToolCall{ name: "safe_agent" } + output: > + { + "effect": "allow", + "message": "" + } \ No newline at end of file diff --git a/tools/src/test/resources/tool_walled_garden.celpolicy b/tools/src/test/resources/tool_walled_garden.celpolicy new file mode 100644 index 000000000..cc4c5c19d --- /dev/null +++ b/tools/src/test/resources/tool_walled_garden.celpolicy @@ -0,0 +1,13 @@ +name: "tool.restrictions" +default: allow + +variables: + - allowed_tools: > + ['core_capabilities', 'google_search', 'image_generation', 'data_analysis', 'content_fetcher'] + +rules: + - description: "Limit tool access for restricted environment. Only specific tools are allowed." + condition: | + !(tool.name in variables.allowed_tools) + effect: deny + message: "Tool access restricted. This tool is not in the allowlist." diff --git a/tools/src/test/resources/tool_walled_garden_tests.yaml b/tools/src/test/resources/tool_walled_garden_tests.yaml new file mode 100644 index 000000000..23e75b89d --- /dev/null +++ b/tools/src/test/resources/tool_walled_garden_tests.yaml @@ -0,0 +1,26 @@ +description: "Tool Restriction Tests" + +section: +- name: "Allowlist Enforcement" + tests: + - name: "Allowed Tool (Google Search)" + input: + tool: + expr: > + ToolCall{ name: "google_search" } + output: > + { + "effect": "allow", + "message": "" + } + + - name: "Disallowed Tool (Random Tool)" + input: + tool: + expr: > + ToolCall{ name: "random_3p_tool" } + output: > + { + "effect": "deny", + "message": "Tool access restricted. This tool is not in the allowlist." + } \ No newline at end of file diff --git a/tools/src/test/resources/trust_cascading.celpolicy b/tools/src/test/resources/trust_cascading.celpolicy new file mode 100644 index 000000000..0563db5f6 --- /dev/null +++ b/tools/src/test/resources/trust_cascading.celpolicy @@ -0,0 +1,21 @@ +name: "policy.trust.cascading" +default: allow + +variables: + - trust_decision: > + security.cascade_trust(agent.history()) + +rules: + - description: "Elevate trust and replay model call if required" + condition: variables.trust_decision.action == 'REPLAY' + effect: replay + output_expr: | + { + 'append_attributes': variables.trust_decision.new_attributes, + 'reason': 'Trust elevation required for proper answer.' + } + + - description: "Trust sufficient, allow execution" + condition: variables.trust_decision.action == 'ALLOW' + effect: allow + message: "Trust level sufficient." \ No newline at end of file diff --git a/tools/src/test/resources/trust_cascading_tests.yaml b/tools/src/test/resources/trust_cascading_tests.yaml new file mode 100644 index 000000000..ccb13f17c --- /dev/null +++ b/tools/src/test/resources/trust_cascading_tests.yaml @@ -0,0 +1,34 @@ +description: "Trust Cascading Policy Tests" + +section: +- name: "Cascading Logic" + tests: + - name: "Elevation Required (Replay)" + input: + agent: + # Note: description is important below. It's used to fetch mocked history content. + expr: > + Agent{ + description: "trust_cascading_low" + } + output: > + { + "effect": "replay", + "details": { + "append_attributes": { "trust_score": "MEDIUM" }, + "reason": "Trust elevation required for proper answer." + } + } + + - name: "Trust Sufficient (Allow)" + input: + agent: + expr: > + Agent{ + description: "trust_cascading_medium" + } + output: > + { + "effect": "allow", + "message": "Trust level sufficient." + } \ No newline at end of file diff --git a/tools/src/test/resources/two_models_contextual.celpolicy b/tools/src/test/resources/two_models_contextual.celpolicy new file mode 100644 index 000000000..887df5c03 --- /dev/null +++ b/tools/src/test/resources/two_models_contextual.celpolicy @@ -0,0 +1,25 @@ +name: "policy.two.models.contextual" +default: allow + +variables: + - trusted_plan: > + security.computePrivilegedPlan( + agent.history().filter(msg, msg.metadata.trust_level == 'TRUSTED') + ) + +rules: + - description: "Enforce the privileged plan: Deny unauthorized tools" + condition: | + tool.name != "" && + variables.trusted_plan.size() > 0 && + !(tool.name in variables.trusted_plan) + effect: deny + message: "Tool call violated the privileged execution plan. This tool is not authorized for this context." + + - description: "Enforce the privileged plan: Allow authorized tools" + condition: | + tool.name != "" && + variables.trusted_plan.size() > 0 && + (tool.name in variables.trusted_plan) + effect: allow + message: "" \ No newline at end of file diff --git a/tools/src/test/resources/two_models_contextual_tests.yaml b/tools/src/test/resources/two_models_contextual_tests.yaml new file mode 100644 index 000000000..9193dc866 --- /dev/null +++ b/tools/src/test/resources/two_models_contextual_tests.yaml @@ -0,0 +1,37 @@ +description: "Contextual Security Tests" + +section: +- name: "Privileged Plan Enforcement" + tests: + - name: "Compliant Tool Call (Allow)" + input: + agent: + # Note: description is important below. It's used to fetch mocked history content. + expr: > + Agent{ + description: "contextual_security_mixed" + } + tool: + expr: > + ToolCall{ name: "calculator" } + output: > + { + "effect": "allow", + "message": "" + } + + - name: "Non-Compliant Tool Call (Deny)" + input: + agent: + expr: > + Agent{ + description: "contextual_security_mixed" + } + tool: + expr: > + ToolCall{ name: "file_deleter" } + output: > + { + "effect": "deny", + "message": "Tool call violated the privileged execution plan. This tool is not authorized for this context." + } \ No newline at end of file