@@ -60,7 +60,7 @@ struct ggml_cgraph *sense_voice_build_graph_ctc_decoder(sense_voice_context &ctx
6060 ggml_cgraph *gf = ggml_new_graph_custom (ctx0, SENSEVOICE_DECODER_MAX_NODES, false );
6161
6262 ggml_tensor *encoder_out = ggml_new_tensor_3d (ctx0, state.encoder_out ->type ,
63- state.encoder_out ->ne [0 ], state.encoder_out ->ne [1 ],
63+ state.encoder_out ->ne [0 ], state.encoder_out ->ne [1 ],
6464 state.encoder_out ->ne [2 ]);
6565 ggml_set_name (encoder_out, " encoder_out" );
6666 ggml_set_input (encoder_out);
@@ -76,11 +76,9 @@ struct ggml_cgraph *sense_voice_build_graph_ctc_decoder(sense_voice_context &ctx
7676 }
7777 ggml_tensor * probs = ggml_soft_max (ctx0, cur);
7878 probs = ggml_reshape_2d (ctx0, probs, probs->ne [0 ], probs->ne [1 ] * probs->ne [2 ] * probs->ne [3 ]);
79- ggml_tensor * argmax_logit = ggml_argmax (ctx0, probs);
80- argmax_logit = ggml_reshape_3d (ctx0, argmax_logit, cur->ne [1 ], cur->ne [2 ], cur->ne [3 ]);
79+ probs = ggml_reshape_3d (ctx0, ggml_argmax (ctx0, probs), cur->ne [1 ], cur->ne [2 ], cur->ne [3 ]); // argmax_logit
8180 ggml_set_output (probs);
82- ggml_set_output (argmax_logit);
83- ggml_build_forward_expand (gf, argmax_logit);
81+ ggml_build_forward_expand (gf, probs);
8482 ggml_free (ctx0);
8583 return gf;
8684}
@@ -124,14 +122,11 @@ bool sense_voice_decode_internal(sense_voice_context &ctx,
124122 ggml_backend_tensor_get (argmax_logit, state.ids .data (), 0 , sizeof (int ) * argmax_logit->ne [0 ]);
125123 }
126124 else {
127- const int32_t n_logits = argmax_logit->ne [0 ] * argmax_logit->ne [1 ];
128125 // Get the tensor data into a temporary buffer
129- std::vector<int > temp_buffer (n_logits);
130- ggml_backend_tensor_get (argmax_logit, temp_buffer.data (), 0 , sizeof (int ) * n_logits);
131126 for (int32_t i = 0 ; i < argmax_logit->ne [1 ]; i++)
132127 {
133- int posL = i * argmax_logit->ne [0 ];
134- state.result_all [state.segmentIDs [i]].tokens = std::vector< int >(temp_buffer. begin () + posL, temp_buffer. begin () + posL + argmax_logit->ne [0 ]);
128+ state. result_all [state. segmentIDs [i]]. tokens . resize ( argmax_logit->ne [0 ]) ;
129+ ggml_backend_tensor_get (argmax_logit, state.result_all [state.segmentIDs [i]].tokens . data (), sizeof ( int ) * i * argmax_logit-> ne [ 0 ], sizeof ( int ) * argmax_logit->ne [0 ]);
135130 }
136131 }
137132 }
@@ -141,4 +136,4 @@ bool sense_voice_decode_internal(sense_voice_context &ctx,
141136 state.t_decode_us += ggml_time_us () - t_start_us;
142137
143138 return true ;
144- }
139+ }
0 commit comments