mlpack  master
binary_space_tree.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
12 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
13 
14 #include <mlpack/prereqs.hpp>
15 
16 #include "../statistic.hpp"
17 #include "midpoint_split.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
47 template<typename MetricType,
48  typename StatisticType = EmptyStatistic,
49  typename MatType = arma::mat,
50  template<typename BoundMetricType, typename...> class BoundType =
52  template<typename SplitBoundType, typename SplitMatType>
53  class SplitType = MidpointSplit>
55 {
56  public:
58  typedef MatType Mat;
60  typedef typename MatType::elem_type ElemType;
61 
62  typedef SplitType<BoundType<MetricType>, MatType> Split;
63 
64  private:
73  size_t begin;
76  size_t count;
78  BoundType<MetricType> bound;
80  StatisticType stat;
82  ElemType parentDistance;
90  MatType* dataset;
91 
92  public:
95  template<typename RuleType>
97 
99  template<typename RuleType>
101 
102  template<typename RuleType>
104 
113  BinarySpaceTree(const MatType& data, const size_t maxLeafSize = 20);
114 
127  BinarySpaceTree(const MatType& data,
128  std::vector<size_t>& oldFromNew,
129  const size_t maxLeafSize = 20);
130 
146  BinarySpaceTree(const MatType& data,
147  std::vector<size_t>& oldFromNew,
148  std::vector<size_t>& newFromOld,
149  const size_t maxLeafSize = 20);
150 
160  BinarySpaceTree(MatType&& data,
161  const size_t maxLeafSize = 20);
162 
175  BinarySpaceTree(MatType&& data,
176  std::vector<size_t>& oldFromNew,
177  const size_t maxLeafSize = 20);
178 
194  BinarySpaceTree(MatType&& data,
195  std::vector<size_t>& oldFromNew,
196  std::vector<size_t>& newFromOld,
197  const size_t maxLeafSize = 20);
198 
212  const size_t begin,
213  const size_t count,
214  SplitType<BoundType<MetricType>, MatType>& splitter,
215  const size_t maxLeafSize = 20);
216 
237  const size_t begin,
238  const size_t count,
239  std::vector<size_t>& oldFromNew,
240  SplitType<BoundType<MetricType>, MatType>& splitter,
241  const size_t maxLeafSize = 20);
242 
265  const size_t begin,
266  const size_t count,
267  std::vector<size_t>& oldFromNew,
268  std::vector<size_t>& newFromOld,
269  SplitType<BoundType<MetricType>, MatType>& splitter,
270  const size_t maxLeafSize = 20);
271 
278  BinarySpaceTree(const BinarySpaceTree& other);
279 
285 
291  template<typename Archive>
293  Archive& ar,
295 
302 
304  const BoundType<MetricType>& Bound() const { return bound; }
306  BoundType<MetricType>& Bound() { return bound; }
307 
309  const StatisticType& Stat() const { return stat; }
311  StatisticType& Stat() { return stat; }
312 
314  bool IsLeaf() const;
315 
317  BinarySpaceTree* Left() const { return left; }
319  BinarySpaceTree*& Left() { return left; }
320 
322  BinarySpaceTree* Right() const { return right; }
324  BinarySpaceTree*& Right() { return right; }
325 
327  BinarySpaceTree* Parent() const { return parent; }
329  BinarySpaceTree*& Parent() { return parent; }
330 
332  const MatType& Dataset() const { return *dataset; }
334  MatType& Dataset() { return *dataset; }
335 
337  MetricType Metric() const { return MetricType(); }
338 
340  size_t NumChildren() const;
341 
346  template<typename VecType>
347  size_t GetNearestChild(
348  const VecType& point,
350 
355  template<typename VecType>
356  size_t GetFurthestChild(
357  const VecType& point,
359 
364  size_t GetNearestChild(const BinarySpaceTree& queryNode);
365 
370  size_t GetFurthestChild(const BinarySpaceTree& queryNode);
371 
376  ElemType FurthestPointDistance() const;
377 
385  ElemType FurthestDescendantDistance() const;
386 
388  ElemType MinimumBoundDistance() const;
389 
392  ElemType ParentDistance() const { return parentDistance; }
395  ElemType& ParentDistance() { return parentDistance; }
396 
403  BinarySpaceTree& Child(const size_t child) const;
404 
405  BinarySpaceTree*& ChildPtr(const size_t child)
406  { return (child == 0) ? left : right; }
407 
409  size_t NumPoints() const;
410 
416  size_t NumDescendants() const;
417 
425  size_t Descendant(const size_t index) const;
426 
435  size_t Point(const size_t index) const;
436 
438  ElemType MinDistance(const BinarySpaceTree& other) const
439  {
440  return bound.MinDistance(other.Bound());
441  }
442 
444  ElemType MaxDistance(const BinarySpaceTree& other) const
445  {
446  return bound.MaxDistance(other.Bound());
447  }
448 
451  {
452  return bound.RangeDistance(other.Bound());
453  }
454 
456  template<typename VecType>
457  ElemType MinDistance(const VecType& point,
459  const
460  {
461  return bound.MinDistance(point);
462  }
463 
465  template<typename VecType>
466  ElemType MaxDistance(const VecType& point,
468  const
469  {
470  return bound.MaxDistance(point);
471  }
472 
474  template<typename VecType>
476  RangeDistance(const VecType& point,
477  typename std::enable_if_t<IsVector<VecType>::value>* = 0) const
478  {
479  return bound.RangeDistance(point);
480  }
481 
483  size_t Begin() const { return begin; }
485  size_t& Begin() { return begin; }
486 
488  size_t Count() const { return count; }
490  size_t& Count() { return count; }
491 
493  void Center(arma::vec& center) const { bound.Center(center); }
494 
495  private:
502  void SplitNode(const size_t maxLeafSize,
503  SplitType<BoundType<MetricType>, MatType>& splitter);
504 
513  void SplitNode(std::vector<size_t>& oldFromNew,
514  const size_t maxLeafSize,
515  SplitType<BoundType<MetricType>, MatType>& splitter);
516 
523  template<typename BoundType2>
524  void UpdateBound(BoundType2& boundToUpdate);
525 
532  void UpdateBound(bound::HollowBallBound<MetricType>& boundToUpdate);
533 
534  protected:
541  BinarySpaceTree();
542 
544  friend class boost::serialization::access;
545 
546  public:
550  template<typename Archive>
551  void Serialize(Archive& ar, const unsigned int version);
552 };
553 
554 } // namespace tree
555 } // namespace mlpack
556 
557 // Include implementation.
558 #include "binary_space_tree_impl.hpp"
559 
560 // Include everything else, if necessary.
561 #include "../binary_space_tree.hpp"
562 
563 #endif
void UpdateBound(BoundType2 &boundToUpdate)
Update the bound of the current node.
BinarySpaceTree *& Parent()
Modify the parent of this node.
BinarySpaceTree * left
The left child node.
void Serialize(Archive &ar, const unsigned int version)
Serialize the tree.
MatType * dataset
The dataset.
A dual-tree traverser for binary space trees; see dual_tree_traverser.hpp.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
MatType Mat
So other classes can use TreeType::Mat.
MatType::elem_type ElemType
The type of element held in MatType.
const BoundType< MetricType > & Bound() const
Return the bound object for this node.
ElemType & ParentDistance()
Modify the distance from the center of this node to the center of the parent node.
BinarySpaceTree * right
The right child node.
bool IsLeaf() const
Return whether or not this node is a leaf (true if it has no children).
size_t NumDescendants() const
Return the number of descendants of this node.
ElemType parentDistance
The distance from the centroid of this node to the centroid of the parent.
The core includes that mlpack expects; standard C++ includes and Armadillo.
BinarySpaceTree * Left() const
Gets the left child of this node.
BoundType< MetricType > bound
The bound object for this node.
ElemType MinDistance(const BinarySpaceTree &other) const
Return the minimum distance to another node.
MatType & Dataset()
Modify the dataset which the tree is built on. Be careful!
A binary space partitioning tree, such as a KD-tree or a ball tree.
ElemType MaxDistance(const BinarySpaceTree &other) const
Return the maximum distance to another node.
BinarySpaceTree * Right() const
Gets the right child of this node.
void Center(arma::vec &center) const
Store the center of the bounding region in the given vector.
ElemType furthestDescendantDistance
The worst possible distance to the furthest descendant, cached to speed things up.
~BinarySpaceTree()
Deletes this node, deallocating the memory for the children and calling their destructors in turn...
A binary space partitioning tree node is split into its left and right child.
math::RangeType< ElemType > RangeDistance(const BinarySpaceTree &other) const
Return the minimum and maximum distance to another node.
ElemType ParentDistance() const
Return the distance from the center of this node to the center of the parent node.
BinarySpaceTree *& Right()
Modify the right child of this node.
Hyper-rectangle bound for an L-metric.
Definition: hrectbound.hpp:54
ElemType MinimumBoundDistance() const
Return the minimum distance from the center of the node to any bound edge.
MetricType Metric() const
Get the metric that the tree uses.
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
size_t Point(const size_t index) const
Return the index (with reference to the dataset) of a particular point in this node.
const MatType & Dataset() const
Get the dataset which the tree is built on.
StatisticType & Stat()
Return the statistic object for this node.
ElemType minimumBoundDistance
The minimum distance from the center to any edge of the bound.
A single-tree traverser for binary space trees; see single_tree_traverser.hpp for implementation...
BinarySpaceTree()
A default constructor.
BinarySpaceTree *& ChildPtr(const size_t child)
size_t Descendant(const size_t index) const
Return the index (with reference to the dataset) of a particular descendant of this node...
ElemType FurthestDescendantDistance() const
Return the furthest possible descendant distance.
ElemType MinDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the minimum distance to another point.
size_t NumPoints() const
Return the number of points in this node (0 if not a leaf).
const StatisticType & Stat() const
Return the statistic object for this node.
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
BinarySpaceTree * Parent() const
Gets the parent of this node.
BinarySpaceTree & Child(const size_t child) const
Return the specified child (0 will be left, 1 will be right).
size_t Begin() const
Return the index of the beginning point of this subset.
math::RangeType< ElemType > RangeDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the minimum and maximum distance to another point.
BoundType< MetricType > & Bound()
Return the bound object for this node.
size_t & Count()
Modify the number of points in this subset.
void SplitNode(const size_t maxLeafSize, SplitType< BoundType< MetricType >, MatType > &splitter)
Splits the current node, assigning its left and right children recursively.
ElemType MaxDistance(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0) const
Return the maximum distance to another point.
size_t begin
The index of the first point in the dataset contained in this node (and its children).
size_t Count() const
Return the number of points in this subset.
size_t count
The number of points of the dataset contained in this node (and its children).
SplitType< BoundType< MetricType >, MatType > Split
StatisticType stat
Any extra data contained in the node.
BinarySpaceTree *& Left()
Modify the left child of this node.
Hollow ball bound encloses a set of points at a specific distance (radius) from a specific point (cen...
ElemType FurthestPointDistance() const
Return the furthest distance to a point held in this node.
size_t & Begin()
Modify the index of the beginning point of this subset.
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
size_t NumChildren() const
Return the number of children in this node.
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:59
BinarySpaceTree * parent
The parent node (NULL if this is the root of the tree).
Empty statistic if you are not interested in storing statistics in your tree.
Definition: statistic.hpp:24