From 0cb4b5a8c9677c8dcc1b63b59f36bff12cab8093 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 20 Nov 2022 02:52:38 -0500 Subject: [PATCH 01/16] ARROW-17676: [C++] [Python] User-defined tabular functions --- python/pyarrow/_compute.pyx | 128 ++++++++++++- python/pyarrow/_dataset.pyx | 45 +---- python/pyarrow/compute.py | 3 + python/pyarrow/includes/libarrow.pxd | 20 +- python/pyarrow/includes/libarrow_dataset.pxd | 7 - python/pyarrow/src/arrow/python/udf.cc | 181 +++++++++++++++++-- python/pyarrow/src/arrow/python/udf.h | 16 +- python/pyarrow/tests/test_udf.py | 40 ++++ 8 files changed, 365 insertions(+), 75 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 659af0afba37..9b482d4ef240 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -36,6 +36,18 @@ import inspect import numpy as np +def _forbid_instantiation(klass, subclasses_instead=True): + msg = '{} is an abstract class thus cannot be initialized.'.format( + klass.__name__ + ) + if subclasses_instead: + subclasses = [cls.__name__ for cls in klass.__subclasses__] + msg += ' Use one of the subclasses instead: {}'.format( + ', '.join(subclasses) + ) + raise TypeError(msg) + + cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): """ Wrap a C++ scalar Function in a ScalarFunction object. @@ -141,6 +153,38 @@ cdef wrap_hash_aggregate_kernel(const CHashAggregateKernel* c_kernel): return kernel +cdef class RecordBatchIterator(_Weakrefable): + """An iterator over a sequence of record batches.""" + cdef: + # An object that must be kept alive with the iterator. + object iterator_owner + # Iterator is a non-POD type and Cython uses offsetof, leading + # to a compiler warning unless wrapped like so + shared_ptr[CRecordBatchIterator] iterator + + def __init__(self): + _forbid_instantiation(self.__class__, subclasses_instead=False) + + @staticmethod + cdef wrap(object owner, CRecordBatchIterator iterator): + cdef RecordBatchIterator self = \ + RecordBatchIterator.__new__(RecordBatchIterator) + self.iterator_owner = owner + self.iterator = make_shared[CRecordBatchIterator](move(iterator)) + return self + + def __iter__(self): + return self + + def __next__(self): + cdef shared_ptr[CRecordBatch] record_batch + with nogil: + record_batch = GetResultValue(move(self.iterator.get().Next())) + if record_batch == NULL: + raise StopIteration + return pyarrow_wrap_batch(record_batch) + + cdef class Kernel(_Weakrefable): """ A kernel object. @@ -2558,8 +2602,55 @@ def _get_scalar_udf_context(memory_pool, batch_length): return context -def register_scalar_function(func, function_name, function_doc, in_types, - out_type): +def udf_result_from_record_batch(record_batch): + cdef: + shared_ptr[CRecordBatch] c_record_batch + CResult[shared_ptr[CArray]] c_res_array + + c_record_batch = pyarrow_unwrap_batch(record_batch) + c_res_array = deref(c_record_batch).ToStructArray() + return pyarrow_wrap_array(GetResultValue(c_res_array)) + + +ctypedef CStatus (*CRegisterScalarLikeFunction)(PyObject* function, + function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + CFunctionRegistry* registry) + +cdef class RegisterScalarLikeFunction(_Weakrefable): + cdef CRegisterScalarLikeFunction register_func + + cdef void init(self, const CRegisterScalarLikeFunction register_func): + self.register_func = register_func + + +cdef GetRegisterScalarFunction(): + cdef RegisterScalarLikeFunction reg = RegisterScalarLikeFunction.__new__(RegisterScalarLikeFunction) + reg.register_func = RegisterScalarFunction + return reg + + +cdef GetRegisterTabularFunction(): + cdef RegisterScalarLikeFunction reg = RegisterScalarLikeFunction.__new__(RegisterScalarLikeFunction) + reg.register_func = RegisterTabularFunction + return reg + + +def register_scalar_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + return register_scalar_like_function(GetRegisterScalarFunction(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_tabular_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + return register_scalar_like_function(GetRegisterTabularFunction(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_scalar_like_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): """ Register a user-defined scalar function. @@ -2574,6 +2665,10 @@ def register_scalar_function(func, function_name, function_doc, in_types, Parameters ---------- + register_func: object + An object holding a CRegisterScalarLikeFunction in + a "register_func" attribute, such as: + GetRegisterScalarFunction, GetRegisterTabularFunction func : callable A callable implementing the user-defined function. The first argument is the context argument of type @@ -2600,6 +2695,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, arity. out_type : DataType Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. Examples -------- @@ -2630,6 +2727,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, ] """ cdef: + CRegisterScalarLikeFunction c_register_func c_string c_func_name CArity c_arity CFunctionDoc c_func_doc @@ -2637,6 +2735,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, PyObject* c_function shared_ptr[CDataType] c_out_type CScalarUdfOptions c_options + CFunctionRegistry* c_func_registry if callable(func): c_function = func @@ -2678,5 +2777,26 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_options.input_types = c_in_types c_options.output_type = c_out_type - check_status(RegisterScalarFunction(c_function, - &_scalar_udf_callback, c_options)) + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + + c_register_func = (register_func).register_func + + check_status(c_register_func(c_function, + &_scalar_udf_callback, + c_options, c_func_registry)) + +def get_record_batches_from_tabular_function(function_name, func_registry=None): + cdef: + c_string c_func_name + CFunctionRegistry* c_func_registry + + c_func_name = tobytes(function_name) + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + + return RecordBatchIterator.wrap(None, move(GetResultValue(GetRecordBatchesFromTabularFunction(c_func_name, c_func_registry)))) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 5d4cf95087de..02452c6af3a7 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -32,24 +32,13 @@ from pyarrow.lib cimport * from pyarrow.lib import ArrowTypeError, frombytes, tobytes, _pc from pyarrow.includes.libarrow_dataset cimport * from pyarrow._compute cimport Expression, _bind +from pyarrow._compute import _forbid_instantiation from pyarrow._fs cimport FileSystem, FileInfo, FileSelector from pyarrow._csv cimport ( ConvertOptions, ParseOptions, ReadOptions, WriteOptions) from pyarrow.util import _is_iterable, _is_path_like, _stringify_path -def _forbid_instantiation(klass, subclasses_instead=True): - msg = '{} is an abstract class thus cannot be initialized.'.format( - klass.__name__ - ) - if subclasses_instead: - subclasses = [cls.__name__ for cls in klass.__subclasses__] - msg += ' Use one of the subclasses instead: {}'.format( - ', '.join(subclasses) - ) - raise TypeError(msg) - - _orc_fileformat = None _orc_imported = False @@ -2167,38 +2156,6 @@ cdef class UnionDatasetFactory(DatasetFactory): self.union_factory = sp.get() -cdef class RecordBatchIterator(_Weakrefable): - """An iterator over a sequence of record batches.""" - cdef: - # An object that must be kept alive with the iterator. - object iterator_owner - # Iterator is a non-POD type and Cython uses offsetof, leading - # to a compiler warning unless wrapped like so - shared_ptr[CRecordBatchIterator] iterator - - def __init__(self): - _forbid_instantiation(self.__class__, subclasses_instead=False) - - @staticmethod - cdef wrap(object owner, CRecordBatchIterator iterator): - cdef RecordBatchIterator self = \ - RecordBatchIterator.__new__(RecordBatchIterator) - self.iterator_owner = owner - self.iterator = make_shared[CRecordBatchIterator](move(iterator)) - return self - - def __iter__(self): - return self - - def __next__(self): - cdef shared_ptr[CRecordBatch] record_batch - with nogil: - record_batch = GetResultValue(move(self.iterator.get().Next())) - if record_batch == NULL: - raise StopIteration - return pyarrow_wrap_batch(record_batch) - - class TaggedRecordBatch(collections.namedtuple( "TaggedRecordBatch", ["record_batch", "fragment"])): """ diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 265d75f6f6b0..62b4644307d9 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -80,7 +80,10 @@ list_functions, _group_by, # Udf + get_record_batches_from_tabular_function, register_scalar_function, + register_tabular_function, + udf_result_from_record_batch, ScalarUdfContext, # Expressions Expression, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index bc82a420897d..a801878f7a32 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -796,6 +796,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: const shared_ptr[CSchema]& schema, int64_t num_rows, const vector[shared_ptr[CArray]]& columns) + CResult[shared_ptr[CStructArray]] ToStructArray() const + @staticmethod CResult[shared_ptr[CRecordBatch]] FromStructArray( const shared_ptr[CArray]& array) @@ -2771,6 +2773,14 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) + +cdef extern from "arrow/api.h" namespace "arrow" nogil: + + cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"( + CIterator[shared_ptr[CRecordBatch]]): + pass + + cdef extern from "arrow/python/udf.h" namespace "arrow::py": cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool @@ -2784,4 +2794,12 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": shared_ptr[CDataType] output_type CStatus RegisterScalarFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + CFunctionRegistry* registry) + + CStatus RegisterTabularFunction(PyObject* function, + function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + CFunctionRegistry* registry) + + CResult[CRecordBatchIterator] GetRecordBatchesFromTabularFunction( + const c_string& func_name, CFunctionRegistry* registry) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index e69c3cbcaf29..bd653288c72e 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -25,13 +25,6 @@ from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_fs cimport * -cdef extern from "arrow/api.h" namespace "arrow" nogil: - - cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"( - CIterator[shared_ptr[CRecordBatch]]): - pass - - cdef extern from "arrow/dataset/plan.h" namespace "arrow::dataset::internal" nogil: cdef void Initialize() diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade0..2f18fabb2116 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -17,35 +17,92 @@ #include "arrow/python/udf.h" #include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" #include "arrow/python/common.h" +#include "arrow/util/checked_cast.h" namespace arrow { using compute::ExecResult; using compute::ExecSpan; +using compute::Function; +using compute::OutputType; +using compute::ScalarFunction; +using compute::ScalarKernel; +using internal::checked_cast; namespace py { namespace { -struct PythonUdf : public compute::KernelState { - ScalarUdfWrapperCallback cb; +struct PythonScalarUdfKernelState : public compute::KernelState { + explicit PythonScalarUdfKernelState(std::shared_ptr function) + : function(function) {} + std::shared_ptr function; - std::shared_ptr output_type; +}; - PythonUdf(ScalarUdfWrapperCallback cb, std::shared_ptr function, - const std::shared_ptr& output_type) - : cb(cb), function(function), output_type(output_type) {} +struct PythonScalarUdfKernelInit { + explicit PythonScalarUdfKernelInit(std::shared_ptr function) + : function(function) {} // function needs to be destroyed at process exit // and Python may no longer be initialized. - ~PythonUdf() { + ~PythonScalarUdfKernelInit() { if (_Py_IsFinalizing()) { function->detach(); } } + Result> operator()( + compute::KernelContext*, const compute::KernelInitArgs&) { + return std::make_unique(function); + } + + std::shared_ptr function; +}; + +struct PythonTableUdfKernelInit { + PythonTableUdfKernelInit(std::shared_ptr function_maker, + ScalarUdfWrapperCallback cb) + : function_maker(function_maker), cb(cb) { + Py_INCREF(function_maker->obj()); + } + + Result> operator()( + compute::KernelContext* ctx, const compute::KernelInitArgs&) { + ScalarUdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; + std::unique_ptr function; + RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { + OwnedRef empty_tuple(PyTuple_New(0)); + function = std::make_unique( + cb(function_maker->obj(), udf_context, empty_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + })); + if (!PyCallable_Check(function->obj())) { + return Status::TypeError("Expected a callable Python object."); + } + return std::make_unique( + std::move(function)); + } + + std::shared_ptr function_maker; + ScalarUdfWrapperCallback cb; +}; + +struct PythonUdf : public PythonScalarUdfKernelState { + PythonUdf(std::shared_ptr function, ScalarUdfWrapperCallback cb, + compute::OutputType output_type) + : PythonScalarUdfKernelState(function), cb(cb), output_type(output_type) {} + + ScalarUdfWrapperCallback cb; + compute::OutputType output_type; + Status Exec(compute::KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto state = + ::arrow::internal::checked_cast(ctx->state()); + std::shared_ptr& function = state->function; const int num_args = batch.num_values(); ScalarUdfContext udf_context{ctx->memory_pool(), batch.length}; @@ -68,8 +125,12 @@ struct PythonUdf : public compute::KernelState { // unwrapping the output for expected output type if (is_array(result.obj())) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); - if (!output_type->Equals(*val->type())) { - return Status::TypeError("Expected output datatype ", output_type->ToString(), + ARROW_ASSIGN_OR_RAISE(TypeHolder type, output_type.Resolve(ctx, batch.GetTypes())); + if (type.type == NULLPTR) { + return Status::TypeError("expected output datatype is null"); + } + if (*type.type != *val->type()) { + return Status::TypeError("Expected output datatype ", type.type->ToString(), ", but function returned datatype ", val->type()->ToString()); } @@ -89,10 +150,11 @@ Status PythonUdfExec(compute::KernelContext* ctx, const ExecSpan& batch, return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); }); } -} // namespace - -Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options) { +Status RegisterScalarLikeFunction(PyObject* user_function, + compute::KernelInit kernel_init, + ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(user_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -105,21 +167,108 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback } compute::OutputType output_type(options.output_type); auto udf_data = std::make_shared( - wrapper, std::make_shared(user_function), options.output_type); + std::make_shared(user_function), wrapper, options.output_type); compute::ScalarKernel kernel( compute::KernelSignature::Make(std::move(input_types), std::move(output_type), options.arity.is_varargs), - PythonUdfExec); + PythonUdfExec, kernel_init); kernel.data = std::move(udf_data); kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); - auto registry = compute::GetFunctionRegistry(); + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); return Status::OK(); } +} // namespace + +Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, + compute::FunctionRegistry* registry) { + return RegisterScalarLikeFunction( + user_function, + PythonScalarUdfKernelInit{std::make_shared(user_function)}, wrapper, + options, registry); +} + +Status RegisterTabularFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, + compute::FunctionRegistry* registry) { + if (options.arity.num_args != 0 || options.arity.is_varargs) { + return Status::Invalid("tabular function must have no arguments"); + } + return RegisterScalarLikeFunction( + user_function, + PythonTableUdfKernelInit{std::make_shared(user_function), wrapper}, + wrapper, options, registry); +} + +namespace { + +Result> RecordBatchFromArray( + std::shared_ptr schema, std::shared_ptr array) { + auto& data = const_cast&>(array->data()); + if (data->child_data.size() != static_cast(schema->num_fields())) { + return Status::Invalid("UDF result with shape not conforming to schema"); + } + return RecordBatch::Make(std::move(schema), data->length, std::move(data->child_data)); +} + +} // namespace + +Result GetRecordBatchesFromTabularFunction( + const std::string& func_name, compute::FunctionRegistry* registry) { + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name)); + if (func->kind() != Function::SCALAR) { + return Status::Invalid("tabular function of non-scalar kind"); + } + auto arity = func->arity(); + if (arity.num_args != 0 || arity.is_varargs) { + return Status::Invalid("tabular function of non-null arity"); + } + auto kernels = ::arrow::internal::checked_pointer_cast(func)->kernels(); + if (kernels.size() != 1) { + return Status::Invalid("tabular function with non-single kernel"); + } + const ScalarKernel* kernel = kernels[0]; + auto out_type = kernel->signature->out_type(); + if (out_type.kind() != OutputType::FIXED) { + return Status::Invalid("tabular kernel of non-fixed kind"); + } + auto datatype = out_type.type(); + if (datatype->id() != Type::type::STRUCT) { + return Status::Invalid("tabular kernel with non-struct output"); + } + auto fields = checked_cast(datatype.get())->fields(); + auto schema = ::arrow::schema(fields); + std::vector in_types; + ARROW_ASSIGN_OR_RAISE(auto func_exec, + GetFunctionExecutor(func_name, in_types, NULLPTR, registry)); + auto next_func = [schema, func_name, + func_exec]() -> Result> { + std::vector args; + // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator + // in exec.cc and to never invoking the source function, so 1 is passed instead + ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, /*passed_length=*/1)); + if (!datum.is_array()) { + return Status::Invalid("UDF result of non-array kind"); + } + std::shared_ptr array = datum.make_array(); + if (array->length() == 0) { + return IterationTraits>::End(); + } + return RecordBatchFromArray(std::move(schema), std::move(array)); + }; + return MakeFunctionIterator(std::move(next_func)); +} + } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd8..c75f3aaae39d 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -21,6 +21,8 @@ #include "arrow/compute/function.h" #include "arrow/compute/registry.h" #include "arrow/python/platform.h" +#include "arrow/record_batch.h" +#include "arrow/util/iterator.h" #include "arrow/python/common.h" #include "arrow/python/pyarrow.h" @@ -51,9 +53,17 @@ using ScalarUdfWrapperCallback = std::function; /// \brief register a Scalar user-defined-function from Python -Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, - ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options); +Status ARROW_PYTHON_EXPORT RegisterScalarFunction( + PyObject* user_function, ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + +/// \brief register a Table user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterTabularFunction( + PyObject* user_function, ScalarUdfWrapperCallback wrapper, + const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + +Result ARROW_PYTHON_EXPORT GetRecordBatchesFromTabularFunction( + const std::string& func_name, compute::FunctionRegistry* registry = NULLPTR); } // namespace py diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d2..d210e902c301 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -504,3 +504,43 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def _record_batch_from_iters(schema, *iters): + arrays = [pa.array(list(v), type=schema[i].type) for i, v in enumerate(iters)] + return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema) + + +def _record_batch_for_range(schema, n): + return _record_batch_from_iters(schema, range(n, n + 10), range(n + 1, n + 11)) + + +def datasource1(ctx): + """A short dataset""" + import pyarrow as pa + schema = pa.schema([('', pa.int32()), ('', pa.int32())]) + class Generator: + def __init__(self): + self.n = 3 + def __call__(self, ctx): + if self.n == 0: + batch = _record_batch_from_iters(schema, [], []) + else: + self.n -= 1 + batch = _record_batch_for_range(schema, self.n) + return pc.udf_result_from_record_batch(batch) + return Generator() + + +def test_udt(): + func = datasource1 + func_name = "datasource1" + func_doc = {"summary": "datasource1 UDT", "description": "test datasource1 UDT"} + in_types = {} + out_type = pa.struct([("", pa.int32()), ("", pa.int32())]) + schema = pa.schema([('', pa.int32()), ('', pa.int32())]) + pc.register_tabular_function(func, func_name, func_doc, in_types, out_type) + n = 3; + for item in pc.get_record_batches_from_tabular_function(func_name): + n -= 1 + assert item == _record_batch_for_range(schema, n) From d0cc3f17b1d29166850fe3576885b95185f4334b Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 20 Nov 2022 03:10:18 -0500 Subject: [PATCH 02/16] lint --- python/pyarrow/_compute.pyx | 1 + python/pyarrow/includes/libarrow.pxd | 6 +++--- python/pyarrow/tests/test_udf.py | 14 ++++++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 9b482d4ef240..323a5b9fb038 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2788,6 +2788,7 @@ def register_scalar_like_function(register_func, func, function_name, function_d &_scalar_udf_callback, c_options, c_func_registry)) + def get_record_batches_from_tabular_function(function_name, func_registry=None): cdef: c_string c_func_name diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index a801878f7a32..fa9d77a9e7bd 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2798,8 +2798,8 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": CFunctionRegistry* registry) CStatus RegisterTabularFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options, - CFunctionRegistry* registry) + function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + CFunctionRegistry* registry) CResult[CRecordBatchIterator] GetRecordBatchesFromTabularFunction( - const c_string& func_name, CFunctionRegistry* registry) + const c_string& func_name, CFunctionRegistry* registry) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index d210e902c301..7e86dc8b113f 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -507,21 +507,26 @@ def test_input_lifetime(unary_func_fixture): def _record_batch_from_iters(schema, *iters): - arrays = [pa.array(list(v), type=schema[i].type) for i, v in enumerate(iters)] + arrays = [pa.array(list(v), type=schema[i].type) + for i, v in enumerate(iters)] return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema) def _record_batch_for_range(schema, n): - return _record_batch_from_iters(schema, range(n, n + 10), range(n + 1, n + 11)) + return _record_batch_from_iters(schema, + range(n, n + 10), + range(n + 1, n + 11)) def datasource1(ctx): """A short dataset""" import pyarrow as pa schema = pa.schema([('', pa.int32()), ('', pa.int32())]) + class Generator: def __init__(self): self.n = 3 + def __call__(self, ctx): if self.n == 0: batch = _record_batch_from_iters(schema, [], []) @@ -535,12 +540,13 @@ def __call__(self, ctx): def test_udt(): func = datasource1 func_name = "datasource1" - func_doc = {"summary": "datasource1 UDT", "description": "test datasource1 UDT"} + func_doc = {"summary": "datasource1 UDT", + "description": "test datasource1 UDT"} in_types = {} out_type = pa.struct([("", pa.int32()), ("", pa.int32())]) schema = pa.schema([('', pa.int32()), ('', pa.int32())]) pc.register_tabular_function(func, func_name, func_doc, in_types, out_type) - n = 3; + n = 3 for item in pc.get_record_batches_from_tabular_function(func_name): n -= 1 assert item == _record_batch_for_range(schema, n) From 2b2986c40c8f87f3ba4f9b6583bcef58ddc1bae8 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 20 Nov 2022 04:51:38 -0500 Subject: [PATCH 03/16] add docs --- python/pyarrow/_compute.pyx | 120 ++++++++++++++++++++++++++++++------ 1 file changed, 102 insertions(+), 18 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 323a5b9fb038..0ad5b5d3628f 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2603,6 +2603,16 @@ def _get_scalar_udf_context(memory_pool, batch_length): def udf_result_from_record_batch(record_batch): + """ + Convert a record batch to a UDF result. + + A UDF result is a struct array appropriate for returning from a UDF. + + Parameters + ---------- + record_batch: object + An object holding a wrapped CRecordBatch + """ cdef: shared_ptr[CRecordBatch] c_record_batch CResult[shared_ptr[CArray]] c_res_array @@ -2637,20 +2647,6 @@ cdef GetRegisterTabularFunction(): def register_scalar_function(func, function_name, function_doc, in_types, out_type, func_registry=None): - return register_scalar_like_function(GetRegisterScalarFunction(), - func, function_name, function_doc, in_types, - out_type, func_registry) - - -def register_tabular_function(func, function_name, function_doc, in_types, out_type, - func_registry=None): - return register_scalar_like_function(GetRegisterTabularFunction(), - func, function_name, function_doc, in_types, - out_type, func_registry) - - -def register_scalar_like_function(register_func, func, function_name, function_doc, in_types, - out_type, func_registry=None): """ Register a user-defined scalar function. @@ -2665,10 +2661,6 @@ def register_scalar_like_function(register_func, func, function_name, function_d Parameters ---------- - register_func: object - An object holding a CRegisterScalarLikeFunction in - a "register_func" attribute, such as: - GetRegisterScalarFunction, GetRegisterTabularFunction func : callable A callable implementing the user-defined function. The first argument is the context argument of type @@ -2726,6 +2718,88 @@ def register_scalar_like_function(register_func, func, function_name, function_d 21 ] """ + return register_scalar_like_function(GetRegisterScalarFunction(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_tabular_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined tabular function. + + A tabular function is one accepting a context argument of type + ScalarUdfContext and returning a generator of struct arrays. + The in_types argument must be empty and the out_type argument + specifies a schema. Each struct array must have field types + correspoding to the schema. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The only argument is the context argument of type + ScalarUdfContext. It must return a callable that + returns on each invocation a StructArray matching + the out_type, where an empty array indicates end. + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + Must be an empty dictionary. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ + return register_scalar_like_function(GetRegisterTabularFunction(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_scalar_like_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): + """ + Register a user-defined scalar-like function. + + A scalar-like function is a callable accepting a first + context argument of type ScalarUdfContext as well as + possibly additional Arrow arguments, and returning a + an Arrow result appropriate for the kind of function. + A scalar function and a tabular function are examples + for scalar-like functions. + This function is normally not called directly but via + register_scalar_function or register_tabular_function. + + Parameters + ---------- + register_func: object + An object holding a CRegisterScalarLikeFunction in + a "register_func" attribute. Use + GetRegisterScalarFunction() for a scalar function and + GetRegisterTabularFunction() for a tabular function. + func : callable + A callable implementing the user-defined function. + See register_scalar_function and + register_tabular_function for details. + + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + See register_scalar_function and + register_tabular_function for details. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ cdef: CRegisterScalarLikeFunction c_register_func c_string c_func_name @@ -2790,6 +2864,16 @@ def register_scalar_like_function(register_func, func, function_name, function_d def get_record_batches_from_tabular_function(function_name, func_registry=None): + """ + Get a record batch iterator from a tabular function. + + Parameters + ---------- + function_name : str + Name of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ cdef: c_string c_func_name CFunctionRegistry* c_func_registry From 3e8b0ad3c0154bb089eb6c810e80649571fb9c5e Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 20 Nov 2022 05:47:40 -0500 Subject: [PATCH 04/16] lint --- python/pyarrow/_compute.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0ad5b5d3628f..cf08ed752c35 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2610,7 +2610,7 @@ def udf_result_from_record_batch(record_batch): Parameters ---------- - record_batch: object + record_batch : object An object holding a wrapped CRecordBatch """ cdef: From 8f38e9501d1f4b075bab6117b29006f0225b5248 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 20 Nov 2022 15:51:11 -0500 Subject: [PATCH 05/16] fix tabular next-function --- python/pyarrow/src/arrow/python/udf.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 2f18fabb2116..473792d90b04 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -251,8 +251,9 @@ Result GetRecordBatchesFromTabularFunction( std::vector in_types; ARROW_ASSIGN_OR_RAISE(auto func_exec, GetFunctionExecutor(func_name, in_types, NULLPTR, registry)); - auto next_func = [schema, func_name, - func_exec]() -> Result> { + auto next_func = + [schema = std::move(schema), + func_exec = std::move(func_exec)]() -> Result> { std::vector args; // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator // in exec.cc and to never invoking the source function, so 1 is passed instead From 345e961d9246af63ce0ab29c8ef9ef70512f87ba Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 22 Nov 2022 10:52:28 -0500 Subject: [PATCH 06/16] requested changes --- python/pyarrow/_compute.pyx | 20 ------- python/pyarrow/compute.py | 1 - python/pyarrow/src/arrow/python/udf.cc | 32 ++++------- python/pyarrow/table.pxi | 14 +++++ python/pyarrow/tests/test_udf.py | 79 ++++++++++++++++++++++---- 5 files changed, 94 insertions(+), 52 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index cf08ed752c35..9237df8fd900 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2602,26 +2602,6 @@ def _get_scalar_udf_context(memory_pool, batch_length): return context -def udf_result_from_record_batch(record_batch): - """ - Convert a record batch to a UDF result. - - A UDF result is a struct array appropriate for returning from a UDF. - - Parameters - ---------- - record_batch : object - An object holding a wrapped CRecordBatch - """ - cdef: - shared_ptr[CRecordBatch] c_record_batch - CResult[shared_ptr[CArray]] c_res_array - - c_record_batch = pyarrow_unwrap_batch(record_batch) - c_res_array = deref(c_record_batch).ToStructArray() - return pyarrow_wrap_array(GetResultValue(c_res_array)) - - ctypedef CStatus (*CRegisterScalarLikeFunction)(PyObject* function, function[CallbackUdf] wrapper, const CScalarUdfOptions& options, CFunctionRegistry* registry) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 62b4644307d9..50830f524011 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -83,7 +83,6 @@ get_record_batches_from_tabular_function, register_scalar_function, register_tabular_function, - udf_result_from_record_batch, ScalarUdfContext, # Expressions Expression, diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 473792d90b04..f72ae578bee5 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -22,15 +22,6 @@ #include "arrow/util/checked_cast.h" namespace arrow { - -using compute::ExecResult; -using compute::ExecSpan; -using compute::Function; -using compute::OutputType; -using compute::ScalarFunction; -using compute::ScalarKernel; -using internal::checked_cast; - namespace py { namespace { @@ -99,9 +90,9 @@ struct PythonUdf : public PythonScalarUdfKernelState { ScalarUdfWrapperCallback cb; compute::OutputType output_type; - Status Exec(compute::KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - auto state = - ::arrow::internal::checked_cast(ctx->state()); + Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch, + compute::ExecResult* out) { + auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->function; const int num_args = batch.num_values(); ScalarUdfContext udf_context{ctx->memory_pool(), batch.length}; @@ -144,8 +135,8 @@ struct PythonUdf : public PythonScalarUdfKernelState { } }; -Status PythonUdfExec(compute::KernelContext* ctx, const ExecSpan& batch, - ExecResult* out) { +Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch, + compute::ExecResult* out) { auto udf = static_cast(ctx->kernel()->data.get()); return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); }); } @@ -226,27 +217,29 @@ Result GetRecordBatchesFromTabularFunction( registry = compute::GetFunctionRegistry(); } ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name)); - if (func->kind() != Function::SCALAR) { + if (func->kind() != compute::Function::SCALAR) { return Status::Invalid("tabular function of non-scalar kind"); } auto arity = func->arity(); if (arity.num_args != 0 || arity.is_varargs) { return Status::Invalid("tabular function of non-null arity"); } - auto kernels = ::arrow::internal::checked_pointer_cast(func)->kernels(); + auto kernels = + arrow::internal::checked_pointer_cast(func)->kernels(); if (kernels.size() != 1) { return Status::Invalid("tabular function with non-single kernel"); } - const ScalarKernel* kernel = kernels[0]; + const compute::ScalarKernel* kernel = kernels[0]; auto out_type = kernel->signature->out_type(); - if (out_type.kind() != OutputType::FIXED) { + if (out_type.kind() != compute::OutputType::FIXED) { return Status::Invalid("tabular kernel of non-fixed kind"); } auto datatype = out_type.type(); if (datatype->id() != Type::type::STRUCT) { return Status::Invalid("tabular kernel with non-struct output"); } - auto fields = checked_cast(datatype.get())->fields(); + auto fields = + arrow::internal::checked_cast(datatype.get())->fields(); auto schema = ::arrow::schema(fields); std::vector in_types; ARROW_ASSIGN_OR_RAISE(auto func_exec, @@ -271,5 +264,4 @@ Result GetRecordBatchesFromTabularFunction( } } // namespace py - } // namespace arrow diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5c58ae61f191..2b1e50733ccf 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2498,6 +2498,20 @@ cdef class RecordBatch(_PandasConvertible): CRecordBatch.FromStructArray(struct_array.sp_array)) return pyarrow_wrap_batch(c_record_batch) + def to_struct_array(self): + """ + Convert to a struct array. + """ + cdef: + shared_ptr[CRecordBatch] c_record_batch + shared_ptr[CArray] c_array + + c_record_batch = pyarrow_unwrap_batch(self) + with nogil: + c_array = GetResultValue( + deref(c_record_batch).ToStructArray()) + return pyarrow_wrap_array(c_array) + def _export_to_c(self, out_ptr, out_schema_ptr=0): """ Export to a C ArrowArray struct, given its pointer. diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 7e86dc8b113f..de990be5f34f 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -518,10 +518,28 @@ def _record_batch_for_range(schema, n): range(n + 1, n + 11)) -def datasource1(ctx): +def make_udt_func(schema, batch_gen): + def udf_func(ctx): + class UDT: + def __init__(self): + self.caller = None + def __call__(self, ctx): + try: + if self.caller is None: + self.caller, ctx = batch_gen(ctx).send, None + batch = self.caller(ctx) + except StopIteration: + arrays = [pa.array([], type=field.type) for field in schema] + batch = pa.RecordBatch.from_arrays(arrays=arrays, schema=schema) + return batch.to_struct_array() + return UDT() + return udf_func + + +def datasource1_direct(): """A short dataset""" import pyarrow as pa - schema = pa.schema([('', pa.int32()), ('', pa.int32())]) + schema = datasource1_schema() class Generator: def __init__(self): @@ -533,20 +551,59 @@ def __call__(self, ctx): else: self.n -= 1 batch = _record_batch_for_range(schema, self.n) - return pc.udf_result_from_record_batch(batch) - return Generator() + return batch.to_struct_array() + return lambda ctx: Generator() -def test_udt(): - func = datasource1 - func_name = "datasource1" - func_doc = {"summary": "datasource1 UDT", - "description": "test datasource1 UDT"} +def datasource1_generator(): + schema = datasource1_schema() + def batch_gen(ctx): + for n in range(3, 0, -1): + ctx = yield _record_batch_for_range(schema, n - 1) + return make_udt_func(schema, batch_gen) + + +def datasource1_exception(): + schema = datasource1_schema() + def batch_gen(ctx): + for n in range(3, 0, -1): + ctx = yield _record_batch_for_range(schema, n - 1) + raise RuntimeError("datasource1_exception") + return make_udt_func(schema, batch_gen) + + +def datasource1_schema(): + return pa.schema([('', pa.int32()), ('', pa.int32())]) + + +def datasource1_args(func, func_name): + func_doc = {"summary": f"{func_name} UDT", + "description": "test {func_name} UDT"} in_types = {} out_type = pa.struct([("", pa.int32()), ("", pa.int32())]) - schema = pa.schema([('', pa.int32()), ('', pa.int32())]) - pc.register_tabular_function(func, func_name, func_doc, in_types, out_type) + return func, func_name, func_doc, in_types, out_type + + +def _test_datasource1_udt(func_maker): + schema = datasource1_schema() + func = func_maker() + func_name = func_maker.__name__ + func_args = datasource1_args(func, func_name) + pc.register_tabular_function(*func_args) n = 3 for item in pc.get_record_batches_from_tabular_function(func_name): n -= 1 assert item == _record_batch_for_range(schema, n) + + +def test_udt_datasource1_direct(): + _test_datasource1_udt(datasource1_direct) + + +def test_udt_datasource1_generator(): + _test_datasource1_udt(datasource1_generator) + + +def test_udt_datasource1_exception(): + with pytest.raises(RuntimeError, match='datasource1_exception'): + _test_datasource1_udt(datasource1_exception) From 2fcc55373f9729d82d55bce50897943fb184b57e Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 23 Nov 2022 03:59:31 -0500 Subject: [PATCH 07/16] more requested fixes --- python/pyarrow/src/arrow/python/udf.cc | 14 ++++++++------ python/pyarrow/tests/test_udf.py | 9 +++++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index f72ae578bee5..e5139b4478c5 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -190,7 +190,10 @@ Status RegisterTabularFunction(PyObject* user_function, ScalarUdfWrapperCallback const ScalarUdfOptions& options, compute::FunctionRegistry* registry) { if (options.arity.num_args != 0 || options.arity.is_varargs) { - return Status::Invalid("tabular function must have no arguments"); + return Status::NotImplemented("tabular function of non-null arity"); + } + if (options.output_type->id() != Type::type::STRUCT) { + return Status::Invalid("tabular function with non-struct output"); } return RegisterScalarLikeFunction( user_function, @@ -222,12 +225,12 @@ Result GetRecordBatchesFromTabularFunction( } auto arity = func->arity(); if (arity.num_args != 0 || arity.is_varargs) { - return Status::Invalid("tabular function of non-null arity"); + return Status::NotImplemented("tabular function of non-null arity"); } auto kernels = arrow::internal::checked_pointer_cast(func)->kernels(); if (kernels.size() != 1) { - return Status::Invalid("tabular function with non-single kernel"); + return Status::NotImplemented("tabular function with non-single kernel"); } const compute::ScalarKernel* kernel = kernels[0]; auto out_type = kernel->signature->out_type(); @@ -238,9 +241,8 @@ Result GetRecordBatchesFromTabularFunction( if (datatype->id() != Type::type::STRUCT) { return Status::Invalid("tabular kernel with non-struct output"); } - auto fields = - arrow::internal::checked_cast(datatype.get())->fields(); - auto schema = ::arrow::schema(fields); + auto struct_type = arrow::internal::checked_cast(datatype.get()); + auto schema = ::arrow::schema(struct_type->fields()); std::vector in_types; ARROW_ASSIGN_OR_RAISE(auto func_exec, GetFunctionExecutor(func_name, in_types, NULLPTR, registry)); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index de990be5f34f..50b21e721c17 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -523,14 +523,17 @@ def udf_func(ctx): class UDT: def __init__(self): self.caller = None + def __call__(self, ctx): try: if self.caller is None: self.caller, ctx = batch_gen(ctx).send, None batch = self.caller(ctx) except StopIteration: - arrays = [pa.array([], type=field.type) for field in schema] - batch = pa.RecordBatch.from_arrays(arrays=arrays, schema=schema) + arrays = [pa.array([], type=field.type) + for field in schema] + batch = pa.RecordBatch.from_arrays( + arrays=arrays, schema=schema) return batch.to_struct_array() return UDT() return udf_func @@ -557,6 +560,7 @@ def __call__(self, ctx): def datasource1_generator(): schema = datasource1_schema() + def batch_gen(ctx): for n in range(3, 0, -1): ctx = yield _record_batch_for_range(schema, n - 1) @@ -565,6 +569,7 @@ def batch_gen(ctx): def datasource1_exception(): schema = datasource1_schema() + def batch_gen(ctx): for n in range(3, 0, -1): ctx = yield _record_batch_for_range(schema, n - 1) From 3f35ccd63fad1d4a470084da8153367c1e4ab725 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 23 Nov 2022 04:29:58 -0500 Subject: [PATCH 08/16] lint --- python/pyarrow/tests/test_udf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 50b21e721c17..64db9d81c4e7 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -541,7 +541,6 @@ def __call__(self, ctx): def datasource1_direct(): """A short dataset""" - import pyarrow as pa schema = datasource1_schema() class Generator: @@ -563,7 +562,8 @@ def datasource1_generator(): def batch_gen(ctx): for n in range(3, 0, -1): - ctx = yield _record_batch_for_range(schema, n - 1) + #ctx = + yield _record_batch_for_range(schema, n - 1) return make_udt_func(schema, batch_gen) @@ -572,7 +572,8 @@ def datasource1_exception(): def batch_gen(ctx): for n in range(3, 0, -1): - ctx = yield _record_batch_for_range(schema, n - 1) + #ctx = + yield _record_batch_for_range(schema, n - 1) raise RuntimeError("datasource1_exception") return make_udt_func(schema, batch_gen) From 3c30eee1ef54349f0f7b41ab1dd1ae4309a182c6 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 24 Nov 2022 13:53:05 -0500 Subject: [PATCH 09/16] lint --- python/pyarrow/tests/test_udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 64db9d81c4e7..e50b7ae7b6e6 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -562,7 +562,7 @@ def datasource1_generator(): def batch_gen(ctx): for n in range(3, 0, -1): - #ctx = + # ctx = yield _record_batch_for_range(schema, n - 1) return make_udt_func(schema, batch_gen) @@ -572,7 +572,7 @@ def datasource1_exception(): def batch_gen(ctx): for n in range(3, 0, -1): - #ctx = + # ctx = yield _record_batch_for_range(schema, n - 1) raise RuntimeError("datasource1_exception") return make_udt_func(schema, batch_gen) From feaa95771e9f21559a2f28b1368602833e9cfd9d Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 6 Dec 2022 17:31:40 -0500 Subject: [PATCH 10/16] requested changes --- python/pyarrow/_compute.pyx | 50 +++++++++++++++++--------- python/pyarrow/compute.py | 2 +- python/pyarrow/includes/libarrow.pxd | 5 +-- python/pyarrow/src/arrow/python/udf.cc | 8 +++-- python/pyarrow/src/arrow/python/udf.h | 5 +-- python/pyarrow/tests/test_udf.py | 2 +- 6 files changed, 47 insertions(+), 25 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 9237df8fd900..6947ad0dd186 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2613,13 +2613,13 @@ cdef class RegisterScalarLikeFunction(_Weakrefable): self.register_func = register_func -cdef GetRegisterScalarFunction(): +cdef get_register_scalar_function(): cdef RegisterScalarLikeFunction reg = RegisterScalarLikeFunction.__new__(RegisterScalarLikeFunction) reg.register_func = RegisterScalarFunction return reg -cdef GetRegisterTabularFunction(): +cdef get_register_tabular_function(): cdef RegisterScalarLikeFunction reg = RegisterScalarLikeFunction.__new__(RegisterScalarLikeFunction) reg.register_func = RegisterTabularFunction return reg @@ -2698,9 +2698,9 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty 21 ] """ - return register_scalar_like_function(GetRegisterScalarFunction(), - func, function_name, function_doc, in_types, - out_type, func_registry) + return _register_scalar_like_function(get_register_scalar_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) def register_tabular_function(func, function_name, function_doc, in_types, out_type, @@ -2728,19 +2728,28 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t A dictionary object with keys "summary" (str), and "description" (str). in_types : Dict[str, DataType] - Must be an empty dictionary. - out_type : DataType - Output type of the function. + Must be an empty dictionary (reserved for future use). + out_type : Union[Schema, DataType] + Schema of the function's output, or a corresponding flat struct type. func_registry : FunctionRegistry Optional function registry to use instead of the default global one. """ - return register_scalar_like_function(GetRegisterTabularFunction(), - func, function_name, function_doc, in_types, - out_type, func_registry) + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CDataType] c_type + if isinstance(out_type, Schema): + c_schema = pyarrow_unwrap_schema(out_type) + with nogil: + c_type = make_shared[CStructType](deref(c_schema).fields()) + out_type = pyarrow_wrap_data_type(c_type) + return _register_scalar_like_function(get_register_tabular_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) -def register_scalar_like_function(register_func, func, function_name, function_doc, in_types, - out_type, func_registry=None): + +def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): """ Register a user-defined scalar-like function. @@ -2758,8 +2767,8 @@ def register_scalar_like_function(register_func, func, function_name, function_d register_func: object An object holding a CRegisterScalarLikeFunction in a "register_func" attribute. Use - GetRegisterScalarFunction() for a scalar function and - GetRegisterTabularFunction() for a tabular function. + get_register_scalar_function() for a scalar function and + get_register_tabular_function() for a tabular function. func : callable A callable implementing the user-defined function. See register_scalar_function and @@ -2843,7 +2852,7 @@ def register_scalar_like_function(register_func, func, function_name, function_d c_options, c_func_registry)) -def get_record_batches_from_tabular_function(function_name, func_registry=None): +def call_tabular_function(function_name, args=None, func_registry=None): """ Get a record batch iterator from a tabular function. @@ -2851,11 +2860,15 @@ def get_record_batches_from_tabular_function(function_name, func_registry=None): ---------- function_name : str Name of the function. + args : iterable + The arguments to pass to the function. Accepted types depend + on the specific function. Currently, only an empty args is supported. func_registry : FunctionRegistry Optional function registry to use instead of the default global one. """ cdef: c_string c_func_name + vector[CDatum] c_args CFunctionRegistry* c_func_registry c_func_name = tobytes(function_name) @@ -2863,5 +2876,8 @@ def get_record_batches_from_tabular_function(function_name, func_registry=None): c_func_registry = NULL else: c_func_registry = (func_registry).registry + if args is None: + args = [] + _pack_compute_args(args, &c_args) - return RecordBatchIterator.wrap(None, move(GetResultValue(GetRecordBatchesFromTabularFunction(c_func_name, c_func_registry)))) + return RecordBatchIterator.wrap(None, move(GetResultValue(CallTabularFunction(c_func_name, c_args, c_func_registry)))) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 50830f524011..d7d4912c8af9 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -80,7 +80,7 @@ list_functions, _group_by, # Udf - get_record_batches_from_tabular_function, + call_tabular_function, register_scalar_function, register_tabular_function, ScalarUdfContext, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index fa9d77a9e7bd..d0d51c0f54d7 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -477,6 +477,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name) int GetFieldIndex(const c_string& name) vector[int] GetAllFieldIndices(const c_string& name) + const vector[shared_ptr[CField]] fields() int num_fields() c_string ToString() @@ -2801,5 +2802,5 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": function[CallbackUdf] wrapper, const CScalarUdfOptions& options, CFunctionRegistry* registry) - CResult[CRecordBatchIterator] GetRecordBatchesFromTabularFunction( - const c_string& func_name, CFunctionRegistry* registry) + CResult[CRecordBatchIterator] CallTabularFunction( + const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index e5139b4478c5..7cbf14ff26c1 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -214,8 +214,12 @@ Result> RecordBatchFromArray( } // namespace -Result GetRecordBatchesFromTabularFunction( - const std::string& func_name, compute::FunctionRegistry* registry) { +Result CallTabularFunction( + const std::string& func_name, const std::vector& args, + compute::FunctionRegistry* registry) { + if (args.size() != 0) { + return Status::NotImplemented("non-empty arguments to tabular function"); + } if (registry == NULLPTR) { registry = compute::GetFunctionRegistry(); } diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index c75f3aaae39d..cbd88b92fe57 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -62,8 +62,9 @@ Status ARROW_PYTHON_EXPORT RegisterTabularFunction( PyObject* user_function, ScalarUdfWrapperCallback wrapper, const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); -Result ARROW_PYTHON_EXPORT GetRecordBatchesFromTabularFunction( - const std::string& func_name, compute::FunctionRegistry* registry = NULLPTR); +Result ARROW_PYTHON_EXPORT CallTabularFunction( + const std::string& func_name, const std::vector& args, + compute::FunctionRegistry* registry = NULLPTR); } // namespace py diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e50b7ae7b6e6..e1dbc7e01f4c 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -597,7 +597,7 @@ def _test_datasource1_udt(func_maker): func_args = datasource1_args(func, func_name) pc.register_tabular_function(*func_args) n = 3 - for item in pc.get_record_batches_from_tabular_function(func_name): + for item in pc.call_tabular_function(func_name): n -= 1 assert item == _record_batch_for_range(schema, n) From 37451c7d30248ee3437871c0033e5663f3072223 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 7 Dec 2022 03:42:38 -0500 Subject: [PATCH 11/16] RecordBatchReader API --- python/pyarrow/_compute.pyx | 41 +++++--------------------- python/pyarrow/includes/libarrow.pxd | 4 +-- python/pyarrow/src/arrow/python/udf.cc | 7 +++-- python/pyarrow/src/arrow/python/udf.h | 2 +- 4 files changed, 15 insertions(+), 39 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 6947ad0dd186..981e26985757 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -153,38 +153,6 @@ cdef wrap_hash_aggregate_kernel(const CHashAggregateKernel* c_kernel): return kernel -cdef class RecordBatchIterator(_Weakrefable): - """An iterator over a sequence of record batches.""" - cdef: - # An object that must be kept alive with the iterator. - object iterator_owner - # Iterator is a non-POD type and Cython uses offsetof, leading - # to a compiler warning unless wrapped like so - shared_ptr[CRecordBatchIterator] iterator - - def __init__(self): - _forbid_instantiation(self.__class__, subclasses_instead=False) - - @staticmethod - cdef wrap(object owner, CRecordBatchIterator iterator): - cdef RecordBatchIterator self = \ - RecordBatchIterator.__new__(RecordBatchIterator) - self.iterator_owner = owner - self.iterator = make_shared[CRecordBatchIterator](move(iterator)) - return self - - def __iter__(self): - return self - - def __next__(self): - cdef shared_ptr[CRecordBatch] record_batch - with nogil: - record_batch = GetResultValue(move(self.iterator.get().Next())) - if record_batch == NULL: - raise StopIteration - return pyarrow_wrap_batch(record_batch) - - cdef class Kernel(_Weakrefable): """ A kernel object. @@ -2870,6 +2838,8 @@ def call_tabular_function(function_name, args=None, func_registry=None): c_string c_func_name vector[CDatum] c_args CFunctionRegistry* c_func_registry + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader reader c_func_name = tobytes(function_name) if func_registry is None: @@ -2880,4 +2850,9 @@ def call_tabular_function(function_name, args=None, func_registry=None): args = [] _pack_compute_args(args, &c_args) - return RecordBatchIterator.wrap(None, move(GetResultValue(CallTabularFunction(c_func_name, c_args, c_func_registry)))) + with nogil: + c_reader = GetResultValue(CallTabularFunction( + c_func_name, c_args, c_func_registry)) + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = c_reader + return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index d0d51c0f54d7..d94a14c259c0 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2782,7 +2782,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: pass -cdef extern from "arrow/python/udf.h" namespace "arrow::py": +cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool int64_t batch_length @@ -2802,5 +2802,5 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": function[CallbackUdf] wrapper, const CScalarUdfOptions& options, CFunctionRegistry* registry) - CResult[CRecordBatchIterator] CallTabularFunction( + CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 7cbf14ff26c1..5e6f2b0ec101 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -214,7 +214,7 @@ Result> RecordBatchFromArray( } // namespace -Result CallTabularFunction( +Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { if (args.size() != 0) { @@ -251,7 +251,7 @@ Result CallTabularFunction( ARROW_ASSIGN_OR_RAISE(auto func_exec, GetFunctionExecutor(func_name, in_types, NULLPTR, registry)); auto next_func = - [schema = std::move(schema), + [schema, func_exec = std::move(func_exec)]() -> Result> { std::vector args; // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator @@ -266,7 +266,8 @@ Result CallTabularFunction( } return RecordBatchFromArray(std::move(schema), std::move(array)); }; - return MakeFunctionIterator(std::move(next_func)); + return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(std::move(next_func)), + schema); } } // namespace py diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index cbd88b92fe57..798251d680e8 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -62,7 +62,7 @@ Status ARROW_PYTHON_EXPORT RegisterTabularFunction( PyObject* user_function, ScalarUdfWrapperCallback wrapper, const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); -Result ARROW_PYTHON_EXPORT CallTabularFunction( +Result> ARROW_PYTHON_EXPORT CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry = NULLPTR); From d0c8f5f21544ada0bbfb7758f813f6954fdfb7f3 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 15 Dec 2022 09:32:06 -0500 Subject: [PATCH 12/16] requested fixes --- cpp/src/arrow/engine/substrait/serde_test.cc | 2 +- cpp/src/arrow/type.cc | 10 ++ cpp/src/arrow/type.h | 3 + python/pyarrow/_compute.pxd | 6 +- python/pyarrow/_compute.pyx | 54 ++++----- python/pyarrow/compute.py | 2 +- python/pyarrow/includes/libarrow.pxd | 10 +- python/pyarrow/src/arrow/python/udf.cc | 120 +++++++++++-------- python/pyarrow/src/arrow/python/udf.h | 16 +-- python/pyarrow/tests/test_udf.py | 12 +- 10 files changed, 136 insertions(+), 99 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index c2c38e7a0fb9..fa4f411f93d8 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -4037,7 +4037,7 @@ TEST(Substrait, SetRelationBasic) { compute::SortOptions sort_options( {compute::SortKey("A", compute::SortOrder::Ascending)}); - CheckRoundTripResult(dummy_schema, std::move(expected_table), exec_context, buf, {}, + CheckRoundTripResult(std::move(expected_table), exec_context, buf, {}, conversion_options, &sort_options); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index cc31735512ba..92d44062e803 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -452,6 +452,16 @@ std::string TypeHolder::ToString(const std::vector& types) { return ss.str(); } +std::vector TypeHolder::FromTypes( + const std::vector>& types) { + std::vector type_holders; + type_holders.reserve(types.size()); + for (const auto& type : types) { + type_holders.emplace_back(type); + } + return std::move(type_holders); +} + // ---------------------------------------------------------------------- FloatingPointType::Precision HalfFloatType::precision() const { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 4bf8fe7fabb9..59e109e93940 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -264,6 +264,9 @@ struct ARROW_EXPORT TypeHolder { } static std::string ToString(const std::vector&); + + static std::vector FromTypes( + const std::vector>& types); }; ARROW_EXPORT diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8b09cbd445e1..8f9441debee5 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -21,11 +21,11 @@ from pyarrow.lib cimport * from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -cdef class ScalarUdfContext(_Weakrefable): +cdef class UdfContext(_Weakrefable): cdef: - CScalarUdfContext c_context + CUdfContext c_context - cdef void init(self, const CScalarUdfContext& c_context) + cdef void init(self, const CUdfContext& c_context) cdef class FunctionOptions(_Weakrefable): cdef: diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index bc3881a7d335..75b6cd48a0b2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2519,7 +2519,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: deref(pyarrow_unwrap_schema(schema).get()))) -cdef class ScalarUdfContext: +cdef class UdfContext: """ Per-invocation function context/state. @@ -2531,7 +2531,7 @@ cdef class ScalarUdfContext: raise TypeError("Do not call {}'s constructor directly" .format(self.__class__.__name__)) - cdef void init(self, const CScalarUdfContext &c_context): + cdef void init(self, const CUdfContext &c_context): self.c_context = c_context @property @@ -2580,48 +2580,47 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: return f_doc -cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): - cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext) +cdef object box_udf_context(const CUdfContext& c_context): + cdef UdfContext context = UdfContext.__new__(UdfContext) context.init(c_context) return context -cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs): +cdef _udf_callback(user_function, const CUdfContext& c_context, inputs): """ - Helper callback function used to wrap the ScalarUdfContext from Python to C++ + Helper callback function used to wrap the UdfContext from Python to C++ execution. """ - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return user_function(context, *inputs) -def _get_scalar_udf_context(memory_pool, batch_length): - cdef CScalarUdfContext c_context +def _get_udf_context(memory_pool, batch_length): + cdef CUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) c_context.batch_length = batch_length - context = box_scalar_udf_context(c_context) + context = box_udf_context(c_context) return context -ctypedef CStatus (*CRegisterScalarLikeFunction)(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options, - CFunctionRegistry* registry) +ctypedef CStatus (*CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper, + const CUdfOptions& options, CFunctionRegistry* registry) -cdef class RegisterScalarLikeFunction(_Weakrefable): - cdef CRegisterScalarLikeFunction register_func +cdef class RegisterUdf(_Weakrefable): + cdef CRegisterUdf register_func - cdef void init(self, const CRegisterScalarLikeFunction register_func): + cdef void init(self, const CRegisterUdf register_func): self.register_func = register_func cdef get_register_scalar_function(): - cdef RegisterScalarLikeFunction reg = RegisterScalarLikeFunction.__new__(RegisterScalarLikeFunction) + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) reg.register_func = RegisterScalarFunction return reg cdef get_register_tabular_function(): - cdef RegisterScalarLikeFunction reg = RegisterScalarLikeFunction.__new__(RegisterScalarLikeFunction) + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) reg.register_func = RegisterTabularFunction return reg @@ -2645,7 +2644,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty func : callable A callable implementing the user-defined function. The first argument is the context argument of type - ScalarUdfContext. + UdfContext. Then, it must take arguments equal to the number of in_types defined. It must return an Array or Scalar matching the out_type. It must return a Scalar if @@ -2710,7 +2709,7 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t Register a user-defined tabular function. A tabular function is one accepting a context argument of type - ScalarUdfContext and returning a generator of struct arrays. + UdfContext and returning a generator of struct arrays. The in_types argument must be empty and the out_type argument specifies a schema. Each struct array must have field types correspoding to the schema. @@ -2720,7 +2719,7 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t func : callable A callable implementing the user-defined function. The only argument is the context argument of type - ScalarUdfContext. It must return a callable that + UdfContext. It must return a callable that returns on each invocation a StructArray matching the out_type, where an empty array indicates end. function_name : str @@ -2755,7 +2754,7 @@ def _register_scalar_like_function(register_func, func, function_name, function_ Register a user-defined scalar-like function. A scalar-like function is a callable accepting a first - context argument of type ScalarUdfContext as well as + context argument of type UdfContext as well as possibly additional Arrow arguments, and returning a an Arrow result appropriate for the kind of function. A scalar function and a tabular function are examples @@ -2766,8 +2765,7 @@ def _register_scalar_like_function(register_func, func, function_name, function_ Parameters ---------- register_func: object - An object holding a CRegisterScalarLikeFunction in - a "register_func" attribute. Use + An object holding a CRegisterUdf in a "register_func" attribute. Use get_register_scalar_function() for a scalar function and get_register_tabular_function() for a tabular function. func : callable @@ -2791,14 +2789,14 @@ def _register_scalar_like_function(register_func, func, function_name, function_ Optional function registry to use instead of the default global one. """ cdef: - CRegisterScalarLikeFunction c_register_func + CRegisterUdf c_register_func c_string c_func_name CArity c_arity CFunctionDoc c_func_doc vector[shared_ptr[CDataType]] c_in_types PyObject* c_function shared_ptr[CDataType] c_out_type - CScalarUdfOptions c_options + CUdfOptions c_options CFunctionRegistry* c_func_registry if callable(func): @@ -2846,10 +2844,10 @@ def _register_scalar_like_function(register_func, func, function_name, function_ else: c_func_registry = (func_registry).registry - c_register_func = (register_func).register_func + c_register_func = (register_func).register_func check_status(c_register_func(c_function, - &_scalar_udf_callback, + &_udf_callback, c_options, c_func_registry)) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index d7d4912c8af9..7c312a69564f 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -83,7 +83,7 @@ call_tabular_function, register_scalar_function, register_tabular_function, - ScalarUdfContext, + UdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 8813132efd34..d182a79fcde4 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2798,7 +2798,7 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: int64_t TotalBufferSize(const CRecordBatch& record_batch) int64_t TotalBufferSize(const CTable& table) -ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) +ctypedef PyObject* CallbackUdf(object user_function, const CUdfContext& context, object inputs) cdef extern from "arrow/api.h" namespace "arrow" nogil: @@ -2809,11 +2809,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: - cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": + cdef cppclass CUdfContext" arrow::py::UdfContext": CMemoryPool *pool int64_t batch_length - cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": + cdef cppclass CUdfOptions" arrow::py::UdfOptions": c_string func_name CArity arity CFunctionDoc func_doc @@ -2821,11 +2821,11 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: shared_ptr[CDataType] output_type CStatus RegisterScalarFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + function[CallbackUdf] wrapper, const CUdfOptions& options, CFunctionRegistry* registry) CStatus RegisterTabularFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options, + function[CallbackUdf] wrapper, const CUdfOptions& options, CFunctionRegistry* registry) CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 5e6f2b0ec101..8345c835846b 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -26,20 +26,32 @@ namespace py { namespace { -struct PythonScalarUdfKernelState : public compute::KernelState { - explicit PythonScalarUdfKernelState(std::shared_ptr function) - : function(function) {} +struct PythonUdfKernelState : public compute::KernelState { + explicit PythonUdfKernelState(std::shared_ptr function) + : function(function) { + Py_INCREF(function->obj()); + } + + // function needs to be destroyed at process exit + // and Python may no longer be initialized. + ~PythonUdfKernelState() { + if (_Py_IsFinalizing()) { + function->detach(); + } + } std::shared_ptr function; }; -struct PythonScalarUdfKernelInit { - explicit PythonScalarUdfKernelInit(std::shared_ptr function) - : function(function) {} +struct PythonUdfKernelInit { + explicit PythonUdfKernelInit(std::shared_ptr function) + : function(function) { + Py_INCREF(function->obj()); + } // function needs to be destroyed at process exit // and Python may no longer be initialized. - ~PythonScalarUdfKernelInit() { + ~PythonUdfKernelInit() { if (_Py_IsFinalizing()) { function->detach(); } @@ -47,7 +59,7 @@ struct PythonScalarUdfKernelInit { Result> operator()( compute::KernelContext*, const compute::KernelInitArgs&) { - return std::make_unique(function); + return std::make_unique(function); } std::shared_ptr function; @@ -55,14 +67,22 @@ struct PythonScalarUdfKernelInit { struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, - ScalarUdfWrapperCallback cb) + UdfWrapperCallback cb) : function_maker(function_maker), cb(cb) { Py_INCREF(function_maker->obj()); } + // function needs to be destroyed at process exit + // and Python may no longer be initialized. + ~PythonTableUdfKernelInit() { + if (_Py_IsFinalizing()) { + function_maker->detach(); + } + } + Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - ScalarUdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; + UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; std::unique_ptr function; RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { OwnedRef empty_tuple(PyTuple_New(0)); @@ -74,28 +94,44 @@ struct PythonTableUdfKernelInit { if (!PyCallable_Check(function->obj())) { return Status::TypeError("Expected a callable Python object."); } - return std::make_unique( + return std::make_unique( std::move(function)); } std::shared_ptr function_maker; - ScalarUdfWrapperCallback cb; + UdfWrapperCallback cb; }; -struct PythonUdf : public PythonScalarUdfKernelState { - PythonUdf(std::shared_ptr function, ScalarUdfWrapperCallback cb, - compute::OutputType output_type) - : PythonScalarUdfKernelState(function), cb(cb), output_type(output_type) {} +struct PythonUdf : public PythonUdfKernelState { + PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, + std::vector input_types, compute::OutputType output_type) + : PythonUdfKernelState(function), + cb(cb), + input_types(input_types), + output_type(output_type) {} - ScalarUdfWrapperCallback cb; + UdfWrapperCallback cb; + std::vector input_types; compute::OutputType output_type; + TypeHolder resolved_type; + + Result ResolveType(compute::KernelContext* ctx, + const std::vector& types) { + if (input_types == types) { + if (!resolved_type) { + ARROW_ASSIGN_OR_RAISE(resolved_type, output_type.Resolve(ctx, input_types)); + } + return resolved_type; + } + return output_type.Resolve(ctx, types); + } Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch, compute::ExecResult* out) { - auto state = arrow::internal::checked_cast(ctx->state()); + auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->function; const int num_args = batch.num_values(); - ScalarUdfContext udf_context{ctx->memory_pool(), batch.length}; + UdfContext udf_context{ctx->memory_pool(), batch.length}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -116,7 +152,7 @@ struct PythonUdf : public PythonScalarUdfKernelState { // unwrapping the output for expected output type if (is_array(result.obj())) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); - ARROW_ASSIGN_OR_RAISE(TypeHolder type, output_type.Resolve(ctx, batch.GetTypes())); + ARROW_ASSIGN_OR_RAISE(TypeHolder type, ResolveType(ctx, batch.GetTypes())); if (type.type == NULLPTR) { return Status::TypeError("expected output datatype is null"); } @@ -141,11 +177,9 @@ Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); }); } -Status RegisterScalarLikeFunction(PyObject* user_function, - compute::KernelInit kernel_init, - ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options, - compute::FunctionRegistry* registry) { +Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init, + UdfWrapperCallback wrapper, const UdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(user_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -158,7 +192,8 @@ Status RegisterScalarLikeFunction(PyObject* user_function, } compute::OutputType output_type(options.output_type); auto udf_data = std::make_shared( - std::make_shared(user_function), wrapper, options.output_type); + std::make_shared(user_function), wrapper, + TypeHolder::FromTypes(options.input_types), options.output_type); compute::ScalarKernel kernel( compute::KernelSignature::Make(std::move(input_types), std::move(output_type), options.arity.is_varargs), @@ -177,17 +212,17 @@ Status RegisterScalarLikeFunction(PyObject* user_function, } // namespace -Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options, +Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry) { - return RegisterScalarLikeFunction( + return RegisterUdf( user_function, - PythonScalarUdfKernelInit{std::make_shared(user_function)}, wrapper, + PythonUdfKernelInit{std::make_shared(user_function)}, wrapper, options, registry); } -Status RegisterTabularFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options, +Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry) { if (options.arity.num_args != 0 || options.arity.is_varargs) { return Status::NotImplemented("tabular function of non-null arity"); @@ -195,25 +230,12 @@ Status RegisterTabularFunction(PyObject* user_function, ScalarUdfWrapperCallback if (options.output_type->id() != Type::type::STRUCT) { return Status::Invalid("tabular function with non-struct output"); } - return RegisterScalarLikeFunction( + return RegisterUdf( user_function, PythonTableUdfKernelInit{std::make_shared(user_function), wrapper}, wrapper, options, registry); } -namespace { - -Result> RecordBatchFromArray( - std::shared_ptr schema, std::shared_ptr array) { - auto& data = const_cast&>(array->data()); - if (data->child_data.size() != static_cast(schema->num_fields())) { - return Status::Invalid("UDF result with shape not conforming to schema"); - } - return RecordBatch::Make(std::move(schema), data->length, std::move(data->child_data)); -} - -} // namespace - Result> CallTabularFunction( const std::string& func_name, const std::vector& args, compute::FunctionRegistry* registry) { @@ -264,7 +286,11 @@ Result> CallTabularFunction( if (array->length() == 0) { return IterationTraits>::End(); } - return RecordBatchFromArray(std::move(schema), std::move(array)); + ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(std::move(array))); + if (!schema->Equals(batch->schema())) { + return Status::Invalid("UDF result with shape not conforming to schema"); + } + return std::move(batch); }; return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(std::move(next_func)), schema); diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 798251d680e8..58a820564442 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -35,7 +35,7 @@ namespace py { // TODO: TODO(ARROW-16041): UDF Options are not exposed to the Python // users. This feature will be included when extending to provide advanced // options for the users. -struct ARROW_PYTHON_EXPORT ScalarUdfOptions { +struct ARROW_PYTHON_EXPORT UdfOptions { std::string func_name; compute::Arity arity; compute::FunctionDoc func_doc; @@ -44,23 +44,23 @@ struct ARROW_PYTHON_EXPORT ScalarUdfOptions { }; /// \brief A context passed as the first argument of scalar UDF functions. -struct ARROW_PYTHON_EXPORT ScalarUdfContext { +struct ARROW_PYTHON_EXPORT UdfContext { MemoryPool* pool; int64_t batch_length; }; -using ScalarUdfWrapperCallback = std::function; +using UdfWrapperCallback = std::function; /// \brief register a Scalar user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterScalarFunction( - PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); /// \brief register a Table user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterTabularFunction( - PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); Result> ARROW_PYTHON_EXPORT CallTabularFunction( const std::string& func_name, const std::vector& args, diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e1dbc7e01f4c..168339054ee0 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -32,8 +32,8 @@ def mock_udf_context(batch_length=10): - from pyarrow._compute import _get_scalar_udf_context - return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) + from pyarrow._compute import _get_udf_context + return _get_udf_context(pa.default_memory_pool(), batch_length) class MyError(RuntimeError): @@ -266,7 +266,7 @@ def check_scalar_function(func_fixture, assert result_table.column(0).chunks[0] == expected_output -def test_scalar_udf_array_unary(unary_func_fixture): +def test_udf_array_unary(unary_func_fixture): check_scalar_function(unary_func_fixture, [ pa.array([10, 20], pa.int64()) @@ -274,7 +274,7 @@ def test_scalar_udf_array_unary(unary_func_fixture): ) -def test_scalar_udf_array_binary(binary_func_fixture): +def test_udf_array_binary(binary_func_fixture): check_scalar_function(binary_func_fixture, [ pa.array([10, 20], pa.int64()), @@ -283,7 +283,7 @@ def test_scalar_udf_array_binary(binary_func_fixture): ) -def test_scalar_udf_array_ternary(ternary_func_fixture): +def test_udf_array_ternary(ternary_func_fixture): check_scalar_function(ternary_func_fixture, [ pa.array([10, 20], pa.int64()), @@ -293,7 +293,7 @@ def test_scalar_udf_array_ternary(ternary_func_fixture): ) -def test_scalar_udf_array_varargs(varargs_func_fixture): +def test_udf_array_varargs(varargs_func_fixture): check_scalar_function(varargs_func_fixture, [ pa.array([2, 3], pa.int64()), From 4236e37aaeef7fb92de4b241f3202f4e8a40463d Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 15 Dec 2022 10:52:00 -0500 Subject: [PATCH 13/16] fix copy elision --- cpp/src/arrow/type.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 92d44062e803..3514d0538fa4 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -459,7 +459,7 @@ std::vector TypeHolder::FromTypes( for (const auto& type : types) { type_holders.emplace_back(type); } - return std::move(type_holders); + return type_holders; } // ---------------------------------------------------------------------- From 8a4d820131ba919147a0704ff91acf100c745384 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 23 Dec 2022 03:41:29 -0500 Subject: [PATCH 14/16] revert to scalar-UDF naming --- python/pyarrow/_compute.pxd | 6 +++--- python/pyarrow/_compute.pyx | 28 +++++++++++++------------- python/pyarrow/compute.py | 2 +- python/pyarrow/includes/libarrow.pxd | 4 ++-- python/pyarrow/src/arrow/python/udf.cc | 10 ++++----- python/pyarrow/src/arrow/python/udf.h | 4 ++-- python/pyarrow/tests/test_udf.py | 10 ++++----- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 8f9441debee5..8b09cbd445e1 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -21,11 +21,11 @@ from pyarrow.lib cimport * from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -cdef class UdfContext(_Weakrefable): +cdef class ScalarUdfContext(_Weakrefable): cdef: - CUdfContext c_context + CScalarUdfContext c_context - cdef void init(self, const CUdfContext& c_context) + cdef void init(self, const CScalarUdfContext& c_context) cdef class FunctionOptions(_Weakrefable): cdef: diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 75b6cd48a0b2..283f5328376b 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2519,7 +2519,7 @@ cdef CExpression _bind(Expression filter, Schema schema) except *: deref(pyarrow_unwrap_schema(schema).get()))) -cdef class UdfContext: +cdef class ScalarUdfContext: """ Per-invocation function context/state. @@ -2531,7 +2531,7 @@ cdef class UdfContext: raise TypeError("Do not call {}'s constructor directly" .format(self.__class__.__name__)) - cdef void init(self, const CUdfContext &c_context): + cdef void init(self, const CScalarUdfContext &c_context): self.c_context = c_context @property @@ -2580,26 +2580,26 @@ cdef inline CFunctionDoc _make_function_doc(dict func_doc) except *: return f_doc -cdef object box_udf_context(const CUdfContext& c_context): - cdef UdfContext context = UdfContext.__new__(UdfContext) +cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): + cdef ScalarUdfContext context = ScalarUdfContext.__new__(ScalarUdfContext) context.init(c_context) return context -cdef _udf_callback(user_function, const CUdfContext& c_context, inputs): +cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs): """ - Helper callback function used to wrap the UdfContext from Python to C++ + Helper callback function used to wrap the ScalarUdfContext from Python to C++ execution. """ - context = box_udf_context(c_context) + context = box_scalar_udf_context(c_context) return user_function(context, *inputs) -def _get_udf_context(memory_pool, batch_length): - cdef CUdfContext c_context +def _get_scalar_udf_context(memory_pool, batch_length): + cdef CScalarUdfContext c_context c_context.pool = maybe_unbox_memory_pool(memory_pool) c_context.batch_length = batch_length - context = box_udf_context(c_context) + context = box_scalar_udf_context(c_context) return context @@ -2644,7 +2644,7 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty func : callable A callable implementing the user-defined function. The first argument is the context argument of type - UdfContext. + ScalarUdfContext. Then, it must take arguments equal to the number of in_types defined. It must return an Array or Scalar matching the out_type. It must return a Scalar if @@ -2709,7 +2709,7 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t Register a user-defined tabular function. A tabular function is one accepting a context argument of type - UdfContext and returning a generator of struct arrays. + ScalarUdfContext and returning a generator of struct arrays. The in_types argument must be empty and the out_type argument specifies a schema. Each struct array must have field types correspoding to the schema. @@ -2719,7 +2719,7 @@ def register_tabular_function(func, function_name, function_doc, in_types, out_t func : callable A callable implementing the user-defined function. The only argument is the context argument of type - UdfContext. It must return a callable that + ScalarUdfContext. It must return a callable that returns on each invocation a StructArray matching the out_type, where an empty array indicates end. function_name : str @@ -2754,7 +2754,7 @@ def _register_scalar_like_function(register_func, func, function_name, function_ Register a user-defined scalar-like function. A scalar-like function is a callable accepting a first - context argument of type UdfContext as well as + context argument of type ScalarUdfContext as well as possibly additional Arrow arguments, and returning a an Arrow result appropriate for the kind of function. A scalar function and a tabular function are examples diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 7c312a69564f..d7d4912c8af9 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -83,7 +83,7 @@ call_tabular_function, register_scalar_function, register_tabular_function, - UdfContext, + ScalarUdfContext, # Expressions Expression, ) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index d182a79fcde4..344bb2f69166 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2798,7 +2798,7 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: int64_t TotalBufferSize(const CRecordBatch& record_batch) int64_t TotalBufferSize(const CTable& table) -ctypedef PyObject* CallbackUdf(object user_function, const CUdfContext& context, object inputs) +ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) cdef extern from "arrow/api.h" namespace "arrow" nogil: @@ -2809,7 +2809,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: - cdef cppclass CUdfContext" arrow::py::UdfContext": + cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool int64_t batch_length diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 8345c835846b..101dfc2ffb32 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -82,12 +82,12 @@ struct PythonTableUdfKernelInit { Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; + ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0}; std::unique_ptr function; - RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { + RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] { OwnedRef empty_tuple(PyTuple_New(0)); function = std::make_unique( - cb(function_maker->obj(), udf_context, empty_tuple.obj())); + cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj())); RETURN_NOT_OK(CheckPyError()); return Status::OK(); })); @@ -131,7 +131,7 @@ struct PythonUdf : public PythonUdfKernelState { auto state = arrow::internal::checked_cast(ctx->state()); std::shared_ptr& function = state->function; const int num_args = batch.num_values(); - UdfContext udf_context{ctx->memory_pool(), batch.length}; + ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -147,7 +147,7 @@ struct PythonUdf : public PythonUdfKernelState { } } - OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); + OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 58a820564442..cde97d9cb916 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -44,13 +44,13 @@ struct ARROW_PYTHON_EXPORT UdfOptions { }; /// \brief A context passed as the first argument of scalar UDF functions. -struct ARROW_PYTHON_EXPORT UdfContext { +struct ARROW_PYTHON_EXPORT ScalarUdfContext { MemoryPool* pool; int64_t batch_length; }; using UdfWrapperCallback = std::function; + PyObject* user_function, const ScalarUdfContext& context, PyObject* inputs)>; /// \brief register a Scalar user-defined-function from Python Status ARROW_PYTHON_EXPORT RegisterScalarFunction( diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index 168339054ee0..6a67e0bae9c7 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -31,9 +31,9 @@ ds = None -def mock_udf_context(batch_length=10): - from pyarrow._compute import _get_udf_context - return _get_udf_context(pa.default_memory_pool(), batch_length) +def mock_scalar_udf_context(batch_length=10): + from pyarrow._compute import _get_scalar_udf_context + return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) class MyError(RuntimeError): @@ -248,7 +248,7 @@ def check_scalar_function(func_fixture, if all_scalar: batch_length = 1 - expected_output = function(mock_udf_context(batch_length), *inputs) + expected_output = function(mock_scalar_udf_context(batch_length), *inputs) func = pc.get_function(name) assert func.name == name @@ -464,7 +464,7 @@ def identity(ctx, val): in_types, out_type) -def test_udf_context(unary_func_fixture): +def test_scalar_udf_context(unary_func_fixture): # Check the memory_pool argument is properly propagated proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool()) _, func_name = unary_func_fixture From 5076b978c347ebf7ab1e699e3ea5bb802ac13a11 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sat, 7 Jan 2023 15:21:10 -0500 Subject: [PATCH 15/16] fix merge --- cpp/src/arrow/engine/substrait/serde_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 9356fedd588a..2916782fe048 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -4220,8 +4220,8 @@ TEST(Substrait, SetRelationBasic) { compute::SortOptions sort_options( {compute::SortKey("A", compute::SortOrder::Ascending)}); - CheckRoundTripResult(std::move(expected_table), exec_context, buf, {}, - conversion_options, &sort_options); + CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options, + &sort_options); } TEST(Substrait, PlanWithAsOfJoinExtension) { From dc61c55437a8a7510b36b78d52bde9bab3648cbd Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 11 Jan 2023 09:38:04 -0500 Subject: [PATCH 16/16] add todo --- python/pyarrow/src/arrow/python/udf.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 101dfc2ffb32..763bdf5a034f 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -278,6 +278,7 @@ Result> CallTabularFunction( std::vector args; // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator // in exec.cc and to never invoking the source function, so 1 is passed instead + // TODO: GH-33612: Support batch size in user-defined tabular functions ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, /*passed_length=*/1)); if (!datum.is_array()) { return Status::Invalid("UDF result of non-array kind");