Skip to content

Commit edfbb28

Browse files
darkcoderrisesmatthewmcneely
authored andcommitted
fix(core): Added tests and change dcomments about stream framework
1 parent 81b3cb9 commit edfbb28

File tree

2 files changed

+159
-18
lines changed

2 files changed

+159
-18
lines changed

stream.go

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ type Stream struct {
7575
//
7676
// Note: Calls to KeyToList are concurrent.
7777
KeyToList func(key []byte, itr *Iterator) (*pb.KVList, error)
78-
// UseKeyToListWithThreadId is used to indicate that KeyToListWithThreadId should be used
79-
// instead of KeyToList. This is a new api that can be used to figure out parallelism
80-
// of the stream. Each threadId would be run serially. KeyToList being concurrent makes you
81-
// take care of concurrency in KeyToList. Here threadId could be used to do some things serially.
82-
// Once a thread finishes FinishThread() would be called.
78+
// UseKeyToListWithThreadId indicates that KeyToListWithThreadId should be used
79+
// instead of KeyToList to express stream parallelism. Entries with the same
80+
// threadId run serially; different threadIds may run in parallel. This avoids
81+
// handling concurrency inside KeyToList. We call FinishThread() when a thread
82+
// completes.
8383
UseKeyToListWithThreadId bool
8484
KeyToListWithThreadId func(key []byte, itr *Iterator, threadId int) (*pb.KVList, error)
8585
FinishThread func(threadId int) (*pb.KVList, error)
@@ -190,6 +190,12 @@ func (st *Stream) produceKVs(ctx context.Context, threadId int) error {
190190
_ = outList.Release()
191191
}()
192192

193+
if st.FinishThread == nil {
194+
st.FinishThread = func(threadId int) (*pb.KVList, error) {
195+
return &pb.KVList{}, nil
196+
}
197+
}
198+
193199
iterate := func(kr keyRange) error {
194200
iterOpts := DefaultIteratorOptions
195201
iterOpts.AllVersions = true
@@ -267,19 +273,17 @@ func (st *Stream) produceKVs(ctx context.Context, threadId int) error {
267273
}
268274
}
269275

270-
if st.UseKeyToListWithThreadId {
271-
if kvs, err := st.FinishThread(threadId); err != nil {
272-
return err
273-
} else {
274-
for _, kv := range kvs.Kv {
275-
kv.StreamId = streamId
276-
KVToBuffer(kv, outList)
277-
if outList.LenNoPadding() < batchSize {
278-
continue
279-
}
280-
if err := sendIt(); err != nil {
281-
return err
282-
}
276+
if kvs, err := st.FinishThread(threadId); err != nil {
277+
return err
278+
} else {
279+
for _, kv := range kvs.Kv {
280+
kv.StreamId = streamId
281+
KVToBuffer(kv, outList)
282+
if outList.LenNoPadding() < batchSize {
283+
continue
284+
}
285+
if err := sendIt(); err != nil {
286+
return err
283287
}
284288
}
285289
}

stream_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,87 @@ func keyWithPrefix(prefix string, k int) []byte {
2626
return []byte(fmt.Sprintf("%s-%d", prefix, k))
2727
}
2828

29+
func TestStreamKeyToListWithThreadId_MapReduceWordFreq(t *testing.T) {
30+
dir, err := os.MkdirTemp("", "badger-test")
31+
require.NoError(t, err)
32+
defer removeDir(dir)
33+
34+
db, err := OpenManaged(DefaultOptions(dir))
35+
require.NoError(t, err)
36+
37+
// Seed dataset with words: beta, gamma, alpha
38+
// Mapping: i%3==0 -> beta (33), i%3==1 -> gamma (34), i%3==2 -> alpha (33) per 1..100
39+
words := []string{"beta", "gamma", "alpha"}
40+
for _, prefix := range []string{"p0", "p1", "p2"} {
41+
txn := db.NewTransactionAt(math.MaxUint64, true)
42+
for i := 1; i <= 100; i++ {
43+
w := words[i%3]
44+
require.NoError(t, txn.SetEntry(NewEntry(keyWithPrefix(prefix, i), []byte(w))))
45+
}
46+
require.NoError(t, txn.CommitAt(5, nil))
47+
}
48+
49+
stream := db.NewStreamAt(math.MaxUint64)
50+
stream.LogPrefix = "Testing"
51+
stream.NumGo = 16
52+
stream.UseKeyToListWithThreadId = true
53+
54+
// Accumulate per-thread word counts to be emitted in FinishThread.
55+
// Use a slice indexed by threadId. Each threadId only writes to its own slot.
56+
toCounts := make([]map[string]int, stream.NumGo)
57+
58+
stream.KeyToListWithThreadId = func(key []byte, itr *Iterator, threadId int) (*bpb.KVList, error) {
59+
item := itr.Item()
60+
val, err := item.ValueCopy(nil)
61+
if err != nil {
62+
return nil, err
63+
}
64+
w := string(val)
65+
if toCounts[threadId] == nil {
66+
toCounts[threadId] = make(map[string]int)
67+
}
68+
toCounts[threadId][w]++
69+
// Return nothing here; flushing happens in FinishThread.
70+
return &bpb.KVList{}, nil
71+
}
72+
73+
stream.FinishThread = func(threadId int) (*bpb.KVList, error) {
74+
counts := toCounts[threadId]
75+
if len(counts) == 0 {
76+
return &bpb.KVList{}, nil
77+
}
78+
out := make([]*bpb.KV, 0, len(counts))
79+
for w, c := range counts {
80+
out = append(out, &bpb.KV{Key: []byte("wc-" + w), Value: []byte(fmt.Sprintf("%d", c))})
81+
}
82+
return &bpb.KVList{Kv: out}, nil
83+
}
84+
85+
// Use a sink, but expect no data as KeyToListWithThreadId returns nothing and FinishThread returns empty list
86+
c := &collector{}
87+
stream.Send = c.Send
88+
89+
require.NoError(t, stream.Orchestrate(ctxb))
90+
91+
// Reduce: aggregate per-word totals from partial outputs
92+
totals := map[string]int{}
93+
for _, kv := range c.kv {
94+
if strings.HasPrefix(string(kv.Key), "wc-") {
95+
w := strings.TrimPrefix(string(kv.Key), "wc-")
96+
n, err := strconv.Atoi(string(kv.Value))
97+
require.NoError(t, err)
98+
totals[w] += n
99+
}
100+
}
101+
102+
// Expected totals across 3 prefixes x 100 items with distribution defined above
103+
require.Equal(t, 99, totals["alpha"]) // 33 per prefix * 3
104+
require.Equal(t, 99, totals["beta"]) // 33 per prefix * 3
105+
require.Equal(t, 102, totals["gamma"]) // 34 per prefix * 3
106+
107+
require.NoError(t, db.Close())
108+
}
109+
29110
func keyToInt(k []byte) (string, int) {
30111
splits := strings.Split(string(k), "-")
31112
key, err := strconv.Atoi(splits[1])
@@ -160,6 +241,62 @@ func TestStream(t *testing.T) {
160241
require.NoError(t, db.Close())
161242
}
162243

244+
func TestStreamKeyToListWithThreadId(t *testing.T) {
245+
dir, err := os.MkdirTemp("", "badger-test")
246+
require.NoError(t, err)
247+
defer removeDir(dir)
248+
249+
db, err := OpenManaged(DefaultOptions(dir))
250+
require.NoError(t, err)
251+
252+
// Seed small dataset
253+
for _, prefix := range []string{"p0", "p1", "p2"} {
254+
txn := db.NewTransactionAt(math.MaxUint64, true)
255+
for i := 1; i <= 100; i++ {
256+
require.NoError(t, txn.SetEntry(NewEntry(keyWithPrefix(prefix, i), value(i))))
257+
}
258+
require.NoError(t, txn.CommitAt(5, nil))
259+
}
260+
261+
stream := db.NewStreamAt(math.MaxUint64)
262+
stream.LogPrefix = "Testing"
263+
stream.NumGo = 4 // fix number of threads for deterministic assertions
264+
stream.UseKeyToListWithThreadId = true
265+
266+
// Ensure threadId passed to KeyToListWithThreadId matches iterator's ThreadId
267+
stream.KeyToListWithThreadId = func(key []byte, itr *Iterator, threadId int) (*bpb.KVList, error) {
268+
require.Equal(t, threadId, itr.ThreadId)
269+
return stream.ToList(key, itr)
270+
}
271+
272+
// Emit a per-thread marker to verify FinishThread is invoked once per thread
273+
stream.FinishThread = func(threadId int) (*bpb.KVList, error) {
274+
kv := &bpb.KV{Key: []byte(fmt.Sprintf("done-%d", threadId))}
275+
return &bpb.KVList{Kv: []*bpb.KV{kv}}, nil
276+
}
277+
278+
c := &collector{}
279+
stream.Send = c.Send
280+
281+
err = stream.Orchestrate(ctxb)
282+
require.NoError(t, err)
283+
284+
// Verify presence of FinishThread markers and totals
285+
markers := make(map[string]struct{})
286+
for _, kv := range c.kv {
287+
if strings.HasPrefix(string(kv.Key), "done-") {
288+
markers[string(kv.Key)] = struct{}{}
289+
}
290+
}
291+
// Total should be data KVs plus marker count
292+
require.Equal(t, 300+len(markers), len(c.kv))
293+
// We expect at least one marker and at most NumGo markers (ranges may be fewer than NumGo)
294+
require.GreaterOrEqual(t, len(markers), 1)
295+
require.LessOrEqual(t, len(markers), stream.NumGo)
296+
297+
require.NoError(t, db.Close())
298+
}
299+
163300
func TestStreamMaxSize(t *testing.T) {
164301
if !*manual {
165302
t.Skip("Skipping test meant to be run manually.")

0 commit comments

Comments
 (0)