Skip to content

Commit 359eab5

Browse files
authored
ARROW-17532: [Go][Compute] Implement Numeric Cast functions (apache#13992)
Authored-by: Matt Topol <zotthewizard@gmail.com> Signed-off-by: Matt Topol <zotthewizard@gmail.com>
1 parent 74dae61 commit 359eab5

25 files changed

Lines changed: 43387 additions & 40 deletions

go/arrow/array/data.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ func (d *Data) Release() {
159159
// DataType returns the DataType of the data.
160160
func (d *Data) DataType() arrow.DataType { return d.dtype }
161161

162+
func (d *Data) SetNullN(n int) { d.nulls = n }
163+
162164
// NullN returns the number of nulls.
163165
func (d *Data) NullN() int { return d.nulls }
164166

go/arrow/array/decimal128.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ func (b *Decimal128Builder) unmarshalOne(dec *json.Decoder) error {
266266
}
267267

268268
var out *big.Float
269+
var tmp big.Int
269270

270271
switch v := t.(type) {
271272
case float64:
@@ -275,12 +276,12 @@ func (b *Decimal128Builder) unmarshalOne(dec *json.Decoder) error {
275276
// what got me the closest equivalent values with the values
276277
// that I tested with, and there isn't a good way to push
277278
// an option all the way down here to control it.
278-
out, _, err = big.ParseFloat(v, 10, 128, big.ToNearestAway)
279+
out, _, err = big.ParseFloat(v, 10, 127, big.ToNearestAway)
279280
if err != nil {
280281
return err
281282
}
282283
case json.Number:
283-
out, _, err = big.ParseFloat(v.String(), 10, 128, big.ToNearestAway)
284+
out, _, err = big.ParseFloat(v.String(), 10, 127, big.ToNearestAway)
284285
if err != nil {
285286
return err
286287
}
@@ -295,7 +296,7 @@ func (b *Decimal128Builder) unmarshalOne(dec *json.Decoder) error {
295296
}
296297
}
297298

298-
val, _ := out.Mul(out, big.NewFloat(math.Pow10(int(b.dtype.Scale)))).Int(nil)
299+
val, _ := out.Mul(out, big.NewFloat(math.Pow10(int(b.dtype.Scale)))).Int(&tmp)
299300
b.Append(decimal128.FromBigInt(val))
300301
return nil
301302
}

go/arrow/array/decimal256.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,12 @@ func (b *Decimal256Builder) unmarshalOne(dec *json.Decoder) error {
275275
// what got me the closest equivalent values with the values
276276
// that I tested with, and there isn't a good way to push
277277
// an option all the way down here to control it.
278-
out, _, err = big.ParseFloat(v, 10, 256, big.ToNearestAway)
278+
out, _, err = big.ParseFloat(v, 10, 255, big.ToNearestAway)
279279
if err != nil {
280280
return err
281281
}
282282
case json.Number:
283-
out, _, err = big.ParseFloat(v.String(), 10, 256, big.ToNearestAway)
283+
out, _, err = big.ParseFloat(v.String(), 10, 255, big.ToNearestAway)
284284
if err != nil {
285285
return err
286286
}

go/arrow/compute/cast.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ func addCastFuncs(fn []*castFunction) {
141141
func initCastTable() {
142142
castTable = make(map[arrow.Type]*castFunction)
143143
addCastFuncs(getBooleanCasts())
144+
addCastFuncs(getNumericCasts())
144145
}
145146

146147
func getCastFunction(to arrow.DataType) (*castFunction, error) {
@@ -167,6 +168,67 @@ func getBooleanCasts() []*castFunction {
167168
return []*castFunction{fn}
168169
}
169170

171+
func getNumericCasts() []*castFunction {
172+
out := make([]*castFunction, 0)
173+
174+
getFn := func(name string, ty arrow.Type, kns []exec.ScalarKernel) *castFunction {
175+
fn := newCastFunction(name, ty)
176+
for _, k := range kns {
177+
if err := fn.AddTypeCast(k.Signature.InputTypes[0].MatchID(), k); err != nil {
178+
panic(err)
179+
}
180+
}
181+
return fn
182+
}
183+
184+
out = append(out, getFn("cast_int8", arrow.INT8, kernels.GetCastToInteger[int8](arrow.PrimitiveTypes.Int8)))
185+
out = append(out, getFn("cast_int16", arrow.INT16, kernels.GetCastToInteger[int8](arrow.PrimitiveTypes.Int16)))
186+
187+
castInt32 := getFn("cast_int32", arrow.INT32, kernels.GetCastToInteger[int32](arrow.PrimitiveTypes.Int32))
188+
castInt32.AddTypeCast(arrow.DATE32,
189+
kernels.GetZeroCastKernel(arrow.DATE32,
190+
exec.NewExactInput(arrow.FixedWidthTypes.Date32),
191+
exec.NewOutputType(arrow.PrimitiveTypes.Int32)))
192+
castInt32.AddTypeCast(arrow.TIME32,
193+
kernels.GetZeroCastKernel(arrow.TIME32,
194+
exec.NewIDInput(arrow.TIME32), exec.NewOutputType(arrow.PrimitiveTypes.Int32)))
195+
out = append(out, castInt32)
196+
197+
castInt64 := getFn("cast_int64", arrow.INT64, kernels.GetCastToInteger[int64](arrow.PrimitiveTypes.Int64))
198+
castInt64.AddTypeCast(arrow.DATE64,
199+
kernels.GetZeroCastKernel(arrow.DATE64,
200+
exec.NewIDInput(arrow.DATE64),
201+
exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
202+
castInt64.AddTypeCast(arrow.TIME64,
203+
kernels.GetZeroCastKernel(arrow.TIME64,
204+
exec.NewIDInput(arrow.TIME64),
205+
exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
206+
castInt64.AddTypeCast(arrow.DURATION,
207+
kernels.GetZeroCastKernel(arrow.DURATION,
208+
exec.NewIDInput(arrow.DURATION),
209+
exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
210+
castInt64.AddTypeCast(arrow.TIMESTAMP,
211+
kernels.GetZeroCastKernel(arrow.TIMESTAMP,
212+
exec.NewIDInput(arrow.TIMESTAMP),
213+
exec.NewOutputType(arrow.PrimitiveTypes.Int64)))
214+
out = append(out, castInt64)
215+
216+
out = append(out, getFn("cast_uint8", arrow.UINT8, kernels.GetCastToInteger[uint8](arrow.PrimitiveTypes.Uint8)))
217+
out = append(out, getFn("cast_uint16", arrow.UINT16, kernels.GetCastToInteger[uint16](arrow.PrimitiveTypes.Uint16)))
218+
out = append(out, getFn("cast_uint32", arrow.UINT32, kernels.GetCastToInteger[uint32](arrow.PrimitiveTypes.Uint32)))
219+
out = append(out, getFn("cast_uint64", arrow.UINT64, kernels.GetCastToInteger[uint64](arrow.PrimitiveTypes.Uint64)))
220+
221+
out = append(out, getFn("cast_half_float", arrow.FLOAT16, kernels.GetCommonCastKernels(arrow.FLOAT16, exec.NewOutputType(arrow.FixedWidthTypes.Float16))))
222+
out = append(out, getFn("cast_float", arrow.FLOAT32, kernels.GetCastToFloating[float32](arrow.PrimitiveTypes.Float32)))
223+
out = append(out, getFn("cast_double", arrow.FLOAT64, kernels.GetCastToFloating[float64](arrow.PrimitiveTypes.Float64)))
224+
225+
// cast to decimal128
226+
out = append(out, getFn("cast_decimal", arrow.DECIMAL128, kernels.GetCastToDecimal128()))
227+
// cast to decimal256
228+
out = append(out, getFn("cast_decimal256", arrow.DECIMAL256, kernels.GetCastToDecimal256()))
229+
return out
230+
}
231+
170232
// CastDatum is a convenience function for casting a Datum to another type.
171233
// It is equivalent to calling CallFunction(ctx, "cast", opts, Datum) and
172234
// should work for Scalar, Array or ChunkedArray Datums.

0 commit comments

Comments
 (0)