mlpack  master
base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
21 
22 namespace mlpack {
23 namespace ann {
24 
42 template <
43  class ActivationFunction = LogisticFunction,
44  typename InputDataType = arma::mat,
45  typename OutputDataType = arma::mat
46 >
47 class BaseLayer
48 {
49  public:
54  {
55  // Nothing to do here.
56  }
57 
65  template<typename InputType, typename OutputType>
66  void Forward(const InputType&& input, OutputType&& output)
67  {
68  ActivationFunction::fn(input, output);
69  }
70 
80  template<typename eT>
81  void Backward(const arma::Mat<eT>&& input,
82  arma::Mat<eT>&& gy,
83  arma::Mat<eT>&& g)
84  {
85  arma::Mat<eT> derivative;
86  ActivationFunction::deriv(input, derivative);
87  g = gy % derivative;
88  }
89 
91  InputDataType const& InputParameter() const { return inputParameter; }
93  InputDataType& InputParameter() { return inputParameter; }
94 
96  OutputDataType const& OutputParameter() const { return outputParameter; }
98  OutputDataType& OutputParameter() { return outputParameter; }
99 
101  OutputDataType const& Delta() const { return delta; }
103  OutputDataType& Delta() { return delta; }
104 
108  template<typename Archive>
109  void Serialize(Archive& /* ar */, const unsigned int /* version */)
110  {
111  /* Nothing to do here */
112  }
113 
114  private:
116  OutputDataType delta;
117 
119  InputDataType inputParameter;
120 
122  OutputDataType outputParameter;
123 }; // class BaseLayer
124 
125 // Convenience typedefs.
126 
130 template <
131  class ActivationFunction = LogisticFunction,
132  typename InputDataType = arma::mat,
133  typename OutputDataType = arma::mat
134 >
135 using SigmoidLayer = BaseLayer<
136  ActivationFunction, InputDataType, OutputDataType>;
137 
141 template <
142  class ActivationFunction = IdentityFunction,
143  typename InputDataType = arma::mat,
144  typename OutputDataType = arma::mat
145 >
146 using IdentityLayer = BaseLayer<
147  ActivationFunction, InputDataType, OutputDataType>;
148 
152 template <
153  class ActivationFunction = RectifierFunction,
154  typename InputDataType = arma::mat,
155  typename OutputDataType = arma::mat
156 >
157 using ReLULayer = BaseLayer<
158  ActivationFunction, InputDataType, OutputDataType>;
159 
163 template <
164  class ActivationFunction = TanhFunction,
165  typename InputDataType = arma::mat,
166  typename OutputDataType = arma::mat
167 >
168 using TanHLayer = BaseLayer<
169  ActivationFunction, InputDataType, OutputDataType>;
170 
171 } // namespace ann
172 } // namespace mlpack
173 
174 #endif
The identity function, defined by.
The tanh function, defined by.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType outputParameter
Locally-stored output parameter object.
Definition: base_layer.hpp:122
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:96
InputDataType & InputParameter()
Modify the input parameter.
Definition: base_layer.hpp:93
Implementation of the base layer.
Definition: base_layer.hpp:47
OutputDataType delta
Locally-stored delta object.
Definition: base_layer.hpp:116
void Forward(const InputType &&input, OutputType &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:66
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:101
InputDataType inputParameter
Locally-stored input parameter object.
Definition: base_layer.hpp:119
The logistic function, defined by.
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:53
InputDataType const & InputParameter() const
Get the input parameter.
Definition: base_layer.hpp:91
void Serialize(Archive &, const unsigned int)
Serialize the layer.
Definition: base_layer.hpp:109
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:98
The rectifier function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:103
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:81