Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/python/feast/feature_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def async_refresh():

@asynccontextmanager
async def lifespan(app: FastAPI):
await store.initialize()
async_refresh()
yield
stop_refresh()
await store.close()

app = FastAPI(lifespan=lifespan)

Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,14 @@ def list_saved_datasets(
self.project, allow_cache=allow_cache, tags=tags
)

async def initialize(self) -> None:
"""Initialize long-lived clients and/or resources needed for accessing datastores"""
await self._get_provider().initialize(self.config)

async def close(self) -> None:
"""Cleanup any long-lived clients and/or resources"""
await self._get_provider().close()


def _print_materialization_log(
start_date, end_date, num_feature_views: int, online_store: str
Expand Down
74 changes: 56 additions & 18 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import contextlib
import itertools
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

from aiobotocore.config import AioConfig
from pydantic import StrictBool, StrictStr

from feast import Entity, FeatureView, utils
Expand Down Expand Up @@ -75,6 +77,9 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel):
session_based_auth: bool = False
"""AWS session based client authentication"""

max_pool_connections: int = 10
"""Max number of connections for async Dynamodb operations"""


class DynamoDBOnlineStore(OnlineStore):
"""
Expand All @@ -87,7 +92,14 @@ class DynamoDBOnlineStore(OnlineStore):

_dynamodb_client = None
_dynamodb_resource = None
_aioboto_session = None

async def initialize(self, config: RepoConfig):
await _get_aiodynamodb_client(
config.online_store.region, config.online_store.max_pool_connections
)

async def close(self):
await _aiodynamodb_close()

@property
def async_supported(self) -> SupportedAsyncMethods:
Expand Down Expand Up @@ -326,15 +338,17 @@ def to_tbl_resp(raw_client_response):
batches.append(batch)
entity_id_batches.append(entity_id_batch)

async with self._get_aiodynamodb_client(online_config.region) as client:
response_batches = await asyncio.gather(
*[
client.batch_get_item(
RequestItems=entity_id_batch,
)
for entity_id_batch in entity_id_batches
]
)
client = await _get_aiodynamodb_client(
online_config.region, online_config.max_pool_connections
)
Comment on lines -329 to +343
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're now reusing the client and doing context management via initialize and close

response_batches = await asyncio.gather(
*[
client.batch_get_item(
RequestItems=entity_id_batch,
)
for entity_id_batch in entity_id_batches
]
)

result_batches = []
for batch, response in zip(batches, response_batches):
Expand All @@ -349,14 +363,6 @@ def to_tbl_resp(raw_client_response):

return list(itertools.chain(*result_batches))

def _get_aioboto_session(self):
if self._aioboto_session is None:
self._aioboto_session = session.get_session()
return self._aioboto_session

def _get_aiodynamodb_client(self, region: str):
return self._get_aioboto_session().create_client("dynamodb", region_name=region)

def _get_dynamodb_client(
self,
region: str,
Expand Down Expand Up @@ -489,6 +495,38 @@ def _to_client_batch_get_payload(online_config, table_name, batch):
}


_aioboto_session = None
_aioboto_client = None


def _get_aioboto_session():
global _aioboto_session
if _aioboto_session is None:
logger.debug("initializing the aiobotocore session")
_aioboto_session = session.get_session()
return _aioboto_session


async def _get_aiodynamodb_client(region: str, max_pool_connections: int):
global _aioboto_client
if _aioboto_client is None:
logger.debug("initializing the aiobotocore dynamodb client")
client_context = _get_aioboto_session().create_client(
"dynamodb",
region_name=region,
config=AioConfig(max_pool_connections=max_pool_connections),
)
context_stack = contextlib.AsyncExitStack()
_aioboto_client = await context_stack.enter_async_context(client_context)
return _aioboto_client


async def _aiodynamodb_close():
global _aioboto_client
if _aioboto_client:
await _aioboto_client.close()


def _initialize_dynamodb_client(
region: str,
endpoint_url: Optional[str] = None,
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,9 @@ def retrieve_online_documents(
raise NotImplementedError(
f"Online store {self.__class__.__name__} does not support online retrieval"
)

async def initialize(self, config: RepoConfig) -> None:
pass

async def close(self) -> None:
pass
6 changes: 6 additions & 0 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,9 @@ def get_table_column_names_and_types_from_data_source(
return self.offline_store.get_table_column_names_and_types_from_data_source(
config=config, data_source=data_source
)

async def initialize(self, config: RepoConfig) -> None:
await self.online_store.initialize(config)

async def close(self) -> None:
await self.online_store.close()
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,14 @@ def get_table_column_names_and_types_from_data_source(
"""
pass

@abstractmethod
async def initialize(self, config: RepoConfig) -> None:
pass

@abstractmethod
async def close(self) -> None:
pass


def get_provider(config: RepoConfig) -> Provider:
if "." not in config.provider:
Expand Down
4 changes: 4 additions & 0 deletions sdk/python/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[pytest]
asyncio_mode = auto

markers =
universal_offline_stores: mark a test as using all offline stores.
universal_online_stores: mark a test as using all online stores.
Expand All @@ -7,6 +9,8 @@ env =
IS_TEST=True

filterwarnings =
error::_pytest.warning_types.PytestConfigWarning
error::_pytest.warning_types.PytestUnhandledCoroutineWarning
Comment on lines +12 to +13
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fail if any async test functions are skipped bc of missing plugins

ignore::DeprecationWarning:pyspark.sql.pandas.*:
ignore::DeprecationWarning:pyspark.sql.connect.*:
ignore::DeprecationWarning:httpx.*:
Expand Down
Loading