mlpack  master
fft_convolution.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP
14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "border_modes.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
36 template<typename BorderMode = FullConvolution, const bool padLastDim = false>
38 {
39  public:
40  /*
41  * Perform a convolution through fft (valid mode). This method only supports
42  * input which is even on the last dimension. In case of an odd input width, a
43  * user can manually pad the imput or specify the padLastDim parameter which
44  * takes care of the padding. The filter instead can have any size. When using
45  * the valid mode the filters has to be smaller than the input.
46  *
47  * @param input Input used to perform the convolution.
48  * @param filter Filter used to perform the conolution.
49  * @param output Output data that contains the results of the convolution.
50  */
51  template<typename eT, typename Border = BorderMode>
52  static typename std::enable_if<
53  std::is_same<Border, ValidConvolution>::value, void>::type
54  Convolution(const arma::Mat<eT>& input,
55  const arma::Mat<eT>& filter,
56  arma::Mat<eT>& output)
57  {
58  arma::Mat<eT> inputPadded = input;
59  arma::Mat<eT> filterPadded = filter;
60 
61  if (padLastDim)
62  inputPadded.resize(inputPadded.n_rows, inputPadded.n_cols + 1);
63 
64  // Pad filter and input to the output shape.
65  filterPadded.resize(inputPadded.n_rows, inputPadded.n_cols);
66 
67  output = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
68  filterPadded)));
69 
70  // Extract the region of interest. We don't need to handle the padLastDim in
71  // a special way we just cut it out from the output matrix.
72  output = output.submat(filter.n_rows - 1, filter.n_cols - 1,
73  input.n_rows - 1, input.n_cols - 1);
74  }
75 
76  /*
77  * Perform a convolution through fft (full mode). This method only supports
78  * input which is even on the last dimension. In case of an odd input width, a
79  * user can manually pad the imput or specify the padLastDim parameter which
80  * takes care of the padding. The filter instead can have any size.
81  *
82  * @param input Input used to perform the convolution.
83  * @param filter Filter used to perform the conolution.
84  * @param output Output data that contains the results of the convolution.
85  */
86  template<typename eT, typename Border = BorderMode>
87  static typename std::enable_if<
88  std::is_same<Border, FullConvolution>::value, void>::type
89  Convolution(const arma::Mat<eT>& input,
90  const arma::Mat<eT>& filter,
91  arma::Mat<eT>& output)
92  {
93  // In case of the full convolution outputRows and outputCols doesn't
94  // represent the true output size when the padLastDim parameter is set,
95  // instead it's the working size.
96  const size_t outputRows = input.n_rows + 2 * (filter.n_rows - 1);
97  size_t outputCols = input.n_cols + 2 * (filter.n_cols - 1);
98 
99  if (padLastDim)
100  outputCols++;
101 
102  // Pad filter and input to the working output shape.
103  arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
104  outputCols);
105  inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
106  filter.n_rows - 1 + input.n_rows - 1,
107  filter.n_cols - 1 + input.n_cols - 1) = input;
108 
109  arma::Mat<eT> filterPadded = filter;
110  filterPadded.resize(outputRows, outputCols);
111 
112  // Perform FFT and IFFT
113  output = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
114  filterPadded)));
115 
116  // Extract the region of interest. We don't need to handle the padLastDim
117  // parameter in a special way we just cut it out from the output matrix.
118  output = output.submat(filter.n_rows - 1, filter.n_cols - 1,
119  2 * (filter.n_rows - 1) + input.n_rows - 1,
120  2 * (filter.n_cols - 1) + input.n_cols - 1);
121  }
122 
123  /*
124  * Perform a convolution through fft using 3rd order tensors. This method only
125  * supports input which is even on the last dimension. In case of an odd input
126  * width, a user can manually pad the imput or specify the padLastDim
127  * parameter which takes care of the padding. The filter instead can have any
128  * size.
129  *
130  * @param input Input used to perform the convolution.
131  * @param filter Filter used to perform the conolution.
132  * @param output Output data that contains the results of the convolution.
133  */
134  template<typename eT>
135  static void Convolution(const arma::Cube<eT>& input,
136  const arma::Cube<eT>& filter,
137  arma::Cube<eT>& output)
138  {
139  arma::Mat<eT> convOutput;
140  FFTConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
141  convOutput);
142 
143  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
144  input.n_slices);
145  output.slice(0) = convOutput;
146 
147  for (size_t i = 1; i < input.n_slices; i++)
148  {
149  FFTConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
150  convOutput);
151  output.slice(i) = convOutput;
152  }
153  }
154 
155  /*
156  * Perform a convolution through fft using dense matrix as input and a 3rd
157  * order tensors as filter and output. This method only supports input which
158  * is even on the last dimension. In case of an odd input width, a user can
159  * manually pad the imput or specify the padLastDim parameter which takes care
160  * of the padding. The filter instead can have any size.
161  *
162  * @param input Input used to perform the convolution.
163  * @param filter Filter used to perform the conolution.
164  * @param output Output data that contains the results of the convolution.
165  */
166  template<typename eT>
167  static void Convolution(const arma::Mat<eT>& input,
168  const arma::Cube<eT>& filter,
169  arma::Cube<eT>& output)
170  {
171  arma::Mat<eT> convOutput;
172  FFTConvolution<BorderMode>::Convolution(input, filter.slice(0),
173  convOutput);
174 
175  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
176  filter.n_slices);
177  output.slice(0) = convOutput;
178 
179  for (size_t i = 1; i < filter.n_slices; i++)
180  {
181  FFTConvolution<BorderMode>::Convolution(input, filter.slice(i),
182  convOutput);
183  output.slice(i) = convOutput;
184  }
185  }
186 
187  /*
188  * Perform a convolution using a 3rd order tensors as input and output and a
189  * dense matrix as filter.
190  *
191  * @param input Input used to perform the convolution.
192  * @param filter Filter used to perform the conolution.
193  * @param output Output data that contains the results of the convolution.
194  */
195  template<typename eT>
196  static void Convolution(const arma::Cube<eT>& input,
197  const arma::Mat<eT>& filter,
198  arma::Cube<eT>& output)
199  {
200  arma::Mat<eT> convOutput;
201  FFTConvolution<BorderMode>::Convolution(input.slice(0), filter,
202  convOutput);
203 
204  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
205  input.n_slices);
206  output.slice(0) = convOutput;
207 
208  for (size_t i = 1; i < input.n_slices; i++)
209  {
210  FFTConvolution<BorderMode>::Convolution(input.slice(i), filter,
211  convOutput);
212  output.slice(i) = convOutput;
213  }
214  }
215 
216 }; // class FFTConvolution
217 
218 } // namespace ann
219 } // namespace mlpack
220 
221 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
The core includes that mlpack expects; standard C++ includes and Armadillo.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution through fft.
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)