Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 126 additions & 60 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Expand All @@ -12,18 +13,24 @@
Optional,
Sequence,
Tuple,
Union,
)

import pytz
from psycopg import sql
from psycopg import AsyncConnection, sql
from psycopg.connection import Connection
from psycopg_pool import ConnectionPool
from psycopg_pool import AsyncConnectionPool, ConnectionPool

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.connection_utils import (
_get_conn,
_get_conn_async,
_get_connection_pool,
_get_connection_pool_async,
)
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -51,6 +58,9 @@ class PostgreSQLOnlineStore(OnlineStore):
_conn: Optional[Connection] = None
_conn_pool: Optional[ConnectionPool] = None

_conn_async: Optional[AsyncConnection] = None
_conn_pool_async: Optional[AsyncConnectionPool] = None

@contextlib.contextmanager
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
assert config.online_store.type == "postgres"
Expand All @@ -67,6 +77,24 @@ def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
self._conn = _get_conn(config.online_store)
yield self._conn

@contextlib.asynccontextmanager
async def _get_conn_async(
self, config: RepoConfig
) -> AsyncGenerator[AsyncConnection, Any]:
if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool_async:
self._conn_pool_async = await _get_connection_pool_async(
config.online_store
)
await self._conn_pool_async.open()
connection = await self._conn_pool_async.getconn()
yield connection
await self._conn_pool_async.putconn(connection)
else:
if not self._conn_async:
self._conn_async = await _get_conn_async(config.online_store)
yield self._conn_async

def online_write_batch(
self,
config: RepoConfig,
Expand Down Expand Up @@ -135,69 +163,107 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version)
query, params = self._construct_query_and_params(
config, table, keys, requested_features
)

project = config.project
with self._get_conn(config) as conn, conn.cursor() as cur:
# Collecting all the keys to a list allows us to make fewer round trips
# to PostgreSQL
keys = []
for entity_key in entity_keys:
keys.append(
serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
)
cur.execute(query, params)
rows = cur.fetchall()

if not requested_features:
cur.execute(
sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s);
"""
).format(
sql.Identifier(_table_id(project, table)),
),
(keys,),
)
else:
cur.execute(
sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s) and feature_name = ANY(%s);
"""
).format(
sql.Identifier(_table_id(project, table)),
),
(keys, requested_features),
)
return self._process_rows(keys, rows)

rows = cur.fetchall()
async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
keys = self._prepare_keys(entity_keys, config.entity_key_serialization_version)
query, params = self._construct_query_and_params(
config, table, keys, requested_features
)

# Since we don't know the order returned from PostgreSQL we'll need
# to construct a dict to be able to quickly look up the correct row
# when we iterate through the keys since they are in the correct order
values_dict = defaultdict(list)
for row in rows if rows is not None else []:
values_dict[
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
].append(row[1:])

for key in keys:
if key in values_dict:
value = values_dict[key]
res = {}
for feature_name, value_bin, event_ts in value:
val = ValueProto()
val.ParseFromString(bytes(value_bin))
res[feature_name] = val
result.append((event_ts, res))
else:
result.append((None, None))
async with self._get_conn_async(config) as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()

return self._process_rows(keys, rows)

@staticmethod
def _construct_query_and_params(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new logic, just moved to a separate method so it can be re-used.

config: RepoConfig,
table: FeatureView,
keys: List[bytes],
requested_features: Optional[List[str]] = None,
) -> Tuple[sql.Composed, Union[Tuple[List[bytes], List[str]], Tuple[List[bytes]]]]:
"""Construct the SQL query based on the given parameters."""
if requested_features:
query = sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s) AND feature_name = ANY(%s);
"""
).format(
sql.Identifier(_table_id(config.project, table)),
)
params = (keys, requested_features)
else:
query = sql.SQL(
"""
SELECT entity_key, feature_name, value, event_ts
FROM {} WHERE entity_key = ANY(%s);
"""
).format(
sql.Identifier(_table_id(config.project, table)),
)
params = (keys, [])
return query, params

@staticmethod
def _prepare_keys(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new logic, just moved to a separate method so it can be re-used.

entity_keys: List[EntityKeyProto], entity_key_serialization_version: int
) -> List[bytes]:
"""Prepare all keys in a list to make fewer round trips to the database."""
return [
serialize_entity_key(
entity_key,
entity_key_serialization_version=entity_key_serialization_version,
)
for entity_key in entity_keys
]

@staticmethod
def _process_rows(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No new logic, just moved to a separate method so it can be re-used.

keys: List[bytes], rows: List[Tuple]
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
"""Transform the retrieved rows in the desired output.

PostgreSQL may return rows in an unpredictable order. Therefore, `values_dict`
is created to quickly look up the correct row using the keys, since these are
actually in the correct order.
"""
values_dict = defaultdict(list)
for row in rows if rows is not None else []:
values_dict[
row[0] if isinstance(row[0], bytes) else row[0].tobytes()
].append(row[1:])

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
for key in keys:
if key in values_dict:
value = values_dict[key]
res = {}
for feature_name, value_bin, event_ts in value:
val = ValueProto()
val.ParseFromString(bytes(value_bin))
res[feature_name] = val
result.append((event_ts, res))
else:
result.append((None, None))
return result

def update(
Expand Down
25 changes: 23 additions & 2 deletions sdk/python/feast/infra/utils/postgres/connection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
import psycopg
import pyarrow as pa
from psycopg.connection import Connection
from psycopg_pool import ConnectionPool
from psycopg import AsyncConnection, Connection
from psycopg_pool import AsyncConnectionPool, ConnectionPool

from feast.infra.utils.postgres.postgres_config import PostgreSQLConfig
from feast.type_map import arrow_to_pg_type
Expand All @@ -21,6 +21,16 @@ def _get_conn(config: PostgreSQLConfig) -> Connection:
return conn


async def _get_conn_async(config: PostgreSQLConfig) -> AsyncConnection:
"""Get a psycopg `AsyncConnection`."""
conn = await psycopg.AsyncConnection.connect(
conninfo=_get_conninfo(config),
keepalives_idle=config.keepalives_idle,
**_get_conn_kwargs(config),
)
return conn


def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool:
"""Get a psycopg `ConnectionPool`."""
return ConnectionPool(
Expand All @@ -32,6 +42,17 @@ def _get_connection_pool(config: PostgreSQLConfig) -> ConnectionPool:
)


async def _get_connection_pool_async(config: PostgreSQLConfig) -> AsyncConnectionPool:
"""Get a psycopg `AsyncConnectionPool`."""
return AsyncConnectionPool(
conninfo=_get_conninfo(config),
min_size=config.min_conn,
max_size=config.max_conn,
open=False,
kwargs=_get_conn_kwargs(config),
)


def _get_conninfo(config: PostgreSQLConfig) -> str:
"""Get the `conninfo` argument required for connection objects."""
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def test_online_retrieval_with_event_timestamps(environment, universal_data_sour


@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["redis", "dynamodb"])
@pytest.mark.universal_online_stores(only=["redis", "dynamodb", "postgres"])
def test_async_online_retrieval_with_event_timestamps(
environment, universal_data_sources
):
Expand Down