mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
18 #include "information_gain.hpp"
21 #include "all_dimension_select.hpp"
22 #include <type_traits>
23 
24 namespace mlpack {
25 namespace tree {
26 
34 template<typename FitnessFunction = GiniGain,
35  template<typename> class NumericSplitType = BestBinaryNumericSplit,
36  template<typename> class CategoricalSplitType = AllCategoricalSplit,
37  typename DimensionSelectionType = AllDimensionSelect,
38  typename ElemType = double,
39  bool NoRecursion = false>
40 class DecisionTree :
41  public NumericSplitType<FitnessFunction>::template
42  AuxiliarySplitInfo<ElemType>,
43  public CategoricalSplitType<FitnessFunction>::template
44  AuxiliarySplitInfo<ElemType>
45 {
46  public:
48  typedef NumericSplitType<FitnessFunction> NumericSplit;
50  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
52  typedef DimensionSelectionType DimensionSelection;
53 
71  template<typename MatType, typename LabelsType>
72  DecisionTree(MatType data,
73  const data::DatasetInfo& datasetInfo,
74  LabelsType labels,
75  const size_t numClasses,
76  const size_t minimumLeafSize = 10,
77  const double minimumGainSplit = 1e-7,
78  const size_t maximumDepth = 0,
79  DimensionSelectionType dimensionSelector =
80  DimensionSelectionType());
81 
98  template<typename MatType, typename LabelsType>
99  DecisionTree(MatType data,
100  LabelsType labels,
101  const size_t numClasses,
102  const size_t minimumLeafSize = 10,
103  const double minimumGainSplit = 1e-7,
104  const size_t maximumDepth = 0,
105  DimensionSelectionType dimensionSelector =
106  DimensionSelectionType());
107 
127  template<typename MatType, typename LabelsType, typename WeightsType>
128  DecisionTree(
129  MatType data,
130  const data::DatasetInfo& datasetInfo,
131  LabelsType labels,
132  const size_t numClasses,
133  WeightsType weights,
134  const size_t minimumLeafSize = 10,
135  const double minimumGainSplit = 1e-7,
136  const size_t maximumDepth = 0,
137  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
138  const std::enable_if_t<arma::is_arma_type<
139  typename std::remove_reference<WeightsType>::type>::value>* = 0);
140 
159  template<typename MatType, typename LabelsType, typename WeightsType>
160  DecisionTree(
161  const DecisionTree& other,
162  MatType data,
163  const data::DatasetInfo& datasetInfo,
164  LabelsType labels,
165  const size_t numClasses,
166  WeightsType weights,
167  const size_t minimumLeafSize = 10,
168  const double minimumGainSplit = 1e-7,
169  const std::enable_if_t<arma::is_arma_type<
170  typename std::remove_reference<WeightsType>::type>::value>* = 0);
189  template<typename MatType, typename LabelsType, typename WeightsType>
190  DecisionTree(
191  MatType data,
192  LabelsType labels,
193  const size_t numClasses,
194  WeightsType weights,
195  const size_t minimumLeafSize = 10,
196  const double minimumGainSplit = 1e-7,
197  const size_t maximumDepth = 0,
198  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
199  const std::enable_if_t<arma::is_arma_type<
200  typename std::remove_reference<WeightsType>::type>::value>* = 0);
201 
220  template<typename MatType, typename LabelsType, typename WeightsType>
221  DecisionTree(
222  const DecisionTree& other,
223  MatType data,
224  LabelsType labels,
225  const size_t numClasses,
226  WeightsType weights,
227  const size_t minimumLeafSize = 10,
228  const double minimumGainSplit = 1e-7,
229  const size_t maximumDepth = 0,
230  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
231  const std::enable_if_t<arma::is_arma_type<
232  typename std::remove_reference<WeightsType>::type>::value>* = 0);
233 
240  DecisionTree(const size_t numClasses = 1);
241 
248  DecisionTree(const DecisionTree& other);
249 
255  DecisionTree(DecisionTree&& other);
256 
263  DecisionTree& operator=(const DecisionTree& other);
264 
271 
275  ~DecisionTree();
276 
297  template<typename MatType, typename LabelsType>
298  double Train(MatType data,
299  const data::DatasetInfo& datasetInfo,
300  LabelsType labels,
301  const size_t numClasses,
302  const size_t minimumLeafSize = 10,
303  const double minimumGainSplit = 1e-7,
304  const size_t maximumDepth = 0,
305  DimensionSelectionType dimensionSelector =
306  DimensionSelectionType());
307 
326  template<typename MatType, typename LabelsType>
327  double Train(MatType data,
328  LabelsType labels,
329  const size_t numClasses,
330  const size_t minimumLeafSize = 10,
331  const double minimumGainSplit = 1e-7,
332  const size_t maximumDepth = 0,
333  DimensionSelectionType dimensionSelector =
334  DimensionSelectionType());
335 
357  template<typename MatType, typename LabelsType, typename WeightsType>
358  double Train(MatType data,
359  const data::DatasetInfo& datasetInfo,
360  LabelsType labels,
361  const size_t numClasses,
362  WeightsType weights,
363  const size_t minimumLeafSize = 10,
364  const double minimumGainSplit = 1e-7,
365  const size_t maximumDepth = 0,
366  DimensionSelectionType dimensionSelector =
367  DimensionSelectionType(),
368  const std::enable_if_t<arma::is_arma_type<typename
369  std::remove_reference<WeightsType>::type>::value>* = 0);
370 
390  template<typename MatType, typename LabelsType, typename WeightsType>
391  double Train(MatType data,
392  LabelsType labels,
393  const size_t numClasses,
394  WeightsType weights,
395  const size_t minimumLeafSize = 10,
396  const double minimumGainSplit = 1e-7,
397  const size_t maximumDepth = 0,
398  DimensionSelectionType dimensionSelector =
399  DimensionSelectionType(),
400  const std::enable_if_t<arma::is_arma_type<typename
401  std::remove_reference<WeightsType>::type>::value>* = 0);
402 
409  template<typename VecType>
410  size_t Classify(const VecType& point) const;
411 
421  template<typename VecType>
422  void Classify(const VecType& point,
423  size_t& prediction,
424  arma::vec& probabilities) const;
425 
433  template<typename MatType>
434  void Classify(const MatType& data,
435  arma::Row<size_t>& predictions) const;
436 
447  template<typename MatType>
448  void Classify(const MatType& data,
449  arma::Row<size_t>& predictions,
450  arma::mat& probabilities) const;
451 
455  template<typename Archive>
456  void serialize(Archive& ar, const unsigned int /* version */);
457 
459  size_t NumChildren() const { return children.size(); }
460 
462  const DecisionTree& Child(const size_t i) const { return *children[i]; }
464  DecisionTree& Child(const size_t i) { return *children[i]; }
465 
468  size_t SplitDimension() const { return splitDimension; }
469 
477  template<typename VecType>
478  size_t CalculateDirection(const VecType& point) const;
479 
483  size_t NumClasses() const;
484 
485  private:
487  std::vector<DecisionTree*> children;
489  size_t splitDimension;
492  size_t dimensionTypeOrMajorityClass;
500  arma::vec classProbabilities;
501 
505  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
506  NumericAuxiliarySplitInfo;
507  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
508  CategoricalAuxiliarySplitInfo;
509 
513  template<bool UseWeights, typename RowType, typename WeightsRowType>
514  void CalculateClassProbabilities(const RowType& labels,
515  const size_t numClasses,
516  const WeightsRowType& weights);
517 
535  template<bool UseWeights, typename MatType>
536  double Train(MatType& data,
537  const size_t begin,
538  const size_t count,
539  const data::DatasetInfo& datasetInfo,
540  arma::Row<size_t>& labels,
541  const size_t numClasses,
542  arma::rowvec& weights,
543  const size_t minimumLeafSize,
544  const double minimumGainSplit,
545  const size_t maximumDepth,
546  DimensionSelectionType& dimensionSelector);
547 
564  template<bool UseWeights, typename MatType>
565  double Train(MatType& data,
566  const size_t begin,
567  const size_t count,
568  arma::Row<size_t>& labels,
569  const size_t numClasses,
570  arma::rowvec& weights,
571  const size_t minimumLeafSize,
572  const double minimumGainSplit,
573  const size_t maximumDepth,
574  DimensionSelectionType& dimensionSelector);
575 };
576 
580 template<typename FitnessFunction = GiniGain,
581  template<typename> class NumericSplitType = BestBinaryNumericSplit,
582  template<typename> class CategoricalSplitType = AllCategoricalSplit,
583  typename DimensionSelectType = AllDimensionSelect,
584  typename ElemType = double>
585 using DecisionStump = DecisionTree<FitnessFunction,
586  NumericSplitType,
587  CategoricalSplitType,
588  DimensionSelectType,
589  ElemType,
590  false>;
591 
600  double,
602 } // namespace tree
603 } // namespace mlpack
604 
605 // Include implementation.
606 #include "decision_tree_impl.hpp"
607 
608 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
size_t NumClasses() const
Get the number of classes in the tree.
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
This class implements a generic decision tree learner.
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
The standard information gain criterion, used for calculating gain in decision trees.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
size_t NumChildren() const
Get the number of children.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectType, ElemType, false > DecisionStump
Convenience typedef for decision stumps (single level decision trees).
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, double, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
This dimension selection policy allows any dimension to be selected for splitting.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).