22#include " ps.h"
33#include " parameter/parameter.h"
44#include " util/parallel_ordered_match.h"
5+ #include " filter/frequency_filter.h"
56namespace PS {
67// TODO doc, and filter
78/* *
@@ -92,7 +93,7 @@ class KVVector : public Parameter {
9293 SliceKOFVMessage<K>(request, krs, msgs);
9394 }
9495 virtual void GetValue (Message* msg);
95- virtual void SetValue (Message* msg);
96+ virtual void SetValue (const Message* msg);
9697 using Parameter::Push;
9798 using Parameter::Pull;
9899 protected:
@@ -103,13 +104,31 @@ class KVVector : public Parameter {
103104 std::unordered_map<int , Buffer> buffer_; // <channel, Buffer>
104105
105106 std::mutex mu_; // protect the structure of data_ and buffer_
107+
108+ // filter tail keys
109+ FreqencyFilter<Key, uint8> freq_filter_;
106110};
107111
108112template <typename K, typename V>
109- void KVVector<K,V>::SetValue(Message* msg) {
113+ void KVVector<K,V>::SetValue(const Message* msg) {
110114 // do check
111115 SArray<K> recv_key (msg->key );
112116 if (recv_key.empty ()) return ;
117+
118+ // filter request
119+ if (msg->task .param ().has_tail_filter () && msg->task .request ()) {
120+ const auto & tail = msg->task .param ().tail_filter ();
121+ CHECK (tail.insert_count ());
122+ CHECK_EQ (msg->value .size (), 1 );
123+ SArray<uint8> count (msg->value [0 ]);
124+ CHECK_EQ (count.size (), recv_key.size ());
125+ if (freq_filter_.Empty ()) {
126+ freq_filter_.Resize (tail.countmin_n (), tail.countmin_k ());
127+ }
128+ freq_filter_.InsertKeys (recv_key, count);
129+ return ;
130+ }
131+
113132 int chl = msg->task .key_channel ();
114133 mu_.lock ();
115134 auto & kv = data_[chl];
@@ -167,6 +186,15 @@ void KVVector<K,V>::GetValue(Message* msg) {
167186 // do check
168187 SArray<K> recv_key (msg->key );
169188 if (recv_key.empty ()) return ;
189+
190+ // filter request
191+ if (msg->task .param ().has_tail_filter ()) {
192+ const auto & tail = msg->task .param ().tail_filter ();
193+ CHECK (tail.has_freq_threshold ());
194+ msg->key = freq_filter_.QueryKeys (recv_key, tail.freq_threshold ());
195+ return ;
196+ }
197+
170198 Lock l (mu_);
171199 auto & kv = data_[msg->task .key_channel ()];
172200 CHECK_EQ (kv.key .size () * k_, kv.value .size ());
@@ -195,9 +223,7 @@ template <typename K, typename V>
195223int KVVector<K,V>::Pull(const Task& request, const SArray<K>& keys,
196224 Message::Callback callback) {
197225 Lock l (mu_);
198- int chl = request.key_channel ();
199- if (keys.empty () ) CHECK_EQ (data_.count (chl), 1 ) << " empty channel " << chl;
200- auto & kv = data_[chl];
226+ auto & kv = data_[request.key_channel ()];
201227 if (!keys.empty ()) kv.key = keys;
202228 kv.value = SArray<V>(kv.key .size ()*k_, 0 );
203229
0 commit comments