mlpack  master
ns_model.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 
23 #include <boost/variant.hpp>
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
32 template<typename SortPolicy,
33  template<typename TreeMetricType,
34  typename TreeStatType,
35  typename TreeMatType> class TreeType>
36 using NSType = NeighborSearch<SortPolicy,
38  arma::mat,
39  TreeType,
41  NeighborSearchStat<SortPolicy>,
42  arma::mat>::template DualTreeTraverser>;
43 
44 template<typename SortPolicy>
46 {
47  static const std::string Name() { return "neighbor_search_model"; }
48 };
49 
50 template<>
52 {
53  static const std::string Name() { return "nearest_neighbor_search_model"; }
54 };
55 
56 template<>
58 {
59  static const std::string Name() { return "furthest_neighbor_search_model"; }
60 };
61 
66 class MonoSearchVisitor : public boost::static_visitor<void>
67 {
68  private:
70  const size_t k;
72  arma::Mat<size_t>& neighbors;
74  arma::mat& distances;
75 
76  public:
78  template<typename NSType>
79  void operator()(NSType* ns) const;
80 
82  MonoSearchVisitor(const size_t k,
83  arma::Mat<size_t>& neighbors,
84  arma::mat& distances) :
85  k(k),
86  neighbors(neighbors),
87  distances(distances)
88  {};
89 };
90 
97 template<typename SortPolicy>
98 class BiSearchVisitor : public boost::static_visitor<void>
99 {
100  private:
102  const arma::mat& querySet;
104  const size_t k;
106  arma::Mat<size_t>& neighbors;
108  arma::mat& distances;
110  const size_t leafSize;
112  const double tau;
114  const double rho;
115 
117  template<typename NSType>
118  void SearchLeaf(NSType* ns) const;
119 
120  public:
122  template<template<typename TreeMetricType,
123  typename TreeStatType,
124  typename TreeMatType> class TreeType>
126 
128  template<template<typename TreeMetricType,
129  typename TreeStatType,
130  typename TreeMatType> class TreeType>
131  void operator()(NSTypeT<TreeType>* ns) const;
132 
134  void operator()(NSTypeT<tree::KDTree>* ns) const;
135 
137  void operator()(NSTypeT<tree::BallTree>* ns) const;
138 
140  void operator()(SpillKNN* ns) const;
141 
143  void operator()(NSTypeT<tree::Octree>* ns) const;
144 
146  BiSearchVisitor(const arma::mat& querySet,
147  const size_t k,
148  arma::Mat<size_t>& neighbors,
149  arma::mat& distances,
150  const size_t leafSize,
151  const double tau,
152  const double rho);
153 };
154 
161 template<typename SortPolicy>
162 class TrainVisitor : public boost::static_visitor<void>
163 {
164  private:
166  arma::mat&& referenceSet;
168  size_t leafSize;
170  const double tau;
172  const double rho;
173 
175  template<typename NSType>
176  void TrainLeaf(NSType* ns) const;
177 
178  public:
180  template<template<typename TreeMetricType,
181  typename TreeStatType,
182  typename TreeMatType> class TreeType>
184 
186  template<template<typename TreeMetricType,
187  typename TreeStatType,
188  typename TreeMatType> class TreeType>
189  void operator()(NSTypeT<TreeType>* ns) const;
190 
192  void operator()(NSTypeT<tree::KDTree>* ns) const;
193 
195  void operator()(NSTypeT<tree::BallTree>* ns) const;
196 
198  void operator()(SpillKNN* ns) const;
199 
201  void operator()(NSTypeT<tree::Octree>* ns) const;
202 
205  TrainVisitor(arma::mat&& referenceSet,
206  const size_t leafSize,
207  const double tau,
208  const double rho);
209 };
210 
214 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode&>
215 {
216  public:
218  template<typename NSType>
219  NeighborSearchMode& operator()(NSType* ns) const;
220 };
221 
225 class EpsilonVisitor : public boost::static_visitor<double&>
226 {
227  public:
229  template<typename NSType>
230  double& operator()(NSType *ns) const;
231 };
232 
236 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
237 {
238  public:
240  template<typename NSType>
241  const arma::mat& operator()(NSType *ns) const;
242 };
243 
247 class DeleteVisitor : public boost::static_visitor<void>
248 {
249  public:
251  template<typename NSType>
252  void operator()(NSType *ns) const;
253 };
254 
265 template<typename SortPolicy>
266 class NSModel
267 {
268  public:
271  {
286  OCTREE
287  };
288 
289  private:
292 
294  size_t leafSize;
295 
297  double tau;
299  double rho;
300 
304  arma::mat q;
305 
311  boost::variant<NSType<SortPolicy, tree::KDTree>*,
323  SpillKNN*,
326 
327  public:
336  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
337 
343  NSModel(const NSModel& other);
344 
350  NSModel(NSModel&& other);
351 
357  NSModel& operator=(const NSModel& other);
358 
364  NSModel& operator=(NSModel&& other);
365 
367  ~NSModel();
368 
370  template<typename Archive>
371  void Serialize(Archive& ar, const unsigned int /* version */);
372 
374  const arma::mat& Dataset() const;
375 
377  NeighborSearchMode SearchMode() const;
378  NeighborSearchMode& SearchMode();
379 
381  double Epsilon() const;
382  double& Epsilon();
383 
385  size_t LeafSize() const { return leafSize; }
386  size_t& LeafSize() { return leafSize; }
387 
389  double Tau() const { return tau; }
390  double& Tau() { return tau; }
391 
393  double Rho() const { return rho; }
394  double& Rho() { return rho; }
395 
397  TreeTypes TreeType() const { return treeType; }
398  TreeTypes& TreeType() { return treeType; }
399 
401  bool RandomBasis() const { return randomBasis; }
402  bool& RandomBasis() { return randomBasis; }
403 
405  void BuildModel(arma::mat&& referenceSet,
406  const size_t leafSize,
407  const NeighborSearchMode searchMode,
408  const double epsilon = 0);
409 
411  void Search(arma::mat&& querySet,
412  const size_t k,
413  arma::Mat<size_t>& neighbors,
414  arma::mat& distances);
415 
417  void Search(const size_t k,
418  arma::Mat<size_t>& neighbors,
419  arma::mat& distances);
420 
422  std::string TreeName() const;
423 };
424 
425 } // namespace neighbor
426 } // namespace mlpack
427 
429 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
431 
432 // Include implementation.
433 #include "ns_model_impl.hpp"
434 
435 #endif
static const std::string Name()
Definition: ns_model.hpp:47
NeighborSearch< SortPolicy, metric::EuclideanDistance, arma::mat, TreeType, TreeType< metric::EuclideanDistance, NeighborSearchStat< SortPolicy >, arma::mat >::template DualTreeTraverser > NSType
Alias template for euclidean neighbor search.
Definition: ns_model.hpp:42
EpsilonVisitor exposes the Epsilon method of the given NSType.
Definition: ns_model.hpp:225
arma::mat && referenceSet
The reference set to use for training.
Definition: ns_model.hpp:166
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
double Rho() const
Expose rho.
Definition: ns_model.hpp:393
const double rho
Balance threshold (for spill trees).
Definition: ns_model.hpp:172
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
const size_t k
Number of neighbors to search for.
Definition: ns_model.hpp:70
arma::mat q
This is the random projection matrix; only used if randomBasis is true.
Definition: ns_model.hpp:304
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::NSModel< SortPolicy >, 1)
Set the serialization version of the NSModel class.
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:270
const size_t leafSize
The number of points in a leaf (for BinarySpaceTrees).
Definition: ns_model.hpp:110
ReferenceSetVisitor exposes the referenceSet of the given NSType.
Definition: ns_model.hpp:236
SearchModeVisitor exposes the SearchMode() method of the given NSType.
Definition: ns_model.hpp:214
The NeighborSearch class is a template class for performing distance-based neighbor searches...
arma::Mat< size_t > & neighbors
Result matrix for neighbors.
Definition: ns_model.hpp:72
arma::mat & distances
Result matrix for distances.
Definition: ns_model.hpp:74
const size_t k
The number of neighbors to search for.
Definition: ns_model.hpp:104
MonoSearchVisitor(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Construct the MonoSearchVisitor object with the given parameters.
Definition: ns_model.hpp:82
size_t leafSize
The leaf size, used only by BinarySpaceTree.
Definition: ns_model.hpp:168
const double tau
Overlapping size (for spill trees).
Definition: ns_model.hpp:170
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
Definition: ns_model.hpp:98
size_t leafSize
For tree types that accept the maxLeafSize parameter.
Definition: ns_model.hpp:294
arma::Mat< size_t > & neighbors
The result matrix for neighbors.
Definition: ns_model.hpp:106
double tau
Overlapping size (for spill trees).
Definition: ns_model.hpp:297
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:266
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:401
TrainVisitor sets the reference set to a new reference set on the given NSType.
const double tau
Overlapping size (for spill trees).
Definition: ns_model.hpp:112
const double rho
Balance threshold (for spill trees).
Definition: ns_model.hpp:114
const arma::mat & querySet
The query set for the bichromatic search.
Definition: ns_model.hpp:102
boost::variant< NSType< SortPolicy, tree::KDTree > *, NSType< SortPolicy, tree::StandardCoverTree > *, NSType< SortPolicy, tree::RTree > *, NSType< SortPolicy, tree::RStarTree > *, NSType< SortPolicy, tree::BallTree > *, NSType< SortPolicy, tree::XTree > *, NSType< SortPolicy, tree::HilbertRTree > *, NSType< SortPolicy, tree::RPlusTree > *, NSType< SortPolicy, tree::RPlusPlusTree > *, NSType< SortPolicy, tree::VPTree > *, NSType< SortPolicy, tree::RPTree > *, NSType< SortPolicy, tree::MaxRPTree > *, SpillKNN *, NSType< SortPolicy, tree::UBTree > *, NSType< SortPolicy, tree::Octree > * > nSearch
nSearch holds an instance of the NeigborSearch class for the current treeType.
Definition: ns_model.hpp:325
size_t LeafSize() const
Expose leafSize.
Definition: ns_model.hpp:385
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:397
TreeTypes & TreeType()
Definition: ns_model.hpp:398
double Tau() const
Expose tau.
Definition: ns_model.hpp:389
MonoSearchVisitor executes a monochromatic neighbor search on the given NSType.
Definition: ns_model.hpp:66
arma::mat & distances
The result matrix for distances.
Definition: ns_model.hpp:108
TreeTypes treeType
Tree type considered for neighbor search.
Definition: ns_model.hpp:291
test cpp RESULT_VARIABLE MEX_RESULT_TRASH OUTPUT_VARIABLE MEX_OUTPUT ERROR_VARIABLE MEX_ERROR_TRASH string(REGEX MATCH"Warning: You are using"MEX_WARNING"${MEX_OUTPUT}") if(MEX_WARNING) string(REGEX REPLACE".*using [a-zA-Z]* version \"([0-9.]*)[^\"]*\".*""\\1"OTHER_COMPILER_VERSION"$
Definition: CMakeLists.txt:18
DeleteVisitor deletes the given NSType instance.
Definition: ns_model.hpp:247
bool randomBasis
If true, random projections are used.
Definition: ns_model.hpp:302
double rho
Balance threshold (for spill trees).
Definition: ns_model.hpp:299