mlpack
master
|
An implementation of a lstm network layer. More...
Public Member Functions | |
LSTM () | |
Create the LSTM object. More... | |
LSTM (const size_t inSize, const size_t outSize, const size_t rho) | |
Create the LSTM layer object using the specified parameters. More... | |
template<typename eT > | |
void | Backward (const arma::Mat< eT > &&, arma::Mat< eT > &&gy, arma::Mat< eT > &&g) |
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f. More... | |
OutputDataType const & | Delta () const |
Get the delta. More... | |
OutputDataType & | Delta () |
Modify the delta. More... | |
bool | Deterministic () const |
The value of the deterministic parameter. More... | |
bool & | Deterministic () |
Modify the value of the deterministic parameter. More... | |
template<typename eT > | |
void | Forward (arma::Mat< eT > &&input, arma::Mat< eT > &&output) |
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f. More... | |
template<typename eT > | |
void | Gradient (arma::Mat< eT > &&input, arma::Mat< eT > &&, arma::Mat< eT > &&) |
OutputDataType const & | Gradient () const |
Get the gradient. More... | |
OutputDataType & | Gradient () |
Modify the gradient. More... | |
InputDataType const & | InputParameter () const |
Get the input parameter. More... | |
InputDataType & | InputParameter () |
Modify the input parameter. More... | |
std::vector< LayerTypes > & | Model () |
Get the model modules. More... | |
OutputDataType const & | OutputParameter () const |
Get the output parameter. More... | |
OutputDataType & | OutputParameter () |
Modify the output parameter. More... | |
OutputDataType const & | Parameters () const |
Get the parameters. More... | |
OutputDataType & | Parameters () |
Modify the parameters. More... | |
size_t | Rho () const |
Get the maximum number of steps to backpropagate through time (BPTT). More... | |
size_t & | Rho () |
Modify the maximum number of steps to backpropagate through time (BPTT). More... | |
template<typename Archive > | |
void | Serialize (Archive &ar, const unsigned int) |
Serialize the layer. More... | |
Private Attributes | |
size_t | backwardStep |
Locally-stored number of backward steps. More... | |
arma::mat | cellActivationError |
Locally-stored cell activation error. More... | |
LayerTypes | cellActivationModule |
Locally-stored cell activation module. More... | |
LayerTypes | cellModule |
Locally-stored cell module. More... | |
std::vector< arma::mat > | cellParameter |
Locally-stored cell parameters. More... | |
OutputDataType | delta |
Locally-stored delta object. More... | |
DeltaVisitor | deltaVisitor |
Locally-stored delta visitor. More... | |
bool | deterministic |
If true dropout and scaling is disabled, see notes above. More... | |
arma::mat | forgetGateError |
Locally-stored foget gate error. More... | |
LayerTypes | forgetGateModule |
Locally-stored forget gate module. More... | |
size_t | forwardStep |
Locally-stored number of forward steps. More... | |
OutputDataType | gradient |
Locally-stored gradient object. More... | |
size_t | gradientStep |
Locally-stored number of gradient steps. More... | |
LayerTypes | hiddenStateModule |
Locally-stored hidden state module. More... | |
LayerTypes | input2GateModule |
Locally-stored input 2 gate module. More... | |
LayerTypes | inputGateModule |
Locally-stored input gate module. More... | |
InputDataType | inputParameter |
Locally-stored input parameter object. More... | |
size_t | inSize |
Locally-stored number of input units. More... | |
std::vector< LayerTypes > | network |
Locally-stored list of network modules. More... | |
std::vector< arma::mat > | outParameter |
Locally-stored output parameters. More... | |
LayerTypes | output2GateModule |
Locally-stored output 2 gate module. More... | |
LayerTypes | outputGateModule |
Locally-stored output gate module. More... | |
OutputDataType | outputParameter |
Locally-stored output parameter object. More... | |
OutputParameterVisitor | outputParameterVisitor |
Locally-stored output parameter visitor. More... | |
size_t | outSize |
Locally-stored number of output units. More... | |
arma::mat | prevCell |
Locally-stored previous cell state. More... | |
arma::mat | prevError |
Locally-stored previous error. More... | |
arma::mat | prevOutput |
Locally-stored previous output. More... | |
size_t | rho |
Number of steps to backpropagate through time (BPTT). More... | |
OutputDataType | weights |
Locally-stored weight object. More... | |
An implementation of a lstm network layer.
This class allows specification of the type of the activation functions used for the gates and cells and also of the type of the function used to initialize and update the peephole weights.
InputDataType | Type of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube). |
OutputDataType | Type of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube). |
Definition at line 51 of file layer_types.hpp.
mlpack::ann::LSTM< InputDataType, OutputDataType >::LSTM | ( | ) |
Create the LSTM object.
mlpack::ann::LSTM< InputDataType, OutputDataType >::LSTM | ( | const size_t | inSize, |
const size_t | outSize, | ||
const size_t | rho | ||
) |
Create the LSTM layer object using the specified parameters.
inSize | The number of input units. |
outSize | The number of output units. |
rho | Maximum number of steps to backpropagate through time (BPTT). |
void mlpack::ann::LSTM< InputDataType, OutputDataType >::Backward | ( | const arma::Mat< eT > && | , |
arma::Mat< eT > && | gy, | ||
arma::Mat< eT > && | g | ||
) |
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f.
Using the results from the feed forward pass.
input | The propagated input activation. |
gy | The backpropagated error. |
g | The calculated gradient. |
|
inline |
Get the delta.
Definition at line 123 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::delta.
|
inline |
Modify the delta.
Definition at line 125 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::delta.
|
inline |
The value of the deterministic parameter.
Definition at line 98 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::deterministic.
|
inline |
Modify the value of the deterministic parameter.
Definition at line 100 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::deterministic.
void mlpack::ann::LSTM< InputDataType, OutputDataType >::Forward | ( | arma::Mat< eT > && | input, |
arma::Mat< eT > && | output | ||
) |
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f.
input | Input data used for evaluating the specified function. |
output | Resulting output activation. |
void mlpack::ann::LSTM< InputDataType, OutputDataType >::Gradient | ( | arma::Mat< eT > && | input, |
arma::Mat< eT > && | , | ||
arma::Mat< eT > && | |||
) |
|
inline |
Get the gradient.
Definition at line 128 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::gradient.
|
inline |
Modify the gradient.
Definition at line 130 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::gradient.
|
inline |
Get the input parameter.
Definition at line 113 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::inputParameter.
|
inline |
Modify the input parameter.
Definition at line 115 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::inputParameter.
|
inline |
Get the model modules.
Definition at line 133 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::network, and mlpack::ann::LSTM< InputDataType, OutputDataType >::Serialize().
|
inline |
Get the output parameter.
Definition at line 118 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::outputParameter.
|
inline |
Modify the output parameter.
Definition at line 120 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::outputParameter.
|
inline |
Get the parameters.
Definition at line 108 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::weights.
|
inline |
Modify the parameters.
Definition at line 110 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::weights.
|
inline |
Get the maximum number of steps to backpropagate through time (BPTT).
Definition at line 103 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::rho.
|
inline |
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition at line 105 of file lstm.hpp.
References mlpack::ann::LSTM< InputDataType, OutputDataType >::rho.
void mlpack::ann::LSTM< InputDataType, OutputDataType >::Serialize | ( | Archive & | ar, |
const unsigned | int | ||
) |
Serialize the layer.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Model().
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
Locally-stored delta object.
Definition at line 222 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Delta().
|
private |
|
private |
If true dropout and scaling is disabled, see notes above.
Definition at line 219 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Deterministic().
|
private |
|
private |
|
private |
|
private |
Locally-stored gradient object.
Definition at line 225 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Gradient().
|
private |
|
private |
|
private |
|
private |
|
private |
Locally-stored input parameter object.
Definition at line 228 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::InputParameter().
|
private |
|
private |
Locally-stored list of network modules.
Definition at line 192 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Model().
|
private |
|
private |
|
private |
|
private |
Locally-stored output parameter object.
Definition at line 231 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::OutputParameter().
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
Number of steps to backpropagate through time (BPTT).
Definition at line 150 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Rho().
|
private |
Locally-stored weight object.
Definition at line 153 of file lstm.hpp.
Referenced by mlpack::ann::LSTM< InputDataType, OutputDataType >::Parameters().