@@ -3309,107 +3309,260 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
33093309 * MoE architectures with potential sparse expert routing.
33103310 */
33113311static void ggml_cann_mul_mat_id_quant (ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3312- // TODO: Use aclnnGroupedMatMul
33133312 // dst [M, K, N, 1]
33143313 ggml_tensor * src0 = dst->src [0 ]; // src0 [D, M, A, 1]
33153314 ggml_tensor * src1 = dst->src [1 ]; // src1 [D, B, N, 1], B = K or B = 1
33163315 ggml_tensor * ids = dst->src [2 ]; // ids [K, N]
33173316
3318- GGML_TENSOR_BINARY_OP_LOCALS
3319-
3320- // copy index from npu to cpu
3321- int64_t n_as = ne02; // A
3322- int64_t n_ids = ids->ne [0 ]; // K
3323-
3324- std::vector<char > ids_host (ggml_nbytes (ids));
3325- ACL_CHECK (aclrtMemcpyAsync (ids_host.data (), ggml_nbytes (ids), ids->data , ggml_nbytes (ids),
3326- ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream ()));
3327- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
3328-
3329- char * src0_original = (char *) src0->data ;
3330- char * src1_original = (char *) src1->data ;
3331- char * dst_original = (char *) dst->data ;
3317+ GGML_ASSERT (src0->ne [3 ] == 1 );
3318+ GGML_ASSERT (src1->ne [3 ] == 1 );
3319+ GGML_ASSERT (dst->ne [3 ] == 1 );
33323320
3333- ggml_tensor src0_row = *src0;
3334- ggml_tensor src1_row = *src1;
3335- ggml_tensor dst_row = *dst;
3321+ int64_t batch = src1->ne [2 ];
3322+ GGML_ASSERT (batch == ids->ne [1 ]);
33363323
3337- const enum ggml_type type = dst-> src [ 0 ] ->type ;
3338- float weight_elem_size;
3324+ const enum ggml_type type = src0 ->type ;
3325+ float weight_elem_size;
33393326 if (type == GGML_TYPE_Q4_0) {
33403327 weight_elem_size = float (sizeof (uint8_t )) / 2 ;
33413328 } else if (type == GGML_TYPE_Q8_0) {
33423329 weight_elem_size = float (sizeof (uint8_t ));
33433330 } else {
3344- GGML_ABORT (" MUL_MAT_ID only support quant type Q4_0 and Q8_0 " );
3331+ GGML_ABORT (" MUL_MAT_ID only support quant type Q4_0 and Q8_0" );
33453332 }
33463333
3347- // src0_row [D, M, 1, 1] weight without permute
3348- src0_row.ne [2 ] = 1 ;
3349- src0_row.ne [3 ] = 1 ;
3350- src0_row.nb [0 ] = weight_elem_size;
3351- src0_row.nb [1 ] = weight_elem_size * ne00;
3352- src0_row.nb [2 ] = weight_elem_size * ne00;
3353- src0_row.nb [3 ] = weight_elem_size * ne00;
3354- size_t weight_stride = ne00 * ne01 * weight_elem_size;
3355- size_t weight_size = weight_stride * ne02 * ne03;
3334+ // Calculate memory layout
3335+ size_t weight_stride = src0->ne [0 ] * src0->ne [1 ] * weight_elem_size;
3336+ size_t weight_size = weight_stride * src0->ne [2 ] * src0->ne [3 ];
33563337
3357- // scale [D, M, 1, 1] -> scale && permute
33583338 size_t scale_elem_size = sizeof (uint16_t );
3359- size_t scale_stride = src0->ne [1 ] * src0->ne [0 ] / QK8_0 * scale_elem_size;
3339+ char * scale_offset = (char *) src0->data + weight_size;
3340+
3341+ // Allocate temporary buffers for selected weights and scales
3342+ size_t export_weight_size = src0->ne [0 ] * src0->ne [1 ] * ids->ne [0 ] * weight_elem_size;
3343+ ggml_cann_pool_alloc export_weight_allocator (ctx.pool (), export_weight_size);
3344+ void * export_weight_ptr = export_weight_allocator.get ();
3345+
3346+ size_t export_scale_size = (src0->ne [0 ] / QK8_0) * src0->ne [1 ] * ids->ne [0 ] * scale_elem_size;
3347+ ggml_cann_pool_alloc export_scale_allocator (ctx.pool (), export_scale_size);
3348+ void * export_scale_ptr = export_scale_allocator.get ();
3349+
3350+ // Prepare input buffer (convert to F16 if needed)
3351+ size_t input_elem_size = sizeof (uint16_t );
3352+ ggml_cann_pool_alloc input_allocator (ctx.pool ());
3353+ void * input_buffer = src1->data ;
33603354
3361- // src1_row [D, 1, 1, 1] -> input
3362- src1_row.ne [1 ] = 1 ;
3363- src1_row.ne [2 ] = 1 ;
3364- src1_row.ne [3 ] = 1 ;
3365- src1_row.nb [2 ] = nb11;
3366- src1_row.nb [3 ] = nb11;
3367-
3368- // dst_row [M, 1, 1, 1] -> out
3369- dst_row.ne [1 ] = 1 ;
3370- dst_row.ne [2 ] = 1 ;
3371- dst_row.ne [3 ] = 1 ;
3372- dst_row.nb [2 ] = nb1;
3373- dst_row.nb [3 ] = nb1;
3374-
3375- // create weight for one row
3376- ggml_cann_pool_alloc weight_allocator (ctx.pool ());
3377- void * weight_buffer = weight_allocator.alloc (nb02);
3378- for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
3379- for (int64_t id = 0 ; id < n_ids; id++) {
3380- // expert index
3381- int32_t i02 = *(int32_t *) (ids_host.data () + iid1 * ids->nb [1 ] + id * ids->nb [0 ]);
3382- GGML_ASSERT (i02 >= 0 && i02 < n_as);
3383-
3384- // If B = 1 (broadcast), always use 0; otherwise, use id.
3385- int64_t i11 = (ne11 == 1 ? 0 : id);
3386- int64_t i12 = iid1;
3387-
3388- int64_t i1 = id;
3389- int64_t i2 = i12;
3390-
3391- void * src0_tmp_ptr = src0_original + i02 * weight_stride;
3392- void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
3393- void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
3394- void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
3395-
3396- // mem cpy
3397- ACL_CHECK (aclrtMemcpyAsync (weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
3398- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
3399- void * scale_buffer = (char *) weight_buffer + weight_stride;
3400- ACL_CHECK (aclrtMemcpyAsync (scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
3401- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
3402-
3403- src0_row.data = weight_buffer;
3404- src1_row.data = src1_tmp_ptr;
3405- dst_row.data = dst_tmp_ptr;
3406- dst_row.src [0 ] = &src0_row;
3407- dst_row.src [1 ] = &src1_row;
3408-
3409- ggml_cann_mul_mat (ctx, &dst_row);
3355+ if (src1->type != GGML_TYPE_F16) {
3356+ size_t total_input_size = input_elem_size;
3357+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3358+ total_input_size *= src1->ne [i];
3359+ }
3360+ input_buffer = input_allocator.alloc (total_input_size);
3361+
3362+ acl_tensor_ptr acl_src1_tensor = ggml_cann_create_tensor (src1);
3363+
3364+ int64_t input_cast_ne[GGML_MAX_DIMS];
3365+ size_t input_cast_nb[GGML_MAX_DIMS];
3366+
3367+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3368+ input_cast_ne[i] = src1->ne [i];
3369+ }
3370+
3371+ input_cast_nb[0 ] = input_elem_size;
3372+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3373+ input_cast_nb[i] = input_cast_nb[i - 1 ] * input_cast_ne[i - 1 ];
34103374 }
3375+
3376+ acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor (
3377+ input_buffer, ACL_FLOAT16, input_elem_size,
3378+ input_cast_ne, input_cast_nb, GGML_MAX_DIMS);
3379+
3380+ aclnn_cast (ctx, acl_src1_tensor.get (), acl_input_tensor.get (), ACL_FLOAT16);
3381+ }
3382+
3383+ // Prepare output buffer (use temp buffer if not F16)
3384+ size_t output_elem_size = sizeof (uint16_t );
3385+ ggml_cann_pool_alloc output_allocator (ctx.pool ());
3386+ void * output_buffer = dst->data ;
3387+
3388+ if (dst->type != GGML_TYPE_F16) {
3389+ size_t total_output_size = output_elem_size;
3390+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3391+ total_output_size *= dst->ne [i];
3392+ }
3393+ output_buffer = output_allocator.alloc (total_output_size);
3394+ }
3395+
3396+ // Process each batch
3397+ for (int64_t i = 0 ; i < batch; i++) {
3398+ // Create index tensor for this batch
3399+ acl_tensor_ptr select_index = ggml_cann_create_tensor (
3400+ ids, ids->ne , ids->nb , 1 , ACL_FORMAT_ND, i * ids->nb [1 ]);
3401+
3402+ // IndexSelect for quantized weights (using int8 type)
3403+ int64_t weight_ne_for_select[3 ];
3404+ if (type == GGML_TYPE_Q4_0) {
3405+ weight_ne_for_select[0 ] = src0->ne [0 ] / 2 ; // 2 Q4_0 values per byte
3406+ } else {
3407+ weight_ne_for_select[0 ] = src0->ne [0 ]; // Q8_0
3408+ }
3409+ weight_ne_for_select[1 ] = src0->ne [1 ];
3410+ weight_ne_for_select[2 ] = src0->ne [2 ];
3411+
3412+ size_t weight_nb_for_select[3 ];
3413+ weight_nb_for_select[0 ] = sizeof (int8_t );
3414+ weight_nb_for_select[1 ] = weight_nb_for_select[0 ] * weight_ne_for_select[0 ];
3415+ weight_nb_for_select[2 ] = weight_nb_for_select[1 ] * weight_ne_for_select[1 ];
3416+
3417+ acl_tensor_ptr export_weight = ggml_cann_create_tensor (
3418+ src0->data , ACL_INT8, sizeof (int8_t ),
3419+ weight_ne_for_select, weight_nb_for_select, 3 );
3420+
3421+ int64_t select_export_weight_ne[3 ] = {
3422+ weight_ne_for_select[0 ],
3423+ weight_ne_for_select[1 ],
3424+ ids->ne [0 ]
3425+ };
3426+ size_t select_export_weight_nb[3 ];
3427+ select_export_weight_nb[0 ] = sizeof (int8_t );
3428+ select_export_weight_nb[1 ] = select_export_weight_nb[0 ] * select_export_weight_ne[0 ];
3429+ select_export_weight_nb[2 ] = select_export_weight_nb[1 ] * select_export_weight_ne[1 ];
3430+
3431+ acl_tensor_ptr select_export_weight = ggml_cann_create_tensor (
3432+ export_weight_ptr, ACL_INT8, sizeof (int8_t ),
3433+ select_export_weight_ne, select_export_weight_nb, 3 );
3434+
3435+ GGML_CANN_CALL_ACLNN_OP (ctx, IndexSelect,
3436+ export_weight.get (), 0 , select_index.get (), select_export_weight.get ());
3437+
3438+ // IndexSelect for scales
3439+ int64_t scale_ne[3 ] = {
3440+ src0->ne [0 ] / QK8_0,
3441+ src0->ne [1 ],
3442+ src0->ne [2 ]
3443+ };
3444+ size_t scale_nb[3 ];
3445+ scale_nb[0 ] = scale_elem_size;
3446+ scale_nb[1 ] = scale_nb[0 ] * scale_ne[0 ];
3447+ scale_nb[2 ] = scale_nb[1 ] * scale_ne[1 ];
3448+
3449+ acl_tensor_ptr export_scale = ggml_cann_create_tensor (
3450+ scale_offset, ACL_FLOAT16, scale_elem_size,
3451+ scale_ne, scale_nb, 3 );
3452+
3453+ int64_t select_export_scale_ne[3 ] = {
3454+ scale_ne[0 ],
3455+ scale_ne[1 ],
3456+ ids->ne [0 ]
3457+ };
3458+ size_t select_export_scale_nb[3 ];
3459+ select_export_scale_nb[0 ] = scale_elem_size;
3460+ select_export_scale_nb[1 ] = select_export_scale_nb[0 ] * select_export_scale_ne[0 ];
3461+ select_export_scale_nb[2 ] = select_export_scale_nb[1 ] * select_export_scale_ne[1 ];
3462+
3463+ acl_tensor_ptr select_export_scale = ggml_cann_create_tensor (
3464+ export_scale_ptr, ACL_FLOAT16, scale_elem_size,
3465+ select_export_scale_ne, select_export_scale_nb, 3 );
3466+
3467+ GGML_CANN_CALL_ACLNN_OP (ctx, IndexSelect,
3468+ export_scale.get (), 0 , select_index.get (), select_export_scale.get ());
3469+
3470+ // IndexSelect output is [D, M, K] in contiguous layout
3471+ // For WeightQuantBatchMatmulV2, we need each expert as [M, D] with M major stride
3472+ for (int64_t k = 0 ; k < ids->ne [0 ]; k++) {
3473+ // Input offset: if src1->ne[1] == 1, broadcast (all k use same input); otherwise each k has its own input
3474+ size_t input_offset = (i * src1->ne [1 ] + (src1->ne [1 ] == 1 ? 0 : k)) * src1->ne [0 ] * input_elem_size;
3475+ size_t output_offset = (i * dst->ne [1 ] + k) * dst->ne [0 ] * output_elem_size;
3476+
3477+ // Create view for the k-th expert weight from [D, M, K] -> [M, D]
3478+ // Data layout in memory is [D0M0, D0M1, ..., D0M_{M-1}, D1M0, D1M1, ...]
3479+ // We need [M, D] format with stride[0]=D*elemsize, stride[1]=elemsize
3480+ int64_t weight_view_ne[2 ] = {
3481+ select_export_weight_ne[1 ], // M = src0->ne[1]
3482+ select_export_weight_ne[0 ] // D = src0->ne[0] (adjusted for Q4_0/Q8_0)
3483+ };
3484+ float weight_view_nb[2 ] = {
3485+ src0->ne [0 ] * weight_elem_size, // M stride: one row = D * elemsize
3486+ weight_elem_size // D stride: one element
3487+ };
3488+ size_t weight_view_offset = k * select_export_weight_nb[2 ];
3489+
3490+ acl_tensor_ptr weight_view = ggml_cann_create_tensor (
3491+ export_weight_ptr, ggml_cann_type_mapping (type), weight_elem_size,
3492+ weight_view_ne, weight_view_nb, 2 ,
3493+ ACL_FORMAT_ND, weight_view_offset);
3494+
3495+ // Create view for the k-th expert scale from [D, M, K] -> [M, D]
3496+ int64_t scale_view_ne[2 ] = {
3497+ select_export_scale_ne[1 ], // M = src0->ne[1]
3498+ select_export_scale_ne[0 ] // D = src0->ne[0] / QK8_0
3499+ };
3500+ size_t scale_view_nb[2 ] = {
3501+ select_export_scale_nb[1 ], // M stride
3502+ select_export_scale_nb[0 ] // D stride
3503+ };
3504+ size_t scale_view_offset = k * select_export_scale_nb[2 ];
3505+
3506+ acl_tensor_ptr scale_view = ggml_cann_create_tensor (
3507+ export_scale_ptr, ACL_FLOAT16, scale_elem_size,
3508+ scale_view_ne, scale_view_nb, 2 ,
3509+ ACL_FORMAT_ND, scale_view_offset);
3510+
3511+ // Prepare input tensor [D, 1]
3512+ int64_t active_tensor_ne[2 ] = { src1->ne [0 ], 1 };
3513+ size_t active_tensor_nb[2 ] = { input_elem_size, src1->ne [0 ] * input_elem_size };
3514+
3515+ acl_tensor_ptr active_tensor = ggml_cann_create_tensor (
3516+ input_buffer, ACL_FLOAT16, input_elem_size,
3517+ active_tensor_ne, active_tensor_nb, 2 ,
3518+ ACL_FORMAT_ND, input_offset);
3519+
3520+ // Prepare output tensor [M, 1]
3521+ int64_t dst_ne[2 ] = { dst->ne [0 ], 1 };
3522+ size_t dst_nb[2 ] = { output_elem_size, dst->ne [0 ] * output_elem_size };
3523+
3524+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor (
3525+ output_buffer, ACL_FLOAT16, output_elem_size,
3526+ dst_ne, dst_nb, 2 ,
3527+ ACL_FORMAT_ND, output_offset);
3528+
3529+ // Call WeightQuantBatchMatmulV2
3530+ GGML_CANN_CALL_ACLNN_OP (ctx, WeightQuantBatchMatmulV2,
3531+ active_tensor.get (),
3532+ weight_view.get (),
3533+ scale_view.get (),
3534+ nullptr ,
3535+ nullptr ,
3536+ nullptr ,
3537+ nullptr ,
3538+ QK8_0,
3539+ acl_dst.get ());
3540+ }
3541+ }
3542+
3543+ // Cast output back to target type if needed
3544+ if (dst->type != GGML_TYPE_F16) {
3545+ int64_t output_cast_ne[GGML_MAX_DIMS];
3546+ size_t output_cast_nb[GGML_MAX_DIMS];
3547+
3548+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3549+ output_cast_ne[i] = dst->ne [i];
3550+ }
3551+
3552+ output_cast_nb[0 ] = output_elem_size;
3553+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3554+ output_cast_nb[i] = output_cast_nb[i - 1 ] * output_cast_ne[i - 1 ];
3555+ }
3556+
3557+ acl_tensor_ptr acl_output_tensor = ggml_cann_create_tensor (
3558+ output_buffer, ACL_FLOAT16, output_elem_size,
3559+ output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
3560+
3561+ acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor (dst);
3562+
3563+ aclnn_cast (ctx, acl_output_tensor.get (), acl_dst_tensor.get (),
3564+ ggml_cann_type_mapping (dst->type ));
34113565 }
3412- return ;
34133566}
34143567
34153568void ggml_cann_mul_mat_id (ggml_backend_cann_context & ctx, ggml_tensor * dst) {
0 commit comments