8 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 9 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 26 template<
typename FitnessFunction = GiniGain,
27 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
28 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
29 typename ElemType = double,
30 bool NoRecursion =
false>
32 public NumericSplitType<FitnessFunction>::template
33 AuxiliarySplitInfo<ElemType>,
34 public CategoricalSplitType<FitnessFunction>::template
35 AuxiliarySplitInfo<ElemType>
55 template<
typename MatType>
58 const arma::Row<size_t>& labels,
59 const size_t numClasses,
60 const size_t minimumLeafSize = 10);
73 template<
typename MatType>
75 const arma::Row<size_t>& labels,
76 const size_t numClasses,
77 const size_t minimumLeafSize = 10);
134 template<
typename MatType>
135 void Train(
const MatType& data,
137 const arma::Row<size_t>& labels,
138 const size_t numClasses,
139 const size_t minimumLeafSize = 10);
152 template<
typename MatType>
153 void Train(
const MatType& data,
154 const arma::Row<size_t>& labels,
155 const size_t numClasses,
156 const size_t minimumLeafSize = 10);
164 template<
typename VecType>
165 size_t Classify(
const VecType& point)
const;
176 template<
typename VecType>
179 arma::vec& probabilities)
const;
188 template<
typename MatType>
190 arma::Row<size_t>& predictions)
const;
202 template<
typename MatType>
204 arma::Row<size_t>& predictions,
205 arma::mat& probabilities)
const;
210 template<
typename Archive>
211 void Serialize(Archive& ar,
const unsigned int );
228 template<
typename VecType>
251 typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
253 typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
259 template<
typename RowType>
261 const size_t numClasses);
267 template<
typename FitnessFunction =
GiniGain,
270 typename ElemType =
double>
273 CategoricalSplitType,
281 #include "decision_tree_impl.hpp" size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
Auxiliary information for a dataset, including mappings to/from strings and the datatype of each dime...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
void CalculateClassProbabilities(const RowType &labels, const size_t numClasses)
Calculate the class probabilities of the given labels.
Linear algebra utility functions, generally performed on matrices or vectors.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
This class implements a generic decision tree learner.
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
size_t NumChildren() const
Get the number of children.
arma::vec classProbabilities
This vector may hold different things.
std::vector< DecisionTree * > children
The vector of children.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
size_t splitDimension
The dimension this node splits on.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
NumericSplit::template AuxiliarySplitInfo< ElemType > NumericAuxiliarySplitInfo
Note that this class will also hold the members of the NumericSplit and CategoricalSplit AuxiliarySpl...
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
~DecisionTree()
Clean up memory.
CategoricalSplit::template AuxiliarySplitInfo< ElemType > CategoricalAuxiliarySplitInfo
void Serialize(Archive &ar, const unsigned int)
Serialize the tree.
size_t dimensionTypeOrMajorityClass
The type of the dimension that we have split on (if we are not a leaf).
DecisionTree(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
void Train(const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
Train the decision tree on the given data.