39 return fROCCurves.get();
46 for(
auto &roc:fROCs) avg+=roc.second;
47 return avg/fROCs.size();
67 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
69 fLogger << kINFO <<
Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
71 fLogger << kINFO <<
"------------------------" <<
Endl;
72 fLogger << kINFO <<
Form(
"Average ROC-Int : %.4f",GetROCAverage()) <<
Endl;
73 fLogger << kINFO <<
Form(
"Std-Dev ROC-Int : %.4f",GetROCStandardDeviation()) <<
Endl;
82 fROCCurves->
Draw(
"AL");
83 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
84 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
85 Float_t adjust=1+fROCs.size()*0.01;
87 c->
SetTitle(
"Cross Validation ROC Curves");
94 fNumFolds(5),fClassifier(new
TMVA::
Factory(
"CrossValidation",
"!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
110 fDataLoader->MakeKFoldDataSet(fNumFolds);
117 fResults.resize(fMethods.size());
118 for (
UInt_t j = 0; j < fMethods.size(); j++) {
121 TString methodTitle = fMethods[j].GetValue<
TString>(
"MethodTitle");
122 TString methodOptions = fMethods[j].GetValue<
TString>(
"MethodOptions");
123 if (methodName ==
"")
124 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
128 Log() << kINFO <<
"Evaluate method: " << methodTitle <<
Endl;
133 fDataLoader->MakeKFoldDataSet(fNumFolds);
138 for (
UInt_t i = 0; i < fNumFolds; ++i) {
139 Log() << kDEBUG <<
"Fold (" << methodTitle <<
"): " << i <<
Endl;
141 TString foldTitle = methodTitle;
142 foldTitle +=
"_fold";
146 MethodBase *smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
158 fResults[j].fROCs[i] = fClassifier->GetROCIntegral(fDataLoader->GetName(), methodTitle);
160 TGraph *
gr = fClassifier->GetROCCurve(fDataLoader->GetName(), methodTitle,
true);
164 fResults[j].fROCCurves->Add(
gr);
181 fClassifier->DeleteAllMethods();
182 fClassifier->fMethodsMap.clear();
186 Log() << kINFO <<
"Evaluation done." <<
Endl;
193 if (fResults.size() == 0)
194 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
Float_t GetROCAverage() const
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
MsgLogger & Endl(MsgLogger &ml)
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void SetTitle(const char *title="")
Set canvas title.
A TMultiGraph is a collection of TGraph (or derived) objects.
Virtual base Class for all MVA method.
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
virtual void SetTitle(const char *title="")
Set graph title.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
void SetNumFolds(UInt_t i)
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
virtual void ParseOptions()
Method to parse the internal option string.
void DeleteResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
delete the results stored for this particular Method instance.
char * Form(const char *fmt,...)
Float_t GetROCStandardDeviation() const
const TString & GetMethodName() const
const std::vector< CrossValidationResult > & GetResults() const
This is the main MVA steering class.
virtual Double_t GetSignificance() const
compute significance of mean difference
CrossValidation(DataLoader *loader)
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
ostringstream derivative to redirect and format output
virtual void Draw(Option_t *option="")
Draw a canvas.
Abstract ClassifierFactory template that handles arbitrary types.
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
std::map< UInt_t, Float_t > fROCs
A Graph is a graphics object made of two arrays X and Y with npoints each.
TCanvas * Draw(const TString name="CrossValidation") const
virtual Double_t GetTrainingEfficiency(const TString &)
Types::EAnalysisType GetAnalysisType() const
Double_t Sqrt(Double_t x)
static void EnableOutput()
virtual void TestClassification()
initialization
std::shared_ptr< TMultiGraph > fROCCurves
const char * Data() const