-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathAPOP.h
More file actions
166 lines (140 loc) · 3.54 KB
/
APOP.h
File metadata and controls
166 lines (140 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
/*
* APOP.h
*
* Created on: Jul 20, 2016
* Author: mason
*/
#ifndef APOP_H_
#define APOP_H_
#include "MyLib.h"
#include "Alphabet.h"
#include "Node.h"
#include "Graph.h"
#include "APParam.h"
// for sparse features
struct APParams {
public:
APParam W;
PAlphabet elems;
int nVSize;
int nDim;
public:
APParams() {
nVSize = 0;
nDim = 0;
elems = NULL;
}
inline void exportAdaParams(ModelUpdate& ada) {
ada.addParam(&W);
}
inline void initialWeights(int nOSize) {
if (nVSize == 0) {
std::cout << "please check the alphabet" << std::endl;
return;
}
nDim = nOSize;
W.initial(nOSize, nVSize);
}
//random initialization
inline void initial(PAlphabet alpha, int nOSize, int base = 1) {
assert(base >= 1);
elems = alpha;
nVSize = base * elems->size();
if (base > 1) {
std::cout << "nVSize: " << nVSize << ", Alpha Size = " << elems->size() << ", Require more Alpha."<< std::endl;
elems->set_fixed_flag(false);
}
initialWeights(nOSize);
}
inline int getFeatureId(const string& strFeat) {
int idx = elems->from_string(strFeat);
if(!elems->m_b_fixed && elems->m_size >= nVSize){
std::cout << "AP Alphabet stopped collecting features" << std::endl;
elems->set_fixed_flag(true);
}
return idx;
}
};
//only implemented sparse linear node.
//non-linear transformations are not support,
class APNode : public Node {
public:
APParams* param;
vector<int> ins;
public:
APNode() : Node() {
ins.clear();
param = NULL;
node_type = "apnode";
}
inline void setParam(APParams* paramInit) {
param = paramInit;
}
inline void clearValue() {
Node::clearValue();
ins.clear();
}
public:
//notice the output
void forward(Graph *cg, const vector<string>& x) {
int featId;
int featSize = x.size();
for (int idx = 0; idx < featSize; idx++) {
featId = param->getFeatureId(x[idx]);
if (featId >= 0) {
ins.push_back(featId);
}
}
degree = 0;
cg->addNode(this);
}
public:
inline void compute(bool bTrain) {
param->W.value(ins, val, bTrain);
}
//no output losses
void backward() {
//assert(param != NULL);
param->W.loss(ins, loss);
}
public:
inline PExecute generate(bool bTrain);
// better to rewrite for deep understanding
inline bool typeEqual(PNode other) {
bool result = Node::typeEqual(other);
if (!result) return false;
APNode* conv_other = (APNode*)other;
if (param != conv_other->param) {
return false;
}
return true;
}
};
class APExecute :public Execute {
public:
bool bTrain;
public:
inline void forward() {
int count = batch.size();
for (int idx = 0; idx < count; idx++) {
APNode* ptr = (APNode*)batch[idx];
ptr->compute(bTrain);
ptr->forward_drop(bTrain);
}
}
inline void backward() {
int count = batch.size();
for (int idx = 0; idx < count; idx++) {
APNode* ptr = (APNode*)batch[idx];
ptr->backward_drop();
ptr->backward();
}
}
};
inline PExecute APNode::generate(bool bTrain) {
APExecute* exec = new APExecute();
exec->batch.push_back(this);
exec->bTrain = bTrain;
return exec;
}
#endif /* APOP_H_ */