Skip to content

Conversation

@timmoon10
Copy link
Collaborator

Description

This PR adds the register_forward_fusion and register_backward_fusion functions to the op fuser API, allowing users to register custom fusions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add function to register custom op fusions
  • Refactor op fuser to have consistent op order in forward and backward pass
  • Refactor op fusion functions to avoid index bookkeeping
  • Add tests for user-defined ops

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 requested review from ksivaman and pggPL January 14, 2026 08:28
@timmoon10 timmoon10 added the enhancement New feature or request label Jan 14, 2026
@timmoon10

This comment was marked as outdated.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

This PR introduces a public API for registering custom op fusions through register_forward_fusion() and register_backward_fusion() functions, enabling users to define domain-specific fusion patterns. The implementation refactors the existing fusion infrastructure to use a registration-based pattern instead of hardcoded fusion calls.

Key Changes:

  • Added register_forward_fusion() and register_backward_fusion() to allow users to register custom fusion functions
  • Refactored OperationFuser to maintain lists of fusion functions and apply them dynamically via _fuse_ops()
  • Converted all built-in fusions to use the new registration API in fused/__init__.py
  • Refactored all fusion functions to use a consistent sliding window pattern that eliminates manual index bookkeeping
  • Fixed backward op ordering by adding reversed() to iterate backward ops in correct order (line 237 in fuser.py)
  • Added comprehensive tests demonstrating custom forward and backward fusions

Architecture:
The new design follows a plugin architecture where fusion functions are registered globally on the OperationFuser class. During maybe_fuse_ops(), the fuser iterates through registered functions, allowing each to pattern-match and fuse operations. This enables extensibility without modifying core fuser logic.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The refactoring maintains backward compatibility by preserving all existing fusion behavior while adding extensibility. The sliding window pattern is consistent across all fusion functions, making the code more maintainable. Comprehensive tests validate both built-in and custom fusions. The critical bug fix (adding reversed() for backward ops) ensures correct gradient computation order.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Refactored to support user-defined op fusions via registration API, simplified fusion logic with consistent op ordering
transformer_engine/pytorch/ops/init.py Exported new registration functions register_forward_fusion and register_backward_fusion in public API
transformer_engine/pytorch/ops/fused/init.py Switched from exporting standalone fusion functions to registering them via new registration API
tests/pytorch/test_fusible_ops.py Added comprehensive tests for custom forward and backward fusion registration

Sequence Diagram

sequenceDiagram
    participant User
    participant Sequential
    participant OperationFuser
    participant ForwardFusionFunc
    participant BackwardFusionFunc
    participant FusedOp

    Note over User,FusedOp: Registration Phase (Module Import)
    User->>OperationFuser: register_forward_fusion(CustomForwardFusion.fuse_ops)
    User->>OperationFuser: register_backward_fusion(CustomBackwardFusion.fuse_ops)
    
    Note over User,FusedOp: Forward Pass
    User->>Sequential: forward(input)
    Sequential->>OperationFuser: __call__(input)
    OperationFuser->>OperationFuser: maybe_fuse_ops()
    
    loop For each forward fusion function
        OperationFuser->>ForwardFusionFunc: func(ops, recipe=recipe)
        ForwardFusionFunc->>ForwardFusionFunc: Sliding window pattern matching
        ForwardFusionFunc-->>OperationFuser: Updated ops list with fusions
    end
    
    OperationFuser->>OperationFuser: _fuse_ops() - determine basic op indices
    
    loop For each fused forward op
        OperationFuser->>FusedOp: fuser_forward(ctxs, input, ...)
        FusedOp-->>OperationFuser: output
    end
    
    OperationFuser-->>Sequential: output
    Sequential-->>User: output
    
    Note over User,FusedOp: Backward Pass
    User->>Sequential: backward(grad_output)
    Sequential->>OperationFuser: backward()
    
    loop For each backward fusion function
        OperationFuser->>BackwardFusionFunc: func(ops, recipe=recipe)
        BackwardFusionFunc->>BackwardFusionFunc: Sliding window pattern matching
        BackwardFusionFunc-->>OperationFuser: Updated ops list with fusions
    end
    
    OperationFuser->>OperationFuser: _fuse_ops() - determine basic op indices
    
    loop For each fused backward op (reversed)
        OperationFuser->>FusedOp: fuser_backward(ctxs, grad_output, ...)
        FusedOp-->>OperationFuser: grad_input
    end
    
    OperationFuser-->>Sequential: grad_input
    Sequential-->>User: grad_input
Loading

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10

This comment was marked as outdated.

@timmoon10 timmoon10 closed this Jan 15, 2026
@timmoon10 timmoon10 reopened this Jan 15, 2026
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant