13 #ifndef MLPACK_METHODS_DET_DTREE_HPP 14 #define MLPACK_METHODS_DET_DTREE_HPP 44 template <
typename MatType,
45 typename TagType =
int>
53 typedef typename MatType::vec_type
VecType;
99 const size_t totalPoints);
109 DTree(MatType& data);
142 const size_t totalPoints,
159 double Grow(MatType& data,
160 arma::Col<size_t>& oldFromNew,
161 const bool useVolReg =
false,
162 const size_t maxLeafSize = 10,
163 const size_t minLeafSize = 5);
175 const bool useVolReg =
false);
191 TagType
TagTree(
const TagType& tag = 0);
199 TagType
FindBucket(
const VecType& query)
const;
304 TagType
BucketTag()
const {
return subtreeLeaves == 1 ? bucketTag : -1; }
315 template<
typename Archive>
316 void Serialize(Archive& ar,
const unsigned int );
327 ElemType& splitValue,
330 const size_t minLeafSize = 5)
const;
336 const size_t splitDim,
337 const ElemType splitValue,
338 arma::Col<size_t>& oldFromNew)
const;
345 #include "dtree_impl.hpp" 347 #endif // MLPACK_METHODS_DET_DTREE_HPP bool WithinRange(const VecType &query) const
Return whether a query point is within the range of this node.
double LogNegativeError(const size_t totalPoints) const
Compute the log-negative-error for this point, given the total number of points in the dataset...
arma::Col< ElemType > StatType
DTree * Right() const
Return the right child.
Linear algebra utility functions, generally performed on matrices or vectors.
double Ratio() const
Return the ratio of points in this node to the points in the whole dataset.
const StatType & MaxVals() const
Return the maximum values.
size_t End() const
Return the first index of a point not contained in this node.
bool FindSplit(const MatType &data, size_t &splitDim, ElemType &splitValue, double &leftError, double &rightError, const size_t minLeafSize=5) const
Find the dimension to split on.
double SubtreeLeavesLogNegError() const
Return the log negative error of all descendants of this node.
double PruneAndUpdate(const double oldAlpha, const size_t points, const bool useVolReg=false)
Perform alpha pruning on a tree.
bool Root() const
Return whether or not this is the root of the tree.
double AlphaUpper() const
Return the upper part of the alpha sum.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t splitDim
The splitting dimension for this node.
size_t end
The index of the last point in the dataset contained in this node (and its children).
double ratio
Ratio of the number of points in the node to the total number of points.
size_t SubtreeLeaves() const
Return the number of leaves which are descendants of this node.
DTree * right
The right child.
double logNegError
log-negative-L2-error of the node.
double ComputeValue(const VecType &query) const
Compute the logarithm of the density estimate of a given query point.
const StatType & MinVals() const
Return the minimum values.
double LogVolume() const
Return the inverse of the volume of this node.
MatType::elem_type ElemType
The actual, underlying type we're working with.
TagType TagTree(const TagType &tag=0)
Index the buckets for possible usage later; this results in every leaf in the tree having a specific ...
size_t Start() const
Return the starting index of points contained in this node.
double LogNegError() const
Return the log negative error of this node.
StatType minVals
Lower half of bounding box for this node.
DTree * Left() const
Return the left child.
MatType::vec_type VecType
A density estimation tree is similar to both a decision tree and a space partitioning tree (like a kd...
size_t start
The index of the first point in the dataset contained in this node (and its children).
double logVolume
The logarithm of the volume of the node.
TagType BucketTag() const
Return the current bucket's ID, if leaf, or -1 otherwise.
ElemType SplitValue() const
Return the split value of this node.
~DTree()
Clean up memory allocated by the tree.
void ComputeVariableImportance(arma::vec &importances) const
Compute the variable importance of each dimension in the learned tree.
DTree()
Create an empty density estimation tree.
double alphaUpper
Upper part of alpha sum; used for pruning.
StatType maxVals
Upper half of bounding box for this node.
TagType FindBucket(const VecType &query) const
Return the tag of the leaf containing the query.
double subtreeLeavesLogNegError
Sum of the error of the leaves of the subtree.
void Serialize(Archive &ar, const unsigned int)
Serialize the density estimation tree.
bool root
If true, this node is the root of the tree.
size_t SplitDim() const
Return the split dimension of this node.
size_t SplitData(MatType &data, const size_t splitDim, const ElemType splitValue, arma::Col< size_t > &oldFromNew) const
Split the data, returning the number of points left of the split.
ElemType splitValue
The split value on the splitting dimension for this node.
TagType bucketTag
The tag for the leaf, used for hashing points.
double Grow(MatType &data, arma::Col< size_t > &oldFromNew, const bool useVolReg=false, const size_t maxLeafSize=10, const size_t minLeafSize=5)
Greedily expand the tree.
DTree & operator=(const DTree &obj)
Copy the given tree.
size_t subtreeLeaves
Number of leaves of the subtree.
DTree * left
The left child.