mlpack  master
svd_complete_incremental_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
13 #define MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack
18 {
19 namespace amf
20 {
21 
44 template <class MatType>
46 {
47  public:
57  double kw = 0,
58  double kh = 0)
59  : u(u), kw(kw), kh(kh)
60  {
61  // Nothing to do.
62  }
63 
72  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
73  {
74  // Initialize the current score counters.
75  currentUserIndex = 0;
76  currentItemIndex = 0;
77  }
78 
87  inline void WUpdate(const MatType& V,
88  arma::mat& W,
89  const arma::mat& H)
90  {
91  arma::mat deltaW;
92  deltaW.zeros(1, W.n_cols);
93 
94  // Loop until a non-zero entry is found.
95  while(true)
96  {
97  const double val = V(currentItemIndex, currentUserIndex);
98  // Update feature vector if current entry is non-zero and break the loop.
99  if (val != 0)
100  {
101  deltaW += (val - arma::dot(W.row(currentItemIndex),
102  H.col(currentUserIndex))) * H.col(currentUserIndex).t();
103 
104  // Add regularization.
105  if (kw != 0)
106  deltaW -= kw * W.row(currentItemIndex);
107  break;
108  }
109  }
110 
111  W.row(currentItemIndex) += u * deltaW;
112  }
113 
123  inline void HUpdate(const MatType& V,
124  const arma::mat& W,
125  arma::mat& H)
126  {
127  arma::mat deltaH;
128  deltaH.zeros(H.n_rows, 1);
129 
130  const double val = V(currentItemIndex, currentUserIndex);
131 
132  // Update H matrix based on the non-zero entry found in WUpdate function.
133  deltaH += (val - arma::dot(W.row(currentItemIndex),
134  H.col(currentUserIndex))) * W.row(currentItemIndex).t();
135  // Add regularization.
136  if (kh != 0)
137  deltaH -= kh * H.col(currentUserIndex);
138 
139  // Move on to the next entry.
141  if (currentUserIndex == V.n_rows)
142  {
143  currentUserIndex = 0;
144  currentItemIndex = (currentItemIndex + 1) % V.n_cols;
145  }
146 
147  H.col(currentUserIndex++) += u * deltaH;
148  }
149 
150  private:
152  double u;
154  double kw;
156  double kh;
157 
162 };
163 
166 
168 template<>
170 {
171  public:
173  double kw = 0,
174  double kh = 0)
175  : u(u), kw(kw), kh(kh), it(NULL)
176  {}
177 
179  {
180  delete it;
181  }
182 
183  void Initialize(const arma::sp_mat& dataset, const size_t rank)
184  {
185  (void)rank;
186  n = dataset.n_rows;
187  m = dataset.n_cols;
188 
189  it = new arma::sp_mat::const_iterator(dataset.begin());
190  isStart = true;
191  }
192 
202  inline void WUpdate(const arma::sp_mat& V,
203  arma::mat& W,
204  const arma::mat& H)
205  {
206  if (!isStart) (*it)++;
207  else isStart = false;
208 
209  if (*it == V.end())
210  {
211  delete it;
212  it = new arma::sp_mat::const_iterator(V.begin());
213  }
214 
215  size_t currentUserIndex = it->col();
216  size_t currentItemIndex = it->row();
217 
218  arma::mat deltaW(1, W.n_cols);
219  deltaW.zeros();
220 
221  deltaW += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
222  * arma::trans(H.col(currentUserIndex));
223  if (kw != 0) deltaW -= kw * W.row(currentItemIndex);
224 
225  W.row(currentItemIndex) += u*deltaW;
226  }
227 
237  inline void HUpdate(const arma::sp_mat& V,
238  const arma::mat& W,
239  arma::mat& H)
240  {
241  (void)V;
242 
243  arma::mat deltaH(H.n_rows, 1);
244  deltaH.zeros();
245 
246  size_t currentUserIndex = it->col();
247  size_t currentItemIndex = it->row();
248 
249  deltaH += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
250  * arma::trans(W.row(currentItemIndex));
251  if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
252 
253  H.col(currentUserIndex) += u * deltaH;
254  }
255 
256  private:
257  double u;
258  double kw;
259  double kh;
260 
261  size_t n;
262  size_t m;
263 
264  arma::sp_mat dummy;
265  arma::sp_mat::const_iterator* it;
266 
267  bool isStart;
268 }; // class SVDCompleteIncrementalLearning
269 
270 } // namespace amf
271 } // namespace mlpack
272 
273 #endif
274 
void HUpdate(const arma::sp_mat &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
This class computes SVD using complete incremental batch learning, as described in the following pape...
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double kw
Regularization parameter for matrix W.
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
Initialize the SVDCompleteIncrementalLearning class with the given parameters.
double kh
Regularization parameter for matrix H.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
void Initialize(const arma::sp_mat &dataset, const size_t rank)