mlpack  master
svd_convolution.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_SVD_CONVOLUTION_HPP
14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_SVD_CONVOLUTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "border_modes.hpp"
18 #include "fft_convolution.hpp"
19 #include "naive_convolution.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
37 template<typename BorderMode = FullConvolution>
39 {
40  public:
41  /*
42  * Perform a convolution (valid or full mode) using singular value
43  * decomposition. By using singular value decomposition of the filter matrix
44  * the convolution can be expressed as a sum of outer products. Each product
45  * can be computed efficiently as convolution with a row and a column vector.
46  * The individual convolutions are computed with the naive implementation
47  * which is fast if the filter is low-dimensional.
48  *
49  * @param input Input used to perform the convolution.
50  * @param filter Filter used to perform the conolution.
51  * @param output Output data that contains the results of the convolution.
52  * @param dW Stride of filter application in the x direction.
53  * @param dH Stride of filter application in the y direction.
54  */
55  template<typename eT>
56  static void Convolution(const arma::Mat<eT>& input,
57  const arma::Mat<eT>& filter,
58  arma::Mat<eT>& output)
59  {
60  // Use the naive convolution in case the filter isn't two dimensional or the
61  // filter is bigger than the input.
62  if (filter.n_rows > input.n_rows || filter.n_cols > input.n_cols ||
63  filter.n_rows == 1 || filter.n_cols == 1)
64  {
65  NaiveConvolution<BorderMode>::Convolution(input, filter, output);
66  }
67  else
68  {
69  arma::Mat<eT> U, V, subOutput;
70  arma::Col<eT> s;
71 
72  arma::svd_econ(U, s, V, filter);
73 
74  // Rank approximation using the singular values calculated with singular
75  // value decomposition of dense filter matrix.
76  const size_t rank = arma::sum(s > (s.n_elem * arma::max(s) *
77  arma::datum::eps));
78 
79  // Test for separability based on the rank of the kernel and take
80  // advantage of the low rank.
81  if (rank * (filter.n_rows + filter.n_cols) < filter.n_elem)
82  {
83  arma::Mat<eT> subFilter = V.unsafe_col(0) * s(0);
84  NaiveConvolution<BorderMode>::Convolution(input, subFilter, subOutput);
85 
86  subOutput = subOutput.t();
87  NaiveConvolution<BorderMode>::Convolution(subOutput, U.unsafe_col(0),
88  output);
89 
90  for (size_t r = 1; r < rank; r++)
91  {
92  subFilter = V.unsafe_col(r) * s(r);
94  subOutput);
95 
96  arma::Mat<eT> temp;
97  subOutput = subOutput.t();
98  NaiveConvolution<BorderMode>::Convolution(subOutput, U.unsafe_col(r),
99  temp);
100  output += temp;
101  }
102 
103  output = output.t();
104  }
105  else
106  {
107  FFTConvolution<BorderMode>::Convolution(input, filter, output);
108  }
109  }
110  }
111 
112  /*
113  * Perform a convolution using 3rd order tensors.
114  *
115  * @param input Input used to perform the convolution.
116  * @param filter Filter used to perform the conolution.
117  * @param output Output data that contains the results of the convolution.
118  * @param dW Stride of filter application in the x direction.
119  * @param dH Stride of filter application in the y direction.
120  */
121  template<typename eT>
122  static void Convolution(const arma::Cube<eT>& input,
123  const arma::Cube<eT>& filter,
124  arma::Cube<eT>& output)
125  {
126  arma::Mat<eT> convOutput;
127  SVDConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
128  convOutput);
129 
130  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
131  input.n_slices);
132  output.slice(0) = convOutput;
133 
134  for (size_t i = 1; i < input.n_slices; i++)
135  {
136  SVDConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
137  convOutput);
138  output.slice(i) = convOutput;
139  }
140  }
141 
142  /*
143  * Perform a convolution using dense matrix as input and a 3rd order tensors
144  * as filter and output.
145  *
146  * @param input Input used to perform the convolution.
147  * @param filter Filter used to perform the conolution.
148  * @param output Output data that contains the results of the convolution.
149  * @param dW Stride of filter application in the x direction.
150  * @param dH Stride of filter application in the y direction.
151  */
152  template<typename eT>
153  static void Convolution(const arma::Mat<eT>& input,
154  const arma::Cube<eT>& filter,
155  arma::Cube<eT>& output)
156  {
157  arma::Mat<eT> convOutput;
158  SVDConvolution<BorderMode>::Convolution(input, filter.slice(0), convOutput);
159 
160  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
161  filter.n_slices);
162  output.slice(0) = convOutput;
163 
164  for (size_t i = 1; i < filter.n_slices; i++)
165  {
166  SVDConvolution<BorderMode>::Convolution(input, filter.slice(i),
167  convOutput);
168  output.slice(i) = convOutput;
169  }
170  }
171 
172  /*
173  * Perform a convolution using a 3rd order tensors as input and output and a
174  * dense matrix as filter.
175  *
176  * @param input Input used to perform the convolution.
177  * @param filter Filter used to perform the conolution.
178  * @param output Output data that contains the results of the convolution.
179  * @param dW Stride of filter application in the x direction.
180  * @param dH Stride of filter application in the y direction.
181  */
182  template<typename eT>
183  static void Convolution(const arma::Cube<eT>& input,
184  const arma::Mat<eT>& filter,
185  arma::Cube<eT>& output)
186  {
187  arma::Mat<eT> convOutput;
188  SVDConvolution<BorderMode>::Convolution(input.slice(0), filter, convOutput);
189 
190  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
191  input.n_slices);
192  output.slice(0) = convOutput;
193 
194  for (size_t i = 1; i < input.n_slices; i++)
195  {
196  SVDConvolution<BorderMode>::Convolution(input.slice(i), filter,
197  convOutput);
198  output.slice(i) = convOutput;
199  }
200  }
201 
202 }; // class SVDConvolution
203 
204 } // namespace ann
205 } // namespace mlpack
206 
207 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution using singular value decomposition.
static void Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)