FairRoot/PandaRoot
PndPidMvaAssociatorTask.cxx
Go to the documentation of this file.
1 /* ************************************
2  * Author: M. Babai (M.Babai@rug.nl) *
3  * *
4  * pid classifier *
5  * *
6  * Created: 23-03-2010 *
7  * Modified: *
8  * *
9  * ************************************/
10 //===================
12 
13 // Standard C++ includes
14 #include <iostream>
15 
16 // Root includes.
17 #include "TClonesArray.h"
18 
19 // PANDA and Fair includes.
20 #include "FairTask.h"
21 #include "FairRootManager.h"
22 #include "PndPidCandidate.h"
23 #include "PndPidProbability.h"
24 
25 // MVA Headers.
26 #include "PndMvaClassifier.h"
27 #include "PndKnnClassify.h"
28 #include "PndLVQClassify.h"
31 
33 
34 //====================
35 #define PIDMVA_ASSOCIATORT_DEBUG 0
36 
37 //==========================================================
38 #if (PIDMVA_ASSOCIATORT_DEBUG != 0)
39 // Function to use for debugging
40 void printResult(std::map<std::string,float>& res)
41 {
42  std::cout << "\n\t================================== \n";
43  for( std::map<std::string,float>::iterator ii=res.begin();
44  ii != res.end(); ++ii)
45  {
46  std::cout << "\t" << (*ii).first
47  << "\t=> " << (*ii).second
48  << '\n';
49  }
50  std::cout << "\n\t================================== \n";
51 }
52 #endif
53 
54 //==========================================================
55 
60  : FairTask("PndPidMvaAssociatorTaskSTD"),
61  fManager(0),
62  fVarNames (std::vector<std::string>()),
63  fClassNames(std::vector<std::string>()),
64  fWeightsFileName(std::string(getenv("VMCWORKDIR")) + std::string("/tools/MVA/PndMVAWeights/")),
65  fNumNeigh(200),
66  fScFact(0.8),
67  fWeight(1.00),
68  fClassifier(0),
69  fMethodType(UNKNOWN_METHOD),
70  fPidChargedCand(0),
71  fPidChargedProb(new TClonesArray("PndPidProbability")),
72  fMCTrack(0),
73  fMethodName("UNKNOWN_METHOD")
74 {
75  // Init neutral probab. containers.
76  // fPidNeutralProb = new TClonesArray("PndPidProbability");
77 
78  // Set Default path to the weight file
79  // SetDefaultWeightsPath();
80 }
81 
82 //___________________________________________________________
87  : FairTask(name),
88  fManager(0),
89  fVarNames (std::vector<std::string>()),
90  fClassNames(std::vector<std::string>()),
91  fWeightsFileName(std::string(getenv("VMCWORKDIR")) + std::string("/tools/MVA/PndMVAWeights/")),
92  fNumNeigh(200),
93  fScFact(0.8),
94  fWeight(1.00),
95  fClassifier(0),
96  fMethodType(UNKNOWN_METHOD),
97  fPidChargedCand(0),
98  fPidChargedProb(new TClonesArray("PndPidProbability")),
99  fMCTrack(0),
100  fMethodName("UNKNOWN_METHOD")
101 {
102  // Init neutral probab. containers.
103  // fPidNeutralProb = new TClonesArray("PndPidProbability");
104 
105  // Set Default path to the weight file
106  //SetDefaultWeightsPath();
107 }
108 
109 /*
110  * Set the default path where the weights are stored.
111  */
113 {
114  fWeightsFileName = std::string(getenv("VMCWORKDIR"));
115  fWeightsFileName += std::string("/tools/MVA/PndMVAWeights/");
116  std::cout<<"<INFO> Default Weights path is set to "
118  << '\n';
119 }
120 
121 //___________________________________________________________
126 {
127  // Clean-up allocated stuff.
128  if(fManager) {
129  fManager->Write();
130  delete fManager;
131  }
132 
133  if(fPidChargedCand) {
134  delete fPidChargedCand;
135  }
136 
137  if(fPidChargedProb) {
138  delete fPidChargedProb;
139  }
140  //if(fPidNeutralCand)
141  //delete fPidNeutralCand;
142 
143  //if(fPidNeutralProb)
144  //delete fPidNeutralProb;
145 
146  if(fMCTrack) {
147  delete fMCTrack;
148  }
149 
150  if(fClassifier) {
151  delete fClassifier;
152  }
153 }
154 
155 //___________________________________________________________
157 {
158  std::cout << "<-I-> InitStatus PndPidMvaAssociatorTask::Init()\n";
159 
160  fManager = FairRootManager::Instance();
161  if( !fManager ) {
162  std::cerr << "<ERROR> PndPidMvaAssociatorTask::Init:\n"
163  << "\t Could not init FairRootManager."
164  << std::endl;
165 
166  return kERROR;
167  }
168 
169  // Get charged candidates.
170  fPidChargedCand = (TClonesArray *)fManager->GetObject("PidChargedCand");
171  if ( !fPidChargedCand) {
172  std::cerr << "<ERROR> PndPidMvaAssociatorTask::Init: No PidChargedCand there!"
173  << std::endl;
174  return kERROR;
175  }
176 
177  // Get Neutral candidates.
178  /*
179  fPidNeutralCand = (TClonesArray *)fManager->GetObject("PidNeutralCand");
180  if ( ! fPidNeutralCand)
181  {
182  std::cerr << "<ERROR> PndPidMvaAssociatorTask::Init: No PidNeutralCand there!"
183  << std::endl;
184  return kERROR;
185  }
186  */
187 
188  std::cout << "<INFO> Using weight file " << fWeightsFileName
189  << '\n';
190 
191  // Init Classifier object
192  switch(fMethodType)
193  {
194  case TMVA_MLP:// Multi label MLP classifier from TMVA.
195  {
196  PndMultiClassMlpClassify* TmvaMlpCls = new PndMultiClassMlpClassify(fWeightsFileName, fClassNames,
197  fVarNames);
198  if(!TmvaMlpCls)
199  {
200  std::cerr << "<Error> Failed to initialize TMVA_MLP classifier."
201  << std::endl;
202  return kERROR;
203  }
204  // Init
205  TmvaMlpCls->Initialize();
206 
207  //fClassifier = dynamic_cast<PndMultiClassMlpClassify*>(TmvaMlpCls);
208  fClassifier = TmvaMlpCls;
209  std::cout << "<INFO> TMVA_MLP initialized using " << fWeightsFileName << '\n';
210 
211  fMethodName = "TMVAMLP";
212  }
213  break;
214 
215  case TMVA_BDT:// Multi label BDT classifier from TMVA.
216  {
217  PndMultiClassBdtClassify* TmvaBdtCls = new PndMultiClassBdtClassify(fWeightsFileName, fClassNames,
218  fVarNames);
219  if(!TmvaBdtCls)
220  {
221  std::cerr << "<Error> Failed to initialize TMVA_BDT classifier."
222  << std::endl;
223  return kERROR;
224  }
225  // INIT
226  TmvaBdtCls->Initialize();
227 
228  //fClassifier = dynamic_cast<PndMultiClassBdtClassify*>(TmvaBdtCls);
229  fClassifier = TmvaBdtCls;
230  std::cout << "<INFO> TMVA_BDT initialized using " << fWeightsFileName << '\n';
231 
232  fMethodName = "TMVABDT";
233  }
234  break;
235 
236  case LVQ:
237  {
238  PndLVQClassify* LvqCls = new PndLVQClassify(fWeightsFileName, fClassNames, fVarNames);
239 
240  if(!LvqCls)
241  {
242  std::cerr << "<Error> Failed to initialize LVQ classifier."
243  << std::endl;
244  return kERROR;
245  }
246  // Init
247  LvqCls->Initialize();
248 
249  //fClassifier = dynamic_cast<PndMvaClassifier*>(LvqCls);
250  fClassifier = LvqCls;
251  std::cout << "<INFO> LVQ initialized using " << fWeightsFileName << '\n';
252 
253  fMethodName = "LVQ";
254  }
255  break;
256 
257  case KNN:
258  default:
259  {
260  PndKnnClassify* KnnCls = new PndKnnClassify(fWeightsFileName, fClassNames, fVarNames);
261 
262  if(!KnnCls)
263  {
264  std::cerr << "<Error> Failed to initialize KNN classifier."
265  << std::endl;
266  return kERROR;
267  }
268 
269  // Set parameters.
270  KnnCls->SetEvtParam(fScFact, fWeight);
271  KnnCls->SetKnn(fNumNeigh);
272  KnnCls->Initialize();
273 
274  //fClassifier = dynamic_cast<PndMvaClassifier*>(KnnCls);
275  fClassifier = KnnCls;
276  std::cout << "<INFO> KNN initialized using " << fWeightsFileName << '\n';
277 
278  fMethodName = "KNN";
279  }
280  break;
281  }// End of switch(fMethodType)
282 
283  // Register objects in the output chain
284  Register();
285 
286  std::cout << "<INFO> PndPidMvaAssociatorTask::Init: Success!\n";
287  return kSUCCESS;
288 }
289 
290 //______________________________________________________
292 {}
293 
294 void PndPidMvaAssociatorTask::SetClassifier(std::string const& methodNameStr)
295 {
296  if(methodNameStr == "KNN")
297  {
298  fMethodType = KNN;
299  fMethodName = "KNN";
300  }
301  else if(methodNameStr == "LVQ")
302  {
303  fMethodType = LVQ;
304  fMethodName = "LVQ";
305  }
306  else if(methodNameStr == "TMVA_MLP")
307  {
309  fMethodName = "TMVAMLP";
310  }
311  else if(methodNameStr == "TMVA_BDT")
312  {
314  fMethodName = "TMVABDT";
315  }
316  else
317  {
318  std::cerr << "<ERROR> Unknown Method."
319  << std::endl;
320  }
321 };
322 //______________________________________________________
323 void PndPidMvaAssociatorTask::Exec(Option_t* option)
324 {
325  std::cout << option << '\n';
326 
327  if (fPidChargedProb->GetEntriesFast() != 0)
328  {
329  fPidChargedProb->Delete();
330  }
331 
332 #if ( PIDMVA_ASSOCIATORT_DEBUG != 0 )
333  std::cout << "<INFO> Call to Exec with options = " << option
334  << "___\n";
335 #endif
336 
337  if(fVerbose > 1)
338  {
339  std::cout << "-I- Start PndPidMvaAssociatorTask.\n";
340  }
341 
342  // Charged Candidates Loop
343  for(int i = 0; i < fPidChargedCand->GetEntriesFast(); i++)
344  {
346  TClonesArray& pidRef = *fPidChargedProb;
347 
348  // initializes with zeros
349  PndPidProbability* prob = new(pidRef[i]) PndPidProbability();
350 
351  if(fVerbose > 1)
352  {
353  std::cout << "-I- PndPidMVAAssociatorTask Ch BEFORE "
354  << pidcand->GetLorentzVector().M()
355  << '\n';
356  }
357  // Classify
358  DoPidMatch(*pidcand, *prob);
359 
360  if(fVerbose > 1)
361  {
362  std::cout << "-I- PndPidMVAAssociatorTask Ch AFTER "
363  << pidcand->GetLorentzVector().M()
364  << '\n';
365  }
366  }
367 
368  // Get the Neutral Candidates
369  /*
370  for(int i = 0; i < fPidNeutralCand->GetEntriesFast(); i++)
371  {
372  PndPidCandidate* pidcand = (PndPidCandidate*)fPidNeutralCand->At(i);
373  TClonesArray& pidRef = *fPidNeutralProb;
374  // initializes with zeros
375  PndPidProbability* prob = new(pidRef[i]) PndPidProbability();
376  // Classify
377  DoPidMatch(*pidcand, *prob);
378  }
379  */
380 }
381 
388  PndPidProbability& prob)
389 {
390 
391  std::map<std::string, float> out;
392  std::vector<float> const* evtPidData = PrepareEvtVect(pidcand);
393 
394  // Perform Recognition.
395  if( evtPidData ) {
396  fClassifier->GetMvaValues( *evtPidData, out);
397  delete evtPidData;
398  }
399  else {
400  // Feature vector is empty or damaged.
401  delete evtPidData;
402  evtPidData = 0;
403  return;
404  }
405 
406 #if ( PIDMVA_ASSOCIATORT_DEBUG != 0 )
407  std::cout << "****************************************************\n"
408  << "Momentum = " << (pidcand.GetMomentum()).Mag()
409  << "\nGetEnergy = " << pidcand.GetEnergy()
410  << "\nEMC = " << pidcand.GetEmcCalEnergy()
411  << "\nEMC/P = "
412  << (pidcand.GetEmcCalEnergy())/((pidcand.GetMomentum()).Mag())
413  << "\nEMCZ20 = " << pidcand.GetEmcClusterZ20()
414  << "\nEMCZ53 = " << pidcand.GetEmcClusterZ53()
415  << "\nEMCLAT = " << pidcand.GetEmcClusterLat()
416  << "\nEmcE1 = " << pidcand.GetEmcClusterE1()
417  << "\nEmcE9 = " << pidcand.GetEmcClusterE9()
418  << "\nEmcE25 = " << pidcand.GetEmcClusterE25()
419  << "\nSTT = " << pidcand.GetSttMeanDEDX()
420  << "\nMVD = " << pidcand.GetMvdDEDX()
421  << "\nDRC_TC = " << pidcand.GetDrcThetaC()
422  << '\n';
423  printResult(out);
424  std::cout << "====================================================\n";
425 #endif
426 
427  // Set probs.
428  for(size_t i = 0; i < fClassNames.size(); i++)
429  {
430  std::string name = fClassNames[i];
431 
432  if(name == "electron")
433  {
434  prob.SetElectronPdf(out[name]);
435  }
436  else if(name == "muon")
437  {
438  prob.SetMuonPdf(out[name]);
439  }
440  else if(name == "pion")
441  {
442  prob.SetPionPdf(out[name]);
443  }
444  else if(name == "kaon")
445  {
446  prob.SetKaonPdf(out[name]);
447  }
448  else if(name == "proton")
449  {
450  prob.SetProtonPdf(out[name]);
451  }
452  else
453  {
454  std::cerr << "<ERROR> Unknown label (class Name).\n"
455  << std::flush;
456  }
457  }
458 }
459 
460 std::vector<float> const* PndPidMvaAssociatorTask::PrepareEvtVect(PndPidCandidate const& pidcand) const
461 {
462  std::vector<float>* vect = new std::vector<float>();
463  float mom = (pidcand.GetMomentum()).Mag();
464 
465  for(size_t i = 0; i < fVarNames.size(); i++)
466  {
467  if( fVarNames[i] == "p" )
468  {
469  vect->push_back((pidcand.GetMomentum()).Mag());
470  }
471  else if(fVarNames[i] == "emc")
472  {
473  if(mom > 0.00) { // E/p
474  vect->push_back( (pidcand.GetEmcCalEnergy())/mom);
475  }
476  else {
477  std::cerr << "<WARNING> (p > 0) failed. The event is skipped.\n"
478  << "<ER-I> p = " << mom << std::endl;
479  delete vect;
480  vect = 0;
481  return 0;// Can not proceed. Break the procedure
482  }
483  }
484  // Cluster Ex parameters.
485  else if( (fVarNames[i] == "e1") || (fVarNames[i] == "E1") )
486  {
487  vect->push_back(pidcand.GetEmcClusterE1());
488  }
489  else if( (fVarNames[i] == "e9") || (fVarNames[i] == "E9") )
490  {
491  vect->push_back(pidcand.GetEmcClusterE9());
492  }
493  else if( (fVarNames[i] == "e25") || (fVarNames[i] == "E25") )
494  {
495  vect->push_back(pidcand.GetEmcClusterE25());
496  }
497  else if( (fVarNames[i] == "e1e9") || (fVarNames[i] == "E1E9") )
498  {
499  if( pidcand.GetEmcClusterE9() > 0 ) {
500  vect->push_back(pidcand.GetEmcClusterE1()/pidcand.GetEmcClusterE9());
501  }
502  else {
503  std::cerr << "<WARNING> (EmcClusterE9 > 0) failed. The event is skipped.\n"
504  << std::flush;
505  delete vect;
506  vect = 0;
507  return 0;// Can not proceed. Break the procedure
508  }
509  }
510  else if( (fVarNames[i] == "e9e25") || (fVarNames[i] == "E9E25") )
511  {
512  if( pidcand.GetEmcClusterE25() > 0 ) {
513  vect->push_back(pidcand.GetEmcClusterE9()/pidcand.GetEmcClusterE25());
514  }
515  else {
516  std::cerr << "<WARNING> (EmcClusterE25 > 0) failed. The event is skipped.\n"
517  << std::flush;
518  delete vect;
519  vect = 0;
520  return 0;// Can not proceed. Break the procedure
521  }
522  }
523  //======== Zernike & moments
524  else if(fVarNames[i] == "z20")
525  {
526  vect->push_back(pidcand.GetEmcClusterZ20());
527  }
528  else if(fVarNames[i] == "z53")
529  {
530  vect->push_back(pidcand.GetEmcClusterZ53());
531  }
532  // Cluster Second lat. moment
533  else if(fVarNames[i] == "lat")
534  {
535  vect->push_back(pidcand.GetEmcClusterLat());
536  }
537  // ========== other detectors
538  else if(fVarNames[i] == "stt")
539  {
540  vect->push_back(pidcand.GetSttMeanDEDX());
541  }
542  else if(fVarNames[i] == "mvd")
543  {
544  vect->push_back(pidcand.GetMvdDEDX());
545  }
546  else if(fVarNames[i] == "thetaC")
547  {
548  vect->push_back(pidcand.GetDrcThetaC());
549  }
550  }
551  return vect;
552 }
553 
554 //_________________________________________________________________
556 {
557  std::string tcaName = fMethodName + "MvaProb";
558  //---
559  FairRootManager::Instance()->Register(tcaName.c_str(),"Pid", fPidChargedProb, kTRUE);
560  // FairRootManager::Instance()->Register("MvaNeutralProb","Pid", fPidNeutralProb, kTRUE);
561 }
562 
563 //_________________________________________________________________
565 {}
566 //_________________________________________________________________
568 {}
Double_t GetEmcClusterE25() const
TClonesArray * fMCTrack
PndPidProbability TCA for charged particles.
int fVerbose
Definition: poormantracks.C:24
Int_t res
Definition: anadigi.C:166
Int_t i
Definition: run_full.C:25
void SetPionPdf(Double_t val)
PndMvaClassifier * fClassifier
MVA classifier object.
Float_t GetSttMeanDEDX() const
Float_t GetEmcCalEnergy() const
virtual void GetMvaValues(std::vector< float > EvtData, std::map< std::string, float > &result)=0
TLorentzVector GetLorentzVector() const
std::vector< float > const * PrepareEvtVect(PndPidCandidate const &pidcand) const
virtual void Exec(Option_t *option)
Double_t GetEnergy() const
void SetKaonPdf(Double_t val)
Double_t GetEmcClusterE9() const
TClonesArray * fPidChargedProb
PndPidCandidate TCA for charged particles.
Double_t mom
Definition: plot_dirc.C:14
!&lt; Type definition of the neighbour list.
Float_t GetDrcThetaC() const
void SetElectronPdf(Double_t val)
void SetMuonPdf(Double_t val)
Double_t GetEmcClusterE1() const
std::string fMethodName
Monte-Carlo Truth track TCA.
Double_t GetEmcClusterZ53() const
void SetEvtParam(float const scFact, double const weight)
Double_t GetEmcClusterLat() const
Float_t GetMvdDEDX() const
Double_t GetEmcClusterZ20() const
TFile * out
Definition: reco_muo.C:20
TString name
std::vector< std::string > fVarNames
Variable names container.
void SetKnn(size_t const N)
Set the number of neighbours.
size_t fNumNeigh
Number of neighbors.
void SetProtonPdf(Double_t val)
void SetClassifier(Mva_MethodType const &methodT)
virtual void Initialize()
std::vector< std::string > fClassNames
Class names container.
ClassImp(PndAnaContFact)
Interface definition of the LVQ classifier.
void DoPidMatch(PndPidCandidate &pidcand, PndPidProbability &prob)
std::string fWeightsFileName
Path to the file holding weights (proto-types, examples, ...)
virtual void Initialize()
TVector3 GetMomentum() const
Mva_MethodType fMethodType
MVA Method name.