mlpack  master
neighbor_search_rules.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
14 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
15 
17 
18 namespace mlpack {
19 namespace neighbor {
20 
32 template<typename SortPolicy, typename MetricType, typename TreeType>
34 {
35  public:
48  NeighborSearchRules(const typename TreeType::Mat& referenceSet,
49  const typename TreeType::Mat& querySet,
50  const size_t k,
51  MetricType& metric,
52  const double epsilon = 0,
53  const bool sameSet = false);
54 
62  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
63 
72  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
73 
82  double Score(const size_t queryIndex, TreeType& referenceNode);
83 
90  size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode);
91 
98  size_t GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
99 
111  double Rescore(const size_t queryIndex,
112  TreeType& referenceNode,
113  const double oldScore) const;
114 
123  double Score(TreeType& queryNode, TreeType& referenceNode);
124 
136  double Rescore(TreeType& queryNode,
137  TreeType& referenceNode,
138  const double oldScore) const;
139 
141  size_t BaseCases() const { return baseCases; }
143  size_t& BaseCases() { return baseCases; }
144 
146  size_t Scores() const { return scores; }
148  size_t& Scores() { return scores; }
149 
152 
154  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
156  TraversalInfoType& TraversalInfo() { return traversalInfo; }
157 
158  protected:
160  const typename TreeType::Mat& referenceSet;
161 
163  const typename TreeType::Mat& querySet;
164 
166  typedef std::pair<double, size_t> Candidate;
167 
169  struct CandidateCmp {
170  bool operator()(const Candidate& c1, const Candidate& c2)
171  {
172  return !SortPolicy::IsBetter(c2.first, c1.first);
173  };
174  };
175 
177  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
179 
181  std::vector<CandidateList> candidates;
182 
184  const size_t k;
185 
187  MetricType& metric;
188 
190  bool sameSet;
191 
193  const double epsilon;
194 
200  double lastBaseCase;
201 
203  size_t baseCases;
205  size_t scores;
206 
209  TraversalInfoType traversalInfo;
210 
214  double CalculateBound(TreeType& queryNode) const;
215 
223  void InsertNeighbor(const size_t queryIndex,
224  const size_t neighbor,
225  const double distance);
226 };
227 
228 } // namespace neighbor
229 } // namespace mlpack
230 
231 // Include implementation.
232 #include "neighbor_search_rules_impl.hpp"
233 
234 #endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
std::vector< CandidateList > candidates
Set of candidate neighbors for each point.
double CalculateBound(TreeType &queryNode) const
Recalculate the bound for a given query node.
void InsertNeighbor(const size_t queryIndex, const size_t neighbor, const double distance)
Helper function to insert a point into the list of candidate points.
TraversalInfoType & TraversalInfo()
Modify the traversal info.
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
size_t lastQueryIndex
The last query point BaseCase() was called with.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
TraversalInfoType traversalInfo
Traversal info for the parent combination; this is updated by the traversal before each call to Score...
size_t BaseCases() const
Get the number of base cases that have been performed.
const TreeType::Mat & referenceSet
The reference set.
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
const size_t k
Number of neighbors to search for.
const double epsilon
Relative error to be considered in approximate search.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
const TraversalInfoType & TraversalInfo() const
Get the traversal info.
size_t & Scores()
Modify the number of scores that have been performed.
tree::TraversalInfo< TreeType > TraversalInfoType
Convenience typedef.
double lastBaseCase
The last base case result.
size_t Scores() const
Get the number of scores that have been performed.
size_t lastReferenceIndex
The last reference point BaseCase() was called with.
The NeighborSearchRules class is a template helper class used by NeighborSearch class when performing...
bool operator()(const Candidate &c1, const Candidate &c2)
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
size_t & BaseCases()
Modify the number of base cases that have been performed.
NeighborSearchRules(const typename TreeType::Mat &referenceSet, const typename TreeType::Mat &querySet, const size_t k, MetricType &metric, const double epsilon=0, const bool sameSet=false)
Construct the NeighborSearchRules object.
size_t GetBestChild(const size_t queryIndex, TreeType &referenceNode)
Get the child node with the best score.
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore) const
Re-evaluate the score for recursion order.
MetricType & metric
The instantiated metric.
size_t scores
The number of scores that have been performed.
size_t baseCases
The number of base cases that have been performed.
const TreeType::Mat & querySet
The query set.
Compare two candidates based on the distance.
bool sameSet
Denotes whether or not the reference and query sets are the same.