Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 49 additions & 162 deletions custom_ops/gpu_ops/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,59 +177,59 @@ void AppendAttentionKernel(
sink_size);
};

int encoder_num_blocks_data = encoder_num_blocks.data<int>()[0];
int kv_num_blocks_data = kv_num_blocks.data<int>()[0];

auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void {
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
EncoderWriteCacheWithRopeKernel<data_t,
decltype(temp_args),
EnforceFmulRN>(
meta_data,
qkv,
seq_lens_this_time,
seq_lens_this_time,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
kv_signal_data,
cache_quant_type_str,
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
rope_3d,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
})
};

if (qkv_out_scales) {
int tmp;
dispatch_EncoderWriteCacheWithRopeKernel(tmp);
} else {
data_t tmp;
dispatch_EncoderWriteCacheWithRopeKernel(tmp);
}

if (max_enc_len_this_time > 0) {
if (max_just_dec_len_this_time > 0) {
cudaEventRecord(main_event, main_stream);
}
int encoder_num_blocks_data = encoder_num_blocks.data<int>()[0];
int kv_num_blocks_data = kv_num_blocks.data<int>()[0];

auto dispatch_EncoderWriteCacheWithRopeKernel =
[&](auto temp_args) -> void {
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
EncoderWriteCacheWithRopeKernel<data_t,
decltype(temp_args),
EnforceFmulRN>(
meta_data,
qkv,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
kv_signal_data,
cache_quant_type_str,
kv_num_blocks_data,
max_input_length,
use_neox_rotary_style,
rope_3d,
main_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
})
};

if (qkv_out_scales) {
int tmp;
dispatch_EncoderWriteCacheWithRopeKernel(tmp);
} else {
data_t tmp;
dispatch_EncoderWriteCacheWithRopeKernel(tmp);
}
if (out_linear_in_scale > 0.0) {
switch (fmha_out.dtype()) {
case paddle::DataType::INT8: {
Expand Down Expand Up @@ -289,119 +289,6 @@ void AppendAttentionKernel(
} else {
exec_stream = main_stream;
}
DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, {
if (speculate_decoder) {
if (qkv_out_scales) {
SpeculateWriteCacheWithRoPEKernel<data_t, int, EnforceFmulRN>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
SpeculateWriteCacheWithRoPEKernel<data_t, data_t, EnforceFmulRN>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
} else {
if (qkv_out_scales) {
DecoderWriteCacheWithRoPEKernel<data_t, int, EnforceFmulRN>(
meta_data,
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
} else {
DecoderWriteCacheWithRoPEKernel<data_t, data_t, EnforceFmulRN>(
meta_data,
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
qkv_bias,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_zp,
cache_v_zp,
cache_quant_type_str,
use_neox_rotary_style,
rope_3d,
max_input_length,
exec_stream,
&qkv_out,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache),
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
}
})

if (out_linear_in_scale > 0.0) {
switch (fmha_out.dtype()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,8 @@ void GetBlockShapeAndSplitKVBlock(
}
// mla_backend not need run the following code.
if (mla_backend) return;

// encoder
if (max_enc_len_this_time > 0) {
// deal with write cache kv!
{
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
Expand All @@ -400,16 +399,23 @@ void GetBlockShapeAndSplitKVBlock(

split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_this_time.data<int>(),
kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(),
kv_num_blocks_x.data<int>(),
bsz,
block_size,
block_size);

kv_num_blocks_x_cpu.copy_(
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true);
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if (!phi::backends::gpu::IsCUDAGraphCapturing())
#endif
kv_num_blocks_x_cpu.copy_(
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true);
}

// encoder
if (max_enc_len_this_time > 0) {
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
Expand Down
Loading