mlpack  master
information_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
26 {
27  public:
34  static double Evaluate(const arma::Row<size_t>& labels,
35  const size_t numClasses)
36  {
37  // Edge case: if there are no elements, the gain is zero.
38  if (labels.n_elem == 0)
39  return 0.0;
40 
41  // Count the number of elements in each class.
42  arma::Col<size_t> counts(numClasses);
43  counts.zeros();
44  for (size_t i = 0; i < labels.n_elem; ++i)
45  counts[labels[i]]++;
46 
47  // Calculate the information gain.
48  double gain = 0.0;
49  for (size_t i = 0; i < numClasses; ++i)
50  {
51  const double f = ((double) counts[i] / (double) labels.n_elem);
52  if (f > 0.0)
53  gain += f * std::log2(f);
54  }
55 
56  return gain;
57  }
58 
66  static double Range(const size_t numClasses)
67  {
68  // The best possible case gives an information gain of 0. The worst
69  // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
70  // log2(1/n) = -log2(n). So, the range is log2(n).
71  return std::log2(numClasses);
72  }
73 };
74 
75 } // namespace tree
76 } // namespace mlpack
77 
78 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const arma::Row< size_t > &labels, const size_t numClasses)
Given a set of labels, calculate the information gain of those labels.
The standard information gain criterion, used for calculating gain in decision trees.