Skip to content

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644

Open
zianglih wants to merge 62 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644
zianglih wants to merge 62 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link
Copy Markdown

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var

Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized env var:

  • Not set: existing default quantization behavior
  • high_precision: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
    • image
  • dequantized: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value
    • image

The movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Summary

This PR adds NVTE_BACKWARD_OVERRIDE=high_precision|dequantized (and the equivalent recipe.backward_override field) to give users control over backward-pass precision when using quantized forward passes. The primary motivation is RL fine-tuning, where gradient noise from FP8 quantization can dominate fragile update signals and cause model collapse.

How the two modes work:

  • high_precision: saves the original bf16/fp16 input and weight tensors before any quantization; the quantized forward GEMM still runs in FP8, but backward GEMMs use the clean unquantized originals. The ctx.fp8=False override routes wgrad preparation to the inputmat.dequantize() branch, and the dgrad weight is taken directly from the parameter.
  • dequantized: saves the FP8 quantized tensors from fprop but with rowwise-only layout (columnwise=False, optimize_for_gemm=False for MXFP8/NVFP4); during backward, ctx.fp8=False causes these to be dequantized via maybe_dequantize / explicit .dequantize() calls before the high-precision GEMMs.

Key implementation details:

  • All context FP8/UB/quantizer flags are zeroed when backward_override is not None, cleanly separating the override backward from the standard FP8 backward infrastructure.
  • LayerNormMLP is explicitly unsupported with a clear assertion and actionable error message.
  • DelayedScaling is gated at the recipe level with an assertion in __post_init__.
  • Empty-grouped-split edge cases for MXFP8/NVFP4 are handled by materialising explicit zero-size tensors before dequantization.
  • Userbuffers fusion, BackwardActivationBias fusion, and OperationFuser cache key are all updated to respect the new override.
  • The new test_backward_override.py (1848 lines) exhaustively covers both modes for Linear, LayerNormLinear, GroupedLinear, and all fused op patterns with saved-operand invariant checks.

One minor inconsistency: BasicLinear.reset_recipe_state (for te_ops.Linear) disables optimize_for_gemm for both high_precision and dequantized, while Linear._get_quantizers / LayerNormLinear._get_quantizers correctly only disable it for dequantized. For high_precision, optimize_for_gemm is irrelevant to backward correctness (unquantized tensors are saved), so the te_ops.Linear path is slightly over-conservative but not incorrect.

Confidence Score: 4/5

  • Safe to merge after addressing the minor optimize_for_gemm inconsistency in BasicLinear.reset_recipe_state; no correctness bugs found.
  • Prior review concerns have been substantially resolved: LayerNormMLP has a clear informative assertion, DelayedScaling is gated at recipe level, saved-tensor logic properly nulls out unused tensors, and the two backward modes are correctly wired through both the classic-module (te.Linear) and fusible-ops (te_ops.Linear) code paths. The implementation is correct and well-tested with 1848 lines of new tests. The only remaining item is the minor optimize_for_gemm inconsistency between BasicLinear.reset_recipe_state and the module-level _get_quantizers for high_precision mode — this does not affect correctness but is worth fixing for consistency.
  • transformer_engine/pytorch/ops/basic/basic_linear.py (reset_recipe_state optimize_for_gemm condition for high_precision mode)

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Adds backward_override field to all recipe classes (DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Float8BlockScaling, NVFP4BlockScaling, CustomRecipe) with _BACKWARD_OVERRIDES tuple validation. DelayedScaling correctly asserts backward_override is None, preventing the incompatible combination at the recipe level before any module code runs.
transformer_engine/pytorch/module/linear.py Core implementation for te.Linear. high_precision mode forces save_original_input=True to retain unquantized input; dequantized mode retains quantized forward tensors with rowwise-only layout. Backward correctly sets ctx.fp8=False (triggering inputmat.dequantize() in wgrad prep) and explicitly dequantizes weight for dgrad. Context FP8/UB flags are zeroed for both override modes.
transformer_engine/pytorch/module/layernorm_linear.py Saves ln_out_hp (high-precision LN output before quantization) for high_precision mode, and the quantized ln_out for dequantized mode. Backward correctly dequantizes ln_out and weight for the dequantized path. optimize_for_gemm is disabled only for dequantized + MXFP8/NVFP4, consistent with linear.py.
transformer_engine/pytorch/module/layernorm_mlp.py Explicitly unsupported: asserts backward_override is None with a clear error message directing users to use LayerNormLinear + Linear instead. The assertion message is informative and actionable. ctx.backward_override is still saved for completeness.
transformer_engine/pytorch/module/grouped_linear.py Comprehensive implementation for grouped GEMM. dequantized wgrad correctly handles empty m-splits (zero-size grouped chunks) by materialising explicit empty tensors instead of calling dequantize() on zero-numel inputs. optimize_for_gemm disabled for dequantized + MXFP8/NVFP4.
transformer_engine/pytorch/ops/basic/basic_linear.py Core te_ops.Linear implementation. high_precision saves original unquantized input/weight; dequantized saves quantized tensors with rowwise-only layout which are then dequantized via maybe_dequantize in _functional_backward. Minor inconsistency: reset_recipe_state disables optimize_for_gemm for BOTH override modes, while Linear._get_quantizers correctly only disables it for dequantized.
tests/pytorch/test_backward_override.py Comprehensive new test file (1848 lines) covering both override modes for Linear, LayerNormLinear, GroupedLinear, and all fused op patterns across all quantized recipes. Uses saved-operand snapshot invariant checks for dequantized mode to verify quantized tensors are not modified by backward. Minor: redundant NVFP4 divisibility skip logic (see inline comment).
transformer_engine/pytorch/module/base.py Adds use_fp8_bwd = ctx.fp8 and ctx.backward_override is None gate. This correctly routes dequantized/high_precision to the non-FP8 backward path (including the inputmat.dequantize() branch in wgrad preparation) while leaving the default quantized backward unchanged.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Forward Pass - FP8 enabled] --> B{backward_override?}
    B -- None --> C[Default: quantize input + weight\ncolumnwise for backward GEMM]
    B -- high_precision --> D[save_original_input=True\nSave original fp16/bf16 input + weight\ncolumnwise=False on quantizer]
    B -- dequantized --> E[Save quantized input rowwise-only\ncolumnwise=False on quantizer\ndisable optimize_for_gemm for MXFP8/NVFP4]

    C --> F[Backward: FP8 GEMMs\ndgrad = quantized_dy × quantized_w\nwgrad = quantized_dy × quantized_x]
    D --> G[Backward: ctx.fp8=False\ndgrad = fp16_dy × original_w\nwgrad = fp16_dy × original_x]
    E --> H[Backward: ctx.fp8=False\nDequantize saved tensors\ndgrad = fp16_dy × dequant_w\nwgrad = fp16_dy × dequant_x]

    F --> I[FP8 weight/input grad]
    G --> J[High-precision grad\nChain rule preserved]
    H --> K[High-precision grad\nDerived from fprop quantized values]

    style C fill:#f9f,stroke:#333
    style D fill:#9f9,stroke:#333
    style E fill:#9ff,stroke:#333
    style F fill:#f9f,stroke:#333
    style G fill:#9f9,stroke:#333
    style H fill:#9ff,stroke:#333
