Skip to content

Commit e9d05be

Browse files
committed
Fix mul_mat_id_quant: use contiguous [M,D] layout for weight tensors
1 parent ec6c742 commit e9d05be

2 files changed

Lines changed: 235 additions & 86 deletions

File tree

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 235 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,107 +3309,260 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
33093309
* MoE architectures with potential sparse expert routing.
33103310
*/
33113311
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3312-
// TODO: Use aclnnGroupedMatMul
33133312
//dst [M, K, N, 1]
33143313
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
33153314
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
33163315
ggml_tensor * ids = dst->src[2]; //ids [K, N]
33173316

3318-
GGML_TENSOR_BINARY_OP_LOCALS
3319-
3320-
// copy index from npu to cpu
3321-
int64_t n_as = ne02; // A
3322-
int64_t n_ids = ids->ne[0]; // K
3323-
3324-
std::vector<char> ids_host(ggml_nbytes(ids));
3325-
ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
3326-
ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
3327-
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
3328-
3329-
char * src0_original = (char *) src0->data;
3330-
char * src1_original = (char *) src1->data;
3331-
char * dst_original = (char *) dst->data;
3317+
GGML_ASSERT(src0->ne[3] == 1);
3318+
GGML_ASSERT(src1->ne[3] == 1);
3319+
GGML_ASSERT(dst->ne[3] == 1);
33323320

3333-
ggml_tensor src0_row = *src0;
3334-
ggml_tensor src1_row = *src1;
3335-
ggml_tensor dst_row = *dst;
3321+
int64_t batch = src1->ne[2];
3322+
GGML_ASSERT(batch == ids->ne[1]);
33363323

3337-
const enum ggml_type type = dst->src[0]->type;
3338-
float weight_elem_size;
3324+
const enum ggml_type type = src0->type;
3325+
float weight_elem_size;
33393326
if (type == GGML_TYPE_Q4_0) {
33403327
weight_elem_size = float(sizeof(uint8_t)) / 2;
33413328
} else if (type == GGML_TYPE_Q8_0) {
33423329
weight_elem_size = float(sizeof(uint8_t));
33433330
} else {
3344-
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
3331+
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0");
33453332
}
33463333

3347-
// src0_row [D, M, 1, 1] weight without permute
3348-
src0_row.ne[2] = 1;
3349-
src0_row.ne[3] = 1;
3350-
src0_row.nb[0] = weight_elem_size;
3351-
src0_row.nb[1] = weight_elem_size * ne00;
3352-
src0_row.nb[2] = weight_elem_size * ne00;
3353-
src0_row.nb[3] = weight_elem_size * ne00;
3354-
size_t weight_stride = ne00 * ne01 * weight_elem_size;
3355-
size_t weight_size = weight_stride * ne02 * ne03;
3334+
// Calculate memory layout
3335+
size_t weight_stride = src0->ne[0] * src0->ne[1] * weight_elem_size;
3336+
size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
33563337

3357-
// scale [D, M, 1, 1] -> scale && permute
33583338
size_t scale_elem_size = sizeof(uint16_t);
3359-
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
3339+
char * scale_offset = (char *) src0->data + weight_size;
3340+
3341+
// Allocate temporary buffers for selected weights and scales
3342+
size_t export_weight_size = src0->ne[0] * src0->ne[1] * ids->ne[0] * weight_elem_size;
3343+
ggml_cann_pool_alloc export_weight_allocator(ctx.pool(), export_weight_size);
3344+
void * export_weight_ptr = export_weight_allocator.get();
3345+
3346+
size_t export_scale_size = (src0->ne[0] / QK8_0) * src0->ne[1] * ids->ne[0] * scale_elem_size;
3347+
ggml_cann_pool_alloc export_scale_allocator(ctx.pool(), export_scale_size);
3348+
void * export_scale_ptr = export_scale_allocator.get();
3349+
3350+
// Prepare input buffer (convert to F16 if needed)
3351+
size_t input_elem_size = sizeof(uint16_t);
3352+
ggml_cann_pool_alloc input_allocator(ctx.pool());
3353+
void * input_buffer = src1->data;
33603354

