mlpack  master
decision_stump.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
13 #define MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace decision_stump {
19 
33 template<typename MatType = arma::mat>
35 {
36  public:
46  DecisionStump(const MatType& data,
47  const arma::Row<size_t>& labels,
48  const size_t classes,
49  const size_t bucketSize = 10);
50 
62  DecisionStump(const DecisionStump<>& other,
63  const MatType& data,
64  const arma::Row<size_t>& labels,
65  const arma::rowvec& weights);
66 
72  DecisionStump();
73 
84  void Train(const MatType& data,
85  const arma::Row<size_t>& labels,
86  const size_t classes,
87  const size_t bucketSize);
88 
100  void Train(const MatType& data,
101  const arma::Row<size_t>& labels,
102  const arma::rowvec& weights,
103  const size_t classes,
104  const size_t bucketSize);
105 
114  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
115 
117  size_t SplitDimension() const { return splitDimension; }
119  size_t& SplitDimension() { return splitDimension; }
120 
122  const arma::vec& Split() const { return split; }
124  arma::vec& Split() { return split; }
125 
127  const arma::Col<size_t> BinLabels() const { return binLabels; }
129  arma::Col<size_t>& BinLabels() { return binLabels; }
130 
132  template<typename Archive>
133  void Serialize(Archive& ar, const unsigned int /* version */);
134 
135  private:
137  size_t classes;
139  size_t bucketSize;
140 
144  arma::vec split;
146  arma::Col<size_t> binLabels;
147 
156  template<bool UseWeights, typename VecType>
157  double SetupSplitDimension(const VecType& dimension,
158  const arma::Row<size_t>& labels,
159  const arma::rowvec& weightD);
160 
168  template<typename VecType>
169  void TrainOnDim(const VecType& dimension,
170  const arma::Row<size_t>& labels);
171 
176  void MergeRanges();
177 
184  template<typename VecType>
185  double CountMostFreq(const VecType& subCols);
186 
192  template<typename VecType>
193  int IsDistinct(const VecType& featureRow);
194 
204  template<bool UseWeights, typename VecType, typename WeightVecType>
205  double CalculateEntropy(const VecType& labels,
206  const WeightVecType& weights);
207 
217  template<bool UseWeights>
218  void Train(const MatType& data,
219  const arma::Row<size_t>& labels,
220  const arma::rowvec& weights);
221 };
222 
223 } // namespace decision_stump
224 } // namespace mlpack
225 
226 #include "decision_stump_impl.hpp"
227 
228 #endif
void MergeRanges()
After the "split" matrix has been set up, merge ranges with identical class labels.
void Classify(const MatType &test, arma::Row< size_t > &predictedLabels)
Classification function.
size_t splitDimension
Stores the value of the dimension on which to split.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
arma::Col< size_t > & BinLabels()
Modify the labels for each split bin (be careful!).
size_t SplitDimension() const
Access the splitting dimension.
The core includes that mlpack expects; standard C++ includes and Armadillo.
int IsDistinct(const VecType &featureRow)
Returns 1 if all the values of featureRow are not same.
This class implements a decision stump.
double SetupSplitDimension(const VecType &dimension, const arma::Row< size_t > &labels, const arma::rowvec &weightD)
Sets up dimension as if it were splitting on it and finds entropy when splitting on dimension...
DecisionStump()
Create a decision stump without training.
const arma::vec & Split() const
Access the splitting values.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t classes, const size_t bucketSize)
Train the decision stump on the given data.
arma::Col< size_t > binLabels
Stores the labels for each splitting bin.
void Serialize(Archive &ar, const unsigned int)
Serialize the decision stump.
size_t bucketSize
The minimum number of points in a bucket.
double CalculateEntropy(const VecType &labels, const WeightVecType &weights)
Calculate the entropy of the given dimension.
arma::vec & Split()
Modify the splitting values (be careful!).
void TrainOnDim(const VecType &dimension, const arma::Row< size_t > &labels)
After having decided the dimension on which to split, train on that dimension.
size_t & SplitDimension()
Modify the splitting dimension (be careful!).
arma::vec split
Stores the splitting values after training.
const arma::Col< size_t > BinLabels() const
Access the labels for each split bin.
size_t classes
The number of classes (we must store this for boosting).
double CountMostFreq(const VecType &subCols)
Count the most frequently occurring element in subCols.