mlpack  master
naive_bayes_classifier.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
15 #define MLPACK_METHODS_NAIVE_BAYES_NAIVE_BAYES_CLASSIFIER_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace naive_bayes {
21 
46 template<typename MatType = arma::mat>
48 {
49  public:
68  NaiveBayesClassifier(const MatType& data,
69  const arma::Row<size_t>& labels,
70  const size_t classes,
71  const bool incrementalVariance = false);
72 
79  NaiveBayesClassifier(const size_t dimensionality = 0,
80  const size_t classes = 0);
81 
97  void Train(const MatType& data,
98  const arma::Row<size_t>& labels,
99  const bool incremental = true);
100 
109  template<typename VecType>
110  void Train(const VecType& point, const size_t label);
111 
126  void Classify(const MatType& data, arma::Row<size_t>& results);
127 
129  const MatType& Means() const { return means; }
131  MatType& Means() { return means; }
132 
134  const MatType& Variances() const { return variances; }
136  MatType& Variances() { return variances; }
137 
139  const arma::vec& Probabilities() const { return probabilities; }
141  arma::vec& Probabilities() { return probabilities; }
142 
144  template<typename Archive>
145  void Serialize(Archive& ar, const unsigned int /* version */);
146 
147  private:
149  MatType means;
151  MatType variances;
153  arma::vec probabilities;
156 };
157 
158 } // namespace naive_bayes
159 } // namespace mlpack
160 
161 // Include implementation.
162 #include "naive_bayes_classifier_impl.hpp"
163 
164 #endif
MatType & Means()
Modify the sample means for each class.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
void Serialize(Archive &ar, const unsigned int)
Serialize the classifier.
const MatType & Means() const
Get the sample means for each class.
The core includes that mlpack expects; standard C++ includes and Armadillo.
The simple Naive Bayes classifier.
void Train(const MatType &data, const arma::Row< size_t > &labels, const bool incremental=true)
Train the Naive Bayes classifier on the given dataset.
MatType means
Sample mean for each class.
MatType & Variances()
Modify the sample variances for each class.
const MatType & Variances() const
Get the sample variances for each class.
arma::vec & Probabilities()
Modify the prior probabilities for each class.
NaiveBayesClassifier(const MatType &data, const arma::Row< size_t > &labels, const size_t classes, const bool incrementalVariance=false)
Initializes the classifier as per the input and then trains it by calculating the sample mean and var...
size_t trainingPoints
Number of training points seen so far.
void Classify(const MatType &data, arma::Row< size_t > &results)
Given a bunch of data points, this function evaluates the class of each of those data points...
MatType variances
Sample variances for each class.
const arma::vec & Probabilities() const
Get the prior probabilities for each class.