mlpack  master
recurrent.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
14 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
15 
16 #include <mlpack/core.hpp>
17 #include <boost/ptr_container/ptr_vector.hpp>
18 
19 #include "../visitor/delta_visitor.hpp"
20 #include "../visitor/output_parameter_visitor.hpp"
21 #include "../visitor/weight_size_visitor.hpp"
22 
23 #include "layer_types.hpp"
24 #include "add_merge.hpp"
25 #include "sequential.hpp"
26 
27 namespace mlpack {
28 namespace ann {
29 
39 template <
40  typename InputDataType = arma::mat,
41  typename OutputDataType = arma::mat
42 >
43 class Recurrent
44 {
45  public:
55  template<typename StartModuleType,
56  typename InputModuleType,
57  typename FeedbackModuleType,
58  typename TransferModuleType>
59  Recurrent(const StartModuleType& start,
60  const InputModuleType& input,
61  const FeedbackModuleType& feedback,
62  const TransferModuleType& transfer,
63  const size_t rho);
64 
72  template<typename eT>
73  void Forward(arma::Mat<eT>&& input, arma::Mat<eT>&& output);
74 
84  template<typename eT>
85  void Backward(const arma::Mat<eT>&& /* input */,
86  arma::Mat<eT>&& gy,
87  arma::Mat<eT>&& g);
88 
89  /*
90  * Calculate the gradient using the output delta and the input activation.
91  *
92  * @param input The input parameter used for calculating the gradient.
93  * @param error The calculated error.
94  * @param gradient The calculated gradient.
95  */
96  template<typename eT>
97  void Gradient(arma::Mat<eT>&& input,
98  arma::Mat<eT>&& error,
99  arma::Mat<eT>&& /* gradient */);
100 
102  std::vector<LayerTypes>& Model() { return network; }
103 
105  bool Deterministic() const { return deterministic; }
107  bool& Deterministic() { return deterministic; }
108 
110  OutputDataType const& Parameters() const { return parameters; }
112  OutputDataType& Parameters() { return parameters; }
113 
115  InputDataType const& InputParameter() const { return inputParameter; }
117  InputDataType& InputParameter() { return inputParameter; }
118 
120  OutputDataType const& OutputParameter() const { return outputParameter; }
122  OutputDataType& OutputParameter() { return outputParameter; }
123 
125  OutputDataType const& Delta() const { return delta; }
127  OutputDataType& Delta() { return delta; }
128 
130  OutputDataType const& Gradient() const { return gradient; }
132  OutputDataType& Gradient() { return gradient; }
133 
137  template<typename Archive>
138  void Serialize(Archive& ar, const unsigned int /* version */);
139 
140  private:
143 
146 
149 
152 
154  size_t rho;
155 
157  size_t forwardStep;
158 
160  size_t backwardStep;
161 
163  size_t gradientStep;
164 
167 
169  OutputDataType parameters;
170 
173 
176 
178  std::vector<LayerTypes> network;
179 
182 
185 
188 
191 
193  std::vector<arma::mat> feedbackOutputParameter;
194 
196  OutputDataType delta;
197 
199  OutputDataType gradient;
200 
202  InputDataType inputParameter;
203 
205  OutputDataType outputParameter;
206 
208  arma::mat recurrentError;
209 }; // class Recurrent
210 
211 } // namespace ann
212 } // namespace mlpack
213 
214 // Include implementation.
215 #include "recurrent_impl.hpp"
216 
217 #endif
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:130
void Backward(const arma::Mat< eT > &&, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType gradient
Locally-stored gradient object.
Definition: recurrent.hpp:199
LayerTypes feedbackModule
Locally-stored feedback module.
Definition: recurrent.hpp:148
size_t forwardStep
Locally-stored number of forward steps.
Definition: recurrent.hpp:157
InputDataType inputParameter
Locally-stored input parameter object.
Definition: recurrent.hpp:202
OutputDataType const & Parameters() const
Get the parameters.
Definition: recurrent.hpp:110
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: recurrent.hpp:122
InputDataType & InputParameter()
Modify the input parameter.
Definition: recurrent.hpp:117
bool Deterministic() const
The value of the deterministic parameter.
Definition: recurrent.hpp:105
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
size_t backwardStep
Locally-stored number of backward steps.
Definition: recurrent.hpp:160
WeightSizeVisitor weightSizeVisitor
Locally-stored weight size visitor.
Definition: recurrent.hpp:184
bool & Deterministic()
Modify the value of the deterministic parameter.
Definition: recurrent.hpp:107
void Serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType parameters
Locally-stored weight object.
Definition: recurrent.hpp:169
LayerTypes mergeModule
Locally-stored merge module.
Definition: recurrent.hpp:181
WeightSizeVisitor returns the number of weights of the given module.
Recurrent(const StartModuleType &start, const InputModuleType &input, const FeedbackModuleType &feedback, const TransferModuleType &transfer, const size_t rho)
Create the Recurrent object using the specified modules.
OutputDataType & Delta()
Modify the delta.
Definition: recurrent.hpp:127
std::vector< LayerTypes > network
Locally-stored model modules.
Definition: recurrent.hpp:178
OutputParameterVisitor outputParameterVisitor
Locally-stored output parameter visitor.
Definition: recurrent.hpp:190
OutputDataType & Gradient()
Modify the gradient.
Definition: recurrent.hpp:132
OutputDataType outputParameter
Locally-stored output parameter object.
Definition: recurrent.hpp:205
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: recurrent.hpp:120
std::vector< arma::mat > feedbackOutputParameter
Locally-stored feedback output parameters.
Definition: recurrent.hpp:193
LayerTypes startModule
Locally-stored start module.
Definition: recurrent.hpp:142
OutputParameterVisitor exposes the output parameter of the given module.
OutputDataType const & Delta() const
Get the delta.
Definition: recurrent.hpp:125
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat > *, LinearNoBias< arma::mat, arma::mat > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MeanSquaredError< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > * > LayerTypes
OutputDataType & Parameters()
Modify the parameters.
Definition: recurrent.hpp:112
std::vector< LayerTypes > & Model()
Get the model modules.
Definition: recurrent.hpp:102
DeltaVisitor deltaVisitor
Locally-stored delta visitor.
Definition: recurrent.hpp:187
arma::mat recurrentError
Locally-stored recurrent error parameter.
Definition: recurrent.hpp:208
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
LayerTypes initialModule
Locally-stored initial module.
Definition: recurrent.hpp:172
LayerTypes transferModule
Locally-stored transfer module.
Definition: recurrent.hpp:151
DeltaVisitor exposes the delta parameter of the given module.
bool deterministic
If true dropout and scaling is disabled, see notes above.
Definition: recurrent.hpp:166
OutputDataType delta
Locally-stored delta object.
Definition: recurrent.hpp:196
LayerTypes inputModule
Locally-stored input module.
Definition: recurrent.hpp:145
InputDataType const & InputParameter() const
Get the input parameter.
Definition: recurrent.hpp:115
size_t gradientStep
Locally-stored number of gradient steps.
Definition: recurrent.hpp:163
size_t rho
Number of steps to backpropagate through time (BPTT).
Definition: recurrent.hpp:154
LayerTypes recurrentModule
Locally-stored recurrent module.
Definition: recurrent.hpp:175
void Forward(arma::Mat< eT > &&input, arma::Mat< eT > &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...