Skip to content

Commit 5c85d51

Browse files
tswastplamut
authored andcommitted
Fix bug where load_table_from_dataframe could not append to REQUIRED fields. (googleapis#8230)
If a BigQuery schema is supplied as part of the `job_config`, it can be used to set the `nullable` bit correctly on the serialized parquet file.
1 parent 879ef99 commit 5c85d51

3 files changed

Lines changed: 169 additions & 12 deletions

File tree

bigquery/google/cloud/bigquery/_pandas_helpers.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Shared helper functions for connecting BigQuery and pandas."""
1616

17+
import warnings
18+
1719
try:
1820
import pyarrow
1921
import pyarrow.parquet
@@ -107,6 +109,8 @@ def bq_to_arrow_field(bq_field):
107109
if arrow_type:
108110
is_nullable = bq_field.mode.upper() == "NULLABLE"
109111
return pyarrow.field(bq_field.name, arrow_type, nullable=is_nullable)
112+
113+
warnings.warn("Unable to determine type for field '{}'.".format(bq_field.name))
110114
return None
111115

112116

@@ -119,34 +123,58 @@ def bq_to_arrow_array(series, bq_field):
119123
return pyarrow.array(series, type=arrow_type)
120124

121125

122-
def to_parquet(dataframe, bq_schema, filepath):
123-
"""Write dataframe as a Parquet file, according to the desired BQ schema.
124-
125-
This function requires the :mod:`pyarrow` package. Arrow is used as an
126-
intermediate format.
126+
def to_arrow(dataframe, bq_schema):
127+
"""Convert pandas dataframe to Arrow table, using BigQuery schema.
127128
128129
Args:
129130
dataframe (pandas.DataFrame):
130131
DataFrame to convert to convert to Parquet file.
131132
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
132133
Desired BigQuery schema. Number of columns must match number of
133134
columns in the DataFrame.
134-
filepath (str):
135-
Path to write Parquet file to.
136-
"""
137-
if pyarrow is None:
138-
raise ValueError("pyarrow is required for BigQuery schema conversion.")
139135
136+
Returns:
137+
pyarrow.Table:
138+
Table containing dataframe data, with schema derived from
139+
BigQuery schema.
140+
"""
140141
if len(bq_schema) != len(dataframe.columns):
141142
raise ValueError(
142143
"Number of columns in schema must match number of columns in dataframe."
143144
)
144145

145146
arrow_arrays = []
146147
arrow_names = []
148+
arrow_fields = []
147149
for bq_field in bq_schema:
150+
arrow_fields.append(bq_to_arrow_field(bq_field))
148151
arrow_names.append(bq_field.name)
149152
arrow_arrays.append(bq_to_arrow_array(dataframe[bq_field.name], bq_field))
150153

151-
arrow_table = pyarrow.Table.from_arrays(arrow_arrays, names=arrow_names)
154+
if all((field is not None for field in arrow_fields)):
155+
return pyarrow.Table.from_arrays(
156+
arrow_arrays, schema=pyarrow.schema(arrow_fields)
157+
)
158+
return pyarrow.Table.from_arrays(arrow_arrays, names=arrow_names)
159+
160+
161+
def to_parquet(dataframe, bq_schema, filepath):
162+
"""Write dataframe as a Parquet file, according to the desired BQ schema.
163+
164+
This function requires the :mod:`pyarrow` package. Arrow is used as an
165+
intermediate format.
166+
167+
Args:
168+
dataframe (pandas.DataFrame):
169+
DataFrame to convert to convert to Parquet file.
170+
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
171+
Desired BigQuery schema. Number of columns must match number of
172+
columns in the DataFrame.
173+
filepath (str):
174+
Path to write Parquet file to.
175+
"""
176+
if pyarrow is None:
177+
raise ValueError("pyarrow is required for BigQuery schema conversion.")
178+
179+
arrow_table = to_arrow(dataframe, bq_schema)
152180
pyarrow.parquet.write_table(arrow_table, filepath)

bigquery/tests/system.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,45 @@ def test_load_table_from_dataframe_w_nulls(self):
699699
self.assertEqual(tuple(table.schema), table_schema)
700700
self.assertEqual(table.num_rows, num_rows)
701701

702+
@unittest.skipIf(pandas is None, "Requires `pandas`")
703+
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
704+
def test_load_table_from_dataframe_w_required(self):
705+
"""Test that a DataFrame with required columns can be uploaded if a
706+
BigQuery schema is specified.
707+
708+
See: https://github.com/googleapis/google-cloud-python/issues/8093
709+
"""
710+
table_schema = (
711+
bigquery.SchemaField("name", "STRING", mode="REQUIRED"),
712+
bigquery.SchemaField("age", "INTEGER", mode="REQUIRED"),
713+
)
714+
715+
records = [{"name": "Chip", "age": 2}, {"name": "Dale", "age": 3}]
716+
dataframe = pandas.DataFrame(records)
717+
job_config = bigquery.LoadJobConfig(schema=table_schema)
718+
dataset_id = _make_dataset_id("bq_load_test")
719+
self.temp_dataset(dataset_id)
720+
table_id = "{}.{}.load_table_from_dataframe_w_required".format(
721+
Config.CLIENT.project, dataset_id
722+
)
723+
724+
# Create the table before loading so that schema mismatch errors are
725+
# identified.
726+
table = retry_403(Config.CLIENT.create_table)(
727+
Table(table_id, schema=table_schema)
728+
)
729+
self.to_delete.insert(0, table)
730+
731+
job_config = bigquery.LoadJobConfig(schema=table_schema)
732+
load_job = Config.CLIENT.load_table_from_dataframe(
733+
dataframe, table_id, job_config=job_config
734+
)
735+
load_job.result()
736+
737+
table = Config.CLIENT.get_table(table)
738+
self.assertEqual(tuple(table.schema), table_schema)
739+
self.assertEqual(table.num_rows, 2)
740+
702741
@unittest.skipIf(pandas is None, "Requires `pandas`")
703742
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
704743
def test_load_table_from_dataframe_w_explicit_schema(self):

bigquery/tests/unit/test__pandas_helpers.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import datetime
1616
import decimal
1717
import functools
18+
import warnings
1819

1920
try:
2021
import pandas
@@ -26,6 +27,7 @@
2627
except ImportError: # pragma: NO COVER
2728
pyarrow = None
2829
import pytest
30+
import pytz
2931

3032
from google.cloud.bigquery import schema
3133

@@ -373,7 +375,7 @@ def test_bq_to_arrow_data_type_w_struct_unknown_subfield(module_under_test):
373375
(
374376
"GEOGRAPHY",
375377
[
376-
"POINT(30, 10)",
378+
"POINT(30 10)",
377379
None,
378380
"LINESTRING (30 10, 10 30, 40 40)",
379381
"POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))",
@@ -440,6 +442,94 @@ def test_bq_to_arrow_array_w_special_floats(module_under_test):
440442
assert roundtrip[3] is None
441443

442444

445+
@pytest.mark.skipIf(pandas is None, "Requires `pandas`")
446+
@pytest.mark.skipIf(pyarrow is None, "Requires `pyarrow`")
447+
def test_to_arrow_w_required_fields(module_under_test):
448+
bq_schema = (
449+
schema.SchemaField("field01", "STRING", mode="REQUIRED"),
450+
schema.SchemaField("field02", "BYTES", mode="REQUIRED"),
451+
schema.SchemaField("field03", "INTEGER", mode="REQUIRED"),
452+
schema.SchemaField("field04", "INT64", mode="REQUIRED"),
453+
schema.SchemaField("field05", "FLOAT", mode="REQUIRED"),
454+
schema.SchemaField("field06", "FLOAT64", mode="REQUIRED"),
455+
schema.SchemaField("field07", "NUMERIC", mode="REQUIRED"),
456+
schema.SchemaField("field08", "BOOLEAN", mode="REQUIRED"),
457+
schema.SchemaField("field09", "BOOL", mode="REQUIRED"),
458+
schema.SchemaField("field10", "TIMESTAMP", mode="REQUIRED"),
459+
schema.SchemaField("field11", "DATE", mode="REQUIRED"),
460+
schema.SchemaField("field12", "TIME", mode="REQUIRED"),
461+
schema.SchemaField("field13", "DATETIME", mode="REQUIRED"),
462+
schema.SchemaField("field14", "GEOGRAPHY", mode="REQUIRED"),
463+
)
464+
dataframe = pandas.DataFrame(
465+
{
466+
"field01": ["hello", "world"],
467+
"field02": [b"abd", b"efg"],
468+
"field03": [1, 2],
469+
"field04": [3, 4],
470+
"field05": [1.25, 9.75],
471+
"field06": [-1.75, -3.5],
472+
"field07": [decimal.Decimal("1.2345"), decimal.Decimal("6.7891")],
473+
"field08": [True, False],
474+
"field09": [False, True],
475+
"field10": [
476+
datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
477+
datetime.datetime(2012, 12, 21, 9, 7, 42, tzinfo=pytz.utc),
478+
],
479+
"field11": [datetime.date(9999, 12, 31), datetime.date(1970, 1, 1)],
480+
"field12": [datetime.time(23, 59, 59, 999999), datetime.time(12, 0, 0)],
481+
"field13": [
482+
datetime.datetime(1970, 1, 1, 0, 0, 0),
483+
datetime.datetime(2012, 12, 21, 9, 7, 42),
484+
],
485+
"field14": [
486+
"POINT(30 10)",
487+
"POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))",
488+
],
489+
}
490+
)
491+
492+
arrow_table = module_under_test.to_arrow(dataframe, bq_schema)
493+
arrow_schema = arrow_table.schema
494+
495+
assert len(arrow_schema) == len(bq_schema)
496+
for arrow_field in arrow_schema:
497+
assert not arrow_field.nullable
498+
499+
500+
@pytest.mark.skipIf(pandas is None, "Requires `pandas`")
501+
@pytest.mark.skipIf(pyarrow is None, "Requires `pyarrow`")
502+
def test_to_arrow_w_unknown_type(module_under_test):
503+
bq_schema = (
504+
schema.SchemaField("field00", "UNKNOWN_TYPE"),
505+
schema.SchemaField("field01", "STRING"),
506+
schema.SchemaField("field02", "BYTES"),
507+
schema.SchemaField("field03", "INTEGER"),
508+
)
509+
dataframe = pandas.DataFrame(
510+
{
511+
"field00": ["whoami", "whatami"],
512+
"field01": ["hello", "world"],
513+
"field02": [b"abd", b"efg"],
514+
"field03": [1, 2],
515+
}
516+
)
517+
518+
with warnings.catch_warnings(record=True) as warned:
519+
arrow_table = module_under_test.to_arrow(dataframe, bq_schema)
520+
arrow_schema = arrow_table.schema
521+
522+
assert len(warned) == 1
523+
warning = warned[0]
524+
assert "field00" in str(warning)
525+
526+
assert len(arrow_schema) == len(bq_schema)
527+
assert arrow_schema[0].name == "field00"
528+
assert arrow_schema[1].name == "field01"
529+
assert arrow_schema[2].name == "field02"
530+
assert arrow_schema[3].name == "field03"
531+
532+
443533
@pytest.mark.skipIf(pandas is None, "Requires `pandas`")
444534
def test_to_parquet_without_pyarrow(module_under_test, monkeypatch):
445535
monkeypatch.setattr(module_under_test, "pyarrow", None)

0 commit comments

Comments
 (0)