mlpack  master
Public Types | Public Member Functions | Private Types | Private Member Functions | Private Attributes | List of all members
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion > Class Template Reference

This class implements a generic decision tree learner. More...

Inheritance diagram for mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >:
Inheritance graph
[legend]

Public Types

typedef CategoricalSplitType< FitnessFunction > CategoricalSplit
 Allow access to the categorical split type. More...
 
typedef NumericSplitType< FitnessFunction > NumericSplit
 Allow access to the numeric split type. More...
 

Public Member Functions

template<typename MatType >
 DecisionTree (const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
 Construct the decision tree on the given data and labels, where the data can be both numeric and categorical. More...
 
template<typename MatType >
 DecisionTree (const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
 Construct the decision tree on the given data and labels, assuming that the data is all of the numeric type. More...
 
 DecisionTree (const size_t numClasses=1)
 Construct a decision tree without training it. More...
 
 DecisionTree (const DecisionTree &other)
 Copy another tree. More...
 
 DecisionTree (DecisionTree &&other)
 Take ownership of another tree. More...
 
 ~DecisionTree ()
 Clean up memory. More...
 
template<typename VecType >
size_t CalculateDirection (const VecType &point) const
 Given a point and that this node is not a leaf, calculate the index of the child node this point would go towards. More...
 
const DecisionTreeChild (const size_t i) const
 Get the child of the given index. More...
 
DecisionTreeChild (const size_t i)
 Modify the child of the given index (be careful!). More...
 
template<typename VecType >
size_t Classify (const VecType &point) const
 Classify the given point, using the entire tree. More...
 
template<typename VecType >
void Classify (const VecType &point, size_t &prediction, arma::vec &probabilities) const
 Classify the given point and also return estimates of the probability for each class in the given vector. More...
 
template<typename MatType >
void Classify (const MatType &data, arma::Row< size_t > &predictions) const
 Classify the given points, using the entire tree. More...
 
template<typename MatType >
void Classify (const MatType &data, arma::Row< size_t > &predictions, arma::mat &probabilities) const
 Classify the given points and also return estimates of the probabilities for each class in the given matrix. More...
 
size_t NumChildren () const
 Get the number of children. More...
 
DecisionTreeoperator= (const DecisionTree &other)
 Copy another tree. More...
 
DecisionTreeoperator= (DecisionTree &&other)
 Take ownership of another tree. More...
 
template<typename Archive >
void Serialize (Archive &ar, const unsigned int)
 Serialize the tree. More...
 
template<typename MatType >
void Train (const MatType &data, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
 Train the decision tree on the given data. More...
 
template<typename MatType >
void Train (const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize=10)
 Train the decision tree on the given data, assuming that all dimensions are numeric. More...
 

Private Types

typedef CategoricalSplit::template AuxiliarySplitInfo< ElemType > CategoricalAuxiliarySplitInfo
 
typedef NumericSplit::template AuxiliarySplitInfo< ElemType > NumericAuxiliarySplitInfo
 Note that this class will also hold the members of the NumericSplit and CategoricalSplit AuxiliarySplitInfo classes, since it inherits from them. More...
 

Private Member Functions

template<typename RowType >
void CalculateClassProbabilities (const RowType &labels, const size_t numClasses)
 Calculate the class probabilities of the given labels. More...
 

Private Attributes

std::vector< DecisionTree * > children
 The vector of children. More...
 
arma::vec classProbabilities
 This vector may hold different things. More...
 
size_t dimensionTypeOrMajorityClass
 The type of the dimension that we have split on (if we are not a leaf). More...
 
size_t splitDimension
 The dimension this node splits on. More...
 

Detailed Description

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
class mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >

This class implements a generic decision tree learner.

Its behavior can be controlled via its template arguments.

The class inherits from the auxiliary split information in order to prevent an empty auxiliary split information struct from taking any extra size.

Definition at line 31 of file decision_tree.hpp.

Member Typedef Documentation

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
typedef CategoricalSplit::template AuxiliarySplitInfo<ElemType> mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::CategoricalAuxiliarySplitInfo
private

Definition at line 254 of file decision_tree.hpp.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
typedef CategoricalSplitType<FitnessFunction> mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::CategoricalSplit

Allow access to the categorical split type.

Definition at line 41 of file decision_tree.hpp.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
typedef NumericSplit::template AuxiliarySplitInfo<ElemType> mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::NumericAuxiliarySplitInfo
private

Note that this class will also hold the members of the NumericSplit and CategoricalSplit AuxiliarySplitInfo classes, since it inherits from them.

We'll define some convenience typedefs here.

Definition at line 252 of file decision_tree.hpp.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
typedef NumericSplitType<FitnessFunction> mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::NumericSplit

Allow access to the numeric split type.

Definition at line 39 of file decision_tree.hpp.

Constructor & Destructor Documentation

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename MatType >
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::DecisionTree ( const MatType &  data,
const data::DatasetInfo datasetInfo,
const arma::Row< size_t > &  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10 
)

Construct the decision tree on the given data and labels, where the data can be both numeric and categorical.

Setting minimumLeafSize too small may cause the tree to overfit, but setting it too large may cause it to underfit.

Parameters
dataDataset to train on.
datasetInfoType information for each dimension of the dataset.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
minimumLeafSizeMinimum number of points in each leaf node.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename MatType >
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::DecisionTree ( const MatType &  data,
const arma::Row< size_t > &  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10 
)

Construct the decision tree on the given data and labels, assuming that the data is all of the numeric type.

Setting minimumLeafSize too small may cause the tree to overfit, but setting it too large may cause it to underfit.

Parameters
dataDataset to train on.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
minimumLeafSizeMinimum number of points in each leaf node.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::DecisionTree ( const size_t  numClasses = 1)

Construct a decision tree without training it.

It will be a leaf node with equal probabilities for each class.

Parameters
numClassesNumber of classes in the dataset.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::DecisionTree ( const DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion > &  other)

Copy another tree.

This may use a lot of memory—be sure that it's what you want to do.

Parameters
otherTree to copy.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::DecisionTree ( DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion > &&  other)

Take ownership of another tree.

Parameters
otherTree to take ownership of.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::~DecisionTree ( )

Clean up memory.

Member Function Documentation

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename RowType >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::CalculateClassProbabilities ( const RowType &  labels,
const size_t  numClasses 
)
private

Calculate the class probabilities of the given labels.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename VecType >
size_t mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::CalculateDirection ( const VecType &  point) const

Given a point and that this node is not a leaf, calculate the index of the child node this point would go towards.

This method is primarily used by the Classify() function, but it can be used in a standalone sense too.

Parameters
pointPoint to classify.

Referenced by mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Child().

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
const DecisionTree& mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Child ( const size_t  i) const
inline
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
DecisionTree& mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Child ( const size_t  i)
inline
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename VecType >
size_t mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Classify ( const VecType &  point) const

Classify the given point, using the entire tree.

The predicted label is returned.

Parameters
pointPoint to classify.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename VecType >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Classify ( const VecType &  point,
size_t &  prediction,
arma::vec &  probabilities 
) const

Classify the given point and also return estimates of the probability for each class in the given vector.

Parameters
pointPoint to classify.
predictionThis will be set to the predicted class of the point.
probabilitiesThis will be filled with class probabilities for the point.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename MatType >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Classify ( const MatType &  data,
arma::Row< size_t > &  predictions 
) const

Classify the given points, using the entire tree.

The predicted labels for each point are stored in the given vector.

Parameters
dataSet of points to classify.
predictionsThis will be filled with predictions for each point.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename MatType >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Classify ( const MatType &  data,
arma::Row< size_t > &  predictions,
arma::mat &  probabilities 
) const

Classify the given points and also return estimates of the probabilities for each class in the given matrix.

The predicted labels for each point are stored in the given vector.

Parameters
dataSet of points to classify.
predictionsThis will be filled with predictions for each point.
probabilitiesThis will be filled with class probabilities for each point.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
size_t mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::NumChildren ( ) const
inline
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
DecisionTree& mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::operator= ( const DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion > &  other)

Copy another tree.

This may use a lot of memory—be sure that it's what you want to do.

Parameters
otherTree to copy.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
DecisionTree& mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::operator= ( DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion > &&  other)

Take ownership of another tree.

Parameters
otherTree to take ownership of.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename Archive >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Serialize ( Archive &  ar,
const unsigned  int 
)

Serialize the tree.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename MatType >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Train ( const MatType &  data,
const data::DatasetInfo datasetInfo,
const arma::Row< size_t > &  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10 
)

Train the decision tree on the given data.

This will overwrite the existing model. The data may have numeric and categorical types, specified by the datasetInfo parameter. Setting minimumLeafSize too small may cause the tree to overfit, but setting it too large may cause it to underfit.

Parameters
dataDataset to train on.
datasetInfoType information for each dimension.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
minimumLeafSizeMinimum number of points in each leaf node.
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
template<typename MatType >
void mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::Train ( const MatType &  data,
const arma::Row< size_t > &  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10 
)

Train the decision tree on the given data, assuming that all dimensions are numeric.

This will overwrite the given model. Setting minimumLeafSize too small may cause the tree to overfit, but setting it too large may cause it to underfit.

Parameters
dataDataset to train on.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
minimumLeafSizeMinimum number of points in each leaf node.

Member Data Documentation

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
std::vector<DecisionTree*> mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::children
private
template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
arma::vec mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::classProbabilities
private

This vector may hold different things.

If the node has no children, then it is guaranteed to hold the probabilities of each class. If the node has children, then it may be used arbitrarily by the split type's CalculateDirection() function and may not necessarily hold class probabilities.

Definition at line 246 of file decision_tree.hpp.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
size_t mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::dimensionTypeOrMajorityClass
private

The type of the dimension that we have split on (if we are not a leaf).

If we are a leaf, then this is the index of the majority class.

Definition at line 238 of file decision_tree.hpp.

template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename ElemType = double, bool NoRecursion = false>
size_t mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, ElemType, NoRecursion >::splitDimension
private

The dimension this node splits on.

Definition at line 235 of file decision_tree.hpp.


The documentation for this class was generated from the following file: