forked from feast-dev/feast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsnowflake_utils.py
More file actions
431 lines (378 loc) · 16.6 KB
/
Copy pathsnowflake_utils.py
File metadata and controls
431 lines (378 loc) · 16.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
import configparser
import os
import random
import string
from logging import getLogger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, Iterator, List, Optional, Tuple, cast
import pandas as pd
import pyarrow
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from feast.errors import SnowflakeIncompleteConfig, SnowflakeQueryUnknownError
try:
import snowflake.connector
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError
raise FeastExtrasDependencyImportError("snowflake", str(e))
getLogger("snowflake.connector.cursor").disabled = True
getLogger("snowflake.connector.connection").disabled = True
getLogger("snowflake.connector.network").disabled = True
logger = getLogger(__name__)
def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCursor:
cursor = conn.cursor().execute(query)
if cursor is None:
raise SnowflakeQueryUnknownError(query)
return cursor
def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection:
assert config.type == "snowflake.offline"
config_header = "connections.feast_offline_store"
config_dict = dict(config)
# read config file
config_reader = configparser.ConfigParser()
config_reader.read([config_dict["config_path"]])
kwargs: Dict[str, Any] = {}
if config_reader.has_section(config_header):
kwargs = dict(config_reader[config_header])
if "schema" in kwargs:
kwargs["schema_"] = kwargs.pop("schema")
kwargs.update((k, v) for k, v in config_dict.items() if v is not None)
for k, v in kwargs.items():
if k in ["role", "warehouse", "database", "schema_"]:
kwargs[k] = f'"{v}"'
if "schema_" in kwargs:
kwargs["schema"] = kwargs.pop("schema_")
else:
kwargs["schema"] = '"PUBLIC"'
# https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation
# https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication
if "private_key" in kwargs:
kwargs["private_key"] = parse_private_key_path(
kwargs["private_key"], kwargs["private_key_passphrase"]
)
try:
conn = snowflake.connector.connect(
application="feast", autocommit=autocommit, **kwargs
)
return conn
except KeyError as e:
raise SnowflakeIncompleteConfig(e)
# TO DO -- sfc-gh-madkins
# Remove dependency on write_pandas function by falling back to native snowflake python connector
# Current issue is datetime[ns] types are read incorrectly in Snowflake, need to coerce to datetime[ns, UTC]
def write_pandas(
conn: SnowflakeConnection,
df: pd.DataFrame,
table_name: str,
database: Optional[str] = None,
schema: Optional[str] = None,
chunk_size: Optional[int] = None,
compression: str = "gzip",
on_error: str = "abort_statement",
parallel: int = 4,
quote_identifiers: bool = True,
auto_create_table: bool = False,
create_temp_table: bool = False,
):
"""Allows users to most efficiently write back a pandas DataFrame to Snowflake.
It works by dumping the DataFrame into Parquet files, uploading them and finally copying their data into the table.
Returns whether all files were ingested correctly, number of chunks uploaded, and number of rows ingested
with all of the COPY INTO command's output for debugging purposes.
Example usage:
import pandas
from snowflake.connector.pandas_tools import write_pandas
df = pandas.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance'])
success, nchunks, nrows, _ = write_pandas(cnx, df, 'customers')
Args:
conn: Connection to be used to communicate with Snowflake.
df: Dataframe we'd like to write back.
table_name: Table name where we want to insert into.
database: Database schema and table is in, if not provided the default one will be used (Default value = None).
schema: Schema table is in, if not provided the default one will be used (Default value = None).
chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once
(Default value = None).
compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a
better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip').
on_error: Action to take when COPY INTO statements fail, default follows documentation at:
https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#copy-options-copyoptions
(Default value = 'abort_statement').
parallel: Number of threads to be used when uploading chunks, default follows documentation at:
https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4).
quote_identifiers: By default, identifiers, specifically database, schema, table and column names
(from df.columns) will be quoted. If set to False, identifiers are passed on to Snowflake without quoting.
I.e. identifiers will be coerced to uppercase by Snowflake. (Default value = True)
auto_create_table: When true, will automatically create a table with corresponding columns for each column in
the passed in DataFrame. The table will not be created if it already exists
create_temp_table: Will make the auto-created table as a temporary table
"""
cursor: SnowflakeCursor = conn.cursor()
stage_name = create_temporary_sfc_stage(cursor)
upload_df(df, cursor, stage_name, chunk_size, parallel, compression)
copy_uploaded_data_to_table(
cursor,
stage_name,
list(df.columns),
table_name,
database,
schema,
compression,
on_error,
quote_identifiers,
auto_create_table,
create_temp_table,
)
def write_parquet(
conn: SnowflakeConnection,
path: Path,
dataset_schema: pyarrow.Schema,
table_name: str,
database: Optional[str] = None,
schema: Optional[str] = None,
compression: str = "gzip",
on_error: str = "abort_statement",
parallel: int = 4,
quote_identifiers: bool = True,
auto_create_table: bool = False,
create_temp_table: bool = False,
):
cursor: SnowflakeCursor = conn.cursor()
stage_name = create_temporary_sfc_stage(cursor)
columns = [field.name for field in dataset_schema]
upload_local_pq(path, cursor, stage_name, parallel)
copy_uploaded_data_to_table(
cursor,
stage_name,
columns,
table_name,
database,
schema,
compression,
on_error,
quote_identifiers,
auto_create_table,
create_temp_table,
)
def copy_uploaded_data_to_table(
cursor: SnowflakeCursor,
stage_name: str,
columns: List[str],
table_name: str,
database: Optional[str] = None,
schema: Optional[str] = None,
compression: str = "gzip",
on_error: str = "abort_statement",
quote_identifiers: bool = True,
auto_create_table: bool = False,
create_temp_table: bool = False,
):
if database is not None and schema is None:
raise ProgrammingError(
"Schema has to be provided to write_pandas when a database is provided"
)
# This dictionary maps the compression algorithm to Snowflake put copy into command type
# https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#type-parquet
compression_map = {"gzip": "auto", "snappy": "snappy"}
if compression not in compression_map.keys():
raise ProgrammingError(
"Invalid compression '{}', only acceptable values are: {}".format(
compression, compression_map.keys()
)
)
if quote_identifiers:
location = (
(('"' + database + '".') if database else "")
+ (('"' + schema + '".') if schema else "")
+ ('"' + table_name + '"')
)
else:
location = (
(database + "." if database else "")
+ (schema + "." if schema else "")
+ (table_name)
)
if quote_identifiers:
quoted_columns = '"' + '","'.join(columns) + '"'
else:
quoted_columns = ",".join(columns)
if auto_create_table:
file_format_name = create_file_format(compression, compression_map, cursor)
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\"{stage_name}\"', file_format=>'{file_format_name}'))"
logger.debug(f"inferring schema with '{infer_schema_sql}'")
result_cursor = cursor.execute(infer_schema_sql, _is_internal=True)
if result_cursor is None:
raise SnowflakeQueryUnknownError(infer_schema_sql)
result = cast(List[Tuple[str, str]], result_cursor.fetchall())
column_type_mapping: Dict[str, str] = dict(result)
# Infer schema can return the columns out of order depending on the chunking we do when uploading
# so we have to iterate through the dataframe columns to make sure we create the table with its
# columns in order
quote = '"' if quote_identifiers else ""
create_table_columns = ", ".join(
[f"{quote}{c}{quote} {column_type_mapping[c]}" for c in columns]
)
create_table_sql = (
f"CREATE {'TEMP ' if create_temp_table else ''}TABLE IF NOT EXISTS {location} "
f"({create_table_columns})"
f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
)
logger.debug(f"auto creating table with '{create_table_sql}'")
cursor.execute(create_table_sql, _is_internal=True)
drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS {file_format_name}"
logger.debug(f"dropping file format with '{drop_file_format_sql}'")
cursor.execute(drop_file_format_sql, _is_internal=True)
# in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly
# see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html)
if quote_identifiers:
parquet_columns = "$1:" + ",$1:".join(f'"{c}"' for c in columns)
else:
parquet_columns = "$1:" + ",$1:".join(columns)
copy_into_sql = (
"COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
"({columns}) "
'FROM (SELECT {parquet_columns} FROM @"{stage_name}") '
"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) "
"PURGE=TRUE ON_ERROR={on_error}"
).format(
location=location,
columns=quoted_columns,
parquet_columns=parquet_columns,
stage_name=stage_name,
compression=compression_map[compression],
on_error=on_error,
)
logger.debug("copying into with '{}'".format(copy_into_sql))
# Snowflake returns the original cursor if the query execution succeeded.
result_cursor = cursor.execute(copy_into_sql, _is_internal=True)
if result_cursor is None:
raise SnowflakeQueryUnknownError(copy_into_sql)
result_cursor.close()
def upload_df(
df: pd.DataFrame,
cursor: SnowflakeCursor,
stage_name: str,
chunk_size: Optional[int] = None,
parallel: int = 4,
compression: str = "gzip",
):
"""
Args:
df: Dataframe we'd like to write back.
cursor: cursor to be used to communicate with Snowflake.
stage_name: stage name in Snowflake connection.
chunk_size: Number of elements to be inserted once, if not provided all elements will be dumped once
(Default value = None).
parallel: Number of threads to be used when uploading chunks, default follows documentation at:
https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4).
compression: The compression used on the Parquet files, can only be gzip, or snappy. Gzip gives supposedly a
better compression, while snappy is faster. Use whichever is more appropriate (Default value = 'gzip').
"""
if chunk_size is None:
chunk_size = len(df)
with TemporaryDirectory() as tmp_folder:
for i, chunk in chunk_helper(df, chunk_size):
chunk_path = os.path.join(tmp_folder, "file{}.txt".format(i))
# Dump chunk into parquet file
chunk.to_parquet(
chunk_path,
compression=compression,
use_deprecated_int96_timestamps=True,
)
# Upload parquet file
upload_sql = (
"PUT /* Python:feast.infra.utils.snowflake_utils.upload_df() */ "
"'file://{path}' @\"{stage_name}\" PARALLEL={parallel}"
).format(
path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"),
stage_name=stage_name,
parallel=parallel,
)
logger.debug(f"uploading files with '{upload_sql}'")
cursor.execute(upload_sql, _is_internal=True)
# Remove chunk file
os.remove(chunk_path)
def upload_local_pq(
path: Path, cursor: SnowflakeCursor, stage_name: str, parallel: int = 4,
):
"""
Args:
path: Path to parquet dataset on disk
cursor: cursor to be used to communicate with Snowflake.
stage_name: stage name in Snowflake connection.
parallel: Number of threads to be used when uploading chunks, default follows documentation at:
https://docs.snowflake.com/en/sql-reference/sql/put.html#optional-parameters (Default value = 4).
"""
for file in path.iterdir():
upload_sql = (
"PUT /* Python:feast.infra.utils.snowflake_utils.upload_local_pq() */ "
"'file://{path}' @\"{stage_name}\" PARALLEL={parallel}"
).format(
path=str(file).replace("\\", "\\\\").replace("'", "\\'"),
stage_name=stage_name,
parallel=parallel,
)
logger.debug(f"uploading files with '{upload_sql}'")
cursor.execute(upload_sql, _is_internal=True)
@retry(
wait=wait_exponential(multiplier=1, max=4),
retry=retry_if_exception_type(ProgrammingError),
stop=stop_after_attempt(5),
reraise=True,
)
def create_file_format(
compression: str, compression_map: Dict[str, str], cursor: SnowflakeCursor
) -> str:
file_format_name = (
'"' + "".join(random.choice(string.ascii_lowercase) for _ in range(5)) + '"'
)
file_format_sql = (
f"CREATE FILE FORMAT {file_format_name} "
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
f"TYPE=PARQUET COMPRESSION={compression_map[compression]}"
)
logger.debug(f"creating file format with '{file_format_sql}'")
cursor.execute(file_format_sql, _is_internal=True)
return file_format_name
@retry(
wait=wait_exponential(multiplier=1, max=4),
retry=retry_if_exception_type(ProgrammingError),
stop=stop_after_attempt(5),
reraise=True,
)
def create_temporary_sfc_stage(cursor: SnowflakeCursor) -> str:
stage_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
create_stage_sql = (
"create temporary stage /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
'"{stage_name}"'
).format(stage_name=stage_name)
logger.debug(f"creating stage with '{create_stage_sql}'")
result_cursor = cursor.execute(create_stage_sql, _is_internal=True)
if result_cursor is None:
raise SnowflakeQueryUnknownError(create_stage_sql)
result_cursor.fetchall()
return stage_name
def chunk_helper(lst: pd.DataFrame, n: int) -> Iterator[Tuple[int, pd.DataFrame]]:
"""Helper generator to chunk a sequence efficiently with current index like if enumerate was called on sequence."""
for i in range(0, len(lst), n):
yield int(i / n), lst[i : i + n]
def parse_private_key_path(key_path: str, private_key_passphrase: str) -> bytes:
with open(key_path, "rb") as key:
p_key = serialization.load_pem_private_key(
key.read(),
password=private_key_passphrase.encode(),
backend=default_backend(),
)
pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return pkb