-
Notifications
You must be signed in to change notification settings - Fork 607
[PyTorch] Support user-defined op fusions #2597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Greptile SummaryThis PR introduces a public API for registering custom op fusions through Key Changes:
Architecture: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Sequential
participant OperationFuser
participant ForwardFusionFunc
participant BackwardFusionFunc
participant FusedOp
Note over User,FusedOp: Registration Phase (Module Import)
User->>OperationFuser: register_forward_fusion(CustomForwardFusion.fuse_ops)
User->>OperationFuser: register_backward_fusion(CustomBackwardFusion.fuse_ops)
Note over User,FusedOp: Forward Pass
User->>Sequential: forward(input)
Sequential->>OperationFuser: __call__(input)
OperationFuser->>OperationFuser: maybe_fuse_ops()
loop For each forward fusion function
OperationFuser->>ForwardFusionFunc: func(ops, recipe=recipe)
ForwardFusionFunc->>ForwardFusionFunc: Sliding window pattern matching
ForwardFusionFunc-->>OperationFuser: Updated ops list with fusions
end
OperationFuser->>OperationFuser: _fuse_ops() - determine basic op indices
loop For each fused forward op
OperationFuser->>FusedOp: fuser_forward(ctxs, input, ...)
FusedOp-->>OperationFuser: output
end
OperationFuser-->>Sequential: output
Sequential-->>User: output
Note over User,FusedOp: Backward Pass
User->>Sequential: backward(grad_output)
Sequential->>OperationFuser: backward()
loop For each backward fusion function
OperationFuser->>BackwardFusionFunc: func(ops, recipe=recipe)
BackwardFusionFunc->>BackwardFusionFunc: Sliding window pattern matching
BackwardFusionFunc-->>OperationFuser: Updated ops list with fusions
end
OperationFuser->>OperationFuser: _fuse_ops() - determine basic op indices
loop For each fused backward op (reversed)
OperationFuser->>FusedOp: fuser_backward(ctxs, grad_output, ...)
FusedOp-->>OperationFuser: grad_input
end
OperationFuser-->>Sequential: grad_input
Sequential-->>User: grad_input
|
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
This comment was marked as outdated.
This comment was marked as outdated.
|
/te-ci pytorch |
Description
This PR adds the
register_forward_fusionandregister_backward_fusionfunctions to the op fuser API, allowing users to register custom fusions.Type of change
Changes
Checklist: