diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index c1586945cc5..6bf19f3de65 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -177,59 +177,59 @@ void AppendAttentionKernel( sink_size); }; + int encoder_num_blocks_data = encoder_num_blocks.data()[0]; + int kv_num_blocks_data = kv_num_blocks.data()[0]; + + auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void { + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + EncoderWriteCacheWithRopeKernel( + meta_data, + qkv, + seq_lens_this_time, + seq_lens_this_time, + seq_lens_decoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + kv_signal_data, + cache_quant_type_str, + kv_num_blocks_data, + max_input_length, + use_neox_rotary_style, + rope_3d, + main_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + }) + }; + + if (qkv_out_scales) { + int tmp; + dispatch_EncoderWriteCacheWithRopeKernel(tmp); + } else { + data_t tmp; + dispatch_EncoderWriteCacheWithRopeKernel(tmp); + } + if (max_enc_len_this_time > 0) { if (max_just_dec_len_this_time > 0) { cudaEventRecord(main_event, main_stream); } - int encoder_num_blocks_data = encoder_num_blocks.data()[0]; - int kv_num_blocks_data = kv_num_blocks.data()[0]; - - auto dispatch_EncoderWriteCacheWithRopeKernel = - [&](auto temp_args) -> void { - DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { - EncoderWriteCacheWithRopeKernel( - meta_data, - qkv, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - kv_batch_ids, - kv_tile_ids_per_batch, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - kv_signal_data, - cache_quant_type_str, - kv_num_blocks_data, - max_input_length, - use_neox_rotary_style, - rope_3d, - main_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - }) - }; - - if (qkv_out_scales) { - int tmp; - dispatch_EncoderWriteCacheWithRopeKernel(tmp); - } else { - data_t tmp; - dispatch_EncoderWriteCacheWithRopeKernel(tmp); - } if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { case paddle::DataType::INT8: { @@ -289,119 +289,6 @@ void AppendAttentionKernel( } else { exec_stream = main_stream; } - DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { - if (speculate_decoder) { - if (qkv_out_scales) { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } else { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } - } else { - if (qkv_out_scales) { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } else { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } - } - }) if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index d61aa3c2313..13ae9e1348c 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -383,9 +383,8 @@ void GetBlockShapeAndSplitKVBlock( } // mla_backend not need run the following code. if (mla_backend) return; - - // encoder - if (max_enc_len_this_time > 0) { + // deal with write cache kv! + { const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size); const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv; @@ -400,7 +399,7 @@ void GetBlockShapeAndSplitKVBlock( split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>( seq_lens_decoder.data(), - seq_lens_encoder.data(), + seq_lens_this_time.data(), kv_batch_ids.data(), kv_tile_ids_per_batch.data(), kv_num_blocks_x.data(), @@ -408,8 +407,15 @@ void GetBlockShapeAndSplitKVBlock( block_size, block_size); - kv_num_blocks_x_cpu.copy_( - kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true); +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + kv_num_blocks_x_cpu.copy_( + kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true); + } + + // encoder + if (max_enc_len_this_time > 0) { // Clear buffer const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);