mlpack  master
gini_impurity.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
14 #define MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
22 {
23  public:
24  static double Evaluate(const arma::Mat<size_t>& counts)
25  {
26  // We need to sum over the difference between the un-split node and the
27  // split nodes. First we'll calculate the number of elements in each split
28  // and total.
29  size_t numElem = 0;
30  arma::vec splitCounts(counts.n_cols);
31  for (size_t i = 0; i < counts.n_cols; ++i)
32  {
33  splitCounts[i] = arma::accu(counts.col(i));
34  numElem += splitCounts[i];
35  }
36 
37  // Corner case: if there are no elements, the impurity is zero.
38  if (numElem == 0)
39  return 0.0;
40 
41  arma::Col<size_t> classCounts = arma::sum(counts, 1);
42 
43  // Calculate the Gini impurity of the un-split node.
44  double impurity = 0.0;
45  for (size_t i = 0; i < classCounts.n_elem; ++i)
46  {
47  const double f = ((double) classCounts[i] / (double) numElem);
48  impurity += f * (1.0 - f);
49  }
50 
51  // Now calculate the impurity of the split nodes and subtract them from the
52  // overall impurity.
53  for (size_t i = 0; i < counts.n_cols; ++i)
54  {
55  if (splitCounts[i] > 0)
56  {
57  double splitImpurity = 0.0;
58  for (size_t j = 0; j < counts.n_rows; ++j)
59  {
60  const double f = ((double) counts(j, i) / (double) splitCounts[i]);
61  splitImpurity += f * (1.0 - f);
62  }
63 
64  impurity -= ((double) splitCounts[i] / (double) numElem) *
65  splitImpurity;
66  }
67  }
68 
69  return impurity;
70  }
71 
77  static double Range(const size_t numClasses)
78  {
79  // The best possible case is that only one class exists, which gives a Gini
80  // impurity of 0. The worst possible case is that the classes are evenly
81  // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
82  return 1.0 - (1.0 / double(numClasses));
83  }
84 };
85 
86 } // namespace tree
87 } // namespace mlpack
88 
89 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const arma::Mat< size_t > &counts)
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.