Skip to content

Commit 6a2652c

Browse files
committed
refactor range and localizer
1 parent f49f101 commit 6a2652c

File tree

14 files changed

+82
-82
lines changed

14 files changed

+82
-82
lines changed

src/app/linear_method/darlin.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DarlinScheduler : public BCDScheduler {
6262
}
6363
// block info
6464
auto blk = fea_blk_[order[i]];
65-
blk.second.to(cmd->mutable_key());
65+
blk.second.To(cmd->mutable_key());
6666
cmd->add_fea_grp(blk.first);
6767

6868
// time stamp
@@ -193,7 +193,7 @@ class DarlinServer : public BCDServer<Real>, public DarlinCompNode {
193193
auto col_range = model_[grp].key.FindRange(g_key_range);
194194

195195
// none of my bussiness
196-
if (MyKeyRange().setIntersection(g_key_range).empty()) return;
196+
if (MyKeyRange().SetIntersection(g_key_range).empty()) return;
197197

198198
// aggregate all workers' local gradients
199199
model_.WaitReceivedRequest(time, kWorkerGroup);
@@ -354,7 +354,7 @@ class DarlinWorker : public BCDWorker<Real>, public DarlinCompNode {
354354
ThreadPool pool(num_threads);
355355
int npart = num_threads * 1; // could use a larger partition number
356356
for (int i = 0; i < npart; ++i) {
357-
auto thr_range = col_range.evenDivide(npart, i);
357+
auto thr_range = col_range.EvenDivide(npart, i);
358358
if (thr_range.empty()) continue;
359359
auto gr = thr_range - col_range.begin();
360360
pool.add([this, grp, thr_range, gr, &G, &U]() {
@@ -500,7 +500,7 @@ class DarlinWorker : public BCDWorker<Real>, public DarlinCompNode {
500500
ThreadPool pool(FLAGS_num_threads);
501501
int npart = FLAGS_num_threads;
502502
for (int i = 0; i < npart; ++i) {
503-
auto thr_range = row_range.evenDivide(npart, i);
503+
auto thr_range = row_range.EvenDivide(npart, i);
504504
if (thr_range.empty()) continue;
505505
pool.add([this, grp, thr_range, col_range, delta_w]() {
506506
UpdateDual(grp, thr_range, col_range, delta_w);

src/learner/bcd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void BCDScheduler::DivideFeatureBlocks() {
8787
n = std::max((int)std::ceil(nnz_per_row * bcd_conf_.feature_block_ratio()), 1);
8888
}
8989
for (int i = 0; i < n; ++i) {
90-
auto block = Range<Key>(info.min_key(), info.max_key()).evenDivide(n, i);
90+
auto block = Range<Key>(info.min_key(), info.max_key()).EvenDivide(n, i);
9191
if (block.empty()) continue;
9292
fea_blk_.push_back(std::make_pair(info.id(), block));
9393
}

src/learner/bcd.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ class BCDWorker : public BCDCompNode<V> {
310310
Localizer<Key, double> *localizer = new Localizer<Key, double>();
311311

312312
VLOG(1) << "counting unique key [" << i + 1 << "/" << grp_size << "]";
313-
localizer->countUniqIndex(slot_reader_.index(grp), &uniq_key, &key_cnt);
313+
localizer->CountUniqIndex(slot_reader_.index(grp), &uniq_key, &key_cnt);
314314
VLOG(1) << "finished counting [" << i + 1 << "/" << grp_size << "]";
315315

316316
// push key and count to servers
@@ -333,7 +333,7 @@ class BCDWorker : public BCDCompNode<V> {
333333
pull_msg.callback = [this, grp, localizer, i, grp_size]() mutable {
334334
// localize the training matrix
335335
VLOG(1) << "remap index [" << i + 1 << "/" << grp_size << "]";
336-
auto X = localizer->remapIndex(grp, model_[grp].key, &slot_reader_);
336+
auto X = localizer->RemapIndex(grp, model_[grp].key, &slot_reader_);
337337
delete localizer;
338338
slot_reader_.clear(grp);
339339
if (!X) return;

src/learner/sgd.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,14 @@ class MinibatchReader {
127127
SArray<Key> uniq_key;
128128
SArray<uint8> key_cnt;
129129
Localizer<Key, V> localizer;
130-
localizer.countUniqIndex(data[1], &uniq_key, &key_cnt);
130+
localizer.CountUniqIndex(data[1], &uniq_key, &key_cnt);
131131

132132
// filter keys
133133
filter_.InsertKeys(uniq_key, key_cnt);
134134
key = filter_.QueryKeys(uniq_key, key_freq_);
135135

136136
// remap keys
137-
X = localizer.remapIndex(key);
137+
X = localizer.RemapIndex(key);
138138
return true;
139139
}
140140

src/parameter/parameter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ class Parameter : public Customer {
2828
int ts = Message::kInvalidTime,
2929
const Timestamps& wait = {},
3030
const Filters& filters = Filters(),
31-
const Range<Key>& key_range = Range<Key>::all()) {
31+
const Range<Key>& key_range = Range<Key>::All()) {
3232
Task req; req.set_request(true);
3333
req.set_key_channel(channel);
3434
if (ts > Message::kInvalidTime) req.set_time(ts);
3535
for (int t : wait) req.add_wait_time(t);
3636
for (const auto& f : filters) *req.add_filter() = f;
37-
key_range.to(req.mutable_key_range());
37+
key_range.To(req.mutable_key_range());
3838
return req;
3939
}
4040

src/system/assigner.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace PS {
88
// assign *node* with proper rank_id, key_range, etc..
99
class NodeAssigner {
1010
public:
11-
NodeAssigner(int num_servers, Range<Key> key_range = Range<Key>::all()) {
11+
NodeAssigner(int num_servers, Range<Key> key_range = Range<Key>::All()) {
1212
num_servers_ = num_servers;
1313
key_range_ = key_range;
1414
}
@@ -18,13 +18,13 @@ class NodeAssigner {
1818
Range<Key> kr = key_range_;
1919
int rank = 0;
2020
if (node->role() == Node::SERVER) {
21-
kr = key_range_.evenDivide(num_servers_, server_rank_);
21+
kr = key_range_.EvenDivide(num_servers_, server_rank_);
2222
rank = server_rank_ ++;
2323
} else if (node->role() == Node::WORKER) {
2424
rank = worker_rank_ ++;
2525
}
2626
node->set_rank(rank);
27-
kr.to(node->mutable_key());
27+
kr.To(node->mutable_key());
2828
}
2929

3030
void remove(const Node& node) {

src/system/message.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ template <typename T> void Message::set_key(const SArray<T>& key) {
7171
if (has_key()) clear_key();
7272
task.set_has_key(true);
7373
this->key = SArray<char>(key);
74-
if (!task.has_key_range()) Range<Key>::all().to(task.mutable_key_range());
74+
if (!task.has_key_range()) Range<Key>::All().To(task.mutable_key_range());
7575
}
7676

7777
template <typename T> void Message::add_value(const SArray<T>& value) {
@@ -115,19 +115,19 @@ template <typename K> void SliceKOFVMessage(
115115
Range<Key> msg_key_range(msg.task.key_range());
116116
for (int i = 0; i < n; ++i) {
117117
if (i == 0) {
118-
K k = (K)msg_key_range.project(krs[0].begin());
118+
K k = (K)msg_key_range.Project(krs[0].begin());
119119
pos[0] = std::lower_bound(key.begin(), key.end(), k) - key.begin();
120120
} else {
121121
CHECK_EQ(krs[i-1].end(), krs[i].begin());
122122
}
123-
K k = (K)msg_key_range.project(krs[i].end());
123+
K k = (K)msg_key_range.Project(krs[i].end());
124124
pos[i+1] = std::lower_bound(key.begin(), key.end(), k) - key.begin();
125125
}
126126

127127
// split the message according to *pos*
128128
for (int i = 0; i < n; ++i) {
129129
Message* ret = CHECK_NOTNULL((*rets)[i]);
130-
if (krs[i].setIntersection(msg_key_range).empty()) {
130+
if (krs[i].SetIntersection(msg_key_range).empty()) {
131131
// the remote node does not maintain this key range. mark this message as
132132
// valid, which will not be sent
133133
ret->valid = false;

src/system/remote_node.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void RemoteNode::AddGroupNode(RemoteNode* rnode) {
3434
int pos = 0;
3535
Range<Key> kr(rnode->node.key());
3636
while (pos < group.size()) {
37-
if (kr.inLeft(Range<Key>(group[pos]->node.key()))) {
37+
if (kr.InLeft(Range<Key>(group[pos]->node.key()))) {
3838
break;
3939
}
4040
++ pos;

src/util/dense_matrix.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ class DenseMatrix : public Matrix<V> {
3636
virtual MatrixPtr<V> rowBlock(SizeR range) const {
3737
if (colMajor()) CHECK_EQ(range, SizeR(0, rows()));
3838
auto info = info_;
39-
range.to(info.mutable_row());
39+
range.To(info.mutable_row());
4040
info.set_nnz(range.size() * cols());
4141
return MatrixPtr<V>(new DenseMatrix<V>(info, value_.Segment(range*cols())));
4242
}
4343

4444
virtual MatrixPtr<V> colBlock(SizeR range) const {
4545
if (rowMajor()) CHECK_EQ(range, SizeR(0, cols()));
4646
auto info = info_;
47-
range.to(info.mutable_col());
47+
range.To(info.mutable_col());
4848
info.set_nnz(range.size() * rows());
4949
return MatrixPtr<V>(new DenseMatrix<V>(info, value_.Segment(range*rows())));
5050
}
@@ -68,8 +68,8 @@ void DenseMatrix<V>::resize(
6868
size_t rows, size_t cols, size_t nnz, bool row_major) {
6969
info_.set_type(MatrixInfo::DENSE);
7070
info_.set_row_major(row_major);
71-
SizeR(0, rows).to(info_.mutable_row());
72-
SizeR(0, cols).to(info_.mutable_col());
71+
SizeR(0, rows).To(info_.mutable_row());
72+
SizeR(0, cols).To(info_.mutable_col());
7373
nnz = rows * cols;
7474
// CHECK_EQ(nnz, rows*cols);
7575
info_.set_nnz(nnz);

src/util/localizer.h

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,32 @@ class Localizer {
1717

1818
// find the unique indeces with their number of occrus in *idx*
1919
void countUniqIndex(const SArray<I>& idx, SArray<I>* uniq_idx) {
20-
countUniqIndex<char>(idx, uniq_idx, nullptr);
20+
CountUniqIndex<char>(idx, uniq_idx, nullptr);
2121
}
22-
void countUniqIndex(const MatrixPtr<V>&mat, SArray<I>* uniq_idx) {
23-
countUniqIndex<char>(mat, uniq_idx, nullptr);
22+
void CountUniqIndex(const MatrixPtr<V>&mat, SArray<I>* uniq_idx) {
23+
CountUniqIndex<char>(mat, uniq_idx, nullptr);
2424
}
25-
template<typename C> void countUniqIndex(
25+
template<typename C> void CountUniqIndex(
2626
const SArray<I>& idx, SArray<I>* uniq_idx, SArray<C>* idx_frq);
2727

28-
template<typename C> void countUniqIndex(
28+
template<typename C> void CountUniqIndex(
2929
const MatrixPtr<V>& mat, SArray<I>* uniq_idx, SArray<C>* idx_frq);
3030

3131
// return a matrix with index mapped: idx_dict[i] -> i. Any index does not exists
3232
// in *idx_dict* is dropped. Assume *idx_dict* is ordered
33-
MatrixPtr<V> remapIndex(int grp_id, const SArray<I>& idx_dict, SlotReader* reader) const;
33+
MatrixPtr<V> RemapIndex(int grp_id, const SArray<I>& idx_dict, SlotReader* reader) const;
3434

3535
// valid only if used countUniqIndex(mat, ...) before
36-
MatrixPtr<V> remapIndex(const SArray<I>& idx_dict);
36+
MatrixPtr<V> RemapIndex(const SArray<I>& idx_dict);
3737

38-
MatrixPtr<V> remapIndex(
38+
MatrixPtr<V> RemapIndex(
3939
const MatrixInfo& info, const SArray<size_t>& offset,
4040
const SArray<I>& index, const SArray<V>& value,
4141
const SArray<I>& idx_dict) const;
4242

43-
void clear() { pair_.clear(); }
43+
void Clear() { pair_.clear(); }
4444

45-
size_t memSize() {
45+
size_t MemSize() {
4646
return pair_.size() * sizeof(Pair) + (mat_ == nullptr ? 0 : mat_->memSize());
4747
}
4848
private:
@@ -56,17 +56,17 @@ class Localizer {
5656

5757
template<typename I, typename V>
5858
template<typename C>
59-
void Localizer<I,V>::countUniqIndex(
59+
void Localizer<I,V>::CountUniqIndex(
6060
const MatrixPtr<V>& mat, SArray<I>* uniq_idx, SArray<C>* idx_frq) {
6161
mat_ = std::static_pointer_cast<SparseMatrix<I,V>>(mat);
62-
countUniqIndex(mat_->index(), uniq_idx, idx_frq);
62+
CountUniqIndex(mat_->index(), uniq_idx, idx_frq);
6363
}
6464

6565

6666

6767
template<typename I, typename V>
6868
template<typename C>
69-
void Localizer<I,V>::countUniqIndex(
69+
void Localizer<I,V>::CountUniqIndex(
7070
const SArray<I>& idx, SArray<I>* uniq_idx, SArray<C>* idx_frq) {
7171
if (idx.empty()) return;
7272
CHECK(uniq_idx);
@@ -108,22 +108,22 @@ void Localizer<I,V>::countUniqIndex(
108108
}
109109

110110
template<typename I, typename V>
111-
MatrixPtr<V> Localizer<I,V>::remapIndex(const SArray<I>& idx_dict) {
111+
MatrixPtr<V> Localizer<I,V>::RemapIndex(const SArray<I>& idx_dict) {
112112
CHECK(mat_);
113-
return remapIndex(mat_->info(), mat_->offset(), mat_->index(), mat_->value(), idx_dict);
113+
return RemapIndex(mat_->info(), mat_->offset(), mat_->index(), mat_->value(), idx_dict);
114114
}
115115

116116
template<typename I, typename V>
117-
MatrixPtr<V> Localizer<I, V>::remapIndex(
117+
MatrixPtr<V> Localizer<I, V>::RemapIndex(
118118
int grp_id, const SArray<I>& idx_dict, SlotReader* reader) const {
119119
SArray<V> val;
120120
auto info = reader->info<V>(grp_id);
121121
if (info.type() == MatrixInfo::SPARSE) val = reader->value<V>(grp_id);
122-
return remapIndex(info, reader->offset(grp_id), reader->index(grp_id), val, idx_dict);
122+
return RemapIndex(info, reader->offset(grp_id), reader->index(grp_id), val, idx_dict);
123123
}
124124

125125
template<typename I, typename V>
126-
MatrixPtr<V> Localizer<I, V>::remapIndex(
126+
MatrixPtr<V> Localizer<I, V>::RemapIndex(
127127
const MatrixInfo& info, const SArray<size_t>& offset,
128128
const SArray<I>& index, const SArray<V>& value,
129129
const SArray<I>& idx_dict) const {
@@ -182,9 +182,9 @@ MatrixPtr<V> Localizer<I, V>::remapIndex(
182182
new_info.set_nnz(new_index.size());
183183
SizeR local(0, idx_dict.size());
184184
if (new_info.row_major()) {
185-
local.to(new_info.mutable_col());
185+
local.To(new_info.mutable_col());
186186
} else {
187-
local.to(new_info.mutable_row());
187+
local.To(new_info.mutable_row());
188188
}
189189
// LL << curr_o << " " << local.end() << " " << curr_j;
190190
return MatrixPtr<V>(new SparseMatrix<uint32, V>(new_info, new_offset, new_index, new_value));

0 commit comments

Comments
 (0)