mlpack  master
softmax_regression.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
13 #define MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
19 
20 namespace mlpack {
21 namespace regression {
22 
62 template<
63  template<typename> class OptimizerType = mlpack::optimization::L_BFGS
64 >
66 {
67  public:
77  SoftmaxRegression(const size_t inputSize = 0,
78  const size_t numClasses = 0,
79  const bool fitIntercept = false);
80 
94  SoftmaxRegression(const arma::mat& data,
95  const arma::Row<size_t>& labels,
96  const size_t numClasses,
97  const double lambda = 0.0001,
98  const bool fitIntercept = false);
99 
109  SoftmaxRegression(OptimizerType<SoftmaxRegressionFunction>& optimizer);
110 
119  void Predict(const arma::mat& testData, arma::Row<size_t>& predictions) const;
120 
129  double ComputeAccuracy(const arma::mat& testData,
130  const arma::Row<size_t>& labels) const;
131 
140  double Train(OptimizerType<SoftmaxRegressionFunction>& optimizer);
141 
149  double Train(const arma::mat &data, const arma::Row<size_t>& labels,
150  const size_t numClasses);
151 
153  size_t& NumClasses() { return numClasses; }
155  size_t NumClasses() const { return numClasses; }
156 
158  double& Lambda() { return lambda; }
160  double Lambda() const { return lambda; }
161 
163  bool FitIntercept() const { return fitIntercept; }
164 
166  arma::mat& Parameters() { return parameters; }
168  const arma::mat& Parameters() const { return parameters; }
169 
171  size_t FeatureSize() const
172  { return fitIntercept ? parameters.n_cols - 1 :
173  parameters.n_cols; }
174 
178  template<typename Archive>
179  void Serialize(Archive& ar, const unsigned int /* version */)
180  {
182 
183  ar & CreateNVP(parameters, "parameters");
184  ar & CreateNVP(numClasses, "numClasses");
185  ar & CreateNVP(lambda, "lambda");
186  ar & CreateNVP(fitIntercept, "fitIntercept");
187  }
188 
189  private:
191  arma::mat parameters;
193  size_t numClasses;
195  double lambda;
198 };
199 
200 } // namespace regression
201 } // namespace mlpack
202 
203 // Include implementation.
204 #include "softmax_regression_impl.hpp"
205 
206 #endif
double & Lambda()
Sets the regularization parameter.
SoftmaxRegression(const size_t inputSize=0, const size_t numClasses=0, const bool fitIntercept=false)
Initialize the SoftmaxRegression without performing training.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
arma::mat & Parameters()
Get the model parameters.
The core includes that mlpack expects; standard C++ includes and Armadillo.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename std::enable_if_t< HasSerialize< T >::value > *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
const arma::mat & Parameters() const
Get the model parameters.
double Train(OptimizerType< SoftmaxRegressionFunction > &optimizer)
Train the softmax regression model with the given optimizer.
Softmax Regression is a classifier which can be used for classification when the data available can t...
size_t & NumClasses()
Sets the number of classes.
size_t NumClasses() const
Gets the number of classes.
double Lambda() const
Gets the regularization parameter.
bool FitIntercept() const
Gets the intercept term flag. We can&#39;t change this after training.
void Serialize(Archive &ar, const unsigned int)
Serialize the SoftmaxRegression model.
double lambda
L2-regularization constant.
size_t FeatureSize() const
Gets the features size of the training data.
void Predict(const arma::mat &testData, arma::Row< size_t > &predictions) const
Predict the class labels for the provided feature points.
arma::mat parameters
Parameters after optimization.
The generic L-BFGS optimizer, which uses a back-tracking line search algorithm to minimize a function...
Definition: lbfgs.hpp:34
double ComputeAccuracy(const arma::mat &testData, const arma::Row< size_t > &labels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...