mlpack  master
validation_RMSE_termination.hpp
Go to the documentation of this file.
1 
12 #ifndef _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
13 #define _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack
18 {
19 namespace amf
20 {
21 
36 template <class MatType>
38 {
39  public:
50  size_t num_test_points,
51  double tolerance = 1e-5,
52  size_t maxIterations = 10000,
53  size_t reverseStepTolerance = 3)
56  num_test_points(num_test_points),
58  {
59  size_t n = V.n_rows;
60  size_t m = V.n_cols;
61 
62  // initialize validation set matrix
63  test_points.zeros(num_test_points, 3);
64 
65  // fill validation set matrix with random chosen entries
66  for(size_t i = 0; i < num_test_points; i++)
67  {
68  double t_val;
69  size_t t_row;
70  size_t t_col;
71 
72  // pick a random non-zero entry
73  do
74  {
75  t_row = rand() % n;
76  t_col = rand() % m;
77  } while((t_val = V(t_row, t_col)) == 0);
78 
79  // add the entry to the validation set
80  test_points(i, 0) = t_row;
81  test_points(i, 1) = t_col;
82  test_points(i, 2) = t_val;
83 
84  // nullify the added entry from data matrix (training set)
85  V(t_row, t_col) = 0;
86  }
87  }
88 
94  void Initialize(const MatType& /* V */)
95  {
96  iteration = 1;
97 
98  rmse = DBL_MAX;
99  rmseOld = DBL_MAX;
100 
101  c_index = 0;
102  c_indexOld = 0;
103 
104  reverseStepCount = 0;
105  isCopy = false;
106  }
107 
114  bool IsConverged(arma::mat& W, arma::mat& H)
115  {
116  arma::mat WH;
117 
118  WH = W * H;
119 
120  // compute validation RMSE
121  if (iteration != 0)
122  {
123  rmseOld = rmse;
124  rmse = 0;
125  for(size_t i = 0; i < num_test_points; i++)
126  {
127  size_t t_row = test_points(i, 0);
128  size_t t_col = test_points(i, 1);
129  double t_val = test_points(i, 2);
130  double temp = (t_val - WH(t_row, t_col));
131  temp *= temp;
132  rmse += temp;
133  }
135  rmse = sqrt(rmse);
136  }
137 
138  // increment iteration count
139  iteration++;
140 
141  // if RMSE tolerance is not satisfied
142  if ((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
143  {
144  // check if this is a first of successive drops
145  if (reverseStepCount == 0 && isCopy == false)
146  {
147  // store a copy of W and H matrix
148  isCopy = true;
149  this->W = W;
150  this->H = H;
151  // store residue values
153  c_index = rmse;
154  }
155  // increase successive drop count
157  }
158  // if tolerance is satisfied
159  else
160  {
161  // initialize successive drop count
162  reverseStepCount = 0;
163  // if residue is droped below minimum scrap stored values
164  if (rmse <= c_indexOld && isCopy == true)
165  {
166  isCopy = false;
167  }
168  }
169 
170  // check if termination criterion is met
172  {
173  // if stored values are present replace them with current value as they
174  // represent the minimum residue point
175  if (isCopy)
176  {
177  W = this->W;
178  H = this->H;
179  rmse = c_index;
180  }
181  return true;
182  }
183  else return false;
184  }
185 
187  const double& Index() const { return rmse; }
188 
190  const size_t& Iteration() const { return iteration; }
191 
193  const size_t& NumTestPoints() const { return num_test_points; }
194 
196  const size_t& MaxIterations() const { return maxIterations; }
197  size_t& MaxIterations() { return maxIterations; }
198 
200  const double& Tolerance() const { return tolerance; }
201  double& Tolerance() { return tolerance; }
202 
203  private:
205  double tolerance;
210 
212  size_t iteration;
213 
215  arma::mat test_points;
216 
218  double rmseOld;
219  double rmse;
220 
225 
228  bool isCopy;
229 
231  arma::mat W;
232  arma::mat H;
233  double c_indexOld;
234  double c_index;
235 }; // class ValidationRMSETermination
236 
237 } // namespace amf
238 } // namespace mlpack
239 
240 
241 #endif // _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
size_t reverseStepTolerance
tolerance on successive residue drops
bool isCopy
indicates whether a copy of information is available which corresponds to minimum residue point ...
const size_t & MaxIterations() const
Access upper limit of iteration count.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
size_t num_test_points
number of validation test points
The core includes that mlpack expects; standard C++ includes and Armadillo.
const size_t & Iteration() const
Get current iteration count.
This class implements validation termination policy based on RMSE index.
const double & Index() const
Get current value of residue.
const double & Tolerance() const
Access tolerance value.
arma::mat W
variables to store information of minimum residue point
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)
Create a validation set according to given parameters and nullifies this set in data matrix(training ...
const size_t & NumTestPoints() const
Get number of validation points.
void Initialize(const MatType &)
Initializes the termination policy before stating the factorization.
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.