10 #ifndef PND_MVA_TRAINER_H
11 #define PND_MVA_TRAINER_H
21 #include "TMVA/Tools.h"
22 #include "TMVA/PDEFoam.h"
23 #include "TMVA/Event.h"
31 #define TRAIN_INC_FOAM 0
45 explicit PndMvaTrainer(std::vector< std::pair<std::string, std::vector<float>*> >
const& InputEvtsParam,
46 std::vector<std::string>
const& ClassNames,
47 std::vector<std::string>
const& VarNames,
57 std::vector<std::string>
const& ClassNames,
58 std::vector<std::string>
const& VarNames,
65 virtual void Train() = 0;
83 void SetTestSet(std::set <size_t>
const& testSet);
130 inline std::vector<PndMvaClass>
const&
GetClasses()
const;
136 inline std::vector<PndMvaVariable>
const&
GetVariables()
const;
158 std::vector<float>*> >
const& weights)
const;
160 #if (TRAIN_INC_FOAM > 0)
size_t m_RND_seed
Random seed.
std::vector< StepError > const & GetErrorValues() const
void SetTestSetSize(size_t percent=50)
void SetRndSeed(size_t const sd)
PndMvaDataSet m_dataSets
Data set. Holds event values.
void NormalizeData(NormType t=NONORM)
virtual void EvalClassifierError()
virtual void Initialize()
virtual ~PndMvaTrainer()
Destructor.
void SetAppType(AppType t)
virtual void storeWeights()=0
void SetAppType(AppType t)
void WriteErroVect(std::string const &FileName) const
virtual void Train()=0
Derived classes need to implement this methode.
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
std::set< size_t > m_testSet_indices
Indices of the test set.
PndMvaTrainer & operator=(PndMvaTrainer const &other)
std::string m_outFile
Output filename.
std::set< size_t > const & GetTestEvetIdx() const
std::vector< StepError > m_StepErro
Container to keep per step error values.
void WriteToWeightFile(std::vector< std::pair< std::string, std::vector< float > * > > const &weights) const
std::vector< PndMvaClass > const & GetClasses() const
Get the list of available classes (labels).
void SetOutPutFile(std::string const &outFile)
std::vector< PndMvaVariable > const & GetVariables() const
Get the list of available variables.
size_t GetRndSeed() const
std::vector< PndMvaVariable > const & GetVars() const
Get the list of available variables.
PndMvaTrainer(std::vector< std::pair< std::string, std::vector< float > * > > const &InputEvtsParam, std::vector< std::string > const &ClassNames, std::vector< std::string > const &VarNames, bool trim=true)
void SetTestSet(std::set< size_t > const &testSet)