diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ce15dd1421..b97e4cc3b7 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2329,13 +2329,13 @@ def test_backward_activation_bias( backward_ops = model._module_groups[0]._backward_ops if with_quantization: assert len(backward_ops) == 2 - assert isinstance(backward_ops[0][0], BackwardActivationBias) - assert isinstance(backward_ops[1][0], te_ops.Quantize) + assert isinstance(backward_ops[0][0], te_ops.Quantize) + assert isinstance(backward_ops[1][0], BackwardActivationBias) else: assert len(backward_ops) == 3 - assert isinstance(backward_ops[0][0], act_type) + assert isinstance(backward_ops[0][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], te_ops.Bias) - assert isinstance(backward_ops[2][0], te_ops.Quantize) + assert isinstance(backward_ops[2][0], act_type) # Expected numerical error tols = dtype_tols(dtype) @@ -2849,3 +2849,317 @@ def test_layernorm_mlp( with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) + + +class TestCustomOps: + """Test with ops that are defined externally""" + + def test_custom_basic_op( + self, + *, + shape: Iterable[int] = (7, 5), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ) -> None: + """Custom basic op""" + + class CustomScaleOp(te.ops.BasicOperation): + """Custom op that applies a learnable scale""" + + def __init__(self) -> None: + super().__init__() + self.scale: torch.nn.Parameter + scale = torch.ones((), dtype=dtype, device=device) + scale = torch.nn.Parameter(scale) + self.register_parameter("scale", scale) + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + ctx.save_for_backward(self.scale, input_) + return self.scale * input_ + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> torch.Tensor: + ( + scale, + input_, + ) = ctx.saved_tensors + grad_scale = torch.inner(input_.reshape(-1), grad_output.reshape(-1)) + grad_scale = grad_scale.reshape(()) + grad_input = scale * grad_output + return grad_input, (grad_scale,) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = w_ref * x_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = CustomScaleOp() + forward = te.ops.Sequential(te.ops.Identity(), op, te.ops.Identity()) + with torch.no_grad(): + op.scale.copy_(w_test) + del w_test + y_test = forward(x_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.scale.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + def test_custom_forward_fused_op( + self, + *, + shape: Iterable[int] = (7, 11), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in forward pass""" + + class CustomForwardLinearSiLU(te.ops.FusedOperation): + """Custom fused op for GEMM + SiLU""" + + _enabled = True + + def __init__(self, *, linear, silu) -> None: + super().__init__((linear, silu)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + **unused, + ) -> torch.Tensor: + weight = self.basic_ops[0].weight + dtype = weight.dtype + device = weight.device + + # Perform compute on CPU, because why not? + x = input_.cpu() + w = weight.cpu() + y = torch.matmul(x, w.T) + z = torch.nn.functional.silu(y) + out = z.to(device=device) + + # Save state for linear backward + linear_op_ctx = basic_op_ctxs[0] + linear_op_ctx.save_for_backward(input_, weight) + linear_op_ctx.with_quantized_compute = False + linear_op_ctx.input_quantizer = None + linear_op_ctx.weight_quantizer = None + linear_op_ctx.grad_output_quantizer = None + linear_op_ctx.grad_input_quantizer = None + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = True + linear_op_ctx.weight_requires_grad = True + + # Save state for SiLU backward + silu_op_ctx = basic_op_ctxs[1] + silu_op_ctx.save_for_backward(y.to(device=device)) + silu_op_ctx.dtype = dtype + silu_op_ctx.prev_op_grad_output_quantizer = None + + return out, [(), ()] + + @staticmethod + def fuse_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply fusion the first time this function is called""" + if CustomForwardLinearSiLU._enabled: + CustomForwardLinearSiLU._enabled = False + op = CustomForwardLinearSiLU(linear=ops[0], silu=ops[1]) + return [op] + ops[2:] + return ops + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (shape[-1], shape[-1]), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + y_ref = torch.nn.functional.silu(y_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_forward_fusion(CustomForwardLinearSiLU.fuse_ops) + model = te.ops.Sequential( + te.ops.Linear(shape[-1], shape[-1], bias=False), + te.ops.SiLU(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 1 + assert isinstance(forward_ops[0][0], CustomForwardLinearSiLU) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + def test_custom_backward_fused_op( + self, + *, + shape: Iterable[int] = (13, 5), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ): + """Custom fused op in backward pass""" + + class CustomBackwardLinearScale(te.ops.FusedOperation): + """Custom fused op for backward linear + scale""" + + _enabled: bool = True + + def __init__(self, *, scale, linear) -> None: + super().__init__((scale, linear)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + **unused, + ) -> torch.Tensor: + + # Load state from linear forward + linear_op_ctx = basic_op_ctxs[1] + x, w = linear_op_ctx.saved_tensors + dtype = linear_op_ctx.dtype + device = w.device + + # Perform compute in FP64 and apply scale before dgrad + # GEMM instead of after + scale = self.basic_ops[0].scale + dy = grad_output.double() + x = x.double() + w = w.double() + dx = torch.matmul(dy, scale * w) + dw = torch.matmul(dy.T, x) + dx = dx.to(dtype=dtype) + dw = dw.to(dtype=dtype) + + return dx, [(), (dw,)], [(), ()] + + @staticmethod + def fuse_ops( + ops: list[FusibleOperation], + **unused, + ) -> list[FusibleOperation]: + """Apply fusion the first time this function is called""" + if CustomBackwardLinearScale._enabled: + CustomBackwardLinearScale._enabled = False + op = CustomBackwardLinearScale(scale=ops[0], linear=ops[1]) + return [op] + ops[2:] + return ops + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (shape[-1], shape[-1]), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + scale = 1.234 + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(scale * x_ref, w_ref) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + te.ops.register_backward_fusion(CustomBackwardLinearScale.fuse_ops, prepend=True) + model = te.ops.Sequential( + te.ops.ConstantScale(scale), + te.ops.Linear(shape[-1], shape[-1], bias=False), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], CustomBackwardLinearScale) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index 2b270ea3de..c61b50417d 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,7 +8,8 @@ """ -from transformer_engine.pytorch.ops.basic import * -from transformer_engine.pytorch.ops.linear import Linear -from transformer_engine.pytorch.ops.op import FusibleOperation -from transformer_engine.pytorch.ops.sequential import Sequential +from .basic import * +from .fuser import register_backward_fusion, register_forward_fusion +from .linear import Linear +from .op import BasicOperation, FusedOperation, FusibleOperation +from .sequential import Sequential diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index f4568ff25d..1ebfe23060 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -4,39 +4,26 @@ """Compound tensor operation supported by the operation fuser.""" -from .backward_activation_bias import ( - BackwardActivationBias, - fuse_backward_activation_bias, -) -from .backward_add_rmsnorm import ( - BackwardAddRMSNorm, - fuse_backward_add_rmsnorm, -) -from .backward_linear_add import ( - BackwardLinearAdd, - fuse_backward_linear_add, -) -from .backward_linear_scale import ( - BackwardLinearScale, - fuse_backward_linear_scale, -) -from .forward_linear_bias_activation import ( - ForwardLinearBiasActivation, - fuse_forward_linear_bias_activation, -) -from .forward_linear_bias_add import ( - ForwardLinearBiasAdd, - fuse_forward_linear_bias_add, -) -from .forward_linear_scale_add import ( - ForwardLinearScaleAdd, - fuse_forward_linear_scale_add, -) -from .userbuffers_backward_linear import ( - UserbuffersBackwardLinear, - fuse_userbuffers_backward_linear, -) -from .userbuffers_forward_linear import ( - UserbuffersForwardLinear, - fuse_userbuffers_forward_linear, -) +from ..fuser import register_backward_fusion, register_forward_fusion +from .backward_activation_bias import BackwardActivationBias +from .backward_add_rmsnorm import BackwardAddRMSNorm +from .backward_linear_add import BackwardLinearAdd +from .backward_linear_scale import BackwardLinearScale +from .forward_linear_bias_activation import ForwardLinearBiasActivation +from .forward_linear_bias_add import ForwardLinearBiasAdd +from .forward_linear_scale_add import ForwardLinearScaleAdd +from .userbuffers_backward_linear import UserbuffersBackwardLinear +from .userbuffers_forward_linear import UserbuffersForwardLinear + +# Register forward fusions +register_forward_fusion(UserbuffersForwardLinear.fuse_forward_ops) +register_forward_fusion(ForwardLinearBiasAdd.fuse_forward_ops) +register_forward_fusion(ForwardLinearBiasActivation.fuse_forward_ops) +register_forward_fusion(ForwardLinearScaleAdd.fuse_forward_ops) + +# Register backward fusions +register_backward_fusion(UserbuffersBackwardLinear.fuse_backward_ops) +register_backward_fusion(BackwardLinearAdd.fuse_backward_ops) +register_backward_fusion(BackwardLinearScale.fuse_backward_ops) +register_backward_fusion(BackwardActivationBias.fuse_backward_ops) +register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index d5b9ce0e96..4ab082d32b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -53,8 +53,8 @@ def fuser_backward( ]: # Get basic operation contexts - activation_op_ctx = basic_op_ctxs[0] - bias_op_ctx = basic_op_ctxs[1] + bias_op_ctx = basic_op_ctxs[0] + activation_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass (act_input,) = activation_op_ctx.saved_tensors @@ -79,68 +79,59 @@ def fuser_backward( # Clear activation input tensor clear_tensor_data(act_input) - return dx, [(), (db,)], [(), ()] + return dx, [(db,), ()], [(), ()] - -def fuse_backward_activation_bias( - ops: list[tuple[FusibleOperation, list[int]]], - recipe: Optional[Recipe], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dact + dbias + quantize - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - recipe : Recipe, optional - Used quantization recipe - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Check if recipe supports bias activation fusion - if recipe is None: - return ops - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 3: + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + recipe : Recipe, optional + Quantization recipe. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Check if recipe supports bias activation fusion + if recipe is None: + return ops + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + if ( + isinstance(window[2], _fusible_activations) + and isinstance(window[1], Bias) + and window[0].get_grad_output_quantizer() is not None + ): + # Construct fused op if window matches pattern + op = BackwardActivationBias(bias=window[1], activation=window[2]) + window = [window[0], op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is a supported activation - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, _fusible_activations): - continue - - # Check if second op is bias - op, _ = ops[0] - if not isinstance(op, Bias): - continue - - # Check if third op has a grad input quantizer - op, _ = ops[1] - if not op.num_quantizers("backward") > 0: - continue - - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardActivationBias( - activation=window[0][0], - bias=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py index 186619caae..a3c81e60c8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -42,7 +42,7 @@ def fuser_backward( # Get basic operations rmsnorm_op = self.basic_ops[1] - rmsnorm_op_ctx = basic_op_ctxs[0] + rmsnorm_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass x, rstdevs = rmsnorm_op_ctx.saved_tensors @@ -53,7 +53,7 @@ def fuser_backward( # Check input tensors dtype = rmsnorm_op_ctx.dtype - extra_grad = basic_op_grad_extra_outputs[1][0] + extra_grad = basic_op_grad_extra_outputs[0][0] dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) @@ -77,57 +77,51 @@ def fuser_backward( grad_input = dx.view(grad_output.size()) grad_weight = dw.view(weight_dims) - return grad_input, [(grad_weight,), ()], [(), ()] - - -def fuse_backward_add_rmsnorm( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward RMNorm + add - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(), (grad_weight,)], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + if ( + isinstance(window[0], MakeExtraOutput) + and isinstance(window[1], RMSNorm) + and not window[0]._in_place + ): + # Construct fused op if window matches pattern + op = BackwardAddRMSNorm(add=window[0], rmsnorm=window[1]) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, RMSNorm): - continue - - # Check if second op is "make extra output" - op, _ = ops[0] - if not isinstance(op, MakeExtraOutput): - continue - if op._in_place: - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardAddRMSNorm( - rmsnorm=window[0][0], - add=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 5e7339db85..c06e212e87 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -45,7 +45,7 @@ def fuser_backward( # Get basic operations linear_op = self.basic_ops[1] - linear_op_ctx = basic_op_ctxs[0] + linear_op_ctx = basic_op_ctxs[1] # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors @@ -71,7 +71,7 @@ def fuser_backward( accumulate_into_main_grad = False # Linear backward pass - grad_input = basic_op_grad_extra_outputs[1][0] + grad_input = basic_op_grad_extra_outputs[0][0] grad_input, grad_weight = BasicLinear._functional_backward( grad_output=grad_output, input=x_local, @@ -109,61 +109,60 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - return grad_input, [(grad_weight,), ()], [(), ()] - - -def fuse_backward_linear_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dgrad GEMM + add - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(), (grad_weight,)], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + + # Check if window matches pattern + matches_pattern = True + if not (isinstance(window[0], MakeExtraOutput) and isinstance(window[1], BasicLinear)): + matches_pattern = False + elif not window[0]._in_place: + # Fused op accumulates grad input in-place + matches_pattern = False + elif window[1].tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication + # after the dgrad GEMM + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = BackwardLinearAdd(backward_add=window[0], linear=window[1]) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "column": - # Row tensor-parallelism requires communication after the - # GEMM - continue - - # Check if second op is "make extra output" - op, _ = ops[0] - if not isinstance(op, MakeExtraOutput): - continue - if not op._in_place: - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardLinearAdd( - linear=window[0][0], - backward_add=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index f7f59e65c9..709073e6f8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -45,7 +45,7 @@ def fuser_backward( # Get basic operations linear_op = self.basic_ops[0] - linear_op_ctx = basic_op_ctxs[1] + linear_op_ctx = basic_op_ctxs[0] scale_op = self.basic_ops[1] # Saved tensors from forward pass @@ -109,58 +109,57 @@ def fuser_backward( zero=getattr(weight_param, "zero_out_wgrad", False), ) - return grad_input, [(), (grad_weight,)], [(), ()] - - -def fuse_backward_linear_scale( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dgrad GEMM + constant scale - - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated backward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + return grad_input, [(grad_weight,), ()], [(), ()] + + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + + # Check if window matches pattern + matches_pattern = True + if not (isinstance(window[0], BasicLinear) and isinstance(window[1], ConstantScale)): + matches_pattern = False + elif window[0].tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication + # after the dgrad GEMM + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = BackwardLinearScale(linear=window[0], scale=window[1]) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is constant scale - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, ConstantScale): - continue - - # Check if second op is linear - op, _ = ops[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "column": - # Column tensor-parallelism requires communication after the dgrad GEMM - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = BackwardLinearScale( - scale=window[0][0], - linear=window[1][0], - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 1c5edfcfcb..dfc11a19e7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -134,62 +134,63 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] - -def fuse_forward_linear_bias_activation( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + bias + activation - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:2], ops[2:] + while len(window) == 2: + + # Check if window matches pattern + matches_pattern = True + if not (isinstance(window[0], BasicLinear) and isinstance(window[1], Bias)): + matches_pattern = False + elif window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + matches_pattern = False + elif window[0].weight.dtype not in (torch.float16, torch.bfloat16): + # cuBLAS only supports fused GEMM+bias+activation with + # FP16 and BF16 output + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardLinearBiasActivation( + linear=window[0], + bias=window[1], + activation=None, + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-1]) + window = window[-1:] + + # Adjust window to expected size + out.extend(window[:-2]) + window = window[-2:] + while ops and len(window) < 2: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op1, _ = window[0] - if not isinstance(op1, BasicLinear): - continue - if op1.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - if op1.weight.dtype not in (torch.float16, torch.bfloat16): - # cuBLAS only supports fused GEMM+bias+activation with - # FP16 and BF16 output - continue - - # Check if second op is bias - op2, _ = ops[0] - if not isinstance(op2, Bias): - continue - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearBiasActivation( - linear=window[0][0], - bias=window[1][0], - activation=None, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 4efb33e037..2dfc0566b7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -131,72 +131,63 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + out.extend(window) + window = [ops[0]] + ops = ops[1:] -def fuse_forward_linear_bias_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + bias + add - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. + # Check if first op is linear + if not isinstance(window[0], BasicLinear): + continue + if window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + continue + linear = window[0] - Returns - ------- - ops : list of tuples - Updated forward pass operations + # Check if next op is bias + bias = None + if ops and isinstance(ops[0], Bias): + window.append(ops[0]) + ops = ops[1:] + bias = window[-1] + + # Check if next op is in-place add extra input + if ops and isinstance(ops[0], AddExtraInput) and ops[0]._in_place: + window.append(ops[0]) + ops = ops[1:] + add = window[-1] + else: + continue - """ + # Replace window with fused op + op = ForwardLinearBiasAdd(linear=linear, bias=bias, add=add) + window = [op] - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 2: + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - linear = op - op, _ = ops[0] - - # Check if next op is bias - bias = None - if isinstance(op, Bias): - bias = op - window.extend(ops[:1]) - ops = ops[1:] - if len(ops) == 0: - continue - op, _ = ops[0] - - # Check if next op is in-place add extra input - if not isinstance(op, AddExtraInput): - continue - if not op._in_place: - continue - add = op - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearBiasAdd( - linear=linear, - bias=bias, - add=add, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 25b40f76e3..ae4bdd4b19 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -110,70 +110,66 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] - -def fuse_forward_linear_scale_add( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Fuse forward GEMM + scale + add - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations - - """ - - # Scan through ops, fusing if possible - out = [] - window = [] - while len(ops) >= 3: + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. + + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. + + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window, ops = ops[:3], ops[3:] + while len(window) == 3: + + # Check if window matches pattern + matches_pattern = True + if not ( + isinstance(window[0], BasicLinear) + and isinstance(window[1], ConstantScale) + and isinstance(window[2], AddExtraInput) + ): + matches_pattern = False + elif window[0].tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after + # the GEMM + matches_pattern = False + elif not window[2]._in_place: + # Fused op accumulates output in-place + matches_pattern = False + + if matches_pattern: + # Construct fused op if window matches pattern + op = ForwardLinearScaleAdd( + linear=window[0], + scale=window[1], + add=window[2], + ) + window = [op] + else: + # Shift window if window doesn't match pattern + out.extend(window[:-2]) + window = window[-2:] + + # Adjust window to expected size + out.extend(window[:-3]) + window = window[-3:] + while ops and len(window) < 3: + window.append(ops[0]) + ops = ops[1:] + + # Return list of ops out.extend(window) - - # Check if first op is linear - window, ops = ops[:1], ops[1:] - op, _ = window[0] - if not isinstance(op, BasicLinear): - continue - if op.tensor_parallel_mode == "row": - # Row tensor-parallelism requires communication after the - # GEMM - continue - linear = op - op, _ = ops[0] - - # Check if next op is constant scale - if not isinstance(op, ConstantScale): - continue - scale = op - window.extend(ops[:1]) - ops = ops[1:] - op, _ = ops[0] - - # Check if next op is in-place add extra input - if not isinstance(op, AddExtraInput): - continue - if not op._in_place: - continue - add = op - window.extend(ops[:1]) - ops = ops[1:] - - # Replace window with fused op - op = ForwardLinearScaleAdd( - linear=linear, - scale=scale, - add=add, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out.extend(window) - out.extend(ops) - return out + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 4943ffb1bd..077f2758cd 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -502,7 +502,7 @@ def fuser_backward( # Get basic operations idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[-1] + linear_op_ctx = basic_op_ctxs[0] bias_op = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] @@ -577,99 +577,84 @@ def fuser_backward( grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) - grad_params.reverse() grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs + @staticmethod + def fuse_backward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for backward pass. -def fuse_userbuffers_backward_linear( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Substitute linear operations with Userbuffers implementation + Parameters + ---------- + ops : list of FusibleOperation + Backward pass operations. + recipe : Recipe, optional + Quantization recipe. - Parameters - ---------- - ops : list of tuples - Backward pass operations and the indices of the corresponding - basic operations. + Returns + ------- + ops : list of FusibleOperation + Updated backward pass operations - Returns - ------- - ops : list of tuples - Updated backward pass operations + """ - """ + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops - # Return immediately if environment is not distributed - if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: - return ops - - # Sliding window in list of ops - window = [] - - def peek_next_op() -> Optional[FusibleOperation]: - """Get next op in list of ops""" - nonlocal ops - if not ops: - return None - return ops[-1][0] - - def pop_next_op() -> FusibleOperation: - """Remove next op from list of ops and add to sliding window""" - nonlocal ops, window - window.insert(0, ops[-1]) - ops = ops[:-1] - return window[0][0] - - # Scan through ops in reverse order, fusing if possible - out_reversed = [] - while ops: - out_reversed.extend(reversed(window)) - window.clear() - - # Check if next op is linear - next_op = pop_next_op() - if not isinstance(next_op, BasicLinear): - continue - linear = next_op - if linear._userbuffers_options is None: - continue - - # Check if next op is bias - bias = None - if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): - bias = pop_next_op() - - # Check if next op is reduce-scatter - reduce_scatter = None - if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): - reduce_scatter = pop_next_op() - - # Check for invalid combinations - if reduce_scatter is None: - if linear.tensor_parallel_mode is None: - continue - if linear.tensor_parallel_size == 1: - continue - if linear.tensor_parallel_mode == "row" and bias is not None: - continue - else: - if linear.tensor_parallel_mode is not None: + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: + + # Shift window + out.extend(window) + window, ops = ops[:1], ops[1:] + + # Check if first op is linear + if not isinstance(window[0], BasicLinear): continue - if reduce_scatter.process_group_size == 1: + linear = window[0] + if linear._userbuffers_options is None: continue - # Replace window with fused op - op = UserbuffersBackwardLinear( - linear=linear, - bias=bias, - reduce_scatter=reduce_scatter, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] - - # Return list of ops - out_reversed.extend(reversed(window)) - out = out_reversed - out.reverse() - return out + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias): + bias, ops = ops[0], ops[1:] + window.append(bias) + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter): + reduce_scatter, ops = ops[0], ops[1:] + window.append(reduce_scatter) + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersBackwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + window = [op] + + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index fe04aa1e0b..6ef9bf083b 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -369,93 +369,79 @@ def fuser_forward( return output, [() for _ in range(len(self.basic_ops))] + @staticmethod + def fuse_forward_ops( + ops: list[FusibleOperation], + **unused, # pylint: disable=unused-argument + ) -> list[FusibleOperation]: + """Apply operation fusion for forward pass. -def fuse_userbuffers_forward_linear( - ops: list[tuple[FusibleOperation, list[int]]], -) -> list[tuple[FusibleOperation, list[int]]]: - """Substitute linear operations with Userbuffers implementation - - Parameters - ---------- - ops : list of tuples - Forward pass operations and the indices of the corresponding - basic operations. - - Returns - ------- - ops : list of tuples - Updated forward pass operations + Parameters + ---------- + ops : list of FusibleOperation + Forward pass operations. - """ + Returns + ------- + ops : list of FusibleOperation + Updated forward pass operations - # Return immediately if environment is not distributed - if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: - return ops - - # Sliding window in list of ops - window = [] - - def peek_next_op() -> Optional[FusibleOperation]: - """Get next op in list of ops""" - nonlocal ops - if not ops: - return None - return ops[0][0] - - def pop_next_op() -> FusibleOperation: - """Remove next op from list of ops and add to sliding window""" - nonlocal ops, window - window.append(ops[0]) - ops = ops[1:] - return window[-1][0] - - # Scan through ops, fusing if possible - out = [] - while ops: - out.extend(window) - window.clear() + """ - # Check if next op is linear - next_op = pop_next_op() - if not isinstance(next_op, BasicLinear): - continue - linear = next_op - if linear._userbuffers_options is None: - continue + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops - # Check if next op is bias - bias = None - if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): - bias = pop_next_op() + # Scan through ops, fusing if possible + out = [] + window = [] + while ops: - # Check if next op is reduce-scatter - reduce_scatter = None - if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): - reduce_scatter = pop_next_op() + # Shift window + out.extend(window) + window, ops = ops[:1], ops[1:] - # Check for invalid combinations - if reduce_scatter is None: - if linear.tensor_parallel_mode is None: - continue - if linear.tensor_parallel_size == 1: - continue - if linear.tensor_parallel_mode == "row" and bias is not None: - continue - else: - if linear.tensor_parallel_mode is not None: + # Check if first op is linear + if not isinstance(window[0], BasicLinear): continue - if reduce_scatter.process_group_size == 1: + linear = window[0] + if linear._userbuffers_options is None: continue - # Replace window with fused op - op = UserbuffersForwardLinear( - linear=linear, - bias=bias, - reduce_scatter=reduce_scatter, - ) - basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] - window = [(op, basic_op_idxs)] + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and ops and isinstance(ops[0], Bias): + bias, ops = ops[0], ops[1:] + window.append(bias) + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and ops and isinstance(ops[0], ReduceScatter): + reduce_scatter, ops = ops[0], ops[1:] + window.append(reduce_scatter) + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersForwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + window = [op] - # Return list of ops - out.extend(window) - return out + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index bf7af48d03..7fe6ea37ed 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -5,33 +5,20 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations -from collections.abc import Callable, Iterable -from typing import Any, Optional +from collections.abc import Callable, Iterable, Sequence import itertools +from typing import Any, Optional, TypeAlias import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling -from transformer_engine.pytorch.ops.op import ( +from ..quantization import FP8GlobalStateManager, Recipe, DelayedScaling +from ..quantized_tensor import prepare_for_saving, restore_from_saved +from .op import ( BasicOperation, FusibleOperation, + FusedOperation, OperationContext, ) -from transformer_engine.pytorch.ops.fused import ( - fuse_backward_activation_bias, - fuse_backward_add_rmsnorm, - fuse_backward_linear_add, - fuse_backward_linear_scale, - fuse_forward_linear_bias_activation, - fuse_forward_linear_bias_add, - fuse_forward_linear_scale_add, - fuse_userbuffers_backward_linear, - fuse_userbuffers_forward_linear, -) -from transformer_engine.pytorch.quantized_tensor import ( - prepare_for_saving, - restore_from_saved, -) def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: @@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool: return _is_graph_capturing_function() +# Type alias for a function that may perform operation fusion +OperationFusionFunction: TypeAlias = ( + "Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]" +) + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -241,7 +234,7 @@ def backward( dx = grad_output grad_params = [None for _ in range(len(basic_ops))] grad_extra_inputs = [None for _ in range(len(basic_ops))] - for op, basic_op_idxs in backward_ops: + for op, basic_op_idxs in reversed(backward_ops): # Stop if no more gradients are required if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): @@ -315,6 +308,10 @@ class OperationFuser: """ + # Functions to perform operation fusion + forward_fusion_functions: list[OperationFusionFunction] = [] + backward_fusion_functions: list[OperationFusionFunction] = [] + def __init__( self, ops: list[FusibleOperation], @@ -334,7 +331,7 @@ def __init__( self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) - # Ops for forward and backward pass, will be populated in fuse_ops + # Ops for forward and backward pass, will be populated in maybe_fuse_ops self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] @@ -349,31 +346,48 @@ def __init__( self._flat_basic_op_params = sum(self._basic_op_params, []) @classmethod - def _fuse_forward_ops( - cls, - ops: list[tuple[FusibleOperation, list[int]]], - recipe: Optional[Recipe], # pylint: disable=unused-argument - ) -> list[tuple[FusibleOperation, list[int]]]: - """Attempt to fuse operations in forward pass""" - ops = fuse_userbuffers_forward_linear(ops) - ops = fuse_forward_linear_bias_add(ops) - ops = fuse_forward_linear_bias_activation(ops) - ops = fuse_forward_linear_scale_add(ops) - return ops - - @classmethod - def _fuse_backward_ops( + def _fuse_ops( cls, - ops: list[tuple[FusibleOperation, list[int]]], + basic_ops: Sequence[BasicOperation], + fusion_funcs: Iterable[OperationFusionFunction], recipe: Optional[Recipe], ) -> list[tuple[FusibleOperation, list[int]]]: - """Attempt to fuse operations in backward pass""" - ops = fuse_userbuffers_backward_linear(ops) - ops = fuse_backward_linear_add(ops) - ops = fuse_backward_linear_scale(ops) - ops = fuse_backward_activation_bias(ops, recipe) - ops = fuse_backward_add_rmsnorm(ops) - return ops + """Apply operation fusions""" + + # Apply op fusions + fused_ops = list(basic_ops) + for func in fusion_funcs: + fused_ops = func(fused_ops, recipe=recipe) + + def raise_mismatch_error() -> None: + """Throw error indicating invalid op fusion""" + raise RuntimeError( + "Found mismatch after fusing operations " + f"(basic_ops={[o.__class__.__name__ for o in basic_ops]}, " + f"fused_ops={[o.__class__.__name__ for o in fused_ops]})" + ) + + # Determine basic op indices corresponding to each op + out = [] + idx = 0 + for op in fused_ops: + if isinstance(op, FusedOperation): + idxs = [] + for basic_op in op.basic_ops: + if basic_op is not basic_ops[idx]: + raise_mismatch_error() + idxs.append(idx) + idx += 1 + out.append((op, idxs)) + else: + if op is not basic_ops[idx]: + raise_mismatch_error() + out.append((op, [idx])) + idx += 1 + if idx != len(basic_ops): + raise_mismatch_error() + + return out def maybe_fuse_ops( self, @@ -424,12 +438,16 @@ def maybe_fuse_ops( op.pre_first_fuser_forward() # Prepare basic op lists for fusions - forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] - backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) - - # Fuse ops - self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) - self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) + self._forward_ops = OperationFuser._fuse_ops( + self._basic_ops, + OperationFuser.forward_fusion_functions, + recipe=recipe, + ) + self._backward_ops = OperationFuser._fuse_ops( + self._basic_ops, + OperationFuser.backward_fusion_functions, + recipe=recipe, + ) # Save current fusion params self.recipe_type, self.first_op_requiring_backward = fusion_params @@ -491,3 +509,55 @@ def __call__( *extra_inputs, ) return forward_func(*args) + + +def register_forward_fusion( + op_fusion_func: OperationFusionFunction, + prepend: bool = False, +) -> None: + """Register function to perform operation fusion for forward pass. + + The fusion function should have the following signature: + + func(ops, *, recipe) -> updated ops + + Parameters + ---------- + op_fusion_func: function + Function that takes a list of operations and may substitute + them with fused operations. + prepend: bool, default = ``False`` + Whether the operation fuser should apply this fusion function + first. The default is to apply it last. + + """ + if prepend: + OperationFuser.forward_fusion_functions.insert(0, op_fusion_func) + else: + OperationFuser.forward_fusion_functions.append(op_fusion_func) + + +def register_backward_fusion( + op_fusion_func: OperationFusionFunction, + prepend: bool = False, +) -> None: + """Register function to perform operation fusion for backward pass. + + The fusion function should have the following signature: + + func(ops, *, recipe) -> updated ops + + Parameters + ---------- + op_fusion_func: function + Function that takes a list of operations and may substitute + them with fused operations. + prepend: bool, default = ``False`` + Whether the operation fuser should apply this fusion function + first. The default is to apply it last. + + """ + if prepend: + OperationFuser.backward_fusion_functions.insert(0, op_fusion_func) + else: + OperationFuser.backward_fusion_functions.append(op_fusion_func)