Skip to content

Commit d123277

Browse files
authored
ARROW-17600: [Go] Implement Casting for Nested types (apache#14056)
Authored-by: Matt Topol <zotthewizard@gmail.com> Signed-off-by: Matt Topol <zotthewizard@gmail.com>
1 parent 21491ec commit d123277

17 files changed

Lines changed: 821 additions & 37 deletions

ci/scripts/go_test.sh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,20 @@
1919

2020
set -ex
2121

22+
# simplistic semver comparison
23+
verlte() {
24+
[ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ]
25+
}
26+
verlt() {
27+
[ "$1" = "$2" ] && return 1 || verlte $1 $2
28+
}
29+
2230
ver=`go env GOVERSION`
2331

2432
source_dir=${1}/go
2533

2634
testargs="-race"
27-
if [[ "${ver#go}" =~ ^1\.1[8-9] ]] && [ "$(go env GOOS)" != "darwin" ]; then
35+
if verlte "1.18" "${ver#go}" && [ "$(go env GOOS)" != "darwin" ]; then
2836
# asan not supported on darwin/amd64
2937
testargs="-asan"
3038
fi
@@ -65,6 +73,11 @@ fi
6573

6674
go test $testargs -tags $TAGS ./...
6775

76+
# only test compute when Go is >= 1.18
77+
if verlte "1.18" "${ver#go}"; then
78+
go test $testargs -tags $TAGS ./compute/...
79+
fi
80+
6881
popd
6982

7083
export PARQUET_TEST_DATA=${1}/cpp/submodules/parquet-testing/data

go/arrow/array/array_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type testDataType struct {
3434
func (d *testDataType) ID() arrow.Type { return d.id }
3535
func (d *testDataType) Name() string { panic("implement me") }
3636
func (d *testDataType) BitWidth() int { return 8 }
37+
func (d *testDataType) Bytes() int { return 1 }
3738
func (d *testDataType) Fingerprint() string { return "" }
3839
func (testDataType) Layout() arrow.DataTypeLayout { return arrow.DataTypeLayout{} }
3940
func (testDataType) String() string { return "" }

go/arrow/array/struct.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,40 @@ type Struct struct {
3636
fields []arrow.Array
3737
}
3838

39+
// NewStructArray constructs a new Struct Array out of the columns passed
40+
// in and the field names. The length of all cols must be the same and
41+
// there should be the same number of columns as names.
42+
func NewStructArray(cols []arrow.Array, names []string) (*Struct, error) {
43+
return NewStructArrayWithNulls(cols, names, nil, 0, 0)
44+
}
45+
46+
// NewStructArrayWithNulls is like NewStructArray as a convenience function,
47+
// but also takes in a null bitmap, the number of nulls, and an optional offset
48+
// to use for creating the Struct Array.
49+
func NewStructArrayWithNulls(cols []arrow.Array, names []string, nullBitmap *memory.Buffer, nullCount int, offset int) (*Struct, error) {
50+
if len(cols) != len(names) {
51+
return nil, fmt.Errorf("%w: mismatching number of fields and child arrays", arrow.ErrInvalid)
52+
}
53+
if len(cols) == 0 {
54+
return nil, fmt.Errorf("%w: can't infer struct array length with 0 child arrays", arrow.ErrInvalid)
55+
}
56+
length := cols[0].Len()
57+
children := make([]arrow.ArrayData, len(cols))
58+
fields := make([]arrow.Field, len(cols))
59+
for i, c := range cols {
60+
if length != c.Len() {
61+
return nil, fmt.Errorf("%w: mismatching child array lengths", arrow.ErrInvalid)
62+
}
63+
children[i] = c.Data()
64+
fields[i].Name = names[i]
65+
fields[i].Type = c.DataType()
66+
fields[i].Nullable = true
67+
}
68+
data := NewData(arrow.StructOf(fields...), length, []*memory.Buffer{nullBitmap}, children, nullCount, offset)
69+
defer data.Release()
70+
return NewStructData(data), nil
71+
}
72+
3973
// NewStructData returns a new Struct array value from data.
4074
func NewStructData(data arrow.ArrayData) *Struct {
4175
a := &Struct{}

go/arrow/compute/cast.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/apache/arrow/go/v10/arrow"
2525
"github.com/apache/arrow/go/v10/arrow/array"
26+
"github.com/apache/arrow/go/v10/arrow/bitutil"
2627
"github.com/apache/arrow/go/v10/arrow/compute/internal/exec"
2728
"github.com/apache/arrow/go/v10/arrow/compute/internal/kernels"
2829
)
@@ -150,6 +151,156 @@ func CastFromExtension(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.Exec
150151
return nil
151152
}
152153

154+
func CastList[SrcOffsetT, DestOffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
155+
var (
156+
opts = ctx.State.(kernels.CastState)
157+
childType = out.Type.(arrow.NestedType).Fields()[0].Type
158+
input = &batch.Values[0].Array
159+
offsets = exec.GetSpanOffsets[SrcOffsetT](input, 1)
160+
isDowncast = kernels.SizeOf[SrcOffsetT]() > kernels.SizeOf[DestOffsetT]()
161+
)
162+
163+
out.Buffers[0] = input.Buffers[0]
164+
out.Buffers[1] = input.Buffers[1]
165+
166+
if input.Offset != 0 && len(input.Buffers[0].Buf) > 0 {
167+
out.Buffers[0].WrapBuffer(ctx.AllocateBitmap(input.Len))
168+
bitutil.CopyBitmap(input.Buffers[0].Buf, int(input.Offset), int(input.Len),
169+
out.Buffers[0].Buf, 0)
170+
}
171+
172+
// Handle list offsets
173+
// Several cases possible:
174+
// - The source offset is non-zero, in which case we slice the
175+
// underlying values and shift the list offsets (regardless of
176+
// their respective types)
177+
// - the source offset is zero but the source and destination types
178+
// have different list offset types, in which case we cast the offsets
179+
// - otherwise we simply keep the original offsets
180+
if isDowncast {
181+
if offsets[input.Len] > SrcOffsetT(kernels.MaxOf[DestOffsetT]()) {
182+
return fmt.Errorf("%w: array of type %s too large to convert to %s",
183+
arrow.ErrInvalid, input.Type, out.Type)
184+
}
185+
}
186+
187+
values := input.Children[0].MakeArray()
188+
defer values.Release()
189+
190+
if input.Offset != 0 {
191+
out.Buffers[1].WrapBuffer(
192+
ctx.Allocate(out.Type.(arrow.OffsetsDataType).
193+
OffsetTypeTraits().BytesRequired(int(input.Len) + 1)))
194+
195+
shiftedOffsets := exec.GetSpanOffsets[DestOffsetT](out, 1)
196+
for i := 0; i < int(input.Len)+1; i++ {
197+
shiftedOffsets[i] = DestOffsetT(offsets[i] - offsets[0])
198+
}
199+
200+
values = array.NewSlice(values, int64(offsets[0]), int64(offsets[input.Len]))
201+
defer values.Release()
202+
} else if kernels.SizeOf[SrcOffsetT]() != kernels.SizeOf[DestOffsetT]() {
203+
out.Buffers[1].WrapBuffer(ctx.Allocate(out.Type.(arrow.OffsetsDataType).
204+
OffsetTypeTraits().BytesRequired(int(input.Len) + 1)))
205+
206+
kernels.DoStaticCast(exec.GetSpanOffsets[SrcOffsetT](input, 1),
207+
exec.GetSpanOffsets[DestOffsetT](out, 1))
208+
}
209+
210+
// handle values
211+
opts.ToType = childType
212+
213+
castedValues, err := CastArray(ctx.Ctx, values, &opts)
214+
if err != nil {
215+
return err
216+
}
217+
defer castedValues.Release()
218+
219+
out.Children = make([]exec.ArraySpan, 1)
220+
out.Children[0].SetMembers(castedValues.Data())
221+
for i, b := range out.Children[0].Buffers {
222+
if b.Owner != nil && b.Owner != values.Data().Buffers()[i] {
223+
b.Owner.Retain()
224+
b.SelfAlloc = true
225+
}
226+
}
227+
return nil
228+
}
229+
230+
func CastStruct(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
231+
var (
232+
opts = ctx.State.(kernels.CastState)
233+
inType = batch.Values[0].Array.Type.(*arrow.StructType)
234+
outType = out.Type.(*arrow.StructType)
235+
inFieldCount = len(inType.Fields())
236+
outFieldCount = len(outType.Fields())
237+
)
238+
239+
fieldsToSelect := make([]int, outFieldCount)
240+
for i := range fieldsToSelect {
241+
fieldsToSelect[i] = -1
242+
}
243+
244+
outFieldIndex := 0
245+
for inFieldIndex := 0; inFieldIndex < inFieldCount && outFieldIndex < outFieldCount; inFieldIndex++ {
246+
inField := inType.Field(inFieldIndex)
247+
outField := outType.Field(outFieldIndex)
248+
if inField.Name == outField.Name {
249+
if inField.Nullable && !outField.Nullable {
250+
return fmt.Errorf("%w: cannot cast nullable field to non-nullable field: %s %s",
251+
arrow.ErrType, inType, outType)
252+
}
253+
fieldsToSelect[outFieldIndex] = inFieldIndex
254+
outFieldIndex++
255+
}
256+
}
257+
258+
if outFieldIndex < outFieldCount {
259+
return fmt.Errorf("%w: struct fields don't match or are in the wrong order: Input: %s Output: %s",
260+
arrow.ErrType, inType, outType)
261+
}
262+
263+
input := &batch.Values[0].Array
264+
if len(input.Buffers[0].Buf) > 0 {
265+
out.Buffers[0].WrapBuffer(ctx.AllocateBitmap(input.Len))
266+
bitutil.CopyBitmap(input.Buffers[0].Buf, int(input.Offset), int(input.Len),
267+
out.Buffers[0].Buf, 0)
268+
}
269+
270+
out.Children = make([]exec.ArraySpan, outFieldCount)
271+
for outFieldIndex, idx := range fieldsToSelect {
272+
values := input.Children[idx].MakeArray()
273+
defer values.Release()
274+
values = array.NewSlice(values, input.Offset, input.Len)
275+
defer values.Release()
276+
277+
opts.ToType = outType.Field(outFieldIndex).Type
278+
castedValues, err := CastArray(ctx.Ctx, values, &opts)
279+
if err != nil {
280+
return err
281+
}
282+
defer castedValues.Release()
283+
284+
out.Children[outFieldIndex].TakeOwnership(castedValues.Data())
285+
}
286+
return nil
287+
}
288+
289+
func addListCast[SrcOffsetT, DestOffsetT int32 | int64](fn *castFunction, inType arrow.Type) error {
290+
kernel := exec.NewScalarKernel([]exec.InputType{exec.NewIDInput(inType)},
291+
kernels.OutputTargetType, CastList[SrcOffsetT, DestOffsetT], nil)
292+
kernel.NullHandling = exec.NullComputedNoPrealloc
293+
kernel.MemAlloc = exec.MemNoPrealloc
294+
return fn.AddTypeCast(inType, kernel)
295+
}
296+
297+
func addStructToStructCast(fn *castFunction) error {
298+
kernel := exec.NewScalarKernel([]exec.InputType{exec.NewIDInput(arrow.STRUCT)},
299+
kernels.OutputTargetType, CastStruct, nil)
300+
kernel.NullHandling = exec.NullComputedNoPrealloc
301+
return fn.AddTypeCast(arrow.STRUCT, kernel)
302+
}
303+
153304
func addCastFuncs(fn []*castFunction) {
154305
for _, f := range fn {
155306
f.AddNewTypeCast(arrow.EXTENSION, []exec.InputType{exec.NewIDInput(arrow.EXTENSION)},
@@ -165,6 +316,12 @@ func initCastTable() {
165316
addCastFuncs(getNumericCasts())
166317
addCastFuncs(getBinaryLikeCasts())
167318
addCastFuncs(getTemporalCasts())
319+
addCastFuncs(getNestedCasts())
320+
321+
nullToExt := newCastFunction("cast_extension", arrow.EXTENSION)
322+
nullToExt.AddNewTypeCast(arrow.NULL, []exec.InputType{exec.NewExactInput(arrow.Null)},
323+
kernels.OutputTargetType, kernels.CastFromNull, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
324+
castTable[arrow.EXTENSION] = nullToExt
168325
}
169326

170327
func getCastFunction(to arrow.DataType) (*castFunction, error) {
@@ -178,6 +335,51 @@ func getCastFunction(to arrow.DataType) (*castFunction, error) {
178335
return nil, fmt.Errorf("%w: unsupported cast to %s", arrow.ErrNotImplemented, to)
179336
}
180337

338+
func getNestedCasts() []*castFunction {
339+
out := make([]*castFunction, 0)
340+
341+
addKernels := func(fn *castFunction, kernels []exec.ScalarKernel) {
342+
for _, k := range kernels {
343+
if err := fn.AddTypeCast(k.Signature.InputTypes[0].MatchID(), k); err != nil {
344+
panic(err)
345+
}
346+
}
347+
}
348+
349+
castLists := newCastFunction("cast_list", arrow.LIST)
350+
addKernels(castLists, kernels.GetCommonCastKernels(arrow.LIST, kernels.OutputTargetType))
351+
if err := addListCast[int32, int32](castLists, arrow.LIST); err != nil {
352+
panic(err)
353+
}
354+
if err := addListCast[int64, int32](castLists, arrow.LARGE_LIST); err != nil {
355+
panic(err)
356+
}
357+
out = append(out, castLists)
358+
359+
castLargeLists := newCastFunction("cast_large_list", arrow.LARGE_LIST)
360+
addKernels(castLargeLists, kernels.GetCommonCastKernels(arrow.LARGE_LIST, kernels.OutputTargetType))
361+
if err := addListCast[int32, int64](castLargeLists, arrow.LIST); err != nil {
362+
panic(err)
363+
}
364+
if err := addListCast[int64, int64](castLargeLists, arrow.LARGE_LIST); err != nil {
365+
panic(err)
366+
}
367+
out = append(out, castLargeLists)
368+
369+
castFsl := newCastFunction("cast_fixed_size_list", arrow.FIXED_SIZE_LIST)
370+
addKernels(castFsl, kernels.GetCommonCastKernels(arrow.FIXED_SIZE_LIST, kernels.OutputTargetType))
371+
out = append(out, castFsl)
372+
373+
castStruct := newCastFunction("cast_struct", arrow.STRUCT)
374+
addKernels(castStruct, kernels.GetCommonCastKernels(arrow.STRUCT, kernels.OutputTargetType))
375+
if err := addStructToStructCast(castStruct); err != nil {
376+
panic(err)
377+
}
378+
out = append(out, castStruct)
379+
380+
return out
381+
}
382+
181383
func getBooleanCasts() []*castFunction {
182384
fn := newCastFunction("cast_boolean", arrow.BOOL)
183385
kns := kernels.GetBooleanCastKernels()

0 commit comments

Comments
 (0)