@@ -36,6 +36,18 @@ import inspect
3636import numpy as np
3737
3838
39+ def _forbid_instantiation (klass , subclasses_instead = True ):
40+ msg = ' {} is an abstract class thus cannot be initialized.' .format(
41+ klass.__name__
42+ )
43+ if subclasses_instead:
44+ subclasses = [cls .__name__ for cls in klass.__subclasses__]
45+ msg += ' Use one of the subclasses instead: {}' .format(
46+ ' , ' .join(subclasses)
47+ )
48+ raise TypeError (msg)
49+
50+
3951cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
4052 """
4153 Wrap a C++ scalar Function in a ScalarFunction object.
@@ -2574,7 +2586,7 @@ cdef object box_scalar_udf_context(const CScalarUdfContext& c_context):
25742586 return context
25752587
25762588
2577- cdef _scalar_udf_callback (user_function, const CScalarUdfContext& c_context, inputs):
2589+ cdef _udf_callback (user_function, const CScalarUdfContext& c_context, inputs):
25782590 """
25792591 Helper callback function used to wrap the ScalarUdfContext from Python to C++
25802592 execution.
@@ -2591,8 +2603,30 @@ def _get_scalar_udf_context(memory_pool, batch_length):
25912603 return context
25922604
25932605
2594- def register_scalar_function (func , function_name , function_doc , in_types ,
2595- out_type ):
2606+ ctypedef CStatus (* CRegisterUdf)(PyObject* function, function[CallbackUdf] wrapper,
2607+ const CUdfOptions& options, CFunctionRegistry* registry)
2608+
2609+ cdef class RegisterUdf(_Weakrefable):
2610+ cdef CRegisterUdf register_func
2611+
2612+ cdef void init(self , const CRegisterUdf register_func):
2613+ self .register_func = register_func
2614+
2615+
2616+ cdef get_register_scalar_function():
2617+ cdef RegisterUdf reg = RegisterUdf.__new__ (RegisterUdf)
2618+ reg.register_func = RegisterScalarFunction
2619+ return reg
2620+
2621+
2622+ cdef get_register_tabular_function():
2623+ cdef RegisterUdf reg = RegisterUdf.__new__ (RegisterUdf)
2624+ reg.register_func = RegisterTabularFunction
2625+ return reg
2626+
2627+
2628+ def register_scalar_function (func , function_name , function_doc , in_types , out_type ,
2629+ func_registry = None ):
25962630 """
25972631 Register a user-defined scalar function.
25982632
@@ -2633,6 +2667,8 @@ def register_scalar_function(func, function_name, function_doc, in_types,
26332667 arity.
26342668 out_type : DataType
26352669 Output type of the function.
2670+ func_registry : FunctionRegistry
2671+ Optional function registry to use instead of the default global one.
26362672
26372673 Examples
26382674 --------
@@ -2662,14 +2698,106 @@ def register_scalar_function(func, function_name, function_doc, in_types,
26622698 21
26632699 ]
26642700 """
2701+ return _register_scalar_like_function(get_register_scalar_function(),
2702+ func, function_name, function_doc, in_types,
2703+ out_type, func_registry)
2704+
2705+
2706+ def register_tabular_function (func , function_name , function_doc , in_types , out_type ,
2707+ func_registry = None ):
2708+ """
2709+ Register a user-defined tabular function.
2710+
2711+ A tabular function is one accepting a context argument of type
2712+ ScalarUdfContext and returning a generator of struct arrays.
2713+ The in_types argument must be empty and the out_type argument
2714+ specifies a schema. Each struct array must have field types
2715+ correspoding to the schema.
2716+
2717+ Parameters
2718+ ----------
2719+ func : callable
2720+ A callable implementing the user-defined function.
2721+ The only argument is the context argument of type
2722+ ScalarUdfContext. It must return a callable that
2723+ returns on each invocation a StructArray matching
2724+ the out_type, where an empty array indicates end.
2725+ function_name : str
2726+ Name of the function. This name must be globally unique.
2727+ function_doc : dict
2728+ A dictionary object with keys "summary" (str),
2729+ and "description" (str).
2730+ in_types : Dict[str, DataType]
2731+ Must be an empty dictionary (reserved for future use).
2732+ out_type : Union[Schema, DataType]
2733+ Schema of the function's output, or a corresponding flat struct type.
2734+ func_registry : FunctionRegistry
2735+ Optional function registry to use instead of the default global one.
2736+ """
26652737 cdef:
2738+ shared_ptr[CSchema] c_schema
2739+ shared_ptr[CDataType] c_type
2740+
2741+ if isinstance (out_type, Schema):
2742+ c_schema = pyarrow_unwrap_schema(out_type)
2743+ with nogil:
2744+ c_type = < shared_ptr[CDataType]> make_shared[CStructType](deref(c_schema).fields())
2745+ out_type = pyarrow_wrap_data_type(c_type)
2746+ return _register_scalar_like_function(get_register_tabular_function(),
2747+ func, function_name, function_doc, in_types,
2748+ out_type, func_registry)
2749+
2750+
2751+ def _register_scalar_like_function (register_func , func , function_name , function_doc , in_types ,
2752+ out_type , func_registry = None ):
2753+ """
2754+ Register a user-defined scalar-like function.
2755+
2756+ A scalar-like function is a callable accepting a first
2757+ context argument of type ScalarUdfContext as well as
2758+ possibly additional Arrow arguments, and returning a
2759+ an Arrow result appropriate for the kind of function.
2760+ A scalar function and a tabular function are examples
2761+ for scalar-like functions.
2762+ This function is normally not called directly but via
2763+ register_scalar_function or register_tabular_function.
2764+
2765+ Parameters
2766+ ----------
2767+ register_func: object
2768+ An object holding a CRegisterUdf in a "register_func" attribute. Use
2769+ get_register_scalar_function() for a scalar function and
2770+ get_register_tabular_function() for a tabular function.
2771+ func : callable
2772+ A callable implementing the user-defined function.
2773+ See register_scalar_function and
2774+ register_tabular_function for details.
2775+
2776+ function_name : str
2777+ Name of the function. This name must be globally unique.
2778+ function_doc : dict
2779+ A dictionary object with keys "summary" (str),
2780+ and "description" (str).
2781+ in_types : Dict[str, DataType]
2782+ A dictionary mapping function argument names to
2783+ their respective DataType.
2784+ See register_scalar_function and
2785+ register_tabular_function for details.
2786+ out_type : DataType
2787+ Output type of the function.
2788+ func_registry : FunctionRegistry
2789+ Optional function registry to use instead of the default global one.
2790+ """
2791+ cdef:
2792+ CRegisterUdf c_register_func
26662793 c_string c_func_name
26672794 CArity c_arity
26682795 CFunctionDoc c_func_doc
26692796 vector[shared_ptr[CDataType]] c_in_types
26702797 PyObject* c_function
26712798 shared_ptr[CDataType] c_out_type
2672- CScalarUdfOptions c_options
2799+ CUdfOptions c_options
2800+ CFunctionRegistry* c_func_registry
26732801
26742802 if callable (func):
26752803 c_function = < PyObject* > func
@@ -2711,5 +2839,51 @@ def register_scalar_function(func, function_name, function_doc, in_types,
27112839 c_options.input_types = c_in_types
27122840 c_options.output_type = c_out_type
27132841
2714- check_status(RegisterScalarFunction(c_function,
2715- < function[CallbackUdf]> & _scalar_udf_callback, c_options))
2842+ if func_registry is None :
2843+ c_func_registry = NULL
2844+ else :
2845+ c_func_registry = (< FunctionRegistry> func_registry).registry
2846+
2847+ c_register_func = (< RegisterUdf> register_func).register_func
2848+
2849+ check_status(c_register_func(c_function,
2850+ < function[CallbackUdf]> & _udf_callback,
2851+ c_options, c_func_registry))
2852+
2853+
2854+ def call_tabular_function (function_name , args = None , func_registry = None ):
2855+ """
2856+ Get a record batch iterator from a tabular function.
2857+
2858+ Parameters
2859+ ----------
2860+ function_name : str
2861+ Name of the function.
2862+ args : iterable
2863+ The arguments to pass to the function. Accepted types depend
2864+ on the specific function. Currently, only an empty args is supported.
2865+ func_registry : FunctionRegistry
2866+ Optional function registry to use instead of the default global one.
2867+ """
2868+ cdef:
2869+ c_string c_func_name
2870+ vector[CDatum] c_args
2871+ CFunctionRegistry* c_func_registry
2872+ shared_ptr[CRecordBatchReader] c_reader
2873+ RecordBatchReader reader
2874+
2875+ c_func_name = tobytes(function_name)
2876+ if func_registry is None :
2877+ c_func_registry = NULL
2878+ else :
2879+ c_func_registry = (< FunctionRegistry> func_registry).registry
2880+ if args is None :
2881+ args = []
2882+ _pack_compute_args(args, & c_args)
2883+
2884+ with nogil:
2885+ c_reader = GetResultValue(CallTabularFunction(
2886+ c_func_name, c_args, c_func_registry))
2887+ reader = RecordBatchReader.__new__ (RecordBatchReader)
2888+ reader.reader = c_reader
2889+ return RecordBatchReader.from_batches(pyarrow_wrap_schema(deref(c_reader).schema()), reader)
0 commit comments