Skip to content

Commit 8e6ab54

Browse files
authored
feat(udf)!: switch ScalarFunction.evaluate to ColumnarValue API (closes #62) (#64)
1 parent 1a81099 commit 8e6ab54

11 files changed

Lines changed: 712 additions & 142 deletions

File tree

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.datafusion;
21+
22+
import java.util.Objects;
23+
24+
import org.apache.arrow.vector.FieldVector;
25+
import org.apache.arrow.vector.types.pojo.ArrowType;
26+
27+
/**
28+
* The value of a scalar UDF argument or result: either a per-row {@link Array} of length {@code
29+
* rowCount}, or a {@link Scalar} (length-1 vector) that the framework broadcasts.
30+
*
31+
* <p>Mirrors DataFusion's {@code datafusion::logical_expr::ColumnarValue} enum. Use {@link
32+
* #array(FieldVector)} and {@link #scalar(FieldVector)} factories rather than constructing the
33+
* records directly so the length invariants are enforced consistently.
34+
*/
35+
public sealed interface ColumnarValue permits ColumnarValue.Array, ColumnarValue.Scalar {
36+
37+
/** The underlying Arrow vector. For {@link Scalar} this vector has {@code valueCount == 1}. */
38+
FieldVector vector();
39+
40+
/** Convenience: the vector's declared Arrow type. */
41+
default ArrowType dataType() {
42+
return vector().getField().getType();
43+
}
44+
45+
/** Wrap an arbitrary-length vector as an {@link Array}. */
46+
static ColumnarValue array(FieldVector vector) {
47+
return new Array(Objects.requireNonNull(vector, "vector"));
48+
}
49+
50+
/**
51+
* Wrap a length-1 vector as a {@link Scalar}.
52+
*
53+
* @throws IllegalArgumentException if {@code vector.getValueCount() != 1}
54+
*/
55+
static ColumnarValue scalar(FieldVector vector) {
56+
Objects.requireNonNull(vector, "vector");
57+
if (vector.getValueCount() != 1) {
58+
throw new IllegalArgumentException(
59+
"Scalar vector must have valueCount == 1, got " + vector.getValueCount());
60+
}
61+
return new Scalar(vector);
62+
}
63+
64+
/** Per-row Arrow vector of length equal to the batch row count. */
65+
record Array(FieldVector vector) implements ColumnarValue {
66+
public Array {
67+
Objects.requireNonNull(vector, "vector");
68+
}
69+
}
70+
71+
/** Length-1 Arrow vector representing a single value broadcast across all rows. */
72+
record Scalar(FieldVector vector) implements ColumnarValue {
73+
public Scalar {
74+
Objects.requireNonNull(vector, "vector");
75+
if (vector.getValueCount() != 1) {
76+
throw new IllegalArgumentException(
77+
"Scalar vector must have valueCount == 1, got " + vector.getValueCount());
78+
}
79+
}
80+
}
81+
}

core/src/main/java/org/apache/datafusion/ScalarFunction.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.util.List;
2323

2424
import org.apache.arrow.memory.BufferAllocator;
25-
import org.apache.arrow.vector.FieldVector;
2625
import org.apache.arrow.vector.types.pojo.ArrowType;
2726

2827
/**
@@ -46,7 +45,9 @@ public interface ScalarFunction {
4645
*/
4746
List<ArrowType> argTypes();
4847

49-
/** Declared return type. The returned {@link FieldVector} must have this exact type. */
48+
/**
49+
* Declared return type. The returned {@link ColumnarValue}'s vector must have this exact type.
50+
*/
5051
ArrowType returnType();
5152

5253
/**
@@ -59,14 +60,16 @@ public interface ScalarFunction {
5960
/**
6061
* Compute the function result for one input batch.
6162
*
62-
* @param allocator the {@link BufferAllocator} that MUST be used for any new {@link FieldVector}
63+
* @param allocator the {@link BufferAllocator} that MUST be used for any new Arrow vector
6364
* allocation, including the result. Buffers allocated from other allocators will not survive
6465
* the JNI handoff.
65-
* @param args one {@link FieldVector} per declared argument, all of the same length. These are
66-
* read-only views; the implementation must NOT close them.
67-
* @return a {@link FieldVector} of the declared return type and the same length as the inputs.
68-
* Ownership transfers to the framework on return; the implementation must NOT close the
69-
* returned vector.
66+
* @param args the per-arg {@link ColumnarValue}s and the batch row count. Each {@link
67+
* ColumnarValue} is a read-only view; the implementation must NOT close its underlying
68+
* vector.
69+
* @return a {@link ColumnarValue} of the declared return type. If {@link ColumnarValue.Array},
70+
* the underlying vector must have length {@code args.rowCount()}; if {@link
71+
* ColumnarValue.Scalar}, length 1. Ownership of the returned vector transfers to the
72+
* framework; the implementation must NOT close it.
7073
*/
71-
FieldVector evaluate(BufferAllocator allocator, List<FieldVector> args);
74+
ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args);
7275
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.datafusion;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
25+
/**
26+
* Bundle of inputs passed to {@link ScalarFunction#evaluate}: the per-arg {@link ColumnarValue}s
27+
* (in declared order) and the batch row count DataFusion is driving.
28+
*
29+
* <p>Mirrors DataFusion's {@code datafusion::logical_expr::ScalarFunctionArgs}. {@code rowCount} is
30+
* the only channel by which an array-returning UDF without array-typed inputs (all-scalar args, or
31+
* nullary) can size its output. Nullary UDFs that prefer to broadcast a single value should return
32+
* {@link ColumnarValue#scalar(org.apache.arrow.vector.FieldVector) ColumnarValue.scalar(...)}
33+
* instead, which removes the need to consult {@code rowCount}.
34+
*/
35+
public record ScalarFunctionArgs(List<ColumnarValue> args, int rowCount) {
36+
public ScalarFunctionArgs {
37+
args = List.copyOf(Objects.requireNonNull(args, "args"));
38+
if (rowCount < 0) {
39+
throw new IllegalArgumentException("rowCount must be >= 0, got " + rowCount);
40+
}
41+
}
42+
}

