mlpack  master
cover_tree.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
13 #define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
18 #include "../statistic.hpp"
19 #include "first_point_is_root.hpp"
20 
21 namespace mlpack {
22 namespace tree {
23 
95 template<typename MetricType = metric::LMetric<2, true>,
96  typename StatisticType = EmptyStatistic,
97  typename MatType = arma::mat,
98  typename RootPointPolicy = FirstPointIsRoot>
99 class CoverTree
100 {
101  public:
103  typedef MatType Mat;
105  typedef typename MatType::elem_type ElemType;
106 
117  CoverTree(const MatType& dataset,
118  const ElemType base = 2.0,
119  MetricType* metric = NULL);
120 
130  CoverTree(const MatType& dataset,
131  MetricType& metric,
132  const ElemType base = 2.0);
133 
141  CoverTree(MatType&& dataset,
142  const ElemType base = 2.0);
143 
152  CoverTree(MatType&& dataset,
153  MetricType& metric,
154  const ElemType base = 2.0);
155 
187  CoverTree(const MatType& dataset,
188  const ElemType base,
189  const size_t pointIndex,
190  const int scale,
191  CoverTree* parent,
192  const ElemType parentDistance,
193  arma::Col<size_t>& indices,
194  arma::vec& distances,
195  size_t nearSetSize,
196  size_t& farSetSize,
197  size_t& usedSetSize,
198  MetricType& metric = NULL);
199 
216  CoverTree(const MatType& dataset,
217  const ElemType base,
218  const size_t pointIndex,
219  const int scale,
220  CoverTree* parent,
221  const ElemType parentDistance,
222  const ElemType furthestDescendantDistance,
223  MetricType* metric = NULL);
224 
231  CoverTree(const CoverTree& other);
232 
239  CoverTree(CoverTree&& other);
240 
244  template<typename Archive>
245  CoverTree(
246  Archive& ar,
248 
252  ~CoverTree();
253 
256  template<typename RuleType>
258 
260  template<typename RuleType>
262 
263  template<typename RuleType>
265 
267  const MatType& Dataset() const { return *dataset; }
268 
270  size_t Point() const { return point; }
272  size_t Point(const size_t) const { return point; }
273 
274  bool IsLeaf() const { return (children.size() == 0); }
275  size_t NumPoints() const { return 1; }
276 
278  const CoverTree& Child(const size_t index) const { return *children[index]; }
280  CoverTree& Child(const size_t index) { return *children[index]; }
281 
282  CoverTree*& ChildPtr(const size_t index) { return children[index]; }
283 
285  size_t NumChildren() const { return children.size(); }
286 
288  const std::vector<CoverTree*>& Children() const { return children; }
290  std::vector<CoverTree*>& Children() { return children; }
291 
293  size_t NumDescendants() const;
294 
296  size_t Descendant(const size_t index) const;
297 
299  int Scale() const { return scale; }
301  int& Scale() { return scale; }
302 
304  ElemType Base() const { return base; }
306  ElemType& Base() { return base; }
307 
309  const StatisticType& Stat() const { return stat; }
311  StatisticType& Stat() { return stat; }
312 
317  template<typename VecType>
318  size_t GetNearestChild(
319  const VecType& point,
321 
326  template<typename VecType>
327  size_t GetFurthestChild(
328  const VecType& point,
330 
335  size_t GetNearestChild(const CoverTree& queryNode);
336 
341  size_t GetFurthestChild(const CoverTree& queryNode);
342 
344  ElemType MinDistance(const CoverTree& other) const;
345 
348  ElemType MinDistance(const CoverTree& other, const ElemType distance) const;
349 
351  ElemType MinDistance(const arma::vec& other) const;
352 
355  ElemType MinDistance(const arma::vec& other, const ElemType distance) const;
356 
358  ElemType MaxDistance(const CoverTree& other) const;
359 
362  ElemType MaxDistance(const CoverTree& other, const ElemType distance) const;
363 
365  ElemType MaxDistance(const arma::vec& other) const;
366 
369  ElemType MaxDistance(const arma::vec& other, const ElemType distance) const;
370 
373 
377  const ElemType distance) const;
378 
380  math::RangeType<ElemType> RangeDistance(const arma::vec& other) const;
381 
384  math::RangeType<ElemType> RangeDistance(const arma::vec& other,
385  const ElemType distance) const;
386 
388  CoverTree* Parent() const { return parent; }
390  CoverTree*& Parent() { return parent; }
391 
393  ElemType ParentDistance() const { return parentDistance; }
395  ElemType& ParentDistance() { return parentDistance; }
396 
398  ElemType FurthestPointDistance() const { return 0.0; }
399 
401  ElemType FurthestDescendantDistance() const
402  { return furthestDescendantDistance; }
406 
410 
412  void Center(arma::vec& center) const
413  {
414  center = arma::vec(dataset->col(point));
415  }
416 
418  MetricType& Metric() const { return *metric; }
419 
420  private:
422  const MatType* dataset;
424  size_t point;
426  std::vector<CoverTree*> children;
428  int scale;
430  ElemType base;
432  StatisticType stat;
438  ElemType parentDistance;
446  MetricType* metric;
447 
451  void CreateChildren(arma::Col<size_t>& indices,
452  arma::vec& distances,
453  size_t nearSetSize,
454  size_t& farSetSize,
455  size_t& usedSetSize);
456 
468  void ComputeDistances(const size_t pointIndex,
469  const arma::Col<size_t>& indices,
470  arma::vec& distances,
471  const size_t pointSetSize);
486  size_t SplitNearFar(arma::Col<size_t>& indices,
487  arma::vec& distances,
488  const ElemType bound,
489  const size_t pointSetSize);
490 
510  size_t SortPointSet(arma::Col<size_t>& indices,
511  arma::vec& distances,
512  const size_t childFarSetSize,
513  const size_t childUsedSetSize,
514  const size_t farSetSize);
515 
516  void MoveToUsedSet(arma::Col<size_t>& indices,
517  arma::vec& distances,
518  size_t& nearSetSize,
519  size_t& farSetSize,
520  size_t& usedSetSize,
521  arma::Col<size_t>& childIndices,
522  const size_t childFarSetSize,
523  const size_t childUsedSetSize);
524  size_t PruneFarSet(arma::Col<size_t>& indices,
525  arma::vec& distances,
526  const ElemType bound,
527  const size_t nearSetSize,
528  const size_t pointSetSize);
529 
534  void RemoveNewImplicitNodes();
535 
536  protected:
543  CoverTree();
544 
546  friend class boost::serialization::access;
547 
548  public:
552  template<typename Archive>
553  void Serialize(Archive& ar, const unsigned int /* version */);
554 
555  size_t DistanceComps() const { return distanceComps; }
556  size_t& DistanceComps() { return distanceComps; }
557 
558  private:
560 };
561 
562 } // namespace tree
563 } // namespace mlpack
564 
565 // Include implementation.
566 #include "cover_tree_impl.hpp"
567 
568 // Include the rest of the pieces, if necessary.
569 #include "../cover_tree.hpp"
570 
571 #endif
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
void Center(arma::vec &center) const
Get the center of the node and store it in the given vector.
Definition: cover_tree.hpp:412
size_t point
Index of the point in the matrix which this node represents.
Definition: cover_tree.hpp:424
ElemType & FurthestDescendantDistance()
Modify the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:405
const MatType * dataset
Reference to the matrix which this tree is built on.
Definition: cover_tree.hpp:422
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:288
CoverTree()
A default constructor.
StatisticType & Stat()
Modify the statistic for this node.
Definition: cover_tree.hpp:311
std::vector< CoverTree * > & Children()
Modify the children manually (maybe not a great idea).
Definition: cover_tree.hpp:290
size_t SortPointSet(arma::Col< size_t > &indices, arma::vec &distances, const size_t childFarSetSize, const size_t childUsedSetSize, const size_t farSetSize)
Assuming that the list of indices and distances is sorted as [ childFarSet | childUsedSet | farSet | ...
size_t Point(const size_t) const
For compatibility with other trees; the argument is ignored.
Definition: cover_tree.hpp:272
MetricType * metric
The metric used for this tree.
Definition: cover_tree.hpp:446
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
Definition: cover_tree.hpp:261
StatisticType stat
The instantiated statistic.
Definition: cover_tree.hpp:432
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:267
void RemoveNewImplicitNodes()
Take a look at the last child (the most recently created one) and remove any implicit nodes that have...
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
size_t NumPoints() const
Definition: cover_tree.hpp:275
The core includes that mlpack expects; standard C++ includes and Armadillo.
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
ElemType FurthestPointDistance() const
Get the distance to the furthest point. This is always 0 for cover trees.
Definition: cover_tree.hpp:398
size_t NumDescendants() const
Get the number of descendant points.
ElemType & ParentDistance()
Modify the distance to the parent.
Definition: cover_tree.hpp:395
void ComputeDistances(const size_t pointIndex, const arma::Col< size_t > &indices, arma::vec &distances, const size_t pointSetSize)
Fill the vector of distances with the distances between the point specified by pointIndex and each po...
MetricType & Metric() const
Get the instantiated metric.
Definition: cover_tree.hpp:418
ElemType base
The base used to construct the tree.
Definition: cover_tree.hpp:430
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
MatType Mat
So that other classes can access the matrix type.
Definition: cover_tree.hpp:103
ElemType MinimumBoundDistance() const
Get the minimum distance from the center to any bound edge (this is the same as furthestDescendantDis...
Definition: cover_tree.hpp:409
void Serialize(Archive &ar, const unsigned int)
Serialize the tree.
ElemType Base() const
Get the base.
Definition: cover_tree.hpp:304
CoverTree *& Parent()
Modify the parent node.
Definition: cover_tree.hpp:390
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:285
int & Scale()
Modify the scale of this node. Be careful...
Definition: cover_tree.hpp:301
const StatisticType & Stat() const
Get the statistic for this node.
Definition: cover_tree.hpp:309
CoverTree * parent
The parent node (NULL if this is the root of the tree).
Definition: cover_tree.hpp:436
ElemType parentDistance
Distance to the parent.
Definition: cover_tree.hpp:438
A single-tree cover tree traverser; see single_tree_traverser.hpp for implementation.
Definition: cover_tree.hpp:257
int scale
Scale level of the node.
Definition: cover_tree.hpp:428
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
std::vector< CoverTree * > children
The list of children; the first is the self-child.
Definition: cover_tree.hpp:426
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:393
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:278
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:270
ElemType furthestDescendantDistance
Distance to the furthest descendant.
Definition: cover_tree.hpp:440
Definition of the Range class, which represents a simple range with a lower and upper bound...
size_t SplitNearFar(arma::Col< size_t > &indices, arma::vec &distances, const ElemType bound, const size_t pointSetSize)
Split the given indices and distances into a near and a far set, returning the number of points in th...
size_t numDescendants
The number of descendant points.
Definition: cover_tree.hpp:434
bool localDataset
If true, we own the dataset and need to destroy it in the destructor.
Definition: cover_tree.hpp:444
size_t DistanceComps() const
Definition: cover_tree.hpp:555
size_t PruneFarSet(arma::Col< size_t > &indices, arma::vec &distances, const ElemType bound, const size_t nearSetSize, const size_t pointSetSize)
ElemType & Base()
Modify the base; don&#39;t do this, you&#39;ll break everything.
Definition: cover_tree.hpp:306
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
void CreateChildren(arma::Col< size_t > &indices, arma::vec &distances, size_t nearSetSize, size_t &farSetSize, size_t &usedSetSize)
Create the children for this node.
CoverTree & Child(const size_t index)
Modify a particular child node.
Definition: cover_tree.hpp:280
CoverTree *& ChildPtr(const size_t index)
Definition: cover_tree.hpp:282
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:388
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:401
~CoverTree()
Delete this cover tree node and its children.
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:59
bool localMetric
Whether or not we need to destroy the metric in the destructor.
Definition: cover_tree.hpp:442
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:299
void MoveToUsedSet(arma::Col< size_t > &indices, arma::vec &distances, size_t &nearSetSize, size_t &farSetSize, size_t &usedSetSize, arma::Col< size_t > &childIndices, const size_t childFarSetSize, const size_t childUsedSetSize)