Skip to content

Commit 1c8853b

Browse files
authored
ARROW-18112: [Go] Remaining Scalar Arithmetic (apache#14777)
Authored-by: Matt Topol <zotthewizard@gmail.com> Signed-off-by: Matt Topol <zotthewizard@gmail.com>
1 parent 1d9f778 commit 1c8853b

20 files changed

Lines changed: 3181 additions & 160 deletions

dev/release/rat_exclude_files.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ go/arrow/compute/go.sum
141141
go/arrow/compute/datumkind_string.go
142142
go/arrow/compute/funckind_string.go
143143
go/arrow/compute/internal/kernels/compareoperator_string.go
144+
go/arrow/compute/internal/kernels/roundmode_string.go
144145
go/arrow/compute/internal/kernels/_lib/vendored/*
145146
go/*.tmpldata
146147
go/*.s

go/arrow/compute/arithmetic.go

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,44 @@ import (
2525
"github.com/apache/arrow/go/v11/arrow"
2626
"github.com/apache/arrow/go/v11/arrow/compute/internal/exec"
2727
"github.com/apache/arrow/go/v11/arrow/compute/internal/kernels"
28+
"github.com/apache/arrow/go/v11/arrow/decimal128"
29+
"github.com/apache/arrow/go/v11/arrow/decimal256"
30+
"github.com/apache/arrow/go/v11/arrow/scalar"
31+
)
32+
33+
type (
34+
RoundOptions = kernels.RoundOptions
35+
RoundMode = kernels.RoundMode
36+
RoundToMultipleOptions = kernels.RoundToMultipleOptions
37+
)
38+
39+
const (
40+
// Round to nearest integer less than or equal in magnitude (aka "floor")
41+
RoundDown = kernels.RoundDown
42+
// Round to nearest integer greater than or equal in magnitude (aka "ceil")
43+
RoundUp = kernels.RoundUp
44+
// Get integral part without fractional digits (aka "trunc")
45+
RoundTowardsZero = kernels.TowardsZero
46+
// Round negative values with DOWN and positive values with UP
47+
RoundTowardsInfinity = kernels.AwayFromZero
48+
// Round ties with DOWN (aka "round half towards negative infinity")
49+
RoundHalfDown = kernels.HalfDown
50+
// Round ties with UP (aka "round half towards positive infinity")
51+
RoundHalfUp = kernels.HalfUp
52+
// Round ties with TowardsZero (aka "round half away from infinity")
53+
RoundHalfTowardsZero = kernels.HalfTowardsZero
54+
// Round ties with AwayFromZero (aka "round half towards infinity")
55+
RoundHalfTowardsInfinity = kernels.HalfAwayFromZero
56+
// Round ties to nearest even integer
57+
RoundHalfToEven = kernels.HalfToEven
58+
// Round ties to nearest odd integer
59+
RoundHalfToOdd = kernels.HalfToOdd
60+
)
61+
62+
var (
63+
DefaultRoundOptions = RoundOptions{NDigits: 0, Mode: RoundHalfToEven}
64+
DefaultRoundToMultipleOptions = RoundToMultipleOptions{
65+
Multiple: scalar.NewFloat64Scalar(1), Mode: RoundHalfToEven}
2866
)
2967

3068
type arithmeticFunction struct {
@@ -121,6 +159,7 @@ func (fn *arithmeticFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exe
121159
return fn.DispatchExact(vals...)
122160
}
123161

162+
// function that promotes only decimal arguments to float64
124163
type arithmeticDecimalToFloatingPointFunc struct {
125164
arithmeticFunction
126165
}
@@ -156,6 +195,46 @@ func (fn *arithmeticDecimalToFloatingPointFunc) DispatchBest(vals ...arrow.DataT
156195
return fn.DispatchExact(vals...)
157196
}
158197

198+
// function that promotes only integer arguments to float64
199+
type arithmeticIntegerToFloatingPointFunc struct {
200+
arithmeticFunction
201+
}
202+
203+
func (fn *arithmeticIntegerToFloatingPointFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) {
204+
return execInternal(ctx, fn, opts, -1, args...)
205+
}
206+
207+
func (fn *arithmeticIntegerToFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) {
208+
if err := fn.checkArity(len(vals)); err != nil {
209+
return nil, err
210+
}
211+
212+
if err := fn.checkDecimals(vals...); err != nil {
213+
return nil, err
214+
}
215+
216+
if kn, err := fn.DispatchExact(vals...); err == nil {
217+
return kn, nil
218+
}
219+
220+
ensureDictionaryDecoded(vals...)
221+
if len(vals) == 2 {
222+
replaceNullWithOtherType(vals...)
223+
}
224+
225+
for i, t := range vals {
226+
if arrow.IsInteger(t.ID()) {
227+
vals[i] = arrow.PrimitiveTypes.Float64
228+
}
229+
}
230+
231+
if dt := commonNumeric(vals...); dt != nil {
232+
replaceTypes(dt, vals...)
233+
}
234+
235+
return fn.DispatchExact(vals...)
236+
}
237+
159238
var (
160239
addDoc FunctionDoc
161240
)
@@ -382,6 +461,25 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
382461
}{
383462
{"sqrt_unchecked", kernels.OpSqrt, decPromoteNone},
384463
{"sqrt", kernels.OpSqrtChecked, decPromoteNone},
464+
{"sin_unchecked", kernels.OpSin, decPromoteNone},
465+
{"sin", kernels.OpSinChecked, decPromoteNone},
466+
{"cos_unchecked", kernels.OpCos, decPromoteNone},
467+
{"cos", kernels.OpCosChecked, decPromoteNone},
468+
{"tan_unchecked", kernels.OpTan, decPromoteNone},
469+
{"tan", kernels.OpTanChecked, decPromoteNone},
470+
{"asin_unchecked", kernels.OpAsin, decPromoteNone},
471+
{"asin", kernels.OpAsinChecked, decPromoteNone},
472+
{"acos_unchecked", kernels.OpAcos, decPromoteNone},
473+
{"acos", kernels.OpAcosChecked, decPromoteNone},
474+
{"atan", kernels.OpAtan, decPromoteNone},
475+
{"ln_unchecked", kernels.OpLn, decPromoteNone},
476+
{"ln", kernels.OpLnChecked, decPromoteNone},
477+
{"log10_unchecked", kernels.OpLog10, decPromoteNone},
478+
{"log10", kernels.OpLog10Checked, decPromoteNone},
479+
{"log2_unchecked", kernels.OpLog2, decPromoteNone},
480+
{"log2", kernels.OpLog2Checked, decPromoteNone},
481+
{"log1p_unchecked", kernels.OpLog1p, decPromoteNone},
482+
{"log1p", kernels.OpLog1pChecked, decPromoteNone},
385483
}
386484

387485
for _, o := range ops {
@@ -396,6 +494,28 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
396494
reg.AddFunction(fn, false)
397495
}
398496

497+
ops = []struct {
498+
funcName string
499+
op kernels.ArithmeticOp
500+
decPromote decimalPromotion
501+
}{
502+
{"atan2", kernels.OpAtan2, decPromoteNone},
503+
{"logb_unchecked", kernels.OpLogb, decPromoteNone},
504+
{"logb", kernels.OpLogbChecked, decPromoteNone},
505+
}
506+
507+
for _, o := range ops {
508+
fn := &arithmeticFloatingPointFunc{arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), addDoc), decPromoteNone}}
509+
kns := kernels.GetArithmeticFloatingPointKernels(o.op)
510+
for _, k := range kns {
511+
if err := fn.AddKernel(k); err != nil {
512+
panic(err)
513+
}
514+
}
515+
516+
reg.AddFunction(fn, false)
517+
}
518+
399519
fn = &arithmeticFunction{*NewScalarFunction("sign", Unary(), addDoc), decPromoteNone}
400520
kns = kernels.GetArithmeticUnaryFixedIntOutKernels(arrow.PrimitiveTypes.Int8, kernels.OpSign)
401521
for _, k := range kns {
@@ -446,6 +566,15 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
446566
reg.AddFunction(fn, false)
447567
}
448568

569+
fn = &arithmeticFunction{*NewScalarFunction("bit_wise_not", Unary(), EmptyFuncDoc), decPromoteNone}
570+
for _, k := range kernels.GetBitwiseUnaryKernels() {
571+
if err := fn.AddKernel(k); err != nil {
572+
panic(err)
573+
}
574+
}
575+
576+
reg.AddFunction(fn, false)
577+
449578
shiftOps := []struct {
450579
funcName string
451580
dir kernels.ShiftDir
@@ -467,6 +596,67 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
467596
}
468597
reg.AddFunction(fn, false)
469598
}
599+
600+
floorFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("floor", Unary(), EmptyFuncDoc), decPromoteNone}}
601+
kns = kernels.GetSimpleRoundKernels(kernels.RoundDown)
602+
for _, k := range kns {
603+
if err := floorFn.AddKernel(k); err != nil {
604+
panic(err)
605+
}
606+
}
607+
floorFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
608+
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal128.Num](kernels.RoundDown), nil)
609+
floorFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
610+
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal256.Num](kernels.RoundDown), nil)
611+
reg.AddFunction(floorFn, false)
612+
613+
ceilFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("ceil", Unary(), EmptyFuncDoc), decPromoteNone}}
614+
kns = kernels.GetSimpleRoundKernels(kernels.RoundUp)
615+
for _, k := range kns {
616+
if err := ceilFn.AddKernel(k); err != nil {
617+
panic(err)
618+
}
619+
}
620+
ceilFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
621+
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal128.Num](kernels.RoundUp), nil)
622+
ceilFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
623+
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal256.Num](kernels.RoundUp), nil)
624+
reg.AddFunction(ceilFn, false)
625+
626+
truncFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("trunc", Unary(), EmptyFuncDoc), decPromoteNone}}
627+
kns = kernels.GetSimpleRoundKernels(kernels.TowardsZero)
628+
for _, k := range kns {
629+
if err := truncFn.AddKernel(k); err != nil {
630+
panic(err)
631+
}
632+
}
633+
truncFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL128)},
634+
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal128.Num](kernels.TowardsZero), nil)
635+
truncFn.AddNewKernel([]exec.InputType{exec.NewIDInput(arrow.DECIMAL256)},
636+
kernels.OutputFirstType, kernels.FixedRoundDecimalExec[decimal256.Num](kernels.TowardsZero), nil)
637+
reg.AddFunction(truncFn, false)
638+
639+
roundFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("round", Unary(), EmptyFuncDoc), decPromoteNone}}
640+
kns = kernels.GetRoundUnaryKernels(kernels.InitRoundState, kernels.UnaryRoundExec)
641+
for _, k := range kns {
642+
if err := roundFn.AddKernel(k); err != nil {
643+
panic(err)
644+
}
645+
}
646+
647+
roundFn.defaultOpts = DefaultRoundOptions
648+
reg.AddFunction(roundFn, false)
649+
650+
roundToMultipleFn := &arithmeticIntegerToFloatingPointFunc{arithmeticFunction{*NewScalarFunction("round_to_multiple", Unary(), EmptyFuncDoc), decPromoteNone}}
651+
kns = kernels.GetRoundUnaryKernels(kernels.InitRoundToMultipleState, kernels.UnaryRoundToMultipleExec)
652+
for _, k := range kns {
653+
if err := roundToMultipleFn.AddKernel(k); err != nil {
654+
panic(err)
655+
}
656+
}
657+
658+
roundToMultipleFn.defaultOpts = DefaultRoundToMultipleOptions
659+
reg.AddFunction(roundToMultipleFn, false)
470660
}
471661

472662
func impl(ctx context.Context, fn string, opts ArithmeticOptions, left, right Datum) (Datum, error) {
@@ -596,3 +786,99 @@ func ShiftRight(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Da
596786
}
597787
return CallFunction(ctx, fn, nil, lhs, rhs)
598788
}
789+
790+
func Sin(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
791+
fn := "sin"
792+
if opts.NoCheckOverflow {
793+
fn += "_unchecked"
794+
}
795+
return CallFunction(ctx, fn, nil, arg)
796+
}
797+
798+
func Cos(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
799+
fn := "cos"
800+
if opts.NoCheckOverflow {
801+
fn += "_unchecked"
802+
}
803+
return CallFunction(ctx, fn, nil, arg)
804+
}
805+
806+
func Tan(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
807+
fn := "tan"
808+
if opts.NoCheckOverflow {
809+
fn += "_unchecked"
810+
}
811+
return CallFunction(ctx, fn, nil, arg)
812+
}
813+
814+
func Asin(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
815+
fn := "asin"
816+
if opts.NoCheckOverflow {
817+
fn += "_unchecked"
818+
}
819+
return CallFunction(ctx, fn, nil, arg)
820+
}
821+
822+
func Acos(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
823+
fn := "acos"
824+
if opts.NoCheckOverflow {
825+
fn += "_unchecked"
826+
}
827+
return CallFunction(ctx, fn, nil, arg)
828+
}
829+
830+
func Atan(ctx context.Context, arg Datum) (Datum, error) {
831+
return CallFunction(ctx, "atan", nil, arg)
832+
}
833+
834+
func Atan2(ctx context.Context, x, y Datum) (Datum, error) {
835+
return CallFunction(ctx, "atan2", nil, x, y)
836+
}
837+
838+
func Ln(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
839+
fn := "ln"
840+
if opts.NoCheckOverflow {
841+
fn += "_unchecked"
842+
}
843+
return CallFunction(ctx, fn, nil, arg)
844+
}
845+
846+
func Log10(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
847+
fn := "log10"
848+
if opts.NoCheckOverflow {
849+
fn += "_unchecked"
850+
}
851+
return CallFunction(ctx, fn, nil, arg)
852+
}
853+
854+
func Log2(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
855+
fn := "log2"
856+
if opts.NoCheckOverflow {
857+
fn += "_unchecked"
858+
}
859+
return CallFunction(ctx, fn, nil, arg)
860+
}
861+
862+
func Log1p(ctx context.Context, opts ArithmeticOptions, arg Datum) (Datum, error) {
863+
fn := "log1p"
864+
if opts.NoCheckOverflow {
865+
fn += "_unchecked"
866+
}
867+
return CallFunction(ctx, fn, nil, arg)
868+
}
869+
870+
func Logb(ctx context.Context, opts ArithmeticOptions, x, base Datum) (Datum, error) {
871+
fn := "logb"
872+
if opts.NoCheckOverflow {
873+
fn += "_unchecked"
874+
}
875+
return CallFunction(ctx, fn, nil, x, base)
876+
}
877+
878+
func Round(ctx context.Context, opts RoundOptions, arg Datum) (Datum, error) {
879+
return CallFunction(ctx, "round", &opts, arg)
880+
}
881+
882+
func RoundToMultiple(ctx context.Context, opts RoundToMultipleOptions, arg Datum) (Datum, error) {
883+
return CallFunction(ctx, "round_to_multiple", &opts, arg)
884+
}

0 commit comments

Comments
 (0)