Skip to content

Commit a3d52bb

Browse files
authored
[PID ML] Separate models for different momentum (#479)
* add separate models for different momentum * Add array with momentum threshold values * True -> true * Fix error with list initialization * chang <= to = * change initialization method
1 parent 137d5fe commit a3d52bb

1 file changed

Lines changed: 31 additions & 9 deletions

File tree

Tools/PIDML/qaPidML.cxx

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ struct pidml {
4747
// available particles: 211, 2212, 321
4848
static constexpr int particlesPdgCode[numParticles] = {211, 2212, 321};
4949

50+
// values of track momentum when to switch from only TPC signal to combined TPC and TOF signal
51+
// i-th momentum corresponds to the i-th particle
52+
static constexpr float pSwitchValue[numParticles] = {0.5, 0.8, 0.5};
53+
5054
HistogramRegistry histReg{
5155
"allHistograms",
5256
{{"MC/211", "MC #pi^{+};p_{T} (GeV/c);Counts", {HistType::kTH1F, {{binsNb, 0, maxP}}}},
@@ -321,9 +325,19 @@ struct pidml {
321325
template <std::size_t i, typename T>
322326
void pidML(const T& track, const int pdgCodeMC)
323327
{
324-
float pidLogits[3] = {model211.applyModel(track), model2212.applyModel(track), model321.applyModel(track)};
328+
float pidLogits[3];
329+
if (track.p() < pSwitchValue[i]) {
330+
pidLogits[0] = model211TPC.applyModel(track);
331+
pidLogits[1] = model2212TPC.applyModel(track);
332+
pidLogits[2] = model321TPC.applyModel(track);
333+
} else {
334+
pidLogits[0] = model211All.applyModel(track);
335+
pidLogits[1] = model2212All.applyModel(track);
336+
pidLogits[2] = model321All.applyModel(track);
337+
}
325338
int pid = getParticlePdg(pidLogits);
326-
if (pid == particlesPdgCode[i]) {
339+
// condition for sign: we want to work only with pi, p and K, without antiparticles
340+
if (pid == particlesPdgCode[i] && track.sign() == 1) {
327341
if (pdgCodeMC == particlesPdgCode[i]) {
328342
fillPidHistos<i>(track, pdgCodeMC, true);
329343
} else {
@@ -332,19 +346,27 @@ struct pidml {
332346
}
333347
}
334348

335-
PidONNXModel model211;
336-
PidONNXModel model2212;
337-
PidONNXModel model321;
349+
// one model for one particle; Model with all TPC and TOF signal
350+
PidONNXModel model211All;
351+
PidONNXModel model2212All;
352+
PidONNXModel model321All;
353+
// Model with only TPC signal model
354+
PidONNXModel model211TPC;
355+
PidONNXModel model2212TPC;
356+
PidONNXModel model321TPC;
338357

339-
Configurable<bool> cfgUseTOF{"useTOF", true, "Use ML model with TOF signal"};
340358
Configurable<std::string> cfgModelDir{"model-dir", "http://alice-ccdb.cern.ch/Users/m/mkabus/pidml/onnx_models", "base path to the directory with ONNX models"};
341359
Configurable<std::string> cfgScalingParamsFile{"scaling-params", "http://alice-ccdb.cern.ch/Users/m/mkabus/pidml/onnx_models/train_208_mc_with_beta_and_sigmas_scaling_params.json", "base path to the ccdb JSON file with scaling parameters from training"};
342360

343361
void init(InitContext const&)
344362
{
345-
model211 = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 211, cfgUseTOF.value);
346-
model2212 = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 2212, cfgUseTOF.value);
347-
model321 = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 321, cfgUseTOF.value);
363+
model211All = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 211, true);
364+
model2212All = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 2212, true);
365+
model321All = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 321, true);
366+
367+
model211TPC = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 211, false);
368+
model2212TPC = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 2212, false);
369+
model321TPC = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 321, false);
348370
}
349371

350372
Filter trackFilter = aod::track::isGlobalTrack == static_cast<uint8_t>(true);

0 commit comments

Comments
 (0)