13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP 14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP 34 template<
typename BorderMode = FullConvolution>
47 template<
typename eT,
typename Border = BorderMode>
48 static typename std::enable_if<
49 std::is_same<Border, ValidConvolution>::value,
void>::type
51 const arma::Mat<eT>& filter,
52 arma::Mat<eT>& output,
56 output = arma::zeros<arma::Mat<eT> >((input.n_rows - filter.n_rows + 1) /
57 dW, (input.n_cols - filter.n_cols + 1) / dH);
61 eT* outputPtr = output.memptr();
63 for (
size_t j = 0; j < output.n_cols; ++j)
65 for (
size_t i = 0; i < output.n_rows; ++i, outputPtr++)
67 const eT* kernelPtr = filter.memptr();
68 for (
size_t kj = 0; kj < filter.n_cols; ++kj)
70 const eT* inputPtr = input.colptr(kj + j * dW) + i * dH;
71 for (
size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr, ++inputPtr)
72 *outputPtr += *kernelPtr * (*inputPtr);
87 template<
typename eT,
typename Border = BorderMode>
88 static typename std::enable_if<
89 std::is_same<Border, FullConvolution>::value,
void>::type
91 const arma::Mat<eT>& filter,
92 arma::Mat<eT>& output,
96 const size_t outputRows = (input.n_rows + 2 * (filter.n_rows - 1)) * dW;
97 const size_t outputCols = (input.n_cols + 2 * (filter.n_cols - 1)) * dH;
100 arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
102 inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
103 filter.n_rows - 1 + input.n_rows - 1,
104 filter.n_cols - 1 + input.n_cols - 1) = input;
119 template<
typename eT>
121 const arma::Cube<eT>& filter,
122 arma::Cube<eT>& output,
126 arma::Mat<eT> convOutput;
130 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
132 output.slice(0) = convOutput;
134 for (
size_t i = 1; i < input.n_slices; i++)
137 output.slice(i), dW, dH);
151 template<
typename eT>
153 const arma::Cube<eT>& filter,
154 arma::Cube<eT>& output,
158 arma::Mat<eT> convOutput;
162 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
164 output.slice(0) = convOutput;
166 for (
size_t i = 1; i < filter.n_slices; i++)
169 output.slice(i), dW, dH);
183 template<
typename eT>
185 const arma::Mat<eT>& filter,
186 arma::Cube<eT>& output,
190 arma::Mat<eT> convOutput;
194 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
196 output.slice(0) = convOutput;
198 for (
size_t i = 1; i < input.n_slices; i++)
201 output.slice(i), dW, dH);
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1)
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1)
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1)
Computes the two-dimensional convolution.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1)