mlpack  master
weight_set_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_WEIGHT_SET_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_WEIGHT_SET_VISITOR_HPP
15 
18 
19 #include <boost/variant.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
27 class WeightSetVisitor : public boost::static_visitor<size_t>
28 {
29  public:
31  WeightSetVisitor(arma::mat&& weight, const size_t offset = 0);
32 
34  template<typename LayerType>
35  size_t operator()(LayerType* layer) const;
36 
37  private:
39  arma::mat&& weight;
40 
42  const size_t offset;
43 
46  template<typename T, typename P>
47  typename std::enable_if<
48  !HasParametersCheck<T, P&(T::*)()>::value &&
49  !HasModelCheck<T, std::vector<LayerTypes>&(T::*)()>::value, size_t>::type
50  LayerSize(T* layer, P&& input) const;
51 
53  template<typename T, typename P>
54  typename std::enable_if<
55  !HasParametersCheck<T, P&(T::*)()>::value &&
56  HasModelCheck<T, std::vector<LayerTypes>&(T::*)()>::value, size_t>::type
57  LayerSize(T* layer, P&& input) const;
58 
60  template<typename T, typename P>
61  typename std::enable_if<
62  HasParametersCheck<T, P&(T::*)()>::value &&
63  !HasModelCheck<T, std::vector<LayerTypes>&(T::*)()>::value, size_t>::type
64  LayerSize(T* layer, P&& input) const;
65 
68  template<typename T, typename P>
69  typename std::enable_if<
70  HasParametersCheck<T, P&(T::*)()>::value &&
71  HasModelCheck<T, std::vector<LayerTypes>&(T::*)()>::value, size_t>::type
72  LayerSize(T* layer, P&& input) const;
73 };
74 
75 } // namespace ann
76 } // namespace mlpack
77 
78 // Include implementation.
79 #include "weight_set_visitor_impl.hpp"
80 
81 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
std::enable_if< !HasParametersCheck< T, P &(T::*)()>::value &&!HasModelCheck< T, std::vector< LayerTypes > &(T::*)()>::value, size_t >::type LayerSize(T *layer, P &&input) const
Do not update the parameters if the module doesn&#39;t implement the Parameters() or Model() function...
arma::mat && weight
The parameters set.
WeightSetVisitor(arma::mat &&weight, const size_t offset=0)
Update the parameters given the parameters set and offset.
const size_t offset
The parameters offset.
size_t operator()(LayerType *layer) const
Update the parameters set.
WeightSetVisitor update the module parameters given the parameters set.