13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP 14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP 48 template<
typename T,
typename U>
49 void Split(
const arma::Mat<T>& input,
50 const arma::Row<U>& inputLabel,
51 arma::Mat<T>& trainData,
52 arma::Mat<T>& testData,
53 arma::Row<U>& trainLabel,
54 arma::Row<U>& testLabel,
55 const double testRatio)
57 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
58 const size_t trainSize = input.n_cols - testSize;
59 trainData.set_size(input.n_rows, trainSize);
60 testData.set_size(input.n_rows, testSize);
61 trainLabel.set_size(trainSize);
62 testLabel.set_size(testSize);
64 const arma::Col<size_t> order =
65 arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols - 1,
68 for (
size_t i = 0; i != trainSize; ++i)
70 trainData.col(i) = input.col(order[i]);
71 trainLabel(i) = inputLabel(order[i]);
74 for (
size_t i = 0; i != testSize; ++i)
76 testData.col(i) = input.col(order[i + trainSize]);
77 testLabel(i) = inputLabel(order[i + trainSize]);
103 void Split(
const arma::Mat<T>& input,
104 arma::Mat<T>& trainData,
105 arma::Mat<T>& testData,
106 const double testRatio)
108 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
109 const size_t trainSize = input.n_cols - testSize;
110 trainData.set_size(input.n_rows, trainSize);
111 testData.set_size(input.n_rows, testSize);
113 const arma::Col<size_t> order =
114 arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols -1,
117 for (
size_t i = 0; i != trainSize; ++i)
119 trainData.col(i) = input.col(order[i]);
121 for (
size_t i = 0; i != testSize; ++i)
123 testData.col(i) = input.col(order[i + trainSize]);
146 template<
typename T,
typename U>
147 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
149 const arma::Row<U>& inputLabel,
150 const double testRatio)
152 arma::Mat<T> trainData;
153 arma::Mat<T> testData;
154 arma::Row<U> trainLabel;
155 arma::Row<U> testLabel;
157 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
160 return std::make_tuple(std::move(trainData),
162 std::move(trainLabel),
163 std::move(testLabel));
183 std::tuple<arma::Mat<T>, arma::Mat<T>>
185 const double testRatio)
187 arma::Mat<T> trainData;
188 arma::Mat<T> testData;
189 Split(input, trainData, testData, testRatio);
191 return std::make_tuple(std::move(trainData),
192 std::move(testData));
void Split(const arma::Mat< T > &input, const arma::Row< U > &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, arma::Row< U > &trainLabel, arma::Row< U > &testLabel, const double testRatio)
Given an input dataset and labels, split into a training set and test set.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.