Skip to content

Commit 70406db

Browse files
authored
[Tools] MlResponse: bugfix assure equality of number of checked features to that in the .onnx file (#16722)
1 parent 1c1f6ab commit 70406db

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

Tools/ML/MlResponse.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,8 @@ class MlResponse
151151
void init(bool enableOptimizations = false, int threads = 0)
152152
{
153153
uint8_t counterModel{0};
154-
const int numCachedIndices = static_cast<int>(mCachedIndices.size());
155154
for (const auto& path : mPaths) {
156155
mModels[counterModel].initModel(path, enableOptimizations, threads);
157-
const int numInputNodes = mModels[counterModel].getNumInputNodes();
158-
if (numInputNodes != numCachedIndices) {
159-
LOG(fatal) << "Number of input nodes in the model " << path << " is different from the number of input features indices (" << numInputNodes << " vs " << numCachedIndices << ")";
160-
return;
161-
}
162156
++counterModel;
163157
}
164158
}
@@ -188,6 +182,13 @@ class MlResponse
188182
LOG(fatal) << "Model index " << nModel << " is out of range! The number of initialised models is " << mModels.size() << ". Please check your configurables.";
189183
}
190184

185+
const int numInputNodes = mModels[nModel].getNumInputNodes();
186+
const int numInputFeatures = static_cast<int>(input.size());
187+
188+
if (numInputNodes != numInputFeatures) {
189+
LOG(fatal) << "Number of input nodes in the model " << mPaths[nModel] << " is different from the number of input features to be tested (" << numInputNodes << " vs " << numInputFeatures << ")";
190+
}
191+
191192
TypeOutputScore* outputPtr = mModels[nModel].template evalModel<TypeOutputScore>(input);
192193
return std::vector<TypeOutputScore>{outputPtr, outputPtr + mNClasses};
193194
}

0 commit comments

Comments
 (0)