diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index cc31735512ba..3514d0538fa4 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -452,6 +452,16 @@ std::string TypeHolder::ToString(const std::vector& types) { return ss.str(); } +std::vector TypeHolder::FromTypes( + const std::vector>& types) { + std::vector type_holders; + type_holders.reserve(types.size()); + for (const auto& type : types) { + type_holders.emplace_back(type); + } + return type_holders; +} + // ---------------------------------------------------------------------- FloatingPointType::Precision HalfFloatType::precision() const { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 4bf8fe7fabb9..59e109e93940 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -264,6 +264,9 @@ struct ARROW_EXPORT TypeHolder { } static std::string ToString(const std::vector&); + + static std::vector FromTypes( + const std::vector>& types); }; ARROW_EXPORT diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c75c5bf189ba..283f5328376b 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -36,6 +36,18 @@ import inspect import numpy as np +def _forbid_instantiation(klass, subclasses_instead=True): + msg = '{} is an abstract class thus cannot be initialized.'.format( + klass.__name__ + ) + if subclasses_instead: + subclasses = [cls.__name__ for cls in klass.__subclasses__] + msg += ' Use one of the subclasses instead: {}'.format( + ', '.join(subclasses) + ) + raise TypeError(msg) + + cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): """ Wrap a C++ scalar Function in a ScalarFunction object. @@ -2574,7 +2586,7 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context): return context -cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs): +cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs): """ Helper callback function used to wrap the ScalarUdfContext from Python to C++ execution. @@ -2591,8 +2603,30 @@ def _get_scalar_udf_context(memory_pool, batch_length): return context -def register_scalar_function(func, function_name, function_doc, in_types, - out_type): +ctypedef CStatus (*CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper, + const CUdfOptions& options, CFunctionRegistry* registry) + +cdef class RegisterUdf(_Weakrefable): + cdef CRegisterUdf register_func + + cdef void init(self, const CRegisterUdf register_func): + self.register_func = register_func + + +cdef get_register_scalar_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterScalarFunction + return reg + + +cdef get_register_tabular_function(): + cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf) + reg.register_func = RegisterTabularFunction + return reg + + +def register_scalar_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): """ Register a user-defined scalar function. @@ -2633,6 +2667,8 @@ def register_scalar_function(func, function_name, function_doc, in_types, arity. out_type : DataType Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. Examples -------- @@ -2662,14 +2698,106 @@ def register_scalar_function(func, function_name, function_doc, in_types, 21 ] """ + return _register_scalar_like_function(get_register_scalar_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def register_tabular_function(func, function_name, function_doc, in_types, out_type, + func_registry=None): + """ + Register a user-defined tabular function. + + A tabular function is one accepting a context argument of type + ScalarUdfContext and returning a generator of struct arrays. + The in_types argument must be empty and the out_type argument + specifies a schema. Each struct array must have field types + correspoding to the schema. + + Parameters + ---------- + func : callable + A callable implementing the user-defined function. + The only argument is the context argument of type + ScalarUdfContext. It must return a callable that + returns on each invocation a StructArray matching + the out_type, where an empty array indicates end. + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + Must be an empty dictionary (reserved for future use). + out_type : Union[Schema, DataType] + Schema of the function's output, or a corresponding flat struct type. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CDataType] c_type + + if isinstance(out_type, Schema): + c_schema = pyarrow_unwrap_schema(out_type) + with nogil: + c_type = make_shared[CStructType](deref(c_schema).fields()) + out_type = pyarrow_wrap_data_type(c_type) + return _register_scalar_like_function(get_register_tabular_function(), + func, function_name, function_doc, in_types, + out_type, func_registry) + + +def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types, + out_type, func_registry=None): + """ + Register a user-defined scalar-like function. + + A scalar-like function is a callable accepting a first + context argument of type ScalarUdfContext as well as + possibly additional Arrow arguments, and returning a + an Arrow result appropriate for the kind of function. + A scalar function and a tabular function are examples + for scalar-like functions. + This function is normally not called directly but via + register_scalar_function or register_tabular_function. + + Parameters + ---------- + register_func: object + An object holding a CRegisterUdf in a "register_func" attribute. Use + get_register_scalar_function() for a scalar function and + get_register_tabular_function() for a tabular function. + func : callable + A callable implementing the user-defined function. + See register_scalar_function and + register_tabular_function for details. + + function_name : str + Name of the function. This name must be globally unique. + function_doc : dict + A dictionary object with keys "summary" (str), + and "description" (str). + in_types : Dict[str, DataType] + A dictionary mapping function argument names to + their respective DataType. + See register_scalar_function and + register_tabular_function for details. + out_type : DataType + Output type of the function. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ + cdef: + CRegisterUdf c_register_func c_string c_func_name CArity c_arity CFunctionDoc c_func_doc vector[shared_ptr[CDataType]] c_in_types PyObject* c_function shared_ptr[CDataType] c_out_type - CScalarUdfOptions c_options + CUdfOptions c_options + CFunctionRegistry* c_func_registry if callable(func): c_function = func @@ -2711,5 +2839,51 @@ def register_scalar_function(func, function_name, function_doc, in_types, c_options.input_types = c_in_types c_options.output_type = c_out_type - check_status(RegisterScalarFunction(c_function, - &_scalar_udf_callback, c_options)) + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + + c_register_func = (register_func).register_func + + check_status(c_register_func(c_function, + &_udf_callback, + c_options, c_func_registry)) + + +def call_tabular_function(function_name, args=None, func_registry=None): + """ + Get a record batch iterator from a tabular function. + + Parameters + ---------- + function_name : str + Name of the function. + args : iterable + The arguments to pass to the function. Accepted types depend + on the specific function. Currently, only an empty args is supported. + func_registry : FunctionRegistry + Optional function registry to use instead of the default global one. + """ + cdef: + c_string c_func_name + vector[CDatum] c_args + CFunctionRegistry* c_func_registry + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader reader + + c_func_name = tobytes(function_name) + if func_registry is None: + c_func_registry = NULL + else: + c_func_registry = (func_registry).registry + if args is None: + args = [] + _pack_compute_args(args, &c_args) + + with nogil: + c_reader = GetResultValue(CallTabularFunction( + c_func_name, c_args, c_func_registry)) + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = c_reader + return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 5f1610c384fb..38ff60f380d0 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -32,24 +32,13 @@ from pyarrow.lib cimport * from pyarrow.lib import ArrowTypeError, frombytes, tobytes, _pc from pyarrow.includes.libarrow_dataset cimport * from pyarrow._compute cimport Expression, _bind +from pyarrow._compute import _forbid_instantiation from pyarrow._fs cimport FileSystem, FileInfo, FileSelector from pyarrow._csv cimport ( ConvertOptions, ParseOptions, ReadOptions, WriteOptions) from pyarrow.util import _is_iterable, _is_path_like, _stringify_path -def _forbid_instantiation(klass, subclasses_instead=True): - msg = '{} is an abstract class thus cannot be initialized.'.format( - klass.__name__ - ) - if subclasses_instead: - subclasses = [cls.__name__ for cls in klass.__subclasses__] - msg += ' Use one of the subclasses instead: {}'.format( - ', '.join(subclasses) - ) - raise TypeError(msg) - - _orc_fileformat = None _orc_imported = False diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 1ee6c40f4232..f455b81411a2 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -80,7 +80,9 @@ list_functions, _group_by, # Udf + call_tabular_function, register_scalar_function, + register_tabular_function, ScalarUdfContext, # Expressions Expression, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index df6a883afe98..537e051a9a89 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -480,6 +480,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name) int GetFieldIndex(const c_string& name) vector[int] GetAllFieldIndices(const c_string& name) + const vector[shared_ptr[CField]] fields() int num_fields() c_string ToString() @@ -800,6 +801,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: const shared_ptr[CSchema]& schema, int64_t num_rows, const vector[shared_ptr[CArray]]& columns) + CResult[shared_ptr[CStructArray]] ToStructArray() const + @staticmethod CResult[shared_ptr[CRecordBatch]] FromStructArray( const shared_ptr[CArray]& array) @@ -2805,12 +2808,20 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs) -cdef extern from "arrow/python/udf.h" namespace "arrow::py": + +cdef extern from "arrow/api.h" namespace "arrow" nogil: + + cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"( + CIterator[shared_ptr[CRecordBatch]]): + pass + + +cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil: cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext": CMemoryPool *pool int64_t batch_length - cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions": + cdef cppclass CUdfOptions" arrow::py::UdfOptions": c_string func_name CArity arity CFunctionDoc func_doc @@ -2818,4 +2829,12 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py": shared_ptr[CDataType] output_type CStatus RegisterScalarFunction(PyObject* function, - function[CallbackUdf] wrapper, const CScalarUdfOptions& options) + function[CallbackUdf] wrapper, const CUdfOptions& options, + CFunctionRegistry* registry) + + CStatus RegisterTabularFunction(PyObject* function, + function[CallbackUdf] wrapper, const CUdfOptions& options, + CFunctionRegistry* registry) + + CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction( + const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index b75eafcdeea4..160379708490 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -25,13 +25,6 @@ from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_fs cimport * -cdef extern from "arrow/api.h" namespace "arrow" nogil: - - cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"( - CIterator[shared_ptr[CRecordBatch]]): - pass - - cdef extern from "arrow/dataset/plan.h" namespace "arrow::dataset::internal" nogil: cdef void Initialize() diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 81bf47c0ade0..763bdf5a034f 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -17,37 +17,121 @@ #include "arrow/python/udf.h" #include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" #include "arrow/python/common.h" +#include "arrow/util/checked_cast.h" namespace arrow { - -using compute::ExecResult; -using compute::ExecSpan; - namespace py { namespace { -struct PythonUdf : public compute::KernelState { - ScalarUdfWrapperCallback cb; +struct PythonUdfKernelState : public compute::KernelState { + explicit PythonUdfKernelState(std::shared_ptr function) + : function(function) { + Py_INCREF(function->obj()); + } + + // function needs to be destroyed at process exit + // and Python may no longer be initialized. + ~PythonUdfKernelState() { + if (_Py_IsFinalizing()) { + function->detach(); + } + } + std::shared_ptr function; - std::shared_ptr output_type; +}; - PythonUdf(ScalarUdfWrapperCallback cb, std::shared_ptr function, - const std::shared_ptr& output_type) - : cb(cb), function(function), output_type(output_type) {} +struct PythonUdfKernelInit { + explicit PythonUdfKernelInit(std::shared_ptr function) + : function(function) { + Py_INCREF(function->obj()); + } // function needs to be destroyed at process exit // and Python may no longer be initialized. - ~PythonUdf() { + ~PythonUdfKernelInit() { if (_Py_IsFinalizing()) { function->detach(); } } - Status Exec(compute::KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + Result> operator()( + compute::KernelContext*, const compute::KernelInitArgs&) { + return std::make_unique(function); + } + + std::shared_ptr function; +}; + +struct PythonTableUdfKernelInit { + PythonTableUdfKernelInit(std::shared_ptr function_maker, + UdfWrapperCallback cb) + : function_maker(function_maker), cb(cb) { + Py_INCREF(function_maker->obj()); + } + + // function needs to be destroyed at process exit + // and Python may no longer be initialized. + ~PythonTableUdfKernelInit() { + if (_Py_IsFinalizing()) { + function_maker->detach(); + } + } + + Result> operator()( + compute::KernelContext* ctx, const compute::KernelInitArgs&) { + ScalarUdfContext scalar_udf_context{ctx->memory_pool(), /*batch_length=*/0}; + std::unique_ptr function; + RETURN_NOT_OK(SafeCallIntoPython([this, &scalar_udf_context, &function] { + OwnedRef empty_tuple(PyTuple_New(0)); + function = std::make_unique( + cb(function_maker->obj(), scalar_udf_context, empty_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + return Status::OK(); + })); + if (!PyCallable_Check(function->obj())) { + return Status::TypeError("Expected a callable Python object."); + } + return std::make_unique( + std::move(function)); + } + + std::shared_ptr function_maker; + UdfWrapperCallback cb; +}; + +struct PythonUdf : public PythonUdfKernelState { + PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, + std::vector input_types, compute::OutputType output_type) + : PythonUdfKernelState(function), + cb(cb), + input_types(input_types), + output_type(output_type) {} + + UdfWrapperCallback cb; + std::vector input_types; + compute::OutputType output_type; + TypeHolder resolved_type; + + Result ResolveType(compute::KernelContext* ctx, + const std::vector& types) { + if (input_types == types) { + if (!resolved_type) { + ARROW_ASSIGN_OR_RAISE(resolved_type, output_type.Resolve(ctx, input_types)); + } + return resolved_type; + } + return output_type.Resolve(ctx, types); + } + + Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch, + compute::ExecResult* out) { + auto state = arrow::internal::checked_cast(ctx->state()); + std::shared_ptr& function = state->function; const int num_args = batch.num_values(); - ScalarUdfContext udf_context{ctx->memory_pool(), batch.length}; + ScalarUdfContext scalar_udf_context{ctx->memory_pool(), batch.length}; OwnedRef arg_tuple(PyTuple_New(num_args)); RETURN_NOT_OK(CheckPyError()); @@ -63,13 +147,17 @@ struct PythonUdf : public compute::KernelState { } } - OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); + OwnedRef result(cb(function->obj(), scalar_udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr val, unwrap_array(result.obj())); - if (!output_type->Equals(*val->type())) { - return Status::TypeError("Expected output datatype ", output_type->ToString(), + ARROW_ASSIGN_OR_RAISE(TypeHolder type, ResolveType(ctx, batch.GetTypes())); + if (type.type == NULLPTR) { + return Status::TypeError("expected output datatype is null"); + } + if (*type.type != *val->type()) { + return Status::TypeError("Expected output datatype ", type.type->ToString(), ", but function returned datatype ", val->type()->ToString()); } @@ -83,16 +171,15 @@ struct PythonUdf : public compute::KernelState { } }; -Status PythonUdfExec(compute::KernelContext* ctx, const ExecSpan& batch, - ExecResult* out) { +Status PythonUdfExec(compute::KernelContext* ctx, const compute::ExecSpan& batch, + compute::ExecResult* out) { auto udf = static_cast(ctx->kernel()->data.get()); return SafeCallIntoPython([&]() -> Status { return udf->Exec(ctx, batch, out); }); } -} // namespace - -Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options) { +Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init, + UdfWrapperCallback wrapper, const UdfOptions& options, + compute::FunctionRegistry* registry) { if (!PyCallable_Check(user_function)) { return Status::TypeError("Expected a callable Python object."); } @@ -105,21 +192,110 @@ Status RegisterScalarFunction(PyObject* user_function, ScalarUdfWrapperCallback } compute::OutputType output_type(options.output_type); auto udf_data = std::make_shared( - wrapper, std::make_shared(user_function), options.output_type); + std::make_shared(user_function), wrapper, + TypeHolder::FromTypes(options.input_types), options.output_type); compute::ScalarKernel kernel( compute::KernelSignature::Make(std::move(input_types), std::move(output_type), options.arity.is_varargs), - PythonUdfExec); + PythonUdfExec, kernel_init); kernel.data = std::move(udf_data); kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE; RETURN_NOT_OK(scalar_func->AddKernel(std::move(kernel))); - auto registry = compute::GetFunctionRegistry(); + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } RETURN_NOT_OK(registry->AddFunction(std::move(scalar_func))); return Status::OK(); } -} // namespace py +} // namespace +Status RegisterScalarFunction(PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + return RegisterUdf( + user_function, + PythonUdfKernelInit{std::make_shared(user_function)}, wrapper, + options, registry); +} + +Status RegisterTabularFunction(PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, + compute::FunctionRegistry* registry) { + if (options.arity.num_args != 0 || options.arity.is_varargs) { + return Status::NotImplemented("tabular function of non-null arity"); + } + if (options.output_type->id() != Type::type::STRUCT) { + return Status::Invalid("tabular function with non-struct output"); + } + return RegisterUdf( + user_function, + PythonTableUdfKernelInit{std::make_shared(user_function), wrapper}, + wrapper, options, registry); +} + +Result> CallTabularFunction( + const std::string& func_name, const std::vector& args, + compute::FunctionRegistry* registry) { + if (args.size() != 0) { + return Status::NotImplemented("non-empty arguments to tabular function"); + } + if (registry == NULLPTR) { + registry = compute::GetFunctionRegistry(); + } + ARROW_ASSIGN_OR_RAISE(auto func, registry->GetFunction(func_name)); + if (func->kind() != compute::Function::SCALAR) { + return Status::Invalid("tabular function of non-scalar kind"); + } + auto arity = func->arity(); + if (arity.num_args != 0 || arity.is_varargs) { + return Status::NotImplemented("tabular function of non-null arity"); + } + auto kernels = + arrow::internal::checked_pointer_cast(func)->kernels(); + if (kernels.size() != 1) { + return Status::NotImplemented("tabular function with non-single kernel"); + } + const compute::ScalarKernel* kernel = kernels[0]; + auto out_type = kernel->signature->out_type(); + if (out_type.kind() != compute::OutputType::FIXED) { + return Status::Invalid("tabular kernel of non-fixed kind"); + } + auto datatype = out_type.type(); + if (datatype->id() != Type::type::STRUCT) { + return Status::Invalid("tabular kernel with non-struct output"); + } + auto struct_type = arrow::internal::checked_cast(datatype.get()); + auto schema = ::arrow::schema(struct_type->fields()); + std::vector in_types; + ARROW_ASSIGN_OR_RAISE(auto func_exec, + GetFunctionExecutor(func_name, in_types, NULLPTR, registry)); + auto next_func = + [schema, + func_exec = std::move(func_exec)]() -> Result> { + std::vector args; + // passed_length of -1 or 0 with args.size() of 0 leads to an empty ExecSpanIterator + // in exec.cc and to never invoking the source function, so 1 is passed instead + // TODO: GH-33612: Support batch size in user-defined tabular functions + ARROW_ASSIGN_OR_RAISE(auto datum, func_exec->Execute(args, /*passed_length=*/1)); + if (!datum.is_array()) { + return Status::Invalid("UDF result of non-array kind"); + } + std::shared_ptr array = datum.make_array(); + if (array->length() == 0) { + return IterationTraits>::End(); + } + ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(std::move(array))); + if (!schema->Equals(batch->schema())) { + return Status::Invalid("UDF result with shape not conforming to schema"); + } + return std::move(batch); + }; + return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(std::move(next_func)), + schema); +} + +} // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/udf.h b/python/pyarrow/src/arrow/python/udf.h index 9a3666459fd8..cde97d9cb916 100644 --- a/python/pyarrow/src/arrow/python/udf.h +++ b/python/pyarrow/src/arrow/python/udf.h @@ -21,6 +21,8 @@ #include "arrow/compute/function.h" #include "arrow/compute/registry.h" #include "arrow/python/platform.h" +#include "arrow/record_batch.h" +#include "arrow/util/iterator.h" #include "arrow/python/common.h" #include "arrow/python/pyarrow.h" @@ -33,7 +35,7 @@ namespace py { // TODO: TODO(ARROW-16041): UDF Options are not exposed to the Python // users. This feature will be included when extending to provide advanced // options for the users. -struct ARROW_PYTHON_EXPORT ScalarUdfOptions { +struct ARROW_PYTHON_EXPORT UdfOptions { std::string func_name; compute::Arity arity; compute::FunctionDoc func_doc; @@ -47,13 +49,22 @@ struct ARROW_PYTHON_EXPORT ScalarUdfContext { int64_t batch_length; }; -using ScalarUdfWrapperCallback = std::function; /// \brief register a Scalar user-defined-function from Python -Status ARROW_PYTHON_EXPORT RegisterScalarFunction(PyObject* user_function, - ScalarUdfWrapperCallback wrapper, - const ScalarUdfOptions& options); +Status ARROW_PYTHON_EXPORT RegisterScalarFunction( + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + +/// \brief register a Table user-defined-function from Python +Status ARROW_PYTHON_EXPORT RegisterTabularFunction( + PyObject* user_function, UdfWrapperCallback wrapper, + const UdfOptions& options, compute::FunctionRegistry* registry = NULLPTR); + +Result> ARROW_PYTHON_EXPORT CallTabularFunction( + const std::string& func_name, const std::vector& args, + compute::FunctionRegistry* registry = NULLPTR); } // namespace py diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index bcc428a4cb29..1791faad2827 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2550,6 +2550,20 @@ cdef class RecordBatch(_PandasConvertible): CRecordBatch.FromStructArray(struct_array.sp_array)) return pyarrow_wrap_batch(c_record_batch) + def to_struct_array(self): + """ + Convert to a struct array. + """ + cdef: + shared_ptr[CRecordBatch] c_record_batch + shared_ptr[CArray] c_array + + c_record_batch = pyarrow_unwrap_batch(self) + with nogil: + c_array = GetResultValue( + deref(c_record_batch).ToStructArray()) + return pyarrow_wrap_array(c_array) + def _export_to_c(self, out_ptr, out_schema_ptr=0): """ Export to a C ArrowArray struct, given its pointer. diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index e711619582d2..6a67e0bae9c7 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -31,7 +31,7 @@ ds = None -def mock_udf_context(batch_length=10): +def mock_scalar_udf_context(batch_length=10): from pyarrow._compute import _get_scalar_udf_context return _get_scalar_udf_context(pa.default_memory_pool(), batch_length) @@ -248,7 +248,7 @@ def check_scalar_function(func_fixture, if all_scalar: batch_length = 1 - expected_output = function(mock_udf_context(batch_length), *inputs) + expected_output = function(mock_scalar_udf_context(batch_length), *inputs) func = pc.get_function(name) assert func.name == name @@ -266,7 +266,7 @@ def check_scalar_function(func_fixture, assert result_table.column(0).chunks[0] == expected_output -def test_scalar_udf_array_unary(unary_func_fixture): +def test_udf_array_unary(unary_func_fixture): check_scalar_function(unary_func_fixture, [ pa.array([10, 20], pa.int64()) @@ -274,7 +274,7 @@ def test_scalar_udf_array_unary(unary_func_fixture): ) -def test_scalar_udf_array_binary(binary_func_fixture): +def test_udf_array_binary(binary_func_fixture): check_scalar_function(binary_func_fixture, [ pa.array([10, 20], pa.int64()), @@ -283,7 +283,7 @@ def test_scalar_udf_array_binary(binary_func_fixture): ) -def test_scalar_udf_array_ternary(ternary_func_fixture): +def test_udf_array_ternary(ternary_func_fixture): check_scalar_function(ternary_func_fixture, [ pa.array([10, 20], pa.int64()), @@ -293,7 +293,7 @@ def test_scalar_udf_array_ternary(ternary_func_fixture): ) -def test_scalar_udf_array_varargs(varargs_func_fixture): +def test_udf_array_varargs(varargs_func_fixture): check_scalar_function(varargs_func_fixture, [ pa.array([2, 3], pa.int64()), @@ -464,7 +464,7 @@ def identity(ctx, val): in_types, out_type) -def test_udf_context(unary_func_fixture): +def test_scalar_udf_context(unary_func_fixture): # Check the memory_pool argument is properly propagated proxy_pool = pa.proxy_memory_pool(pa.default_memory_pool()) _, func_name = unary_func_fixture @@ -504,3 +504,112 @@ def test_input_lifetime(unary_func_fixture): # Calling a UDF should not have kept `v` alive longer than required v = None assert proxy_pool.bytes_allocated() == 0 + + +def _record_batch_from_iters(schema, *iters): + arrays = [pa.array(list(v), type=schema[i].type) + for i, v in enumerate(iters)] + return pa.RecordBatch.from_arrays(arrays=arrays, schema=schema) + + +def _record_batch_for_range(schema, n): + return _record_batch_from_iters(schema, + range(n, n + 10), + range(n + 1, n + 11)) + + +def make_udt_func(schema, batch_gen): + def udf_func(ctx): + class UDT: + def __init__(self): + self.caller = None + + def __call__(self, ctx): + try: + if self.caller is None: + self.caller, ctx = batch_gen(ctx).send, None + batch = self.caller(ctx) + except StopIteration: + arrays = [pa.array([], type=field.type) + for field in schema] + batch = pa.RecordBatch.from_arrays( + arrays=arrays, schema=schema) + return batch.to_struct_array() + return UDT() + return udf_func + + +def datasource1_direct(): + """A short dataset""" + schema = datasource1_schema() + + class Generator: + def __init__(self): + self.n = 3 + + def __call__(self, ctx): + if self.n == 0: + batch = _record_batch_from_iters(schema, [], []) + else: + self.n -= 1 + batch = _record_batch_for_range(schema, self.n) + return batch.to_struct_array() + return lambda ctx: Generator() + + +def datasource1_generator(): + schema = datasource1_schema() + + def batch_gen(ctx): + for n in range(3, 0, -1): + # ctx = + yield _record_batch_for_range(schema, n - 1) + return make_udt_func(schema, batch_gen) + + +def datasource1_exception(): + schema = datasource1_schema() + + def batch_gen(ctx): + for n in range(3, 0, -1): + # ctx = + yield _record_batch_for_range(schema, n - 1) + raise RuntimeError("datasource1_exception") + return make_udt_func(schema, batch_gen) + + +def datasource1_schema(): + return pa.schema([('', pa.int32()), ('', pa.int32())]) + + +def datasource1_args(func, func_name): + func_doc = {"summary": f"{func_name} UDT", + "description": "test {func_name} UDT"} + in_types = {} + out_type = pa.struct([("", pa.int32()), ("", pa.int32())]) + return func, func_name, func_doc, in_types, out_type + + +def _test_datasource1_udt(func_maker): + schema = datasource1_schema() + func = func_maker() + func_name = func_maker.__name__ + func_args = datasource1_args(func, func_name) + pc.register_tabular_function(*func_args) + n = 3 + for item in pc.call_tabular_function(func_name): + n -= 1 + assert item == _record_batch_for_range(schema, n) + + +def test_udt_datasource1_direct(): + _test_datasource1_udt(datasource1_direct) + + +def test_udt_datasource1_generator(): + _test_datasource1_udt(datasource1_generator) + + +def test_udt_datasource1_exception(): + with pytest.raises(RuntimeError, match='datasource1_exception'): + _test_datasource1_udt(datasource1_exception)