Skip to content

Commit a59d51e

Browse files
Anson Qiansbinet
authored andcommitted
ARROW-4734: [Go] Add option to write a header for CSV writer
@sbinet Author: Anson Qian <abq@uber.com> Closes apache#3866 from anson627/ARROW-4734 and squashes the following commits: 233df16 <Anson Qian> Update go.mod and go.sum 5e887cd <Anson Qian> Better error handling af73b2e <Anson Qian> Better error handling 50ec667 <Anson Qian> Create new schema when read header 66daa0b <Anson Qian> Address code review e39affe <Anson Qian> Add option for both read and write 5bfbb61 <Anson Qian> ARROW-4734: Add option to write a header for CSV writer
1 parent fd0b90a commit a59d51e

10 files changed

Lines changed: 261 additions & 1 deletion

File tree

go/arrow/Gopkg.lock

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

go/arrow/Gopkg.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@
1717
[[constraint]]
1818
name = "github.com/stretchr/testify"
1919
version = "1.2.0"
20+
21+
[[constraint]]
22+
name = "github.com/pkg/errors"
23+
version = "0.8.1"

go/arrow/csv/common.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,19 @@ func WithCRLF(useCRLF bool) Option {
104104
}
105105
}
106106

107+
func WithHeader() Option {
108+
return func(cfg config) {
109+
switch cfg := cfg.(type) {
110+
case *Reader:
111+
cfg.header = true
112+
case *Writer:
113+
cfg.header = true
114+
default:
115+
panic(fmt.Errorf("arrow/csv: unknown config type %T", cfg))
116+
}
117+
}
118+
}
119+
107120
func validate(schema *arrow.Schema) {
108121
for i, f := range schema.Fields() {
109122
switch ft := f.Type.(type) {

go/arrow/csv/reader.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ import (
2020
"encoding/csv"
2121
"io"
2222
"strconv"
23+
"sync"
2324
"sync/atomic"
2425

2526
"github.com/apache/arrow/go/arrow"
2627
"github.com/apache/arrow/go/arrow/array"
2728
"github.com/apache/arrow/go/arrow/internal/debug"
2829
"github.com/apache/arrow/go/arrow/memory"
30+
"github.com/pkg/errors"
2931
)
3032

3133
// Reader wraps encoding/csv.Reader and creates array.Records from a schema.
@@ -43,6 +45,9 @@ type Reader struct {
4345
next func() bool
4446

4547
mem memory.Allocator
48+
49+
header bool
50+
once sync.Once
4651
}
4752

4853
// NewReader returns a reader that reads from the CSV file and creates
@@ -76,6 +81,28 @@ func NewReader(r io.Reader, schema *arrow.Schema, opts ...Option) *Reader {
7681
return rr
7782
}
7883

84+
func (r *Reader) readHeader() error {
85+
records, err := r.r.Read()
86+
if err != nil {
87+
return errors.Wrapf(err, "arrow/csv: could not read header from file")
88+
}
89+
90+
if len(records) != len(r.schema.Fields()) {
91+
return ErrMismatchFields
92+
}
93+
94+
fields := make([]arrow.Field, len(records))
95+
for idx, name := range records {
96+
fields[idx] = r.schema.Field(idx)
97+
fields[idx].Name = name
98+
}
99+
100+
meta := r.schema.Metadata()
101+
r.schema = arrow.NewSchema(fields, &meta)
102+
r.bld = array.NewRecordBuilder(r.mem, r.schema)
103+
return nil
104+
}
105+
79106
// Err returns the last error encountered during the iteration over the
80107
// underlying CSV file.
81108
func (r *Reader) Err() error { return r.err }
@@ -92,6 +119,12 @@ func (r *Reader) Record() array.Record { return r.cur }
92119
// Next panics if the number of records extracted from a CSV row does not match
93120
// the number of fields of the associated schema.
94121
func (r *Reader) Next() bool {
122+
if r.header {
123+
r.once.Do(func() {
124+
r.err = r.readHeader()
125+
})
126+
}
127+
95128
if r.cur != nil {
96129
r.cur.Release()
97130
r.cur = nil

go/arrow/csv/reader_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,93 @@ rec[1]["str"]: ["str-2"]
249249
}
250250
}
251251

252+
func TestCSVReaderWithHeader(t *testing.T) {
253+
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
254+
defer mem.AssertSize(t, 0)
255+
256+
raw, err := ioutil.ReadFile("testdata/header.csv")
257+
if err != nil {
258+
t.Fatal(err)
259+
}
260+
261+
schema := arrow.NewSchema(
262+
[]arrow.Field{
263+
arrow.Field{Name: "0", Type: arrow.FixedWidthTypes.Boolean},
264+
arrow.Field{Name: "1", Type: arrow.PrimitiveTypes.Int8},
265+
arrow.Field{Name: "2", Type: arrow.PrimitiveTypes.Int16},
266+
arrow.Field{Name: "3", Type: arrow.PrimitiveTypes.Int32},
267+
arrow.Field{Name: "4", Type: arrow.PrimitiveTypes.Int64},
268+
arrow.Field{Name: "5", Type: arrow.PrimitiveTypes.Uint8},
269+
arrow.Field{Name: "6", Type: arrow.PrimitiveTypes.Uint16},
270+
arrow.Field{Name: "7", Type: arrow.PrimitiveTypes.Uint32},
271+
arrow.Field{Name: "8", Type: arrow.PrimitiveTypes.Uint64},
272+
arrow.Field{Name: "9", Type: arrow.PrimitiveTypes.Float32},
273+
arrow.Field{Name: "10", Type: arrow.PrimitiveTypes.Float64},
274+
arrow.Field{Name: "11", Type: arrow.BinaryTypes.String},
275+
},
276+
nil,
277+
)
278+
279+
r := csv.NewReader(bytes.NewReader(raw), schema,
280+
csv.WithAllocator(mem),
281+
csv.WithComment('#'), csv.WithComma(';'),
282+
csv.WithHeader(),
283+
)
284+
defer r.Release()
285+
286+
r.Retain()
287+
r.Release()
288+
289+
out := new(bytes.Buffer)
290+
291+
n := 0
292+
for r.Next() {
293+
rec := r.Record()
294+
for i, col := range rec.Columns() {
295+
fmt.Fprintf(out, "rec[%d][%q]: %v\n", n, rec.ColumnName(i), col)
296+
}
297+
n++
298+
}
299+
300+
if got, want := n, 2; got != want {
301+
t.Fatalf("invalid number of rows: got=%d, want=%d", got, want)
302+
}
303+
304+
want := `rec[0]["bool"]: [true]
305+
rec[0]["i8"]: [-1]
306+
rec[0]["i16"]: [-1]
307+
rec[0]["i32"]: [-1]
308+
rec[0]["i64"]: [-1]
309+
rec[0]["u8"]: [1]
310+
rec[0]["u16"]: [1]
311+
rec[0]["u32"]: [1]
312+
rec[0]["u64"]: [1]
313+
rec[0]["f32"]: [1.1]
314+
rec[0]["f64"]: [1.1]
315+
rec[0]["str"]: ["str-1"]
316+
rec[1]["bool"]: [false]
317+
rec[1]["i8"]: [-2]
318+
rec[1]["i16"]: [-2]
319+
rec[1]["i32"]: [-2]
320+
rec[1]["i64"]: [-2]
321+
rec[1]["u8"]: [2]
322+
rec[1]["u16"]: [2]
323+
rec[1]["u32"]: [2]
324+
rec[1]["u64"]: [2]
325+
rec[1]["f32"]: [2.2]
326+
rec[1]["f64"]: [2.2]
327+
rec[1]["str"]: ["str-2"]
328+
`
329+
330+
if got, want := out.String(), want; got != want {
331+
t.Fatalf("invalid output:\ngot= %s\nwant=%s\n", got, want)
332+
}
333+
334+
if r.Err() != nil {
335+
t.Fatalf("unexpected error: %v", r.Err())
336+
}
337+
}
338+
252339
func TestCSVReaderWithChunk(t *testing.T) {
253340
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
254341
defer mem.AssertSize(t, 0)

go/arrow/csv/testdata/header.csv

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
bool;i8;i16;i32;i64;u8;u16;u32;u64;f32;f64;str
19+
true;-1;-1;-1;-1;1;1;1;1;1.1;1.1;str-1
20+
false;-2;-2;-2;-2;2;2;2;2;2.2;2.2;str-2

go/arrow/csv/writer.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"encoding/csv"
2121
"io"
2222
"strconv"
23+
"sync"
2324

2425
"github.com/apache/arrow/go/arrow"
2526
"github.com/apache/arrow/go/arrow/array"
@@ -29,6 +30,8 @@ import (
2930
type Writer struct {
3031
w *csv.Writer
3132
schema *arrow.Schema
33+
header bool
34+
once sync.Once
3235
}
3336

3437
// NewWriter returns a writer that writes array.Records to the CSV file
@@ -55,6 +58,16 @@ func (w *Writer) Write(record array.Record) error {
5558
return ErrMismatchFields
5659
}
5760

61+
var err error
62+
if w.header {
63+
w.once.Do(func() {
64+
err = w.writeHeader()
65+
})
66+
if err != nil {
67+
return err
68+
}
69+
}
70+
5871
recs := make([][]string, record.NumRows())
5972
for i := range recs {
6073
recs[i] = make([]string, record.NumCols())
@@ -139,3 +152,14 @@ func (w *Writer) Flush() error {
139152
func (w *Writer) Error() error {
140153
return w.w.Error()
141154
}
155+
156+
func (w *Writer) writeHeader() error {
157+
headers := make([]string, len(w.schema.Fields()))
158+
for i := range headers {
159+
headers[i] = w.schema.Field(i).Name
160+
}
161+
if err := w.w.Write(headers); err != nil {
162+
return err
163+
}
164+
return nil
165+
}

go/arrow/csv/writer_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,65 @@ true;1;1;1;1;2;2;2;2;0.2;0.2;str-2
182182
}
183183
}
184184

185+
func TestCSVWriterWithHeader(t *testing.T) {
186+
f := new(bytes.Buffer)
187+
188+
pool := memory.NewCheckedAllocator(memory.NewGoAllocator())
189+
defer pool.AssertSize(t, 0)
190+
schema := arrow.NewSchema(
191+
[]arrow.Field{
192+
{Name: "bool", Type: arrow.FixedWidthTypes.Boolean},
193+
{Name: "i8", Type: arrow.PrimitiveTypes.Int8},
194+
{Name: "i16", Type: arrow.PrimitiveTypes.Int16},
195+
{Name: "i32", Type: arrow.PrimitiveTypes.Int32},
196+
{Name: "i64", Type: arrow.PrimitiveTypes.Int64},
197+
{Name: "u8", Type: arrow.PrimitiveTypes.Uint8},
198+
{Name: "u16", Type: arrow.PrimitiveTypes.Uint16},
199+
{Name: "u32", Type: arrow.PrimitiveTypes.Uint32},
200+
{Name: "u64", Type: arrow.PrimitiveTypes.Uint64},
201+
{Name: "f32", Type: arrow.PrimitiveTypes.Float32},
202+
{Name: "f64", Type: arrow.PrimitiveTypes.Float64},
203+
{Name: "str", Type: arrow.BinaryTypes.String},
204+
},
205+
nil,
206+
)
207+
208+
b := array.NewRecordBuilder(pool, schema)
209+
defer b.Release()
210+
211+
b.Field(0).(*array.BooleanBuilder).AppendValues([]bool{true, false, true}, nil)
212+
b.Field(1).(*array.Int8Builder).AppendValues([]int8{-1, 0, 1}, nil)
213+
b.Field(2).(*array.Int16Builder).AppendValues([]int16{-1, 0, 1}, nil)
214+
b.Field(3).(*array.Int32Builder).AppendValues([]int32{-1, 0, 1}, nil)
215+
b.Field(4).(*array.Int64Builder).AppendValues([]int64{-1, 0, 1}, nil)
216+
b.Field(5).(*array.Uint8Builder).AppendValues([]uint8{0, 1, 2}, nil)
217+
b.Field(6).(*array.Uint16Builder).AppendValues([]uint16{0, 1, 2}, nil)
218+
b.Field(7).(*array.Uint32Builder).AppendValues([]uint32{0, 1, 2}, nil)
219+
b.Field(8).(*array.Uint64Builder).AppendValues([]uint64{0, 1, 2}, nil)
220+
b.Field(9).(*array.Float32Builder).AppendValues([]float32{0.0, 0.1, 0.2}, nil)
221+
b.Field(10).(*array.Float64Builder).AppendValues([]float64{0.0, 0.1, 0.2}, nil)
222+
b.Field(11).(*array.StringBuilder).AppendValues([]string{"str-0", "str-1", "str-2"}, nil)
223+
224+
rec := b.NewRecord()
225+
defer rec.Release()
226+
227+
w := csv.NewWriter(f, schema, csv.WithComma(';'), csv.WithCRLF(false), csv.WithHeader())
228+
err := w.Write(rec)
229+
if err != nil {
230+
t.Fatal(err)
231+
}
232+
233+
want := `bool;i8;i16;i32;i64;u8;u16;u32;u64;f32;f64;str
234+
true;-1;-1;-1;-1;0;0;0;0;0;0;str-0
235+
false;0;0;0;0;1;1;1;1;0.1;0.1;str-1
236+
true;1;1;1;1;2;2;2;2;0.2;0.2;str-2
237+
`
238+
239+
if got, want := f.String(), want; strings.Compare(got, want) != 0 {
240+
t.Fatalf("invalid output:\ngot=%s\nwant=%s\n", got, want)
241+
}
242+
}
243+
185244
func BenchmarkWrite(b *testing.B) {
186245
pool := memory.NewCheckedAllocator(memory.NewGoAllocator())
187246
defer pool.AssertSize(b, 0)

go/arrow/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ module github.com/apache/arrow/go/arrow
1818

1919
require (
2020
github.com/davecgh/go-spew v1.1.0 // indirect
21+
github.com/pkg/errors v0.8.1
2122
github.com/pmezard/go-difflib v1.0.0 // indirect
2223
github.com/stretchr/testify v1.2.0
2324
)

go/arrow/go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
22
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3+
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
4+
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
35
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
46
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
57
github.com/stretchr/testify v1.2.0 h1:LThGCOvhuJic9Gyd1VBCkhyUXmO8vKaBFvBsJ2k03rg=

0 commit comments

Comments
 (0)