Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions Tools/PIDML/qaPidML.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ struct pidml {
// available particles: 211, 2212, 321
static constexpr int particlesPdgCode[numParticles] = {211, 2212, 321};

// values of track momentum when to switch from only TPC signal to combined TPC and TOF signal
// i-th momentum corresponds to the i-th particle
static constexpr float pSwitchValue[numParticles] = {0.5, 0.8, 0.5};

HistogramRegistry histReg{
"allHistograms",
{{"MC/211", "MC #pi^{+};p_{T} (GeV/c);Counts", {HistType::kTH1F, {{binsNb, 0, maxP}}}},
Expand Down Expand Up @@ -321,9 +325,19 @@ struct pidml {
template <std::size_t i, typename T>
void pidML(const T& track, const int pdgCodeMC)
{
float pidLogits[3] = {model211.applyModel(track), model2212.applyModel(track), model321.applyModel(track)};
float pidLogits[3];
if (track.p() < pSwitchValue[i]) {
pidLogits[0] = model211TPC.applyModel(track);
pidLogits[1] = model2212TPC.applyModel(track);
pidLogits[2] = model321TPC.applyModel(track);
} else {
pidLogits[0] = model211All.applyModel(track);
pidLogits[1] = model2212All.applyModel(track);
pidLogits[2] = model321All.applyModel(track);
}
int pid = getParticlePdg(pidLogits);
if (pid == particlesPdgCode[i]) {
// condition for sign: we want to work only with pi, p and K, without antiparticles
if (pid == particlesPdgCode[i] && track.sign() == 1) {
if (pdgCodeMC == particlesPdgCode[i]) {
fillPidHistos<i>(track, pdgCodeMC, true);
} else {
Expand All @@ -332,19 +346,27 @@ struct pidml {
}
}

PidONNXModel model211;
PidONNXModel model2212;
PidONNXModel model321;
// one model for one particle; Model with all TPC and TOF signal
PidONNXModel model211All;
PidONNXModel model2212All;
PidONNXModel model321All;
// Model with only TPC signal model
PidONNXModel model211TPC;
PidONNXModel model2212TPC;
PidONNXModel model321TPC;

Configurable<bool> cfgUseTOF{"useTOF", true, "Use ML model with TOF signal"};
Configurable<std::string> cfgModelDir{"model-dir", "http://alice-ccdb.cern.ch/Users/m/mkabus/pidml/onnx_models", "base path to the directory with ONNX models"};
Configurable<std::string> cfgScalingParamsFile{"scaling-params", "http://alice-ccdb.cern.ch/Users/m/mkabus/pidml/onnx_models/train_208_mc_with_beta_and_sigmas_scaling_params.json", "base path to the ccdb JSON file with scaling parameters from training"};

void init(InitContext const&)
{
model211 = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 211, cfgUseTOF.value);
model2212 = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 2212, cfgUseTOF.value);
model321 = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 321, cfgUseTOF.value);
model211All = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 211, true);
model2212All = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 2212, true);
model321All = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 321, true);

model211TPC = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 211, false);
model2212TPC = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 2212, false);
model321TPC = PidONNXModel(cfgModelDir.value, cfgScalingParamsFile.value, 321, false);
}

Filter trackFilter = aod::track::isGlobalTrack == static_cast<uint8_t>(true);
Expand Down