Skip to content

Commit 1121bbc

Browse files
authored
ARROW-18111: [Go] Remaining scalar binary arithmetic (shifts, power, bitwise) (apache#14703)
Authored-by: Matt Topol <zotthewizard@gmail.com> Signed-off-by: Matt Topol <zotthewizard@gmail.com>
1 parent ad54d6c commit 1121bbc

11 files changed

Lines changed: 778 additions & 117 deletions

go/arrow/compute/arithmetic.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,41 @@ func (fn *arithmeticFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exe
121121
return fn.DispatchExact(vals...)
122122
}
123123

124+
type arithmeticDecimalToFloatingPointFunc struct {
125+
arithmeticFunction
126+
}
127+
128+
func (fn *arithmeticDecimalToFloatingPointFunc) Execute(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) {
129+
return execInternal(ctx, fn, opts, -1, args...)
130+
}
131+
132+
func (fn *arithmeticDecimalToFloatingPointFunc) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) {
133+
if err := fn.checkArity(len(vals)); err != nil {
134+
return nil, err
135+
}
136+
137+
if kn, err := fn.DispatchExact(vals...); err == nil {
138+
return kn, nil
139+
}
140+
141+
ensureDictionaryDecoded(vals...)
142+
if len(vals) == 2 {
143+
replaceNullWithOtherType(vals...)
144+
}
145+
146+
for i, t := range vals {
147+
if arrow.IsDecimal(t.ID()) {
148+
vals[i] = arrow.PrimitiveTypes.Float64
149+
}
150+
}
151+
152+
if dt := commonNumeric(vals...); dt != nil {
153+
replaceTypes(dt, vals...)
154+
}
155+
156+
return fn.DispatchExact(vals...)
157+
}
158+
124159
var (
125160
addDoc FunctionDoc
126161
)
@@ -370,6 +405,68 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
370405
}
371406

372407
reg.AddFunction(fn, false)
408+
409+
ops = []struct {
410+
funcName string
411+
op kernels.ArithmeticOp
412+
decPromote decimalPromotion
413+
}{
414+
{"power_unchecked", kernels.OpPower, decPromoteNone},
415+
{"power", kernels.OpPowerChecked, decPromoteNone},
416+
}
417+
418+
for _, o := range ops {
419+
fn := &arithmeticDecimalToFloatingPointFunc{arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), EmptyFuncDoc), o.decPromote}}
420+
kns := kernels.GetArithmeticBinaryKernels(o.op)
421+
for _, k := range kns {
422+
if err := fn.AddKernel(k); err != nil {
423+
panic(err)
424+
}
425+
}
426+
reg.AddFunction(fn, false)
427+
}
428+
429+
bitWiseOps := []struct {
430+
funcName string
431+
op kernels.BitwiseOp
432+
}{
433+
{"bit_wise_and", kernels.OpBitAnd},
434+
{"bit_wise_or", kernels.OpBitOr},
435+
{"bit_wise_xor", kernels.OpBitXor},
436+
}
437+
438+
for _, o := range bitWiseOps {
439+
fn := &arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), EmptyFuncDoc), decPromoteNone}
440+
kns := kernels.GetBitwiseBinaryKernels(o.op)
441+
for _, k := range kns {
442+
if err := fn.AddKernel(k); err != nil {
443+
panic(err)
444+
}
445+
}
446+
reg.AddFunction(fn, false)
447+
}
448+
449+
shiftOps := []struct {
450+
funcName string
451+
dir kernels.ShiftDir
452+
checked bool
453+
}{
454+
{"shift_left", kernels.ShiftLeft, true},
455+
{"shift_left_unchecked", kernels.ShiftLeft, false},
456+
{"shift_right", kernels.ShiftRight, true},
457+
{"shift_right_unchecked", kernels.ShiftRight, false},
458+
}
459+
460+
for _, o := range shiftOps {
461+
fn := &arithmeticFunction{*NewScalarFunction(o.funcName, Binary(), EmptyFuncDoc), decPromoteNone}
462+
kns := kernels.GetShiftKernels(o.dir, o.checked)
463+
for _, k := range kns {
464+
if err := fn.AddKernel(k); err != nil {
465+
panic(err)
466+
}
467+
}
468+
reg.AddFunction(fn, false)
469+
}
373470
}
374471

375472
func impl(ctx context.Context, fn string, opts ArithmeticOptions, left, right Datum) (Datum, error) {
@@ -463,3 +560,39 @@ func Negate(ctx context.Context, opts ArithmeticOptions, input Datum) (Datum, er
463560
func Sign(ctx context.Context, input Datum) (Datum, error) {
464561
return CallFunction(ctx, "sign", nil, input)
465562
}
563+
564+
// Power returns base**exp for each element in the input arrays. Should work
565+
// for both Arrays and Scalars
566+
func Power(ctx context.Context, opts ArithmeticOptions, base, exp Datum) (Datum, error) {
567+
fn := "power"
568+
if opts.NoCheckOverflow {
569+
fn += "_unchecked"
570+
}
571+
return CallFunction(ctx, fn, nil, base, exp)
572+
}
573+
574+
// ShiftLeft only accepts integral types and shifts each element of the
575+
// first argument to the left by the value of the corresponding element
576+
// in the second argument.
577+
//
578+
// The value to shift by should be >= 0 and < precision of the type.
579+
func ShiftLeft(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Datum, error) {
580+
fn := "shift_left"
581+
if opts.NoCheckOverflow {
582+
fn += "_unchecked"
583+
}
584+
return CallFunction(ctx, fn, nil, lhs, rhs)
585+
}
586+
587+
// ShiftRight only accepts integral types and shifts each element of the
588+
// first argument to the right by the value of the corresponding element
589+
// in the second argument.
590+
//
591+
// The value to shift by should be >= 0 and < precision of the type.
592+
func ShiftRight(ctx context.Context, opts ArithmeticOptions, lhs, rhs Datum) (Datum, error) {
593+
fn := "shift_right"
594+
if opts.NoCheckOverflow {
595+
fn += "_unchecked"
596+
}
597+
return CallFunction(ctx, fn, nil, lhs, rhs)
598+
}

0 commit comments

Comments
 (0)