[PyTorch] [torch.compile] Remove module reference from autograd function args#2791
[PyTorch] [torch.compile] Remove module reference from autograd function args#2791pggPL wants to merge 9 commits intoNVIDIA:mainfrom
Conversation
Extract weight quantization into standalone `quantize_weight()` function in base.py, eliminating the need to pass `self` (nn.Module) into autograd functions. Each op's autograd function now receives/returns Optional[Tensor] weight workspaces instead, with cache management handled by the nn.Module before/after the autograd call. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
…autograd_args Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor # Conflicts: # transformer_engine/pytorch/module/base.py
for more information, see https://pre-commit.ci
No callers remain after the quantize_weight refactor. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
Greptile SummaryThis PR removes Confidence Score: 5/5Safe to merge — the refactoring is logically correct with backward signatures, gradient counts, and workspace semantics all preserved; remaining findings are style-level. All five changed files verified: backward return orders match new forward parameter positions,
Important Files Changed
Sequence DiagramsequenceDiagram
participant M as Module.forward()
participant FP8C as _fp8_workspaces cache
participant AF as _Linear.forward()
participant QW as quantize_weight()
M->>FP8C: lookup weight_workspace (if cache_weight=True)
FP8C-->>M: cached workspace or None
M->>AF: weight, weight_workspace, inp, bias, non_tensor_args
AF->>QW: tensor=weight, workspace=cached_ws, cache=cache_weight
alt cache hit (workspace valid)
QW->>QW: update in-place via quantize_()
QW-->>AF: (weightmat, None)
else cache miss / invalid
QW->>QW: quantizer.quantize(tensor)
QW-->>AF: (weightmat, new_workspace)
end
AF->>AF: GEMM computation
AF-->>M: (out, new_weight_workspace)
alt new_weight_workspace is not None
M->>FP8C: store new_weight_workspace under cache_name
end
M-->>M: return out
Reviews (4): Last reviewed commit: "Merge branch 'main' into remove_module_f..." | Re-trigger Greptile |
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
The
torch.autograd.functions in TE modules havemoduleargument, which is used for weight cache.This will not work with torch.compile. This PR changed that with direct tensor pass and return from operator and cache update outside
torch.autograd.function.Type of change
Checklist: