mlpack  master
svd_incomplete_incremental_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
13 #define MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
14 
15 namespace mlpack
16 {
17 namespace amf
18 {
19 
44 {
45  public:
54  double kw = 0,
55  double kh = 0)
56  : u(u), kw(kw), kh(kh)
57  {
58  // Nothing to do.
59  }
60 
69  template<typename MatType>
70  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
71  {
72  // Set the current user to 0.
73  currentUserIndex = 0;
74  }
75 
85  template<typename MatType>
86  inline void WUpdate(const MatType& V,
87  arma::mat& W,
88  const arma::mat& H)
89  {
90  arma::mat deltaW;
91  deltaW.zeros(V.n_rows, W.n_cols);
92 
93  // Iterate through all the rating by this user to update corresponding item
94  // feature feature vector.
95  for (size_t i = 0; i < V.n_rows; ++i)
96  {
97  const double val = V(i, currentUserIndex);
98  // Update only if the rating is non-zero.
99  if (val != 0)
100  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
101  H.col(currentUserIndex).t();
102  // Add regularization.
103  if (kw != 0)
104  deltaW.row(i) -= kw * W.row(i);
105  }
106 
107  W += u * deltaW;
108  }
109 
118  template<typename MatType>
119  inline void HUpdate(const MatType& V,
120  const arma::mat& W,
121  arma::mat& H)
122  {
123  arma::vec deltaH;
124  deltaH.zeros(H.n_rows);
125 
126  // Iterate through all the rating by this user to update corresponding item
127  // feature feature vector.
128  for (size_t i = 0; i < V.n_rows; ++i)
129  {
130  const double val = V(i, currentUserIndex);
131  // Update only if the rating is non-zero.
132  if (val != 0)
133  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
134  W.row(i).t();
135  }
136  // Add regularization.
137  if (kh != 0)
138  deltaH -= kh * H.col(currentUserIndex);
139 
140  // Update H matrix and move on to the next user.
141  H.col(currentUserIndex++) += u * deltaH;
142  currentUserIndex = currentUserIndex % V.n_cols;
143  }
144 
145  private:
147  double u;
149  double kw;
151  double kh;
152 
155 };
156 
159 
161 template<>
162 inline void SVDIncompleteIncrementalLearning::
163  WUpdate<arma::sp_mat>(const arma::sp_mat& V,
164  arma::mat& W,
165  const arma::mat& H)
166 {
167  arma::mat deltaW(V.n_rows, W.n_cols);
168  deltaW.zeros();
169  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
170  it != V.end_col(currentUserIndex);it++)
171  {
172  double val = *it;
173  size_t i = it.row();
174  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
175  arma::trans(H.col(currentUserIndex));
176  if (kw != 0) deltaW.row(i) -= kw * W.row(i);
177  }
178 
179  W += u*deltaW;
180 }
181 
182 template<>
183 inline void SVDIncompleteIncrementalLearning::
184  HUpdate<arma::sp_mat>(const arma::sp_mat& V,
185  const arma::mat& W,
186  arma::mat& H)
187 {
188  arma::mat deltaH(H.n_rows, 1);
189  deltaH.zeros();
190 
191  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
192  it != V.end_col(currentUserIndex);it++)
193  {
194  double val = *it;
195  size_t i = it.row();
196  if ((val = V(i, currentUserIndex)) != 0)
197  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
198  arma::trans(W.row(i));
199  }
200  if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
201 
202  H.col(currentUserIndex++) += u * deltaH;
203  currentUserIndex = currentUserIndex % V.n_cols;
204 }
205 
206 } // namepsace amf
207 } // namespace mlpack
208 
209 #endif
This class computes SVD using incomplete incremental batch learning, as described in the following pa...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
SVDIncompleteIncrementalLearning(double u=0.001, double kw=0, double kh=0)
Initialize the parameters of SVDIncompleteIncrementalLearning.