diff --git a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh index 02c92b134b..cac7abfd34 100644 --- a/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_nn_vbatch.cuh @@ -273,7 +273,7 @@ static __global__ void vbatched_gemm_nn_kernel(const int* M, const int* global_ldc, const T* alpha) { - extern __shared__ __align__(sizeof(T)) unsigned char smem[]; + extern __shared__ __align__(sizeof(double)) unsigned char smem[]; T* shared_mem = reinterpret_cast(smem); int batchid = blockIdx.z; diff --git a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh index 8e56dc083e..380eb9efa8 100644 --- a/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh +++ b/source/source_lcao/module_gint/kernel/gemm_tn_vbatch.cuh @@ -273,7 +273,7 @@ static __global__ void vbatched_gemm_nt_kernel(const int* M, const int* global_ldc, const T* alpha) { - extern __shared__ __align__(sizeof(T)) unsigned char smem[]; + extern __shared__ __align__(sizeof(double)) unsigned char smem[]; T* shared_mem = reinterpret_cast(smem); int batchid = blockIdx.z;