Skip to content

Commit 898c908

Browse files
committed
fixed_point_filter
1 parent 56e9968 commit 898c908

File tree

21 files changed

+491
-37
lines changed

21 files changed

+491
-37
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ LDFLAGS += $(THIRD_LIB) -lpthread -lrt
2020
PS_LIB = build/libps.a
2121
PS_MAIN = build/libpsmain.a
2222

23-
all: ps app build/hello
23+
all: ps app
2424
clean:
2525
rm -rf build
2626

example/linear/ctr/online_l1lr.conf

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ async_sgd {
3434
algo: FTRL
3535

3636
# The size of minibatch
37-
minibatch : 10000
37+
minibatch : 100000
3838

3939
# The number of data passes
40-
num_data_pass: 10
40+
num_data_pass: 6
4141

4242
# features which occurs <= *tail_feature_freq* will be filtered before
4343
# training. it save both memory and bandwidth.
44-
tail_feature_freq : 4
44+
tail_feature_freq : 0
4545

4646
# It controls the countmin size. We filter the tail features by countmin, which
4747
# is more efficient than hash, but still is the memory bottleneck for servers. A

src/app/linear_method/async_sgd.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ struct SGDState {
3838

3939
void update() {
4040
if (reporter) {
41-
SGDProgress prog; prog.set_nnz(nnz);
41+
SGDProgress prog;
42+
prog.set_nnz(nnz);
43+
prog.set_weight_sum(weight_sum); weight_sum = 0;
44+
prog.set_delta_sum(delta_sum); delta_sum = 0;
4245
reporter->report(prog);
4346
}
4447
}
@@ -49,13 +52,18 @@ struct SGDState {
4952
} else if (new_weight != 0 && old_weight == 0) {
5053
++ nnz;
5154
}
55+
weight_sum += new_weight * new_weight;
56+
V delta = new_weight - old_weight;
57+
delta_sum += delta * delta;
5258
}
5359

5460
shared_ptr<LearningRate<V>> lr;
5561
shared_ptr<Penalty<V>> h;
5662

5763
int iter = 0;
5864
size_t nnz = 0;
65+
V weight_sum = 0;
66+
V delta_sum = 0;
5967
V max_delta = 1.0; // maximal change of weight
6068
MonitorSlaver<SGDProgress>* reporter = nullptr;
6169
};
@@ -213,7 +221,6 @@ class AsyncSGDWorker : public ISGDCompNode, public LinearMethod {
213221
data.add_file(call.data().file(idx[j]));
214222
}
215223
}
216-
217224
reader.setReader(data, sgd.minibatch(), sgd.data_buf());
218225
reader.setFilter(sgd.countmin_n(), sgd.countmin_k(), sgd.tail_feature_freq());
219226
reader.start();
@@ -262,7 +269,6 @@ class AsyncSGDWorker : public ISGDCompNode, public LinearMethod {
262269
prog.set_num_examples_processed(
263270
prog.num_examples_processed() + Xw.size());
264271
this->reporter_.report(prog);
265-
// LL << prog.objective(0) << " " << prog.auc(0);
266272

267273
// compute the gradient
268274
SArray<V> grad(X->cols());
@@ -278,6 +284,13 @@ class AsyncSGDWorker : public ISGDCompNode, public LinearMethod {
278284
model_.clear(id);
279285

280286
++ processed_batch_;
287+
288+
// auto we = w.eigenArray();
289+
// auto ge = grad.eigenArray();
290+
// LL << we.minCoeff() << " " << we.maxCoeff() << " "
291+
// << w.mean() << " " << w.std() << " "
292+
// << ge.minCoeff() << " " << ge.maxCoeff() << " "
293+
// << grad.mean() << " " << grad.std();
281294
}
282295

283296
private:

src/app/linear_method/loss.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class ScalarLoss : public Loss<T> {
5959
if (gradient.size() != 0) CHECK_EQ(gradient.size(), X->cols());
6060
if (diag_hessian.size() != 0) CHECK_EQ(diag_hessian.size(), X->cols());
6161

62+
if (!y.size()) return;
6263
compute(y.eigenArray(), X, Xw.eigenArray(), gradient.eigenArray(), diag_hessian.eigenArray());
6364
}
6465

src/app/linear_method/proto/linear.pb.h

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/app/linear_method/proto/linear.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ message SGDConfig {
4040

4141
optional int32 minibatch = 2 [default = 1000];
4242

43-
optional int32 data_buf = 12 [default = 100]; // in mb
43+
optional int32 data_buf = 12 [default = 1000]; // in mb
4444
optional bool ada_grad = 5 [default = true];
4545

4646
optional int32 max_delay = 4 [default = 0];

src/filter/compressing.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
namespace PS {
55

66
class CompressingFilter : public Filter {
7+
public:
78
void encode(const MessagePtr& msg) {
89
auto conf = find(FilterConfig::COMPRESSING, msg);
910
if (!conf) return;

src/filter/fixing_float.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#pragma once
2+
#include "filter/filter.h"
3+
#include <time.h>
4+
namespace PS {
5+
6+
class FixingFloatFilter : public Filter {
7+
public:
8+
void encode(const MessagePtr& msg) {
9+
auto conf = CHECK_NOTNULL(find(FilterConfig::FIXING_FLOAT, msg))->fixed_point();
10+
int n = msg->value.size();
11+
CHECK_EQ(n, msg->task.value_type_size());
12+
for (int i = 0; i < n; ++i) {
13+
auto type = msg->task.value_type(i);
14+
if (type == DataType::FLOAT) {
15+
msg->value[i] = encode(SArray<float>(msg->value[i]), conf);
16+
}
17+
if (type == DataType::DOUBLE) {
18+
msg->value[i] = encode(SArray<double>(msg->value[i]), conf);
19+
}
20+
}
21+
}
22+
23+
void decode(const MessagePtr& msg) {
24+
auto conf = CHECK_NOTNULL(find(FilterConfig::FIXING_FLOAT, msg))->fixed_point();
25+
int n = msg->value.size();
26+
CHECK_EQ(n, msg->task.value_type_size());
27+
for (int i = 0; i < n; ++i) {
28+
auto type = msg->task.value_type(i);
29+
if (type == DataType::FLOAT) {
30+
msg->value[i] = decode<float>(msg->value[i], conf);
31+
}
32+
}
33+
}
34+
35+
private:
36+
inline bool boolrand(int* seed) {
37+
*seed = (214013 * *seed + 2531011);
38+
return ((*seed >> 16) & 0x1) == 0;
39+
}
40+
41+
template <typename V>
42+
SArray<char> encode(const SArray<V>& data, const FilterConfig::FixedConfig& conf) {
43+
int nbytes = conf.num_bytes();
44+
CHECK_GT(nbytes, 0);
45+
CHECK_LT(nbytes, 8);
46+
V ratio = static_cast<V>(1<<(nbytes*4));
47+
V min_v = static_cast<V>(conf.min_value());
48+
V max_v = static_cast<V>(conf.max_value());
49+
V bin = max_v - min_v;
50+
CHECK_GT(bin, 0);
51+
52+
SArray<char> res(data.size() * nbytes);
53+
char* res_ptr = res.data();
54+
int seed = time(NULL);
55+
56+
for (int i = 0; i < data.size(); ++i) {
57+
V proj = data[i] > max_v ? max_v : data[i] < min_v ? min_v : data[i];
58+
V tmp = (proj - min_v) / bin * ratio;
59+
uint64 r = static_cast<uint64>(floor(tmp)) + boolrand(&seed);
60+
61+
for (int j = 0; j < nbytes; ++j) {
62+
*(res_ptr++) = static_cast<char>(r & 0xFF);
63+
r = r >> 8;
64+
}
65+
}
66+
return res;
67+
}
68+
69+
template <typename V>
70+
SArray<V> decode(const SArray<char>& data, const FilterConfig::FixedConfig& conf) {
71+
int nbytes = conf.num_bytes();
72+
V ratio = static_cast<V>(1<<(nbytes*4));
73+
V min_v = static_cast<V>(conf.min_value());
74+
V max_v = static_cast<V>(conf.max_value());
75+
V bin = max_v - min_v;
76+
77+
int n = data.size() / nbytes;
78+
SArray<V> res(n);
79+
char* data_ptr = data.data();
80+
for (int i = 0; i < n; ++i) {
81+
V r = 0;
82+
for (int j = 0; j < nbytes; ++j) {
83+
r += static_cast<uint64>(*(data_ptr++)) << 8 * j;
84+
}
85+
res[i] = static_cast<V>(r) / ratio * bin + min_v;
86+
}
87+
return res;
88+
}
89+
};
90+
91+
} // namespace PS

src/filter/key_caching.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
namespace PS {
55

66
class KeyCachingFilter : public Filter {
7+
public:
78
// thread safe
89
void encode(const MessagePtr& msg) {
910
// if (!msg->task.has_key_range()) return;

0 commit comments

Comments
 (0)