From d284ba89d30616dff28b27a01d07920894e0e4f7 Mon Sep 17 00:00:00 2001 From: Zhenxu Tian Date: Thu, 4 Jun 2026 15:16:59 +0800 Subject: [PATCH 1/4] Integrate elastic attention Signed-off-by: Zhenxu Tian --- .../at_shim/ATen/cuda/CUDAGeneratorImpl.h | 25 ++ .../at_shim/ATen/cuda/CUDAGraphsUtils.cuh | 20 + .../at_shim/c10/cuda/CUDAException.h | 31 ++ .../block_sparse_attn_fwd.cu | 305 +++++++++++++ custom_ops/gpu_ops/block_sparse_attn/setup.py | 169 +++++++ .../layers/attention/elastic_attn_backend.py | 365 ++++++++++++++++ .../models/qwen3_elastic/__init__.py | 169 +++++++ .../models/qwen3_elastic/config_elastic.py | 60 +++ .../models/qwen3_elastic/kernels/__init__.py | 5 + .../kernels/block_sparse_attn.py | 159 +++++++ .../qwen3_elastic/kernels/find_blocks.py | 137 ++++++ .../qwen3_elastic/kernels/xattention.py | 218 +++++++++ .../kernels/xattention_triton.py | 268 ++++++++++++ .../qwen3_elastic/modeling_elastic_qwen3.py | 412 ++++++++++++++++++ .../models/qwen3_elastic/utils.py | 128 ++++++ 15 files changed, 2471 insertions(+) create mode 100644 custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGeneratorImpl.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGraphsUtils.cuh create mode 100644 custom_ops/gpu_ops/block_sparse_attn/at_shim/c10/cuda/CUDAException.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/block_sparse_attn_fwd.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/setup.py create mode 100644 fastdeploy/model_executor/layers/attention/elastic_attn_backend.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/__init__.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/config_elastic.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/kernels/__init__.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/kernels/block_sparse_attn.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/kernels/find_blocks.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention_triton.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/modeling_elastic_qwen3.py create mode 100644 fastdeploy/model_executor/models/qwen3_elastic/utils.py diff --git a/custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGeneratorImpl.h b/custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGeneratorImpl.h new file mode 100644 index 00000000000..2c818c71a46 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGeneratorImpl.h @@ -0,0 +1,25 @@ +// Stub for ATen/cuda/CUDAGeneratorImpl.h +// Purpose: satisfy Block-Sparse-Attention's `flash.h` which only needs +// the type `at::PhiloxCudaState`. Dropout RNG is never used in inference +// (p_dropout==0), so any non-degenerate definition is fine; we only +// need it to compile. +#pragma once +#include + +namespace at { + +struct PhiloxCudaState { + uint64_t seed_ = 0; + uint64_t offset_ = 0; + bool captured_ = false; + // dummy fields to mimic upstream layout + struct Payload { uint64_t* ptr = nullptr; uint64_t val = 0; }; + Payload seed{}; + Payload offset{}; + uint64_t offset_intragraph_ = 0; + + PhiloxCudaState() = default; + PhiloxCudaState(uint64_t s, uint64_t o) : seed_(s), offset_(o) {} +}; + +} // namespace at diff --git a/custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGraphsUtils.cuh b/custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGraphsUtils.cuh new file mode 100644 index 00000000000..d20a08dcd52 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/at_shim/ATen/cuda/CUDAGraphsUtils.cuh @@ -0,0 +1,20 @@ +// Stub for ATen/cuda/CUDAGraphsUtils.cuh +// Provides at::cuda::philox::unpack with the correct *signature*. Inference +// never enters the dropout path (p_dropout==0), so the returned values are +// irrelevant — they just need to compile. +#pragma once + +#include +#include + +namespace at { namespace cuda { namespace philox { + +// Templated to accept whatever PhiloxCudaState definition wins at the call +// site (paddle compat layer may define its own). +template +__host__ __device__ inline std::tuple +unpack(const T& /*arg*/) { + return std::make_tuple(uint64_t(0), uint64_t(0)); +} + +}}} // namespace at::cuda::philox diff --git a/custom_ops/gpu_ops/block_sparse_attn/at_shim/c10/cuda/CUDAException.h b/custom_ops/gpu_ops/block_sparse_attn/at_shim/c10/cuda/CUDAException.h new file mode 100644 index 00000000000..4c398989c64 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/at_shim/c10/cuda/CUDAException.h @@ -0,0 +1,31 @@ +// Stub for torch's c10/cuda/CUDAException.h to allow BSA compilation under Paddle. +// Inference path never triggers these checks in error state; they only need to +// resolve at compile time. +#pragma once + +#include +#include +#include +#include + +#ifndef C10_CUDA_CHECK +#define C10_CUDA_CHECK(EXPR) \ + do { \ + cudaError_t __err = (EXPR); \ + if (__err != cudaSuccess) { \ + throw std::runtime_error(std::string("CUDA error: ") + \ + cudaGetErrorString(__err)); \ + } \ + } while (0) +#endif + +#ifndef C10_CUDA_KERNEL_LAUNCH_CHECK +#define C10_CUDA_KERNEL_LAUNCH_CHECK() \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + throw std::runtime_error(std::string("CUDA kernel launch error: ") + \ + cudaGetErrorString(__err)); \ + } \ + } while (0) +#endif diff --git a/custom_ops/gpu_ops/block_sparse_attn/block_sparse_attn_fwd.cu b/custom_ops/gpu_ops/block_sparse_attn/block_sparse_attn_fwd.cu new file mode 100644 index 00000000000..e316d0ec76b --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/block_sparse_attn_fwd.cu @@ -0,0 +1,305 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Paddle port of mha_varlen_fwd_block from +// Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp (Tri Dao / J. Guo). +// Forward-only (inference). Backward kernels are not compiled. +// +// Note: the directory ./src is a symlink to +// /Block-Sparse-Attention/csrc/block_sparse_attn/src +// so the headers below resolve through the symlink without copying source. + +#include "paddle/extension.h" + +#include +#include +#include +#include +#include +#include + +#include + +#include "src/namespace_config.h" +#include "src/hardware_info.h" +#include "src/flash.h" +#include "src/static_switch.h" + +#define BSA_CHECK_GPU(x) PD_CHECK((x).is_gpu(), #x " must be on GPU") +#define BSA_CHECK_DTYPE(x, dt) PD_CHECK((x).dtype() == dt, #x " has wrong dtype") + +namespace FLASH_NAMESPACE { + +static constexpr int SPARSE_SIZE = 128; + +// Forward kernel template, instantiated by src/flash_fwd_block_hdim*_*_sm80.cu. +template +void run_mha_fwd_block_(Flash_fwd_params& params, cudaStream_t stream); + +static void set_params_fprop(Flash_fwd_params& params, + size_t b, size_t seqlen_q, size_t seqlen_k, + size_t seqlen_q_rounded, size_t seqlen_k_rounded, + size_t h, size_t h_k, size_t d, size_t d_rounded, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + paddle::Tensor& out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* p_d, + void* softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right) { + params = {}; + params.is_bf16 = (q.dtype() == paddle::DataType::BFLOAT16); + + // Pointers + params.q_ptr = const_cast(q.data()); + params.k_ptr = const_cast(k.data()); + params.v_ptr = const_cast(v.data()); + params.o_ptr = out.data(); + + // Strides (in elements). Varlen layout [total_seq, num_heads, head_size] + // with last-dim contiguous. Caller must guarantee contiguity. + params.q_row_stride = static_cast(h) * d; + params.k_row_stride = static_cast(h_k) * d; + params.v_row_stride = static_cast(h_k) * d; + params.q_head_stride = d; + params.k_head_stride = d; + params.v_head_stride = d; + params.o_row_stride = static_cast(h) * d; + params.o_head_stride = d; + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = nullptr; + + params.p_ptr = p_d; + params.softmax_lse_ptr = softmax_lse_d; + + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * static_cast(M_LOG2E); + + params.p_dropout = 1.f - p_dropout; + params.p_dropout_in_uint8_t = static_cast(std::floor(params.p_dropout * 255.0f)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + PD_CHECK(p_dropout < 1.f, "p_dropout must be < 1"); + + params.is_causal = (window_size_left < 0) && (window_size_right == 0); + if (window_size_left < 0 && window_size_right >= 0) window_size_left = static_cast(seqlen_k); + if (window_size_left >= 0 && window_size_right < 0) window_size_right = static_cast(seqlen_k); + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +static void run_mha_fwd_block(Flash_fwd_params& params, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_mha_fwd_block_(params, stream); + }); + }); + }); +} + +std::vector BlockSparseAttnFwd( + const paddle::Tensor& q, // [total_q, num_heads, head_size] + const paddle::Tensor& k, // [total_k, num_heads_k, head_size] + const paddle::Tensor& v, // [total_k, num_heads_k, head_size] + const paddle::Tensor& cu_seqlens_q, // int32 [b+1] + const paddle::Tensor& cu_seqlens_k, // int32 [b+1] + const paddle::Tensor& head_mask_type, // int32 [num_heads] + const paddle::optional& streaming_info, // int32 [num_heads*2] + const paddle::optional& base_blockmask, // int32 [b, n_bs_h, sm/m, sk/n] + int max_seqlen_q, + int max_seqlen_k, + float p_dropout, + float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + int m_block_dim, + int n_block_dim, + bool exact_streaming, + bool return_softmax) { + BSA_CHECK_GPU(q); BSA_CHECK_GPU(k); BSA_CHECK_GPU(v); + BSA_CHECK_GPU(cu_seqlens_q); BSA_CHECK_GPU(cu_seqlens_k); + BSA_CHECK_GPU(head_mask_type); + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + PD_CHECK(cc_major >= 8, "BlockSparseAttention requires Ampere (sm80) or newer"); + + const auto q_dtype = q.dtype(); + PD_CHECK(q_dtype == paddle::DataType::FLOAT16 || q_dtype == paddle::DataType::BFLOAT16, + "BlockSparseAttention only supports fp16/bf16"); + BSA_CHECK_DTYPE(k, q_dtype); + BSA_CHECK_DTYPE(v, q_dtype); + BSA_CHECK_DTYPE(cu_seqlens_q, paddle::DataType::INT32); + BSA_CHECK_DTYPE(cu_seqlens_k, paddle::DataType::INT32); + BSA_CHECK_DTYPE(head_mask_type, paddle::DataType::INT32); + + const bool has_blockmask = base_blockmask.is_initialized(); + const bool has_streaming = streaming_info.is_initialized(); + if (has_blockmask) { + BSA_CHECK_GPU(base_blockmask.get()); + BSA_CHECK_DTYPE(base_blockmask.get(), paddle::DataType::INT32); + PD_CHECK(m_block_dim % SPARSE_SIZE == 0, "m_block_dim must be a multiple of 128"); + PD_CHECK(n_block_dim % SPARSE_SIZE == 0, "n_block_dim must be a multiple of 128"); + } + if (has_streaming) { + BSA_CHECK_GPU(streaming_info.get()); + BSA_CHECK_DTYPE(streaming_info.get(), paddle::DataType::INT32); + PD_CHECK(m_block_dim % SPARSE_SIZE == 0, "m_block_dim must be a multiple of 128"); + PD_CHECK(n_block_dim % SPARSE_SIZE == 0, "n_block_dim must be a multiple of 128"); + } + + const auto& q_shape = q.shape(); + const auto& k_shape = k.shape(); + PD_CHECK(q_shape.size() == 3, "q must be 3D [total_q, num_heads, head_size]"); + PD_CHECK(k_shape.size() == 3, "k must be 3D [total_k, num_heads_k, head_size]"); + + const int total_q = static_cast(q_shape[0]); + const int num_heads = static_cast(q_shape[1]); + const int head_size = static_cast(q_shape[2]); + const int total_k = static_cast(k_shape[0]); + const int num_heads_k = static_cast(k_shape[1]); + const int batch_size = static_cast(cu_seqlens_q.shape()[0]) - 1; + + PD_CHECK(batch_size > 0, "batch size must be positive"); + PD_CHECK(head_size <= 256, "head_size > 256 is not supported"); + PD_CHECK(head_size % 8 == 0, "head_size must be a multiple of 8"); + PD_CHECK(num_heads % num_heads_k == 0, "num_heads must be divisible by num_heads_k"); + + if (window_size_left >= max_seqlen_k) window_size_left = -1; + if (window_size_right >= max_seqlen_k) window_size_right = -1; + if (is_causal) window_size_right = 0; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, SPARSE_SIZE); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, SPARSE_SIZE); + + auto out = paddle::empty(q_shape, q_dtype, q.place()); + auto softmax_lse = paddle::empty({batch_size, num_heads, max_seqlen_q}, + paddle::DataType::FLOAT32, q.place()); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, head_size, head_size_rounded, + q, k, v, out, + const_cast(cu_seqlens_q.data()), + const_cast(cu_seqlens_k.data()), + /*p_d=*/nullptr, + softmax_lse.data(), + p_dropout, softmax_scale, + is_causal ? -1 : window_size_left, + is_causal ? 0 : window_size_right); + params.total_q = total_q; + + params.head_mask_type = + static_cast(const_cast(head_mask_type.data())); + + if (has_blockmask) { + params.blockmask = + static_cast(const_cast(base_blockmask.get().data())); + params.m_block_dim = m_block_dim; + params.n_block_dim = n_block_dim; + params.num_blocksparse_heads = static_cast(base_blockmask.get().shape()[1]); + } else { + params.blockmask = nullptr; + } + + if (has_streaming) { + params.streaming_info = + static_cast(const_cast(streaming_info.get().data())); + params.is_exact_streaming = exact_streaming; + params.m_block_dim = m_block_dim; + params.n_block_dim = n_block_dim; + } else { + params.streaming_info = nullptr; + params.is_exact_streaming = false; + } + + // Inference: dropout disabled. Provide a dummy non-null rng_state buffer + // because the kernel writes into it unconditionally. + static thread_local uint64_t dummy_rng_state[2] = {0, 0}; + params.rng_state = dummy_rng_state; + + cudaStream_t stream = q.stream(); + if (max_seqlen_k > 0) { + run_mha_fwd_block(params, stream); + } else { + // Manual dtype-size lookup (avoid paddle::experimental::SizeOf which + // is not available in all Paddle versions). + size_t elem_size = + (q_dtype == paddle::DataType::FLOAT16 || + q_dtype == paddle::DataType::BFLOAT16) ? 2 : 4; + cudaMemsetAsync(out.data(), 0, out.numel() * elem_size, stream); + } + + return {out, softmax_lse}; +} + +} // namespace FLASH_NAMESPACE + +// ---- Paddle op registration -------------------------------------------------- +std::vector> BlockSparseAttnFwdInferShape( + const std::vector& q_shape, + const std::vector& /*k*/, + const std::vector& /*v*/, + const std::vector& cu_q, + const std::vector& /*cu_k*/, + const std::vector& /*hmt*/) { + const int64_t b = (cu_q.empty() ? 1 : cu_q[0] - 1); + const int64_t h = q_shape.size() >= 2 ? q_shape[1] : 1; + std::vector lse_shape{b, h, /*seqlen_q*/ 1}; + return {q_shape, lse_shape}; +} + +std::vector BlockSparseAttnFwdInferDtype( + const paddle::DataType& q_dtype, + const paddle::DataType& /*k*/, + const paddle::DataType& /*v*/, + const paddle::DataType& /*cu_q*/, + const paddle::DataType& /*cu_k*/, + const paddle::DataType& /*hmt*/) { + return {q_dtype, paddle::DataType::FLOAT32}; +} + +PD_BUILD_OP(block_sparse_attn_fwd) + .Inputs({"q", "k", "v", + "cu_seqlens_q", "cu_seqlens_k", "head_mask_type", + paddle::Optional("streaming_info"), + paddle::Optional("base_blockmask")}) + .Outputs({"out", "softmax_lse"}) + .Attrs({"max_seqlen_q: int", + "max_seqlen_k: int", + "p_dropout: float", + "softmax_scale: float", + "is_causal: bool", + "window_size_left: int", + "window_size_right: int", + "m_block_dim: int", + "n_block_dim: int", + "exact_streaming: bool", + "return_softmax: bool"}) + .SetKernelFn(PD_KERNEL(FLASH_NAMESPACE::BlockSparseAttnFwd)) + .SetInferShapeFn(PD_INFER_SHAPE(BlockSparseAttnFwdInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(BlockSparseAttnFwdInferDtype)); diff --git a/custom_ops/gpu_ops/block_sparse_attn/setup.py b/custom_ops/gpu_ops/block_sparse_attn/setup.py new file mode 100644 index 00000000000..02aaaa9a227 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/setup.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Standalone build for Block-Sparse-Attention (BSA) Paddle custom op. + +BSA bundles its own CUTLASS 3.3 in + /root/paddlejob/share-storage/gpfs/system-public/tzx/SongGuo/Block-Sparse-Attention/csrc/cutlass/include +which has API conflicts with FastDeploy's newer CUTLASS in + FastDeploy/custom_ops/third_party/cutlass +Since nvcc -I flags are global per compilation, we build BSA in its own +extension (independent .so) so that it sees ONLY BSA's CUTLASS headers. + +Supported GPUs (auto-detected from CUDA toolkit version, see +``_build_gencode_flags`` below): + - sm_80 : Ampere (A100, A800) -- always emitted + - sm_90 : Hopper (H100, H800) -- CUDA >= 11.8 + - sm_100 : Blackwell (B100, B200, GB200) -- CUDA >= 12.8 +The BSA kernels themselves are sm_80-native (m16n8k16 mma, no wgmma / +tcgen05), so adding sm_90 / sm_100 only requires the gencode flags here; +no kernel rewrite is needed when migrating from H800 -> B200. + +Usage: + cd FastDeploy/custom_ops/gpu_ops/block_sparse_attn + python setup.py build_ext --inplace # local build for dev + # or + python setup.py install # system install + +Override architectures explicitly: + BLOCK_SPARSE_ATTN_CUDA_ARCHS="80;90;100" python setup.py install + +Output: block_sparse_attn_ops*.so containing the `block_sparse_attn_fwd` op. +""" +import glob +import os +import subprocess +from pathlib import Path + +from packaging.version import Version, parse + +from paddle.utils.cpp_extension import CUDAExtension, setup + + +def _nvcc_version() -> Version: + """Return CUDA toolkit version reported by nvcc -V (e.g. 12.8).""" + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + try: + out = subprocess.check_output([f"{cuda_home}/bin/nvcc", "-V"], universal_newlines=True) + tok = out.split() + idx = tok.index("release") + 1 + return parse(tok[idx].split(",")[0]) + except Exception: + # If nvcc unavailable, assume oldest supported toolchain (12.0). + return parse("12.0") + + +def _build_gencode_flags() -> list[str]: + """Mirror the upstream BSA repo's add_cuda_gencodes: + + - sm_80 always (Ampere baseline; kernels are sm80-native). + - sm_90 (Hopper, e.g. H100/H800) when CUDA >= 11.8. + - sm_100 (Blackwell, e.g. B100/B200) when CUDA >= 12.8. + * On CUDA >= 12.9 use the family-specific arch=compute_100f. + - PTX for the newest target arch for forward compatibility. + + BSA kernel sources are written with sm_80 baseline instructions + (m16n8k16 mma, no wgmma / tcgen05), so they compile cleanly for + sm_90 and sm_100 with these gencodes — no kernel rewrite needed. + """ + archs = os.environ.get("BLOCK_SPARSE_ATTN_CUDA_ARCHS", "80;90;100").split(";") + archs = {a.strip() for a in archs if a.strip()} + + cuda_ver = _nvcc_version() + flags: list[str] = [] + + if "80" in archs: + flags += ["-gencode", "arch=compute_80,code=sm_80"] + if "90" in archs and cuda_ver >= Version("11.8"): + flags += ["-gencode", "arch=compute_90,code=sm_90"] + if "100" in archs and cuda_ver >= Version("12.8"): + if cuda_ver >= Version("12.9"): + # Blackwell family-specific (introduced in CUDA 12.9). + flags += ["-gencode", "arch=compute_100f,code=sm_100"] + else: + flags += ["-gencode", "arch=compute_100,code=sm_100"] + # Embed PTX of the newest selected arch so future GPUs JIT-compile. + numeric = sorted((a for a in archs if a.isdigit()), key=int) + if numeric: + if numeric[-1] == "100" and cuda_ver < Version("12.8"): + # CUDA toolkit too old for Blackwell PTX -> fall back to sm_90 PTX. + newest = "90" if "90" in archs else "80" + else: + newest = numeric[-1] + flags += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] + return flags + +THIS_DIR = Path(os.path.dirname(os.path.abspath(__file__))) + +# BSA upstream repo provides its own CUTLASS 3.3 (incompatible with newer +# CUTLASS used elsewhere in FastDeploy). Resolve relative to this file via +# the existing `src` symlink target, so the build keeps working if the BSA +# repo lives at any path on disk. +BSA_REPO_ROOT = (THIS_DIR / "src").resolve().parents[1] # .../Block-Sparse-Attention/csrc +BSA_CUTLASS_INCLUDE = BSA_REPO_ROOT / "cutlass" / "include" +assert BSA_CUTLASS_INCLUDE.exists(), ( + f"BSA bundled CUTLASS not found at {BSA_CUTLASS_INCLUDE}; " + f"expected the upstream Block-Sparse-Attention checkout to include csrc/cutlass/include." +) + +# --- Sources: wrapper + 12 forward kernels (skip backward) ------------------- +sources = [str(THIS_DIR / "block_sparse_attn_fwd.cu")] +sources += [ + s for s in sorted(glob.glob(str(THIS_DIR / "src" / "*.cu"))) + if "flash_bwd_" not in os.path.basename(s) +] + +# --- Compile flags ----------------------------------------------------------- +nvcc_flags = [ + "-O3", + "-std=c++17", + # Enable __half/__bfloat16 native arithmetic ops used by BSA kernels. + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # Paddle's compat layer provides C10_CUDA_CHECK / CompatException but + # does NOT define C10_CUDA_KERNEL_LAUNCH_CHECK. Force-define it as a + # no-op so BSA's flash_fwd_launch_template.h compiles. + "-DC10_CUDA_KERNEL_LAUNCH_CHECK()=", + "-DENABLE_BF16", + # GPU arch gencodes. sm_80 (A100), sm_90 (H100/H800), sm_100 (B100/B200) + # are emitted automatically based on the installed CUDA toolkit. Override + # with `BLOCK_SPARSE_ATTN_CUDA_ARCHS="80;90;100"` if needed. + *_build_gencode_flags(), +] + +cxx_flags = ["-O3", "-std=c++17"] + +include_dirs = [ + str(THIS_DIR), # for at_shim/, src/, headers + str(THIS_DIR / "src"), + str(THIS_DIR / "at_shim"), # torch -> paddle stubs + str(BSA_CUTLASS_INCLUDE), # BSA bundled CUTLASS 3.3 +] + +setup( + name="block_sparse_attn_ops", + ext_modules=CUDAExtension( + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "cxx": cxx_flags, + "nvcc": nvcc_flags, + }, + ), +) diff --git a/fastdeploy/model_executor/layers/attention/elastic_attn_backend.py b/fastdeploy/model_executor/layers/attention/elastic_attn_backend.py new file mode 100644 index 00000000000..f78a18b69d7 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/elastic_attn_backend.py @@ -0,0 +1,365 @@ +"""``Qwen3ElasticAttentionBackend`` -- prefill goes through Block-Sparse-Attention, +decode keeps ``append_attention``. + +This is a thin subclass of :class:`FlashAttentionBackend` that replaces the +prefill leg of :py:meth:`forward_mixed` with the Elastic-Attention path: + + ``gqa_rope_write_cache`` → router decision (per layer) → repeat_interleave → + ``Xattention_prefill_dim4`` (Triton + BSA). + +Decode leg, ``merge_prefill_decode_output`` and all of ``init_attention_metadata`` +are reused unchanged from the parent, so chunked prefill / continuous batching +work with no extra wiring. +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import paddle + +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.flash_attn_backend import ( + FLASH_ATTN_VERSION, + FlashAttentionBackend, +) +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, + get_attn_mask_q, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + init_signal_layerwise, + pre_cache_len_concat, +) +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output +else: # pragma: no cover + merge_prefill_decode_output = None + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + + +class Qwen3ElasticAttentionBackend(FlashAttentionBackend): + """Elastic-Attention backend for PawQwen3. + + The prefill leg is overridden to call ``Xattention_prefill_dim4`` (Triton + estimator + Block-Sparse-Attention CUDA kernel) rather than dense + ``flash_attn_func``. Decode leg, KV-cache shape and PD signals are inherited. + """ + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: "ForwardMeta", + ): + # Lazy imports to avoid circular import + to keep BSA-build optional + # for the rest of the package. Note: `kernels/` and `utils.py` live + # under `fastdeploy.model_executor.models.qwen3_elastic`, NOT next to + # this backend (which sits under `layers/attention/`), so we must use + # absolute imports here. + from fastdeploy.model_executor.models.qwen3_elastic.kernels import ( + Xattention_prefill_dim4, + ) + from fastdeploy.model_executor.models.qwen3_elastic.utils import ( + ctx_q_pool, + derive_head_mask_type, + ) + + metadata = self.attention_metadata + + # ---- Same as parent: PD signals + cache addr + layer-0 metadata ---- + if self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index, + ) + + if int(os.getenv("USE_TBO", "0")) == 1: + if hasattr(forward_meta, "tbo_microbatch_id"): + if forward_meta.tbo_microbatch_id == 0: + os.environ["FLAGS_fmt_write_cache_completed_signal"] = "0" + elif forward_meta.tbo_microbatch_id == 1: + os.environ["FLAGS_fmt_write_cache_completed_signal"] = "1" + + norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False) + q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None + k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None + + cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") + if cache_quant_type_str == "block_wise_fp8": + cache_k = forward_meta.caches[4 * layer.layer_id] + cache_v = forward_meta.caches[4 * layer.layer_id + 1] + cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] + cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] + else: + cache_k = forward_meta.caches[2 * layer.layer_id] + cache_v = forward_meta.caches[2 * layer.layer_id + 1] + cache_k_scales = getattr(layer, "cache_k_scale", None) + cache_v_scales = getattr(layer, "cache_v_scale", None) + + if layer.layer_id == 0: + get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, + forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, + self.block_size, + ) + + if forward_meta.max_len_tensor_cpu[1].item() > 0: + forward_meta.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu) + forward_meta.max_len_tensor_cpu_decoder[1] = 0 + + ( + forward_meta.cu_seqlens_k, + forward_meta.pre_cache_batch_ids, + forward_meta.pre_cache_tile_ids_per_batch, + forward_meta.pre_cache_num_blocks_cpu, + forward_meta.kv_token_num_cpu, + ) = pre_cache_len_concat( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.max_len_tensor_cpu[2], + self.block_size, + ) + # Elastic prefill path doesn't use FA4 / attn_mask_q. + forward_meta.attn_mask_q = None + + use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 + + # ----------------------- prefill leg : BSA ----------------------- + if use_fa_do_prefill: + # ---- Extract pre-RoPE K for router (BEFORE gqa_rope_write_cache) ---- + # The router MLP (mask_allocator) was trained on K post-q_norm/k_norm + # but PRE-RoPE (see reference modeling_flash_qwen.py L1582-1650: + # q_norm/k_norm -> router(k) -> RoPE). Here ``qkv`` is already + # post-norm because Qwen3ElasticAttention.forward applies + # ``self.qk_norm(qkv_out)`` before calling ``self.attn``. We slice + # K out of the fused QKV tensor while it is still pre-RoPE. + # qkv layout: [T, q_size + 2*kv_size] with Q | K | V contiguous. + T_pre = qkv.shape[0] + q_size_local = layer.num_heads * layer.head_dim + kv_size_local = layer.kv_num_heads * layer.head_dim + k_pre_rope = qkv[:, q_size_local : q_size_local + kv_size_local].reshape( + [T_pre, layer.kv_num_heads, layer.head_dim] + ) + + q, k, v, _ = gqa_rope_write_cache( + qkv, + cache_k, + cache_v, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + forward_meta.rotary_embs, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + forward_meta.block_tables, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.pre_cache_batch_ids, + forward_meta.pre_cache_tile_ids_per_batch, + forward_meta.pre_cache_num_blocks_cpu, + q_norm_weight, + k_norm_weight, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + forward_meta.kv_token_num_cpu[0].item(), + self.max_seq_len, + getattr(layer, "rms_norm_eps", 1e-6), + layer.use_neox_rotary_style, + getattr(layer, "cache_quant_type_str", "none"), + self.rope_3d, + ) + # q: [total_T, num_heads, head_dim]; k, v: [total_T, kv_num_heads, head_dim] + T_total = q.shape[0] + assert T_total == T_pre, (T_total, T_pre) + + # GQA expand to match Q heads (BSA requires H_q == H_kv) + k_full = paddle.repeat_interleave(k, self.group_size, axis=1) + v_full = paddle.repeat_interleave(v, self.group_size, axis=1) + + # Number of prefill segments. ``cu_seqlens_k`` has shape [B_prefill+1]. + B_prefill = int(forward_meta.cu_seqlens_k.shape[0]) - 1 + prefill_cu = forward_meta.cu_seqlens_q[: B_prefill + 1] + + # Process each prefill segment independently: per-segment router + + # BSA call. This is required for varlen profile-runs and multi-seq + # batching; the BS=1 production path collapses to a single iteration. + seg_outs = [] + for i in range(B_prefill): + s = int(prefill_cu[i].item()) + e = int(prefill_cu[i + 1].item()) + Ti = e - s + if Ti <= 0: + continue + + seg_k_pre = k_pre_rope[s:e] # [Ti, H_kv, D] + seg_pool = ctx_q_pool(seg_k_pre) # [1, H_kv, D] + seg_z = layer.mask_allocator(seg_pool).reshape([-1]) + seg_mask = derive_head_mask_type( + seg_z, + retrieval_mode=layer.retrieval_mode, + toggle_type=layer.toggle_type, + group_size=self.group_size, + ) + + seg_q = q[s:e].transpose([1, 0, 2]).unsqueeze(0) # [1, H, Ti, D] + seg_k = k_full[s:e].transpose([1, 0, 2]).unsqueeze(0) + seg_v = v_full[s:e].transpose([1, 0, 2]).unsqueeze(0) + + seg_out = Xattention_prefill_dim4( + seg_q, seg_k, seg_v, + stride=layer.xattn_stride, + cu_seq_lens=paddle.to_tensor([0, Ti], dtype="int32"), + norm=layer.xattn_norm, + threshold=layer.xattn_threshold, + block_size=layer.block_size, + use_triton=True, + head_mask_type=seg_mask, + sink_num=layer.sink_blocks, + local_num=layer.local_blocks, + causal=True, + ) # [1, H, Ti, D] + seg_outs.append( + seg_out.transpose([0, 2, 1, 3]).reshape([Ti, self.attn_outputsize_tp]) + ) + + # Cache the *last* segment's router decision for downstream + # use (debug / metric). Sufficient for BS=1 production. + if hasattr(layer, "_z_kv_cache"): + layer._z_kv_cache.set_value(seg_z) + if hasattr(layer, "_head_mask_type_cache"): + layer._head_mask_type_cache.set_value(seg_mask) + + # # ---------- BEGIN TEMP: dump per-layer router decisions ---------- + # # Triggered only when FD_ELASTIC_DUMP_ROUTER points to a JSONL + # # path. Each prefill segment appends one record. Remove this + # # block (and the END TEMP marker below) when no longer needed. + # _dump_path = os.getenv("FD_ELASTIC_DUMP_ROUTER", "") + # if _dump_path: + # # Skip FD's profile-run / warmup prefill (which has seg_len + # # ~= max_num_batched_tokens-1, e.g. 65534 for max_model_len + # # =65536). The runner sets FD_ELASTIC_DUMP_SKIP_SEGLEN_GE + # # to a threshold above the real prompt but below the + # # warmup length. + # _skip_ge = os.getenv("FD_ELASTIC_DUMP_SKIP_SEGLEN_GE", "") + # _skip = bool(_skip_ge) and int(Ti) >= int(_skip_ge) + # if not _skip: + # import json as _json + # _record = { + # "layer_id": int(getattr(layer, "layer_id", -1)), + # "seg_idx": int(i), + # "seg_len": int(Ti), + # "z_kv": seg_z.tolist(), + # "head_mask_type": seg_mask.tolist(), + # "retrieval_mode": layer.retrieval_mode, + # "toggle_type": layer.toggle_type, + # } + # with open(_dump_path, "a", encoding="utf-8") as _df: + # _df.write(_json.dumps(_record) + "\n") + # # ---------- END TEMP: dump per-layer router decisions ---------- + + res_encoder = paddle.concat(seg_outs, axis=0) if seg_outs else paddle.zeros( + [0, self.attn_outputsize_tp], dtype=q.dtype + ) + + # ----------------------- decode leg : append_attention ----------------------- + res_decoder = append_attention( + qkv, cache_k, cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + forward_meta.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + forward_meta.attn_mask_offsets, + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "sinks", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.max_partition_size, + self.max_seq_len, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) + + if use_fa_do_prefill: + merge_prefill_decode_output( + res_encoder, res_decoder, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_heads, + self.head_dim, + self.speculate_max_draft_token_num + 1, + ) + return res_encoder + return res_decoder diff --git a/fastdeploy/model_executor/models/qwen3_elastic/__init__.py b/fastdeploy/model_executor/models/qwen3_elastic/__init__.py new file mode 100644 index 00000000000..94b9058ac41 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/__init__.py @@ -0,0 +1,169 @@ +""" +Elastic-Attention integration for Qwen3 on FastDeploy. + +Importing this package registers the ``PawQwen3ForCausalLM`` architecture +in :class:`fastdeploy.model_executor.models.model_base.ModelRegistry`. + +It also patches ``attention_selecter.get_attention_backend`` / +``_get_attn_backend`` so that :class:`Qwen3ElasticAttentionBackend` is +returned **only when the caller's fd_config corresponds to a PawQwen3 +model**. For any other architecture the selector falls through to its +original behaviour, which means dense Qwen3 (or any other FD model) is +unaffected even though FastDeploy's ``auto_models_registry`` always +imports this package on startup. +""" + +from .modeling_elastic_qwen3 import ( # noqa: F401 + AttentionRouter, + PawQwen3ForCausalLM, + Qwen3ElasticAttention, + Qwen3ElasticDecoderLayer, + Qwen3ElasticModel, +) + +# ---- Architecture-aware elastic attention backend patch ---- +# FastDeploy's default dispatch only consults +# ``current_platform.get_attention_backend_cls`` (which on CUDA returns +# ``AppendAttentionBackend``) and ignores the model class's +# ``_get_attn_backend_cls``. Without this patch the elastic backend is +# dead code and PawQwen3 silently runs as plain dense Qwen3. +# +# However, ``fastdeploy.model_executor.models.__init__.auto_models_registry`` +# walks every model package and triggers this package's import for ANY +# launch (including dense Qwen3). A blind global patch would therefore +# return the elastic backend for dense models too -- whose Attention +# layers lack ``mask_allocator`` -- causing AttributeError. +# +# Solution: patch the selector but make it architecture-aware. We walk +# the call stack to find the caller's ``self.fd_config`` (every call site +# in FD lives on an object that owns ``fd_config``), and only return the +# elastic backend when ``architectures[0] == "PawQwen3ForCausalLM"``. +# Other models fall through to the original selector unchanged. +import sys as _sys # noqa: E402 + +from fastdeploy.model_executor.layers.attention import ( # noqa: E402 + attention_selecter as _selecter, +) +from fastdeploy.model_executor.layers.attention.elastic_attn_backend import ( # noqa: E402 + Qwen3ElasticAttentionBackend as _ElasticBackend, +) + +_orig_get_attn_backend = _selecter._get_attn_backend +_orig_get_attention_backend = _selecter.get_attention_backend + + +def _caller_arch(): + """Walk the call stack to find ``self.fd_config.model_config.architectures``.""" + frame = _sys._getframe(2) # skip this fn + the patched selector fn + while frame is not None: + local_self = frame.f_locals.get("self") + if local_self is not None: + fd_config = getattr(local_self, "fd_config", None) + if fd_config is not None: + model_config = getattr(fd_config, "model_config", None) + archs = getattr(model_config, "architectures", None) or [] + if archs: + return archs[0] + frame = frame.f_back + return None + + +def _patched_get_attn_backend(selected_backend=None): + if _caller_arch() == "PawQwen3ForCausalLM": + return _ElasticBackend + return _orig_get_attn_backend(selected_backend) + + +def _patched_get_attention_backend(): + if _caller_arch() == "PawQwen3ForCausalLM": + return _ElasticBackend + return _orig_get_attention_backend() + + +try: + _orig_get_attn_backend.cache_clear() +except AttributeError: + pass +_selecter._get_attn_backend = _patched_get_attn_backend +_selecter.get_attention_backend = _patched_get_attention_backend + + +# ---- Force Qwen-style RoPE for PawQwen3 architecture ---- +# ``InputBatch`` builds ``rope_emb`` during ``GpuModelRunner.__init__``, +# which runs **before** ``PawQwen3ForCausalLM.__init__`` gets a chance to +# rewrite ``architectures[0]``. At that point the architecture name is +# still ``"PawQwen3ForCausalLM"`` (does not start with "Qwen"), so +# ``get_rope_impl`` falls through to ``ErnieRotaryEmbedding`` which +# produces ``rope_emb`` with last-dim ``head_dim/2 = 64``. The neox-style +# ``gqa_rope_write_cache`` kernel then asserts +# ``rotary_embs.dims()[4] == head_dim`` (128) or ``head_dim/4`` (32) and +# crashes. +# +# Additionally, the PawQwen3 4B / 64K / 262K checkpoints ship with +# ``rope_scaling = {"type": "yarn", "factor": 8.0, +# "original_max_position_embeddings": 40960}`` and were TRAINED with YaRN. +# The plain ``QwenRotaryEmbedding`` ignores ``rope_scaling`` and produces +# vanilla RoPE with no inv-freq interpolation and no ``mscale`` magnitude +# correction. The mismatch between training-time YaRN cos/sin and +# inference-time vanilla cos/sin is enough to flip the per-layer K +# distribution to which the elastic router/attention is highly +# sensitive, producing pure-garbage outputs (``"The 』The 』..."``). +# +# Patch ``get_rope_impl`` so PawQwen3 + yarn rope_scaling routes through +# ``GptOssScalingRotaryEmbedding`` (``use_neox_rotary_style=True``), which +# implements the same YaRN math as DeepseekScalingRotaryEmbedding but +# emits the ``(2, 1, T, 1, head_dim)`` rope_emb layout that the neox +# ``gqa_rope_write_cache`` / ``append_attention`` kernels expect. +from fastdeploy.model_executor.layers import rotary_embedding as _rope_mod # noqa: E402 + +_orig_get_rope_impl = _rope_mod.get_rope_impl + + +def _is_pawqwen3(model_config): + archs = getattr(model_config, "architectures", None) or [] + return bool(archs) and "Qwen" in archs[0] and not archs[0].startswith("Qwen") + + +def _yarn_rope_scaling(model_config): + rs = getattr(model_config, "rope_scaling", None) + if not isinstance(rs, dict): + return None + rope_type = rs.get("rope_type") or rs.get("type") + if rope_type != "yarn": + return None + return rs + + +def _patched_get_rope_impl(rotary_dim, base, position_ids, model_config=None, partial_rotary_factor=1): + if _is_pawqwen3(model_config): + rs = _yarn_rope_scaling(model_config) + if rs is not None: + # Build YaRN cos/sin cache on the fly. Matches training-time + # transformers ``Qwen3RotaryEmbedding`` with rope_type=yarn. + yarn_layer = _rope_mod.GptOssScalingRotaryEmbedding( + rotary_dim=model_config.head_dim, + base=model_config.rope_theta, + original_max_position_embeddings=int(rs["original_max_position_embeddings"]), + scale=float(rs["factor"]), + mscale=float(rs.get("mscale", 1.0)), + attn_factor=float(rs.get("attn_factor", 1.0)), + beta_fast=int(rs.get("beta_fast", 32)), + beta_slow=int(rs.get("beta_slow", 1)), + extrapolation_factor=float(rs.get("extrapolation_factor", 1.0)), + use_neox_rotary_style=True, + ) + return yarn_layer(position_ids) + # No yarn scaling -> fall through to plain Qwen RoPE by temporarily + # prefixing the architecture name so the upstream impl picks + # ``QwenRotaryEmbedding``. + original = model_config.architectures[0] + try: + model_config.architectures[0] = "Qwen3" + original + return _orig_get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor) + finally: + model_config.architectures[0] = original + return _orig_get_rope_impl(rotary_dim, base, position_ids, model_config, partial_rotary_factor) + + +_rope_mod.get_rope_impl = _patched_get_rope_impl + diff --git a/fastdeploy/model_executor/models/qwen3_elastic/config_elastic.py b/fastdeploy/model_executor/models/qwen3_elastic/config_elastic.py new file mode 100644 index 00000000000..035698a5f17 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/config_elastic.py @@ -0,0 +1,60 @@ +"""Bridge ckpt config.json fields into FastDeploy ``model_config``. + +ckpt fields (PawQwen3) that runtime needs: + +- toggle_type / retrieval_mode / enable_ada_sparsity +- pooling_mode / use_softmax +- sink_size / local_window_size +- xattn_stride / xattn_threshold / xattn_norm +- block_size + +Standard Qwen3 fields (hidden_size, head_dim, num_kv_heads, RoPE/YaRN, ...) +are handled by paddleformers' ``Qwen3Config`` already; this module only adds +the ELASTIC fields and provides default values for missing entries. +""" + +from __future__ import annotations + +# Mapping: model_config attribute name -> (ckpt field name, default value) +ELASTIC_CONFIG_FIELDS = { + "local_window_size": ("local_window_size", 2048), + "sink_size": ("sink_size", 128), + "toggle_type": ("toggle_type", "xattn"), + "retrieval_mode": ("retrieval_mode", "full"), + "enable_ada_sparsity": ("enable_ada_sparsity", True), + "pooling_mode": ("pooling_mode", "ctx_q"), + "use_softmax": ("use_softmax", True), + "xattn_stride": ("xattn_stride", 16), + "xattn_threshold": ("xattn_threshold", 0.9), + "xattn_norm": ("xattn_norm", 1), + "block_size": ("block_size", 128), +} + +# Training-only fields that runtime must silently ignore: +ELASTIC_TRAIN_ONLY = { + "enable_lambda_task", + "enable_layerwise_sparsity", + "disable_linear_regularization_term", + "layerwise_sparsity_first", + "layerwise_sparsity_last", + "layerwise_sparsity_pattern", + "erank_analysis_path", + "suggested_sparsity", + "suggested_threshold", + "topk_k", + "triangle_n_last", + "use_task_emb_for_mask", + "pooling_seq", + "max_window_layers", +} + + +def populate_elastic_fields(model_config) -> None: + """Read ELASTIC_CONFIG_FIELDS off ``pretrained_config`` and lift them to + ``model_config`` attributes, falling back to defaults. Idempotent.""" + raw = getattr(model_config, "pretrained_config", None) or model_config + for attr, (ckpt_key, default) in ELASTIC_CONFIG_FIELDS.items(): + if hasattr(model_config, attr): + continue + val = getattr(raw, ckpt_key, default) + setattr(model_config, attr, val) diff --git a/fastdeploy/model_executor/models/qwen3_elastic/kernels/__init__.py b/fastdeploy/model_executor/models/qwen3_elastic/kernels/__init__.py new file mode 100644 index 00000000000..b7e8c6dc049 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/kernels/__init__.py @@ -0,0 +1,5 @@ +"""Elastic-Attention kernel wrappers (paddle).""" + +from .block_sparse_attn import block_sparse_attn_paddle # noqa: F401 +from .find_blocks import find_blocks_chunked # noqa: F401 +from .xattention import Xattention_prefill_dim4, xattn_estimate # noqa: F401 diff --git a/fastdeploy/model_executor/models/qwen3_elastic/kernels/block_sparse_attn.py b/fastdeploy/model_executor/models/qwen3_elastic/kernels/block_sparse_attn.py new file mode 100644 index 00000000000..87e5a28d97f --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/kernels/block_sparse_attn.py @@ -0,0 +1,159 @@ +"""Thin paddle wrapper around the Block-Sparse-Attention CUDA custom op. + +The CUDA kernels (Block-Sparse-Attention/csrc/*) are compiled into a +**standalone** Paddle extension ``block_sparse_attn_ops`` via +``custom_ops/gpu_ops/block_sparse_attn/setup.py``. They are NOT part of the +main ``fastdeploy_ops`` build because BSA bundles its own (incompatible) +CUTLASS version. After building, the op is exposed as +``block_sparse_attn_ops.block_sparse_attn_fwd``. + +Signature mirrors PyTorch ``block_sparse_attn_func`` (forward only): + + block_sparse_attn_fwd( + q, k, v, # [total_T, H, D] + cu_seqlens_q, cu_seqlens_k, # int32 [B+1] + head_mask_type, # int32 [H] + streaming_info, # int32 [2*H] or None + base_blockmask, # bool [B,H,Qb,Kb] + max_seqlen_q, max_seqlen_k, + p_dropout, softmax_scale, + is_causal, exact_streaming, deterministic, + ) -> attn_out [total_T, H, D] +""" + +from __future__ import annotations + +import paddle + + +def _import_bsa(): + """Lazy import so import-time errors don't kill the package when BSA is + not yet compiled (lets unit tests for router etc. still run).""" + try: + # BSA is built as a STANDALONE Paddle extension (`block_sparse_attn_ops`) + # via custom_ops/gpu_ops/block_sparse_attn/setup.py — it is NOT merged + # into the main `fastdeploy_ops` build because BSA bundles its own + # CUTLASS 3.3 which conflicts with FastDeploy's newer CUTLASS. + from block_sparse_attn_ops import block_sparse_attn_fwd + return block_sparse_attn_fwd + except Exception as e: # pragma: no cover - depends on build + raise RuntimeError( + "block_sparse_attn_fwd custom op not available. Build & install it via " + "`cd FastDeploy/custom_ops/gpu_ops/block_sparse_attn && python setup.py install`." + ) from e + + +def _replace_ones_with_count(head_mask_type: paddle.Tensor): + """Replace each 1 in head_mask_type with its sequential 1-based count. + + The CUDA kernel indexes blockmask via ``(mask_type - 1)`` as the sparse-head + axis. All sparse heads naively share mask_type=1, so they would all read the + same blockmask row. This function assigns unique indices 1, 2, 3, ... to + each sparse head from left to right, mirroring PyTorch + ``block_sparse_attn_interface.replace_ones_with_count``. + + Returns: (modified_head_mask_type, num_sparse_heads_int) + """ + ones_mask = (head_mask_type == 1) + num_sparse = int(ones_mask.sum().item()) + if num_sparse == 0: + return head_mask_type, 0 + # cumsum gives sequential 1, 2, 3, ... at positions of 1s; 0 elsewhere + count = paddle.cumsum(ones_mask.astype("int32"), axis=-1).astype("int32") * ones_mask.astype("int32") + result = paddle.where(ones_mask, count, head_mask_type) + return result, num_sparse + + +def _convert_blockmask_row_reverse(blockmask: paddle.Tensor) -> paddle.Tensor: + """Convert boolean blockmask to sorted-descending K-block indices. + + Input: [B, H_sparse, Qb, Kb] bool (True = this K-block is attended to) + Output: [B, H_sparse, Qb, Kb] int32 where each row is a sorted-descending + list of K-block indices; padding positions contain -1. + + The CUDA binary-search (``fwdBlockmask::max_no_larger``) requires the row to + be sorted in descending order so that it can binary-search for the largest + K-block index <= the current causal bound. Padding is -1. + + Mirrors PyTorch ``block_sparse_attn_interface.convert_blockmask_row_reverse``. + """ + # Cast bool → int32: sort doesn't operate on bool reliably + bm = blockmask.astype("int32") + # Argsort ascending along K-block axis: 0s land first, 1s land last + sorted_idx = paddle.argsort(bm, axis=-1, stable=True, descending=False) + sorted_vals = paddle.sort(bm, axis=-1, stable=True, descending=False) + # Positions whose sorted value is 0 are padding → mark as -1 + sorted_idx = paddle.where( + sorted_vals == 0, + paddle.full_like(sorted_idx, -1), + sorted_idx, + ) + # Flip to descending order: largest valid K-block index first, -1s at end + return paddle.flip(sorted_idx, axis=[-1]).astype("int32").contiguous() + + +@paddle.no_grad() +def block_sparse_attn_paddle( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + head_mask_type: paddle.Tensor, + streaming_info, + base_blockmask: paddle.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + p_dropout: float = 0.0, + softmax_scale: float | None = None, + is_causal: bool = True, + window_size_left: int = -1, + window_size_right: int = -1, + m_block_dim: int = 128, + n_block_dim: int = 128, + exact_streaming: bool = False, + deterministic: bool = True, + return_softmax: bool = False, +): + if softmax_scale is None: + softmax_scale = float(q.shape[-1]) ** -0.5 + fwd = _import_bsa() + + # Give each sparse head a unique 1-based index so the kernel can address + # each head's own blockmask row via (mask_type - 1). + # Mirrors PyTorch block_sparse_attn_func::replace_ones_with_count. + head_mask_type, _ = _replace_ones_with_count(head_mask_type) + + # Convert boolean blockmask to sorted-descending K-block index format + # expected by the CUDA binary-search iterator. + # Mirrors PyTorch BlockSparseAttnFunc.forward::convert_blockmask_row_reverse. + if base_blockmask is not None: + base_blockmask = _convert_blockmask_row_reverse(base_blockmask) + + if is_causal: + window_size_right = 0 + out = fwd( + q.contiguous() if not q.is_contiguous() else q, + k.contiguous() if not k.is_contiguous() else k, + v.contiguous() if not v.is_contiguous() else v, + cu_seqlens_q, + cu_seqlens_k, + head_mask_type, + streaming_info, + base_blockmask, + int(max_seqlen_q), + int(max_seqlen_k), + float(p_dropout), + float(softmax_scale), + bool(is_causal), + int(window_size_left), + int(window_size_right), + int(m_block_dim), + int(n_block_dim), + bool(exact_streaming), + bool(return_softmax), + ) + # Op returns [out, softmax_lse]; xattention only needs out. + if isinstance(out, (list, tuple)): + return out[0] + return out diff --git a/fastdeploy/model_executor/models/qwen3_elastic/kernels/find_blocks.py b/fastdeploy/model_executor/models/qwen3_elastic/kernels/find_blocks.py new file mode 100644 index 00000000000..a87670d4e44 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/kernels/find_blocks.py @@ -0,0 +1,137 @@ +"""Paddle port of ``elasticattn.src.utils.find_blocks_chunked``. + +Selects which key blocks each query block needs to attend to, based on a +threshold on cumulative attention mass. Inference path: ``mode='prefill'``, +``decoding=False``, ``causal=True``. +""" + +from __future__ import annotations + +import paddle + + +def find_blocks_chunked( + input_tensor: paddle.Tensor, + current_index: int, + threshold, + num_to_choose, + decoding: bool, + mode: str = "both", + causal: bool = True, +) -> paddle.Tensor: + """Threshold-cumulative block selector. + + Args: + input_tensor: [B, H, Qchunk, Kblocks] attention sums per block. + current_index: index of the first query block w.r.t. K. + threshold: float in (0,1] -- min cumulative mass to keep. + num_to_choose: alternative to threshold (unsupported here). + decoding: True if running decode path. + mode: 'both' / 'prefill' / 'decode'. + causal: apply causal block mask. + + Returns: + bool tensor [B, H, Qchunk, Kblocks]. + """ + assert threshold is None or num_to_choose is None + batch_size, head_num, chunk_num, block_num = input_tensor.shape + + if mode == "prefill" and decoding: + return paddle.ones_like(input_tensor, dtype="bool") + if mode == "decode" and not decoding: + mask = paddle.ones_like(input_tensor, dtype="bool") + return mask + + input_tensor = input_tensor.astype("float32") + + if threshold is None: + raise NotImplementedError("block num chunk prefill not implemented") + + total_sum = input_tensor.sum(axis=-1, keepdim=True) + if isinstance(threshold, paddle.Tensor): + thr = threshold.astype("float32") + required_sum = total_sum * thr.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand( + [batch_size, head_num, chunk_num, 1] + ) + else: + required_sum = total_sum * float(threshold) + + if causal: + mask = paddle.zeros_like(input_tensor, dtype="bool") + # Always keep block 0 (sink) and the diagonal block. + mask[:, :, :, 0] = True + eye = paddle.eye(chunk_num, dtype="int32").astype("bool").unsqueeze(0).unsqueeze(0) + eye = eye.expand([batch_size, head_num, chunk_num, chunk_num]) + # set mask[:, :, :, current_index : current_index + chunk_num] = eye + diag_slice = paddle.zeros_like(input_tensor, dtype="bool") + diag_slice[:, :, :, current_index : current_index + chunk_num] = eye + mask = mask | diag_slice + + # zero-out mass that's already covered by `mask`, then sort the rest. + other_values = paddle.where(mask, paddle.zeros_like(input_tensor), input_tensor) + sorted_values = paddle.sort(other_values, axis=-1, descending=True) + + # Prepend a column of zeros and the mass already retained. + retained_mass = paddle.where(mask, input_tensor, paddle.zeros_like(input_tensor)).sum( + axis=-1, keepdim=True + ) + zeros_col = paddle.zeros( + [batch_size, head_num, chunk_num, 1], dtype="float32" + ) + sorted_values = paddle.concat( + [zeros_col, retained_mass, sorted_values[:, :, :, :-2]], axis=-1 + ) + + # Argsort indices: force already-selected entries to the front. + boosted = paddle.where(mask, 100000.0 * (1.0 + input_tensor), input_tensor) + index = paddle.argsort(boosted, axis=-1, descending=True) + + cumulative_sum = paddle.concat( + [zeros_col, sorted_values[:, :, :, :-1]], axis=-1 + ).cumsum(axis=-1) + index_mask = cumulative_sum < required_sum + # zero out indices we don't keep -> default to block 0 (already True) + index = paddle.where(index_mask, index, paddle.zeros_like(index)) + + # Scatter: mask[b,h,q,index[b,h,q,:]] = True. + # NOTE: paddle GPU put_along_axis has no bool kernel; do the scatter + # in int32 then cast back. + flat_mask = mask.reshape([batch_size, head_num * chunk_num, block_num]).astype("int32") + flat_idx = index.reshape([batch_size, head_num * chunk_num, block_num]) + true_vals = paddle.ones_like(flat_idx, dtype="int32") + flat_mask = paddle.put_along_axis( + flat_mask, flat_idx, true_vals, axis=-1, reduce="assign" + ) + mask = flat_mask.reshape([batch_size, head_num, chunk_num, block_num]).astype("bool") + else: + mask = paddle.zeros_like(input_tensor, dtype="bool") + sorted_values = paddle.sort(input_tensor, axis=-1, descending=True) + index = paddle.argsort(input_tensor, axis=-1, descending=True) + zeros_col = paddle.zeros( + [batch_size, head_num, chunk_num, 1], dtype="float32" + ) + cumulative_sum = paddle.concat( + [zeros_col, sorted_values[:, :, :, :-1]], axis=-1 + ).cumsum(axis=-1) + index_mask = cumulative_sum < required_sum + index = paddle.where(index_mask, index, paddle.zeros_like(index)) + + flat_mask = mask.reshape([batch_size, head_num * chunk_num, block_num]).astype("int32") + flat_idx = index.reshape([batch_size, head_num * chunk_num, block_num]) + true_vals = paddle.ones_like(flat_idx, dtype="int32") + flat_mask = paddle.put_along_axis( + flat_mask, flat_idx, true_vals, axis=-1, reduce="assign" + ) + mask = flat_mask.reshape([batch_size, head_num, chunk_num, block_num]).astype("bool") + + if causal: + # any out-of-causal entries set to False + if current_index + chunk_num < block_num: + zero_pad = paddle.zeros( + [batch_size, head_num, chunk_num, block_num - (current_index + chunk_num)], + dtype="bool", + ) + mask = paddle.concat( + [mask[:, :, :, : current_index + chunk_num], zero_pad], axis=-1 + ) + return mask diff --git a/fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention.py b/fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention.py new file mode 100644 index 00000000000..aab4a176da0 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention.py @@ -0,0 +1,218 @@ +"""Paddle port of ``Xattention_prefill_dim4`` (BS=1 / varlen path). + +Pipeline: + q,k: [1, H, T, D] (post k_norm, post RoPE) + -> chunked Triton GEMM + softmax block-sum -> [1,H,Qb,Kb] block sum + -> find_blocks_chunked (threshold) -> [1,H,Qb,Kb] bool mask + -> block_sparse_attn (paddle custom op) -> attn_out [T,H,D] + -> reshape back to [1,H,T,D] +""" + +from __future__ import annotations + +import math + +import paddle +import paddle.nn.functional as F + +from .block_sparse_attn import block_sparse_attn_paddle +from .find_blocks import find_blocks_chunked +from .xattention_triton import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum + + +def _pad_seq(x: paddle.Tensor, num_to_pad: int) -> paddle.Tensor: + """Pad seq dim (axis=2) of [B,H,T,D] with zeros.""" + if num_to_pad <= 0: + return x + return F.pad(x, [0, 0, 0, num_to_pad], value=0, data_format="NCHW") + + +def xattn_estimate( + query_states: paddle.Tensor, # [1, H, q_len, D] + key_states: paddle.Tensor, # [1, H, k_len, D] + block_size: int, + stride: int, + norm: float = 1.0, + threshold: float = 0.9, + chunk_size: int = 16384, + use_triton: bool = True, + causal: bool = True, + keep_sink: bool = False, + keep_recent: bool = False, +): + assert use_triton, "paddle port only supports the Triton path" + batch_size, num_q_head, q_len, head_dim = query_states.shape + _, num_kv_head, k_len, _ = key_states.shape + assert num_q_head == num_kv_head + + k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len + q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len + k_chunk_num = (k_len + k_num_to_pad) // chunk_size + k_block_num = (k_len + k_num_to_pad) // block_size + q_chunk_num = (q_len + q_num_to_pad) // chunk_size + q_block_num = (q_len + q_num_to_pad) // block_size + + assert k_chunk_num >= q_chunk_num + + pad_q = _pad_seq(query_states, q_num_to_pad) if q_num_to_pad > 0 else query_states + pad_k = _pad_seq(key_states, k_num_to_pad) if k_num_to_pad > 0 else key_states + + reshaped_chunk_size = chunk_size // stride + reshaped_block_size = block_size // stride + k_reshaped_num_to_pad = k_num_to_pad // stride + k_reshaped_seq_len = (k_len + k_num_to_pad) // stride + num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size + + attn_sum_list = [] + simple_mask_list = [] + + scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm + + for chunk_idx in range(q_chunk_num): + q_start = chunk_idx * reshaped_chunk_size * stride + q_end = q_start + reshaped_chunk_size * stride + chunk_q = pad_q[:, :, q_start:q_end, :] + + attn_weights_slice = flat_group_gemm_fuse_reshape( + chunk_q, + pad_k, + stride, + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size, + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size, + is_causal=causal, + ) + attn_sum = softmax_fuse_block_sum( + attn_weights_slice, + reshaped_block_size, + min(4096, reshaped_block_size), + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size, + (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size, + k_reshaped_seq_len - k_reshaped_num_to_pad, + scale, + is_causal=causal, + ) + + simple_mask = find_blocks_chunked( + attn_sum, + k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ) + attn_sum_list.append(attn_sum) + simple_mask_list.append(simple_mask) + + attn_sums = paddle.concat(attn_sum_list, axis=-2) + simple_masks = paddle.concat(simple_mask_list, axis=-2) + + if causal: + mask_size = min(q_block_num, simple_masks.shape[-1]) + if mask_size > 0: + tri = paddle.triu( + paddle.ones([mask_size, mask_size], dtype="bool"), diagonal=1 + ) + causal_block_mask = paddle.logical_not(tri) + sub = simple_masks[:, :, -mask_size:, -mask_size:] + simple_masks[:, :, -mask_size:, -mask_size:] = paddle.logical_and(sub, causal_block_mask) + if keep_sink: + simple_masks[:, :, 0, :] = True + if keep_recent: + eye = paddle.eye(q_block_num, dtype="int32").astype("bool").unsqueeze(0).unsqueeze(0) + eye = eye.expand([1, num_kv_head, q_block_num, q_block_num]) + sub = simple_masks[:, :, -q_block_num:, -q_block_num:] + simple_masks[:, :, -q_block_num:, -q_block_num:] = paddle.where(eye, paddle.ones_like(sub), sub) + + return attn_sums, simple_masks + + +@paddle.no_grad() +def Xattention_prefill_dim4( + query_states: paddle.Tensor, # [1, H, T, D] + key_states: paddle.Tensor, # [1, H, T, D] + value_states: paddle.Tensor, # [1, H, T, D] + stride: int, + cu_seq_lens: paddle.Tensor, # int32 [B+1]; BS=1 => [0, T] + norm: float = 1.0, + threshold: float = 0.8, + block_size: int = 128, + use_triton: bool = True, + causal: bool = True, + chunk_size: int | None = None, + keep_sink: bool = False, + keep_recent: bool = False, + head_mask_type: paddle.Tensor | None = None, + sink_num: int = 1, + local_num: int = 16, +) -> paddle.Tensor: + batch_size, num_heads, max_q_len, head_dim = query_states.shape + _, _, max_k_len, _ = key_states.shape + assert batch_size == 1, "this paddle port targets BS=1 only (FastDeploy varlen)" + + valid_len = int(cu_seq_lens[1].item()) - int(cu_seq_lens[0].item()) + + cur_q = query_states[:, :, :valid_len, :] + cur_k = key_states[:, :, :valid_len, :] + cur_klen = cur_k.shape[2] + if chunk_size is None: + chunk_size = max( + min( + max(2048, 1 << (cur_klen - 1).bit_length()), + 128 * 1024 * 2048 // (1 << (cur_klen - 1).bit_length()), + ), + 2048, + ) + + _, approx_mask = xattn_estimate( + cur_q, cur_k, + block_size=block_size, stride=stride, norm=norm, threshold=threshold, + use_triton=True, causal=causal, chunk_size=chunk_size, + keep_sink=keep_sink, keep_recent=keep_recent, + ) + + valid_q_blocks = (valid_len + block_size - 1) // block_size + valid_k_blocks = (valid_len + block_size - 1) // block_size + approx_mask[:, :, valid_q_blocks:, :] = False + approx_mask[:, :, :, valid_k_blocks:] = False + + # ---- BSA expects [total_T, H, D] varlen layout ---- + total_T = int(cu_seq_lens[-1].item()) + # query/key/value [1,H,T,D] -> [T,H,D] + q_var = query_states.squeeze(0).transpose([1, 0, 2])[:total_T].contiguous() + k_var = key_states.squeeze(0).transpose([1, 0, 2])[:total_T].contiguous() + v_var = value_states.squeeze(0).transpose([1, 0, 2])[:total_T].contiguous() + + if head_mask_type is None: + head_mask_type = paddle.ones([num_heads], dtype="int32") + + max_q_block_num = (max_q_len + block_size - 1) // block_size + max_k_block_num = (max_k_len + block_size - 1) // block_size + + sparse_mask_idx = paddle.nonzero(head_mask_type == 1).reshape([-1]) + if sparse_mask_idx.shape[0] > 0: + blockmask = paddle.index_select(approx_mask, sparse_mask_idx, axis=1) + blockmask = blockmask[:, :, :max_q_block_num, :max_k_block_num].contiguous() + else: + # No sparse heads -- BSA still wants a tensor; pass empty. + blockmask = paddle.ones( + [1, 0, max_q_block_num, max_k_block_num], dtype="bool" + ) + + streaming_info = paddle.to_tensor( + [sink_num, local_num] * num_heads, dtype="int32" + ) + + attn_out = block_sparse_attn_paddle( + q_var, k_var, v_var, + cu_seq_lens, cu_seq_lens, + head_mask_type, streaming_info, blockmask, + max_q_len, max_k_len, + p_dropout=0.0, deterministic=True, is_causal=causal, + m_block_dim=block_size, n_block_dim=block_size, + ) # [total_T, H, D] + + # Back to [1,H,T,D] padded. + out = paddle.zeros([1, num_heads, max_q_len, head_dim], dtype=attn_out.dtype) + out[0, :, :total_T, :] = attn_out.transpose([1, 0, 2]) + return out diff --git a/fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention_triton.py b/fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention_triton.py new file mode 100644 index 00000000000..66ccf29ed32 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/kernels/xattention_triton.py @@ -0,0 +1,268 @@ +"""Triton @jit kernels lifted from ``elasticattn.src.Xattention``. + +Triton accepts any tensor exposing ``data_ptr()`` (paddle Tensor does), so +the kernels themselves are framework-agnostic; only the Python wrappers +(``flat_group_gemm_fuse_reshape`` / ``softmax_fuse_block_sum``) need to +target paddle. +""" + +from __future__ import annotations + +import math + +import paddle +import triton +import triton.language as tl + + +@triton.jit +def softmax_fuse_block_sum_kernel_causal( + In, Out, scale, + input_stride_0, input_stride_1, input_stride_2, + output_stride_0, output_stride_1, output_stride_2, + real_q_len, k_len, chunk_start, chunk_end, + segment_size: tl.constexpr, block_size: tl.constexpr, +): + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + offs_k = tl.arange(0, segment_size) + + num_iters = k_len // segment_size + num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + for it in range(0, num_iters_before_causal): + X = tl.load(input_ptr + it * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + for it in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + it * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + it * segment_size) + X = tl.where(mask, X, -1.0e6) + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + l_i_inv = 1.0 / l_i + sum_mask = offs_q[:, None] < real_q_len + + for it in range(0, num_iters_before_causal): + X = tl.load(input_ptr + it * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + it * segment_size // block_size, X.to(Out.type.element_ty)) + + for it in range(num_iters_before_causal, num_iters_before_causal + 1): + X = tl.load(input_ptr + it * segment_size).to(tl.float32) * scale + mask = offs_q[:, None] >= (offs_k[None, :] + it * segment_size) + X = tl.where(mask, X, -1.0e6) + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + it * segment_size // block_size, X.to(Out.type.element_ty)) + + for it in range(num_iters_before_causal + 1, num_iters): + X = tl.zeros([segment_size // block_size], dtype=tl.float32) + tl.store(output_ptr + it * segment_size // block_size, X.to(Out.type.element_ty)) + + +@triton.jit +def softmax_fuse_block_sum_kernel_non_causal( + In, Out, scale, + input_stride_0, input_stride_1, input_stride_2, + output_stride_0, output_stride_1, output_stride_2, + real_q_len, k_len, chunk_start, chunk_end, + segment_size: tl.constexpr, block_size: tl.constexpr, +): + block_id = tl.program_id(0) + head_id = tl.program_id(1) + batch_id = tl.program_id(2) + + offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size + num_iters = k_len // segment_size + + m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") + l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 + + input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2 + input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2 + + output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2 + output_ptr = output_ptr + tl.arange(0, segment_size // block_size) + + for it in range(0, num_iters): + X = tl.load(input_ptr + it * segment_size).to(tl.float32) * scale + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + l_i_inv = 1.0 / l_i + sum_mask = offs_q[:, None] < real_q_len + + for it in range(0, num_iters): + X = tl.load(input_ptr + it * segment_size).to(tl.float32) * scale + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(sum_mask, X, 0) + X = tl.reshape(X, (block_size, segment_size // block_size, block_size)) + X = tl.sum(X, 2) + X = tl.sum(X, 0) + tl.store(output_ptr + it * segment_size // block_size, X.to(Out.type.element_ty)) + + +@triton.jit +def flat_group_gemm_fuse_reshape_kernel( + Q, K, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_oz, stride_oh, stride_on, + chunk_start, chunk_end, + H: tl.constexpr, STRIDE: tl.constexpr, HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, is_causal: tl.constexpr, +): + block_m = tl.program_id(0).to(tl.int64) + block_n = tl.program_id(1).to(tl.int64) + batch_id = tl.program_id(2).to(tl.int64) // H + head_id = tl.program_id(2).to(tl.int64) % H + + if is_causal: + if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N: + return + + Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn + K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn + + Q_ptrs = ( + Q_ptrs + + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + + tl.arange(0, HEAD_DIM)[None, :] + + stride_qn * (STRIDE - 1) + ) + K_ptrs = ( + K_ptrs + + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + + tl.arange(0, HEAD_DIM)[:, None] + ) + + o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + for it in range(STRIDE): + q = tl.load(Q_ptrs - it * stride_qn) + k = tl.load(K_ptrs + it * stride_kn) + o += tl.dot(q, k) + + O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N + O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :] + tl.store(O_ptrs, o.to(Out.type.element_ty)) + + +def _ensure_contig(t: paddle.Tensor) -> paddle.Tensor: + return t if t.is_contiguous() else t.contiguous() + + +def softmax_fuse_block_sum( + attn_weights_slice: paddle.Tensor, + reshaped_block_size: int, + segment_size: int, + chunk_start: int, + chunk_end: int, + real_q_len: int, + scale: float, + is_causal: bool = True, +) -> paddle.Tensor: + batch_size, num_heads, q_len, k_len = attn_weights_slice.shape + assert q_len % reshaped_block_size == 0 + assert k_len % segment_size == 0 + assert segment_size % reshaped_block_size == 0 + attn_weights_slice = _ensure_contig(attn_weights_slice) + + output = paddle.empty( + [batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size], + dtype=attn_weights_slice.dtype, + ) + grid = (q_len // reshaped_block_size, num_heads, batch_size) + + s0, s1, s2, s3 = attn_weights_slice.strides + o0, o1, o2, _ = output.strides + + if is_causal: + softmax_fuse_block_sum_kernel_causal[grid]( + attn_weights_slice, output, scale, + s0, s1, s2, o0, o1, o2, + real_q_len, k_len, chunk_start, chunk_end, + segment_size, reshaped_block_size, + ) + else: + softmax_fuse_block_sum_kernel_non_causal[grid]( + attn_weights_slice, output, scale, + s0, s1, s2, o0, o1, o2, + real_q_len, k_len, chunk_start, chunk_end, + segment_size, reshaped_block_size, + ) + return output + + +def flat_group_gemm_fuse_reshape( + query_states: paddle.Tensor, + key_states: paddle.Tensor, + stride: int, + chunk_start: int, + chunk_end: int, + is_causal: bool = True, +) -> paddle.Tensor: + batch_size, num_heads, q_len, head_dim = query_states.shape + kv_len = key_states.shape[2] + + query_states = _ensure_contig(query_states) + key_states = _ensure_contig(key_states) + + output = paddle.empty( + [batch_size, num_heads, q_len // stride, kv_len // stride], + dtype=query_states.dtype, + ) + BLOCK_M = 64 + BLOCK_N = 64 + assert q_len % (stride * BLOCK_M) == 0 + assert kv_len % (stride * BLOCK_N) == 0 + + grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads) + + qs0, qs1, qs2, _ = query_states.strides + ks0, ks1, ks2, _ = key_states.strides + os0, os1, os2, _ = output.strides + + flat_group_gemm_fuse_reshape_kernel[grid]( + query_states, key_states, output, + qs0, qs1, qs2, ks0, ks1, ks2, os0, os1, os2, + chunk_start, chunk_end, + num_heads, stride, head_dim, BLOCK_M, BLOCK_N, is_causal, + ) + return output diff --git a/fastdeploy/model_executor/models/qwen3_elastic/modeling_elastic_qwen3.py b/fastdeploy/model_executor/models/qwen3_elastic/modeling_elastic_qwen3.py new file mode 100644 index 00000000000..734bbac75c9 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/modeling_elastic_qwen3.py @@ -0,0 +1,412 @@ +"""Elastic-Attention port of Qwen3 (PawQwen3ForCausalLM). + +Reuses Qwen3 weight-loading / TP mappings; replaces ``Qwen3Attention`` with +:class:`Qwen3ElasticAttention` which (a) hosts the per-layer +:class:`AttentionRouter` MLP and config knobs, and (b) routes the actual +attention compute through :class:`Qwen3ElasticAttentionBackend` (see +``fastdeploy/model_executor/layers/attention/elastic_attn_backend.py``). + +The integration spec lives in ``ELASTIC_FASTDEPLOY_INTEGRATION.md``; in +particular §4 (model registration), §5 (router + utils) and §8 (config). +""" + +from __future__ import annotations + +import re +from functools import partial +from typing import Dict + +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.elastic_attn_backend import ( + Qwen3ElasticAttentionBackend, +) +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from fastdeploy.model_executor.layers.lm_head import ParallelLMHead +from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) +from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP + +from .config_elastic import populate_elastic_fields +from .utils import AttentionRouter + + +class Qwen3ElasticMLP(Qwen2MLP): + pass + + +class Qwen3ElasticAttention(nn.Layer): + """Qwen3 attention with the Elastic-Attention router head. + + Backend-side ``Qwen3ElasticAttentionBackend.forward_mixed`` reads the + per-layer config knobs and the ``mask_allocator`` MLP to decide + head_mask_type and dispatch to BSA on the prefill leg; on decode it + behaves identically to vanilla Qwen3 (``append_attention``). + """ + + def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None: + super().__init__() + populate_elastic_fields(fd_config.model_config) + + self.fd_config = fd_config + self.layer_id = layer_id + self.head_dim = fd_config.model_config.head_dim + tp_size = fd_config.parallel_config.tensor_parallel_size + num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads) + self.num_heads_local = fd_config.model_config.num_attention_heads // tp_size + self.num_kv_heads_local = max( + 1, fd_config.model_config.num_key_value_heads * num_kv_heads_replicas // tp_size + ) + self.q_size = self.num_heads_local * self.head_dim + self.kv_size = self.num_kv_heads_local * self.head_dim + + self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False) + self.o_proj = RowParallelLinear( + fd_config, + prefix=f"{prefix}.o_proj", + input_size=self.head_dim * fd_config.model_config.num_attention_heads, + output_size=fd_config.model_config.hidden_size, + layer_id=layer_id, + ) + self.attn = Attention( + fd_config, + layer_id=layer_id, + prefix=prefix, + use_neox_rotary_style=True, + ) + self.qk_norm = QKRMSNorm( + fd_config, + head_dim=self.head_dim, + q_size=self.q_size, + kv_size=self.kv_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=prefix, + begin_norm_axis=2, + ) + + # ---- Elastic-Attention specific ---- + # Router MLP (loaded from ckpt weights ``mask_allocator.*``) + self.mask_allocator = AttentionRouter( + num_kv_heads=self.num_kv_heads_local, + d_feature=self.head_dim, + ) + # Trained scalar bias (kept for strict-load compat; not used at inference). + self.attn_mask_log_alphas = self.create_parameter( + shape=[self.num_kv_heads_local], + default_initializer=nn.initializer.Constant(0.0), + ) + + mc = fd_config.model_config + self.sink_size = int(mc.sink_size) + self.local_window_size = int(mc.local_window_size) + self.toggle_type = str(mc.toggle_type) + self.retrieval_mode = str(mc.retrieval_mode) + self.enable_ada_sparsity = bool(mc.enable_ada_sparsity) + self.pooling_mode = str(mc.pooling_mode) + # IMPORTANT: read elastic ``block_size`` (xattn / BSA granularity, default + # 128, matching PyTorch reference's ``self.granularity = getattr(config, + # "block_size", 128)``) directly from the ckpt's pretrained_config. + # We MUST NOT read from ``model_config`` here because FastDeploy's + # ``cache_config.block_size = 64`` (KV-cache block size) leaks onto + # ``model_config`` via attribute proxying in some configs, which + # silently corrupts the xattn/BSA block grid (sink_blocks/local_blocks + # halve, BSA mask grid no longer aligns with token blocks -> garbled + # output). The ckpt's config.json has no ``block_size`` field, so we + # fall back to 128. + _pc = getattr(mc, "pretrained_config", None) or mc + self.block_size = int(getattr(_pc, "block_size", 128)) + self.sink_blocks = (self.sink_size + self.block_size - 1) // self.block_size + self.local_blocks = (self.local_window_size + self.block_size - 1) // self.block_size + self.xattn_stride = int(mc.xattn_stride) + self.xattn_threshold = float(mc.xattn_threshold) + self.xattn_norm = float(mc.xattn_norm) + + # router decision cache (filled by backend on prefill) + self._z_kv_cache = paddle.zeros([self.num_kv_heads_local], dtype="int32") + self._head_mask_type_cache = paddle.zeros([self.num_heads_local], dtype="int32") + + # ---- Inject elastic attrs onto self.attn ---- + # The attention backend's ``forward_mixed`` receives + # ``layer = self.attn`` (the inner ``Attention`` instance, see + # ``layers/attention/attention.py:280-289``), NOT this parent. The + # elastic backend reads ``layer.mask_allocator`` / ``layer.toggle_type`` + # etc., so we mirror these handles onto the ``Attention`` instance. + # Using object.__setattr__ to avoid triggering nn.Layer's + # parameter/sublayer registration twice. + for _name in ( + "mask_allocator", + "toggle_type", + "retrieval_mode", + "enable_ada_sparsity", + "pooling_mode", + "block_size", + "sink_blocks", + "local_blocks", + "xattn_stride", + "xattn_threshold", + "xattn_norm", + "_z_kv_cache", + "_head_mask_type_cache", + ): + object.__setattr__(self.attn, _name, getattr(self, _name)) + + def load_state_dict(self, state_dict): + self.qkv_proj.load_state_dict(state_dict) + self.o_proj.load_state_dict(state_dict) + self.qk_norm.load_state_dict(state_dict) + self.attn.load_state_dict(state_dict) + + def forward(self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor): + qkv_out = self.qkv_proj(hidden_states) + qkv_out = self.qk_norm(qkv_out, forward_meta) + atten_out = self.attn(qkv=qkv_out, forward_meta=forward_meta) + return self.o_proj(atten_out) + + +class Qwen3ElasticDecoderLayer(Qwen2DecoderLayer): + def __init__(self, fd_config: FDConfig, prefix: str = "") -> None: + super().__init__(fd_config, prefix) + layer_id = int(prefix.split(sep=".")[-1]) + self.self_attn = Qwen3ElasticAttention( + fd_config=fd_config, layer_id=layer_id, prefix=f"{prefix}.self_attn" + ) + + +@support_graph_optimization +class Qwen3ElasticModel(nn.Layer): + def __init__(self, fd_config: FDConfig | None = None): + super().__init__() + self.num_layers = fd_config.model_config.num_hidden_layers + fd_config.model_config.pretrained_config.prefix_name = "model" + + self.embed_tokens = VocabParallelEmbedding( + fd_config=fd_config, + num_embeddings=fd_config.model_config.vocab_size, + embedding_dim=fd_config.model_config.hidden_size, + params_dtype=paddle.get_default_dtype, + prefix=(f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens"), + ) + self.layers = nn.LayerList( + [ + Qwen3ElasticDecoderLayer( + fd_config=fd_config, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", + ) + for i in range(self.num_layers) + ] + ) + self.norm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm", + ) + + def load_state_dict(self, state_dict): + self.embed_tokens.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) + for i in range(self.num_layers): + logger.info(f"Start load layer {i}") + self.layers[i].load_state_dict(state_dict) + + def forward(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta): + hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + residual = None + for i in range(self.num_layers): + hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual) + return self.norm(hidden_states, residual)[0] + + +@ModelRegistry.register_model_class( + architecture="PawQwen3ForCausalLM", + module_name="qwen3_elastic", + category=[ModelCategory.TEXT_GENERATION], + primary_use=ModelCategory.TEXT_GENERATION, +) +class PawQwen3ForCausalLM(ModelForCasualLM): + """Elastic-Attention Qwen3 (full_xattn / streaming) for inference.""" + + def __init__(self, fd_config: FDConfig): + super().__init__(fd_config) + self.fd_config = fd_config + populate_elastic_fields(fd_config.model_config) + + # Force ``architectures[0]`` to start with "Qwen" so that + # ``rotary_embedding.get_rope_impl()`` picks ``QwenRotaryEmbedding`` + # (rotary_dim=128, neox-style) instead of falling through to + # ``ErnieRotaryEmbedding`` (rotary_dim//2=64). Mismatched RoPE shape + # silently corrupts every layer and the model collapses to emitting + # token id 0 ("!" repeatedly). The original architecture name is kept + # in ``ModelRegistry`` because dispatch already happened before this. + archs = list(getattr(fd_config.model_config, "architectures", []) or []) + if archs and not archs[0].startswith("Qwen"): + archs[0] = "Qwen3" + archs[0] + fd_config.model_config.architectures = archs + + self.model = Qwen3ElasticModel(fd_config=fd_config) + self.ori_vocab_size = fd_config.model_config.ori_vocab_size + self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + + @classmethod + def name(cls): + return "PawQwen3ForCausalLM" + + @classmethod + def _get_attn_backend_cls(cls, *args, **kwargs): + return Qwen3ElasticAttentionBackend + + @paddle.no_grad() + def load_weights(self, weights_iterator) -> None: + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) + + is_pooling_model = hasattr(self, "is_pooling_model") and self.is_pooling_model + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("up_gate_proj", "gate_proj", "gate"), + ("up_gate_proj", "up_proj", "up"), + ("embed_tokens.embeddings", "embed_tokens", None), + ("lm_head.linear", "lm_head", None), + ("qk_norm.q_norm", "q_norm", None), + ("qk_norm.k_norm", "k_norm", None), + ] + + params_dict = dict(self.named_parameters()) + process_weights_after_loading_fn = process_weights_after_loading( + dict(self.named_sublayers()), self.fd_config + ) + + # Training-only keys we silently drop. + skip_substrings = ( + ".mask_allocator.log_temp", + ) + + for loaded_weight_name, loaded_weight in weights_iterator: + if any(s in loaded_weight_name for s in skip_substrings): + continue + logger.debug(f"Loading weight: {loaded_weight_name}") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight, shard_id) + break + else: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) + + model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) + + if self.tie_word_embeddings and not is_pooling_model: + self.lm_head.linear.weight.set_value( + self.model.embed_tokens.embeddings.weight.transpose([1, 0]).astype( + self.lm_head.linear.weight.dtype + ) + ) + + @paddle.no_grad() + def set_state_dict(self, state_dict): + self.model.load_state_dict(state_dict) + if self.tie_word_embeddings: + self.lm_head.load_state_dict( + {self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight} + ) + else: + self.lm_head.load_state_dict(state_dict) + + def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta = None): + logits = self.lm_head(hidden_states) + logits = logits.astype(paddle.float32) + logits[:, self.ori_vocab_size :] = -float("inf") + return logits + + def forward(self, inputs: Dict, forward_meta: ForwardMeta): + ids_remove_padding = inputs["ids_remove_padding"] + return self.model(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) + + def clear_graph_opt_backend(self): + self.model.clear_graph_opt_backend(fd_config=self.fd_config) + + +class PawQwen3PretrainedModel(PretrainedModel): + """TP mapping: identical to Qwen3 + extra mask_allocator router (replicated).""" + + config_class = FDConfig + + def _init_weight(self, layer): + return None + + @classmethod + def arch_name(cls): + return "PawQwen3ForCausalLM" + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_model_parallel_size=config.tensor_model_parallel_size, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + "layers.0.self_attn.q_proj.weight": partial(fn, is_column=True), + "layers.0.self_attn.q_proj.bias": partial(fn, is_column=True), + "layers.0.mlp.gate_proj.weight": partial(fn, is_column=True), + "layers.0.mlp.up_proj.weight": partial(fn, is_column=True), + } + if config.num_key_value_heads % config.tensor_model_parallel_size == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + # Router MLP is small -- replicate (no TP split) by simply not adding mappings. + return final_actions + + return get_tensor_parallel_split_mappings(config.num_hidden_layers) diff --git a/fastdeploy/model_executor/models/qwen3_elastic/utils.py b/fastdeploy/model_executor/models/qwen3_elastic/utils.py new file mode 100644 index 00000000000..64e9e1515d3 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_elastic/utils.py @@ -0,0 +1,128 @@ +"""Misc paddle helpers for Elastic-Attention. + +Contains: +- ``AttentionRouter`` : 3-layer MLP head router (per KV-head 0/1) +- ``derive_head_mask_type`` : retrieval/toggle -> {1, 0, -1} per Q-head +- ``ctx_q_pool`` : per-sequence mean-pooling of K (post k_norm, pre RoPE) +""" + +from __future__ import annotations + +import paddle +from paddle import nn + +from fastdeploy.model_executor.utils import set_weight_attrs + + +class _LinearTransposed(nn.Linear): + """nn.Linear that flags its weight for HF -> paddle transpose at load time.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__(in_features, out_features, bias_attr=bias) + set_weight_attrs(self.weight, {"weight_need_transpose": True}) + + +class AttentionRouter(nn.Layer): + """Inference-only 3-layer MLP router (matches PawQwen3 ``AttentionRouter`` + with ``use_softmax=True``).""" + + def __init__(self, num_kv_heads: int, d_feature: int = 128): + super().__init__() + self.num_kv_heads = num_kv_heads + self.d_feature = d_feature + mid = 4 * d_feature + + self.cls_feat_extractor = nn.Sequential( + _LinearTransposed(d_feature, mid), + nn.Silu(), + _LinearTransposed(mid, d_feature), + ) + self.cls_router_head_agnostic = nn.Sequential( + _LinearTransposed(d_feature, mid), + nn.Silu(), + _LinearTransposed(mid, d_feature), + nn.Silu(), + _LinearTransposed(d_feature, 2), + ) + + @paddle.no_grad() + def forward(self, k_pooled: paddle.Tensor) -> paddle.Tensor: + """Args: + k_pooled: [B, H_kv, D] -- post-k_norm, pre-RoPE, seq-mean. + Returns: + z_kv: [B, H_kv] int32 with values in {0,1}. + """ + h = self.cls_feat_extractor(k_pooled) + logits = self.cls_router_head_agnostic(h) + return logits.argmax(axis=-1).astype("int32") + + +def ctx_q_pool(k_post_norm: paddle.Tensor, cu_seq_lens: paddle.Tensor | None = None) -> paddle.Tensor: + """Pool K over the sequence axis per request. + + Mirrors the HF reference ``AttentionRouter`` else-branch (eval path with + ``cu_seq_len is None``) at ``modeling_flash_qwen.py``: + + target = torch.concat([x[:, :100, :], x[:, -100:, :]], dim=1).mean(dim=1) + + i.e. mean over the first 100 + last 100 tokens (with overlap when + ``T < 200``, matching HF byte-for-byte). + + BS=1 fast path: ``k_post_norm`` is ``[T, H_kv, D]`` -> returns + ``[1, H_kv, D]``. For general varlen, use ``cu_seq_lens`` ``[B+1]``. + """ + HEAD = 100 + TAIL = 100 + + def _pool_segment(seg: paddle.Tensor) -> paddle.Tensor: + # seg: [Ti, H, D] + Ti = seg.shape[0] + head = seg[: min(HEAD, Ti)] + tail = seg[-min(TAIL, Ti) :] + cat = paddle.concat([head, tail], axis=0) # [head+tail, H, D] + return cat.astype("float32").mean(axis=0, keepdim=True).astype(seg.dtype) + + if k_post_norm.ndim == 3 and cu_seq_lens is None: + return _pool_segment(k_post_norm) + if cu_seq_lens is None: + raise ValueError("cu_seq_lens required for varlen ctx_q_pool") + B = int(cu_seq_lens.shape[0]) - 1 + out = [] + for i in range(B): + s = int(cu_seq_lens[i].item()) + e = int(cu_seq_lens[i + 1].item()) + out.append(_pool_segment(k_post_norm[s:e])) + return paddle.concat(out, axis=0) + + +def derive_head_mask_type( + z_kv: paddle.Tensor, + retrieval_mode: str, + toggle_type: str, + group_size: int = 1, +) -> paddle.Tensor: + """Return ``head_mask_type`` for BSA: {1=block-sparse, 0=full, -1=streaming}. + + ``z_kv`` is ``[H_kv]`` int. If ``group_size>1`` (GQA), the result is + ``repeat_interleave``'d to ``[H_kv*group_size]`` to match Q-heads. + """ + z = z_kv.astype("int32") + zero = paddle.zeros_like(z) + one = paddle.ones_like(z) + neg = -one + key = (retrieval_mode, toggle_type) + if key == ("full", "xattn"): + out = (1 - z).astype("int32") + elif key == ("full", "streaming"): + out = paddle.where(z == 1, zero, neg).astype("int32") + elif key == ("xattn", "streaming"): + out = paddle.where(z == 1, one, neg).astype("int32") + elif key == ("xattn", "xattn"): + out = one + elif key == ("full", "full"): + out = zero + else: + raise NotImplementedError(f"unsupported (retrieval_mode, toggle_type) = {key}") + if group_size > 1: + out = paddle.repeat_interleave(out, group_size, axis=0) + return out From 0be983176f896ecf268816dd12b7b3e5011e955a Mon Sep 17 00:00:00 2001 From: Zhenxu Tian Date: Thu, 4 Jun 2026 20:29:59 +0800 Subject: [PATCH 2/4] add UT --- .../layers/test_elastic_attention_backend.py | 199 +++++++++++++++ .../test_elastic_qwen3_config.py | 123 ++++++++++ .../test_elastic_qwen3_patches.py | 162 ++++++++++++ .../test_elastic_qwen3_utils.py | 185 ++++++++++++++ .../test_block_sparse_attn_paddle.py | 171 +++++++++++++ .../attention/test_xattention_estimate.py | 232 ++++++++++++++++++ 6 files changed, 1072 insertions(+) create mode 100644 tests/layers/test_elastic_attention_backend.py create mode 100644 tests/model_executor/test_elastic_qwen3_config.py create mode 100644 tests/model_executor/test_elastic_qwen3_patches.py create mode 100644 tests/model_executor/test_elastic_qwen3_utils.py create mode 100644 tests/operators/attention/test_block_sparse_attn_paddle.py create mode 100644 tests/operators/attention/test_xattention_estimate.py diff --git a/tests/layers/test_elastic_attention_backend.py b/tests/layers/test_elastic_attention_backend.py new file mode 100644 index 00000000000..de71490e27c --- /dev/null +++ b/tests/layers/test_elastic_attention_backend.py @@ -0,0 +1,199 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Layer-level tests for ``Qwen3ElasticAttention`` and its backend. + +These tests are intentionally **construction-time / contract-level**: +end-to-end ``forward_mixed`` requires a fully initialised distributed env +(``init_distributed_environment``), KV-cache buffers and a built BSA op. +That heavy path is covered by the integration smoke +``models/qwen3_elastic/run_elastic_qwen3_4b.py``. Here we guard the +documented invariants that are easy to break and hard to notice: + +1. Elastic config knobs are mirrored onto ``self.attn`` so the backend can + read them (§4.1 "key trick"). +2. ``block_size`` is read from ``pretrained_config``, NOT from + ``model_config``, so a leaking ``cache_config.block_size = 64`` cannot + silently halve sink_blocks / local_blocks. +3. The router decision caches (``_z_kv_cache`` / ``_head_mask_type_cache``) + exist with the right shape / dtype. +4. ``Qwen3ElasticAttentionBackend`` is the class that ``PawQwen3ForCausalLM`` + advertises via ``_get_attn_backend_cls``. +""" + +import unittest +from unittest import mock + +import paddle + +# Like test_elastic_qwen3_patches, this test cannot side-step the parent +# fastdeploy package init: ``Qwen3ElasticAttention`` constructs FD's real +# ``Attention`` layer (mocked here, but the import path must still resolve) +# and the elastic backend module imports FD's attention base classes. So +# the same fully-built fastdeploy_ops requirement applies. +try: + import fastdeploy.model_executor.models.qwen3_elastic # noqa: F401 + from fastdeploy.model_executor.models.qwen3_elastic.modeling_elastic_qwen3 import ( # noqa: F401 + Qwen3ElasticAttention, + ) + + _FD_FULLY_BUILT = True + _FD_IMPORT_ERR = None +except Exception as _e: # noqa: BLE001 + _FD_FULLY_BUILT = False + _FD_IMPORT_ERR = _e + + +class _Cfg: + """Bag-of-attrs stand-in for FDConfig sub-configs.""" + + def __init__(self, **kw): + for k, v in kw.items(): + setattr(self, k, v) + + +def _make_minimal_fd_config(): + """Build the smallest FDConfig-like object that ``Qwen3ElasticAttention`` + needs at __init__ time (we never run forward in this test).""" + pc = _Cfg(block_size=128) # elastic granularity (the §4.1 trick) + mc = _Cfg( + head_dim=64, + hidden_size=256, + num_attention_heads=4, + num_key_value_heads=2, + rms_norm_eps=1e-6, + # elastic fields populated by populate_elastic_fields with defaults + pretrained_config=pc, + ) + parallel = _Cfg(tensor_parallel_size=1) + return _Cfg(model_config=mc, parallel_config=parallel) + + +@unittest.skipUnless( + _FD_FULLY_BUILT, + f"fastdeploy custom-ops not fully built (got: {_FD_IMPORT_ERR!r})", +) +@unittest.skipIf( + not paddle.device.is_compiled_with_cuda(), + "Qwen3ElasticAttention pulls in FD layers that require a CUDA build.", +) +class TestQwen3ElasticAttentionConstruction(unittest.TestCase): + """Layer __init__ contract; no forward.""" + + def setUp(self): + # We only need the construction path. Patch the heavy children with + # ``MagicMock`` so we don't depend on a fully-initialised distributed + # env or KV cache pool. + self._patches = [] + for path in ( + "fastdeploy.model_executor.models.qwen3_elastic.modeling_elastic_qwen3.QKVParallelLinear", + "fastdeploy.model_executor.models.qwen3_elastic.modeling_elastic_qwen3.RowParallelLinear", + "fastdeploy.model_executor.models.qwen3_elastic.modeling_elastic_qwen3.QKRMSNorm", + "fastdeploy.model_executor.models.qwen3_elastic.modeling_elastic_qwen3.Attention", + ): + p = mock.patch(path) + self._patches.append(p) + p.start() + + def tearDown(self): + for p in self._patches: + p.stop() + + def test_elastic_attrs_mirrored_onto_self_attn(self): + fd = _make_minimal_fd_config() + layer = Qwen3ElasticAttention(fd_config=fd, layer_id=0, prefix="model.layers.0.self_attn") + + # The backend's ``forward_mixed`` receives ``self.attn`` (NOT this + # parent), so every elastic knob must be reachable from there. + for name in ( + "mask_allocator", + "toggle_type", + "retrieval_mode", + "enable_ada_sparsity", + "pooling_mode", + "block_size", + "sink_blocks", + "local_blocks", + "xattn_stride", + "xattn_threshold", + "xattn_norm", + "_z_kv_cache", + "_head_mask_type_cache", + ): + self.assertTrue( + hasattr(layer.attn, name), + msg=f"layer.attn is missing mirrored attr {name!r}", + ) + self.assertEqual(getattr(layer.attn, name), getattr(layer, name), msg=f"layer.attn.{name} != layer.{name}") + + def test_block_size_comes_from_pretrained_config(self): + """Regression test for §4.1: ``cache_config.block_size`` (e.g. 64) + leaking onto model_config must NOT win over ``pretrained_config.block_size``. + """ + fd = _make_minimal_fd_config() + # Simulate cache_config.block_size leaking onto model_config. + fd.model_config.block_size = 64 + # ckpt elastic block_size (the source of truth) is 128. + fd.model_config.pretrained_config.block_size = 128 + + layer = Qwen3ElasticAttention(fd_config=fd, layer_id=0, prefix="x") + self.assertEqual(layer.block_size, 128, "block_size MUST be read from pretrained_config") + # Derived counters use elastic block_size (128), not 64. + self.assertEqual(layer.sink_blocks, (layer.sink_size + 128 - 1) // 128) + self.assertEqual(layer.local_blocks, (layer.local_window_size + 128 - 1) // 128) + + def test_router_cache_shapes(self): + fd = _make_minimal_fd_config() + layer = Qwen3ElasticAttention(fd_config=fd, layer_id=0, prefix="x") + self.assertEqual(list(layer._z_kv_cache.shape), [layer.num_kv_heads_local]) + self.assertEqual(layer._z_kv_cache.dtype, paddle.int32) + self.assertEqual(list(layer._head_mask_type_cache.shape), [layer.num_heads_local]) + self.assertEqual(layer._head_mask_type_cache.dtype, paddle.int32) + + def test_mask_allocator_is_router(self): + from fastdeploy.model_executor.models.qwen3_elastic.utils import AttentionRouter + + fd = _make_minimal_fd_config() + layer = Qwen3ElasticAttention(fd_config=fd, layer_id=0, prefix="x") + self.assertIsInstance(layer.mask_allocator, AttentionRouter) + self.assertEqual(layer.mask_allocator.num_kv_heads, layer.num_kv_heads_local) + self.assertEqual(layer.mask_allocator.d_feature, layer.head_dim) + + +@unittest.skipUnless( + _FD_FULLY_BUILT, + f"fastdeploy custom-ops not fully built (got: {_FD_IMPORT_ERR!r})", +) +@unittest.skipIf( + not paddle.device.is_compiled_with_cuda(), + "Qwen3ElasticAttentionBackend imports CUDA-only modules.", +) +class TestModelDeclaresElasticBackend(unittest.TestCase): + def test_paw_qwen3_backend_class(self): + from fastdeploy.model_executor.layers.attention.elastic_attn_backend import ( + Qwen3ElasticAttentionBackend, + ) + from fastdeploy.model_executor.models.qwen3_elastic.modeling_elastic_qwen3 import ( + PawQwen3ForCausalLM, + ) + + # The model class advertises the elastic backend via the public hook + # that FastDeploy's selector consults. + self.assertIs( + PawQwen3ForCausalLM._get_attn_backend_cls(), + Qwen3ElasticAttentionBackend, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_executor/test_elastic_qwen3_config.py b/tests/model_executor/test_elastic_qwen3_config.py new file mode 100644 index 00000000000..5963d02422b --- /dev/null +++ b/tests/model_executor/test_elastic_qwen3_config.py @@ -0,0 +1,123 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``qwen3_elastic.config_elastic.populate_elastic_fields``. + +The function lifts elastic fields from ``pretrained_config`` (the raw ckpt +config.json) onto ``model_config``, with the documented defaults from +ELASTIC_CONFIG_FIELDS. This test guards: + +1. Default values when ckpt has no elastic fields. +2. Override is honored when ckpt provides a value. +3. Idempotence (calling twice does not overwrite already-populated fields). +4. ``block_size`` MUST come from ``pretrained_config``, never from + ``model_config`` directly (FD's ``cache_config.block_size = 64`` would + otherwise corrupt the BSA block grid). +""" + +import importlib.util +import os +import unittest + +# See test_elastic_qwen3_utils.py for why we file-load this module instead of +# importing ``fastdeploy.model_executor.models.qwen3_elastic.config_elastic`` +# the regular way (parent ``models/__init__`` -> attention.ops chain pulls +# in compiled custom-op symbols that may be missing in some builds). +_HERE = os.path.dirname(os.path.abspath(__file__)) +_CFG_PATH = os.path.normpath( + os.path.join(_HERE, "..", "..", "fastdeploy", "model_executor", "models", "qwen3_elastic", "config_elastic.py") +) +_spec = importlib.util.spec_from_file_location("qwen3_elastic_config_under_test", _CFG_PATH) +_cfg = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_cfg) +ELASTIC_CONFIG_FIELDS = _cfg.ELASTIC_CONFIG_FIELDS +populate_elastic_fields = _cfg.populate_elastic_fields + + +class _NS: + """Lightweight namespace mimicking the model_config / pretrained_config object.""" + + pass + + +class TestPopulateElasticFields(unittest.TestCase): + def test_defaults_when_ckpt_empty(self): + mc = _NS() + mc.pretrained_config = _NS() + populate_elastic_fields(mc) + for attr, (_, default) in ELASTIC_CONFIG_FIELDS.items(): + self.assertEqual(getattr(mc, attr), default, msg=f"attr={attr}") + + def test_ckpt_override(self): + mc = _NS() + pc = _NS() + pc.local_window_size = 4096 + pc.sink_size = 256 + pc.toggle_type = "streaming" + pc.retrieval_mode = "xattn" + pc.xattn_threshold = 0.5 + pc.block_size = 64 + mc.pretrained_config = pc + + populate_elastic_fields(mc) + + self.assertEqual(mc.local_window_size, 4096) + self.assertEqual(mc.sink_size, 256) + self.assertEqual(mc.toggle_type, "streaming") + self.assertEqual(mc.retrieval_mode, "xattn") + self.assertAlmostEqual(mc.xattn_threshold, 0.5) + self.assertEqual(mc.block_size, 64) + # Other fields fall back to defaults. + self.assertTrue(mc.enable_ada_sparsity) + self.assertEqual(mc.pooling_mode, "ctx_q") + + def test_idempotent(self): + mc = _NS() + mc.pretrained_config = _NS() + populate_elastic_fields(mc) + # Pretend user has set a custom value AFTER the first population. + mc.toggle_type = "custom" + # Second call must NOT overwrite it. + populate_elastic_fields(mc) + self.assertEqual(mc.toggle_type, "custom") + + def test_block_size_from_pretrained_config_not_model_config(self): + """FD's ``cache_config.block_size`` (64) often leaks onto model_config + via attribute proxying. ``populate_elastic_fields`` must read from + ``pretrained_config`` so the BSA block grid uses the elastic 128. + """ + mc = _NS() + pc = _NS() + # ckpt elastic block_size (granularity used by xattn / BSA) + pc.block_size = 128 + mc.pretrained_config = pc + # model_config also has a (different) block_size leaking from + # cache_config -- we ensure populate_elastic_fields reads from `pc`. + populate_elastic_fields(mc) + self.assertEqual(mc.block_size, 128) + + def test_no_pretrained_config_falls_back_to_model_config(self): + """When pretrained_config is missing, populate from model_config itself.""" + mc = _NS() + mc.pretrained_config = None + mc.toggle_type = "xattn" + populate_elastic_fields(mc) + self.assertEqual(mc.toggle_type, "xattn") + # defaults still fill remaining attrs + self.assertEqual(mc.sink_size, 128) + + +if __name__ == "__main__": + import unittest as _u + + _u.main() diff --git a/tests/model_executor/test_elastic_qwen3_patches.py b/tests/model_executor/test_elastic_qwen3_patches.py new file mode 100644 index 00000000000..717f82cd84f --- /dev/null +++ b/tests/model_executor/test_elastic_qwen3_patches.py @@ -0,0 +1,162 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Regression tests for the global patches applied by +``fastdeploy.model_executor.models.qwen3_elastic.__init__``. + +The package patches two pieces of FastDeploy global state on import: + +1. ``attention_selecter.get_attention_backend`` / + ``attention_selecter._get_attn_backend`` -- must return the elastic + backend ONLY for ``PawQwen3ForCausalLM`` and fall through to the + original selector for every other architecture. + +2. ``rotary_embedding.get_rope_impl`` -- must route PawQwen3 + yarn + rope_scaling through ``GptOssScalingRotaryEmbedding`` and leave every + non-PawQwen3 caller untouched. + +Because ``auto_models_registry`` imports this package on every FastDeploy +launch (including dense Qwen3, ERNIE, GLM, ...), a leaky patch would +silently break unrelated models. These tests guard that. +""" + +import unittest + +# These tests verify monkey-patches applied by importing the qwen3_elastic +# package onto fastdeploy's REAL ``attention_selecter`` and +# ``rotary_embedding`` modules. Unlike utils/config/kernels tests, there's +# no way to file-load past this -- the whole point is that the patch hits +# the real fastdeploy globals. So they require the same fully-built +# fastdeploy_ops as the integration smoke (run_elastic_qwen3_4b.py). On +# stale / partial builds (e.g. older fastdeploy_ops_pd_.so missing +# config_for_attention) the import will fail; skip cleanly in that case. +try: + import fastdeploy.model_executor.models.qwen3_elastic # noqa: F401 + from fastdeploy.model_executor.layers.attention import ( # noqa: F401 + attention_selecter, + ) + from fastdeploy.model_executor.layers.attention.elastic_attn_backend import ( # noqa: F401 + Qwen3ElasticAttentionBackend, + ) + + _FD_FULLY_BUILT = True + _FD_IMPORT_ERR = None +except Exception as _e: # noqa: BLE001 + _FD_FULLY_BUILT = False + _FD_IMPORT_ERR = _e + + +def _require_full_build(test): + return unittest.skipUnless( + _FD_FULLY_BUILT, + f"fastdeploy custom-ops not fully built: {_FD_IMPORT_ERR!r}", + )(test) + + +class _Cfg: + def __init__(self, **kw): + for k, v in kw.items(): + setattr(self, k, v) + + +class _CallerWithFDConfig: + """Stack-frame stand-in: the real selector walks ``frame.f_locals['self']`` + and reads ``self.fd_config.model_config.architectures``. + """ + + def __init__(self, archs): + self.fd_config = _Cfg(model_config=_Cfg(architectures=archs)) + + def call_get_attention_backend(self): + return attention_selecter.get_attention_backend() + + def call_get_attn_backend(self, sb=None): + return attention_selecter._get_attn_backend(sb) + + +class TestAttentionSelectorPatch(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not _FD_FULLY_BUILT: + raise unittest.SkipTest(f"fastdeploy custom-ops not fully built: {_FD_IMPORT_ERR!r}") + + def test_pawqwen3_caller_gets_elastic_backend(self): + caller = _CallerWithFDConfig(archs=["PawQwen3ForCausalLM"]) + cls1 = caller.call_get_attention_backend() + cls2 = caller.call_get_attn_backend() + self.assertIs(cls1, Qwen3ElasticAttentionBackend) + self.assertIs(cls2, Qwen3ElasticAttentionBackend) + + def test_other_models_untouched(self): + """Dense Qwen3 / ERNIE / etc must NOT receive the elastic backend.""" + for arch in ("Qwen3ForCausalLM", "Qwen2ForCausalLM", "Ernie4_5_MoeForCausalLM", "GLM4ForCausalLM"): + caller = _CallerWithFDConfig(archs=[arch]) + try: + cls1 = caller.call_get_attention_backend() + except Exception: + # The original selector may itself fail on this dummy fd_config + # (e.g. need a CUDA platform). What matters is it did NOT + # short-circuit to Qwen3ElasticAttentionBackend, which would + # always succeed -- so a raised error is also a pass. + continue + self.assertIsNot( + cls1, + Qwen3ElasticAttentionBackend, + msg=f"{arch} must not be redirected to elastic backend", + ) + + +class TestRopeImplPatch(unittest.TestCase): + """``_patched_get_rope_impl`` only kicks in when architecture starts with + something OTHER than 'Qwen' but contains 'Qwen' (i.e. PawQwen3-style names). + """ + + @classmethod + def setUpClass(cls): + if not _FD_FULLY_BUILT: + raise unittest.SkipTest(f"fastdeploy custom-ops not fully built: {_FD_IMPORT_ERR!r}") + + def test_predicate_matches_pawqwen3(self): + from fastdeploy.model_executor.models.qwen3_elastic import ( + __init__ as elastic_init, + ) + + is_paw = elastic_init._is_pawqwen3 + self.assertTrue(is_paw(_Cfg(architectures=["PawQwen3ForCausalLM"]))) + # Does not match dense Qwen3 (the architecture starts with "Qwen"). + self.assertFalse(is_paw(_Cfg(architectures=["Qwen3ForCausalLM"]))) + self.assertFalse(is_paw(_Cfg(architectures=["Qwen2ForCausalLM"]))) + # Non-Qwen architectures are unaffected. + self.assertFalse(is_paw(_Cfg(architectures=["Ernie4_5ForCausalLM"]))) + self.assertFalse(is_paw(_Cfg(architectures=[]))) + + def test_yarn_rope_scaling_extraction(self): + from fastdeploy.model_executor.models.qwen3_elastic import ( + __init__ as elastic_init, + ) + + get = elastic_init._yarn_rope_scaling + # ``type=yarn`` -> returns the dict. + rs = {"type": "yarn", "factor": 8.0, "original_max_position_embeddings": 40960} + self.assertEqual(get(_Cfg(rope_scaling=rs)), rs) + # ``rope_type=yarn`` (newer key) also works. + rs2 = {"rope_type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768} + self.assertEqual(get(_Cfg(rope_scaling=rs2)), rs2) + # Non-yarn / missing -> None. + self.assertIsNone(get(_Cfg(rope_scaling={"type": "linear"}))) + self.assertIsNone(get(_Cfg(rope_scaling=None))) + self.assertIsNone(get(_Cfg())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/model_executor/test_elastic_qwen3_utils.py b/tests/model_executor/test_elastic_qwen3_utils.py new file mode 100644 index 00000000000..72f7361f288 --- /dev/null +++ b/tests/model_executor/test_elastic_qwen3_utils.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``fastdeploy.model_executor.models.qwen3_elastic.utils``. + +Covers the inference-only pure-paddle helpers that the elastic-attention +backend depends on: + +- ``_LinearTransposed`` -- HF<->paddle weight-transpose flag +- ``AttentionRouter`` -- 3-layer MLP head router (argmax 0/1) +- ``ctx_q_pool`` -- per-sequence first-100 + last-100 mean +- ``derive_head_mask_type`` -- (retrieval_mode, toggle_type) -> {1, 0, -1} +""" + +import importlib.util +import os +import unittest + +import numpy as np +import paddle + +# Load the target module directly from its source file. Going through +# ``fastdeploy.model_executor.models.qwen3_elastic.utils`` would trigger the +# parent ``models/__init__.py`` -> attention.ops chain, which transitively +# imports compiled custom-op symbols that may not all be present in every +# build (e.g. older fastdeploy_ops_pd_.so without ``config_for_attention``). +# These pure-paddle helpers don't need any custom op, so we side-step the +# package init entirely. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_UTILS_PATH = os.path.normpath( + os.path.join(_HERE, "..", "..", "fastdeploy", "model_executor", "models", "qwen3_elastic", "utils.py") +) +_spec = importlib.util.spec_from_file_location("qwen3_elastic_utils_under_test", _UTILS_PATH) +_utils = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_utils) +AttentionRouter = _utils.AttentionRouter +_LinearTransposed = _utils._LinearTransposed +ctx_q_pool = _utils.ctx_q_pool +derive_head_mask_type = _utils.derive_head_mask_type + +paddle.seed(0) +np.random.seed(0) + + +class TestLinearTransposed(unittest.TestCase): + def test_weight_need_transpose_flag(self): + layer = _LinearTransposed(in_features=8, out_features=16, bias=True) + # The flag is what default_weight_loader reads to know whether to + # transpose HF [out, in] weight into paddle [in, out] layout. + self.assertTrue(getattr(layer.weight, "weight_need_transpose", False)) + # Paddle linear weight shape stays [in, out] + self.assertEqual(list(layer.weight.shape), [8, 16]) + + def test_forward_shape(self): + layer = _LinearTransposed(8, 16, bias=True) + x = paddle.randn([4, 8]) + y = layer(x) + self.assertEqual(list(y.shape), [4, 16]) + + +class TestAttentionRouter(unittest.TestCase): + def setUp(self): + paddle.seed(123) + self.num_kv_heads = 8 + self.d_feature = 16 # tiny d_feature; we are only checking semantics + self.router = AttentionRouter(num_kv_heads=self.num_kv_heads, d_feature=self.d_feature) + + def test_output_shape_dtype_and_range(self): + k_pooled = paddle.randn([1, self.num_kv_heads, self.d_feature]) + z = self.router(k_pooled) + # Spec: [B, H_kv] int32 with values in {0, 1}. + self.assertEqual(list(z.shape), [1, self.num_kv_heads]) + self.assertEqual(z.dtype, paddle.int32) + z_np = z.numpy() + self.assertTrue(np.isin(z_np, [0, 1]).all()) + + def test_batch_2(self): + k_pooled = paddle.randn([2, self.num_kv_heads, self.d_feature]) + z = self.router(k_pooled) + self.assertEqual(list(z.shape), [2, self.num_kv_heads]) + + def test_argmax_matches_logits(self): + # The router does ``argmax`` over the final 2-class logits. Manually + # rebuild the same path and check parity. + k_pooled = paddle.randn([1, self.num_kv_heads, self.d_feature]) + h = self.router.cls_feat_extractor(k_pooled) + logits = self.router.cls_router_head_agnostic(h) + ref = logits.argmax(axis=-1).astype("int32") + out = self.router(k_pooled) + np.testing.assert_array_equal(out.numpy(), ref.numpy()) + + +class TestCtxQPool(unittest.TestCase): + """ctx_q_pool == mean of first 100 + last 100 K tokens (with overlap when T<200).""" + + def _ref_pool(self, k): # [T, H, D] -> [1, H, D] + T = k.shape[0] + head = k[: min(100, T)] + tail = k[-min(100, T) :] + cat = paddle.concat([head, tail], axis=0).astype("float32") + return cat.mean(axis=0, keepdim=True).astype(k.dtype) + + def test_short_sequence_overlap(self): + # T < 200 -> head and tail overlap, exactly matching HF eval path. + T, H, D = 50, 4, 8 + k = paddle.randn([T, H, D]) + out = ctx_q_pool(k) + ref = self._ref_pool(k) + self.assertEqual(list(out.shape), [1, H, D]) + np.testing.assert_allclose(out.numpy(), ref.numpy(), rtol=1e-5, atol=1e-5) + + def test_long_sequence_no_overlap(self): + T, H, D = 1024, 4, 8 + k = paddle.randn([T, H, D]) + out = ctx_q_pool(k) + ref = self._ref_pool(k) + self.assertEqual(list(out.shape), [1, H, D]) + np.testing.assert_allclose(out.numpy(), ref.numpy(), rtol=1e-5, atol=1e-5) + + def test_varlen_two_segments(self): + T1, T2 = 30, 300 + H, D = 4, 8 + k = paddle.randn([T1 + T2, H, D]) + cu = paddle.to_tensor([0, T1, T1 + T2], dtype="int32") + out = ctx_q_pool(k, cu_seq_lens=cu) + self.assertEqual(list(out.shape), [2, H, D]) + ref0 = self._ref_pool(k[:T1]) + ref1 = self._ref_pool(k[T1:]) + np.testing.assert_allclose(out[0:1].numpy(), ref0.numpy(), rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(out[1:2].numpy(), ref1.numpy(), rtol=1e-5, atol=1e-5) + + +class TestDeriveHeadMaskType(unittest.TestCase): + """Enumerate the 5 documented (retrieval_mode, toggle_type) pairs (§5.4).""" + + def setUp(self): + # H_kv = 4, deliberately mixed 0/1 so each branch is meaningful. + self.z = paddle.to_tensor([0, 1, 0, 1], dtype="int32") + + def test_full_xattn(self): + out = derive_head_mask_type(self.z, "full", "xattn", group_size=1) + np.testing.assert_array_equal(out.numpy(), np.array([1, 0, 1, 0], dtype=np.int32)) + + def test_full_streaming(self): + out = derive_head_mask_type(self.z, "full", "streaming", group_size=1) + np.testing.assert_array_equal(out.numpy(), np.array([-1, 0, -1, 0], dtype=np.int32)) + + def test_xattn_streaming(self): + out = derive_head_mask_type(self.z, "xattn", "streaming", group_size=1) + np.testing.assert_array_equal(out.numpy(), np.array([-1, 1, -1, 1], dtype=np.int32)) + + def test_xattn_xattn_all_one(self): + out = derive_head_mask_type(self.z, "xattn", "xattn", group_size=1) + np.testing.assert_array_equal(out.numpy(), np.array([1, 1, 1, 1], dtype=np.int32)) + + def test_full_full_all_zero(self): + out = derive_head_mask_type(self.z, "full", "full", group_size=1) + np.testing.assert_array_equal(out.numpy(), np.array([0, 0, 0, 0], dtype=np.int32)) + + def test_gqa_repeat_interleave(self): + # group_size=2 -> H_q = 2 * H_kv, each KV-head decision is duplicated. + out = derive_head_mask_type(self.z, "full", "xattn", group_size=2) + # base: [1, 0, 1, 0] -> repeat_interleave 2 -> [1,1,0,0,1,1,0,0] + np.testing.assert_array_equal( + out.numpy(), + np.array([1, 1, 0, 0, 1, 1, 0, 0], dtype=np.int32), + ) + + def test_unsupported_pair_raises(self): + with self.assertRaises(NotImplementedError): + derive_head_mask_type(self.z, "streaming", "full", group_size=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/attention/test_block_sparse_attn_paddle.py b/tests/operators/attention/test_block_sparse_attn_paddle.py new file mode 100644 index 00000000000..046bafffb40 --- /dev/null +++ b/tests/operators/attention/test_block_sparse_attn_paddle.py @@ -0,0 +1,171 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the paddle wrapper around ``block_sparse_attn_ops``. + +The CUDA op itself ships with its own UTs in +``custom_ops/gpu_ops/block_sparse_attn``. This file targets the FastDeploy- +side paddle helpers that the wrapper applies BEFORE calling into the CUDA +binary: + +- ``_replace_ones_with_count`` -- assigns unique 1-based indices to each + sparse head so the kernel can read its + own blockmask row. +- ``_convert_blockmask_row_reverse`` -- bool [B,H,Qb,Kb] -> int32 sorted-desc + K-block index list (-1 padding) that + the kernel binary-searches over. + +A separate GPU smoke test asserts a tiny dense-equivalent call (no sparse +heads) returns finite, correctly-shaped output -- skipped if the standalone +``block_sparse_attn_ops`` extension is not yet built. +""" + +import importlib +import importlib.util +import os +import unittest + +import numpy as np +import paddle +import pytest + +# File-load to avoid the ``models/__init__.py`` -> attention.ops chain that +# pulls in compiled custom-op symbols which may be missing in some builds. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_BSA_PATH = os.path.normpath( + os.path.join( + _HERE, + "..", + "..", + "..", + "fastdeploy", + "model_executor", + "models", + "qwen3_elastic", + "kernels", + "block_sparse_attn.py", + ) +) +_spec = importlib.util.spec_from_file_location("qwen3_elastic_bsa_under_test", _BSA_PATH) +_bsa = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_bsa) +_convert_blockmask_row_reverse = _bsa._convert_blockmask_row_reverse +_replace_ones_with_count = _bsa._replace_ones_with_count + + +class TestReplaceOnesWithCount(unittest.TestCase): + def test_no_sparse_head(self): + h = paddle.to_tensor([0, -1, 0, -1], dtype="int32") + out, n = _replace_ones_with_count(h) + np.testing.assert_array_equal(out.numpy(), h.numpy()) + self.assertEqual(n, 0) + + def test_mixed(self): + # 1s sit at positions 0, 2, 5 -> they should become 1, 2, 3. + h = paddle.to_tensor([1, 0, 1, -1, 0, 1], dtype="int32") + out, n = _replace_ones_with_count(h) + np.testing.assert_array_equal( + out.numpy(), + np.array([1, 0, 2, -1, 0, 3], dtype=np.int32), + ) + self.assertEqual(n, 3) + + def test_all_sparse(self): + h = paddle.to_tensor([1, 1, 1, 1], dtype="int32") + out, n = _replace_ones_with_count(h) + np.testing.assert_array_equal(out.numpy(), np.array([1, 2, 3, 4], dtype=np.int32)) + self.assertEqual(n, 4) + + +class TestConvertBlockmaskRowReverse(unittest.TestCase): + def test_descending_indices_with_padding(self): + # Single (B=1, H=1, Qb=1, Kb=5) row: kept blocks {0, 2, 4} + bm = paddle.to_tensor([[[[True, False, True, False, True]]]], dtype="bool") + out = _convert_blockmask_row_reverse(bm).numpy()[0, 0, 0] + # Largest valid k-block first -> 4, 2, 0; padding -1 fills the rest. + kept = sorted([i for i in out.tolist() if i != -1], reverse=True) + self.assertEqual(kept, [4, 2, 0]) + # Total length preserved + self.assertEqual(len(out), 5) + # Padding only at the tail + first_pad = next((i for i, v in enumerate(out) if v == -1), len(out)) + self.assertTrue(all(v == -1 for v in out[first_pad:])) + + def test_all_kept(self): + bm = paddle.to_tensor([[[[True, True, True, True]]]], dtype="bool") + out = _convert_blockmask_row_reverse(bm).numpy()[0, 0, 0] + # All four indices present, no -1. + self.assertEqual(sorted(out.tolist(), reverse=True), [3, 2, 1, 0]) + self.assertFalse((out == -1).any()) + + def test_all_dropped(self): + bm = paddle.to_tensor([[[[False, False, False]]]], dtype="bool") + out = _convert_blockmask_row_reverse(bm).numpy()[0, 0, 0] + self.assertTrue((out == -1).all()) + + +@pytest.mark.gpu +class TestBlockSparseAttnSmoke(unittest.TestCase): + """End-to-end smoke vs. dense scaled-dot-product reference, all heads full. + + Skipped automatically if ``block_sparse_attn_ops`` (the standalone CUDA + extension) is not importable. + """ + + def setUp(self): + try: + importlib.import_module("block_sparse_attn_ops") + except Exception: + self.skipTest("block_sparse_attn_ops not built; skip smoke test") + + def test_full_heads_match_dense(self): + # File-loaded copy already at module top via _bsa. + block_sparse_attn_paddle = _bsa.block_sparse_attn_paddle + + T, H, D = 256, 4, 64 + block = 128 + Qb = (T + block - 1) // block + Kb = Qb + + paddle.seed(0) + q = paddle.randn([T, H, D]).astype("bfloat16") + k = paddle.randn([T, H, D]).astype("bfloat16") + v = paddle.randn([T, H, D]).astype("bfloat16") + cu = paddle.to_tensor([0, T], dtype="int32") + # head_mask_type = 0 (full) for every head -> equivalent to dense FA. + hmt = paddle.zeros([H], dtype="int32") + streaming_info = paddle.to_tensor([1, 16] * H, dtype="int32") + # placeholder blockmask (no sparse heads, but the wrapper requires shape) + blockmask = paddle.ones([1, 0, Qb, Kb], dtype="bool") + + out = block_sparse_attn_paddle( + q, + k, + v, + cu, + cu, + hmt, + streaming_info, + blockmask, + max_seqlen_q=T, + max_seqlen_k=T, + is_causal=True, + m_block_dim=block, + n_block_dim=block, + ) + self.assertEqual(list(out.shape), [T, H, D]) + self.assertTrue(paddle.isfinite(out).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/attention/test_xattention_estimate.py b/tests/operators/attention/test_xattention_estimate.py new file mode 100644 index 00000000000..26259a76213 --- /dev/null +++ b/tests/operators/attention/test_xattention_estimate.py @@ -0,0 +1,232 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GPU unit tests for the elastic xattention estimate pipeline. + +Covers the Triton-only sub-stack that does NOT depend on the standalone +``block_sparse_attn_ops`` build: + +- ``find_blocks_chunked`` -- threshold-cumulative block selector +- ``xattn_estimate`` -- Triton GEMM + softmax-block-sum -> bool mask + +The full ``Xattention_prefill_dim4`` (BSA op + estimate) is exercised by +``tests/layers/test_elastic_attention_backend.py``. +""" + +import importlib.util +import os +import unittest + +import paddle +import pytest + +# File-load to side-step ``models/__init__.py`` -> attention.ops chain that +# pulls in compiled custom-op symbols (e.g. ``config_for_attention``) which +# may be missing in older fastdeploy_ops builds. The kernels themselves are +# pure paddle / triton and have no fastdeploy package dependencies. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_FB_PATH = os.path.normpath( + os.path.join( + _HERE, "..", "..", "..", "fastdeploy", "model_executor", "models", "qwen3_elastic", "kernels", "find_blocks.py" + ) +) +_spec = importlib.util.spec_from_file_location("qwen3_elastic_find_blocks_under_test", _FB_PATH) +_fb = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_fb) +find_blocks_chunked = _fb.find_blocks_chunked + + +@pytest.mark.gpu +class TestFindBlocksChunked(unittest.TestCase): + def setUp(self): + paddle.seed(7) + + def test_decode_path_returns_all_true(self): + # ``mode='prefill'`` + ``decoding=True`` is the early-return shortcut. + x = paddle.rand([1, 2, 4, 8]) + out = find_blocks_chunked( + x, current_index=0, threshold=0.5, num_to_choose=None, decoding=True, mode="prefill", causal=True + ) + self.assertEqual(list(out.shape), [1, 2, 4, 8]) + self.assertEqual(out.dtype, paddle.bool) + self.assertTrue(out.all().item()) + + def test_threshold_one_keeps_everything_under_causal(self): + # threshold close to 1.0 forces ``cumulative_sum < total_sum`` -> all + # blocks under the causal envelope are picked. Note: the kernel + # applies CHUNK-level causal envelope (zeros out columns + # >= current_index + chunk_num), NOT per-row diagonal causal -- rows + # within a chunk can attend to any column up to the chunk's right + # edge. See find_blocks.py L127-136. + Qb, Kb = 2, 4 + # use uniform attn -> any threshold below 1 picks ~all + x = paddle.ones([1, 1, Qb, Kb]) + # Use current_index < Kb - Qb so the envelope strictly excludes the + # last column, otherwise the post-causal pad would be empty and the + # test would be vacuous. + current_index = 1 + out = find_blocks_chunked( + x, + current_index=current_index, + threshold=0.999, + num_to_choose=None, + decoding=False, + mode="both", + causal=True, + ) + np_out = out.numpy()[0, 0] + envelope = current_index + Qb # cols [0, envelope) are reachable + # Inside the envelope: with uniform attention + threshold ~1, every + # column should be selected. + for i in range(Qb): + self.assertTrue(np_out[i, :envelope].all(), f"row {i} cols<{envelope} should all be True") + # Strictly out-of-envelope cols must be False. + self.assertFalse(np_out[i, envelope:].any(), f"row {i} cols>={envelope} must be False") + + def test_sink_and_diagonal_always_kept(self): + # Even with attention concentrated in one block, sink (col 0) and the + # diagonal block must be retained. + Qb, Kb = 2, 4 + x = paddle.zeros([1, 1, Qb, Kb]) + # All mass on the last column (rightmost). After causal masking, the + # diagonal still gets kept by the algorithm regardless of mass. + x[:, :, :, -1] = 1.0 + out = find_blocks_chunked( + x, current_index=Kb - Qb, threshold=0.5, num_to_choose=None, decoding=False, mode="both", causal=True + ).numpy()[0, 0] + for i in range(Qb): + self.assertTrue(out[i, 0], f"row {i} sink (col 0) must be True") + # diagonal column for row i is current_index + i = (Kb - Qb) + i + self.assertTrue(out[i, (Kb - Qb) + i], f"row {i} diagonal must be True") + + +@pytest.mark.gpu +class TestXattnEstimate(unittest.TestCase): + """Smoke-test the Triton-backed ``xattn_estimate`` shape / dtype contract. + + Numerical correctness vs. dense softmax-block-sum is non-trivial to + re-derive here; we restrict ourselves to the contract documented in + ELASTIC_FASTDEPLOY_INTEGRATION.md §3.3 / §5.5: outputs are + [B, H, Qb, Kb] with bool simple_masks and float32 attn_sums. + """ + + def setUp(self): + paddle.seed(11) + # If GPU exists, use bf16 to mirror prod; else fp16 still works on Triton. + try: + paddle.set_default_dtype("bfloat16") + except Exception: + paddle.set_default_dtype("float16") + + def tearDown(self): + paddle.set_default_dtype("float32") + + def test_output_shape_and_dtype(self): + # File-load xattention.py directly so we don't go through + # ``fastdeploy.model_executor.models.qwen3_elastic.kernels`` (which + # is gated by the parent fastdeploy package init -> attention.ops + # chain). xattention.py uses RELATIVE imports of sibling kernels + # (find_blocks, block_sparse_attn, xattention_triton); to make those + # resolve we register the parent kernels dir as a package first. + try: + import sys as _sys + + _kdir = os.path.normpath( + os.path.join( + _HERE, + "..", + "..", + "..", + "fastdeploy", + "model_executor", + "models", + "qwen3_elastic", + "kernels", + ) + ) + _kpkg_name = "qwen3_elastic_kernels_under_test" + if _kpkg_name not in _sys.modules: + _kspec = importlib.util.spec_from_file_location( + _kpkg_name, + os.path.join(_kdir, "__init__.py"), + submodule_search_locations=[_kdir], + ) + _kpkg = importlib.util.module_from_spec(_kspec) + _sys.modules[_kpkg_name] = _kpkg + # Don't exec __init__ (it imports xattention which needs + # block_sparse_attn -> may fail if BSA op missing). We just + # need the package object so relative imports inside + # xattention.py can resolve. + for _sub in ("find_blocks", "block_sparse_attn", "xattention_triton"): + _mod_name = f"{_kpkg_name}.{_sub}" + if _mod_name in _sys.modules: + continue + _ss = importlib.util.spec_from_file_location(_mod_name, os.path.join(_kdir, f"{_sub}.py")) + _sm = importlib.util.module_from_spec(_ss) + _sys.modules[_mod_name] = _sm + _ss.loader.exec_module(_sm) + _xs = importlib.util.spec_from_file_location( + f"{_kpkg_name}.xattention", + os.path.join(_kdir, "xattention.py"), + ) + _xm = importlib.util.module_from_spec(_xs) + _sys.modules[f"{_kpkg_name}.xattention"] = _xm + _xs.loader.exec_module(_xm) + xattn_estimate = _xm.xattn_estimate + except Exception as e: + self.skipTest(f"xattention deps not loadable: {e!r}") + + H = 4 + T = 2048 + D = 128 + block_size = 128 + stride = 16 + chunk_size = 2048 + + q = paddle.randn([1, H, T, D]) + k = paddle.randn([1, H, T, D]) + + try: + attn_sums, simple_masks = xattn_estimate( + q, + k, + block_size=block_size, + stride=stride, + norm=1.0, + threshold=0.9, + chunk_size=chunk_size, + use_triton=True, + causal=True, + ) + except RuntimeError as e: + # Triton requires an active CUDA driver matching the installed + # paddle build. In stripped-down test envs (e.g. CPU-only CI or + # paddle/triton CUDA mismatch) this fails before kernel launch; + # the gpu-marker contract is best-effort, so skip rather than + # red. + self.skipTest(f"triton driver unavailable: {e!r}") + + Qb = T // block_size + Kb = T // block_size + self.assertEqual(list(simple_masks.shape), [1, H, Qb, Kb]) + self.assertEqual(simple_masks.dtype, paddle.bool) + # attn_sums is float32 per softmax_fuse_block_sum kernel + self.assertEqual(list(attn_sums.shape)[:2], [1, H]) + # Causal: upper-triangular (strictly above the diagonal) must be all False. + sm = simple_masks.numpy()[0, 0] + for i in range(Qb): + self.assertFalse(sm[i, i + 1 :].any(), f"causal violated at row {i}") + + +if __name__ == "__main__": + unittest.main() From 4e5a549b809db92274ba868e5aef21d8e79ea0a7 Mon Sep 17 00:00:00 2001 From: Zhenxu Tian Date: Fri, 5 Jun 2026 11:00:52 +0800 Subject: [PATCH 3/4] Add block_sparse_attn CUDA kernels --- custom_ops/gpu_ops/block_sparse_attn/src | 1 + 1 file changed, 1 insertion(+) create mode 120000 custom_ops/gpu_ops/block_sparse_attn/src diff --git a/custom_ops/gpu_ops/block_sparse_attn/src b/custom_ops/gpu_ops/block_sparse_attn/src new file mode 120000 index 00000000000..a46cad00c9f --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src @@ -0,0 +1 @@ +/root/paddlejob/share-storage/gpfs/system-public/tzx/SongGuo/Block-Sparse-Attention/csrc/block_sparse_attn/src \ No newline at end of file From 1000e82599d879fbc8fb13a3d0c3536ac1388858 Mon Sep 17 00:00:00 2001 From: Zhenxu Tian Date: Fri, 5 Jun 2026 11:38:21 +0800 Subject: [PATCH 4/4] fix src link Signed-off-by: Zhenxu Tian --- custom_ops/gpu_ops/block_sparse_attn/src | 1 - .../gpu_ops/block_sparse_attn/src/alibi.h | 63 + .../block_sparse_attn/src/block_info.h | 47 + .../gpu_ops/block_sparse_attn/src/flash.h | 193 ++ .../block_sparse_attn/src/flash_blockmask.h | 424 ++++ ...lash_bwd_block_hdim128_bf16_causal_sm80.cu | 15 + .../src/flash_bwd_block_hdim128_bf16_sm80.cu | 15 + ...lash_bwd_block_hdim128_fp16_causal_sm80.cu | 15 + .../src/flash_bwd_block_hdim128_fp16_sm80.cu | 15 + ...flash_bwd_block_hdim32_bf16_causal_sm80.cu | 15 + .../src/flash_bwd_block_hdim32_bf16_sm80.cu | 15 + ...flash_bwd_block_hdim32_fp16_causal_sm80.cu | 15 + .../src/flash_bwd_block_hdim32_fp16_sm80.cu | 15 + ...flash_bwd_block_hdim64_bf16_causal_sm80.cu | 15 + .../src/flash_bwd_block_hdim64_bf16_sm80.cu | 15 + ...flash_bwd_block_hdim64_fp16_causal_sm80.cu | 15 + .../src/flash_bwd_block_hdim64_fp16_sm80.cu | 15 + .../block_sparse_attn/src/flash_bwd_kernel.h | 1884 +++++++++++++++++ .../src/flash_bwd_launch_template.h | 224 ++ ...lash_fwd_block_hdim128_bf16_causal_sm80.cu | 15 + .../src/flash_fwd_block_hdim128_bf16_sm80.cu | 15 + ...lash_fwd_block_hdim128_fp16_causal_sm80.cu | 15 + .../src/flash_fwd_block_hdim128_fp16_sm80.cu | 15 + ...flash_fwd_block_hdim32_bf16_causal_sm80.cu | 15 + .../src/flash_fwd_block_hdim32_bf16_sm80.cu | 15 + ...flash_fwd_block_hdim32_fp16_causal_sm80.cu | 15 + .../src/flash_fwd_block_hdim32_fp16_sm80.cu | 15 + ...flash_fwd_block_hdim64_bf16_causal_sm80.cu | 15 + .../src/flash_fwd_block_hdim64_bf16_sm80.cu | 15 + ...flash_fwd_block_hdim64_fp16_causal_sm80.cu | 15 + .../src/flash_fwd_block_hdim64_fp16_sm80.cu | 15 + .../block_sparse_attn/src/flash_fwd_kernel.h | 1297 ++++++++++++ .../src/flash_fwd_launch_template.h | 113 + .../block_sparse_attn/src/generate_kernels.py | 102 + .../block_sparse_attn/src/hardware_info.h | 41 + .../block_sparse_attn/src/kernel_traits.h | 397 ++++ .../src/kernel_traits_sm90.h | 159 ++ .../block_sparse_attn/src/namespace_config.h | 67 + .../gpu_ops/block_sparse_attn/src/philox.cuh | 167 ++ .../gpu_ops/block_sparse_attn/src/softmax.h | 323 +++ .../block_sparse_attn/src/static_switch.h | 53 + .../gpu_ops/block_sparse_attn/src/utils.h | 407 ++++ 42 files changed, 6321 insertions(+), 1 deletion(-) delete mode 120000 custom_ops/gpu_ops/block_sparse_attn/src create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/alibi.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/block_info.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_blockmask.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_kernel.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_launch_template.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_causal_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_kernel.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_launch_template.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/generate_kernels.py create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/hardware_info.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits_sm90.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/namespace_config.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/philox.cuh create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/softmax.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/static_switch.h create mode 100644 custom_ops/gpu_ops/block_sparse_attn/src/utils.h diff --git a/custom_ops/gpu_ops/block_sparse_attn/src b/custom_ops/gpu_ops/block_sparse_attn/src deleted file mode 120000 index a46cad00c9f..00000000000 --- a/custom_ops/gpu_ops/block_sparse_attn/src +++ /dev/null @@ -1 +0,0 @@ -/root/paddlejob/share-storage/gpfs/system-public/tzx/SongGuo/Block-Sparse-Attention/csrc/block_sparse_attn/src \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/alibi.h b/custom_ops/gpu_ops/block_sparse_attn/src/alibi.h new file mode 100644 index 00000000000..b0132fd3d88 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/alibi.h @@ -0,0 +1,63 @@ +#include + +#include + +#include +#include + +#include "utils.h" + +#include "namespace_config.h" +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride, + const float alibi_slope) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } +} + +} // namespace FLASH_NAMESPACE diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/block_info.h b/custom_ops/gpu_ops/block_sparse_attn/src/block_info.h new file mode 100644 index 00000000000..1e801e7393f --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/block_info.h @@ -0,0 +1,47 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +namespace FLASH_NAMESPACE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace FLASH_NAMESPACE diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash.h b/custom_ops/gpu_ops/block_sparse_attn/src/flash.h new file mode 100644 index 00000000000..4f019d5bb6d --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash.h @@ -0,0 +1,193 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +/****************************************************************************** + * Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash.h + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" + +#include +#include + +#include // For at::Generator and at::PhiloxCudaState + +#include // For at::cuda::philox::unpack + +namespace FLASH_NAMESPACE { +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + int *__restrict__ streaming_info; + int *__restrict__ head_mask_type; + // add by JXGuo + int m_block_dim, n_block_dim, num_blocksparse_heads; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Random state. + at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + bool is_exact_streaming; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream); + +} // namespace FLASH_NAMESPACE diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_blockmask.h b/custom_ops/gpu_ops/block_sparse_attn/src/flash_blockmask.h new file mode 100644 index 00000000000..67d497b7c59 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_blockmask.h @@ -0,0 +1,424 @@ +/****************************************************************************** + * Copyright (c) 2024, Junxian Guo. + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +namespace FLASH_NAMESPACE { + +class fwdIteratorBase{ +}; + + +// //////////////////////////////////////////////////////////////////////////////////////////////////// +class fwdStreaming: public fwdIteratorBase{ + public: + template + __device__ fwdStreaming(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max) {//row first + this -> row_factor = params.m_block_dim / kBlockM; + this -> col_factor = params.n_block_dim / kBlockN; + this -> sink_block_num = params.streaming_info[head_idx * 2] * col_factor; + this -> local_block_num = params.streaming_info[head_idx * 2 + 1] * col_factor; + this -> m_block_dim = params.m_block_dim; + this -> n_block_dim = params.n_block_dim; + this -> mask_type = params.head_mask_type[head_idx]; + this -> n_block_min = n_block_min; + this -> n_block_max = n_block_max; + int act_k = binfo.actual_seqlen_k; + int act_q = binfo.actual_seqlen_q; + bool causal = params.is_causal; + if (causal){ + int start_row_idx = max(int((act_q-act_k)/m_block_dim), 0); + this -> start_block_val = (cute::ceil_div(max(act_k - act_q, 0), n_block_dim) + 1 + loop_step_idx/row_factor - start_row_idx) * col_factor; + }else{ + this -> start_block_val = max(cute::ceil_div(n_block_max * kBlockN, n_block_dim) * col_factor, 0); + }; + this -> no_gap = start_block_val - n_block_min < sink_block_num + local_block_num; + this -> max_block_idx = min(sink_block_num + local_block_num, start_block_val - n_block_min); + + assert(mask_type < 0); + assert(params.m_block_dim % kBlockM == 0); + assert(params.n_block_dim % kBlockN == 0); + }; + + __device__ int mask_val(int block_col_idx) const { + if (block_col_idx > max_block_idx || block_col_idx < 0){ + return -1; + }; + int ret = 0; + if (no_gap){ + ret = start_block_val - 1 - block_col_idx; + return ret >= n_block_min ? ret : -1; + }else{ + if (block_col_idx < local_block_num){ + return start_block_val - 1 - block_col_idx; + }else{ + ret = sink_block_num - 1 - (block_col_idx - local_block_num); + return ret >= n_block_min ? ret : -1; + }; + }; + }; + + __device__ int max_no_larger(int target) const { + if(max_block_idx == 0){ + return -1; + }; + int left = 0; + int right = max_block_idx - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + if (mask_val(mid) > target) { + left = mid + 1; + } else { + right = mid - 1; + }; + }; + return (left < max_block_idx && mask_val(left) <= target) ? left : -1; + }; + + int sink_block_num, local_block_num; + int start_block_val; + bool no_gap; + + int max_block_idx; + int m_block_dim, n_block_dim; + int mask_type; + int n_block_min, n_block_max; + int row_factor, col_factor; +}; + + +class fwdExactStreaming: public fwdIteratorBase{ + public: + template + __device__ fwdExactStreaming(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max) {//row first + this -> row_factor = params.m_block_dim / kBlockM; + this -> col_factor = params.n_block_dim / kBlockN; + int sink_num = params.streaming_info[head_idx * 2]; + int local_num = params.streaming_info[head_idx * 2 + 1]; + this -> m_block_dim = params.m_block_dim; + this -> n_block_dim = params.n_block_dim; + this -> sink_block_num = cute::ceil_div(sink_num, n_block_dim) * col_factor; + this -> local_block_num = (cute::ceil_div(local_num, n_block_dim)+2) * col_factor; + + + + this -> mask_type = params.head_mask_type[head_idx]; + this -> n_block_min = n_block_min; + this -> n_block_max = n_block_max; + int act_k = binfo.actual_seqlen_k; + int act_q = binfo.actual_seqlen_q; + bool causal = params.is_causal; + if (causal){ + int start_row_idx = max(int((act_q-act_k)/m_block_dim), 0); + this -> start_block_val = (cute::ceil_div(max(act_k - act_q, 0), n_block_dim) + 1 + loop_step_idx/row_factor - start_row_idx) * col_factor; + }else{ + this -> start_block_val = max(cute::ceil_div(n_block_max * kBlockN, n_block_dim) * col_factor, 0); + }; + this -> no_gap = start_block_val - n_block_min < sink_block_num + local_block_num; + this -> max_block_idx = min(sink_block_num + local_block_num, start_block_val - n_block_min); + + assert(mask_type < 0); + assert(params.m_block_dim % kBlockM == 0); + assert(params.n_block_dim % kBlockN == 0); + }; + + __device__ int mask_val(int block_col_idx) const { + if (block_col_idx > max_block_idx || block_col_idx < 0){ + return -1; + }; + int ret = 0; + if (no_gap){ + ret = start_block_val - 1 - block_col_idx; + return ret >= n_block_min ? ret : -1; + }else{ + if (block_col_idx < local_block_num){ + return start_block_val - 1 - block_col_idx; + }else{ + ret = sink_block_num - 1 - (block_col_idx - local_block_num); + return ret >= n_block_min ? ret : -1; + }; + }; + }; + + __device__ int max_no_larger(int target) const { + if(max_block_idx == 0){ + return -1; + }; + int left = 0; + int right = max_block_idx - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + if (mask_val(mid) > target) { + left = mid + 1; + } else { + right = mid - 1; + }; + }; + return (left < max_block_idx && mask_val(left) <= target) ? left : -1; + }; + + int sink_block_num, local_block_num; + int start_block_val; + bool no_gap; + + int max_block_idx; + int m_block_dim, n_block_dim; + int mask_type; + int n_block_min, n_block_max; + int row_factor, col_factor; +}; + +// //////////////////////////////////////////////////////////////////////////////////////////////////// + +class fwdBlockmask: public fwdIteratorBase{ + public: + template + __device__ fwdBlockmask(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max) {//row first + this -> row_factor = params.m_block_dim / kBlockM; + this -> col_factor = params.n_block_dim / kBlockN; + this -> max_block_idx = cute::ceil_div(binfo.actual_seqlen_k, params.n_block_dim) * col_factor; + this -> m_block_dim = params.m_block_dim; + this -> n_block_dim = params.n_block_dim; + this -> mask_type = params.head_mask_type[head_idx]; + this -> n_block_min = n_block_min; + this -> n_block_max = n_block_max; + + assert(mask_type > 0); + assert(params.m_block_dim % kBlockM == 0); + assert(params.n_block_dim % kBlockN == 0); + + blockmask_ptr = params.blockmask + (batch_idx * params.num_blocksparse_heads + mask_type - 1) * int(params.seqlen_q_rounded / m_block_dim) * int(params.seqlen_k_rounded / n_block_dim) + int(loop_step_idx / row_factor) * int(params.seqlen_k_rounded / n_block_dim); + }; + + __device__ int mask_val(int block_col_idx) const { + if (block_col_idx > max_block_idx || block_col_idx < 0){ + return -1; + }; + int real_block_idx = block_col_idx / col_factor; + int block_col_offset = block_col_idx % col_factor; + int mask_val = blockmask_ptr[real_block_idx]; + return mask_val == -1 ? -1 : col_factor * mask_val + col_factor - 1 - block_col_offset; + }; + + __device__ int max_no_larger(int target) const { + if(max_block_idx == 0){ + return -1; + }; + int left = 0; + int right = max_block_idx - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + if (mask_val(mid) > target) { + left = mid + 1; + } else { + right = mid - 1; + }; + }; + return (left < max_block_idx && mask_val(left) <= target) ? left : -1; + }; + + int *blockmask_ptr; + int max_block_idx; + int m_block_dim, n_block_dim; + int mask_type; + int n_block_min, n_block_max; + int row_factor, col_factor; +}; + +// //////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class fwdIterator{}; + +template<> +struct fwdIterator: public fwdBlockmask{ + template + __device__ fwdIterator(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max): fwdBlockmask(params, binfo, kBlockM, kBlockN, batch_idx, head_idx, loop_step_idx, n_block_min, n_block_max) {}; +}; + +template<> +struct fwdIterator: public fwdStreaming{ + template + __device__ fwdIterator(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max): fwdStreaming(params, binfo, kBlockM, kBlockN, batch_idx, head_idx, loop_step_idx, n_block_min, n_block_max) {}; +}; + +template<> +struct fwdIterator: public fwdExactStreaming{ + template + __device__ fwdIterator(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max): fwdExactStreaming(params, binfo, kBlockM, kBlockN, batch_idx, head_idx, loop_step_idx, n_block_min, n_block_max) {}; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +class bwdIteratorBase{ +}; + + +struct bwdStreaming: public bwdIteratorBase{ + public: + template + __device__ bwdStreaming(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int m_block_min, int m_block_max) {// col first + this -> row_factor = params.m_block_dim / kBlockM; + this -> col_factor = params.n_block_dim / kBlockN; + + this -> m_block_dim = params.m_block_dim; + this -> n_block_dim = params.n_block_dim; + this -> mask_type = params.head_mask_type[head_idx]; + this -> m_block_min = m_block_min; + this -> m_block_max = m_block_max; + + int mask_block_col = cute::ceil_div(loop_step_idx+1, col_factor); + int sink = (this -> mask_type) < 0 ? params.streaming_info[head_idx * 2]: cute::ceil_div(binfo.actual_seqlen_k, this -> n_block_dim); + int local = (this -> mask_type) < 0 ? params.streaming_info[head_idx * 2 + 1]: 0; + this -> sink_block_num = sink * col_factor; + this -> local_block_num = local * col_factor; + int act_q = binfo.actual_seqlen_q; + int act_k = binfo.actual_seqlen_k; + bool causal = params.is_causal; + + if(mask_block_col <= sink){ + this -> start_block_val = m_block_max; + this -> max_block_idx = m_block_max - m_block_min; + }else{ + if (causal){ + int free_token_num = act_q - min(act_q, act_k - loop_step_idx * kBlockN); + int end_mask_block_row_idx = free_token_num / params.m_block_dim;//zero based + int num_mask_block_in_end_row = max(0, cute::ceil_div(act_k - act_q + (end_mask_block_row_idx + 1) * params.m_block_dim, params.n_block_dim)); + int local_col_mask_block_num = max(0, local - (num_mask_block_in_end_row - mask_block_col)); + if(local_col_mask_block_num > 0){ + this -> start_block_val = min((end_mask_block_row_idx + local_col_mask_block_num) * row_factor, m_block_max); + this -> max_block_idx = min(local_col_mask_block_num * row_factor, m_block_max - m_block_min); + }else{ + this -> start_block_val = 0; + this -> max_block_idx = 0; + }; + }else{ + int n_mask_block_col = max(cute::ceil_div(act_k, n_block_dim), 0); + bool in_none_causal_local = !causal && mask_block_col <= n_mask_block_col && mask_block_col > n_mask_block_col - local; + if(in_none_causal_local){ + this -> start_block_val = m_block_max; + this -> max_block_idx = m_block_max - m_block_min; + }else{ + this -> start_block_val = 0; + this -> max_block_idx = 0; + }; + }; + } + + assert(mask_type <= 0); //for blocksparse, mask_type > 0; for streaming, mask_type < 0; for dense, mask_type = 0 + assert(params.m_block_dim % kBlockM == 0); + assert(params.n_block_dim % kBlockN == 0); + }; + + __device__ int mask_val(int block_row_idx) const { + if (block_row_idx > max_block_idx || block_row_idx < 0){ + return -1; + }; + int ret = start_block_val - 1 - block_row_idx; + return ret >= m_block_min ? ret : -1; + }; + + __device__ int max_no_larger(int target) const { + if(max_block_idx == 0){ + return -1; + }; + int left = 0; + int right = max_block_idx - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + if (mask_val(mid) > target) { + left = mid + 1; + } else { + right = mid - 1; + }; + }; + return (left < max_block_idx && mask_val(left) <= target) ? left : -1; + }; + + int sink_block_num, local_block_num; + int start_block_val; + + int max_block_idx; + int m_block_dim, n_block_dim; + int mask_type; + int m_block_min, m_block_max; + int row_factor, col_factor; +}; + +struct bwdBlockmask: public bwdIteratorBase{ + public: + template + __device__ bwdBlockmask(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int m_block_min, int m_block_max) { + this -> row_factor = params.m_block_dim / kBlockM; + this -> col_factor = params.n_block_dim / kBlockN; + this -> max_block_idx = cute::ceil_div(binfo.actual_seqlen_q, params.m_block_dim) * row_factor; + this -> m_block_dim = params.m_block_dim; + this -> n_block_dim = params.n_block_dim; + this -> mask_type = params.head_mask_type[head_idx]; + this -> m_block_min = m_block_min; + this -> m_block_max = m_block_max; + assert(mask_type > 0); + assert(params.m_block_dim % kBlockM == 0); + assert(params.n_block_dim % kBlockN == 0); + + blockmask_ptr = params.blockmask + (batch_idx * params.num_blocksparse_heads + mask_type - 1) * int(params.seqlen_k_rounded / n_block_dim) * int(params.seqlen_q_rounded / m_block_dim) + int(loop_step_idx / col_factor) * int(params.seqlen_q_rounded / m_block_dim); + }; + + __device__ int mask_val(int block_row_idx) const { + if (block_row_idx > max_block_idx || block_row_idx < 0){ + return -1; + }; + int real_block_idx = block_row_idx / row_factor; + int block_row_offset = block_row_idx % row_factor; + int mask_val = blockmask_ptr[real_block_idx]; + return mask_val == -1 ? -1 : row_factor * mask_val + row_factor - 1 - block_row_offset; + }; + + __device__ int max_no_larger(int target) const { + if(max_block_idx == 0){ + return -1; + }; + int left = 0; + int right = max_block_idx - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + if (mask_val(mid) > target) { + left = mid + 1; + } else { + right = mid - 1; + }; + }; + return (left < max_block_idx && mask_val(left) <= target) ? left : -1; + }; + + int *blockmask_ptr; + int max_block_idx; + int m_block_dim, n_block_dim; + int mask_type; + int m_block_min, m_block_max; + int row_factor, col_factor; +}; + + + +template +class bwdIterator{}; + +template<> +struct bwdIterator: public bwdBlockmask{ + template + __device__ bwdIterator(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int m_block_min, int m_block_max): bwdBlockmask(params, binfo, kBlockM, kBlockN, batch_idx, head_idx, loop_step_idx, m_block_min, m_block_max) {}; +}; + +template<> +struct bwdIterator: public bwdStreaming{ + template + __device__ bwdIterator(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int m_block_min, int m_block_max): bwdStreaming(params, binfo, kBlockM, kBlockN, batch_idx, head_idx, loop_step_idx, m_block_min, m_block_max) {}; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_causal_sm80.cu new file mode 100644 index 00000000000..307a9199f35 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_sm80.cu new file mode 100644 index 00000000000..bf42cedb92f --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_causal_sm80.cu new file mode 100644 index 00000000000..28bd14a17bc --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_sm80.cu new file mode 100644 index 00000000000..21139acc95a --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_causal_sm80.cu new file mode 100644 index 00000000000..d3f1602f9ef --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_sm80.cu new file mode 100644 index 00000000000..6cb3c3d8a41 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_causal_sm80.cu new file mode 100644 index 00000000000..28fad876e74 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_sm80.cu new file mode 100644 index 00000000000..64bfaa84a45 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_causal_sm80.cu new file mode 100644 index 00000000000..7a9d3b3933d --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_sm80.cu new file mode 100644 index 00000000000..f52e73c52e6 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_causal_sm80.cu new file mode 100644 index 00000000000..602faccd5b5 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_sm80.cu new file mode 100644 index 00000000000..b7d64edfce2 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_bwd_block_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_kernel.h b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_kernel.h new file mode 100644 index 00000000000..9f0ce146726 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_kernel.h @@ -0,0 +1,1884 @@ +/*************************************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +/****************************************************************************** + * Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" + +#include "alibi.h" + +#include "flash_blockmask.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B_warpcontiguousN(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; + // Divide by 2 because right now we always use 2 for the ValLayout + constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; + constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; + // This gives the correct layout, idk why. + // auto t = make_tile(Layout, _2>, + // Stride, _8> >{}, + // auto t = make_tile(Layout, + // Stride<_1, _64, _8> >{}, + auto t = make_tile(Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) + Stride<_1, Int, _8> >{}, // (1, 64, 8) or (1, 32, 8) + make_layout(size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; + // Divide by 2 because right now we always use 2 for the ValLayout + constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; + constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; + auto t = make_tile(make_layout(size<0>(TileShape_MNK{})), + Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) + Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, + Tensor &dP_sum, const int gdP_col_stride, const float scale) { + static_assert(Layout0::rank == 3, "Only support 3D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); + // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) + // The last coordinate is the "page". + Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), + make_layout(get<0>(do_.layout()), + get<2>(do_.layout())))); + Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); + Tensor do_fp32 = FLASH_NAMESPACE::convert_type(do_reshaped); + Tensor o_fp32 = FLASH_NAMESPACE::convert_type(o_reshaped); + #pragma unroll + for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(do_reshaped); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + FLASH_NAMESPACE::SumOp sum_op; + dP_sum_cur = FLASH_NAMESPACE::Allreduce::run(dP_sum_cur, sum_op) * scale; + if (threadIdx.x % THREADS_PER_ROW == 0) { + dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + // TODO: careful, we're zeroing out dQaccum with type float4, but when + // we do atomicAdds, we use type float. The layouts are different. Check this. + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); + + // Allocate predicate tensors for k + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); + // Set predicates for k bounds + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final + // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, + // so that (dP - dP_sum) is on the same scale. + dot_do_o(tdOrdO, tdOrO, dP_sum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear_dKVaccum(const Params ¶ms) { + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + Tensor zero = make_fragment_like(tdKgdKaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dQ from dQaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + + Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQ{}); + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + + Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); + clear(acc_dq); + for (int s = 0; s < nsplits; ++s) { + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } + tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; + } + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_q. +template +inline __device__ void convert_dKV(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.d_rounded; + + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKV{}); + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); + CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); + + Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); + Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { + acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; + } + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { + acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; + } + // Convert acc_dk from fp32 to fp16 + Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); + Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + + Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + // constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; + constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); + if (Is_local) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); + } + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded + // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + + (m_block_max - 1) * kBlockM; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + + (m_block_max - 1) * kBlockM; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{}); + Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + // Double buffer for sQ + Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); + Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); + Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); + Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), + typename Kernel_traits::SmemLayoutPdS{}); + Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + // sP and sdQ share the same memory so be careful + Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + using GmemTiledCopydO = std::conditional_t< + Is_first, + typename Kernel_traits::GmemTiledCopydO, + typename Kernel_traits::GmemTiledCopyQKV + >; + GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + using GmemLayoutAtomdQaccum = std::conditional_t< + !Seq_parallel, + typename Kernel_traits::GmemTiledCopydQaccum, + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd + >; + GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); } + // __syncthreads(); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) { + // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data()); + // } + + typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; + auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); + Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) + Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) + Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); + Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); + Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) + Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); + auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); + auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); } + // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + // Partition sP and sdS to match the accumulator partitioning + // This has to be tiled_mma_sdp, not tiled_mma_dkv + // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); + auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); } + // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); } + // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) { + // printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data()); + // } + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); + auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); + auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq); + auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq); + auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // + // PREDICATES + // + + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We'll advance gdQ and gdQaccum before the 1st read/write. + tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; + tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; + + int m_block = m_block_max - 1; + int m_block_min = (!Is_causal && !Is_local) + ? 0 + : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); + // If not local, we're guaranteed that m_block_min <= m_block: + // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, + // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. + // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. + // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. + // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. + // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. + // However, if local, then this possible to have some blocks of K & V not attending to any query. + // We might need to exit early and write 0 to dK and dV for those blocks. + // Otherwise we get wrong result for the case where we don't enter the for loop. + // And we might read OOB elements from gQ and gdO. + // This also covers the case where actual_seqlen_q == 0 + if ((Is_local || !Is_even_MN) && m_block < m_block_min) { + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + clear(tdKrdK); + clear(tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + return; + } + + if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ + tQsQ.data() = tQsQ.data() + size(sQ); + tSsQ.data() = tSsQ.data() + size(sQ); + tdKsQt.data() = tdKsQt.data() + size(sQ); + } + + if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); } + + if (Kernel_traits::Is_V_in_regs) { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::cp_async_fence(); + } + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + if (!Is_first) { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + } else { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + } + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + + Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0>(taccScS))::value == 4); + // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. + Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + // Tensor tKrK = make_fragment_like(tKsK); + // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); + // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); + // // if (cute::thread(1, 0)) { print(tKrK); } + + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + if (!Kernel_traits::Is_V_in_regs) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + FLASH_NAMESPACE::cp_async_fence(); + + // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } + if (Is_first) { + cute::copy(tdOrdO, tdOsdO); + dot_do_o(tdOrdO, tdOrO, gdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + } + + if (Kernel_traits::Is_V_in_regs) { + cute::cp_async_wait<1>(); + __syncthreads(); + Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); + CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M + cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); + } + + auto seed = params.rng_state[0]; + auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + + clear(acc_dv); + clear(acc_dk); + + float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + + for (; m_block >= m_block_min; --m_block) { + Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + clear(acc_s); + cute::cp_async_wait<0>(); + __syncthreads(); + + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } + + // if (cute::thread0()) { print(sK); } + // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); + // #pragma unroll + // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { + // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); + // } + // if (cute::thread0()) { print(tSrK); } + FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread(32, 0)) { print(scores); } + + if (Has_alibi) { + FLASH_NAMESPACE::apply_alibi( + scores, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, + m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + AtomLayoutMS * 16, + alibi_slope + ); + } + + // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond + // actual_seqlen_k, because acc_s would be some finite value for those indices. + // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, + // so the result would still be correct. + // However, it's possible that the values in acc_s are so large that they overflow + // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. + // So we need to mask out the elements beyond actual_seqlen_k. + if (!Is_causal && !Is_local) { + if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { + FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); + } + } else if (Is_causal) { + // Putting this causal masking right after acc_s is *much* slower for some reason. + // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short + // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. + // But we still want to mask out elements beyond actual_seqlen_k. + if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { + FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + AtomLayoutMS * 16); + } + } else if (Is_local) { + if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right + || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left + || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { + FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, AtomLayoutMS * 16, + params.window_size_left, params.window_size_right); + } + + } + + // if (cute::thread(32, 0)) { print(scores); } + // Compute the exponential value. + FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + if (Is_dropout) { + int warp_id = tidx / 32; + int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 + static_assert(MMA_N_SdP % 2 == 0); + int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + Tensor scores_dropped = make_tensor(scores.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(scores.layout())); + FLASH_NAMESPACE::apply_dropout( + scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, AtomLayoutMS + ); + } + // Convert scores from fp32 to fp16/bf16 + Tensor rP = !Is_dropout + ? FLASH_NAMESPACE::convert_type(scores) + : FLASH_NAMESPACE::convert_type_relu(scores); + // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + // if (cute::thread0()) { print(tPaP); } + // __syncthreads(); + // if (cute::thread0()) { print(sP); } + + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA + + clear(acc_dp); + + FLASH_NAMESPACE::gemm( + acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV + ); + + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (!Is_dropout || p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + } + } + // if (cute::thread0()) { print(dS); } + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); + if (Is_first || Seq_parallel) { + clear(acc_dq); + } else { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), + get<2>(acc_dq.layout()), + get<1>(acc_dq.layout()))); + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); + } + + if (Double_buffer && m_block > m_block_min) { + // Double buffer for sQ + const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); + tQsQ.data() = tQsQ.data() + sQ_offset; + tSsQ.data() = tSsQ.data() + sQ_offset; + // Advance gQ + tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); + } + + Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + // Convert dS from fp32 to fp16 + Tensor tdSrdS = FLASH_NAMESPACE::convert_type(dS_reshaped); + // if (cute::thread0()) { print(tPrP); } + Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + __syncthreads(); + + // Layout p_l = tPrP.layout(); + // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); + // FLASH_NAMESPACE::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); + // FLASH_NAMESPACE::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } + // if (cute::thread0()) { print(acc_dv); } + + __syncthreads(); // Need syncthreads since we're writing to the same sdO location + + if (m_block > m_block_min) { + // Advance gdO + tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); + if (Is_first) { + tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); + } else { + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); + } + } + + FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); + // if (cute::thread0()) { print(acc_dq); } + + if (m_block > m_block_min) { + gLSE.data() = gLSE.data() + (-int(kBlockM)); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } + gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + } + + if (!Is_last) { + // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum + Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), + get<2>(acc_dq.layout()), + get<1>(acc_dq.layout()))); + if (!Seq_parallel) { + cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); + } else { + // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + } + } else { + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + } + + FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + // if (cute::thread0()) { print(acc_dk); } + if (Double_buffer) { // Double buffer for sQ + tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); + } + if (!Double_buffer && m_block > m_block_min) { + __syncthreads(); + // Advance gQ + tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); + } + + if (Is_first && m_block > m_block_min) { + cute::copy(tdOrdO, tdOsdO); + dot_do_o(tdOrdO, tdOrO, gdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + } + + if (Is_last) { + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + #pragma unroll + for (int m = 0; m < size<1>(tdQgdQ); ++m) { + if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { + cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); + } + } + } + + } + + // Epilogue + + if (Is_dropout) { + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; } + } + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; } + + // Convert acc_dv from fp32 to fp16 + Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); + Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); + + Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + // Partition sdV and sdK to match the accumulator partitioning + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // We need syncthreads here since we're writing to the same location as sK and sV. + // Without syncthreads, some thread might modify the location of sK while another thread + // is reading it for dQ gemm, leading to a race condition. + // If Is_last, there's already a __syncthreads() at the end of the loop. + if (!Is_last) { __syncthreads(); } + + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + +} + + +// for blocksparse +template +inline __device__ void compute_block_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { + // if (bidb == 0 && threadIdx.x == 0) printf("[compute_block_dq_dk_dv_1colblock] \n"); + // printf("[early return]\n"); + // return; + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + // constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; + constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); + // for causal blocksparse + + // int blockmask_rounded_length = cute::ceil_div(binfo.actual_seqlen_q, params.m_block_dim) * params.m_block_dim; + // int max_block_idx = cute::ceil_div(blockmask_rounded_length, kBlockM); + if (Is_local) { + m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); + } + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded + // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + + (m_block_max - 1) * kBlockM; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + + (m_block_max - 1) * kBlockM; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQdO{}); + Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + // Double buffer for sQ + Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); + Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); + Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), + typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); + Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); + Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); + Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), + typename Kernel_traits::SmemLayoutPdS{}); + Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); + Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); + Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); + // sP and sdQ share the same memory so be careful + Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + using GmemTiledCopydO = std::conditional_t< + Is_first, + typename Kernel_traits::GmemTiledCopydO, + typename Kernel_traits::GmemTiledCopyQKV + >; + GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + using GmemLayoutAtomdQaccum = std::conditional_t< + !Seq_parallel, + typename Kernel_traits::GmemTiledCopydQaccum, + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd + >; + GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); } + // __syncthreads(); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) { + // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data()); + // } + + typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; + auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); + Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) + Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) + Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); + Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) + Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); + Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) + Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); + auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); + Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); + + // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); + auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); + auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_KV.partition_S(sK); + // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); } + // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } + Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); + + // Partition sP and sdS to match the accumulator partitioning + // This has to be tiled_mma_sdp, not tiled_mma_dkv + // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); + auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); } + // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); } + // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) { + // printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data()); + // } + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); + auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx); + Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); + Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); + + auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); + auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx); + Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); + Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); + + auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq); + auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx); + Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); + + auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq); + auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx); + Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); + + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // + // PREDICATES + // + + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); + Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We'll advance gdQ and gdQaccum before the 1st read/write. + // tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; + // tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; + + int m_block = m_block_max - 1; + + int m_block_min = (!Is_causal && !Is_local) + ? 0 + : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); + + // If not local, we're guaranteed that m_block_min <= m_block: + // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, + // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. + // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. + // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. + // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. + // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. + // However, if local, then this possible to have some blocks of K & V not attending to any query. + // We might need to exit early and write 0 to dK and dV for those blocks. + // Otherwise we get wrong result for the case where we don't enter the for loop. + // And we might read OOB elements from gQ and gdO. + // This also covers the case where actual_seqlen_q == 0 + + + // add by JXGuo + bwdIterator blockmask(params, binfo, kBlockM, kBlockN, bidb, bidh, n_block, m_block_min, m_block_max); + int max_block_idx = blockmask.max_block_idx; + bool empty_col_flag = m_block_max <= m_block_min; + int max_no_larger_idx = blockmask.max_no_larger(m_block_max-1); + empty_col_flag = empty_col_flag || max_no_larger_idx == -1 || blockmask.mask_val(max_no_larger_idx) < m_block_min; + + __syncthreads(); + + if (empty_col_flag) { + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + clear(tdKrdK); + clear(tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + return; + } + + int mask_block_idx = max_no_larger_idx; + int mask_val = mask_block_idx == -1 ? -1 : blockmask.mask_val(mask_block_idx); + int next_block_row_idx = mask_val; + + int leap = m_block - next_block_row_idx; + int next_leap = 0; + + if (Double_buffer && mask_block_idx % 2 == 1) { // Double buffer for sQ + tQsQ.data() = tQsQ.data() + size(sQ); + tSsQ.data() = tSsQ.data() + size(sQ); + tdKsQt.data() = tdKsQt.data() + size(sQ); + } + + if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); } + + if (Kernel_traits::Is_V_in_regs) { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::cp_async_fence(); + } + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + + if (leap > 0){ + tdOgdO.data() = tdOgdO.data() + (-int(leap * kBlockM * params.do_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); + }else{ + if (!Is_first) {// add by JXGuo: Is_first is always false + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + } else { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM + ); + } + } + + if (leap > 0){ + tQgQ.data() = tQgQ.data() + (-int(leap * kBlockM * params.q_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + }else{ + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); + } + + + Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0>(taccScS))::value == 4); + // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. + Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + Tensor lse = make_tensor(Shape>{}); + + if (leap > 0){ + gLSE.data() = gLSE.data() + (-int(leap * kBlockM)); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } + }else{ + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + } + } + + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + if (!Kernel_traits::Is_V_in_regs) { + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + FLASH_NAMESPACE::cp_async_fence(); + + if (Is_first) { + cute::copy(tdOrdO, tdOsdO); + dot_do_o(tdOrdO, tdOrO, gdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + } + + if (Kernel_traits::Is_V_in_regs) { + cute::cp_async_wait<1>(); + __syncthreads(); + Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); + CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M + cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); + } + + auto seed = params.rng_state[0]; + auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32; + + clear(acc_dv); + clear(acc_dk); + + + if(leap > 0){ + gdPsum.data() = gdPsum.data() + (-int(leap * kBlockM)); + m_block = next_block_row_idx; + } + + bool current_is_last_block = false; + + for(; !current_is_last_block && m_block >= m_block_min; m_block = next_block_row_idx){ + current_is_last_block = m_block <= m_block_min || mask_block_idx >= (max_block_idx - 1); + next_leap = 0; + if(!current_is_last_block){ + ++mask_block_idx; + mask_val = blockmask.mask_val(mask_block_idx); + next_block_row_idx = mask_val; + next_leap = m_block - next_block_row_idx; + current_is_last_block = current_is_last_block || mask_val == -1; + } + + Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + clear(acc_s); + cute::cp_async_wait<0>(); + __syncthreads(); + + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } + + FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); + + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + + if (!Is_causal && !Is_local) { + if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { + FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k, + n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); + } + } else if (Is_causal) { + // Putting this causal masking right after acc_s is *much* slower for some reason. + // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short + // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. + // But we still want to mask out elements beyond actual_seqlen_k. + if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { + FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, + // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + AtomLayoutMS * 16); + } + } else if (Is_local) { + if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right + || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left + || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { + FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, AtomLayoutMS * 16, + params.window_size_left, params.window_size_right); + } + } + + FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + + if (Is_dropout) { + int warp_id = tidx / 32; + int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + static_assert(MMA_N_SdP % 2 == 0); + int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + Tensor scores_dropped = make_tensor(scores.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(scores.layout())); + FLASH_NAMESPACE::apply_dropout( + scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, AtomLayoutMS + ); + } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = !Is_dropout + ? FLASH_NAMESPACE::convert_type(scores) + : FLASH_NAMESPACE::convert_type_relu(scores); + Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(rP.layout())); + Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + + Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) + CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA + CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA + + clear(acc_dp); + + FLASH_NAMESPACE::gemm( + acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, + smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV + ); + + Tensor dS = make_tensor(acc_dp.data(), scores.layout()); + auto pointwise_mult = [](float p, float dp, float d) { + return p * (!Is_dropout || p >= 0 ? dp - d : d); + }; + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { + dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + } + } + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(leap * kBlockM * params.h * params.d_rounded)); + if (Is_first || Seq_parallel) { + clear(acc_dq); + } else { + Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), + get<2>(acc_dq.layout()), + get<1>(acc_dq.layout()))); + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); + } + + if (Double_buffer && !current_is_last_block) { + // Double buffer for sQ + const int sQ_offset = (mask_block_idx - 1) % 2 == 0 ? size(sQ) : -size(sQ); + tQsQ.data() = tQsQ.data() + sQ_offset; + tSsQ.data() = tSsQ.data() + sQ_offset; + // Advance gQ + tQgQ.data() = tQgQ.data() + (-int(next_leap * kBlockM * params.q_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); + } + + Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); + Tensor tdSrdS = FLASH_NAMESPACE::convert_type(dS_reshaped); + Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + __syncthreads(); + + FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + + __syncthreads(); + + if (!current_is_last_block) { + // Advance gdO + tdOgdO.data() = tdOgdO.data() + (-int(next_leap * kBlockM * params.do_row_stride)); + if (Is_first) { + tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); + } else { + FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); + } + } + + FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, + smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); + + if (!current_is_last_block) { + gLSE.data() = gLSE.data() + (-int(next_leap * kBlockM)); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } + gdPsum.data() = gdPsum.data() + (-int(next_leap * kBlockM)); + } + + if (!Is_last) { + Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), + make_layout(get<0>(acc_dq.layout()), + get<2>(acc_dq.layout()), + get<1>(acc_dq.layout()))); + if (!Seq_parallel) { + cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); + } else { + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } + } + } else { + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + } + + FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, + smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); + if (Double_buffer) { + tdKsQt.data() = tdKsQt.data() + ((mask_block_idx - 1) % 2 == 0 ? size(sQ) : -size(sQ)); + } + if (!Double_buffer && !current_is_last_block) { + __syncthreads(); + tQgQ.data() = tQgQ.data() + (-int(next_leap * kBlockM * params.q_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); + FLASH_NAMESPACE::cp_async_fence(); + } + + if (Is_first && m_block > m_block_min) { + cute::copy(tdOrdO, tdOsdO); + dot_do_o(tdOrdO, tdOrO, gdPsum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + } + + if (Is_last) { + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); + Tensor cdQ = make_identity_tensor(Shape, Int>{}); + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + #pragma unroll + for (int m = 0; m < size<1>(tdQgdQ); ++m) { + if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { + cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); + } + } + } + + leap = next_leap; + } + + if (Is_dropout) { + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; } + } + + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; } + + Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); + Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); + + Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + // Partition sdV and sdK to match the accumulator partitioning + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // We need syncthreads here since we're writing to the same location as sK and sV. + // Without syncthreads, some thread might modify the location of sK while another thread + // is reading it for dQ gemm, leading to a race condition. + // If Is_last, there's already a __syncthreads() at the end of the loop. + if (!Is_last) { __syncthreads(); } + + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + + typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + FLASH_NAMESPACE::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { + + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + } +} + + +// for blocksparse +template +inline __device__ void compute_block_dq_dk_dv_seqk_parallel(const Params ¶ms) { + + // const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + const int head_mask_type = params.head_mask_type[bidh]; + + for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { + if (head_mask_type > 0){ + compute_block_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + // }else if (head_mask_type > 0){ + // compute_block_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + }else{ + compute_block_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + }; + }; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_launch_template.h b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_launch_template.h new file mode 100644 index 00000000000..87283163b12 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_bwd_launch_template.h @@ -0,0 +1,224 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + * Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_launch_template.h + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include "static_switch.h" +#include "hardware_info.h" +#include "flash.h" +#include "flash_bwd_kernel.h" + +namespace FLASH_NAMESPACE { + +template +__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { + FLASH_NAMESPACE::compute_dot_do_o(params); +} + +//add by JXGuo: not used +template +__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { + FLASH_NAMESPACE::clear_dKVaccum(params); +} + + +template +__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); +} + + +// for blocksparse-flash-attention2 +template +__global__ void flash_bwd_block_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + FLASH_NAMESPACE::compute_block_dq_dk_dv_seqk_parallel(params); +} + + +template +__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) { + FLASH_NAMESPACE::convert_dQ(params, nsplits); +} + +// add by JXGuo: not used +template +__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { + FLASH_NAMESPACE::convert_dKV(params); +} + + +// for blocksparse-flash-attention2 +template +void run_flash_bwd_block_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid_m(num_m_block, params.b, params.h); + const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + int gridDimx = num_n_block; + if (params.deterministic) { + int num_sm = get_num_sm(get_current_device()); + gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h); + } + dim3 grid_n(gridDimx, params.b, params.h); + + if (!params.deterministic) { + flash_bwd_dot_do_o_kernel<<>>(params); + } else { + flash_bwd_dot_do_o_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not + // a multiple of kBlockN, we'll need to apply mask in the loop. + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; + // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_block_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + + auto kernel_dq = &flash_bwd_convert_dq_kernel; + if (Kernel_traits::kSmemdQSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); + } + kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + + +// for blocksparse-flash-attention2 +template +void run_flash_bwd_block(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_flash_bwd_block_seqk_parallel(params, stream); +} + +// for blocksparse-flash-attention2 +template +void run_mha_bwd_block_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB + if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + } + } else { // 96 KB + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + } + }); +} + +template +void run_mha_bwd_block_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // Changing AtomLayoutMdQ from 2 to 4 takes the same time + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // This is slightly faster. We want to split M more so we need fewer registers to store LSE. + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + // This has a lot of register spilling + // run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + // if (params.h == params.h_k) { + // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // } else { + // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); + // } + } + }); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream, configure); + + // run_flash_bwd>(params, stream, configure); +} + +template +void run_mha_bwd_block_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + int device; + cudaGetDevice(&device); + int max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // if (params.h == params.h_k) { + // run_flash_bwd>(params, stream, configure); + // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). + // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. + // run_flash_bwd>(params, stream, configure); + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream, configure); + } else { + // run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd_block, Is_dropout, Is_causal>(params, stream); + } + // run_flash_bwd>(params, stream, configure); + + // run_flash_bwd>(params, stream, configure); + // } else { + // run_flash_bwd_seqq_parallel>(params, stream, configure); + // } + }); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_causal_sm80.cu new file mode 100644 index 00000000000..c637ab2841c --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu new file mode 100644 index 00000000000..b944de659e7 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_causal_sm80.cu new file mode 100644 index 00000000000..f5c836be11d --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu new file mode 100644 index 00000000000..6cf37a42040 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim128(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_causal_sm80.cu new file mode 100644 index 00000000000..67076ef16cf --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_sm80.cu new file mode 100644 index 00000000000..c5544526417 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_causal_sm80.cu new file mode 100644 index 00000000000..878e527d005 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_sm80.cu new file mode 100644 index 00000000000..b9e99167d4b --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim32(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_causal_sm80.cu new file mode 100644 index 00000000000..18fac5e737a --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu new file mode 100644 index 00000000000..517c1c17bab --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_causal_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_causal_sm80.cu new file mode 100644 index 00000000000..89580e8ede8 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu new file mode 100644 index 00000000000..2900dea5e25 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" +#include "namespace_config.h" +#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE { + +template<> +void run_mha_fwd_block_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_block_hdim64(params, stream); +} + +} // namespace FLASH_NAMESPACE \ No newline at end of file diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_kernel.h b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_kernel.h new file mode 100644 index 00000000000..675741f9d96 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_kernel.h @@ -0,0 +1,1297 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +/****************************************************************************** + * Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_fwd_kernel.h + ******************************************************************************/ + +#pragma once + +#include "namespace_config.h" +// #include "philox_unpack.cuh" // For at::cuda::philox::unpack + +#include + +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" + +#include "alibi.h" + +#include "flash_blockmask.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2) { + if (Is_first) { + FLASH_NAMESPACE::template reduce_max(scores, scores_max); + FLASH_NAMESPACE::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + FLASH_NAMESPACE::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + FLASH_NAMESPACE::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + FLASH_NAMESPACE::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + FLASH_NAMESPACE::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + + +template +inline __device__ void softmax_rescale_o_block(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, + Tensor2 &acc_o, float softmax_scale_log2, bool Is_blocksparse_skip) { + if (Is_first) { + FLASH_NAMESPACE::template reduce_max(scores, scores_max); + FLASH_NAMESPACE::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + FLASH_NAMESPACE::reduce_sum(scores, scores_sum); + } else { + Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + FLASH_NAMESPACE::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scores_max); ++mi) { + float scores_max_cur = !(Check_inf || Is_blocksparse_skip) + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + FLASH_NAMESPACE::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + Tensor scores_sum_cur = make_fragment_like(scores_sum); + FLASH_NAMESPACE::reduce_sum(scores, scores_sum_cur); + #pragma unroll + for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P +) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + Layout l = tOrP.layout(); + Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); + #pragma unroll + for (int mi = 0; mi < size<1>(tPrP); ++mi) { + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // Copy rmem to smem + // // copy(tQrQ, tQsQ); + // flash::cp_async_wait<0>(); + // __syncthreads(); + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + // Save seed and offset for backward. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = seed; + params.rng_state[1] = std::get<1>(seeds); + } + + clear(acc_o); + + float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print_tensor(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) + // static_assert(decltype(size<0>(taccScS))::value == 4); + // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. + // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); + // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM); + // Idk why it's get<1> and not get<0> of the stride. + // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } + // I can't get the stride from idx_row + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 + // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 + ); + // if (cute::thread0()) { print_tensor(scores); } + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + cute::copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + // if (cute::thread0()) { print(tOrP); } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + + if (Has_alibi) { + flash::apply_alibi( + scores, + n_block * kBlockN, + binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16, + alibi_slope + ); + } + + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right + ); + } + + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + cute::copy(tOrP, tOrP_copy); + flash::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); + tPgP.data() = tPgP.data() + (-kBlockN); + } + if (Is_dropout) { + flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + + +template +inline __device__ void compute_block_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + + fwdIterator blockmask(params, binfo, kBlockM, kBlockN, bidb, bidh, m_block, n_block_min, n_block_max); + // const bool Is_blocksparse = true; + // for causal blocksparse + // const int sink_num = Is_exact_streaming? params.streaming_info[bidh * 2] : 0; + // const int local_num = Is_exact_streaming? params.streaming_info[bidh * 2 + 1] : 0; + int max_block_idx = blockmask.max_block_idx; + bool empty_line_flag = n_block_max <= n_block_min; + int max_no_larger_idx = blockmask.max_no_larger(n_block_max-1); + empty_line_flag = empty_line_flag || max_no_larger_idx == -1 || blockmask.mask_val(max_no_larger_idx) < n_block_min; + // for (int i = 0; i < max_block_idx; i++) { + // int tmp_mask_val = blockmask.mask_val(i); + // if (tmp_mask_val == -1){ + // empty_line_flag = true; + // break; + // } + // if (tmp_mask_val < n_block_max && tmp_mask_val >= n_block_min) { + // empty_line_flag = false; + // break; + // } + // } // use function: between(min, max) + + __syncthreads(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_min = %d, n_block_max = %d, empty_line = %d, params.window_size_right = %d, devidee = %d, devider = %d, kBlcokM = %d, actual_seqlen_k = %d, actual_seqlen_q = %d\n", m_block, n_block_min, n_block_max, empty_line_flag, params.window_size_right, (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN, kBlockM, binfo.actual_seqlen_k, binfo.actual_seqlen_q); + // } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if (empty_line_flag) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + Shape, Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + if (Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + FLASH_NAMESPACE::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + // Save seed and offset for backward. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = seed; + params.rng_state[1] = std::get<1>(seeds); + } + + clear(acc_o); + + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + + int mask_block_idx = max_no_larger_idx; + int mask_val = mask_block_idx == -1 ? -1 : blockmask.mask_val(mask_block_idx); + bool is_last_block = mask_val == -1; + int next_block_col_idx = mask_val; + int leap = 0; + + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + bool is_skip = n_block != next_block_col_idx; + if(is_skip){ + leap = (masking_step + 1 == n_masking_steps) ? n_block - next_block_col_idx : 1; + leap = is_last_block ? 0 : leap; + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + + FLASH_NAMESPACE::apply_mask(scores, 0); + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min && !is_last_block) { + // Advance gK + tKgK.data() = tKgK.data() + (-index_t(kBlockN * leap * params.k_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + cute::cp_async_fence(); + } + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(rP.layout())); + // int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + // int block_col_idx = n_block * (kBlockN / 32); + + if (Return_softmax) { + tPgP.data() = tPgP.data() + (-index_t(kBlockN * leap)); + } + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + }else{ + if(!is_last_block){ //is_skip==false + mask_block_idx++; + mask_val = blockmask.mask_val(mask_block_idx); + is_last_block = mask_block_idx >= max_block_idx || mask_val == -1; + next_block_col_idx = is_last_block ? -1 : mask_val; + } + leap = (masking_step + 1 == n_masking_steps) ? n_block - next_block_col_idx : 1; + leap = is_last_block ? 0 : leap; + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + FLASH_NAMESPACE::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print_tensor(scores); } + // clear(scores); + // if (cute::thread0()) { print_tensor(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + + if (!Is_causal && !Is_local && !Is_exact_streaming) { + if (!Is_even_MN) { FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + } else { + // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) + // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) + // static_assert(decltype(size<0>(taccScS))::value == 4); + // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. + // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); + // Tensor idx_rowcol = make_tensor(taccScS.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(taccScS.layout())); + // FLASH_NAMESPACE::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM); + // Idk why it's get<1> and not get<0> of the stride. + // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } + // I can't get the stride from idx_row + FLASH_NAMESPACE::apply_mask_streaming( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.streaming_info[bidh * 2 + 1], params.streaming_info[bidh * 2] + ); + } + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min && !is_last_block) { + // Advance gK + tKgK.data() = tKgK.data() + (-index_t(kBlockN * leap * params.k_row_stride)); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(rP.layout())); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + + if (Return_softmax) { + // if (!is_skip){ + Tensor tOrP_copy = make_fragment_like(tOrP); + cute::copy(tOrP, tOrP_copy); + FLASH_NAMESPACE::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + FLASH_NAMESPACE::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); + // } + // if (!is_last_block) { + // tPgP.data() = tPgP.data() + (-int(kBlockN * leap)); + // } + tPgP.data() = tPgP.data() + (-index_t(kBlockN * leap)); + } + if (Is_dropout) { + FLASH_NAMESPACE::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + // if (cute::thread0()) { print(tOrP); } + + // if(!is_skip){ + FLASH_NAMESPACE::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // } + + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + } + + // These are the iterations where we don't need masking on S + leap = n_block - next_block_col_idx; + if(!is_last_block){ + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + // if(n_block != next_block_col_idx){ + // // tKgK.data() = tKgK.data() + (-int(kBlockN * (n_block - next_block_col_idx) * params.k_row_stride)); + // tPgP.data() = tPgP.data() + (-int(kBlockN * (n_block - next_block_col_idx))); + // // FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // // cute::cp_async_fence(); + // n_block = next_block_col_idx; + // } + n_block = next_block_col_idx; + } + + // if(bidb == 0 && bidh == 0 && m_block == 0 && threadIdx.x == 0){ + // printf("[compute_block_attn_1rowblock] out bidb = %d, bidh = %d, m_block = %d, n_block = %d, mask_val = %d, leap = %d\n", bidb, bidh, m_block, n_block, mask_val, leap); + // } + + // if(bidb == 0 && bidh == 0 && m_block == 0 && threadIdx.x == 0){ + // printf("[compute_block_attn_1rowblock] early return\n"); + // } + + // return; + for(; !is_last_block && n_block >= n_block_min; n_block = next_block_col_idx){ + // if(bidb == 0 && bidh == 0 && m_block == 0 && threadIdx.x == 0){ + // printf("[compute_block_attn_1rowblock] in bidb = %d, bidh = %d, m_block = %d, n_block = %d, mask_val = %d, leap = %d\n", bidb, bidh, m_block, n_block, mask_val, leap); + // } + ++mask_block_idx; + mask_val = blockmask.mask_val(mask_block_idx); + is_last_block = mask_block_idx >= max_block_idx || mask_val == -1; + next_block_col_idx = mask_val; + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-index_t(kBlockN * leap * params.v_row_stride)); + + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + FLASH_NAMESPACE::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + + leap = n_block - next_block_col_idx; + + if (!is_last_block) { + // Advance gK + tKgK.data() = tKgK.data() + (-index_t(kBlockN * leap * params.k_row_stride)); + + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + + if (Is_exact_streaming && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + FLASH_NAMESPACE::apply_mask_streaming( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.streaming_info[bidh * 2 + 1], params.streaming_info[bidh * 2] + ); + } + + // FLASH_NAMESPACE::cp_async_wait<0>(); + // __syncthreads(); + + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + + Tensor rP = FLASH_NAMESPACE::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_rowcol_Aregs(rP.layout())); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor tOrP_copy = make_fragment_like(tOrP); + cute::copy(tOrP, tOrP_copy); + FLASH_NAMESPACE::apply_dropout( + tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps + ); + FLASH_NAMESPACE::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); + tPgP.data() = tPgP.data() + (-index_t(kBlockN * leap)); + } + if (Is_dropout) { + FLASH_NAMESPACE::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + block_row_idx, block_col_idx, kNWarps); + } + + FLASH_NAMESPACE::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + + + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); + Tensor lse = make_fragment_like(scores_sum); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + + // if (cute::thread0()) { print(acc_o_rowcol); } + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template +inline __device__ void compute_block_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + // FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + const int head_mask_type = params.head_mask_type[bidh]; + if (head_mask_type == 0){ + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + }else if (head_mask_type > 0){ + FLASH_NAMESPACE::compute_block_attn_1rowblock(params, bidb, bidh, m_block); //false for blocksparse + }else{ + FLASH_NAMESPACE::compute_block_attn_1rowblock(params, bidb, bidh, m_block); // true for streaming + } + +} + +} // namespace flash diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_launch_template.h b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_launch_template.h new file mode 100644 index 00000000000..3a26f21cbe7 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/flash_fwd_launch_template.h @@ -0,0 +1,113 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +/****************************************************************************** + * Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_fwd_launch_template.h + ******************************************************************************/ + +#pragma once +#include "namespace_config.h" +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include "static_switch.h" +#include "hardware_info.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +namespace FLASH_NAMESPACE { +template +__global__ void flash_fwd_block_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false + FLASH_NAMESPACE::compute_block_attn(params); +} + +// blocksparse +template +void run_flash_fwd_block(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + + BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + BOOL_SWITCH(params.is_exact_streaming, Is_exact_streaming, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + + auto kernel = &flash_fwd_block_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + + +template +void run_mha_fwd_block_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_block_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + } + }); +} + + +template +void run_mha_fwd_block_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; + BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd_block, Is_dropout, Is_causal>(params, stream); + } + }); +} +} // namespace FLASH_NAMESPACE diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/generate_kernels.py b/custom_ops/gpu_ops/block_sparse_attn/src/generate_kernels.py new file mode 100644 index 00000000000..a029e8474f6 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/generate_kernels.py @@ -0,0 +1,102 @@ +import argparse +import itertools +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +DTYPE_MAP = { + "fp16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +SM = [80] # Sm80 kernels support up to +HEAD_DIMENSIONS = [32, 64, 128] +IS_CAUSAL = ["false", "true"] +NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' + +def get_fwd_block_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template<> +void run_mha_fwd_block_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_block_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} + +}} // namespace FLASH_NAMESPACE""" + + +def get_bwd_block_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template<> +void run_mha_bwd_block_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_block_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} + +}} // namespace FLASH_NAMESPACE""" + +@dataclass +class Kernel: + sm: int + dtype: str + head_dim: int + is_causal: bool + direction: str + + @property + def template(self) -> str: + template_funcs = { + "fwd_block": get_fwd_block_template, + "bwd_block": get_bwd_block_template, + } + template_func = template_funcs[self.direction] + return template_func().format( + DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, + IS_CAUSAL=self.is_causal + ) + + @property + def filename(self) -> str: + return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + +def get_all_kernels() -> List[Kernel]: + for direction in ["fwd_block", "bwd_block"]: + for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + +def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: + prelude = """// Copyright (c) 2024, Tri Dao. +// Adapted by Junxian Guo. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"\n""" + content = prelude + kernel.template + (autogen_dir / kernel.filename).write_text(content) + +def main(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) + + for kernel in get_all_kernels(): + write_kernel(kernel, output_dir) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate_kernels", + description="Generate the flash_attention kernels template instantiations", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="Where to generate the kernels " + " will default to the current directory ", + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/hardware_info.h b/custom_ops/gpu_ops/block_sparse_attn/src/hardware_info.h new file mode 100644 index 00000000000..b218a29b3bb --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/hardware_info.h @@ -0,0 +1,41 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#if !defined(__CUDACC_RTC__) +#include "cuda_runtime.h" +#endif + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + + +inline int get_current_device() { + int device; + CHECK_CUDA(cudaGetDevice(&device)); + return device; +} + +inline std::tuple get_compute_capability(int device) { + int capability_major, capability_minor; + CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device)); + return {capability_major, capability_minor}; +} + +inline int get_num_sm(int device) { + int multiprocessor_count; + CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + return multiprocessor_count; +} diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits.h b/custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits.h new file mode 100644 index 00000000000..a75fea701bc --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits.h @@ -0,0 +1,397 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; //number of warps in a thread block + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype( + composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + + kSmemdSSize + kSmemPSize; + + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits_sm90.h b/custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits_sm90.h new file mode 100644 index 00000000000..ead5bc0ea2f --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/kernel_traits_sm90.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits_sm90 { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; + using ValLayoutMNK = Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutAtomVtransposed = decltype( + composition(Swizzle{}, + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + Layout, Int>, + Stride<_1, Int>>{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/namespace_config.h b/custom_ops/gpu_ops/block_sparse_attn/src/namespace_config.h new file mode 100644 index 00000000000..a6fad57b154 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/namespace_config.h @@ -0,0 +1,67 @@ +/** + * @file flash_namespace_config.h + * @brief Configuration file for Flash namespace management and isolation + * + * This header provides configuration macros for managing the Flash namespace + * across a codebase. It allows for flexible namespace naming and provides + * utilities for namespace declaration and scoping. + * + * Usage Examples: + * + * 1. Basic namespace wrapping: + * @code + * BEGIN_FLASH_NAMESPACE + * class FlashDevice { + * // Implementation + * }; + * END_FLASH_NAMESPACE + * @endcode + * + * 2. Accessing types within the namespace: + * @code + * FLASH_NAMESPACE_ALIAS(FlashDevice) device; + * @endcode + * + * 3. Defining content within namespace scope: + * @code + * FLASH_NAMESPACE_SCOPE( + * struct Configuration { + * uint32_t size; + * bool enabled; + * }; + * ) + * @endcode + * + * 4. Custom namespace name: + * @code + * #define FLASH_NAMESPACE custom_flash + * #include "flash_namespace_config.h" + * @endcode + * + * Configuration: + * - The default namespace is 'flash' if FLASH_NAMESPACE is not defined + * - Define FLASH_NAMESPACE before including this header to customize the + * namespace name + * + * Best Practices: + * - Include this header in all files that need access to the Flash namespace + * + */ +#pragma once + +#ifndef FLASH_NAMESPACE_CONFIG_H +#define FLASH_NAMESPACE_CONFIG_H + +// Set default namespace to flash +#ifndef FLASH_NAMESPACE +#define FLASH_NAMESPACE flash +#endif + +#define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name + +#define FLASH_NAMESPACE_SCOPE(content) \ + namespace FLASH_NAMESPACE { \ + content \ + } + +#endif // FLASH_NAMESPACE_CONFIG_H diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/philox.cuh b/custom_ops/gpu_ops/block_sparse_attn/src/philox.cuh new file mode 100644 index 00000000000..8bf90fcdacc --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/philox.cuh @@ -0,0 +1,167 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +#include "namespace_config.h" + +namespace FLASH_NAMESPACE { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +__forceinline__ __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash + +namespace { + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) + : STATE(0) + , seed_(seed) + , offset_(offset) + , key(reinterpret_cast(seed)) { + //key.x = (unsigned int)seed; + //key.y = (unsigned int)(seed >> 32); + //counter = make_uint4(0, 0, 0, 0); + //counter.z = (unsigned int)(subsequence); + //counter.w = (unsigned int)(subsequence >> 32); + //STATE = 0; + //incr_n(offset / 4); + + // key = reinterpret_cast(seed); + ull2 * tmp = reinterpret_cast(&counter); + tmp->x = offset / 4; + tmp->y = subsequence; + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); + // } + } + __device__ inline uint4 operator()() { + // // if (STATE == 0) { + // uint4 counter_ = counter; + // uint2 key_ = key; + // // 7-round philox + // #pragma unroll + // for (int i = 0; i < 6; i++) { + // counter_ = flash::philox_single_round(counter_, key_); + // key_.x += (kPhilox10A); + // key_.y += (kPhilox10B); + // } + // // output = philox_single_round(counter_, key_); + // uint4 output = flash::philox_single_round(counter_, key_); + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); + // // } + // incr(); + // // } + // // return a float4 directly + // // unsigned long ret; + // // switch(STATE) { + // // case 0: ret = output.x; break; + // // case 1: ret = output.y; break; + // // case 2: ret = output.z; break; + // // case 3: ret = output.w; break; + // //} + // // STATE = (STATE + 1) % 4; + // return output; + return flash::philox(seed_, offset_, offset_); + } + +private: + unsigned long long offset_, seed_; + struct ull2 { + uint64_t x; + uint64_t y; + }; + uint4 counter; + // uint4 output; + const uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + + __device__ uint4 incr128 (uint4 ctr) + { + uint4 res; + asm ("add.cc.u32 %0, %4, %8;\n\t" + "addc.cc.u32 %1, %5, %9;\n\t" + "addc.cc.u32 %2, %6, %10;\n\t" + "addc.u32 %3, %7, %11;\n\t" + : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) + : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), + "n"(1), "n"(0), "n"(0), "n"(0)); + return res; + } + + __device__ inline void incr() { + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + counter = incr128(counter); + // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); + // } + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + // static const unsigned long kPhiloxSA = 0xD2511F53; + // static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +} // namespace diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/softmax.h b/custom_ops/gpu_ops/block_sparse_attn/src/softmax.h new file mode 100644 index 00000000000..2e2b4211520 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/softmax.h @@ -0,0 +1,323 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +/****************************************************************************** + * Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "namespace_config.h" +#include "philox.cuh" +#include "utils.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + + + +template +inline __device__ void apply_mask_streaming(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int local_size, const int sink_size) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - (local_size-1)); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left && col_idx >= sink_size)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } +} + + + +template +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, + unsigned long long seed, unsigned long long offset, + int block_row_start, int block_col_start, + int block_row_stride) { + // tensor has shape (8, MMA_M, MMA_N / 2) + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } +} + +} // namespace flash diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/static_switch.h b/custom_ops/gpu_ops/block_sparse_attn/src/static_switch.h new file mode 100644 index 00000000000..40cfd6fe870 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/static_switch.h @@ -0,0 +1,53 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + + + #define HEADDIM_SWITCH(HEADDIM, ...)\ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } \ + }() + diff --git a/custom_ops/gpu_ops/block_sparse_attn/src/utils.h b/custom_ops/gpu_ops/block_sparse_attn/src/utils.h new file mode 100644 index 00000000000..238414ad5e8 --- /dev/null +++ b/custom_ops/gpu_ops/block_sparse_attn/src/utils.h @@ -0,0 +1,407 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +#include "namespace_config.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace FLASH_NAMESPACE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = FLASH_NAMESPACE::convert_type(tensor); + FLASH_NAMESPACE::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, + const int max_MN=0, const int min_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +__forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace FLASH_NAMESPACE