mlpack  master
Public Member Functions | Private Member Functions | Private Attributes | List of all members
mlpack::ann::RecurrentAttention< InputDataType, OutputDataType > Class Template Reference

This class implements the Recurrent Model for Visual Attention, using a variety of possible layer implementations. More...

Public Member Functions

template<typename RNNModuleType , typename ActionModuleType >
 RecurrentAttention (const size_t outSize, const RNNModuleType &rnn, const ActionModuleType &action, const size_t rho)
 Create the RecurrentAttention object using the specified modules. More...
 
template<typename eT >
void Backward (const arma::Mat< eT > &&, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f. More...
 
OutputDataType const & Delta () const
 Get the delta. More...
 
OutputDataType & Delta ()
 Modify the delta. More...
 
bool Deterministic () const
 The value of the deterministic parameter. More...
 
bool & Deterministic ()
 Modify the value of the deterministic parameter. More...
 
template<typename eT >
void Forward (arma::Mat< eT > &&input, arma::Mat< eT > &&output)
 Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f. More...
 
template<typename eT >
void Gradient (arma::Mat< eT > &&, arma::Mat< eT > &&, arma::Mat< eT > &&)
 
OutputDataType const & Gradient () const
 Get the gradient. More...
 
OutputDataType & Gradient ()
 Modify the gradient. More...
 
InputDataType const & InputParameter () const
 Get the input parameter. More...
 
InputDataType & InputParameter ()
 Modify the input parameter. More...
 
std::vector< LayerTypes > & Model ()
 Get the model modules. More...
 
OutputDataType const & OutputParameter () const
 Get the output parameter. More...
 
OutputDataType & OutputParameter ()
 Modify the output parameter. More...
 
OutputDataType const & Parameters () const
 Get the parameters. More...
 
OutputDataType & Parameters ()
 Modify the parameters. More...
 
template<typename Archive >
void Serialize (Archive &ar, const unsigned int)
 Serialize the layer. More...
 

Private Member Functions

void IntermediateGradient ()
 Calculate the gradient of the attention module. More...
 

Private Attributes

arma::mat actionDelta
 Locally-stored action delta. More...
 
arma::mat actionError
 Locally-stored action error parameter. More...
 
LayerTypes actionModule
 Locally-stored input module. More...
 
arma::mat attentionGradient
 Locally-stored attention gradient. More...
 
size_t backwardStep
 Locally-stored number of backward steps. More...
 
OutputDataType delta
 Locally-stored delta object. More...
 
DeltaVisitor deltaVisitor
 Locally-stored delta visitor. More...
 
bool deterministic
 If true dropout and scaling is disabled, see notes above. More...
 
std::vector< arma::mat > feedbackOutputParameter
 Locally-stored feedback output parameters. More...
 
size_t forwardStep
 Locally-stored number of forward steps. More...
 
OutputDataType gradient
 Locally-stored gradient object. More...
 
arma::mat initialInput
 Locally-stored initial action input. More...
 
LayerTypes initialModule
 Locally-stored initial module. More...
 
InputDataType inputParameter
 Locally-stored input parameter object. More...
 
arma::mat intermediateGradient
 Locally-stored intermediate gradient for the attention module. More...
 
LayerTypes mergeModule
 Locally-stored merge module. More...
 
std::vector< arma::mat > moduleOutputParameter
 List of all module parameters for the backward pass (BBTT). More...
 
std::vector< LayerTypesnetwork
 Locally-stored model modules. More...
 
OutputDataType outputParameter
 Locally-stored output parameter object. More...
 
OutputParameterVisitor outputParameterVisitor
 Locally-stored output parameter visitor. More...
 
size_t outSize
 Locally-stored module output size. More...
 
OutputDataType parameters
 Locally-stored weight object. More...
 
arma::mat recurrentError
 Locally-stored recurrent error parameter. More...
 
LayerTypes recurrentModule
 Locally-stored recurrent module. More...
 
ResetVisitor resetVisitor
 Locally-stored reset visitor. More...
 
size_t rho
 Number of steps to backpropagate through time (BPTT). More...
 
arma::mat rnnDelta
 Locally-stored recurrent delta. More...
 
LayerTypes rnnModule
 Locally-stored start module. More...
 
WeightSizeVisitor weightSizeVisitor
 Locally-stored weight size visitor. More...
 

Detailed Description

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
class mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >

This class implements the Recurrent Model for Visual Attention, using a variety of possible layer implementations.

For more information, see the following paper.

@article{MnihHGK14,
title={Recurrent Models of Visual Attention},
author={Volodymyr Mnih, Nicolas Heess, Alex Graves, Koray Kavukcuoglu},
journal={CoRR},
volume={abs/1406.6247},
year={2014}
}
Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Definition at line 76 of file layer_types.hpp.

Constructor & Destructor Documentation

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
template<typename RNNModuleType , typename ActionModuleType >
mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::RecurrentAttention ( const size_t  outSize,
const RNNModuleType &  rnn,
const ActionModuleType &  action,
const size_t  rho 
)

Create the RecurrentAttention object using the specified modules.

Parameters
startThe module output size.
startThe recurrent neural network module.
startThe action module.
rhoMaximum number of steps to backpropagate through time (BPTT).

Member Function Documentation

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
template<typename eT >
void mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Backward ( const arma::Mat< eT > &&  ,
arma::Mat< eT > &&  gy,
arma::Mat< eT > &&  g 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f.

Using the results from the feed forward pass.

Parameters
inputThe propagated input activation.
gyThe backpropagated error.
gThe calculated gradient.
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType const& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Delta ( ) const
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Delta ( )
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
bool mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Deterministic ( ) const
inline

The value of the deterministic parameter.

Definition at line 112 of file recurrent_attention.hpp.

References mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::deterministic.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
bool& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Deterministic ( )
inline

Modify the value of the deterministic parameter.

Definition at line 114 of file recurrent_attention.hpp.

References mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::deterministic.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
template<typename eT >
void mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Forward ( arma::Mat< eT > &&  input,
arma::Mat< eT > &&  output 
)

Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f.

Parameters
inputInput data used for evaluating the specified function.
outputResulting output activation.
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
template<typename eT >
void mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Gradient ( arma::Mat< eT > &&  ,
arma::Mat< eT > &&  ,
arma::Mat< eT > &&   
)
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType const& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Gradient ( ) const
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Gradient ( )
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
InputDataType const& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::InputParameter ( ) const
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
InputDataType& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::InputParameter ( )
inline

Modify the input parameter.

Definition at line 124 of file recurrent_attention.hpp.

References mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::inputParameter.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
void mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient ( )
inlineprivate
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
std::vector<LayerTypes>& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Model ( )
inline

Get the model modules.

Definition at line 109 of file recurrent_attention.hpp.

References mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::network.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType const& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::OutputParameter ( ) const
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::OutputParameter ( )
inline

Modify the output parameter.

Definition at line 129 of file recurrent_attention.hpp.

References mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::outputParameter.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType const& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Parameters ( ) const
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType& mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Parameters ( )
inline
template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
template<typename Archive >
void mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Serialize ( Archive &  ar,
const unsigned  int 
)

Member Data Documentation

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::actionDelta
private

Locally-stored action delta.

Definition at line 244 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::actionError
private

Locally-stored action error parameter.

Definition at line 241 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
LayerTypes mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::actionModule
private

Locally-stored input module.

Definition at line 181 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::attentionGradient
private

Locally-stored attention gradient.

Definition at line 256 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
size_t mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::backwardStep
private

Locally-stored number of backward steps.

Definition at line 190 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::delta
private

Locally-stored delta object.

Definition at line 226 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Delta().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
DeltaVisitor mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::deltaVisitor
private

Locally-stored delta visitor.

Definition at line 214 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
bool mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::deterministic
private

If true dropout and scaling is disabled, see notes above.

Definition at line 193 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Deterministic().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
std::vector<arma::mat> mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::feedbackOutputParameter
private

Locally-stored feedback output parameters.

Definition at line 220 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
size_t mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::forwardStep
private

Locally-stored number of forward steps.

Definition at line 187 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::gradient
private

Locally-stored gradient object.

Definition at line 229 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Gradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::initialInput
private

Locally-stored initial action input.

Definition at line 250 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
LayerTypes mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::initialModule
private

Locally-stored initial module.

Definition at line 199 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
InputDataType mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::inputParameter
private

Locally-stored input parameter object.

Definition at line 232 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::InputParameter().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::intermediateGradient
private

Locally-stored intermediate gradient for the attention module.

Definition at line 259 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
LayerTypes mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::mergeModule
private

Locally-stored merge module.

Definition at line 208 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
std::vector<arma::mat> mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::moduleOutputParameter
private

List of all module parameters for the backward pass (BBTT).

Definition at line 223 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
std::vector<LayerTypes> mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::network
private

Locally-stored model modules.

Definition at line 205 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Model().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::outputParameter
private

Locally-stored output parameter object.

Definition at line 235 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::OutputParameter().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputParameterVisitor mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::outputParameterVisitor
private

Locally-stored output parameter visitor.

Definition at line 217 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
size_t mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::outSize
private

Locally-stored module output size.

Definition at line 175 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
OutputDataType mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::parameters
private

Locally-stored weight object.

Definition at line 196 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::Parameters().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::recurrentError
private

Locally-stored recurrent error parameter.

Definition at line 238 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
LayerTypes mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::recurrentModule
private

Locally-stored recurrent module.

Definition at line 202 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
ResetVisitor mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::resetVisitor
private

Locally-stored reset visitor.

Definition at line 253 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
size_t mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::rho
private

Number of steps to backpropagate through time (BPTT).

Definition at line 184 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
arma::mat mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::rnnDelta
private

Locally-stored recurrent delta.

Definition at line 247 of file recurrent_attention.hpp.

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
LayerTypes mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::rnnModule
private

Locally-stored start module.

Definition at line 178 of file recurrent_attention.hpp.

Referenced by mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::IntermediateGradient().

template<typename InputDataType = arma::mat, typename OutputDataType = arma::mat>
WeightSizeVisitor mlpack::ann::RecurrentAttention< InputDataType, OutputDataType >::weightSizeVisitor
private

Locally-stored weight size visitor.

Definition at line 211 of file recurrent_attention.hpp.


The documentation for this class was generated from the following files: