mlpack  master
dtree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DET_DTREE_HPP
14 #define MLPACK_METHODS_DET_DTREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace det {
20 
44 template <typename MatType,
45  typename TagType = int>
46 class DTree
47 {
48  public:
52  typedef typename MatType::elem_type ElemType;
53  typedef typename MatType::vec_type VecType;
54  typedef typename arma::Col<ElemType> StatType;
55 
59  DTree();
60 
66  DTree(const DTree& obj);
67 
73  DTree& operator=(const DTree& obj);
74 
80  DTree(DTree&& obj);
81 
87  DTree& operator=(DTree&& obj);
88 
97  DTree(const StatType& maxVals,
98  const StatType& minVals,
99  const size_t totalPoints);
100 
109  DTree(MatType& data);
110 
123  DTree(const StatType& maxVals,
124  const StatType& minVals,
125  const size_t start,
126  const size_t end,
127  const double logNegError);
128 
140  DTree(const StatType& maxVals,
141  const StatType& minVals,
142  const size_t totalPoints,
143  const size_t start,
144  const size_t end);
145 
147  ~DTree();
148 
159  double Grow(MatType& data,
160  arma::Col<size_t>& oldFromNew,
161  const bool useVolReg = false,
162  const size_t maxLeafSize = 10,
163  const size_t minLeafSize = 5);
164 
173  double PruneAndUpdate(const double oldAlpha,
174  const size_t points,
175  const bool useVolReg = false);
176 
182  double ComputeValue(const VecType& query) const;
183 
191  TagType TagTree(const TagType& tag = 0);
192 
199  TagType FindBucket(const VecType& query) const;
200 
206  void ComputeVariableImportance(arma::vec& importances) const;
207 
214  double LogNegativeError(const size_t totalPoints) const;
215 
219  bool WithinRange(const VecType& query) const;
220 
221  private:
222  // The indices in the complete set of points
223  // (after all forms of swapping in the original data
224  // matrix to align all the points in a node
225  // consecutively in the matrix. The 'old_from_new' array
226  // maps the points back to their original indices.
227 
230  size_t start;
233  size_t end;
234 
236  StatType maxVals;
238  StatType minVals;
239 
241  size_t splitDim;
242 
244  ElemType splitValue;
245 
247  double logNegError;
248 
251 
254 
256  bool root;
257 
259  double ratio;
260 
262  double logVolume;
263 
265  TagType bucketTag;
266 
268  double alphaUpper;
269 
274 
275  public:
277  size_t Start() const { return start; }
279  size_t End() const { return end; }
281  size_t SplitDim() const { return splitDim; }
283  ElemType SplitValue() const { return splitValue; }
285  double LogNegError() const { return logNegError; }
289  size_t SubtreeLeaves() const { return subtreeLeaves; }
292  double Ratio() const { return ratio; }
294  double LogVolume() const { return logVolume; }
296  DTree* Left() const { return left; }
298  DTree* Right() const { return right; }
300  bool Root() const { return root; }
302  double AlphaUpper() const { return alphaUpper; }
304  TagType BucketTag() const { return subtreeLeaves == 1 ? bucketTag : -1; }
305 
307  const StatType& MaxVals() const { return maxVals; }
308 
310  const StatType& MinVals() const { return minVals; }
311 
315  template<typename Archive>
316  void Serialize(Archive& ar, const unsigned int /* version */);
317 
318  private:
319 
320  // Utility methods.
321 
325  bool FindSplit(const MatType& data,
326  size_t& splitDim,
327  ElemType& splitValue,
328  double& leftError,
329  double& rightError,
330  const size_t minLeafSize = 5) const;
331 
335  size_t SplitData(MatType& data,
336  const size_t splitDim,
337  const ElemType splitValue,
338  arma::Col<size_t>& oldFromNew) const;
339 
340 };
341 
342 } // namespace det
343 } // namespace mlpack
344 
345 #include "dtree_impl.hpp"
346 
347 #endif // MLPACK_METHODS_DET_DTREE_HPP
bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
arma::Col< ElemType > StatType
Definition: dtree.hpp:54
DTree * Right() const
Return the right child.
Definition: dtree.hpp:298
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
Definition: dtree.hpp:292
const StatType & MaxVals() const
Return the maximum values.
Definition: dtree.hpp:307
size_t End() const
Return the first index of a point not contained in this node.
Definition: dtree.hpp:279
bool FindSplit(const MatType &data, size_t &splitDim, ElemType &splitValue, double &leftError, double &rightError, const size_t minLeafSize=5) const
Find the dimension to split on.
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
Definition: dtree.hpp:287
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
bool Root() const
Return whether or not this is the root of the tree.
Definition: dtree.hpp:300
double AlphaUpper() const
Return the upper part of the alpha sum.
Definition: dtree.hpp:302
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t splitDim
The splitting dimension for this node.
Definition: dtree.hpp:241
size_t end
The index of the last point in the dataset contained in this node (and its children).
Definition: dtree.hpp:233
double ratio
Ratio of the number of points in the node to the total number of points.
Definition: dtree.hpp:259
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
Definition: dtree.hpp:289
DTree * right
The right child.
Definition: dtree.hpp:273
double logNegError
log-negative-L2-error of the node.
Definition: dtree.hpp:247
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
const StatType & MinVals() const
Return the minimum values.
Definition: dtree.hpp:310
double LogVolume() const
Return the inverse of the volume of this node.
Definition: dtree.hpp:294
MatType::elem_type ElemType
The actual, underlying type we&#39;re working with.
Definition: dtree.hpp:52
TagType TagTree(const TagType &tag=0)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
size_t Start() const
Return the starting index of points contained in this node.
Definition: dtree.hpp:277
double LogNegError() const
Return the log negative error of this node.
Definition: dtree.hpp:285
StatType minVals
Lower half of bounding box for this node.
Definition: dtree.hpp:238
DTree * Left() const
Return the left child.
Definition: dtree.hpp:296
MatType::vec_type VecType
Definition: dtree.hpp:53
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
Definition: dtree.hpp:46
size_t start
The index of the first point in the dataset contained in this node (and its children).
Definition: dtree.hpp:230
double logVolume
The logarithm of the volume of the node.
Definition: dtree.hpp:262
TagType BucketTag() const
Return the current bucket&#39;s ID, if leaf, or -1 otherwise.
Definition: dtree.hpp:304
ElemType SplitValue() const
Return the split value of this node.
Definition: dtree.hpp:283
~DTree()
Clean up memory allocated by the tree.
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
DTree()
Create an empty density estimation tree.
double alphaUpper
Upper part of alpha sum; used for pruning.
Definition: dtree.hpp:268
StatType maxVals
Upper half of bounding box for this node.
Definition: dtree.hpp:236
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
double subtreeLeavesLogNegError
Sum of the error of the leaves of the subtree.
Definition: dtree.hpp:250
void Serialize(Archive &ar, const unsigned int)
Serialize the density estimation tree.
bool root
If true, this node is the root of the tree.
Definition: dtree.hpp:256
size_t SplitDim() const
Return the split dimension of this node.
Definition: dtree.hpp:281
size_t SplitData(MatType &data, const size_t splitDim, const ElemType splitValue, arma::Col< size_t > &oldFromNew) const
Split the data, returning the number of points left of the split.
ElemType splitValue
The split value on the splitting dimension for this node.
Definition: dtree.hpp:244
TagType bucketTag
The tag for the leaf, used for hashing points.
Definition: dtree.hpp:265
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
DTree & operator=(const DTree &obj)
Copy the given tree.
size_t subtreeLeaves
Number of leaves of the subtree.
Definition: dtree.hpp:253
DTree * left
The left child.
Definition: dtree.hpp:271