Skip to content

Commit 29b0076

Browse files
committed
user-defined filters in conf
1 parent ab3e36a commit 29b0076

File tree

4 files changed

+39
-22
lines changed

4 files changed

+39
-22
lines changed

src/app/linear_method/async_sgd.h

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class AsyncSGDServer : public ISGDCompNode {
6060
// model_ = new KVStore<Key, V, AdaGradEntry<V>, SGDState<V>>();
6161
// }
6262
}
63-
// CHECK_NOTNULL(model_)->set_state(state);
6463
}
6564

6665
virtual ~AsyncSGDServer() {
@@ -139,7 +138,7 @@ class AsyncSGDServer : public ISGDCompNode {
139138
V z = 0;
140139
V sqrt_n = 0;
141140

142-
void Get(V const* data, void* state) {
141+
void Set(const V* data, void* state) {
143142
SGDState* st = (SGDState*) state;
144143
// update model
145144
V w_old = w;
@@ -155,11 +154,15 @@ class AsyncSGDServer : public ISGDCompNode {
155154
st->UpdateWeight(w, w_old);
156155
}
157156

158-
void Set(V* data, void* state) { *data = w; }
157+
void Get(V* data, void* state) { *data = w; }
159158
};
160159

161160
};
162161

162+
/**
163+
* @brief A worker node
164+
*
165+
*/
163166
template <typename V>
164167
class AsyncSGDWorker : public ISGDCompNode {
165168
public:
@@ -191,6 +194,11 @@ class AsyncSGDWorker : public ISGDCompNode {
191194
}
192195

193196
private:
197+
/**
198+
* @brief process a data file
199+
*
200+
* @param load
201+
*/
194202
void UpdateModel(const Workload& load) {
195203
LOG(INFO) << MyNodeID() << ": accept workload " << load.id();
196204
VLOG(1) << "workload data: " << load.data().ShortDebugString();
@@ -213,13 +221,20 @@ class AsyncSGDWorker : public ISGDCompNode {
213221

214222
// pull the weight
215223
auto req = Parameter::Request(id);
224+
for (int i = 0; i < conf_.pull_filter_size(); ++i) {
225+
*req.add_filter() = conf_.pull_filter(i);
226+
}
216227
model_.Pull(req, key, [this, id]() { ComputeGradient(id); });
217228
}
218229

219230
while (processed_batch_ < id) { usleep(500); }
220231
LOG(INFO) << MyNodeID() << ": finished workload " << load.id();
221232
}
222233

234+
/**
235+
*
236+
* @param id minibatch id
237+
*/
223238
void ComputeGradient(int id) {
224239
mu_.lock();
225240
auto Y = data_[id].first;
@@ -248,19 +263,17 @@ class AsyncSGDWorker : public ISGDCompNode {
248263

249264
// push the gradient
250265
auto req = Parameter::Request(id);
251-
// LL << grad;
266+
for (int i = 0; i < conf_.push_filter_size(); ++i) {
267+
// add filters
268+
auto filter = conf_.push_filter(i);
269+
if (filter.type() == FilterConfig::KEY_CACHING) {
270+
filter.set_clear_cache_if_done(true);
271+
}
272+
*req.add_filter() = filter;
273+
}
252274
model_.Push(req, model_[id].key, grad);
253275
model_.clear(id);
254276

255-
256-
// msg->add_filter(FilterConfig::KEY_CACHING)->set_clear_cache_if_done(true);
257-
// int nbytes = conf_.async_sgd().fixing_float_by_nbytes();
258-
// if (nbytes) {
259-
// auto conf = msg->add_filter(FilterConfig::FIXING_FLOAT)->add_fixed_point();
260-
// conf->set_num_bytes(nbytes);
261-
// }
262-
263-
264277
++ processed_batch_;
265278
}
266279

src/app/linear_method/proto/linear.proto

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package PS.LM;
22
import "data/proto/data.proto";
33
import "learner/proto/bcd.proto";
4-
4+
import "filter/proto/filter.proto";
55
message Config {
66
optional DataConfig training_data = 1;
77
optional DataConfig validation_data = 2;
@@ -17,6 +17,8 @@ message Config {
1717
optional SGDConfig async_sgd = 17;
1818
optional BCDConfig darlin = 15;
1919

20+
repeated FilterConfig push_filter = 13;
21+
repeated FilterConfig pull_filter = 14;
2022
}
2123

2224
extend BCDConfig {
@@ -62,9 +64,6 @@ message SGDConfig {
6264
// filtered feature.
6365
optional float countmin_n = 8 [default = 1e8];
6466
optional int32 countmin_k = 7 [default = 2];
65-
66-
// if > 0, then use *fixing_float_by_nbytes* bytes to encode float during communication
67-
optional int32 fixing_float_by_nbytes = 13 [default = 0];
6867
}
6968

7069
message LossConfig {

src/filter/fixing_float.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,23 @@ class FixingFloatFilter : public Filter {
3030
auto type = msg->task.value_type(i);
3131
if (type == DataType::FLOAT) {
3232
CHECK_GT(filter_conf->fixed_point_size(), k);
33-
msg->value[i] = convert<float>(msg->value[i], encode, filter_conf->mutable_fixed_point(k++));
33+
msg->value[i] = convert<float>(
34+
msg->value[i], encode, filter_conf->num_bytes(),
35+
filter_conf->mutable_fixed_point(k++));
3436
}
3537
if (type == DataType::DOUBLE) {
3638
CHECK_GT(filter_conf->fixed_point_size(), k);
37-
msg->value[i] = convert<double>(msg->value[i], encode, filter_conf->mutable_fixed_point(k++));
39+
msg->value[i] = convert<double>(
40+
msg->value[i], encode, filter_conf->num_bytes(),
41+
filter_conf->mutable_fixed_point(k++));
3842
}
3943
}
4044
}
4145

4246
// decode / encode an array
4347
template <typename V>
44-
SArray<char> convert(const SArray<char>& array, bool encode, FilterConfig::FixedFloatConfig* conf) {
45-
int nbytes = conf->num_bytes();
48+
SArray<char> convert(const SArray<char>& array, bool encode, int nbytes,
49+
FilterConfig::FixedFloatConfig* conf) {
4650
CHECK_GT(nbytes, 0);
4751
CHECK_LT(nbytes, 8);
4852
double ratio = static_cast<double>(1 << (nbytes*8)) - 2;

src/filter/proto/filter.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ message FilterConfig {
1414
// if the task is done, then clear the cache (to save memory)
1515
optional bool clear_cache_if_done = 20 [default = false];
1616

17+
optional int32 num_bytes = 5 [default = 3];
18+
1719
message FixedFloatConfig {
1820
optional float min_value = 1 [default = -1];
1921
optional float max_value = 2 [default = 1];
20-
optional int32 num_bytes = 3 [default = 3];
2122
}
2223
repeated FixedFloatConfig fixed_point = 4;
2324

0 commit comments

Comments
 (0)