-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Bump psycopg2 to psycopg3 for all Postgres components #4303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9cdceab
dca9c9a
fc65cfc
a3ea80d
e53d9e6
59cbd10
0f86e9e
cd91fdc
3504c77
6e45f8e
c755fcd
acd4a8f
36147ef
d3fd7e7
0a9bced
af136da
0926a15
6514987
915454c
d8e6619
3328530
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -389,3 +389,13 @@ def __init__(self, input_dict: dict): | |
| super().__init__( | ||
| f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}" | ||
| ) | ||
|
|
||
|
|
||
| class ZeroRowsQueryResult(Exception): | ||
| def __init__(self, query: str): | ||
| super().__init__(f"This query returned zero rows:\n{query}") | ||
|
|
||
|
|
||
| class ZeroColumnQueryResult(Exception): | ||
| def __init__(self, query: str): | ||
| super().__init__(f"This query returned zero columns:\n{query}") | ||
|
Comment on lines
+394
to
+401
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exceptions to use for stricter handling of type hints of |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,11 +19,11 @@ | |
| import pandas as pd | ||
| import pyarrow as pa | ||
| from jinja2 import BaseLoader, Environment | ||
| from psycopg2 import sql | ||
| from psycopg import sql | ||
| from pytz import utc | ||
|
|
||
| from feast.data_source import DataSource | ||
| from feast.errors import InvalidEntityType | ||
| from feast.errors import InvalidEntityType, ZeroColumnQueryResult, ZeroRowsQueryResult | ||
| from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView | ||
| from feast.infra.offline_stores import offline_utils | ||
| from feast.infra.offline_stores.contrib.postgres_offline_store.postgres_source import ( | ||
|
|
@@ -274,8 +274,10 @@ def to_sql(self) -> str: | |
| def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table: | ||
| with self._query_generator() as query: | ||
| with _get_conn(self.config.offline_store) as conn, conn.cursor() as cur: | ||
| conn.set_session(readonly=True) | ||
| conn.read_only = True | ||
| cur.execute(query) | ||
| if not cur.description: | ||
| raise ZeroColumnQueryResult(query) | ||
job-almekinders marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| fields = [ | ||
| (c.name, pg_type_code_to_arrow(c.type_code)) | ||
| for c in cur.description | ||
|
|
@@ -331,16 +333,19 @@ def _get_entity_df_event_timestamp_range( | |
| entity_df_event_timestamp.max().to_pydatetime(), | ||
| ) | ||
| elif isinstance(entity_df, str): | ||
| # If the entity_df is a string (SQL query), determine range | ||
| # from table | ||
| # If the entity_df is a string (SQL query), determine range from table | ||
| with _get_conn(config.offline_store) as conn, conn.cursor() as cur: | ||
| ( | ||
| cur.execute( | ||
| f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max FROM ({entity_df}) as tmp_alias" | ||
| ), | ||
| ) | ||
| query = f""" | ||
| SELECT | ||
| MIN({entity_df_event_timestamp_col}) AS min, | ||
| MAX({entity_df_event_timestamp_col}) AS max | ||
| FROM ({entity_df}) AS tmp_alias | ||
| """ | ||
|
Comment on lines
+338
to
+343
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No updates here, only re-formatting the query |
||
| cur.execute(query) | ||
| res = cur.fetchone() | ||
| entity_df_event_timestamp_range = (res[0], res[1]) | ||
| if not res: | ||
| raise ZeroRowsQueryResult(query) | ||
| entity_df_event_timestamp_range = (res[0], res[1]) | ||
| else: | ||
| raise InvalidEntityType(type(entity_df)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,22 @@ | |
| import logging | ||
| from collections import defaultdict | ||
| from datetime import datetime | ||
| from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple | ||
| from typing import ( | ||
| Any, | ||
| Callable, | ||
| Dict, | ||
| Generator, | ||
| List, | ||
| Literal, | ||
| Optional, | ||
| Sequence, | ||
| Tuple, | ||
| ) | ||
|
|
||
| import psycopg2 | ||
| import pytz | ||
| from psycopg2 import sql | ||
| from psycopg2.extras import execute_values | ||
| from psycopg2.pool import SimpleConnectionPool | ||
| from psycopg import sql | ||
| from psycopg.connection import Connection | ||
| from psycopg_pool import ConnectionPool | ||
|
|
||
| from feast import Entity | ||
| from feast.feature_view import FeatureView | ||
|
|
@@ -39,15 +48,17 @@ class PostgreSQLOnlineStoreConfig(PostgreSQLConfig): | |
|
|
||
|
|
||
| class PostgreSQLOnlineStore(OnlineStore): | ||
| _conn: Optional[psycopg2._psycopg.connection] = None | ||
| _conn_pool: Optional[SimpleConnectionPool] = None | ||
| _conn: Optional[Connection] = None | ||
| _conn_pool: Optional[ConnectionPool] = None | ||
|
|
||
| @contextlib.contextmanager | ||
| def _get_conn(self, config: RepoConfig): | ||
| def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]: | ||
| assert config.online_store.type == "postgres" | ||
|
|
||
| if config.online_store.conn_type == ConnectionType.pool: | ||
| if not self._conn_pool: | ||
| self._conn_pool = _get_connection_pool(config.online_store) | ||
| self._conn_pool.open() | ||
| connection = self._conn_pool.getconn() | ||
| yield connection | ||
| self._conn_pool.putconn(connection) | ||
|
|
@@ -64,57 +75,56 @@ def online_write_batch( | |
| Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] | ||
| ], | ||
| progress: Optional[Callable[[int], Any]], | ||
| batch_size: int = 5000, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make configurable, addressing #4036 |
||
| ) -> None: | ||
| project = config.project | ||
| # Format insert values | ||
| insert_values = [] | ||
| for entity_key, values, timestamp, created_ts in data: | ||
| entity_key_bin = serialize_entity_key( | ||
| entity_key, | ||
| entity_key_serialization_version=config.entity_key_serialization_version, | ||
| ) | ||
| timestamp = _to_naive_utc(timestamp) | ||
| if created_ts is not None: | ||
| created_ts = _to_naive_utc(created_ts) | ||
|
|
||
| with self._get_conn(config) as conn, conn.cursor() as cur: | ||
| insert_values = [] | ||
| for entity_key, values, timestamp, created_ts in data: | ||
| entity_key_bin = serialize_entity_key( | ||
| entity_key, | ||
| entity_key_serialization_version=config.entity_key_serialization_version, | ||
| ) | ||
| timestamp = _to_naive_utc(timestamp) | ||
| if created_ts is not None: | ||
| created_ts = _to_naive_utc(created_ts) | ||
|
|
||
| for feature_name, val in values.items(): | ||
| vector_val = None | ||
| if config.online_store.pgvector_enabled: | ||
| vector_val = get_list_val_str(val) | ||
| insert_values.append( | ||
| ( | ||
| entity_key_bin, | ||
| feature_name, | ||
| val.SerializeToString(), | ||
| vector_val, | ||
| timestamp, | ||
| created_ts, | ||
| ) | ||
| for feature_name, val in values.items(): | ||
| vector_val = None | ||
| if config.online_store.pgvector_enabled: | ||
| vector_val = get_list_val_str(val) | ||
| insert_values.append( | ||
| ( | ||
| entity_key_bin, | ||
| feature_name, | ||
| val.SerializeToString(), | ||
| vector_val, | ||
| timestamp, | ||
| created_ts, | ||
| ) | ||
| # Control the batch so that we can update the progress | ||
| batch_size = 5000 | ||
| ) | ||
|
|
||
| # Create insert query | ||
| sql_query = sql.SQL( | ||
| """ | ||
| INSERT INTO {} | ||
| (entity_key, feature_name, value, vector_value, event_ts, created_ts) | ||
| VALUES (%s, %s, %s, %s, %s, %s) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1 out of 2 actual changes to the function: We need to explicitly set the number of placeholder values. |
||
| ON CONFLICT (entity_key, feature_name) DO | ||
| UPDATE SET | ||
| value = EXCLUDED.value, | ||
| vector_value = EXCLUDED.vector_value, | ||
| event_ts = EXCLUDED.event_ts, | ||
| created_ts = EXCLUDED.created_ts; | ||
| """ | ||
| ).format(sql.Identifier(_table_id(config.project, table))) | ||
|
|
||
| # Push data in batches to online store | ||
|
Comment on lines
+80
to
+121
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No changes here, only moving code further up in the function to make it more readable. |
||
| with self._get_conn(config) as conn, conn.cursor() as cur: | ||
| for i in range(0, len(insert_values), batch_size): | ||
| cur_batch = insert_values[i : i + batch_size] | ||
| execute_values( | ||
| cur, | ||
| sql.SQL( | ||
| """ | ||
| INSERT INTO {} | ||
| (entity_key, feature_name, value, vector_value, event_ts, created_ts) | ||
| VALUES %s | ||
| ON CONFLICT (entity_key, feature_name) DO | ||
| UPDATE SET | ||
| value = EXCLUDED.value, | ||
| vector_value = EXCLUDED.vector_value, | ||
| event_ts = EXCLUDED.event_ts, | ||
| created_ts = EXCLUDED.created_ts; | ||
| """, | ||
| ).format(sql.Identifier(_table_id(project, table))), | ||
| cur_batch, | ||
| page_size=batch_size, | ||
| ) | ||
| cur.executemany(sql_query, cur_batch) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 out of 2 actual changes to the function: The |
||
| conn.commit() | ||
|
|
||
| if progress: | ||
| progress(len(cur_batch)) | ||
|
|
||
|
|
@@ -172,7 +182,9 @@ def online_read( | |
| # when we iterate through the keys since they are in the correct order | ||
| values_dict = defaultdict(list) | ||
| for row in rows if rows is not None else []: | ||
| values_dict[row[0].tobytes()].append(row[1:]) | ||
| values_dict[ | ||
| row[0] if isinstance(row[0], bytes) else row[0].tobytes() | ||
| ].append(row[1:]) | ||
|
Comment on lines
+185
to
+187
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only call |
||
|
|
||
| for key in keys: | ||
| if key in values_dict: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.