Skip to content

Commit 08303e6

Browse files
fix count aggregate function (evaluate only not null fields) (#453)
Co-authored-by: Sam Willis <sam.willis@gmail.com>
1 parent bafeaa1 commit 08303e6

File tree

5 files changed

+123
-5
lines changed

5 files changed

+123
-5
lines changed

.changeset/curly-doors-deny.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@tanstack/db-ivm": patch
3+
"@tanstack/db": patch
4+
---
5+
6+
fix count aggregate function (evaluate only not null field values like SQL count)

packages/db-ivm/src/operators/groupBy.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,16 @@ export function sum<T>(
165165
/**
166166
* Creates a count aggregate function
167167
*/
168-
export function count<T>(): AggregateFunction<T, number, number> {
168+
export function count<T>(
169+
valueExtractor: (value: T) => any = (v) => v
170+
): AggregateFunction<T, number, number> {
169171
return {
170-
preMap: () => 1,
172+
// Count only not-null values (the `== null` comparison gives true for both null and undefined)
173+
preMap: (data: T) => (valueExtractor(data) == null ? 0 : 1),
171174
reduce: (values: Array<[number, number]>) => {
172175
let totalCount = 0
173-
for (const [_, multiplicity] of values) {
174-
totalCount += multiplicity
176+
for (const [nullMultiplier, multiplicity] of values) {
177+
totalCount += nullMultiplier * multiplicity
175178
}
176179
return totalCount
177180
},

packages/db-ivm/tests/operators/groupBy.test.ts

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,76 @@ describe(`Operators`, () => {
304304
expect(latestMessage.getInner()).toEqual(expectedDeleteResult)
305305
})
306306

307+
test(`with count (only not-null values)`, () => {
308+
const graph = new D2()
309+
const input = graph.newInput<{
310+
category: string
311+
amount: number | null
312+
}>()
313+
let latestMessage: any = null
314+
const messages: Array<MultiSet<any>> = []
315+
316+
input.pipe(
317+
groupBy(
318+
(data) => ({
319+
category: data.category,
320+
}),
321+
{
322+
countNotNull: count((data) => data.amount),
323+
count: count(),
324+
}
325+
),
326+
output((message) => {
327+
latestMessage = message
328+
messages.push(message)
329+
})
330+
)
331+
332+
graph.finalize()
333+
334+
// Initial data
335+
input.sendData(
336+
new MultiSet([
337+
[{ category: `A`, amount: 10 }, 1],
338+
[{ category: `B`, amount: 10 }, 1],
339+
[{ category: `A`, amount: null }, 1],
340+
[{ category: `B`, amount: null }, 1],
341+
])
342+
)
343+
344+
graph.run()
345+
346+
// Verify we have the latest message
347+
expect(latestMessage).not.toBeNull()
348+
349+
const expectedResult = [
350+
[
351+
[
352+
`{"category":"A"}`,
353+
{
354+
category: `A`,
355+
countNotNull: 1,
356+
count: 2,
357+
},
358+
],
359+
1,
360+
],
361+
[
362+
[
363+
`{"category":"B"}`,
364+
{
365+
category: `B`,
366+
countNotNull: 1,
367+
count: 2,
368+
},
369+
],
370+
1,
371+
],
372+
]
373+
374+
expect(latestMessage.getInner()).toEqual(expectedResult)
375+
})
376+
307377
test(`with avg and count aggregates`, () => {
308378
const graph = new D2()
309379
const input = graph.newInput<{

packages/db/src/query/compiler/group-by.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,17 @@ function getAggregateFunction(aggExpr: Aggregate) {
349349
return typeof value === `number` ? value : value != null ? Number(value) : 0
350350
}
351351

352+
// Create a raw value extractor function for the expression to aggregate
353+
const rawValueExtractor = ([, namespacedRow]: [string, NamespacedRow]) => {
354+
return compiledExpr(namespacedRow)
355+
}
356+
352357
// Return the appropriate aggregate function
353358
switch (aggExpr.name.toLowerCase()) {
354359
case `sum`:
355360
return sum(valueExtractor)
356361
case `count`:
357-
return count() // count() doesn't need a value extractor
362+
return count(rawValueExtractor)
358363
case `avg`:
359364
return avg(valueExtractor)
360365
case `min`:

packages/db/tests/query/group-by.test.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,40 @@ function createGroupByTests(autoIndex: `off` | `eager`): void {
180180
expect(customer3?.max_amount).toBe(250)
181181
})
182182

183+
test(`group by customer_id with count aggregation (not null only)`, () => {
184+
const customerSummary = createLiveQueryCollection({
185+
startSync: true,
186+
query: (q) =>
187+
q
188+
.from({ orders: ordersCollection })
189+
.groupBy(({ orders }) => orders.customer_id)
190+
.select(({ orders }) => ({
191+
customer_id: orders.customer_id,
192+
total: count(orders.id),
193+
onlyNotNull: count(orders.sales_rep_id),
194+
})),
195+
})
196+
197+
expect(customerSummary.size).toBe(3) // 3 customers
198+
199+
// Customer 1: orders 1, 2, 7 (total: 3, onlyNotNull: 3)
200+
const customer1 = customerSummary.get(1)
201+
expect(customer1?.total).toBe(3)
202+
expect(customer1?.onlyNotNull).toBe(3)
203+
204+
// Customer 2: orders 3, 4 (total: 2, onlyNotNull: 2)
205+
const customer2 = customerSummary.get(2)
206+
expect(customer2).toBeDefined()
207+
expect(customer2?.total).toBe(2)
208+
expect(customer2?.onlyNotNull).toBe(2)
209+
210+
// Customer 3: orders 5, 6 (total: 2, onlyNotNull: 1)
211+
const customer3 = customerSummary.get(3)
212+
expect(customer3).toBeDefined()
213+
expect(customer3?.total).toBe(2)
214+
expect(customer3?.onlyNotNull).toBe(1)
215+
})
216+
183217
test(`group by status`, () => {
184218
const statusSummary = createLiveQueryCollection({
185219
startSync: true,

0 commit comments

Comments
 (0)