FairRoot/PandaRoot
PndMultiClassBdtTrain.h
Go to the documentation of this file.
1 /* ***************************************
2  * MultiClass BDTG Training functions *
3  * Author: M.Babai@rug.nl *
4  * Version: *
5  * LICENSE: *
6  * ***************************************
7  */
8 /*
9  * Note: This is just an interface to the original TMVA
10  * implementation. To find out the available options, please read TMVA
11  * manuals. In case of errors or wrong outputs produced by TMVA
12  * classifiers, try to read their mailing list and send your questions
13  * to the same list.
14  ******* VERY IMORTANT ****
15  * You NEED (TMVA version) > 4.1.X before this works.
16  */
17 //#pragma once
18 #ifndef PND_MULTICLASS_BDT_TRAIN_H
19 #define PND_MULTICLASS_BDT_TRAIN_H
20 
21 //Local includes
22 #include "PndMvaTrainer.h"
23 
24 // TMVA && ROOT
25 #include "TMVA/Factory.h"
26 #include "TMVA/Config.h"
27 
28 // Interface definition for Multiclass MLP trainers.
30 {
31  //----------------------------------------
32  //================== public ==============
33  public:
41  explicit PndMultiClassBdtTrain(std::string const& InPut,
42  std::vector<std::string> const& ClassNames,
43  std::vector<std::string> const& VarNames,
44  bool trim = true);
48  virtual ~PndMultiClassBdtTrain();
49 
53  void Train();
54 
59  void storeWeights();
60 
64  void Initialize();
65 
66  //______________________________________________
67  //====== Getters and setters.
68  // Set the name of the current job
69  inline void SetJobName (std::string const& name);
70 
71  // Set data Transformation scheme
72  inline void SetTransformation (std::string const& tran);
73 
74  // Set the options for the MLP alg. See TMVA manuals.
75  inline void SetBdtOptions (std::string const& opts);
76 
77  // Set the file name to store evaluation outputs.
78  inline void SetEvalFileName (std::string const& fname);
79 
80  // Set the directory where weights are stored.
81  inline void SetWeightsOutDir (std::string const& dirName);
82 
83  // Evaluate the classifier?
84  inline void SetEvaluation (bool evaluate);
85 
86  // Get the current job name.
87  inline std::string const& GetJobName() const;
88 
89  // Get the current transformation info.
90  inline std::string const& GetTransformation() const;
91 
92  // Get the classifier options.
93  inline std::string const& GetBdtOptions() const;
94 
95  // Get the name of the weight file.
96  inline std::string const& GetEvalFileName() const;
97 
98  // Get the directory where the weights are stored.
99  inline std::string const& GetWeightsOutDir() const;
100  //----------------------------------------
101 
102  //================== protected ============
103  //protected:
104  //----------------------------------------
105 
106  //================== private =============
107  private:
108  // To avoid mistakes.
111 
112  // Initialize mlp object and set the options.
113  void InitBdt();
114  // Add the variables to the TMVA factory object.
115  void AddVariables();
116 
117  //==============================
118  TMVA::Factory* m_factory;// TMVA factory
119  TFile* EvalFile; // To store evaluation file
120  std::string m_JName; //Job name
121  std::string m_transform;// Transformation opt.
122  std::string m_BdtOptions; // Bdt options.
123  std::string m_evalFileName; //evaluation file name.
124  std::string m_weightDirName;// Directory name to store weights.
126 };// End of interface definition.
127 //=============== inline functions implementation. ========
128 //__________________________________________
129 inline void PndMultiClassBdtTrain::SetJobName(std::string const& name)
130 {
131  this->m_JName = name;
132 };
133 
134 inline void PndMultiClassBdtTrain::SetTransformation(std::string const& tr)
135 {
136  this->m_transform = tr;
137 };
138 
139 inline void PndMultiClassBdtTrain::SetBdtOptions(std::string const& opt)
140 {
141  this->m_BdtOptions = opt;
142 };
143 
144 inline std::string const& PndMultiClassBdtTrain::GetJobName() const
145 {
146  return m_JName;
147 };
148 
149 inline std::string const& PndMultiClassBdtTrain::GetTransformation() const
150 {
151  return m_transform;
152 };
153 
154 inline std::string const& PndMultiClassBdtTrain::GetBdtOptions() const
155 {
156  return m_BdtOptions;
157 };
158 
159 inline void PndMultiClassBdtTrain::SetEvalFileName(std::string const& fname)
160 {
161  this->m_evalFileName = fname;
162 };
163 
164 inline std::string const& PndMultiClassBdtTrain::GetEvalFileName() const
165 {
166  return m_evalFileName;
167 };
168 
169 inline void PndMultiClassBdtTrain::SetWeightsOutDir (std::string const& dirName)
170 {
171  this->m_weightDirName = dirName;
172 };
173 
174 inline std::string const& PndMultiClassBdtTrain::GetWeightsOutDir() const
175 {
176  return m_weightDirName;
177 };
178 
179 inline void PndMultiClassBdtTrain::SetEvaluation(bool evaluate)
180 {
181  this->m_Evaluate = evaluate;
182 };
183 #endif
std::string const & GetBdtOptions() const
PndMultiClassBdtTrain & operator=(PndMultiClassBdtTrain const &oth)
void SetEvaluation(bool evaluate)
PndMultiClassBdtTrain(std::string const &InPut, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetTransformation(std::string const &tran)
void SetEvalFileName(std::string const &fname)
void SetJobName(std::string const &name)
TString name
std::string const & GetTransformation() const
void SetBdtOptions(std::string const &opts)
void SetWeightsOutDir(std::string const &dirName)
std::string const & GetEvalFileName() const
std::string const & GetJobName() const
virtual ~PndMultiClassBdtTrain()
std::string const & GetWeightsOutDir() const