mlpack  master
svd_batch_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
13 #define MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace amf {
19 
42 {
43  public:
52  SVDBatchLearning(double u = 0.0002,
53  double kw = 0,
54  double kh = 0,
55  double momentum = 0.9)
56  : u(u), kw(kw), kh(kh), momentum(momentum)
57  {
58  // empty constructor
59  }
60 
68  template<typename MatType>
69  void Initialize(const MatType& dataset, const size_t rank)
70  {
71  const size_t n = dataset.n_rows;
72  const size_t m = dataset.n_cols;
73 
74  mW.zeros(n, rank);
75  mH.zeros(rank, m);
76  }
77 
87  template<typename MatType>
88  inline void WUpdate(const MatType& V,
89  arma::mat& W,
90  const arma::mat& H)
91  {
92  size_t n = V.n_rows;
93  size_t m = V.n_cols;
94 
95  size_t r = W.n_cols;
96 
97  // initialize the momentum of this iteration.
98  mW = momentum * mW;
99 
100  // Compute the step.
101  arma::mat deltaW;
102  deltaW.zeros(n, r);
103  for (size_t i = 0; i < n; i++)
104  {
105  for (size_t j = 0; j < m; j++)
106  {
107  const double val = V(i, j);
108  if (val != 0)
109  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
110  arma::trans(H.col(j));
111  }
112  // Add regularization.
113  if (kw != 0)
114  deltaW.row(i) -= kw * W.row(i);
115  }
116 
117  // Add the step to the momentum.
118  mW += u * deltaW;
119  // Add the momentum to the W matrix.
120  W += mW;
121  }
122 
132  template<typename MatType>
133  inline void HUpdate(const MatType& V,
134  const arma::mat& W,
135  arma::mat& H)
136  {
137  size_t n = V.n_rows;
138  size_t m = V.n_cols;
139 
140  size_t r = W.n_cols;
141 
142  // Initialize the momentum of this iteration.
143  mH = momentum * mH;
144 
145  // Compute the step.
146  arma::mat deltaH;
147  deltaH.zeros(r, m);
148  for (size_t j = 0; j < m; j++)
149  {
150  for (size_t i = 0; i < n; i++)
151  {
152  const double val = V(i, j);
153  if (val != 0)
154  deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
155  }
156  // Add regularization.
157  if (kh != 0)
158  deltaH.col(j) -= kh * H.col(j);
159  }
160 
161  // Add this step to the momentum.
162  mH += u * deltaH;
163  // Add the momentum to H.
164  H += mH;
165  }
166 
168  template<typename Archive>
169  void Serialize(Archive& ar, const unsigned int /* version */)
170  {
171  using data::CreateNVP;
172  ar & CreateNVP(u, "u");
173  ar & CreateNVP(kw, "kw");
174  ar & CreateNVP(kh, "kh");
175  ar & CreateNVP(momentum, "momentum");
176  ar & CreateNVP(mW, "mW");
177  ar & CreateNVP(mH, "mH");
178  }
179 
180  private:
182  double u;
184  double kw;
186  double kh;
188  double momentum;
189 
191  arma::mat mW;
193  arma::mat mH;
194 }; // class SVDBatchLearning
195 
198 
202 template<>
203 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
204  arma::mat& W,
205  const arma::mat& H)
206 {
207  const size_t n = V.n_rows;
208  const size_t r = W.n_cols;
209 
210  mW = momentum * mW;
211 
212  arma::mat deltaW;
213  deltaW.zeros(n, r);
214 
215  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
216  {
217  const size_t row = it.row();
218  const size_t col = it.col();
219  deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
220  arma::trans(H.col(col));
221  }
222 
223  if (kw != 0)
224  deltaW -= kw * W;
225 
226  mW += u * deltaW;
227  W += mW;
228 }
229 
230 template<>
231 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
232  const arma::mat& W,
233  arma::mat& H)
234 {
235  const size_t m = V.n_cols;
236  const size_t r = W.n_cols;
237 
238  mH = momentum * mH;
239 
240  arma::mat deltaH;
241  deltaH.zeros(r, m);
242 
243  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
244  {
245  const size_t row = it.row();
246  const size_t col = it.col();
247  deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
248  W.row(row).t();
249  }
250 
251  if (kh != 0)
252  deltaH -= kh * H;
253 
254  mH += u * deltaH;
255  H += mH;
256 }
257 
258 } // namespace amf
259 } // namespace mlpack
260 
261 #endif // MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
double kh
Regularization parameter for matrix H.
void Serialize(Archive &ar, const unsigned int)
Serialize the SVDBatch object.
double u
Step size of the algorithm.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename std::enable_if_t< HasSerialize< T >::value > *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9)
SVD Batch learning constructor.
void Initialize(const MatType &dataset, const size_t rank)
Initialize parameters before factorization.
double kw
Regularization parameter for matrix W.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
arma::mat mW
Momentum matrix for matrix W.
arma::mat mH
Momentum matrix for matrix H.
double momentum
Momentum value (between 0 and 1).