Skip to content

Commit a1d9b51

Browse files
authored
apacheGH-32916: [C++] [Python] User-defined tabular functions (apache#14682)
See https://issues.apache.org/jira/browse/ARROW-17676 * Closes: apache#32916 Lead-authored-by: Yaron Gvili <rtpsw@hotmail.com> Co-authored-by: rtpsw <rtpsw@hotmail.com> Signed-off-by: Weston Pace <weston.pace@gmail.com>
1 parent 295c664 commit a1d9b51

11 files changed

Lines changed: 566 additions & 66 deletions

File tree

cpp/src/arrow/type.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,16 @@ std::string TypeHolder::ToString(const std::vector<TypeHolder>& types) {
452452
return ss.str();
453453
}
454454

455+
std::vector<TypeHolder> TypeHolder::FromTypes(
456+
const std::vector<std::shared_ptr<DataType>>& types) {
457+
std::vector<TypeHolder> type_holders;
458+
type_holders.reserve(types.size());
459+
for (const auto& type : types) {
460+
type_holders.emplace_back(type);
461+
}
462+
return type_holders;
463+
}
464+
455465
// ----------------------------------------------------------------------
456466

457467
FloatingPointType::Precision HalfFloatType::precision() const {

cpp/src/arrow/type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ struct ARROW_EXPORT TypeHolder {
264264
}
265265

266266
static std::string ToString(const std::vector<TypeHolder>&);
267+
268+
static std::vector<TypeHolder> FromTypes(
269+
const std::vector<std::shared_ptr<DataType>>& types);
267270
};
268271

269272
ARROW_EXPORT

python/pyarrow/_compute.pyx

Lines changed: 180 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ import inspect
3636
import numpy as np
3737

3838

39+
def _forbid_instantiation(klass, subclasses_instead=True):
40+
msg = '{} is an abstract class thus cannot be initialized.'.format(
41+
klass.__name__
42+
)
43+
if subclasses_instead:
44+
subclasses = [cls.__name__ for cls in klass.__subclasses__]
45+
msg += ' Use one of the subclasses instead: {}'.format(
46+
', '.join(subclasses)
47+
)
48+
raise TypeError(msg)
49+
50+
3951
cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
4052
"""
4153
Wrap a C++ scalar Function in a ScalarFunction object.
@@ -2574,7 +2586,7 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context):
25742586
return context
25752587

25762588

2577-
cdef _scalar_udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
2589+
cdef _udf_callback(user_function, const CScalarUdfContext& c_context, inputs):
25782590
"""
25792591
Helper callback function used to wrap the ScalarUdfContext from Python to C++
25802592
execution.
@@ -2591,8 +2603,30 @@ def _get_scalar_udf_context(memory_pool, batch_length):
25912603
return context
25922604

25932605

2594-
def register_scalar_function(func, function_name, function_doc, in_types,
2595-
out_type):
2606+
ctypedef CStatus (*CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper,
2607+
const CUdfOptions& options, CFunctionRegistry* registry)
2608+
2609+
cdef class RegisterUdf(_Weakrefable):
2610+
cdef CRegisterUdf register_func
2611+
2612+
cdef void init(self, const CRegisterUdf register_func):
2613+
self.register_func = register_func
2614+
2615+
2616+
cdef get_register_scalar_function():
2617+
cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
2618+
reg.register_func = RegisterScalarFunction
2619+
return reg
2620+
2621+
2622+
cdef get_register_tabular_function():
2623+
cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
2624+
reg.register_func = RegisterTabularFunction
2625+
return reg
2626+
2627+
2628+
def register_scalar_function(func, function_name, function_doc, in_types, out_type,
2629+
func_registry=None):
25962630
"""
25972631
Register a user-defined scalar function.
25982632
@@ -2633,6 +2667,8 @@ def register_scalar_function(func, function_name, function_doc, in_types,
26332667
arity.
26342668
out_type : DataType
26352669
Output type of the function.
2670+
func_registry : FunctionRegistry
2671+
Optional function registry to use instead of the default global one.
26362672
26372673
Examples
26382674
--------
@@ -2662,14 +2698,106 @@ def register_scalar_function(func, function_name, function_doc, in_types,
26622698
21
26632699
]
26642700
"""
2701+
return _register_scalar_like_function(get_register_scalar_function(),
2702+
func, function_name, function_doc, in_types,
2703+
out_type, func_registry)
2704+
2705+
2706+
def register_tabular_function(func, function_name, function_doc, in_types, out_type,
2707+
func_registry=None):
2708+
"""
2709+
Register a user-defined tabular function.
2710+
2711+
A tabular function is one accepting a context argument of type
2712+
ScalarUdfContext and returning a generator of struct arrays.
2713+
The in_types argument must be empty and the out_type argument
2714+
specifies a schema. Each struct array must have field types
2715+
correspoding to the schema.
2716+
2717+
Parameters
2718+
----------
2719+
func : callable
2720+
A callable implementing the user-defined function.
2721+
The only argument is the context argument of type
2722+
ScalarUdfContext. It must return a callable that
2723+
returns on each invocation a StructArray matching
2724+
the out_type, where an empty array indicates end.
2725+
function_name : str
2726+
Name of the function. This name must be globally unique.
2727+
function_doc : dict
2728+
A dictionary object with keys "summary" (str),
2729+
and "description" (str).
2730+
in_types : Dict[str, DataType]
2731+
Must be an empty dictionary (reserved for future use).
2732+
out_type : Union[Schema, DataType]
2733+
Schema of the function's output, or a corresponding flat struct type.
2734+
func_registry : FunctionRegistry
2735+
Optional function registry to use instead of the default global one.
2736+
"""
26652737
cdef:
2738+
shared_ptr[CSchema] c_schema
2739+
shared_ptr[CDataType] c_type
2740+
2741+
if isinstance(out_type, Schema):
2742+
c_schema = pyarrow_unwrap_schema(out_type)
2743+
with nogil:
2744+
c_type = <shared_ptr[CDataType]>make_shared[CStructType](deref(c_schema).fields())
2745+
out_type = pyarrow_wrap_data_type(c_type)
2746+
return _register_scalar_like_function(get_register_tabular_function(),
2747+
func, function_name, function_doc, in_types,
2748+
out_type, func_registry)
2749+
2750+
2751+
def _register_scalar_like_function(register_func, func, function_name, function_doc, in_types,
2752+
out_type, func_registry=None):
2753+
"""
2754+
Register a user-defined scalar-like function.
2755+
2756+
A scalar-like function is a callable accepting a first
2757+
context argument of type ScalarUdfContext as well as
2758+
possibly additional Arrow arguments, and returning a
2759+
an Arrow result appropriate for the kind of function.
2760+
A scalar function and a tabular function are examples
2761+
for scalar-like functions.
2762+
This function is normally not called directly but via
2763+
register_scalar_function or register_tabular_function.
2764+
2765+
Parameters
2766+
----------
2767+
register_func: object
2768+
An object holding a CRegisterUdf in a "register_func" attribute. Use
2769+
get_register_scalar_function() for a scalar function and
2770+
get_register_tabular_function() for a tabular function.
2771+
func : callable
2772+
A callable implementing the user-defined function.
2773+
See register_scalar_function and
2774+
register_tabular_function for details.
2775+
2776+
function_name : str
2777+
Name of the function. This name must be globally unique.
2778+
function_doc : dict
2779+
A dictionary object with keys "summary" (str),
2780+
and "description" (str).
2781+
in_types : Dict[str, DataType]
2782+
A dictionary mapping function argument names to
2783+
their respective DataType.
2784+
See register_scalar_function and
2785+
register_tabular_function for details.
2786+
out_type : DataType
2787+
Output type of the function.
2788+
func_registry : FunctionRegistry
2789+
Optional function registry to use instead of the default global one.
2790+
"""
2791+
cdef:
2792+
CRegisterUdf c_register_func
26662793
c_string c_func_name
26672794
CArity c_arity
26682795
CFunctionDoc c_func_doc
26692796
vector[shared_ptr[CDataType]] c_in_types
26702797
PyObject* c_function
26712798
shared_ptr[CDataType] c_out_type
2672-
CScalarUdfOptions c_options
2799+
CUdfOptions c_options
2800+
CFunctionRegistry* c_func_registry
26732801

26742802
if callable(func):
26752803
c_function = <PyObject*>func
@@ -2711,5 +2839,51 @@ def register_scalar_function(func, function_name, function_doc, in_types,
27112839
c_options.input_types = c_in_types
27122840
c_options.output_type = c_out_type
27132841

2714-
check_status(RegisterScalarFunction(c_function,
2715-
<function[CallbackUdf]> &_scalar_udf_callback, c_options))
2842+
if func_registry is None:
2843+
c_func_registry = NULL
2844+
else:
2845+
c_func_registry = (<FunctionRegistry>func_registry).registry
2846+
2847+
c_register_func = (<RegisterUdf>register_func).register_func
2848+
2849+
check_status(c_register_func(c_function,
2850+
<function[CallbackUdf]> &_udf_callback,
2851+
c_options, c_func_registry))
2852+
2853+
2854+
def call_tabular_function(function_name, args=None, func_registry=None):
2855+
"""
2856+
Get a record batch iterator from a tabular function.
2857+
2858+
Parameters
2859+
----------
2860+
function_name : str
2861+
Name of the function.
2862+
args : iterable
2863+
The arguments to pass to the function. Accepted types depend
2864+
on the specific function. Currently, only an empty args is supported.
2865+
func_registry : FunctionRegistry
2866+
Optional function registry to use instead of the default global one.
2867+
"""
2868+
cdef:
2869+
c_string c_func_name
2870+
vector[CDatum] c_args
2871+
CFunctionRegistry* c_func_registry
2872+
shared_ptr[CRecordBatchReader] c_reader
2873+
RecordBatchReader reader
2874+
2875+
c_func_name = tobytes(function_name)
2876+
if func_registry is None:
2877+
c_func_registry = NULL
2878+
else:
2879+
c_func_registry = (<FunctionRegistry>func_registry).registry
2880+
if args is None:
2881+
args = []
2882+
_pack_compute_args(args, &c_args)
2883+
2884+
with nogil:
2885+
c_reader = GetResultValue(CallTabularFunction(
2886+
c_func_name, c_args, c_func_registry))
2887+
reader = RecordBatchReader.__new__(RecordBatchReader)
2888+
reader.reader = c_reader
2889+
return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader)

python/pyarrow/_dataset.pyx

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,13 @@ from pyarrow.lib cimport *
3232
from pyarrow.lib import ArrowTypeError, frombytes, tobytes, _pc
3333
from pyarrow.includes.libarrow_dataset cimport *
3434
from pyarrow._compute cimport Expression, _bind
35+
from pyarrow._compute import _forbid_instantiation
3536
from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
3637
from pyarrow._csv cimport (
3738
ConvertOptions, ParseOptions, ReadOptions, WriteOptions)
3839
from pyarrow.util import _is_iterable, _is_path_like, _stringify_path
3940

4041

41-
def _forbid_instantiation(klass, subclasses_instead=True):
42-
msg = '{} is an abstract class thus cannot be initialized.'.format(
43-
klass.__name__
44-
)
45-
if subclasses_instead:
46-
subclasses = [cls.__name__ for cls in klass.__subclasses__]
47-
msg += ' Use one of the subclasses instead: {}'.format(
48-
', '.join(subclasses)
49-
)
50-
raise TypeError(msg)
51-
52-
5342
_orc_fileformat = None
5443
_orc_imported = False
5544

python/pyarrow/compute.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@
8080
list_functions,
8181
_group_by,
8282
# Udf
83+
call_tabular_function,
8384
register_scalar_function,
85+
register_tabular_function,
8486
ScalarUdfContext,
8587
# Expressions
8688
Expression,

python/pyarrow/includes/libarrow.pxd

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
480480
vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
481481
int GetFieldIndex(const c_string& name)
482482
vector[int] GetAllFieldIndices(const c_string& name)
483+
const vector[shared_ptr[CField]] fields()
483484
int num_fields()
484485
c_string ToString()
485486

@@ -800,6 +801,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
800801
const shared_ptr[CSchema]& schema, int64_t num_rows,
801802
const vector[shared_ptr[CArray]]& columns)
802803

