Skip to content

[PyTorch] [torch.compile] Remove module reference from autograd function args#2791

Open
pggPL wants to merge 9 commits intoNVIDIA:mainfrom
pggPL:remove_module_from_autograd_args
Open

[PyTorch] [torch.compile] Remove module reference from autograd function args#2791
pggPL wants to merge 9 commits intoNVIDIA:mainfrom
pggPL:remove_module_from_autograd_args

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented Mar 23, 2026

The torch.autograd.functions in TE modules have module argument, which is used for weight cache.
This will not work with torch.compile. This PR changed that with direct tensor pass and return from operator and cache update outside torch.autograd.function.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 4 commits March 23, 2026 14:55
Extract weight quantization into standalone `quantize_weight()` function
in base.py, eliminating the need to pass `self` (nn.Module) into
autograd functions. Each op's autograd function now receives/returns
Optional[Tensor] weight workspaces instead, with cache management
handled by the nn.Module before/after the autograd call.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…autograd_args

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor

# Conflicts:
#	transformer_engine/pytorch/module/base.py
No callers remain after the quantize_weight refactor.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 23, 2026

Greptile Summary

This PR removes module references from torch.autograd.Function forward arguments across all five PyTorch modules in TransformerEngine, replacing side-effecting module.get_weight_workspace() calls with explicit tensor passing and returning. The new standalone quantize_weight() helper in base.py faithfully replicates the old method's cache-hit, cache-miss, and workspace-invalidation logic. Backward return orders, gradient counts, and workspace semantics are all preserved correctly throughout.

Confidence Score: 5/5

Safe to merge — the refactoring is logically correct with backward signatures, gradient counts, and workspace semantics all preserved; remaining findings are style-level.

All five changed files verified: backward return orders match new forward parameter positions, quantize_weight faithfully replicates old cache logic, and get_weight_workspace is fully deleted with no remaining callers. The only open items are a dead isinstance guard (harmless) and the list-vs-tuple question in _GroupedLinear (style/possible future compile nuance). Neither blocks correctness today.

grouped_linear.py — Python list return from forward deserves a second look for torch.compile; linear.py, layernorm_linear.py, layernorm_mlp.py share the dead .detach() guard.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Adds standalone quantize_weight() helper and _is_weight_workspace_valid() extracted from the removed TransformerEngineBaseModule.get_weight_workspace() method; logic is semantically equivalent to the old method, with clear documentation on the (workspace, new_workspace) return contract.
transformer_engine/pytorch/module/linear.py Adds weight_workspace tensor arg to _Linear.forward, returns (out, new_weight_workspace) tuple, and moves cache-update logic to Linear.forward; backward signature updated with _grad_weight_workspace and returns None at the correct position. Dead isinstance(...torch.Tensor) guard before .detach() is harmless but misleading.
transformer_engine/pytorch/module/layernorm_linear.py Mirrors the _Linear pattern: _LayerNormLinear.forward gains weight_workspace arg and always returns a 3-tuple (out, ln_out_for_return, new_weight_workspace); backward return order is correct; same dead .detach() guard present but harmless.
transformer_engine/pytorch/module/layernorm_mlp.py Both _LayerNormMLP._forward and .forward gain fc1_weight_workspace/fc2_weight_workspace args and return a 4-tuple; fp8_meta passed directly instead of through module; recompute path correctly passes None for both workspaces and sets cache_weight=False. Same dead .detach() guard present.
transformer_engine/pytorch/module/grouped_linear.py Replaces module arg with weight_workspaces list + cache_weight bool; forward returns (out, new_workspaces_list); cache update done in GroupedLinear.forward. Returning a Python list as one forward output diverges from the tensor/None pattern in other modules and may present torch.compile tracing edge cases.

Sequence Diagram

sequenceDiagram
    participant M as Module.forward()
    participant FP8C as _fp8_workspaces cache
    participant AF as _Linear.forward()
    participant QW as quantize_weight()

    M->>FP8C: lookup weight_workspace (if cache_weight=True)
    FP8C-->>M: cached workspace or None
    M->>AF: weight, weight_workspace, inp, bias, non_tensor_args
    AF->>QW: tensor=weight, workspace=cached_ws, cache=cache_weight
    alt cache hit (workspace valid)
        QW->>QW: update in-place via quantize_()
        QW-->>AF: (weightmat, None)
    else cache miss / invalid
        QW->>QW: quantizer.quantize(tensor)
        QW-->>AF: (weightmat, new_workspace)
    end
    AF->>AF: GEMM computation
    AF-->>M: (out, new_weight_workspace)
    alt new_weight_workspace is not None
        M->>FP8C: store new_weight_workspace under cache_name
    end
    M-->>M: return out
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into remove_module_f..." | Re-trigger Greptile

pggPL and others added 5 commits March 23, 2026 15:31
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented Mar 30, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant