@@ -25,6 +25,44 @@ import (
2525 "github.com/apache/arrow/go/v11/arrow"
2626 "github.com/apache/arrow/go/v11/arrow/compute/internal/exec"
2727 "github.com/apache/arrow/go/v11/arrow/compute/internal/kernels"
28+ "github.com/apache/arrow/go/v11/arrow/decimal128"
29+ "github.com/apache/arrow/go/v11/arrow/decimal256"
30+ "github.com/apache/arrow/go/v11/arrow/scalar"
31+ )
32+
33+ type (
34+ RoundOptions = kernels.RoundOptions
35+ RoundMode = kernels.RoundMode
36+ RoundToMultipleOptions = kernels.RoundToMultipleOptions
37+ )
38+
39+ const (
40+ // Round to nearest integer less than or equal in magnitude (aka "floor")
41+ RoundDown = kernels .RoundDown
42+ // Round to nearest integer greater than or equal in magnitude (aka "ceil")
43+ RoundUp = kernels .RoundUp
44+ // Get integral part without fractional digits (aka "trunc")
45+ RoundTowardsZero = kernels .TowardsZero
46+ // Round negative values with DOWN and positive values with UP
47+ RoundTowardsInfinity = kernels .AwayFromZero
48+ // Round ties with DOWN (aka "round half towards negative infinity")
49+ RoundHalfDown = kernels .HalfDown
50+ // Round ties with UP (aka "round half towards positive infinity")
51+ RoundHalfUp = kernels .HalfUp
52+ // Round ties with TowardsZero (aka "round half away from infinity")
53+ RoundHalfTowardsZero = kernels .HalfTowardsZero
54+ // Round ties with AwayFromZero (aka "round half towards infinity")
55+ RoundHalfTowardsInfinity = kernels .HalfAwayFromZero
56+ // Round ties to nearest even integer
57+ RoundHalfToEven = kernels .HalfToEven
58+ // Round ties to nearest odd integer
59+ RoundHalfToOdd = kernels .HalfToOdd
60+ )
61+
62+ var (
63+ DefaultRoundOptions = RoundOptions {NDigits : 0 , Mode : RoundHalfToEven }
64+ DefaultRoundToMultipleOptions = RoundToMultipleOptions {
65+ Multiple : scalar .NewFloat64Scalar (1 ), Mode : RoundHalfToEven }
2866)
2967
3068type arithmeticFunction struct {
@@ -121,6 +159,7 @@ func (fn *arithmeticFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exe
121159 return fn .DispatchExact (vals ... )
122160}
123161
162+ // function that promotes only decimal arguments to float64
124163type arithmeticDecimalToFloatingPointFunc struct {
125164 arithmeticFunction
126165}
@@ -156,6 +195,46 @@ func (fn *arithmeticDecimalToFloatingPointFunc) DispatchBest(vals ...arrow.DataT
156195 return fn .DispatchExact (vals ... )
157196}
158197
198+ // function that promotes only integer arguments to float64
199+ type arithmeticIntegerToFloatingPointFunc struct {
200+ arithmeticFunction
201+ }
202+
203+ func (fn * arithmeticIntegerToFloatingPointFunc ) Execute (ctx context.Context , opts FunctionOptions , args ... Datum ) (Datum , error ) {
204+ return execInternal (ctx , fn , opts , - 1 , args ... )
205+ }
206+
207+ func (fn * arithmeticIntegerToFloatingPointFunc ) DispatchBest (vals ... arrow.DataType ) (exec.Kernel , error ) {
208+ if err := fn .checkArity (len (vals )); err != nil {
209+ return nil , err
210+ }
211+
212+ if err := fn .checkDecimals (vals ... ); err != nil {
213+ return nil , err
214+ }
215+
216+ if kn , err := fn .DispatchExact (vals ... ); err == nil {
217+ return kn , nil
218+ }
219+
220+ ensureDictionaryDecoded (vals ... )
221+ if len (vals ) == 2 {
222+ replaceNullWithOtherType (vals ... )
223+ }
224+
225+ for i , t := range vals {
226+ if arrow .IsInteger (t .ID ()) {
227+ vals [i ] = arrow .PrimitiveTypes .Float64
228+ }
229+ }
230+
231+ if dt := commonNumeric (vals ... ); dt != nil {
232+ replaceTypes (dt , vals ... )
233+ }
234+
235+ return fn .DispatchExact (vals ... )
236+ }
237+
159238var (
160239 addDoc FunctionDoc
161240)
@@ -382,6 +461,25 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
382461 }{
383462 {"sqrt_unchecked" , kernels .OpSqrt , decPromoteNone },
384463 {"sqrt" , kernels .OpSqrtChecked , decPromoteNone },
464+ {"sin_unchecked" , kernels .OpSin , decPromoteNone },
465+ {"sin" , kernels .OpSinChecked , decPromoteNone },
466+ {"cos_unchecked" , kernels .OpCos , decPromoteNone },
467+ {"cos" , kernels .OpCosChecked , decPromoteNone },
468+ {"tan_unchecked" , kernels .OpTan , decPromoteNone },
469+ {"tan" , kernels .OpTanChecked , decPromoteNone },
470+ {"asin_unchecked" , kernels .OpAsin , decPromoteNone },
471+ {"asin" , kernels .OpAsinChecked , decPromoteNone },
472+ {"acos_unchecked" , kernels .OpAcos , decPromoteNone },
473+ {"acos" , kernels .OpAcosChecked , decPromoteNone },
474+ {"atan" , kernels .OpAtan , decPromoteNone },
475+ {"ln_unchecked" , kernels .OpLn , decPromoteNone },
476+ {"ln" , kernels .OpLnChecked , decPromoteNone },
477+ {"log10_unchecked" , kernels .OpLog10 , decPromoteNone },
478+ {"log10" , kernels .OpLog10Checked , decPromoteNone },
479+ {"log2_unchecked" , kernels .OpLog2 , decPromoteNone },
480+ {"log2" , kernels .OpLog2Checked , decPromoteNone },
481+ {"log1p_unchecked" , kernels .OpLog1p , decPromoteNone },
482+ {"log1p" , kernels .OpLog1pChecked , decPromoteNone },
385483 }
386484
387485 for _ , o := range ops {
@@ -396,6 +494,28 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
396494 reg .AddFunction (fn , false )
397495 }
398496
497+ ops = []struct {
498+ funcName string
499+ op kernels.ArithmeticOp
500+ decPromote decimalPromotion
501+ }{
502+ {"atan2" , kernels .OpAtan2 , decPromoteNone },
503+ {"logb_unchecked" , kernels .OpLogb , decPromoteNone },
504+ {"logb" , kernels .OpLogbChecked , decPromoteNone },
505+ }
506+
507+ for _ , o := range ops {
508+ fn := & arithmeticFloatingPointFunc {arithmeticFunction {* NewScalarFunction (o .funcName , Binary (), addDoc ), decPromoteNone }}
509+ kns := kernels .GetArithmeticFloatingPointKernels (o .op )
510+ for _ , k := range kns {
511+ if err := fn .AddKernel (k ); err != nil {
512+ panic (err )
513+ }
514+ }
515+
516+ reg .AddFunction (fn , false )
517+ }
518+
399519 fn = & arithmeticFunction {* NewScalarFunction ("sign" , Unary (), addDoc ), decPromoteNone }
400520 kns = kernels .GetArithmeticUnaryFixedIntOutKernels (arrow .PrimitiveTypes .Int8 , kernels .OpSign )
401521 for _ , k := range kns {
@@ -446,6 +566,15 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
446566 reg .AddFunction (fn , false )
447567 }
448568
569+ fn = & arithmeticFunction {* NewScalarFunction ("bit_wise_not" , Unary (), EmptyFuncDoc ), decPromoteNone }
570+ for _ , k := range kernels .GetBitwiseUnaryKernels () {
571+ if err := fn .AddKernel (k ); err != nil {
572+ panic (err )
573+ }
574+ }
575+
576+ reg .AddFunction (fn , false )
577+
449578 shiftOps := []struct {
450579 funcName string
451580 dir kernels.ShiftDir
@@ -467,6 +596,67 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
467596 }
468597 reg .AddFunction (fn , false )
469598 }
599+
600+ floorFn := & arithmeticIntegerToFloatingPointFunc {arithmeticFunction {* NewScalarFunction ("floor" , Unary (), EmptyFuncDoc ), decPromoteNone }}
601+ kns = kernels .GetSimpleRoundKernels (kernels .RoundDown )
602+ for _ , k := range kns {
603+ if err := floorFn .AddKernel (k ); err != nil {
604+ panic (err )
605+ }
606+ }
607+ floorFn .AddNewKernel ([]exec.InputType {exec .NewIDInput (arrow .DECIMAL128 )},
608+ kernels .OutputFirstType , kernels.FixedRoundDecimalExec [decimal128.Num ](kernels .RoundDown ), nil )
609+ floorFn .AddNewKernel ([]exec.InputType {exec .NewIDInput (arrow .DECIMAL256 )},
610+ kernels .OutputFirstType , kernels.FixedRoundDecimalExec [decimal256.Num ](kernels .RoundDown ), nil )
611+ reg .AddFunction (floorFn , false )
612+
613+ ceilFn := & arithmeticIntegerToFloatingPointFunc {arithmeticFunction {* NewScalarFunction ("ceil" , Unary (), EmptyFuncDoc ), decPromoteNone }}
614+ kns = kernels .GetSimpleRoundKernels (kernels .RoundUp )
615+ for _ , k := range kns {
616+ if err := ceilFn .AddKernel (k ); err != nil {
617+ panic (err )
618+ }
619+ }
620+ ceilFn .AddNewKernel ([]exec.InputType {exec .NewIDInput (arrow .DECIMAL128 )},
621+ kernels .OutputFirstType , kernels.FixedRoundDecimalExec [decimal128.Num ](kernels .RoundUp ), nil )
622+ ceilFn .AddNewKernel ([]exec.InputType {exec .NewIDInput (arrow .DECIMAL256 )},
623+ kernels .OutputFirstType , kernels.FixedRoundDecimalExec [decimal256.Num ](kernels .RoundUp ), nil )
624+ reg .AddFunction (ceilFn , false )
625+
626+ truncFn := & arithmeticIntegerToFloatingPointFunc {arithmeticFunction {* NewScalarFunction ("trunc" , Unary (), EmptyFuncDoc ), decPromoteNone }}
627+ kns = kernels .GetSimpleRoundKernels (kernels .TowardsZero )
628+ for _ , k := range kns {
629+ if err := truncFn .AddKernel (k ); err != nil {
630+ panic (err )
631+ }
632+ }
633+ truncFn .AddNewKernel ([]exec.InputType {exec .NewIDInput (arrow .DECIMAL128 )},
634+ kernels .OutputFirstType , kernels.FixedRoundDecimalExec [decimal128.Num ](kernels .TowardsZero ), nil )
635+ truncFn .AddNewKernel ([]exec.InputType {exec .NewIDInput (arrow .DECIMAL256 )},
636+ kernels .OutputFirstType , kernels.FixedRoundDecimalExec [decimal256.Num ](kernels .TowardsZero ), nil )
637+ reg .AddFunction (truncFn , false )
638+
639+ roundFn := & arithmeticIntegerToFloatingPointFunc {arithmeticFunction {* NewScalarFunction ("round" , Unary (), EmptyFuncDoc ), decPromoteNone }}
640+ kns = kernels .GetRoundUnaryKernels (kernels .InitRoundState , kernels .UnaryRoundExec )
641+ for _ , k := range kns {
642+ if err := roundFn .AddKernel (k ); err != nil {
643+ panic (err )
644+ }
645+ }
646+
647+ roundFn .defaultOpts = DefaultRoundOptions
648+ reg .AddFunction (roundFn , false )
649+
650+ roundToMultipleFn := & arithmeticIntegerToFloatingPointFunc {arithmeticFunction {* NewScalarFunction ("round_to_multiple" , Unary (), EmptyFuncDoc ), decPromoteNone }}
651+ kns = kernels .GetRoundUnaryKernels (kernels .InitRoundToMultipleState , kernels .UnaryRoundToMultipleExec )
652+ for _ , k := range kns {
653+ if err := roundToMultipleFn .AddKernel (k ); err != nil {
654+ panic (err )
655+ }
656+ }
657+
658+ roundToMultipleFn .defaultOpts = DefaultRoundToMultipleOptions
659+ reg .AddFunction (roundToMultipleFn , false )
470660}
471661
472662func impl (ctx context.Context , fn string , opts ArithmeticOptions , left , right Datum ) (Datum , error ) {
@@ -596,3 +786,99 @@ func ShiftRight(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Da
596786 }
597787 return CallFunction (ctx , fn , nil , lhs , rhs )
598788}
789+
790+ func Sin (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
791+ fn := "sin"
792+ if opts .NoCheckOverflow {
793+ fn += "_unchecked"
794+ }
795+ return CallFunction (ctx , fn , nil , arg )
796+ }
797+
798+ func Cos (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
799+ fn := "cos"
800+ if opts .NoCheckOverflow {
801+ fn += "_unchecked"
802+ }
803+ return CallFunction (ctx , fn , nil , arg )
804+ }
805+
806+ func Tan (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
807+ fn := "tan"
808+ if opts .NoCheckOverflow {
809+ fn += "_unchecked"
810+ }
811+ return CallFunction (ctx , fn , nil , arg )
812+ }
813+
814+ func Asin (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
815+ fn := "asin"
816+ if opts .NoCheckOverflow {
817+ fn += "_unchecked"
818+ }
819+ return CallFunction (ctx , fn , nil , arg )
820+ }
821+
822+ func Acos (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
823+ fn := "acos"
824+ if opts .NoCheckOverflow {
825+ fn += "_unchecked"
826+ }
827+ return CallFunction (ctx , fn , nil , arg )
828+ }
829+
830+ func Atan (ctx context.Context , arg Datum ) (Datum , error ) {
831+ return CallFunction (ctx , "atan" , nil , arg )
832+ }
833+
834+ func Atan2 (ctx context.Context , x , y Datum ) (Datum , error ) {
835+ return CallFunction (ctx , "atan2" , nil , x , y )
836+ }
837+
838+ func Ln (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
839+ fn := "ln"
840+ if opts .NoCheckOverflow {
841+ fn += "_unchecked"
842+ }
843+ return CallFunction (ctx , fn , nil , arg )
844+ }
845+
846+ func Log10 (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
847+ fn := "log10"
848+ if opts .NoCheckOverflow {
849+ fn += "_unchecked"
850+ }
851+ return CallFunction (ctx , fn , nil , arg )
852+ }
853+
854+ func Log2 (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
855+ fn := "log2"
856+ if opts .NoCheckOverflow {
857+ fn += "_unchecked"
858+ }
859+ return CallFunction (ctx , fn , nil , arg )
860+ }
861+
862+ func Log1p (ctx context.Context , opts ArithmeticOptions , arg Datum ) (Datum , error ) {
863+ fn := "log1p"
864+ if opts .NoCheckOverflow {
865+ fn += "_unchecked"
866+ }
867+ return CallFunction (ctx , fn , nil , arg )
868+ }
869+
870+ func Logb (ctx context.Context , opts ArithmeticOptions , x , base Datum ) (Datum , error ) {
871+ fn := "logb"
872+ if opts .NoCheckOverflow {
873+ fn += "_unchecked"
874+ }
875+ return CallFunction (ctx , fn , nil , x , base )
876+ }
877+
878+ func Round (ctx context.Context , opts RoundOptions , arg Datum ) (Datum , error ) {
879+ return CallFunction (ctx , "round" , & opts , arg )
880+ }
881+
882+ func RoundToMultiple (ctx context.Context , opts RoundToMultipleOptions , arg Datum ) (Datum , error ) {
883+ return CallFunction (ctx , "round_to_multiple" , & opts , arg )
884+ }
0 commit comments