Skip to content

Add FoldQATConvBNPass to fold QAT Conv-BN simulation chains into conv bias (#19315)#19315

Open
rezaasjd wants to merge 1 commit intopytorch:mainfrom
rezaasjd:export-D103949573
Open

Add FoldQATConvBNPass to fold QAT Conv-BN simulation chains into conv bias (#19315)#19315
rezaasjd wants to merge 1 commit intopytorch:mainfrom
rezaasjd:export-D103949573

Conversation

@rezaasjd
Copy link
Copy Markdown
Contributor

@rezaasjd rezaasjd commented May 5, 2026

Summary:

Add FoldQATConvBNPass to the Cadence AOT compiler pipeline to handle QAT Conv-BN simulated fusion patterns that survive into the exported graph.

When a model is exported after QAT training, the Conv-BN simulation chain (add(var+eps) -> sqrt -> div(bn_weight) -> div(conv_out/scale) -> add(orig_bias) -> batch_norm) may not be folded by TorchAO _fold_conv_bn_qat due to pattern mismatch. This leaves non-quantized add/div/sqrt nodes in the graph that cause QuantFusion to crash when it tries to fuse them as quantized add ops.

The fix has three parts:

  1. Add conv1d.default to QuantizeFusedConvBnBiasAtenPass conv_targets so it matches conv1d ops and can create zero biases for convs without one (mirrors the existing conv2d support).

  2. Add FoldQATConvBNPass which matches the QAT simulation chain, computes the BN correction constant C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias, folds C into the conv quantized bias tensor, and removes the simulation chain + batch_norm nodes. No new graph nodes are created.

  3. Apply these passes in the correct order in both the get_fake_quant_model (pre-export, on GraphModule) and apply_pre_edge_transform_passes (post-export, on ExportedProgram) pipelines: first QuantizeFusedConvBnBiasAtenPass to create zero biases for convs that lack one, then FoldQATConvBNPass to fold the simulation chain into those biases.

Differential Revision: D103949573

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19315

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 1 Awaiting Approval, 20 Pending

As of commit 3c1ea80 with merge base bf8abb6 (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 5, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 5, 2026

@rezaasjd has exported this pull request. If you are a Meta employee, you can view the originating Diff in D103949573.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@rezaasjd rezaasjd force-pushed the export-D103949573 branch from 0758007 to 1c13d6d Compare May 7, 2026 18:22
@rezaasjd rezaasjd requested a review from kimishpatel as a code owner May 7, 2026 18:22
@meta-codesync meta-codesync Bot changed the title bypass add and addRelu that have inputs from non-dq nodes bypass add and addRelu that have inputs from non-dq nodes (#19315) May 7, 2026
rezaasjd pushed a commit to rezaasjd/executorch that referenced this pull request May 7, 2026
)

Summary:

Add FoldQATConvBNPass to the Cadence AOT compiler pipeline to handle QAT Conv-BN simulated fusion patterns that survive into the exported graph.

When a model is exported after QAT training, the Conv-BN simulation chain (add(var+eps) -> sqrt -> div(bn_weight) -> div(conv_out/scale) -> add(orig_bias) -> batch_norm) may not be folded by TorchAO `_fold_conv_bn_qat` due to pattern mismatch. This leaves non-quantized add/div/sqrt nodes in the graph that cause QuantFusion to crash when it tries to fuse them as quantized add ops.

The fix has three parts:

1. Add `conv1d.default` to `QuantizeFusedConvBnBiasAtenPass` conv_targets so it matches conv1d ops and can create zero biases for convs without one (mirrors the existing conv2d support).

2. Add `FoldQATConvBNPass` which matches the QAT simulation chain, computes the BN correction constant C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias, folds C into the conv quantized bias tensor, and removes the simulation chain + batch_norm nodes. No new graph nodes are created.

3. Apply these passes in the correct order in both the `get_fake_quant_model` (pre-export, on GraphModule) and `apply_pre_edge_transform_passes` (post-export, on ExportedProgram) pipelines: first `QuantizeFusedConvBnBiasAtenPass` to create zero biases for convs that lack one, then `FoldQATConvBNPass` to fold the simulation chain into those biases.

Differential Revision: D103949573
@rezaasjd rezaasjd force-pushed the export-D103949573 branch from 1c13d6d to 9e5e0e7 Compare May 7, 2026 18:23
… bias (pytorch#19315)

Summary:

Add FoldQATConvBNPass to the Cadence AOT compiler pipeline to handle QAT Conv-BN simulated fusion patterns that survive into the exported graph.

When a model is exported after QAT training, the Conv-BN simulation chain (add(var+eps) -> sqrt -> div(bn_weight) -> div(conv_out/scale) -> add(orig_bias) -> batch_norm) may not be folded by TorchAO `_fold_conv_bn_qat` due to pattern mismatch. This leaves non-quantized add/div/sqrt nodes in the graph that cause QuantFusion to crash when it tries to fuse them as quantized add ops.

The fix has three parts:

1. Add `conv1d.default` to `QuantizeFusedConvBnBiasAtenPass` conv_targets so it matches conv1d ops and can create zero biases for convs without one (mirrors the existing conv2d support).

2. Add `FoldQATConvBNPass` which matches the QAT simulation chain, computes the BN correction constant C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias, folds C into the conv quantized bias tensor, and removes the simulation chain + batch_norm nodes. No new graph nodes are created.

3. Apply these passes in the correct order in both the `get_fake_quant_model` (pre-export, on GraphModule) and `apply_pre_edge_transform_passes` (post-export, on ExportedProgram) pipelines: first `QuantizeFusedConvBnBiasAtenPass` to create zero biases for convs that lack one, then `FoldQATConvBNPass` to fold the simulation chain into those biases.

Differential Revision: D103949573
@meta-codesync meta-codesync Bot changed the title bypass add and addRelu that have inputs from non-dq nodes (#19315) Add FoldQATConvBNPass to fold QAT Conv-BN simulation chains into conv bias (#19315) May 7, 2026
@rezaasjd rezaasjd force-pushed the export-D103949573 branch from 9e5e0e7 to 3c1ea80 Compare May 7, 2026 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant