FairRoot/PandaRoot
PndMvaDataSet.h
Go to the documentation of this file.
1 /***************************************
2  * Class interface of DataSet class. *
3  * Author: M.Babai (M.Babai@rug.nl) *
4  * License: *
5  * Version: *
6  ***************************************/
7 //#pragma once
8 #ifndef PND_MVA_DATASET_H
9 #define PND_MVA_DATASET_H
10 
11 // C++ includes
12 #include <iostream>
13 #include <fstream>
14 #include <string>
15 #include <vector>
16 #include <map>
17 #include <set>
18 #include <algorithm>
19 #include <cmath>
20 #include <cassert>
21 #include <limits>
22 #include <typeinfo>
23 #include <exception>
24 #include <utility>
25 
26 // ROOT
27 #include "TFile.h"
28 #include "TTree.h"
29 #include "TRandom3.h"
30 
31 // Local includes
32 #include "PndMvaClass.h"
33 #include "PndMvaVariable.h"
34 #include "PndMvaVarPCATransform.h"
35 
36 // ========================================================================
37 // Application type
38 typedef enum AppType{
39  UNKAPP = 0,
40  TRAIN = 1, // Training algorithm.
41  CLASSIFY = 2, // Read weights to do classification.
42  TMVATRAIN = 10,// Provide input for TMVA Training.
43  TMVACLS = 20,// TMVA classification.
44  PRE_INIT_EVTS = 30 // Pre-initialized event data.
45 } AppType;
46 
47 // Normalization schemes
48 typedef enum NormType{
49  NONORM = 0, // Do nothing
50  VARX = 1, // Use Sample variance
51  MINMAX = 2, // Use Sample Min and Max
52  MEDIAN = 3, // Use median and interquartile range (IQR).
53  VARNORM = 4 // Variable Normalize Transform
54 } NormType;
55 
57 
58 class PndMvaDataSetException: public std::exception
59 {
60  public:
62  : m_message("UNKNOWN_MvaDataSetException")
63  {};
64 
65  explicit PndMvaDataSetException(std::string const& val)
66  : m_message(val)
67  {};
68 
69  virtual ~PndMvaDataSetException() throw()
70  {};
71 
72  virtual char const* what() const throw()
73  {
74  return m_message.c_str();
75  };
76 
77  virtual std::string const& what()
78  {
79  return m_message;
80  };
81 
82  private:
83  std::string m_message;
84 };
86 
87 // ==================== Data set class ==========================
89 {
90  public:
91 
102  explicit PndMvaDataSet( std::vector< std::pair<std::string, std::vector<float>*> > const& InputEvtsParam,
103  std::vector<std::string> const& classNames,
104  std::vector<std::string> const& varNames,
105  AppType type);
113  explicit PndMvaDataSet( std::string const& WeightFile,
114  std::vector<std::string> const& classNames,
115  std::vector<std::string> const& varNames,
116  AppType type);
117 
119  virtual ~PndMvaDataSet();
120 
125  //void WriteDataSet(std::string const& outFile) __attribute__ ((deprecated));
126  virtual void WriteDataSet(std::string const& outFile);
127 
137  virtual void InitClsCondMeans(std::set <size_t> const& excludeIndxs);
138 
143  inline void SetTrim(bool t);
144 
146  inline std::vector< std::pair<std::string, std::vector<float>* > > const& GetData() const;
147 
149  inline std::vector<PndMvaClass> const& GetClasses() const;
150 
152  inline std::vector<PndMvaVariable> const& GetVars() const;
153 
155  inline std::map< std::string, std::vector<float>* > const& GetClassCondMeans() const;
156 
158  inline std::string const& GetInFileName() const;
159 
160  //========================= PCA =====================//
166  virtual void PCATransForm();
167 
171  inline bool Used_PCA() const;
172 
176  inline void Use_PCA(bool t);
177 
182  inline PndMvaVarPCATransform const& Get_PCA() const;
183 
184  //_________________________ PCA _____________________//
185 
189  inline NormType GetNormType() const;
190 
194  inline void SetNormType(NormType t);
195 
199  inline AppType GetAppType () const;
200 
203  inline void SetAppType(AppType t);
204 
209  virtual void Initialize();
210 
211  inline size_t GetRndSeed()const;
212  inline void SetRndSeed(size_t const sd);
213 
214  //______________________________________________________________
215  protected:
219  void ReadInput();
220 
224  void ReadWeightsFromFile();
225 
226  //==============================================================
227  private:
228  // Private to avoid mistakes.
229  PndMvaDataSet(PndMvaDataSet const& other);
230  PndMvaDataSet& operator=(PndMvaDataSet const& other);
231 
235  void Trim();
236 
240  void NormalizeDataSet();
241 
246  void InitClasses(std::vector<std::string> const& labels);
247 
252  void InitVariables(std::vector<std::string> const& variables);
253 
254  // Validate the input file
255  void ValidateWeightFile();
256 
263  void CompClsCondMean( std::string const& clsName, std::set <size_t> const& exCluds );
264 
269  void ComputeVariance();
270 
274  void DetermineMedian();
275 
279  void MinMaxDiff();
280 
284  void FindMinMax();
285 
289  void VarNormalize();
290 
291  // __________________________ Member parameters ___________
293  std::string m_input;
294 
296  std::vector<PndMvaClass> m_classes;
297 
299  std::vector<PndMvaVariable> m_vars;
300 
302  std::vector< std::pair<std::string, std::vector<float>*> > m_events;
303 
305  std::map< std::string, std::vector<float>* > m_ClassCondMeans;
306 
307  // PCA transformation.
309 
310  // If PCA was applied.
311  bool m_UsePCA;
312 
313  // Normalization scheme
315 
316  // Application type.
318  bool m_trim;
319  size_t m_RND_seed;
320 };
321 // End of class interface definition.
322 
323 // ============= Inline implementation ==================
324 
325 inline size_t PndMvaDataSet::GetRndSeed()const
326 {
327  return this->m_RND_seed;
328 };
329 
330 inline void PndMvaDataSet::SetRndSeed(size_t const sd)
331 {
332  this->m_RND_seed = sd;
333 };
334 
335 inline std::vector< std::pair<std::string, std::vector<float>*> > const& PndMvaDataSet::GetData() const
336 {
337  assert(m_events.size() != 0);
338  return m_events;
339 };
340 
341 inline std::vector<PndMvaClass> const& PndMvaDataSet::GetClasses() const
342 {
343  return m_classes;
344 };
345 
346 inline std::vector<PndMvaVariable> const& PndMvaDataSet::GetVars() const
347 {
348  return m_vars;
349 };
350 
351 inline std::map< std::string, std::vector<float>* > const& PndMvaDataSet::GetClassCondMeans() const
352 {
353  return m_ClassCondMeans;
354 };
355 
356 inline std::string const& PndMvaDataSet::GetInFileName() const
357 {
358  return m_input;
359 };
360 
361 inline bool PndMvaDataSet::Used_PCA() const
362 {
363  return m_UsePCA;
364 };
365 inline void PndMvaDataSet::Use_PCA(bool t)
366 {
367  m_UsePCA = t;
368 };
370 {
371  return m_PCA;
372 };
374 {
375  return m_NormType;
376 };
378 {
379  m_NormType = t;
380 };
382 {
383  return m_AppType;
384 };
386 {
387  m_AppType = t;
388 };
389 inline void PndMvaDataSet::SetTrim(bool t)
390 {
391  m_trim = t;
392 };
393 #endif
void ComputeVariance()
void ReadInput()
AppType m_AppType
std::vector< PndMvaClass > m_classes
Classes.
virtual void PCATransForm()
std::vector< std::pair< std::string, std::vector< float > * > > m_events
Container to keep the Event data feature vectors.
void MinMaxDiff()
AppType GetAppType() const
void InitVariables(std::vector< std::string > const &variables)
TString outFile
Definition: hit_dirc.C:17
Double_t val[nBoxes][nFEBox]
Definition: createCalib.C:11
virtual void WriteDataSet(std::string const &outFile)
virtual ~PndMvaDataSet()
Destructor.
void SetNormType(NormType t)
PndMvaDataSetException(std::string const &val)
Definition: PndMvaDataSet.h:65
void SetAppType(AppType t)
virtual void Initialize()
std::vector< std::pair< std::string, std::vector< float > * > > const & GetData() const
Get available data vectors.
std::vector< std::string > labels
std::string m_input
Input File name.
void SetRndSeed(size_t const sd)
NormType
Definition: PndMvaDataSet.h:48
NormType m_NormType
bool Used_PCA() const
std::string const & GetInFileName() const
Get name of input file name (weight/event file).
std::vector< PndMvaVariable > m_vars
Variables.
std::map< std::string, std::vector< float > * > const & GetClassCondMeans() const
Get classconditional means for all classes (labels).
void Use_PCA(bool t)
void DetermineMedian()
void VarNormalize()
virtual char const * what() const
Definition: PndMvaDataSet.h:72
PndMvaVarPCATransform const & Get_PCA() const
void ReadWeightsFromFile()
void FindMinMax()
std::map< std::string, std::vector< float > * > m_ClassCondMeans
Container to keep the Class Conditional means.
void NormalizeDataSet()
void SetTrim(bool t)
size_t GetRndSeed() const
PndMvaVarPCATransform m_PCA
virtual void InitClsCondMeans(std::set< size_t > const &excludeIndxs)
virtual std::string const & what()
Definition: PndMvaDataSet.h:77
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
TTree * t
Definition: bump_analys.C:13
PndMvaDataSet & operator=(PndMvaDataSet const &other)
void CompClsCondMean(std::string const &clsName, std::set< size_t > const &exCluds)
void InitClasses(std::vector< std::string > const &labels)
void ValidateWeightFile()
virtual ~PndMvaDataSetException()
Definition: PndMvaDataSet.h:69
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
PndMvaDataSet(std::vector< std::pair< std::string, std::vector< float > * > > const &InputEvtsParam, std::vector< std::string > const &classNames, std::vector< std::string > const &varNames, AppType type)
NormType GetNormType() const
AppType
Definition: PndMvaDataSet.h:38