mlpack  master
mean_shift.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
14 #define MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
15 
16 #include <mlpack/prereqs.hpp>
20 #include <boost/utility.hpp>
21 
22 namespace mlpack {
23 namespace meanshift {
24 
46 template<bool UseKernel = false,
47  typename KernelType = kernel::GaussianKernel,
48  typename MatType = arma::mat>
49 class MeanShift
50 {
51  public:
63  MeanShift(const double radius = 0,
64  const size_t maxIterations = 1000,
65  const KernelType kernel = KernelType());
66 
73  double EstimateRadius(const MatType& data, const double ratio = 0.2);
74 
84  void Cluster(const MatType& data,
85  arma::Col<size_t>& assignments,
86  arma::mat& centroids,
87  bool useSeeds = true);
88 
90  size_t MaxIterations() const { return maxIterations; }
92  size_t& MaxIterations() { return maxIterations; }
93 
95  double Radius() const { return radius; }
97  void Radius(double radius);
98 
100  const KernelType& Kernel() const { return kernel; }
102  KernelType& Kernel() { return kernel; }
103 
104  private:
118  void GenSeeds(const MatType& data,
119  const double binSize,
120  const int minFreq,
121  MatType& seeds);
122 
131  template<bool ApplyKernel = UseKernel>
132  typename std::enable_if<ApplyKernel, bool>::type
133  CalculateCentroid(const MatType& data,
134  const std::vector<size_t>& neighbors,
135  const std::vector<double>& distances,
136  arma::colvec& centroid);
137 
146  template<bool ApplyKernel = UseKernel>
147  typename std::enable_if<!ApplyKernel, bool>::type
148  CalculateCentroid(const MatType& data,
149  const std::vector<size_t>& neighbors,
150  const std::vector<double>&, /*unused*/
151  arma::colvec& centroid);
152 
158  double radius;
159 
162 
164  KernelType kernel;
165 };
166 
167 } // namespace meanshift
168 } // namespace mlpack
169 
170 // Include implementation.
171 #include "mean_shift_impl.hpp"
172 
173 #endif // MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
KernelType & Kernel()
Modify the kernel.
Definition: mean_shift.hpp:102
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class implements mean shift clustering.
Definition: mean_shift.hpp:49
size_t MaxIterations() const
Get the maximum number of iterations.
Definition: mean_shift.hpp:90
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
void Cluster(const MatType &data, arma::Col< size_t > &assignments, arma::mat &centroids, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids...
std::enable_if< ApplyKernel, bool >::type CalculateCentroid(const MatType &data, const std::vector< size_t > &neighbors, const std::vector< double > &distances, arma::colvec &centroid)
Use kernel to calculate new centroid given dataset and valid neighbors.
size_t & MaxIterations()
Set the maximum number of iterations.
Definition: mean_shift.hpp:92
const KernelType & Kernel() const
Get the kernel.
Definition: mean_shift.hpp:100
double radius
If distance of two centroids is less than radius, one will be removed.
Definition: mean_shift.hpp:158
KernelType kernel
Instantiated kernel.
Definition: mean_shift.hpp:164
The standard Gaussian kernel.
double Radius() const
Get the radius.
Definition: mean_shift.hpp:95
size_t maxIterations
Maximum number of iterations before giving up.
Definition: mean_shift.hpp:161
void GenSeeds(const MatType &data, const double binSize, const int minFreq, MatType &seeds)
To speed up, we can generate some seeds from data set and use them as initial centroids rather than a...