@@ -60,7 +60,6 @@ class AsyncSGDServer : public ISGDCompNode {
6060 // model_ = new KVStore<Key, V, AdaGradEntry<V>, SGDState<V>>();
6161 // }
6262 }
63- // CHECK_NOTNULL(model_)->set_state(state);
6463 }
6564
6665 virtual ~AsyncSGDServer () {
@@ -139,7 +138,7 @@ class AsyncSGDServer : public ISGDCompNode {
139138 V z = 0 ;
140139 V sqrt_n = 0 ;
141140
142- void Get (V const * data, void * state) {
141+ void Set ( const V * data, void * state) {
143142 SGDState* st = (SGDState*) state;
144143 // update model
145144 V w_old = w;
@@ -155,11 +154,15 @@ class AsyncSGDServer : public ISGDCompNode {
155154 st->UpdateWeight (w, w_old);
156155 }
157156
158- void Set (V* data, void * state) { *data = w; }
157+ void Get (V* data, void * state) { *data = w; }
159158 };
160159
161160};
162161
162+ /* *
163+ * @brief A worker node
164+ *
165+ */
163166template <typename V>
164167class AsyncSGDWorker : public ISGDCompNode {
165168 public:
@@ -191,6 +194,11 @@ class AsyncSGDWorker : public ISGDCompNode {
191194 }
192195
193196 private:
197+ /* *
198+ * @brief process a data file
199+ *
200+ * @param load
201+ */
194202 void UpdateModel (const Workload& load) {
195203 LOG (INFO) << MyNodeID () << " : accept workload " << load.id ();
196204 VLOG (1 ) << " workload data: " << load.data ().ShortDebugString ();
@@ -213,13 +221,20 @@ class AsyncSGDWorker : public ISGDCompNode {
213221
214222 // pull the weight
215223 auto req = Parameter::Request (id);
224+ for (int i = 0 ; i < conf_.pull_filter_size (); ++i) {
225+ *req.add_filter () = conf_.pull_filter (i);
226+ }
216227 model_.Pull (req, key, [this , id]() { ComputeGradient (id); });
217228 }
218229
219230 while (processed_batch_ < id) { usleep (500 ); }
220231 LOG (INFO) << MyNodeID () << " : finished workload " << load.id ();
221232 }
222233
234+ /* *
235+ *
236+ * @param id minibatch id
237+ */
223238 void ComputeGradient (int id) {
224239 mu_.lock ();
225240 auto Y = data_[id].first ;
@@ -248,19 +263,17 @@ class AsyncSGDWorker : public ISGDCompNode {
248263
249264 // push the gradient
250265 auto req = Parameter::Request (id);
251- // LL << grad;
266+ for (int i = 0 ; i < conf_.push_filter_size (); ++i) {
267+ // add filters
268+ auto filter = conf_.push_filter (i);
269+ if (filter.type () == FilterConfig::KEY_CACHING) {
270+ filter.set_clear_cache_if_done (true );
271+ }
272+ *req.add_filter () = filter;
273+ }
252274 model_.Push (req, model_[id].key , grad);
253275 model_.clear (id);
254276
255-
256- // msg->add_filter(FilterConfig::KEY_CACHING)->set_clear_cache_if_done(true);
257- // int nbytes = conf_.async_sgd().fixing_float_by_nbytes();
258- // if (nbytes) {
259- // auto conf = msg->add_filter(FilterConfig::FIXING_FLOAT)->add_fixed_point();
260- // conf->set_num_bytes(nbytes);
261- // }
262-
263-
264277 ++ processed_batch_;
265278 }
266279
0 commit comments