mlpack  master
decision_tree.hpp
Go to the documentation of this file.
1 
8 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
9 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
10 
11 #include <mlpack/prereqs.hpp>
12 #include "gini_gain.hpp"
15 
16 namespace mlpack {
17 namespace tree {
18 
26 template<typename FitnessFunction = GiniGain,
27  template<typename> class NumericSplitType = BestBinaryNumericSplit,
28  template<typename> class CategoricalSplitType = AllCategoricalSplit,
29  typename ElemType = double,
30  bool NoRecursion = false>
31 class DecisionTree :
32  public NumericSplitType<FitnessFunction>::template
33  AuxiliarySplitInfo<ElemType>,
34  public CategoricalSplitType<FitnessFunction>::template
35  AuxiliarySplitInfo<ElemType>
36 {
37  public:
39  typedef NumericSplitType<FitnessFunction> NumericSplit;
41  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
42 
55  template<typename MatType>
56  DecisionTree(const MatType& data,
57  const data::DatasetInfo& datasetInfo,
58  const arma::Row<size_t>& labels,
59  const size_t numClasses,
60  const size_t minimumLeafSize = 10);
61 
73  template<typename MatType>
74  DecisionTree(const MatType& data,
75  const arma::Row<size_t>& labels,
76  const size_t numClasses,
77  const size_t minimumLeafSize = 10);
78 
85  DecisionTree(const size_t numClasses = 1);
86 
93  DecisionTree(const DecisionTree& other);
94 
100  DecisionTree(DecisionTree&& other);
101 
108  DecisionTree& operator=(const DecisionTree& other);
109 
116 
120  ~DecisionTree();
121 
134  template<typename MatType>
135  void Train(const MatType& data,
136  const data::DatasetInfo& datasetInfo,
137  const arma::Row<size_t>& labels,
138  const size_t numClasses,
139  const size_t minimumLeafSize = 10);
140 
152  template<typename MatType>
153  void Train(const MatType& data,
154  const arma::Row<size_t>& labels,
155  const size_t numClasses,
156  const size_t minimumLeafSize = 10);
157 
164  template<typename VecType>
165  size_t Classify(const VecType& point) const;
166 
176  template<typename VecType>
177  void Classify(const VecType& point,
178  size_t& prediction,
179  arma::vec& probabilities) const;
180 
188  template<typename MatType>
189  void Classify(const MatType& data,
190  arma::Row<size_t>& predictions) const;
191 
202  template<typename MatType>
203  void Classify(const MatType& data,
204  arma::Row<size_t>& predictions,
205  arma::mat& probabilities) const;
206 
210  template<typename Archive>
211  void Serialize(Archive& ar, const unsigned int /* version */);
212 
214  size_t NumChildren() const { return children.size(); }
215 
217  const DecisionTree& Child(const size_t i) const { return *children[i]; }
219  DecisionTree& Child(const size_t i) { return *children[i]; }
220 
228  template<typename VecType>
229  size_t CalculateDirection(const VecType& point) const;
230 
231  private:
233  std::vector<DecisionTree*> children;
247 
251  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
253  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
255 
259  template<typename RowType>
260  void CalculateClassProbabilities(const RowType& labels,
261  const size_t numClasses);
262 };
263 
267 template<typename FitnessFunction = GiniGain,
268  template<typename> class NumericSplitType = BestBinaryNumericSplit,
269  template<typename> class CategoricalSplitType = AllCategoricalSplit,
270  typename ElemType = double>
271 using DecisionStump = DecisionTree<FitnessFunction,
272  NumericSplitType,
273  CategoricalSplitType,
274  ElemType,
275  false>;
276 
277 } // namespace tree
278 } // namespace mlpack
279 
280 // Include implementation.
281 #include "decision_tree_impl.hpp"
282 
283 #endif
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Auxiliary information for a dataset, including mappings to/from strings and the datatype of each dime...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
void CalculateClassProbabilities(const RowType &labels, const size_t numClasses)
Calculate the class probabilities of the given labels.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
This class implements a generic decision tree learner.
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
size_t NumChildren() const
Get the number of children.
arma::vec classProbabilities
This vector may hold different things.
std::vector< DecisionTree * > children
The vector of children.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
size_t splitDimension
The dimension this node splits on.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
NumericSplit::template AuxiliarySplitInfo< ElemType > NumericAuxiliarySplitInfo
Note that this class will also hold the members of the NumericSplit and CategoricalSplit AuxiliarySpl...
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
~DecisionTree()
Clean up memory.
CategoricalSplit::template AuxiliarySplitInfo< ElemType > CategoricalAuxiliarySplitInfo
void Serialize(Archive &ar, const unsigned int)
Serialize the tree.
size_t dimensionTypeOrMajorityClass
The type of the dimension that we have split on (if we are not a leaf).
DecisionTree(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
void Train(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Train the decision tree on the given data.