Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CadenceWithSoftmaxQuantizer,
qconfig_A16,
qconfig_A8W8,
qconfig_A8W8sym,
)
from executorch.exir.pass_base import NodeMetadata
from parameterized import parameterized
Expand All @@ -53,7 +54,6 @@
# Quantizers intentionally excluded from annotation testing.
# These should be explicitly justified when added.
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
Expand Down Expand Up @@ -137,6 +137,61 @@
# For add: both inputs are activations
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
# CadenceDefaultQuantizer test cases
(
"default_matmul_A8W8",
lambda self: self._build_matmul_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.matmul.default,
qconfig_A8W8.output_activation,
# For matmul: both inputs are activations
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
(
"default_linear_A8W8",
lambda self: self._build_linear_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.linear.default,
qconfig_A8W8.output_activation,
# For linear: [input_activation, weight]
[qconfig_A8W8.input_activation, qconfig_A8W8.weight],
),
(
"default_conv1d_A8W8sym",
lambda self: self._build_conv1d_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.conv1d.default,
qconfig_A8W8sym.output_activation,
# For conv1d: [input_activation, weight] with symmetric weights
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
),
(
"default_conv2d_A8W8sym",
lambda self: self._build_conv2d_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.conv2d.default,
qconfig_A8W8sym.output_activation,
# For conv2d: [input_activation, weight] with symmetric weights
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
),
(
"default_bmm_A8W8",
lambda self: self._build_bmm_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.bmm.default,
qconfig_A8W8.output_activation,
# For bmm: both inputs are activations
[qconfig_A8W8.input_activation, qconfig_A8W8.input_activation],
),
(
"default_relu_A8W8",
lambda self: self._build_relu_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.relu.default,
qconfig_A8W8.output_activation,
# For relu: only input_activation
[qconfig_A8W8.input_activation],
),
]

# Derive the set of tested quantizer classes from the test cases.
Expand Down Expand Up @@ -309,6 +364,50 @@ def _build_add_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
self.assertEqual(len(add_nodes), 1, "Should find exactly one add node")
return gm, add_nodes[0]

def _build_bmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a bmm (batch matrix multiply) operation."""
builder = GraphBuilder()
# BMM requires 3D tensors: (batch, n, m) @ (batch, m, p) -> (batch, n, p)
x = builder.placeholder("x", torch.randn(2, 4, 8))
y = builder.placeholder("y", torch.randn(2, 8, 4))
bmm = builder.call_operator(
op=torch.ops.aten.bmm.default,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("bmm", torch.ops.aten.bmm.default)]}
),
)
builder.output([bmm])
gm = builder.get_graph_module()

bmm_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.bmm.default,
)
self.assertEqual(len(bmm_nodes), 1, "Should find exactly one bmm node")
return gm, bmm_nodes[0]

def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a relu operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
relu = builder.call_operator(
op=torch.ops.aten.relu.default,
args=(x,),
meta=NodeMetadata(
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
),
)
builder.output([relu])
gm = builder.get_graph_module()

relu_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.relu.default,
)
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
return gm, relu_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
Expand Down
Loading