3361-
// src1_row [D, 1, 1, 1] -> input
3362-
src1_row.ne[1] = 1;
3363-
src1_row.ne[2] = 1;
3364-
src1_row.ne[3] = 1;
3365-
src1_row.nb[2] = nb11;
3366-
src1_row.nb[3] = nb11;
3367-
3368-
// dst_row [M, 1, 1, 1] -> out
3369-
dst_row.ne[1] = 1;
3370-
dst_row.ne[2] = 1;
3371-
dst_row.ne[3] = 1;
3372-
dst_row.nb[2] = nb1;
3373-
dst_row.nb[3] = nb1;
3374-
3375-
//create weight for one row
3376-
ggml_cann_pool_alloc weight_allocator(ctx.pool());
3377-
void * weight_buffer = weight_allocator.alloc(nb02);
3378-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3379-
for (int64_t id = 0; id < n_ids; id++) {
3380-
// expert index
3381-
int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
3382-
GGML_ASSERT(i02 >= 0 && i02 < n_as);
3383-
3384-
// If B = 1 (broadcast), always use 0; otherwise, use id.
3385-
int64_t i11 = (ne11 == 1 ? 0 : id);
3386-
int64_t i12 = iid1;
3387-
3388-
int64_t i1 = id;
3389-
int64_t i2 = i12;
3390-
3391-
void * src0_tmp_ptr = src0_original + i02 * weight_stride;
3392-
void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
3393-
void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
3394-
void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
3395-
3396-
// mem cpy
3397-
ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
3398-
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3399-
void * scale_buffer = (char *) weight_buffer + weight_stride;
3400-
ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
3401-
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3402-
3403-
src0_row.data = weight_buffer;
3404-
src1_row.data = src1_tmp_ptr;
3405-
dst_row.data = dst_tmp_ptr;
3406-
dst_row.src[0] = &src0_row;
3407-
dst_row.src[1] = &src1_row;
3408-
3409-
ggml_cann_mul_mat(ctx, &dst_row);
3355+
if (src1->type != GGML_TYPE_F16) {
3356+
size_t total_input_size = input_elem_size;
3357+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3358+
total_input_size *= src1->ne[i];
3359+
}
3360+
input_buffer = input_allocator.alloc(total_input_size);
3361+
3362+
acl_tensor_ptr acl_src1_tensor = ggml_cann_create_tensor(src1);
3363+
3364+
int64_t input_cast_ne[GGML_MAX_DIMS];
3365+
size_t input_cast_nb[GGML_MAX_DIMS];
3366+
3367+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3368+
input_cast_ne[i] = src1->ne[i];
3369+
}
3370+
3371+
input_cast_nb[0] = input_elem_size;
3372+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3373+
input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1];
34103374
}
3375+
3376+
acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(
3377+
input_buffer, ACL_FLOAT16, input_elem_size,
3378+
input_cast_ne, input_cast_nb, GGML_MAX_DIMS);
3379+
3380+
aclnn_cast(ctx, acl_src1_tensor.get(), acl_input_tensor.get(), ACL_FLOAT16);
3381+
}
3382+
3383+
// Prepare output buffer (use temp buffer if not F16)
3384+
size_t output_elem_size = sizeof(uint16_t);
3385+
ggml_cann_pool_alloc output_allocator(ctx.pool());
3386+
void * output_buffer = dst->data;
3387+
3388+
if (dst->type != GGML_TYPE_F16) {
3389+
size_t total_output_size = output_elem_size;
3390+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3391+
total_output_size *= dst->ne[i];
3392+
}
3393+
output_buffer = output_allocator.alloc(total_output_size);
3394+
}
3395+
3396+
// Process each batch
3397+
for (int64_t i = 0; i < batch; i++) {
3398+
// Create index tensor for this batch
3399+
acl_tensor_ptr select_index = ggml_cann_create_tensor(
3400+
ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);
3401+
3402+
// IndexSelect for quantized weights (using int8 type)
3403+
int64_t weight_ne_for_select[3];
3404+
if (type == GGML_TYPE_Q4_0) {
3405+
weight_ne_for_select[0] = src0->ne[0] / 2; // 2 Q4_0 values per byte
3406+
} else {
3407+
weight_ne_for_select[0] = src0->ne[0]; // Q8_0
3408+
}
3409+
weight_ne_for_select[1] = src0->ne[1];
3410+
weight_ne_for_select[2] = src0->ne[2];
3411+
3412+
size_t weight_nb_for_select[3];
3413+
weight_nb_for_select[0] = sizeof(int8_t);
3414+
weight_nb_for_select[1] = weight_nb_for_select[0] * weight_ne_for_select[0];
3415+
weight_nb_for_select[2] = weight_nb_for_select[1] * weight_ne_for_select[1];
3416+
3417+
acl_tensor_ptr export_weight = ggml_cann_create_tensor(
3418+
src0->data, ACL_INT8, sizeof(int8_t),
3419+
weight_ne_for_select, weight_nb_for_select, 3);
3420+
3421+
int64_t select_export_weight_ne[3] = {
3422+
weight_ne_for_select[0],
3423+
weight_ne_for_select[1],
3424+
ids->ne[0]
3425+
};
3426+
size_t select_export_weight_nb[3];
3427+
select_export_weight_nb[0] = sizeof(int8_t);
3428+
select_export_weight_nb[1] = select_export_weight_nb[0] * select_export_weight_ne[0];
3429+
select_export_weight_nb[2] = select_export_weight_nb[1] * select_export_weight_ne[1];
3430+
3431+
acl_tensor_ptr select_export_weight = ggml_cann_create_tensor(
3432+
export_weight_ptr, ACL_INT8, sizeof(int8_t),
3433+
select_export_weight_ne, select_export_weight_nb, 3);
3434+
3435+
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect,
3436+
export_weight.get(), 0, select_index.get(), select_export_weight.get());
3437+
3438+
// IndexSelect for scales
3439+
int64_t scale_ne[3] = {
3440+
src0->ne[0] / QK8_0,
3441+
src0->ne[1],
3442+
src0->ne[2]
3443+
};
3444+
size_t scale_nb[3];
3445+
scale_nb[0] = scale_elem_size;
3446+
scale_nb[1] = scale_nb[0] * scale_ne[0];
3447+
scale_nb[2] = scale_nb[1] * scale_ne[1];
3448+
3449+
acl_tensor_ptr export_scale = ggml_cann_create_tensor(
3450+
scale_offset, ACL_FLOAT16, scale_elem_size,
3451+
scale_ne, scale_nb, 3);
3452+
3453+
int64_t select_export_scale_ne[3] = {
3454+
scale_ne[0],
3455+
scale_ne[1],
3456+
ids->ne[0]
3457+
};
3458+
size_t select_export_scale_nb[3];
3459+
select_export_scale_nb[0] = scale_elem_size;
3460+
select_export_scale_nb[1] = select_export_scale_nb[0] * select_export_scale_ne[0];
3461+
select_export_scale_nb[2] = select_export_scale_nb[1] * select_export_scale_ne[1];
3462+
3463+
acl_tensor_ptr select_export_scale = ggml_cann_create_tensor(
3464+
export_scale_ptr, ACL_FLOAT16, scale_elem_size,
3465+
select_export_scale_ne, select_export_scale_nb, 3);
3466+
3467+
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect,
3468+
export_scale.get(), 0, select_index.get(), select_export_scale.get());
3469+
3470+
// IndexSelect output is [D, M, K] in contiguous layout
3471+
// For WeightQuantBatchMatmulV2, we need each expert as [M, D] with M major stride
3472+
for (int64_t k = 0; k < ids->ne[0]; k++) {
3473+
// Input offset: if src1->ne[1] == 1, broadcast (all k use same input); otherwise each k has its own input
3474+
size_t input_offset = (i * src1->ne[1] + (src1->ne[1] == 1 ? 0 : k)) * src1->ne[0] * input_elem_size;
3475+
size_t output_offset = (i * dst->ne[1] + k) * dst->ne[0] * output_elem_size;
3476+
3477+
// Create view for the k-th expert weight from [D, M, K] -> [M, D]
3478+
// Data layout in memory is [D0M0, D0M1, ..., D0M_{M-1}, D1M0, D1M1, ...]
3479+
// We need [M, D] format with stride[0]=D*elemsize, stride[1]=elemsize
3480+
int64_t weight_view_ne[2] = {
3481+
select_export_weight_ne[1], // M = src0->ne[1]
3482+
select_export_weight_ne[0] // D = src0->ne[0] (adjusted for Q4_0/Q8_0)
3483+
};
3484+
float weight_view_nb[2] = {
3485+
src0->ne[0] * weight_elem_size, // M stride: one row = D * elemsize
3486+
weight_elem_size // D stride: one element
3487+
};
3488+
size_t weight_view_offset = k * select_export_weight_nb[2];
3489+
3490+
acl_tensor_ptr weight_view = ggml_cann_create_tensor(
3491+
export_weight_ptr, ggml_cann_type_mapping(type), weight_elem_size,
3492+
weight_view_ne, weight_view_nb, 2,
3493+
ACL_FORMAT_ND, weight_view_offset);
3494+
3495+
// Create view for the k-th expert scale from [D, M, K] -> [M, D]
3496+
int64_t scale_view_ne[2] = {
3497+
select_export_scale_ne[1], // M = src0->ne[1]
3498+
select_export_scale_ne[0] // D = src0->ne[0] / QK8_0
3499+
};
3500+
size_t scale_view_nb[2] = {
3501+
select_export_scale_nb[1], // M stride
3502+
select_export_scale_nb[0] // D stride
3503+
};
3504+
size_t scale_view_offset = k * select_export_scale_nb[2];
3505+
3506+
acl_tensor_ptr scale_view = ggml_cann_create_tensor(
3507+
export_scale_ptr, ACL_FLOAT16, scale_elem_size,
3508+
scale_view_ne, scale_view_nb, 2,
3509+
ACL_FORMAT_ND, scale_view_offset);
3510+
3511+
// Prepare input tensor [D, 1]
3512+
int64_t active_tensor_ne[2] = { src1->ne[0], 1 };
3513+
size_t active_tensor_nb[2] = { input_elem_size, src1->ne[0] * input_elem_size };
3514+
3515+
acl_tensor_ptr active_tensor = ggml_cann_create_tensor(
3516+
input_buffer, ACL_FLOAT16, input_elem_size,
3517+
active_tensor_ne, active_tensor_nb, 2,
3518+
ACL_FORMAT_ND, input_offset);
3519+
3520+
// Prepare output tensor [M, 1]
3521+
int64_t dst_ne[2] = { dst->ne[0], 1 };
3522+
size_t dst_nb[2] = { output_elem_size, dst->ne[0] * output_elem_size };
3523+
3524+
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
3525+
output_buffer, ACL_FLOAT16, output_elem_size,
3526+
dst_ne, dst_nb, 2,
3527+
ACL_FORMAT_ND, output_offset);
3528+
3529+
// Call WeightQuantBatchMatmulV2
3530+
GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2,
3531+
active_tensor.get(),
3532+
weight_view.get(),
3533+
scale_view.get(),
3534+
nullptr,
3535+
nullptr,
3536+
nullptr,
3537+
nullptr,
3538+
QK8_0,
3539+
acl_dst.get());
3540+
}
3541+
}
3542+
3543+
// Cast output back to target type if needed
3544+
if (dst->type != GGML_TYPE_F16) {
3545+
int64_t output_cast_ne[GGML_MAX_DIMS];
3546+
size_t output_cast_nb[GGML_MAX_DIMS];
3547+
3548+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3549+
output_cast_ne[i] = dst->ne[i];
3550+
}
3551+
3552+
output_cast_nb[0] = output_elem_size;
3553+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3554+
output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
3555+
}
3556+
3557+
acl_tensor_ptr acl_output_tensor = ggml_cann_create_tensor(
3558+
output_buffer, ACL_FLOAT16, output_elem_size,
3559+
output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
3560+
3561+
acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
3562+
3563+
aclnn_cast(ctx, acl_output_tensor.get(), acl_dst_tensor.get(),
3564+
ggml_cann_type_mapping(dst->type));
34113565
}
3412-
return;
34133566
}
34143567

34153568
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,10 +2300,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
23002300
}
23012301
case GGML_OP_MUL_MAT_ID:
23022302
switch (op->src[0]->type) {
2303-
case GGML_TYPE_F16:
2304-
case GGML_TYPE_F32:
2305-
return true;
2306-
case GGML_TYPE_Q8_0:
23072303
case GGML_TYPE_Q4_0:
23082304
#ifdef ASCEND_310P
23092305
// Q4 && Q8 per group is not support on 310p device

0 commit comments

Comments
 (0)