Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
const int kv_head_num,
const int head_dim);

#ifdef ENABLE_APPEND_ATTENTION
std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& qkv,
const paddle::Tensor& key_cache,
Expand Down Expand Up @@ -216,6 +217,7 @@ std::vector<paddle::Tensor> PreCacheLenConcat(
const paddle::Tensor& seq_lens_this_time,
const int max_dec_len,
const int block_size);
#endif // ENABLE_APPEND_ATTENTION

paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor& input,
Expand Down Expand Up @@ -386,6 +388,7 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata,
const int layer_id);

#ifdef ENABLE_APPEND_ATTENTION
void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand All @@ -406,6 +409,7 @@ void GetBlockShapeAndSplitKVBlock(
const int decoder_block_shape_q,
const int group_size,
const int block_size);
#endif // ENABLE_APPEND_ATTENTION

std::vector<paddle::Tensor> GetPaddingOffset(
const paddle::Tensor& input_ids,
Expand Down Expand Up @@ -1170,9 +1174,11 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("wait_flag"),
"get_output_kv_signal function");

#ifdef ENABLE_BF16
m.def("moe_deepgemm_permute", &MoEDeepGEMMPermute, "MoEDeepGEMMPermute");
m.def(
"moe_deepgemm_depermute", &MoEDeepGEMMDePermute, "MoEDeepGEMMDePermute");
#endif
/**
* alloc_cache_pinned.cc
* cuda_host_alloc
Expand All @@ -1186,6 +1192,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def(
"cuda_host_free", &cuda_host_free, "Free pinned memory", py::arg("ptr"));
py::register_exception<CudaError>(m, "CudaError");
#ifdef ENABLE_APPEND_ATTENTION
/**
* append_attention.cu
* append_attention
Expand Down Expand Up @@ -1213,6 +1220,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("pre_cache_len_concat",
&PreCacheLenConcat,
"pre_cache len concat function");
#endif // ENABLE_APPEND_ATTENTION
/**
* moe/fused_moe/fused_moe.cu
* fused_moe
Expand Down Expand Up @@ -1242,7 +1250,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
"moe export dispatch function");

/**
* moe/fused_moe/ep_moe_prefill_func.cu
* moe/ep_moe_expert_dispatch.cu
* ep_moe_dispatch
*/
m.def("ep_moe_expert_dispatch",
Expand Down Expand Up @@ -1386,13 +1394,15 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&OpenShmAndGetMetaSignalFunc,
"open_shm_and_get_meta_signal function");

#ifdef ENABLE_APPEND_ATTENTION
/**
* append_attn/get_block_shape_and_split_kv_block.cu
* get_block_shape_and_split_kv_block
*/
m.def("get_block_shape_and_split_kv_block",
&GetBlockShapeAndSplitKVBlock,
"get_block_shape_and_split_kv_block function");
#endif // ENABLE_APPEND_ATTENTION

/**
* get_padding_offset.cu
Expand Down Expand Up @@ -1456,9 +1466,12 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
&TextImageGatherScatter,
"text_image_gather_scatter function");

// tritonmoe_preprocess_func does not depend on BF16, keep it unconditionally
// available
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);

#ifdef ENABLE_BF16
m.def("MoeWna16MarlinGemmApi",
&MoeWna16MarlinGemmApi,
py::arg("a"),
Expand Down Expand Up @@ -1549,6 +1562,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("noaux_tc_redundant",
&NoauxTcRedundant,
"noaux_tc_redundant for MoE compute");
#endif

#ifdef ENABLE_FP8
m.def("cutlass_fp8_fp8_half_gemm_fused",
Expand All @@ -1562,6 +1576,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("output_dtype"),
py::arg("activation_type"),
"cutlass_fp8_fp8_half_gemm_fused function");

m.def("moe_fused_hadamard_quant_fp8",
&MoeFusedHadamardQuantFp8Func,
py::arg("input"),
Expand Down
9 changes: 7 additions & 2 deletions custom_ops/gpu_ops/gelu_tanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <math.h>
#include "helper.h"
#include "paddle/extension.h"

#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
__forceinline__ __device__ float tanh_ptx(float x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
// Use hardware tanh instruction for sm_75 and above
float y;
asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
#else
// Fallback implementation for sm_70 and below
return tanhf(x);
#endif
}
#endif

Expand Down Expand Up @@ -89,7 +94,7 @@ std::vector<paddle::Tensor> GeluTanh(paddle::Tensor& input) {
DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), scalar_t, {
uint32_t vec_size = 16 / sizeof(scalar_t);
dim3 grid(num_tokens);
dim3 block(std::max(d / vec_size, 1024U));
dim3 block(std::min(d / vec_size, 1024U));

#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
gelu_tanh_kernel<scalar_t><<<grid, block, 0, stream>>>(
Expand Down
Loading
Loading