43 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 44 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 61 template<
typename SortPolicy = NearestNeighborSort>
182 void Train(
const arma::mat& referenceSet,
183 const size_t numProj,
184 const size_t numTables,
188 const arma::cube& projection = arma::cube());
211 void Search(
const arma::mat& querySet,
213 arma::Mat<size_t>& resultingNeighbors,
214 arma::mat& distances,
215 const size_t numTablesToSearch = 0,
236 void Search(
const size_t k,
237 arma::Mat<size_t>& resultingNeighbors,
238 arma::mat& distances,
239 const size_t numTablesToSearch = 0,
251 static double ComputeRecall(
const arma::Mat<size_t>& foundNeighbors,
252 const arma::Mat<size_t>& realNeighbors);
259 template<
typename Archive>
260 void Serialize(Archive& ar,
const unsigned int version);
313 template<
typename VecType>
315 arma::uvec& referenceIndices,
316 size_t numTablesToSearch,
317 const size_t T)
const;
332 void BaseCase(
const size_t queryIndex,
333 const arma::uvec& referenceIndices,
335 arma::Mat<size_t>& neighbors,
336 arma::mat& distances)
const;
352 void BaseCase(
const size_t queryIndex,
353 const arma::uvec& referenceIndices,
355 const arma::mat& querySet,
356 arma::Mat<size_t>& neighbors,
357 arma::mat& distances)
const;
374 const arma::vec& queryCodeNotFloored,
376 arma::mat& additionalProbingBins)
const;
386 const arma::vec& scores)
const;
466 return !SortPolicy::IsBetter(c2.first, c1.first);
471 typedef std::priority_queue<Candidate, std::vector<Candidate>,
CandidateCmp>
484 #include "lsh_search_impl.hpp" bool PerturbationExpand(std::vector< bool > &A) const
Inline function used by GetAdditionalProbingBins.
size_t numTables
The number of hash tables.
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
void Serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
Linear algebra utility functions, generally performed on matrices or vectors.
std::vector< arma::Col< size_t > > secondHashTable
The final hash table; should be (< secondHashSize) vectors each with (<= bucketSize) elements...
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
size_t bucketSize
The bucket size of the second hash.
void ReturnIndicesFromTable(const VecType &queryPoint, arma::uvec &referenceIndices, size_t numTablesToSearch, const size_t T) const
This function takes a query and hashes it into each of the hash tables to get keys for the query and ...
arma::Col< size_t > bucketContentSize
The number of elements present in each hash bucket; should be secondHashSize.
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
const arma::mat & Offsets() const
Get the offsets 'b' for each of the projections. (One 'b' per column.)
void BaseCase(const size_t queryIndex, const arma::uvec &referenceIndices, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances) const
This is a helper function that computes the distance of the query to the neighbor candidates and appr...
bool PerturbationValid(const std::vector< bool > &A) const
Return true if perturbation set A is valid.
const arma::cube & Projections()
Get the projection tables.
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::LSHSearch< SortPolicy >, 1)
Set the serialization version of the LSHSearch class.
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
void Train(const arma::mat &referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
arma::mat offsets
The list of the offsets 'b' for each of the projection for each table.
Compare two candidates based on the distance.
std::priority_queue< Candidate, std::vector< Candidate >, CandidateCmp > CandidateList
Use a priority queue to represent the list of candidate neighbors.
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
size_t BucketSize() const
Get the bucket size of the second hash.
bool PerturbationShift(std::vector< bool > &A) const
Inline function used by GetAdditionalProbingBins.
double PerturbationScore(const std::vector< bool > &A, const arma::vec &scores) const
Returns the score of a perturbation vector generated by perturbation set A.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
size_t secondHashSize
The big prime representing the size of the second hash.
size_t numProj
The number of projections.
bool operator()(const Candidate &c1, const Candidate &c2)
bool ownsSet
If true, we own the reference set.
arma::vec secondHashWeights
The weights of the second hash.
const arma::mat & ReferenceSet() const
Return the reference dataset.
const arma::mat * referenceSet
Reference dataset.
size_t NumProjections() const
Get the number of projections.
double hashWidth
The hash width.
arma::Col< size_t > bucketRowInHashTable
For a particular hash value, points to the row in secondHashTable corresponding to this value...
size_t distanceEvaluations
The number of distance evaluations.
std::pair< double, size_t > Candidate
Candidate represents a possible candidate neighbor (distance, index).
arma::cube projections
The arma::cube containing the projection matrix of each table.
~LSHSearch()
Clean memory.
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
void GetAdditionalProbingBins(const arma::vec &queryCode, const arma::vec &queryCodeNotFloored, const size_t T, arma::mat &additionalProbingBins) const
This function implements the core idea behind Multiprobe LSH.