13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_SVD_CONVOLUTION_HPP 14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_SVD_CONVOLUTION_HPP 37 template<
typename BorderMode = FullConvolution>
57 const arma::Mat<eT>& filter,
58 arma::Mat<eT>& output)
62 if (filter.n_rows > input.n_rows || filter.n_cols > input.n_cols ||
63 filter.n_rows == 1 || filter.n_cols == 1)
69 arma::Mat<eT> U, V, subOutput;
72 arma::svd_econ(U, s, V, filter);
76 const size_t rank = arma::sum(s > (s.n_elem * arma::max(s) *
81 if (rank * (filter.n_rows + filter.n_cols) < filter.n_elem)
83 arma::Mat<eT> subFilter = V.unsafe_col(0) * s(0);
86 subOutput = subOutput.t();
90 for (
size_t r = 1; r < rank; r++)
92 subFilter = V.unsafe_col(r) * s(r);
97 subOutput = subOutput.t();
121 template<
typename eT>
123 const arma::Cube<eT>& filter,
124 arma::Cube<eT>& output)
126 arma::Mat<eT> convOutput;
130 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
132 output.slice(0) = convOutput;
134 for (
size_t i = 1; i < input.n_slices; i++)
138 output.slice(i) = convOutput;
152 template<
typename eT>
154 const arma::Cube<eT>& filter,
155 arma::Cube<eT>& output)
157 arma::Mat<eT> convOutput;
160 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
162 output.slice(0) = convOutput;
164 for (
size_t i = 1; i < filter.n_slices; i++)
168 output.slice(i) = convOutput;
182 template<
typename eT>
184 const arma::Mat<eT>& filter,
185 arma::Cube<eT>& output)
187 arma::Mat<eT> convOutput;
190 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
192 output.slice(0) = convOutput;
194 for (
size_t i = 1; i < input.n_slices; i++)
198 output.slice(i) = convOutput;
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution using singular value decomposition.
static void Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)