Skip to content
Open

Jj #8013

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
af5c139
remove unsed para in produce_kv_blockwise
zhoutianzi666 Jun 2, 2026
0447b7a
remove unsed para in produce_kv_blockwise
zhoutianzi666 Jun 2, 2026
ed50b12
remove unsed para in produce_kv_blockwise
zhoutianzi666 Jun 2, 2026
22f338b
Merge remote-tracking branch 'origin/develop' into develop
zhoutianzi666 Jun 2, 2026
635b4f3
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 2, 2026
7a327b6
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 2, 2026
ced1aa7
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 2, 2026
d055b0b
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 2, 2026
0f81ab8
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 3, 2026
1ec9ac1
Merge remote-tracking branch 'origin/develop' into develop
zhoutianzi666 Jun 3, 2026
6d27f9e
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 3, 2026
b1cdc85
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 3, 2026
3a36806
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 3, 2026
0dc803d
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 3, 2026
ee36c6f
Merge remote-tracking branch 'origin/develop' into develop
zhoutianzi666 Jun 3, 2026
9162f8d
simplify code in produce_kv_blockwise_c16
zhoutianzi666 Jun 4, 2026
58d0f0d
Merge remote-tracking branch 'origin/develop' into develop
zhoutianzi666 Jun 4, 2026
9d9eea3
support only_do_attn
zhoutianzi666 Jun 4, 2026
6daeb2e
support only_do_attn
zhoutianzi666 Jun 4, 2026
a18cb11
support only_do_attn
zhoutianzi666 Jun 4, 2026
c23d1a4
support only_do_attn
zhoutianzi666 Jun 4, 2026
459e85e
support only_do_attn
zhoutianzi666 Jun 4, 2026
8451480
support only_do_attn
zhoutianzi666 Jun 4, 2026
899ceb7
support only_do_attn
zhoutianzi666 Jun 4, 2026
f8d7da5
support only_do_attn
zhoutianzi666 Jun 4, 2026
c28b6d6
support only_do_attn
zhoutianzi666 Jun 4, 2026
9edfe26
support only_do_attn
zhoutianzi666 Jun 4, 2026
29914ab
support only_do_attn
zhoutianzi666 Jun 4, 2026
3be8f99
support only_do_attn
zhoutianzi666 Jun 4, 2026
49b4408
support only_do_attn
zhoutianzi666 Jun 4, 2026
29806bf
support only_do_attn
zhoutianzi666 Jun 6, 2026
afa91a5
Merge remote-tracking branch 'origin/develop' into develop
zhoutianzi666 Jun 6, 2026
ec9e650
support only_do_attn
zhoutianzi666 Jun 6, 2026
f49af7a
support only_do_attn
zhoutianzi666 Jun 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions custom_ops/gpu_ops/append_attn/append_attention_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(

template <uint32_t group_size,
uint32_t num_frags_x,
uint32_t num_frags_y,
uint32_t HEAD_DIM,
typename T>
__device__ __forceinline__ void load_q_global_smem(
Expand All @@ -175,6 +174,7 @@ __device__ __forceinline__ void load_q_global_smem(
const uint32_t qo_h_stride) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();

static_assert(HEAD_DIM % 64 == 0, "");
const uint32_t tx = threadIdx.x, ty = threadIdx.y;

uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim]
Expand All @@ -193,7 +193,7 @@ __device__ __forceinline__ void load_q_global_smem(
const T* q_ptr =
q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
for (uint32_t fyo = 0; fyo < HEAD_DIM / 64; ++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
Expand All @@ -202,7 +202,7 @@ __device__ __forceinline__ void load_q_global_smem(
}
q_smem_offset_w =
q_smem->advance_offset_by_row<4, num_vecs_per_head>(q_smem_offset_w) -
2 * num_frags_y; // num_frags_y / 4 * 8
HEAD_DIM / 8;
}
}
}
Expand All @@ -228,15 +228,14 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps(
}
}