Loading

Reviews (42): Last reviewed commit: "Merge branch 'main' into keep-bwd" | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
zhongbozhu
zhongbozhu previously approved these changes Mar 13, 2026
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@zhongbozhu
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Not a fan of NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar.
  2. Is there a reason to have the dequant mode? Is it just for memory saving? Can't imagine it being numerically better that unquant. Either way, dequantized and high_precision might be better names for these features.

@zhongbozhu
Copy link
Copy Markdown
Collaborator

  1. Not a fan of NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar.
  2. Is there a reason to have the dequant mode? Is it just for memory saving? Can't imagine it being numerically better that unquant. Either way, dequantized and high_precision might be better names for these features.

Naming part I agree but I have no strong opinion.

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Mar 13, 2026

Hi @ksivaman , thanks for reviewing!

we should make it explicitly configurable via recipe API and not envvar

Currently the backward_mode is a configurable recipe member, not a global toggle. It is set by the NVTE_BACKWARD_MODE envvar. I can work on a better interface.

Is there a reason to have the dequant mode?

Yes we have very good reasons in RL use cases since it best preserves chain rule and serves as an STE. Our experiments showed clearly more stable gradient curves compared with default and unquant mode. unquant seems to have good numerics but violates chain rule more, which is acceptable in pre-training but not RL.

dequantized and high_precision might be better names for these features

Yes I can change naming to default|high_precision|dequantized.

@zhongbozhu
Copy link
Copy Markdown
Collaborator

Can you clarify the dequant method here? For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both, is that right?

@zianglih
Copy link
Copy Markdown
Author

zianglih commented Mar 13, 2026

Hi @zhongbozhu ,

For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both

This is exactly right. The fprop uses quantized compute specified by the quantization recipe with no behavioral changes. In bwd, input_fp8 is dequantized for high-precision wgrad, weight_fp8 is dequantized for high-precision dgrad, gradient is always kept in high-precision and gradient quantization never happens.

The movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.

image

Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
…zed`

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih changed the title Add NVTE_BACKWARD_MODE=default|unquant|dequant Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized Mar 14, 2026
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Author

zianglih commented Mar 16, 2026

using "dequantized" in bwd still does not preserve the chain rule 100%, as the quantization in fwd and bwd happens along different dims

@victordion I think you are describing the default TE 1d recipe or requantized behavior.

@victordion
Copy link
Copy Markdown

using "dequantized" in bwd still does not preserve the chain rule 100%, as the quantization in fwd and bwd happens along different dims

@victordion I think you are describing the default TE 1d recipe or requantized behavior.

Right. My mistake. My mental model assumed there is requantize happening. Thanks for responding!

@zianglih zianglih requested review from ksivaman and zhongbozhu March 17, 2026 05:42
@zianglih
Copy link
Copy Markdown
Author

Regarding the env var design, since this feature is mainly used by RL, there has to be a way for the user to directly override the bwd behavior in RL framework instead of plumbing all the way through Megatron.

@ksivaman
Copy link
Copy Markdown
Member

/te-ci L0 L1

@zianglih
Copy link
Copy Markdown
Author

All pytorch ci passed.

Some failed jax tests are due to FileExistsError: [Errno 17] File exists: '/logs' .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants