Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,38 @@ def get_table_column_names_and_types(
"The following source:\n" + query + "\n ... is empty"
)

high_precision_number_columns = [
col["column_name"]
for col in metadata
if col["type_code"] == 0 and col["scale"] == 0 and col["precision"] > 19
]

if high_precision_number_columns:
max_selects = [
f'MAX("{col}") AS "{col}"' for col in high_precision_number_columns
]
query = (
f"SELECT {', '.join(max_selects)} FROM {self.get_table_query_string()}"
)

with GetSnowflakeConnection(config.offline_store) as conn:
result = execute_snowflake_statement(conn, query).fetch_pandas_all()

for col in high_precision_number_columns:
max_value = result[col].iloc[0]
if max_value is not None:
str_length = len(str(int(max_value)))
for row in metadata:
if row["column_name"] == col:
if str_length <= 9:
row["snowflake_type"] = "NUMBER32"
elif str_length <= 19:
row["snowflake_type"] = "NUMBER64"
else:
raise NotImplementedError(
f"Number in column {col} larger than INT64 is not supported"
)

for row in metadata:
if row["type_code"] == 0:
if row["scale"] == 0:
Expand All @@ -253,34 +285,7 @@ def get_table_column_names_and_types(
elif row["precision"] <= 18: # max precision size to ensure INT64
row["snowflake_type"] = "NUMBER64"
else:
column = row["column_name"]

with GetSnowflakeConnection(config.offline_store) as conn:
query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}'
result = execute_snowflake_statement(
conn, query
).fetch_pandas_all()
if (
result.dtypes[column].name
in python_int_to_snowflake_type_map
):
row["snowflake_type"] = python_int_to_snowflake_type_map[
result.dtypes[column].name
]
else:
if len(result) > 0:
max_value = result.iloc[0][0]
if max_value is not None and len(str(max_value)) <= 9:
row["snowflake_type"] = "NUMBER32"
continue
elif (
max_value is not None and len(str(max_value)) <= 18
):
row["snowflake_type"] = "NUMBER64"
continue
raise NotImplementedError(
"NaNs or Numbers larger than INT64 are not supported"
)
continue
else:
row["snowflake_type"] = "NUMBERwSCALE"

Expand Down
Loading