Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644
Add NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644zianglih wants to merge 62 commits intoNVIDIA:mainfrom
NVTE_BACKWARD_OVERRIDE=high_precision|dequantized#2644Conversation
Greptile SummaryThis PR adds How the two modes work:
Key implementation details:
One minor inconsistency: Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Forward Pass - FP8 enabled] --> B{backward_override?}
B -- None --> C[Default: quantize input + weight\ncolumnwise for backward GEMM]
B -- high_precision --> D[save_original_input=True\nSave original fp16/bf16 input + weight\ncolumnwise=False on quantizer]
B -- dequantized --> E[Save quantized input rowwise-only\ncolumnwise=False on quantizer\ndisable optimize_for_gemm for MXFP8/NVFP4]
C --> F[Backward: FP8 GEMMs\ndgrad = quantized_dy × quantized_w\nwgrad = quantized_dy × quantized_x]
D --> G[Backward: ctx.fp8=False\ndgrad = fp16_dy × original_w\nwgrad = fp16_dy × original_x]
E --> H[Backward: ctx.fp8=False\nDequantize saved tensors\ndgrad = fp16_dy × dequant_w\nwgrad = fp16_dy × dequant_x]
F --> I[FP8 weight/input grad]
G --> J[High-precision grad\nChain rule preserved]
H --> K[High-precision grad\nDerived from fprop quantized values]
style C fill:#f9f,stroke:#333
style D fill:#9f9,stroke:#333
style E fill:#9ff,stroke:#333
style F fill:#f9f,stroke:#333
style G fill:#9f9,stroke:#333
style H fill:#9ff,stroke:#333
Reviews (42): Last reviewed commit: "Merge branch 'main' into keep-bwd" | Re-trigger Greptile |
|
I'll work on potential unit test breakage. |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
/te-ci pytorch L1 |
ksivaman
left a comment
There was a problem hiding this comment.
- Not a fan of
NVTE_BACKWARD_MODE, it's too generic. I am still not sure if this feature should be allowed via environment toggle. It's easy for the users but we should make it explicitly configurable via recipe API and not envvar. - Is there a reason to have the
dequantmode? Is it just for memory saving? Can't imagine it being numerically better thatunquant. Either way,dequantizedandhigh_precisionmight be better names for these features.
Naming part I agree but I have no strong opinion. |
|
Hi @ksivaman , thanks for reviewing!
Currently the
Yes we have very good reasons in RL use cases since it best preserves chain rule and serves as an STE. Our experiments showed clearly more stable gradient curves compared with
Yes I can change naming to |
|
Can you clarify the dequant method here? For fprop, we quantize and get input_fp8, and weight_fp8, and then for dequantize you also dequantize both, is that right? |
|
Hi @zhongbozhu ,
This is exactly right. The fprop uses quantized compute specified by the quantization recipe with no behavioral changes. In bwd, input_fp8 is dequantized for high-precision wgrad, weight_fp8 is dequantized for high-precision dgrad, gradient is always kept in high-precision and gradient quantization never happens. The movitivation for this
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
…zed` Signed-off-by: Ziang Li <ziangli@umich.edu>
NVTE_BACKWARD_MODE=default|unquant|dequantNVTE_BACKWARD_OVERRIDE=high_precision|dequantized
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@victordion I think you are describing the default TE 1d recipe or requantized behavior. |
Right. My mistake. My mental model assumed there is requantize happening. Thanks for responding! |
|
Regarding the env var design, since this feature is mainly used by RL, there has to be a way for the user to directly override the bwd behavior in RL framework instead of plumbing all the way through Megatron. |
|
/te-ci L0 L1 |
|
All pytorch ci passed. Some failed jax tests are due to |

Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.Add NVTE_BACKWARD_MODE=default|unquant|dequant env varAdd
NVTE_BACKWARD_OVERRIDE=high_precision|dequantizedenv var:high_precision: quantized fprop + high precision wgrad & dgrad using unquantized activation and weightdequantized: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized valueThe movitivation for this dequantized design is RL. Unlike pre-training which only needs to preserve coarse optimization direction and convergence, RL gradients are noisy and useful updates are small and delicate. If gradient quantization and chain rule violation are present, noise dominates the true and fragile update signal and model will collapse. This dequantized design avoids gradient quantization and effectively preserves chain rule.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: