FairRoot/PandaRoot
analysis/TMVATrainer.C
Go to the documentation of this file.
1 #include "TMVA/Factory.h"
2 #include "TFile.h"
3 #include "TTree.h"
4 #include "TString.h"
5 #include "TRegexp.h"
6 #include "TEventList.h"
7 #include "TLeaf.h"
8 
9 #include <iostream>
10 #include <map>
11 #include <vector>
12 #include <algorithm>
13 #include <utility>
14 
15 #ifndef TMVAgettype
16 #define TMVAgettype
17 
18 int gettype(TTree *t, TString varname)
19 {
20  if (t->GetBranch(varname)==0) return -1;
21 
22  TString leaftype = t->GetLeaf(varname)->GetTypeName();
23 
24  if (leaftype=="Float_t") return 0;
25  else if (leaftype=="Int_t") return 1;
26  else if (leaftype=="Bool_t") return 2;
27 
28  return -1;
29 }
30 
31 // ---------------------------------------------------------------
32 
33 int SplitString(TString s, TString delim, TString *toks, int maxtoks)
34 {
35  TObjArray *tok = s.Tokenize(delim);
36  int N = tok->GetEntries();
37  for (int i=0;i<N;++i)
38  if (i<maxtoks)
39  {
40  toks[i] = ((TObjString*)tok->At(i))->String();
41  toks[i].ReplaceAll("\t","");
42  toks[i] = toks[i].Strip(TString::kBoth);
43  }
44  return N;
45 }
46 
47 // ---------------------------------------------------------------
48 // for a string containing a cut, returns only the variables used,
49 // separated by blanks
51 {
52  TString toks[50];
53  int n=SplitString(vars, "&&", toks, 50);
54  TRegexp rvar("[_a-zA-Z][_a-zA-Z0-9]*");
55 
56  TString res=" ";
57 
58  for (int i=0;i<n;++i)
59  {
60  TString v = toks[i](rvar);
61  if (v!="")
62  {
63  if (res.Contains(" "+v+" ")) continue;
64  res+=v+" ";
65  }
66  }
67 
68  res = res.Strip(TString::kBoth);
69 
70  return res;
71 }
72 #endif
73 
74 // ---------------------------------------------------------------
75 
76 void TMVATrainer(TString fname="", TString treename="", TString sigcut="", TString vars="", TString algo="BDT", TString precut="")
77 {
78  if ( fname=="" || treename=="" || sigcut=="" || vars=="" )
79  {
80  cout << "USAGE:\n";
81  cout << "TMVATrainer.C( <input>, <tree>, <sigcut>, <vars>, [method], [precut] )\n\n";
82  cout << " <input> : input file name containing TTree <tree>\n";
83  cout << " <tree> : name of the TTree containing signal and background\n";
84  cout << " <sigcut> : cut separating signal from background -> bgcut = !(sigcut)\n";
85  cout << " <vars> : blank separated list with variables for training\n";
86  cout << " [method] : optional method: 'BDT' (default), 'MLP' or 'LH'\n";
87  cout << " [precut] : optional precut before training; has to be applied also before testing!\n\n";
88  cout << "EXAMPLE:\n";
89  cout << "root -l -b -q 'TMVATrainer.C(\"demodata.root\",\"ntp\",\"signal>0\",\"v1 v2 v3 v4 v5\",\"\",\"MLP\")'\n\n";
90  return;
91  }
92 
93  if (algo=="LH") algo="Likelihood";
94 
95  if (vars.Contains("&&")) vars = getFromCut(vars);
96  cout <<"Vars : "<<vars<<endl;
97 
98  TString bkgcut="!("+sigcut+")";
99 
100  if (precut!="")
101  {
102  cout <<"Precut : "<<precut<<endl;
103  sigcut += "&&" + precut;
104  bkgcut += "&&" + precut;
105  }
106 
107  TFile *f = TFile::Open(fname);
108  TTree *t =(TTree*) f->Get(treename);
109 
110  TString outfname = fname;
111  outfname.ReplaceAll(".root","");
112  TString tmvaname = outfname+"_"+treename;
113  outfname=outfname+"_tmva.root";
114 
115  TFile* outputFile = TFile::Open(outfname, "RECREATE" );
116 
117  TMVA::Factory *factory = new TMVA::Factory( tmvaname, outputFile, "Silent:!V:Transformations=I;N;D");
118 
119  TString toks[30];
120  int N = SplitString(vars," ",toks,30);
121 
122  for (int i=0;i<N;++i)
123  {
124  int btype = gettype(t, toks[i]);
125  if (btype==0 || btype==1)
126  factory->AddVariable(toks[i], btype==0?'F':'I');
127  }
128 
129  factory->AddTree( t, "Signal", 1.0, sigcut.Data());
130  factory->AddTree( t, "Background", 1.0, bkgcut.Data());
131 
132  int nsig = t->GetEntries(sigcut);
133  int nbkg = t->GetEntries(bkgcut);
134 
135  factory->PrepareTrainingAndTestTree( "", int(nsig*0.8), int(nbkg*0.8), int(nsig*0.19), int(nbkg*0.19));
136 
137  // old verion 4.1.3
138  if (algo=="BDT") factory->BookMethod( TMVA::Types::kBDT, "BDT", "!V:nTrees=400:BoostType=AdaBoost:nCuts=10:NNodesMax=10" );
139  // new verion 4.2.0
140  // if (algo=="BDT") factory->BookMethod( TMVA::Types::kBDT, "BDT", "!V:nTrees=400:BoostType=AdaBoost:nCuts=10:MaxDepth=3:UseFisherCuts:DoPreselection" );
141  else if (algo=="MLP") factory->BookMethod( TMVA::Types::kMLP, "MLP", "!V:NCycles=50:HiddenLayers=10,10:TestRate=5" );
142  else if (algo=="Likelihood") factory->BookMethod( TMVA::Types::kLikelihood, "Likelihood","!V:NAvEvtPerBin=50" );
143  else {cout <<"Unconfigured algorithm! Exiting..."<<endl; return;}
144 
145  factory->TrainAllMethods();
146  //factory->TestAllMethods();
147  //factory->EvaluateAllMethods();
148 
149  outputFile->Close();
150  delete factory;
151 }
Int_t res
Definition: anadigi.C:166
Int_t i
Definition: run_full.C:25
TLorentzVector s
Definition: Pnd2DStar.C:50
int n
__m128 v
Definition: P4_F32vec4.h:4
void TMVATrainer(TString fname="", TString treename="", TString sigcut="", TString vars="", TString algo="BDT", TString precut="")
int SplitString(TString s, TString delim, TString *toks, int maxtoks)
int gettype(TTree *t, TString varname)
TString vars[MAX]
Definition: autocutx.C:34
TFile * f
Definition: bump_analys.C:12
int nsig
Definition: toy_core.C:46
TString getFromCut(TString vars)
TTree * t
Definition: bump_analys.C:13