-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestSaveLoad.cpp
More file actions
112 lines (88 loc) · 2.33 KB
/
testSaveLoad.cpp
File metadata and controls
112 lines (88 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/**
* @file testSaveLoad.cpp
* @brief Tests of loading and saving. These work by
* generating random networks, running them, and load/save cycling
* them to see if the params are the same
*
*/
#include <iostream>
#include <boost/test/unit_test.hpp>
#include "test.hpp"
/** \addtogroup saveloadtests save and load tests.
* \ingroup tests
* @{
*/
BOOST_AUTO_TEST_SUITE(saveload)
void testSaveLoad(NetType tp){
// generate a new network
int layers[3];
layers[0]=4;
layers[1]=3;
layers[2]=2;
Net *n = NetFactory::makeNet(tp,3,layers);
// generate a toy example. Doesn't matter what it is.
ExampleSet e(1,4,2,1);
double *p = e.getInputs(0);
*p++=0;
*p++=2;
*p++=3;
*p=1;
p = e.getOutputs(0);
*p++=100;
*p=20;
e.setH(0,0);
// train it a little.
Net::SGDParams parms(10,e,100);
n->trainSGD(e,parms);
// save the net to memory
double *oldData = new double[n->getDataSize()];
n->save(oldData);
// now save the net to disk
NetFactory::save("foo.net",n);
// and load
Net *saved = NetFactory::load("foo.net");
BOOST_REQUIRE(n->type == saved->type);
BOOST_REQUIRE(n->getDataSize() == saved->getDataSize());
// save the newly loaded net params to memory
double *savedData = new double[saved->getDataSize()];
saved->save(savedData);
// and compare params
for(int i=0;i<n->getDataSize();i++){
BOOST_REQUIRE(oldData[i]==savedData[i]);
}
delete [] savedData;
delete [] oldData;
delete n;
}
/**
* \brief Test that saving and loading a plain network
* leaves the weights and biases unchanged
*/
BOOST_AUTO_TEST_CASE(saveloadplain) {
testSaveLoad(NetType::PLAIN);
}
/**
* \brief Test that saving and loading an output blending network
* leaves the weights and biases unchanged
*/
BOOST_AUTO_TEST_CASE(saveloadob) {
testSaveLoad(NetType::OUTPUTBLENDING);
}
/**
* \brief Test that saving and loading an h-as-input network
* leaves the weights and biases unchanged
*/
BOOST_AUTO_TEST_CASE(saveloadhin) {
testSaveLoad(NetType::HINPUT);
}
/**
* \brief Test that saving and loading a UESMANN network
* leaves the weights and biases unchanged
*/
BOOST_AUTO_TEST_CASE(saveloadues) {
testSaveLoad(NetType::UESMANN);
}
/**
* @}
*/
BOOST_AUTO_TEST_SUITE_END()