Skip to content

Commit 18183b2

Browse files
committed
内存第一批优化与调整,916M->850M
1 parent 5fa8bb1 commit 18183b2

3 files changed

Lines changed: 64 additions & 97 deletions

File tree

sense-voice/csrc/sense-voice-decoder.cc

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct ggml_cgraph *sense_voice_build_graph_ctc_decoder(sense_voice_context &ctx
6060
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, SENSEVOICE_DECODER_MAX_NODES, false);
6161

6262
ggml_tensor *encoder_out = ggml_new_tensor_3d(ctx0, state.encoder_out->type,
63-
state.encoder_out->ne[0], state.encoder_out->ne[1],
63+
state.encoder_out->ne[0], state.encoder_out->ne[1],
6464
state.encoder_out->ne[2]);
6565
ggml_set_name(encoder_out, "encoder_out");
6666
ggml_set_input(encoder_out);
@@ -76,11 +76,9 @@ struct ggml_cgraph *sense_voice_build_graph_ctc_decoder(sense_voice_context &ctx
7676
}
7777
ggml_tensor * probs = ggml_soft_max(ctx0, cur);
7878
probs = ggml_reshape_2d(ctx0, probs, probs->ne[0], probs->ne[1] * probs->ne[2] * probs->ne[3]);
79-
ggml_tensor * argmax_logit = ggml_argmax(ctx0, probs);
80-
argmax_logit = ggml_reshape_3d(ctx0, argmax_logit, cur->ne[1], cur->ne[2], cur->ne[3]);
79+
probs = ggml_reshape_3d(ctx0, ggml_argmax(ctx0, probs), cur->ne[1], cur->ne[2], cur->ne[3]); // argmax_logit
8180
ggml_set_output(probs);
82-
ggml_set_output(argmax_logit);
83-
ggml_build_forward_expand(gf, argmax_logit);
81+
ggml_build_forward_expand(gf, probs);
8482
ggml_free(ctx0);
8583
return gf;
8684
}
@@ -124,14 +122,11 @@ bool sense_voice_decode_internal(sense_voice_context &ctx,
124122
ggml_backend_tensor_get(argmax_logit, state.ids.data(), 0, sizeof(int) * argmax_logit->ne[0]);
125123
}
126124
else {
127-
const int32_t n_logits = argmax_logit->ne[0] * argmax_logit->ne[1];
128125
// Get the tensor data into a temporary buffer
129-
std::vector<int> temp_buffer(n_logits);
130-
ggml_backend_tensor_get(argmax_logit, temp_buffer.data(), 0, sizeof(int) * n_logits);
131126
for(int32_t i = 0; i < argmax_logit->ne[1]; i++)
132127
{
133-
int posL = i * argmax_logit->ne[0];
134-
state.result_all[state.segmentIDs[i]].tokens = std::vector<int>(temp_buffer.begin() + posL, temp_buffer.begin() + posL + argmax_logit->ne[0]);
128+
state.result_all[state.segmentIDs[i]].tokens.resize(argmax_logit->ne[0]);
129+
ggml_backend_tensor_get(argmax_logit, state.result_all[state.segmentIDs[i]].tokens.data(), sizeof(int) * i * argmax_logit->ne[0], sizeof(int) * argmax_logit->ne[0]);
135130
}
136131
}
137132
}
@@ -141,4 +136,4 @@ bool sense_voice_decode_internal(sense_voice_context &ctx,
141136
state.t_decode_us += ggml_time_us() - t_start_us;
142137

143138
return true;
144-
}
139+
}

0 commit comments

Comments
 (0)