13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP 14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP 36 template<
typename BorderMode = FullConvolution, const
bool padLastDim = false>
51 template<
typename eT,
typename Border = BorderMode>
52 static typename std::enable_if<
53 std::is_same<Border, ValidConvolution>::value,
void>::type
55 const arma::Mat<eT>& filter,
56 arma::Mat<eT>& output)
58 arma::Mat<eT> inputPadded = input;
59 arma::Mat<eT> filterPadded = filter;
62 inputPadded.resize(inputPadded.n_rows, inputPadded.n_cols + 1);
65 filterPadded.resize(inputPadded.n_rows, inputPadded.n_cols);
67 output = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
72 output = output.submat(filter.n_rows - 1, filter.n_cols - 1,
73 input.n_rows - 1, input.n_cols - 1);
86 template<
typename eT,
typename Border = BorderMode>
87 static typename std::enable_if<
88 std::is_same<Border, FullConvolution>::value,
void>::type
90 const arma::Mat<eT>& filter,
91 arma::Mat<eT>& output)
96 const size_t outputRows = input.n_rows + 2 * (filter.n_rows - 1);
97 size_t outputCols = input.n_cols + 2 * (filter.n_cols - 1);
103 arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
105 inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
106 filter.n_rows - 1 + input.n_rows - 1,
107 filter.n_cols - 1 + input.n_cols - 1) = input;
109 arma::Mat<eT> filterPadded = filter;
110 filterPadded.resize(outputRows, outputCols);
113 output = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
118 output = output.submat(filter.n_rows - 1, filter.n_cols - 1,
119 2 * (filter.n_rows - 1) + input.n_rows - 1,
120 2 * (filter.n_cols - 1) + input.n_cols - 1);
134 template<
typename eT>
136 const arma::Cube<eT>& filter,
137 arma::Cube<eT>& output)
139 arma::Mat<eT> convOutput;
143 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
145 output.slice(0) = convOutput;
147 for (
size_t i = 1; i < input.n_slices; i++)
151 output.slice(i) = convOutput;
166 template<
typename eT>
168 const arma::Cube<eT>& filter,
169 arma::Cube<eT>& output)
171 arma::Mat<eT> convOutput;
175 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
177 output.slice(0) = convOutput;
179 for (
size_t i = 1; i < filter.n_slices; i++)
183 output.slice(i) = convOutput;
195 template<
typename eT>
197 const arma::Mat<eT>& filter,
198 arma::Cube<eT>& output)
200 arma::Mat<eT> convOutput;
204 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
206 output.slice(0) = convOutput;
208 for (
size_t i = 1; i < input.n_slices; i++)
212 output.slice(i) = convOutput;
Linear algebra utility functions, generally performed on matrices or vectors.
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
The core includes that mlpack expects; standard C++ includes and Armadillo.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution through fft.
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)