mlpack  master
all_categorical_split.hpp
Go to the documentation of this file.
1 
8 #ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
9 #define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
10 
11 #include <mlpack/prereqs.hpp>
12 
13 namespace mlpack {
14 namespace tree {
15 
22 template<typename FitnessFunction>
24 {
25  public:
26  // No extra info needed for split.
27  template<typename ElemType>
28  class AuxiliarySplitInfo { };
29 
50  template<typename VecType>
51  static double SplitIfBetter(
52  const double bestGain,
53  const VecType& data,
54  const size_t numCategories,
55  const arma::Row<size_t>& labels,
56  const size_t numClasses,
57  const size_t minimumLeafSize,
58  arma::Col<typename VecType::elem_type>& classProbabilities,
60 
67  template<typename ElemType>
68  static size_t NumChildren(const arma::Col<ElemType>& classProbabilities,
69  const AuxiliarySplitInfo<ElemType>& /* aux */);
70 
77  template<typename ElemType>
78  static size_t CalculateDirection(
79  const ElemType& point,
80  const arma::Col<ElemType>& classProbabilities,
81  const AuxiliarySplitInfo<ElemType>& /* aux */);
82 };
83 
84 } // namespace tree
85 } // namespace mlpack
86 
87 // Include implementation.
88 #include "all_categorical_split_impl.hpp"
89 
90 #endif
91 
static double SplitIfBetter(const double bestGain, const VecType &data, const size_t numCategories, const arma::Row< size_t > &labels, const size_t numClasses, const size_t minimumLeafSize, arma::Col< typename VecType::elem_type > &classProbabilities, AuxiliarySplitInfo< typename VecType::elem_type > &aux)
Check if we can split a node.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.
static size_t NumChildren(const arma::Col< ElemType > &classProbabilities, const AuxiliarySplitInfo< ElemType > &)
Return the number of children in the split.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
static size_t CalculateDirection(const ElemType &point, const arma::Col< ElemType > &classProbabilities, const AuxiliarySplitInfo< ElemType > &)
Calculate the direction a point should percolate to.