mlpack  master
projection_vector.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_SPILL_TREE_PROJECTION_VECTOR_HPP
13 #define MLPACK_CORE_TREE_SPILL_TREE_PROJECTION_VECTOR_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "../bounds.hpp"
17 namespace mlpack {
18 namespace tree {
19 
25 {
27  size_t dim;
28 
29  public:
35  AxisParallelProjVector(size_t dim = 0) :
36  dim(dim)
37  {};
38 
44  template<typename VecType>
45  double Project(const VecType& point,
46  typename std::enable_if_t<IsVector<VecType>::value>* = 0) const
47  {
48  return point[dim];
49  };
50 
57  template<typename MetricType, typename ElemType>
59  const bound::HRectBound<MetricType, ElemType>& bound) const
60  {
61  return bound[dim];
62  };
63 
70  template<typename MetricType, typename VecType>
72  const bound::BallBound<MetricType, VecType>& bound) const
73  {
74  return bound[dim];
75  };
76 
80  template<typename Archive>
81  void Serialize(Archive& ar, const unsigned int /* version */)
82  {
83  ar & data::CreateNVP(dim, "dim");
84  };
85 };
86 
92 {
94  arma::vec projVect;
95 
96  public:
101  projVect()
102  {};
103 
109  ProjVector(const arma::vec& vect) :
110  projVect(arma::normalise(vect))
111  {};
112 
118  template<typename VecType>
119  double Project(const VecType& point,
120  typename std::enable_if_t<IsVector<VecType>::value>* = 0) const
121  {
122  return arma::dot(point, projVect);
123  };
124 
131  template<typename MetricType, typename VecType>
133  const bound::BallBound<MetricType, VecType>& bound) const
134  {
135  typedef typename VecType::elem_type ElemType;
136  const double center = Project(bound.Center());
137  const ElemType radius = bound.Radius();
138  return math::RangeType<ElemType>(center - radius, center + radius);
139  };
140 
144  template<typename Archive>
145  void Serialize(Archive& ar, const unsigned int /* version */)
146  {
147  ar & data::CreateNVP(projVect, "projVect");
148  };
149 };
150 
151 } // namespace tree
152 } // namespace mlpack
153 
154 #endif
double Project(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Project the given point on the projection vector.
AxisParallelProjVector defines an axis-parallel projection vector.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
arma::vec projVect
Projection vector.
ElemType Radius() const
Get the radius of the ball.
Definition: ballbound.hpp:89
The core includes that mlpack expects; standard C++ includes and Armadillo.
math::RangeType< typename VecType::elem_type > Project(const bound::BallBound< MetricType, VecType > &bound) const
Project the given ball bound on the projection vector.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename std::enable_if_t< HasSerialize< T >::value > *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
size_t dim
Dimension considered.
ProjVector defines a general projection vector (not necessarily axis-parallel).
void Serialize(Archive &ar, const unsigned int)
Serialization.
Ball bound encloses a set of points at a specific distance (radius) from a specific point (center)...
Definition: ballbound.hpp:32
ProjVector(const arma::vec &vect)
Create the projection vector based on the specified vector.
void Serialize(Archive &ar, const unsigned int)
Serialization.
Hyper-rectangle bound for an L-metric.
Definition: hrectbound.hpp:54
ProjVector()
Empty Constructor.
AxisParallelProjVector(size_t dim=0)
Create the projection vector based on the specified dimension.
math::RangeType< typename VecType::elem_type > Project(const bound::BallBound< MetricType, VecType > &bound) const
Project the given ball bound on the projection vector.
math::RangeType< ElemType > Project(const bound::HRectBound< MetricType, ElemType > &bound) const
Project the given hrect bound on the projection vector.
const VecType & Center() const
Get the center point of the ball.
Definition: ballbound.hpp:94
double Project(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Project the given point on the projection vector.
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:59