Skip to content

Commit 4928cad

Browse files
committed
freq_filter at kv_vector
1 parent 6a2652c commit 4928cad

File tree

11 files changed

+94
-76
lines changed

11 files changed

+94
-76
lines changed

src/app/linear_method/async_sgd.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,7 @@ class AsyncSGDWorker : public ISGDCompNode {
220220
<< data.second->rows() << "-by-" << data.second->cols();
221221

222222
// pull the weight
223-
auto req = Parameter::Request(id);
224-
// TODO
225-
// for (int i = 0; i < conf_.pull_filter_size(); ++i) {
226-
// *req.add_filter() = conf_.pull_filter(i);
227-
// }
223+
auto req = Parameter::Request(id, -1, {}, sgd.pull_filter());
228224
model_.Pull(req, key, [this, id]() { ComputeGradient(id); });
229225
}
230226

@@ -263,16 +259,7 @@ class AsyncSGDWorker : public ISGDCompNode {
263259
loss_->compute({Y, X, Xw.SMatrix()}, {grad.SMatrix()});
264260

265261
// push the gradient
266-
auto req = Parameter::Request(id);
267-
// TODO
268-
// for (int i = 0; i < conf_.push_filter_size(); ++i) {
269-
// // add filters
270-
// auto filter = conf_.push_filter(i);
271-
// if (filter.type() == FilterConfig::KEY_CACHING) {
272-
// filter.set_clear_cache_if_done(true);
273-
// }
274-
// *req.add_filter() = filter;
275-
// }
262+
auto req = Parameter::Request(id, -1, {}, conf_.async_sgd().push_filter());
276263
model_.Push(req, model_[id].key, grad);
277264
model_.clear(id);
278265

src/app/linear_method/darlin.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ class DarlinWorker : public BCDWorker<Real>, public DarlinCompNode {
285285

286286
virtual void PreprocessData(int time, Message* request) {
287287
BCDWorker<Real>::PreprocessData(time, request);
288-
const BCDCall& call = request->task.bcd();
289288
// dual_ = exp(y.*(X_*w_))
290289
if (bcd_conf_.init_w().type() == ParamInitConfig::ZERO) {
291290
dual_.SetValue(1); // an optimizatoin

src/app/linear_method/proto/linear.proto

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ message Config {
1919
optional SGDConfig async_sgd = 17;
2020
optional BCDConfig darlin = 15;
2121

22-
repeated FilterConfig push_filter = 13;
23-
repeated FilterConfig pull_filter = 14;
2422
}
2523

2624
extend BCDConfig {
@@ -67,6 +65,9 @@ message SGDConfig {
6765
// filtered feature.
6866
optional float countmin_n = 8 [default = 1e8];
6967
optional int32 countmin_k = 7 [default = 2];
68+
69+
repeated FilterConfig push_filter = 13;
70+
repeated FilterConfig pull_filter = 14;
7071
}
7172

7273
message LossConfig {

src/learner/bcd.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,22 @@ void BCDScheduler::ProcessResponse(Message* response) {
3232
}
3333

3434
void BCDScheduler::Run() {
35-
WaitWorkersReady();
3635
LoadData();
3736
PreprocesseData();
3837
DivideFeatureBlocks();
3938
}
4039

4140
void BCDScheduler::LoadData() {
4241
// wait workers have load the data
42+
WaitWorkersReady();
4343
auto load_time = tic();
4444
int n = sys_.manager().num_workers();
4545
while (load_data_ < n) usleep(500);
4646
if (hit_cache_ > 0) {
4747
CHECK_EQ(hit_cache_, n) << "clear the local caches";
4848
NOTICE("Hit local caches for the training data");
4949
}
50-
NOTICE ("Loaded %lld examples in %g sec", g_train_info_.num_ex(), toc(load_time));
50+
NOTICE ("Loaded %lu examples in %g sec", g_train_info_.num_ex(), toc(load_time));
5151
}
5252

5353
void BCDScheduler::PreprocesseData() {

src/parameter/kv_layer.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class KVLayer : public Parameter {
7777
virtual void Slice(const Message& request, const std::vector<Range<Key>>& krs,
7878
std::vector<Message*>* msgs);
7979
virtual void GetValue(Message* msg);
80-
virtual void SetValue(Message* msg);
80+
virtual void SetValue(const Message* msg);
8181
protected:
8282
std::mutex mu_;
8383
std::unordered_map<int, SArray<V>> layer_;
@@ -94,7 +94,7 @@ int KVLayer<V, Updater>::Push(const Task& task, V* data, size_t size, bool zero_
9494
val.CopyFrom(data, size);
9595
}
9696
Message push(task, kServerGroup);
97-
Range<Key>(0, size).to(push.task.mutable_key_range());
97+
Range<Key>(0, size).To(push.task.mutable_key_range());
9898
push.add_value(val);
9999
return Parameter::Push(&push);
100100
}
@@ -109,7 +109,7 @@ int KVLayer<V, Updater>::Pull(
109109
layer_[id] = SArray<V>(data, size, false);
110110
}
111111
Message pull(task, kServerGroup);
112-
Range<Key>(0, size).to(pull.task.mutable_key_range());
112+
Range<Key>(0, size).To(pull.task.mutable_key_range());
113113
if (callback) pull.callback = callback;
114114
return Parameter::Pull(&pull);
115115
}
@@ -129,14 +129,14 @@ void KVLayer<V, Updater>::Slice(
129129
// a tiny layer, sent it to server k
130130
int k = (key * 991) % n;
131131
if (i == k) {
132-
kr.to(mut_kr);
132+
kr.To(mut_kr);
133133
} else {
134-
Range<Key>(0,0).to(mut_kr);
134+
Range<Key>(0,0).To(mut_kr);
135135
msg->valid = false; // invalid msg will not be sent
136136
}
137137
} else {
138138
// evenly parititon the data into all server nodes
139-
kr.evenDivide(n, i).to(mut_kr);
139+
kr.EvenDivide(n, i).To(mut_kr);
140140
}
141141
}
142142

@@ -176,7 +176,7 @@ void KVLayer<V, Updater>::GetValue(Message* msg) {
176176
}
177177

178178
template <typename V, class Updater>
179-
void KVLayer<V, Updater>::SetValue(Message* msg) {
179+
void KVLayer<V, Updater>::SetValue(const Message* msg) {
180180
// Lock l(mu_);
181181
CHECK_EQ(msg->value.size(), 1);
182182
SArray<V> recv_data(msg->value[0]);

src/parameter/kv_map.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class KVMap : public Parameter {
5454
}
5555

5656
virtual void GetValue(Message* msg);
57-
virtual void SetValue(Message* msg);
57+
virtual void SetValue(const Message* msg);
5858

5959
protected:
6060
int k_;
@@ -75,7 +75,7 @@ void KVMap<K,V,E,S>::GetValue(Message* msg) {
7575
}
7676

7777
template <typename K, typename V, typename E, typename S>
78-
void KVMap<K,V,E,S>::SetValue(Message* msg) {
78+
void KVMap<K,V,E,S>::SetValue(const Message* msg) {
7979
SArray<K> key(msg->key);
8080
size_t n = key.size();
8181
CHECK_EQ(msg->value.size(), 1);

src/parameter/kv_vector.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ps.h"
33
#include "parameter/parameter.h"
44
#include "util/parallel_ordered_match.h"
5+
#include "filter/frequency_filter.h"
56
namespace PS {
67
// TODO doc, and filter
78
/**
@@ -92,7 +93,7 @@ class KVVector : public Parameter {
9293
SliceKOFVMessage<K>(request, krs, msgs);
9394
}
9495
virtual void GetValue(Message* msg);
95-
virtual void SetValue(Message* msg);
96+
virtual void SetValue(const Message* msg);
9697
using Parameter::Push;
9798
using Parameter::Pull;
9899
protected:
@@ -103,13 +104,31 @@ class KVVector : public Parameter {
103104
std::unordered_map<int, Buffer> buffer_; // <channel, Buffer>
104105

105106
std::mutex mu_; // protect the structure of data_ and buffer_
107+
108+
// filter tail keys
109+
FreqencyFilter<Key, uint8> freq_filter_;
106110
};
107111

108112
template <typename K, typename V>
109-
void KVVector<K,V>::SetValue(Message* msg) {
113+
void KVVector<K,V>::SetValue(const Message* msg) {
110114
// do check
111115
SArray<K> recv_key(msg->key);
112116
if (recv_key.empty()) return;
117+
118+
// filter request
119+
if (msg->task.param().has_tail_filter() && msg->task.request()) {
120+
const auto& tail = msg->task.param().tail_filter();
121+
CHECK(tail.insert_count());
122+
CHECK_EQ(msg->value.size(), 1);
123+
SArray<uint8> count(msg->value[0]);
124+
CHECK_EQ(count.size(), recv_key.size());
125+
if (freq_filter_.Empty()) {
126+
freq_filter_.Resize(tail.countmin_n(), tail.countmin_k());
127+
}
128+
freq_filter_.InsertKeys(recv_key, count);
129+
return;
130+
}
131+
113132
int chl = msg->task.key_channel();
114133
mu_.lock();
115134
auto& kv = data_[chl];
@@ -167,6 +186,15 @@ void KVVector<K,V>::GetValue(Message* msg) {
167186
// do check
168187
SArray<K> recv_key(msg->key);
169188
if (recv_key.empty()) return;
189+
190+
// filter request
191+
if (msg->task.param().has_tail_filter()) {
192+
const auto& tail = msg->task.param().tail_filter();
193+
CHECK(tail.has_freq_threshold());
194+
msg->key = freq_filter_.QueryKeys(recv_key, tail.freq_threshold());
195+
return;
196+
}
197+
170198
Lock l(mu_);
171199
auto& kv = data_[msg->task.key_channel()];
172200
CHECK_EQ(kv.key.size() * k_, kv.value.size());
@@ -195,9 +223,7 @@ template <typename K, typename V>
195223
int KVVector<K,V>::Pull(const Task& request, const SArray<K>& keys,
196224
Message::Callback callback) {
197225
Lock l(mu_);
198-
int chl = request.key_channel();
199-
if (keys.empty() ) CHECK_EQ(data_.count(chl), 1) << "empty channel " << chl;
200-
auto& kv = data_[chl];
226+
auto& kv = data_[request.key_channel()];
201227
if (!keys.empty()) kv.key = keys;
202228
kv.value = SArray<V>(kv.key.size()*k_, 0);
203229

src/parameter/parameter.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ class Parameter : public Customer {
1111

1212
typedef std::initializer_list<int> Timestamps;
1313
typedef ::google::protobuf::RepeatedPtrField<FilterConfig> Filters;
14-
// typedef std::initializer_list<FilterConfig> Filters;
15-
1614
/**
1715
* @brief Creats a request task
1816
*
@@ -60,12 +58,12 @@ class Parameter : public Customer {
6058

6159
/// @brief Set the values in "msg" into into my data strcuture, e.g..
6260
/// my_val_[msg->key[0]] = msg->value(0)[0];
63-
virtual void SetValue(Message* msg) = 0;
61+
virtual void SetValue(const Message* msg) = 0;
6462

6563
/// @brief the message contains the backup KV pairs sent by the master node of the key
6664
/// segment to its replica node. merge these pairs into my replica, say
6765
/// replica_[msg->sender] = ...
68-
virtual void SetReplica(Message* msg) { }
66+
virtual void SetReplica(const Message* msg) { }
6967

7068
/// @brief retrieve the replica. a new server node replacing a dead server will first
7169
/// ask for the dead's replica node for the data

src/parameter/proto/param.proto

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ message ParamCall {
1010

1111
optional TailKeyFilter tail_filter = 3;
1212

13-
optional bool insert_key = 5;
14-
optional bool gather = 6;
13+
// optional bool insert_key = 5;
14+
// optional bool gather = 6;
1515

1616
// it's a replica request
1717
optional bool replica = 10;
@@ -43,6 +43,6 @@ message TailKeyFilter {
4343
optional bool insert_count = 1;
4444
optional int32 freq_threshold = 2;
4545
optional bool query_value = 3;
46-
optional int32 countmin_n = 4;
47-
optional int32 countmin_k = 5;
46+
optional int32 countmin_n = 4 [default = 1000000];
47+
optional int32 countmin_k = 5 [default = 2];
4848
}

src/util/file.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ File* File::open(const DataConfig& name, const char* const flag) {
9393
auto f = new File(des, filename);
9494
return f;
9595
}
96-
#if USE_S3
96+
#if USE_S3
9797
else if (s3file(filename)) {
9898
std::string cmd = "curl -s -X GET "+s3FileUrl(filename);
9999
// .gz
@@ -105,7 +105,7 @@ File* File::open(const DataConfig& name, const char* const flag) {
105105
auto f = new File(des, filename);
106106
return f;
107107
}
108-
#endif // USE_S3
108+
#endif // USE_S3
109109
else {
110110
return open(filename, flag);
111111
}
@@ -314,10 +314,10 @@ std::vector<std::string> s3GetFileNamesFromXml(const char* fbuf, int fsize, cons
314314
xmlFree(str);
315315
}
316316
}
317-
xmlXPathFreeObject(xpathObj);
317+
xmlXPathFreeObject(xpathObj);
318318
}
319-
xmlXPathFreeContext(xpathCtx);
320-
xmlFreeDoc(doc);
319+
xmlXPathFreeContext(xpathCtx);
320+
xmlFreeDoc(doc);
321321
return files;
322322
}
323323
#endif // USE_S3
@@ -343,9 +343,7 @@ std::vector<std::string> readFilenamesInDirectory(const DataConfig& directory) {
343343
std::vector<std::string> files;
344344
string cmd = hadoopFS(directory.hdfs()) + " -ls " + dirname;
345345

346-
if (FLAGS_verbose) {
347-
LI << "readFilenamesInDirectory hdfs ls [" << cmd << "]";
348-
}
346+
VLOG(1) << "readFilenamesInDirectory hdfs ls [" << cmd << "]";
349347

350348
FILE* des = popen(cmd.c_str(), "r"); CHECK(des);
351349
char line[10000];
@@ -365,7 +363,7 @@ std::vector<std::string> readFilenamesInDirectory(const DataConfig& directory) {
365363
pclose(des);
366364
return files;
367365
}
368-
#if USE_S3
366+
#if USE_S3
369367
else if (s3file(dirname)) {
370368
// open xml
371369
std::string cmd = "curl -s -X GET "+s3DirectoryUrl(dirname);
@@ -387,7 +385,7 @@ std::vector<std::string> readFilenamesInDirectory(const DataConfig& directory) {
387385
free (fbuf);
388386
return files;
389387
}
390-
#endif // USE_S3
388+
#endif // USE_S3
391389
else {
392390
return readFilenamesInDirectory(dirname);
393391
}

0 commit comments

Comments
 (0)