47 return fROCCurves.get();
56 MsgLogger fLogger(
"HyperParameterOptimisation");
58 for(
UInt_t j=0; j<fFoldParameters.size(); ++j) {
59 fLogger<<kHEADER<<
"===========================================================" <<
Endl;
60 fLogger<<kINFO<<
"Optimisation for " << fMethodName <<
" fold " << j+1 <<
Endl;
62 for(
auto &it : fFoldParameters.at(j)) {
63 fLogger<<kINFO<< it.first <<
" " << it.second <<
Endl;
73 fFomType(
"Separation"),
77 fClassifier(new
TMVA::
Factory(
"HyperParameterOptimisation",
"!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
92 fDataLoader->MakeKFoldDataSet(fNumFolds);
99 for (
auto &meth : fMethods) {
106 fDataLoader->MakeKFoldDataSet(fNumFolds);
109 fResults.fMethodName = methodName;
111 for (
UInt_t i = 0; i < fNumFolds; ++i) {
113 TString foldTitle = methodTitle;
120 auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
122 auto params = smethod->OptimizeTuningParameters(fFomType, fFitType);
123 fResults.fFoldParameters.push_back(params);
127 fClassifier->DeleteAllMethods();
129 fClassifier->fMethodsMap.clear();
HyperParameterOptimisationResult()
MsgLogger & Endl(MsgLogger &ml)
A TMultiGraph is a collection of TGraph (or derived) objects.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
void SetNumFolds(UInt_t folds)
HyperParameterOptimisation(DataLoader *dataloader)
~HyperParameterOptimisation()
This is the main MVA steering class.
ostringstream derivative to redirect and format output
Abstract ClassifierFactory template that handles arbitrary types.
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
~HyperParameterOptimisationResult()
static void EnableOutput()