mlpack  master
split_data.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP
14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace data {
48 template<typename T, typename U>
49 void Split(const arma::Mat<T>& input,
50  const arma::Row<U>& inputLabel,
51  arma::Mat<T>& trainData,
52  arma::Mat<T>& testData,
53  arma::Row<U>& trainLabel,
54  arma::Row<U>& testLabel,
55  const double testRatio)
56 {
57  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
58  const size_t trainSize = input.n_cols - testSize;
59  trainData.set_size(input.n_rows, trainSize);
60  testData.set_size(input.n_rows, testSize);
61  trainLabel.set_size(trainSize);
62  testLabel.set_size(testSize);
63 
64  const arma::Col<size_t> order =
65  arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols - 1,
66  input.n_cols));
67 
68  for (size_t i = 0; i != trainSize; ++i)
69  {
70  trainData.col(i) = input.col(order[i]);
71  trainLabel(i) = inputLabel(order[i]);
72  }
73 
74  for (size_t i = 0; i != testSize; ++i)
75  {
76  testData.col(i) = input.col(order[i + trainSize]);
77  testLabel(i) = inputLabel(order[i + trainSize]);
78  }
79 }
80 
102 template<typename T>
103 void Split(const arma::Mat<T>& input,
104  arma::Mat<T>& trainData,
105  arma::Mat<T>& testData,
106  const double testRatio)
107 {
108  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
109  const size_t trainSize = input.n_cols - testSize;
110  trainData.set_size(input.n_rows, trainSize);
111  testData.set_size(input.n_rows, testSize);
112 
113  const arma::Col<size_t> order =
114  arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols -1,
115  input.n_cols));
116 
117  for (size_t i = 0; i != trainSize; ++i)
118  {
119  trainData.col(i) = input.col(order[i]);
120  }
121  for (size_t i = 0; i != testSize; ++i)
122  {
123  testData.col(i) = input.col(order[i + trainSize]);
124  }
125 }
126 
146 template<typename T,typename U>
147 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
148 Split(const arma::Mat<T>& input,
149  const arma::Row<U>& inputLabel,
150  const double testRatio)
151 {
152  arma::Mat<T> trainData;
153  arma::Mat<T> testData;
154  arma::Row<U> trainLabel;
155  arma::Row<U> testLabel;
156 
157  Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
158  testRatio);
159 
160  return std::make_tuple(std::move(trainData),
161  std::move(testData),
162  std::move(trainLabel),
163  std::move(testLabel));
164 }
165 
182 template<typename T>
183 std::tuple<arma::Mat<T>, arma::Mat<T>>
184 Split(const arma::Mat<T>& input,
185  const double testRatio)
186 {
187  arma::Mat<T> trainData;
188  arma::Mat<T> testData;
189  Split(input, trainData, testData, testRatio);
190 
191  return std::make_tuple(std::move(trainData),
192  std::move(testData));
193 }
194 
195 } // namespace data
196 } // namespace mlpack
197 
198 #endif
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:49
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
The core includes that mlpack expects; standard C++ includes and Armadillo.