mlpack  master
lsh_search.hpp
Go to the documentation of this file.
1 
43 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
44 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
45 
46 #include <mlpack/prereqs.hpp>
47 
50 
51 namespace mlpack {
52 namespace neighbor {
53 
61 template<typename SortPolicy = NearestNeighborSort>
62 class LSHSearch
63 {
64  public:
85  LSHSearch(const arma::mat& referenceSet,
86  const arma::cube& projections,
87  const double hashWidth = 0.0,
88  const size_t secondHashSize = 99901,
89  const size_t bucketSize = 500);
90 
112  LSHSearch(const arma::mat& referenceSet,
113  const size_t numProj,
114  const size_t numTables,
115  const double hashWidth = 0.0,
116  const size_t secondHashSize = 99901,
117  const size_t bucketSize = 500);
118 
123  LSHSearch();
124 
130  LSHSearch(const LSHSearch& other);
131 
137  LSHSearch(LSHSearch&& other);
138 
144  LSHSearch& operator=(const LSHSearch& other);
145 
151  LSHSearch& operator=(LSHSearch&& other);
152 
156  ~LSHSearch();
157 
182  void Train(const arma::mat& referenceSet,
183  const size_t numProj,
184  const size_t numTables,
185  const double hashWidth = 0.0,
186  const size_t secondHashSize = 99901,
187  const size_t bucketSize = 500,
188  const arma::cube& projection = arma::cube());
189 
211  void Search(const arma::mat& querySet,
212  const size_t k,
213  arma::Mat<size_t>& resultingNeighbors,
214  arma::mat& distances,
215  const size_t numTablesToSearch = 0,
216  const size_t T = 0);
217 
236  void Search(const size_t k,
237  arma::Mat<size_t>& resultingNeighbors,
238  arma::mat& distances,
239  const size_t numTablesToSearch = 0,
240  size_t T = 0);
241 
251  static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
252  const arma::Mat<size_t>& realNeighbors);
253 
259  template<typename Archive>
260  void Serialize(Archive& ar, const unsigned int version);
261 
263  size_t DistanceEvaluations() const { return distanceEvaluations; }
266 
268  const arma::mat& ReferenceSet() const { return *referenceSet; }
269 
271  size_t NumProjections() const { return projections.n_slices; }
272 
274  const arma::mat& Offsets() const { return offsets; }
275 
277  const arma::vec& SecondHashWeights() const { return secondHashWeights; }
278 
280  size_t BucketSize() const { return bucketSize; }
281 
283  const std::vector<arma::Col<size_t>>& SecondHashTable() const
284  { return secondHashTable; }
285 
287  const arma::cube& Projections() { return projections; }
288 
290  void Projections(const arma::cube& projTables)
291  {
292  // Simply call Train() with the given projection tables.
293  Train(*referenceSet, numProj, numTables, hashWidth, secondHashSize,
294  bucketSize, projTables);
295  }
296 
297  private:
313  template<typename VecType>
314  void ReturnIndicesFromTable(const VecType& queryPoint,
315  arma::uvec& referenceIndices,
316  size_t numTablesToSearch,
317  const size_t T) const;
318 
332  void BaseCase(const size_t queryIndex,
333  const arma::uvec& referenceIndices,
334  const size_t k,
335  arma::Mat<size_t>& neighbors,
336  arma::mat& distances) const;
337 
352  void BaseCase(const size_t queryIndex,
353  const arma::uvec& referenceIndices,
354  const size_t k,
355  const arma::mat& querySet,
356  arma::Mat<size_t>& neighbors,
357  arma::mat& distances) const;
358 
373  void GetAdditionalProbingBins(const arma::vec& queryCode,
374  const arma::vec& queryCodeNotFloored,
375  const size_t T,
376  arma::mat& additionalProbingBins) const;
377 
385  double PerturbationScore(const std::vector<bool>& A,
386  const arma::vec& scores) const;
387 
395  bool PerturbationShift(std::vector<bool>& A) const;
396 
405  bool PerturbationExpand(std::vector<bool>& A) const;
406 
414  bool PerturbationValid(const std::vector<bool>& A) const;
415 
417  const arma::mat* referenceSet;
419  bool ownsSet;
420 
422  size_t numProj;
424  size_t numTables;
425 
427  arma::cube projections; // should be [numProj x dims] x numTables slices
428 
430  arma::mat offsets; // should be numProj x numTables
431 
433  double hashWidth;
434 
437 
439  arma::vec secondHashWeights;
440 
442  size_t bucketSize;
443 
446  std::vector<arma::Col<size_t>> secondHashTable;
447 
450  arma::Col<size_t> bucketContentSize;
451 
454  arma::Col<size_t> bucketRowInHashTable;
455 
458 
460  typedef std::pair<double, size_t> Candidate;
461 
463  struct CandidateCmp {
464  bool operator()(const Candidate& c1, const Candidate& c2)
465  {
466  return !SortPolicy::IsBetter(c2.first, c1.first);
467  };
468  };
469 
471  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
473 
474 }; // class LSHSearch
475 
476 } // namespace neighbor
477 } // namespace mlpack
478 
480 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
482 
483 // Include implementation.
484 #include "lsh_search_impl.hpp"
485 
486 #endif
bool PerturbationExpand(std::vector< bool > &A) const
Inline function used by GetAdditionalProbingBins.
size_t numTables
The number of hash tables.
Definition: lsh_search.hpp:424
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
Definition: lsh_search.hpp:263
void Serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
std::vector< arma::Col< size_t > > secondHashTable
The final hash table; should be (< secondHashSize) vectors each with (<= bucketSize) elements...
Definition: lsh_search.hpp:446
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
size_t bucketSize
The bucket size of the second hash.
Definition: lsh_search.hpp:442
void ReturnIndicesFromTable(const VecType &queryPoint, arma::uvec &referenceIndices, size_t numTablesToSearch, const size_t T) const
This function takes a query and hashes it into each of the hash tables to get keys for the query and ...
arma::Col< size_t > bucketContentSize
The number of elements present in each hash bucket; should be secondHashSize.
Definition: lsh_search.hpp:450
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
const arma::mat & Offsets() const
Get the offsets &#39;b&#39; for each of the projections. (One &#39;b&#39; per column.)
Definition: lsh_search.hpp:274
void BaseCase(const size_t queryIndex, const arma::uvec &referenceIndices, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances) const
This is a helper function that computes the distance of the query to the neighbor candidates and appr...
bool PerturbationValid(const std::vector< bool > &A) const
Return true if perturbation set A is valid.
const arma::cube & Projections()
Get the projection tables.
Definition: lsh_search.hpp:287
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
Definition: lsh_search.hpp:62
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::LSHSearch< SortPolicy >, 1)
Set the serialization version of the LSHSearch class.
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
void Train(const arma::mat &referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
arma::mat offsets
The list of the offsets &#39;b&#39; for each of the projection for each table.
Definition: lsh_search.hpp:430
Compare two candidates based on the distance.
Definition: lsh_search.hpp:463
std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
Definition: lsh_search.hpp:472
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
Definition: lsh_search.hpp:283
size_t BucketSize() const
Get the bucket size of the second hash.
Definition: lsh_search.hpp:280
bool PerturbationShift(std::vector< bool > &A) const
Inline function used by GetAdditionalProbingBins.
double PerturbationScore(const std::vector< bool > &A, const arma::vec &scores) const
Returns the score of a perturbation vector generated by perturbation set A.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
Definition: lsh_search.hpp:277
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
Definition: lsh_search.hpp:265
size_t secondHashSize
The big prime representing the size of the second hash.
Definition: lsh_search.hpp:436
size_t numProj
The number of projections.
Definition: lsh_search.hpp:422
bool operator()(const Candidate &c1, const Candidate &c2)
Definition: lsh_search.hpp:464
bool ownsSet
If true, we own the reference set.
Definition: lsh_search.hpp:419
arma::vec secondHashWeights
The weights of the second hash.
Definition: lsh_search.hpp:439
const arma::mat & ReferenceSet() const
Return the reference dataset.
Definition: lsh_search.hpp:268
const arma::mat * referenceSet
Reference dataset.
Definition: lsh_search.hpp:417
size_t NumProjections() const
Get the number of projections.
Definition: lsh_search.hpp:271
double hashWidth
The hash width.
Definition: lsh_search.hpp:433
arma::Col< size_t > bucketRowInHashTable
For a particular hash value, points to the row in secondHashTable corresponding to this value...
Definition: lsh_search.hpp:454
size_t distanceEvaluations
The number of distance evaluations.
Definition: lsh_search.hpp:457
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
Definition: lsh_search.hpp:460
arma::cube projections
The arma::cube containing the projection matrix of each table.
Definition: lsh_search.hpp:427
~LSHSearch()
Clean memory.
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
Definition: lsh_search.hpp:290
void GetAdditionalProbingBins(const arma::vec &queryCode, const arma::vec &queryCodeNotFloored, const size_t T, arma::mat &additionalProbingBins) const
This function implements the core idea behind Multiprobe LSH.