core/src/main/java/org/apache/datafusion/internal/JniBridge.java

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.datafusion.internal;
2121

22+
import java.util.ArrayList;
2223
import java.util.List;
2324

2425
import org.apache.arrow.c.ArrowArray;
@@ -27,7 +28,9 @@
2728
import org.apache.arrow.memory.RootAllocator;
2829
import org.apache.arrow.vector.FieldVector;
2930
import org.apache.arrow.vector.VectorSchemaRoot;
31+
import org.apache.datafusion.ColumnarValue;
3032
import org.apache.datafusion.ScalarFunction;
33+
import org.apache.datafusion.ScalarFunctionArgs;
3134

3235
/** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */
3336
public final class JniBridge {
@@ -40,54 +43,100 @@ public final class JniBridge {
4043

4144
private JniBridge() {}
4245

46+
/** argKind byte signalling a {@link ColumnarValue.Array} arg. */
47+
private static final byte KIND_ARRAY = 0;
48+
49+
/** argKind byte signalling a {@link ColumnarValue.Scalar} arg. */
50+
private static final byte KIND_SCALAR = 1;
51+
4352
/**
4453
* Invoke a scalar UDF for one batch. Called from native code; not for application use.
4554
*
46-
* @param impl the registered {@link ScalarFunction} implementation
47-
* @param argsArrayAddr address of a populated {@code FFI_ArrowArray} struct holding the input
48-
* batch as a struct array (one field per UDF argument)
49-
* @param argsSchemaAddr address of the matching {@code FFI_ArrowSchema}
50-
* @param resultArrayAddr address of an empty {@code FFI_ArrowArray} the bridge writes into
51-
* @param resultSchemaAddr address of an empty {@code FFI_ArrowSchema} the bridge writes into
52-
* @param expectedRowCount the row count the result vector must have
55+
* <p>Args arrive split into two struct arrays: {@code arrayArgs*} of length {@code rowCount}
56+
* holding the {@link ColumnarValue.Array} arguments in their relative order, and {@code
57+
* scalarArgs*} of length 1 holding the {@link ColumnarValue.Scalar} arguments. {@code argKinds}
58+
* records the original positional order so the bridge can interleave them back into a single
59+
* {@code List<ColumnarValue>} for the user.
60+
*
61+
* @return {@link #KIND_ARRAY} if the UDF returned an Array, {@link #KIND_SCALAR} if it returned a
62+
* Scalar. The native caller uses this to reconstruct the right {@code ColumnarValue} variant.
5363
*/
54-
public static void invokeScalarUdf(
64+
public static byte invokeScalarUdf(
5565
ScalarFunction impl,
56-
long argsArrayAddr,
57-
long argsSchemaAddr,
66+
long arrayArgsArrayAddr,
67+
long arrayArgsSchemaAddr,
68+
long scalarArgsArrayAddr,
69+
long scalarArgsSchemaAddr,
70+
byte[] argKinds,
5871
long resultArrayAddr,
5972
long resultSchemaAddr,
60-
int expectedRowCount) {
61-
ArrowArray argsArr = ArrowArray.wrap(argsArrayAddr);
62-
ArrowSchema argsSch = ArrowSchema.wrap(argsSchemaAddr);
73+
int rowCount) {
74+
ArrowArray arrayArr = ArrowArray.wrap(arrayArgsArrayAddr);
75+
ArrowSchema arraySch = ArrowSchema.wrap(arrayArgsSchemaAddr);
76+
ArrowArray scalarArr = ArrowArray.wrap(scalarArgsArrayAddr);
77+
ArrowSchema scalarSch = ArrowSchema.wrap(scalarArgsSchemaAddr);
6378
ArrowArray resultArr = ArrowArray.wrap(resultArrayAddr);
6479
ArrowSchema resultSch = ArrowSchema.wrap(resultSchemaAddr);
6580

66-
try (VectorSchemaRoot root = Data.importVectorSchemaRoot(ALLOCATOR, argsArr, argsSch, null)) {
67-
List<FieldVector> argVectors = root.getFieldVectors();
81+
try (VectorSchemaRoot arrayRoot =
82+
Data.importVectorSchemaRoot(ALLOCATOR, arrayArr, arraySch, null);
83+
VectorSchemaRoot scalarRoot =
84+
Data.importVectorSchemaRoot(ALLOCATOR, scalarArr, scalarSch, null)) {
6885

69-
FieldVector result = impl.evaluate(ALLOCATOR, argVectors);
86+
List<FieldVector> arrayFields = arrayRoot.getFieldVectors();
87+
List<FieldVector> scalarFields = scalarRoot.getFieldVectors();
88+
89+
List<ColumnarValue> args = new ArrayList<>(argKinds.length);
90+
int arrayIdx = 0;
91+
int scalarIdx = 0;
92+
for (byte kind : argKinds) {
93+
if (kind == KIND_ARRAY) {
94+
args.add(ColumnarValue.array(arrayFields.get(arrayIdx++)));
95+
} else if (kind == KIND_SCALAR) {
96+
args.add(ColumnarValue.scalar(scalarFields.get(scalarIdx++)));
97+
} else {
98+
throw new IllegalStateException("Unknown argKind byte: " + kind);
99+
}
100+
}
101+
102+
ColumnarValue result = impl.evaluate(ALLOCATOR, new ScalarFunctionArgs(args, rowCount));
70103

71104
if (result == null) {
72105
throw new IllegalStateException("ScalarFunction.evaluate returned null");
73106
}
74-
if (result.getValueCount() != expectedRowCount) {
107+
108+
FieldVector resultVec = result.vector();
109+
byte resultKind;
110+
int expectedLen;
111+
if (result instanceof ColumnarValue.Array) {
112+
resultKind = KIND_ARRAY;
113+
expectedLen = rowCount;
114+
} else {
115+
resultKind = KIND_SCALAR;
116+
expectedLen = 1;
117+
}
118+
119+
if (resultVec.getValueCount() != expectedLen) {
75120
try {
76121
throw new IllegalStateException(
77-
"ScalarFunction.evaluate returned vector with "
78-
+ result.getValueCount()
122+
"ScalarFunction.evaluate returned "
123+
+ (resultKind == KIND_ARRAY ? "Array" : "Scalar")
124+
+ " vector with "
125+
+ resultVec.getValueCount()
79126
+ " rows; expected "
80-
+ expectedRowCount);
127+
+ expectedLen);
81128
} finally {
82-
result.close();
129+
resultVec.close();
83130
}
84131
}
85132

86133
try {
87-
Data.exportVector(ALLOCATOR, result, null, resultArr, resultSch);
134+
Data.exportVector(ALLOCATOR, resultVec, null, resultArr, resultSch);
88135
} finally {
89-
result.close();
136+
resultVec.close();
90137
}
138+
139+
return resultKind;
91140
}
92141
}
93142
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.datafusion;
21+
22+
import static org.junit.jupiter.api.Assertions.assertEquals;
23+
import static org.junit.jupiter.api.Assertions.assertSame;
24+
import static org.junit.jupiter.api.Assertions.assertThrows;
25+
26+
import org.apache.arrow.memory.BufferAllocator;
27+
import org.apache.arrow.memory.RootAllocator;
28+
import org.apache.arrow.vector.IntVector;
29+
import org.apache.arrow.vector.types.pojo.ArrowType;
30+
import org.junit.jupiter.api.Test;
31+
32+
class ColumnarValueTest {
33+
34+
private static final ArrowType INT32 = new ArrowType.Int(32, true);
35+
36+
@Test
37+
void array_factory_returnsArrayVariant() {
38+
try (BufferAllocator allocator = new RootAllocator();
39+
IntVector v = new IntVector("v", allocator)) {
40+
v.allocateNew(3);
41+
v.setValueCount(3);
42+
ColumnarValue cv = ColumnarValue.array(v);
43+
assertSame(v, cv.vector());
44+
assertEquals(INT32, cv.dataType());
45+
}
46+
}
47+
48+
@Test
49+
void scalar_factory_returnsScalarVariant() {
50+
try (BufferAllocator allocator = new RootAllocator();
51+
IntVector v = new IntVector("v", allocator)) {
52+
v.allocateNew(1);
53+
v.set(0, 42);
54+
v.setValueCount(1);
55+
ColumnarValue cv = ColumnarValue.scalar(v);
56+
assertSame(v, cv.vector());
57+
assertEquals(INT32, cv.dataType());
58+
}
59+
}
60+
61+
@Test
62+
void scalar_factory_rejectsNonOneLength() {
63+
try (BufferAllocator allocator = new RootAllocator();
64+
IntVector v = new IntVector("v", allocator)) {
65+
v.allocateNew(2);
66+
v.setValueCount(2);
67+
IllegalArgumentException ex =
68+
assertThrows(IllegalArgumentException.class, () -> ColumnarValue.scalar(v));
69+
assertEquals("Scalar vector must have valueCount == 1, got 2", ex.getMessage());
70+
}
71+
}
72+
73+
@Test
74+
void array_factory_rejectsNull() {
75+
assertThrows(NullPointerException.class, () -> ColumnarValue.array(null));
76+
}
77+
78+
@Test
79+
void scalar_factory_rejectsNull() {
80+
assertThrows(NullPointerException.class, () -> ColumnarValue.scalar(null));
81+
}
82+
}

0 commit comments

Comments
 (0)