diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 90c1ab8f8ed..b7066172ff8 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -163,7 +163,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( template __device__ __forceinline__ void load_q_global_smem( @@ -175,6 +174,7 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t qo_h_stride) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + static_assert(HEAD_DIM % 64 == 0, ""); const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] @@ -193,7 +193,7 @@ __device__ __forceinline__ void load_q_global_smem( const T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + for (uint32_t fyo = 0; fyo < HEAD_DIM / 64; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -202,7 +202,7 @@ __device__ __forceinline__ void load_q_global_smem( } q_smem_offset_w = q_smem->advance_offset_by_row<4, num_vecs_per_head>(q_smem_offset_w) - - 2 * num_frags_y; // num_frags_y / 4 * 8 + HEAD_DIM / 8; } } } @@ -228,7 +228,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( } } -template +template __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] const float sm_scale) { @@ -236,7 +236,6 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( using LoadT = AlignedVector; LoadT tmp_vec; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); #pragma unroll diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 0bdf88e0a56..c8399a0e1f7 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -129,7 +129,7 @@ __global__ void multi_query_append_attention_kernel( #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif - load_q_global_smem( + load_q_global_smem( q_base_ptr, &qo_smem, q_base_seq_id_this_block, @@ -140,8 +140,7 @@ __global__ void multi_query_append_attention_kernel( wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)), v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); @@ -390,7 +389,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, - const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -1080,7 +1078,6 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, - num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1137,7 +1134,6 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, - num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index 8d4111a0a29..ee1821f13c1 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -190,7 +190,7 @@ __global__ void multi_query_append_attention_c4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); - load_q_global_smem( + load_q_global_smem( q_base_ptr, &qo_smem, q_base_seq_id_this_block, @@ -201,8 +201,7 @@ __global__ void multi_query_append_attention_c4_kernel( wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); T cache_k_scale_frag[num_frags_y][4]; T cache_k_zp_frag[num_frags_y][4]; diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 611b5d66435..ba103a6e2f4 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -198,7 +198,7 @@ __global__ void multi_query_append_attention_c8_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - load_q_global_smem( + load_q_global_smem( q_base_ptr, &qo_smem, q_base_seq_id_this_block, @@ -209,8 +209,7 @@ __global__ void multi_query_append_attention_c8_kernel( wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); diff --git a/custom_ops/gpu_ops/append_attn/template_config.json b/custom_ops/gpu_ops/append_attn/template_config.json index b2590586206..b1932536859 100644 --- a/custom_ops/gpu_ops/append_attn/template_config.json +++ b/custom_ops/gpu_ops/append_attn/template_config.json @@ -90,7 +90,7 @@ ], "dispatch_params": { "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16, 24], - "HEAD_DIM": [64,128], + "HEAD_DIM": [64,128,192], "BLOCK_SIZE": [64], "CAUSAL": [0, 1], "BLOCK_SHAPE_Q": [16, 32, 64, 128],