FairRoot/PandaRoot
PndMvaTrainer.h
Go to the documentation of this file.
1 /* **********************************************
2  * MVA classifiers trainers interface. *
3  * Author: M. Babai *
4  * M.Babai@rug.nl *
5  * Version: 0.1 beta1. *
6  * LICENSE: *
7  * **********************************************
8  */
9 //#pragma once
10 #ifndef PND_MVA_TRAINER_H
11 #define PND_MVA_TRAINER_H
12 
13 // C++ includes
14 #include <cassert>
15 #include <limits>
16 #include <ctime>
17 #include <iomanip>
18 #include <cstdlib>
19 
20 // ROOT and PandaRoot
21 #include "TMVA/Tools.h"
22 #include "TMVA/PDEFoam.h"
23 #include "TMVA/Event.h"
24 
25 class TRandom3;
26 
27 // Local includes
28 #include "PndMvaDataSet.h"
29 #include "PndMvaUtil.h"
30 
31 #define TRAIN_INC_FOAM 0
32 
34 {
35  //==============================================
36  //================ Public =======================
37  public:
38 
45  explicit PndMvaTrainer(std::vector< std::pair<std::string, std::vector<float>*> > const& InputEvtsParam,
46  std::vector<std::string> const& ClassNames,
47  std::vector<std::string> const& VarNames,
48  bool trim = true);
49 
56  explicit PndMvaTrainer(std::string const& InPut,
57  std::vector<std::string> const& ClassNames,
58  std::vector<std::string> const& VarNames,
59  bool trim = true);
60 
62  virtual ~PndMvaTrainer();
63 
65  virtual void Train() = 0;
66 
71  virtual void storeWeights() = 0;
72 
77  void SetTestSetSize(size_t percent = 50);
78 
83  void SetTestSet(std::set <size_t> const& testSet);
84 
89 
94  void PCATransForm();
95 
100  inline void SetOutPutFile(std::string const& outFile);
101 
106  void WriteErroVect(std::string const& FileName) const;
107 
113  inline std::vector <StepError> const& GetErrorValues() const;
114 
118  virtual void Initialize();
119 
124  inline std::set <size_t> const& GetTestEvetIdx() const;
125 
130  inline std::vector<PndMvaClass> const& GetClasses() const;
131 
136  inline std::vector<PndMvaVariable> const& GetVariables() const;
137 
141  virtual void EvalClassifierError();
142 
143  inline size_t GetRndSeed() const;
144  inline void SetRndSeed(size_t const sd);
145 
146  //______________________________________________
147  //================ Protected ===================
148  protected:
152  inline void SetAppType(AppType t);
153 
157  void WriteToWeightFile(std::vector< std::pair<std::string,
158  std::vector<float>*> > const& weights) const;
159 
160 #if (TRAIN_INC_FOAM > 0)
161 
167  void WriteToWeightFile(std::vector<TMVA::PDEFoam*> const& foams) const;
168 #endif
169  //void WriteDataSetToOutFile();
170 
172  std::set <size_t> m_testSet_indices;
173 
176 
178  std::vector <StepError> m_StepErro;
179 
181  std::string m_outFile;
182 
184  size_t m_RND_seed;
185 
186  void splitTetsSet();
187  //______________________________________________
188  //================ Private =====================
189  private:
191  PndMvaTrainer(PndMvaTrainer const& other);
192  PndMvaTrainer& operator=(PndMvaTrainer const& other);
193 
194  // trim or not
195  bool m_trim;
197 };// End of class definition.
198 
199 //========================= Inline implementations =================
200 inline size_t PndMvaTrainer::GetRndSeed() const
201 {
202  return this->m_RND_seed;
203 };
204 
205 inline void PndMvaTrainer::SetRndSeed(size_t const sd)
206 {
207  this->m_RND_seed = sd;
208 };
209 
210 inline void PndMvaTrainer::SetOutPutFile(std::string const& outFile)
211 {
212  m_outFile = outFile;
213 };
214 
216 {
218 };
219 
220 inline std::set <size_t> const& PndMvaTrainer::GetTestEvetIdx() const
221 {
222  return m_testSet_indices;
223 };
224 
226 inline std::vector<PndMvaClass> const& PndMvaTrainer::GetClasses() const
227 {
228  return m_dataSets.GetClasses();
229 };
230 
232 inline std::vector<PndMvaVariable> const& PndMvaTrainer::GetVariables() const
233 {
234  return m_dataSets.GetVars();
235 };
236 
237 // @return List of evaluation objects.
238 inline std::vector <StepError> const& PndMvaTrainer::GetErrorValues() const
239 {
240  return m_StepErro;
241 };
242 #endif
size_t m_RND_seed
Random seed.
std::vector< StepError > const & GetErrorValues() const
TString outFile
Definition: hit_dirc.C:17
void SetTestSetSize(size_t percent=50)
void PCATransForm()
void SetRndSeed(size_t const sd)
PndMvaDataSet m_dataSets
Data set. Holds event values.
size_t m_testSetSize
void NormalizeData(NormType t=NONORM)
virtual void EvalClassifierError()
TString FileName
virtual void Initialize()
virtual ~PndMvaTrainer()
Destructor.
void SetAppType(AppType t)
virtual void storeWeights()=0
void SetAppType(AppType t)
void WriteErroVect(std::string const &FileName) const
NormType
Definition: PndMvaDataSet.h:48
virtual void Train()=0
Derived classes need to implement this methode.
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
std::set< size_t > m_testSet_indices
Indices of the test set.
PndMvaTrainer & operator=(PndMvaTrainer const &other)
std::string m_outFile
Output filename.
std::set< size_t > const & GetTestEvetIdx() const
std::vector< StepError > m_StepErro
Container to keep per step error values.
void WriteToWeightFile(std::vector< std::pair< std::string, std::vector< float > * > > const &weights) const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
TTree * t
Definition: bump_analys.C:13
void SetOutPutFile(std::string const &outFile)
std::vector< PndMvaVariable > const & GetVariables() const
Get the list of available variables.
size_t GetRndSeed() const
void splitTetsSet()
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
AppType
Definition: PndMvaDataSet.h:38
PndMvaTrainer(std::vector< std::pair< std::string, std::vector< float > * > > const &InputEvtsParam, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetTestSet(std::set< size_t > const &testSet)