template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
template <uint32_t num_frags_x, uint32_t head_dim, typename T>
__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale(
smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16]

This comment was marked as outdated.

const float sm_scale) {
constexpr int vec_size = 16 / sizeof(T);
using LoadT = AlignedVector<T, vec_size>;
LoadT tmp_vec;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();

#pragma unroll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ __global__ void multi_query_append_attention_kernel(
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
load_q_global_smem<GROUP_SIZE, num_frags_x, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
Expand All @@ -140,8 +140,7 @@ __global__ void multi_query_append_attention_kernel(
wait_group<0>();
__syncthreads();

q_smem_inplace_multiply_sm_scale<num_frags_x, num_frags_y, T>(&qo_smem,
scale);
q_smem_inplace_multiply_sm_scale<num_frags_x, HEAD_DIM, T>(&qo_smem, scale);

smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)),
v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T));
Expand Down Expand Up @@ -390,7 +389,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const float quant_min_bound,
const float in_scale,
const uint32_t chunk_size,
const int num_blocks_x_cpu,
T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks,
// num_heads, head_dim]
float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads]
Expand Down Expand Up @@ -1080,7 +1078,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
nullptr,
nullptr,
nullptr,
Expand Down Expand Up @@ -1137,7 +1134,6 @@ void MultiQueryAppendAttention(
quant_min_bound,
in_scale,
chunk_size,
num_blocks_x_cpu,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ __global__ void multi_query_append_attention_c4_kernel(

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_x * 16 + tid % 16, tid / 16);
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
load_q_global_smem<GROUP_SIZE, num_frags_x, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
Expand All @@ -201,8 +201,7 @@ __global__ void multi_query_append_attention_c4_kernel(
wait_group<0>();
__syncthreads();

q_smem_inplace_multiply_sm_scale<num_frags_x, num_frags_y, T>(&qo_smem,
scale);
q_smem_inplace_multiply_sm_scale<num_frags_x, HEAD_DIM, T>(&qo_smem, scale);

T cache_k_scale_frag[num_frags_y][4];
T cache_k_zp_frag[num_frags_y][4];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ __global__ void multi_query_append_attention_c8_kernel(

uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
load_q_global_smem<GROUP_SIZE, num_frags_x, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
q_base_seq_id_this_block,
Expand All @@ -209,8 +209,7 @@ __global__ void multi_query_append_attention_c8_kernel(
wait_group<0>();
__syncthreads();

q_smem_inplace_multiply_sm_scale<num_frags_x, num_frags_y, T>(&qo_smem,
scale);
q_smem_inplace_multiply_sm_scale<num_frags_x, HEAD_DIM, T>(&qo_smem, scale);
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
Expand Down
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/append_attn/template_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
],
"dispatch_params": {
"GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16, 24],
"HEAD_DIM": [64,128],
"HEAD_DIM": [64,128,192],

This comment was marked as outdated.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug HEAD_DIM=192 只加入了模板实例化配置,但运行时 append attention 入口仍不会 dispatch 到这个实例。

CascadeAppendAttentionC16Kernelmeta_data.head_dims 读取 head_dim 后走 DISPATCH_HEAD_DIM,而该宏目前只有 case 64 / case 128192 会进入 default 并 PD_THROW("not support the head_dim")。因此这里生成了 MultiQueryAppendAttention<..., 192, ...>,实际 cache_quant_type_str == "none" 的 192 维请求仍会在入口处失败。

建议修复方式:
同步扩展运行时入口 dispatch,例如给 DISPATCH_HEAD_DIM 增加 case 192,或为 c16 append attention 使用一个包含 64/128/192 的专用 dispatch 宏;同时补一条 head_dim=192 的 append attention 编译/运行或精度用例,确认新实例能被实际调用。

"BLOCK_SIZE": [64],
"CAUSAL": [0, 1],
"BLOCK_SHAPE_Q": [16, 32, 64, 128],
Expand Down
Loading