804+
CResult[shared_ptr[CStructArray]] ToStructArray() const
805+
803806
@staticmethod
804807
CResult[shared_ptr[CRecordBatch]] FromStructArray(
805808
const shared_ptr[CArray]& array)
@@ -2805,17 +2808,33 @@ cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil:
28052808

28062809
ctypedef PyObject* CallbackUdf(object user_function, const CScalarUdfContext& context, object inputs)
28072810

2808-
cdef extern from "arrow/python/udf.h" namespace "arrow::py":
2811+
2812+
cdef extern from "arrow/api.h" namespace "arrow" nogil:
2813+
2814+
cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
2815+
CIterator[shared_ptr[CRecordBatch]]):
2816+
pass
2817+
2818+
2819+
cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
28092820
cdef cppclass CScalarUdfContext" arrow::py::ScalarUdfContext":
28102821
CMemoryPool *pool
28112822
int64_t batch_length
28122823

2813-
cdef cppclass CScalarUdfOptions" arrow::py::ScalarUdfOptions":
2824+
cdef cppclass CUdfOptions" arrow::py::UdfOptions":
28142825
c_string func_name
28152826
CArity arity
28162827
CFunctionDoc func_doc
28172828
vector[shared_ptr[CDataType]] input_types
28182829
shared_ptr[CDataType] output_type
28192830

28202831
CStatus RegisterScalarFunction(PyObject* function,
2821-
function[CallbackUdf] wrapper, const CScalarUdfOptions& options)
2832+
function[CallbackUdf] wrapper, const CUdfOptions& options,
2833+
CFunctionRegistry* registry)
2834+
2835+
CStatus RegisterTabularFunction(PyObject* function,
2836+
function[CallbackUdf] wrapper, const CUdfOptions& options,
2837+
CFunctionRegistry* registry)
2838+
2839+
CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
2840+
const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry)

python/pyarrow/includes/libarrow_dataset.pxd

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@ from pyarrow.includes.libarrow cimport *
2525
from pyarrow.includes.libarrow_fs cimport *
2626

2727

28-
cdef extern from "arrow/api.h" namespace "arrow" nogil:
29-
30-
cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
31-
CIterator[shared_ptr[CRecordBatch]]):
32-
pass
33-
34-
3528
cdef extern from "arrow/dataset/plan.h" namespace "arrow::dataset::internal" nogil:
3629

3730
cdef void Initialize()

0 commit comments

Comments
 (0)