mlpack  master
serialization.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_SERIALIZATION_HPP
13 #define MLPACK_TESTS_SERIALIZATION_HPP
14 
15 #include <boost/serialization/serialization.hpp>
16 #include <boost/archive/xml_iarchive.hpp>
17 #include <boost/archive/xml_oarchive.hpp>
18 #include <boost/archive/text_iarchive.hpp>
19 #include <boost/archive/text_oarchive.hpp>
20 #include <boost/archive/binary_iarchive.hpp>
21 #include <boost/archive/binary_oarchive.hpp>
22 #include <mlpack/core.hpp>
23 
24 #include <boost/test/unit_test.hpp>
25 #include "test_tools.hpp"
26 
27 namespace mlpack {
28 
29 // Test function for loading and saving Armadillo objects.
30 template<typename CubeType,
31  typename IArchiveType,
32  typename OArchiveType>
33 void TestArmadilloSerialization(arma::Cube<CubeType>& x)
34 {
35  // First save it.
36  std::ofstream ofs("test", std::ios::binary);
37  OArchiveType o(ofs);
38 
39  bool success = true;
40  try
41  {
42  o << BOOST_SERIALIZATION_NVP(x);
43  }
44  catch (boost::archive::archive_exception& e)
45  {
46  success = false;
47  }
48 
49  BOOST_REQUIRE_EQUAL(success, true);
50  ofs.close();
51 
52  // Now load it.
53  arma::Cube<CubeType> orig(x);
54  success = true;
55  std::ifstream ifs("test", std::ios::binary);
56  IArchiveType i(ifs);
57 
58  try
59  {
60  i >> BOOST_SERIALIZATION_NVP(x);
61  }
62  catch (boost::archive::archive_exception& e)
63  {
64  success = false;
65  }
66 
67  BOOST_REQUIRE_EQUAL(success, true);
68 
69  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
70  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
71  BOOST_REQUIRE_EQUAL(x.n_elem_slice, orig.n_elem_slice);
72  BOOST_REQUIRE_EQUAL(x.n_slices, orig.n_slices);
73  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
74 
75  for(size_t slice = 0; slice != x.n_slices; ++slice){
76  auto const &orig_slice = orig.slice(slice);
77  auto const &x_slice = x.slice(slice);
78  for (size_t i = 0; i < x.n_cols; ++i){
79  for (size_t j = 0; j < x.n_rows; ++j){
80  if (double(orig_slice(j, i)) == 0.0)
81  BOOST_REQUIRE_SMALL(double(x_slice(j, i)), 1e-8);
82  else
83  BOOST_REQUIRE_CLOSE(double(orig_slice(j, i)), double(x_slice(j, i)), 1e-8);
84  }
85  }
86  }
87 
88  remove("test");
89 }
90 
91 // Test all serialization strategies.
92 template<typename CubeType>
93 void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
94 {
95  TestArmadilloSerialization<CubeType, boost::archive::xml_iarchive,
96  boost::archive::xml_oarchive>(x);
97  TestArmadilloSerialization<CubeType, boost::archive::text_iarchive,
98  boost::archive::text_oarchive>(x);
99  TestArmadilloSerialization<CubeType, boost::archive::binary_iarchive,
100  boost::archive::binary_oarchive>(x);
101 }
102 
103 // Test function for loading and saving Armadillo objects.
104 template<typename MatType,
105  typename IArchiveType,
106  typename OArchiveType>
108 {
109  // First save it.
110  std::ofstream ofs("test", std::ios::binary);
111  OArchiveType o(ofs);
112 
113  bool success = true;
114  try
115  {
116  o << BOOST_SERIALIZATION_NVP(x);
117  }
118  catch (boost::archive::archive_exception& e)
119  {
120  success = false;
121  }
122 
123  BOOST_REQUIRE_EQUAL(success, true);
124  ofs.close();
125 
126  // Now load it.
127  MatType orig(x);
128  success = true;
129  std::ifstream ifs("test", std::ios::binary);
130  IArchiveType i(ifs);
131 
132  try
133  {
134  i >> BOOST_SERIALIZATION_NVP(x);
135  }
136  catch (boost::archive::archive_exception& e)
137  {
138  success = false;
139  }
140 
141  BOOST_REQUIRE_EQUAL(success, true);
142 
143  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
144  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
145  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
146 
147  for (size_t i = 0; i < x.n_cols; ++i)
148  for (size_t j = 0; j < x.n_rows; ++j)
149  if (double(orig(j, i)) == 0.0)
150  BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
151  else
152  BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
153 
154  remove("test");
155 }
156 
157 // Test all serialization strategies.
158 template<typename MatType>
160 {
161  TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
162  boost::archive::xml_oarchive>(x);
163  TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
164  boost::archive::text_oarchive>(x);
165  TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
166  boost::archive::binary_oarchive>(x);
167 }
168 
169 // Save and load an mlpack object.
170 // The re-loaded copy is placed in 'newT'.
171 template<typename T, typename IArchiveType, typename OArchiveType>
172 void SerializeObject(T& t, T& newT)
173 {
174  std::ofstream ofs("test", std::ios::binary);
175  OArchiveType o(ofs);
176 
177  bool success = true;
178  try
179  {
180  o << data::CreateNVP(t, "t");
181  }
182  catch (boost::archive::archive_exception& e)
183  {
184  success = false;
185  }
186  ofs.close();
187 
188  BOOST_REQUIRE_EQUAL(success, true);
189 
190  std::ifstream ifs("test", std::ios::binary);
191  IArchiveType i(ifs);
192 
193  try
194  {
195  i >> data::CreateNVP(newT, "t");
196  }
197  catch (boost::archive::archive_exception& e)
198  {
199  success = false;
200  }
201  ifs.close();
202 
203  BOOST_REQUIRE_EQUAL(success, true);
204 }
205 
206 // Test mlpack serialization with all three archive types.
207 template<typename T>
208 void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
209 {
210  SerializeObject<T, boost::archive::text_iarchive,
211  boost::archive::text_oarchive>(t, textT);
212  SerializeObject<T, boost::archive::binary_iarchive,
213  boost::archive::binary_oarchive>(t, binaryT);
214  SerializeObject<T, boost::archive::xml_iarchive,
215  boost::archive::xml_oarchive>(t, xmlT);
216 }
217 
218 // Save and load a non-default-constructible mlpack object.
219 template<typename T, typename IArchiveType, typename OArchiveType>
220 void SerializePointerObject(T* t, T*& newT)
221 {
222  std::ofstream ofs("test", std::ios::binary);
223  OArchiveType o(ofs);
224 
225  bool success = true;
226  try
227  {
228  o << data::CreateNVP(*t, "t");
229  }
230  catch (boost::archive::archive_exception& e)
231  {
232  success = false;
233  }
234  ofs.close();
235 
236  BOOST_REQUIRE_EQUAL(success, true);
237 
238  std::ifstream ifs("test", std::ios::binary);
239  IArchiveType i(ifs);
240 
241  try
242  {
243  newT = new T(i);
244  }
245  catch (std::exception& e)
246  {
247  success = false;
248  }
249  ifs.close();
250 
251  BOOST_REQUIRE_EQUAL(success, true);
252 }
253 
254 template<typename T>
255 void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
256 {
257  SerializePointerObject<T, boost::archive::text_iarchive,
258  boost::archive::text_oarchive>(t, textT);
259  SerializePointerObject<T, boost::archive::binary_iarchive,
260  boost::archive::binary_oarchive>(t, binaryT);
261  SerializePointerObject<T, boost::archive::xml_iarchive,
262  boost::archive::xml_oarchive>(t, xmlT);
263 }
264 
265 // Utility function to check the equality of two Armadillo matrices.
266 void CheckMatrices(const arma::mat& x,
267  const arma::mat& xmlX,
268  const arma::mat& textX,
269  const arma::mat& binaryX);
270 
271 void CheckMatrices(const arma::Mat<size_t>& x,
272  const arma::Mat<size_t>& xmlX,
273  const arma::Mat<size_t>& textX,
274  const arma::Mat<size_t>& binaryX);
275 
276 } // namespace mlpack
277 
278 #endif
void SerializePointerObject(T *t, T *&newT)
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: binarize.hpp:18
FirstShim< T > CreateNVP(T &t, const std::string &name, typename std::enable_if_t< HasSerialize< T >::value > *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
void SerializeObject(T &t, T &newT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)