-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathobnet.hpp
More file actions
165 lines (134 loc) · 4.63 KB
/
obnet.hpp
File metadata and controls
165 lines (134 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
/**
* @file obnet.hpp
* @brief Output blending network - only works with 2 h-levels, 0 and 1, and only with
* SGD.
*/
#ifndef __OBNET_HPP
#define __OBNET_HPP
#include "data.hpp"
/**
* \brief A modulatory network architecture which uses two plain backprop networks,
* each of which is trained separately. When the network is run, each subnetwork is run
* and the output generated by interpolating between the subnet outputs.
*/
class OutputBlendingNet : public Net {
private:
/**
* \brief the modulator (or h)
*/
double modulator;
public:
/**
* \brief Constructor - does not initialise the weights to random values so
* that we can reinitialise networks.
* \param nlayers number of layers
* \param layerCounts array of layer counts
*/
OutputBlendingNet(int nlayers,const int *layerCounts) : Net() {
// we create two networks, one for each modulator level.
net0 = new BPNet(nlayers,layerCounts);
net1 = new BPNet(nlayers,layerCounts);
interpolatedOutputs = new double [net0->getOutputCount()];
}
/**
* \brief destructor to delete subnets and outputs
*/
virtual ~OutputBlendingNet(){
delete net0;
delete net1;
delete [] interpolatedOutputs;
}
virtual int getLayerSize(int n) const {
return net0->getLayerSize(n);
}
virtual int getLayerCount() const {
return net0->getLayerCount();
}
virtual void setH(double h){
modulator = h;
}
virtual double getH() const {
return modulator;
}
virtual void setInputs(double *d) {
// a bit inefficient, since we should only need to do this
// for the network currently being trained.
net0->setInputs(d);
net1->setInputs(d);
}
virtual double *getOutputs() const {
// constructed during the update
return interpolatedOutputs;
}
virtual int getDataSize() const {
// need room for the two (equally-sized) nets
return net0->getDataSize()*2;
}
virtual void save(double *buf) const {
// just save the two networks, one after the other
net0->save(buf);
buf+=net0->getDataSize();
net1->save(buf);
}
virtual void load(double *buf){
net0->load(buf);
buf+=net0->getDataSize();
net1->load(buf);
}
protected:
Net *net0; //!< the network trained by h=0 examples
Net *net1; //!< the network trained by h=1 examples
double *interpolatedOutputs; //!< the interpolated result after update()
virtual void initWeights(double initr){
net0->initWeights(initr);
net1->initWeights(initr);
}
/**
* \brief Update the two networks, and interpolate linearly between the
* outputs with the modulator.
*/
virtual void update(){
net0->update();
net1->update();
// interpolate the outputs
double *o0 = net0->getOutputs();
double *o1 = net1->getOutputs();
double h = getH();
for(int i=0;i<getOutputCount();i++){
interpolatedOutputs[i] = h*o1[i] + (1.0-h)*o0[i];
}
}
double lastError = -1;
/**
* \brief Train the network - see Net::trainBatch for more details, but this version
* is only suitable for SGD; it can only accept one example.
*/
virtual double trainBatch(ExampleSet& ex,int start,int num,double eta){
/** \bug can only use SGD for now; how this works in batching
could be tricky. */
if(num!=1)
std::runtime_error("num!=1 (i.e. batch training) not implemented");
// what we do here depends on the modulator for the first and only
// example
double hzero = (ex.getH(start)<0.5);
Net *net = hzero ? net0 : net1;
double e = net->trainBatch(ex,start,1,eta);
// return avg of 0/1 error rate, so this will change once every two cycles;
// but the first one will just be the error for h=0
double rv;
if(lastError<0){
// nothing yet, we've done one h=0, return it.
lastError=e;
rv=e;
} else if(hzero) {
// this is the h=0 the second and subsequent times; return the last mean.
rv=lastError;
} else {
// this is the h=1 - calculate the new mean and return it. Set this to
// also be the value that will be returned on the next h=0 run.
lastError = rv = (e+lastError)*0.5;
}
return rv;
}
};
#endif /* __OBNET_HPP */