mlpack  master
ub_tree_split.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "../address.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
28 template<typename BoundType, typename MatType = arma::mat>
30 {
31  public:
33  typedef typename std::conditional<sizeof(typename MatType::elem_type) * CHAR_BIT <= 32,
34  uint32_t,
35  uint64_t>::type AddressElemType;
36 
38  struct SplitInfo
39  {
41  std::vector<std::pair<arma::Col<AddressElemType>, size_t>>* addresses;
42  };
43 
55  bool SplitNode(BoundType& bound,
56  MatType& data,
57  const size_t begin,
58  const size_t count,
59  SplitInfo& splitInfo);
60 
70  static size_t PerformSplit(MatType& data,
71  const size_t begin,
72  const size_t count,
73  const SplitInfo& splitInfo);
74 
87  static size_t PerformSplit(MatType& data,
88  const size_t begin,
89  const size_t count,
90  const SplitInfo& splitInfo,
91  std::vector<size_t>& oldFromNew);
92 
93  private:
95  std::vector<std::pair<arma::Col<AddressElemType>, size_t>> addresses;
96 
102  void InitializeAddresses(const MatType& data);
103 
105  static bool ComparePair(
106  const std::pair<arma::Col<AddressElemType>, size_t>& p1,
107  const std::pair<arma::Col<AddressElemType>, size_t>& p2)
108  {
109  return bound::addr::CompareAddresses(p1.first, p2.first) < 0;
110  }
111 };
112 
113 } // namespace tree
114 } // namespace mlpack
115 
116 // Include implementation.
117 #include "ub_tree_split_impl.hpp"
118 
119 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const typename SplitType::SplitInfo &splitInfo)
This function implements the default split behavior i.e.
Split a node into two parts according to the median address of points contained in the node...
int CompareAddresses(const AddressType1 &addr1, const AddressType2 &addr2)
Compare two addresses.
Definition: address.hpp:233