Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 88 additions & 43 deletions sdk/python/feast/infra/materialization/aws_lambda/lambda_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Callable, List, Literal, Optional, Sequence, Union

import boto3
from botocore.config import Config
from pydantic import StrictStr
from tqdm import tqdm

Expand All @@ -33,6 +34,8 @@
from feast.version import get_version

DEFAULT_BATCH_SIZE = 10_000
DEFAULT_TIMEOUT = 600
LAMBDA_TIMEOUT_RETRIES = 5

logger = logging.getLogger(__name__)

Expand All @@ -52,11 +55,16 @@ class LambdaMaterializationEngineConfig(FeastConfigBaseModel):

@dataclass
class LambdaMaterializationJob(MaterializationJob):
def __init__(self, job_id: str, status: MaterializationJobStatus) -> None:
def __init__(
self,
job_id: str,
status: MaterializationJobStatus,
error: Optional[BaseException] = None,
) -> None:
super().__init__()
self._job_id: str = job_id
self._status = status
self._error = None
self._error = error

def status(self) -> MaterializationJobStatus:
return self._status
Expand Down Expand Up @@ -97,7 +105,7 @@ def update(
PackageType="Image",
Role=self.repo_config.batch_engine.lambda_role,
Code={"ImageUri": self.repo_config.batch_engine.materialization_image},
Timeout=600,
Timeout=DEFAULT_TIMEOUT,
Tags={
"feast-owned": "True",
"project": project,
Expand Down Expand Up @@ -149,7 +157,8 @@ def __init__(
self.lambda_name = f"feast-materialize-{self.repo_config.project}"
if len(self.lambda_name) > 64:
self.lambda_name = self.lambda_name[:64]
self.lambda_client = boto3.client("lambda")
config = Config(read_timeout=DEFAULT_TIMEOUT + 10)
self.lambda_client = boto3.client("lambda", config=config)

def materialize(
self, registry, tasks: List[MaterializationTask]
Expand Down Expand Up @@ -200,47 +209,83 @@ def _materialize_one(
)

paths = offline_job.to_remote_storage()
max_workers = len(paths) if len(paths) <= 20 else 20
executor = ThreadPoolExecutor(max_workers=max_workers)
futures = []

for path in paths:
payload = {
FEATURE_STORE_YAML_ENV_NAME: self.feature_store_base64,
"view_name": feature_view.name,
"view_type": "batch",
"path": path,
}
# Invoke a lambda to materialize this file.

logger.info("Invoking materialization for %s", path)
futures.append(
executor.submit(
self.lambda_client.invoke,
FunctionName=self.lambda_name,
InvocationType="RequestResponse",
Payload=json.dumps(payload),
)
if (num_files := len(paths)) == 0:
logger.warning("No values to update for the given time range.")
return LambdaMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
)
else:
max_workers = num_files if num_files <= 20 else 20
executor = ThreadPoolExecutor(max_workers=max_workers)
futures = []

for path in paths:
payload = {
FEATURE_STORE_YAML_ENV_NAME: self.feature_store_base64,
"view_name": feature_view.name,
"view_type": "batch",
"path": path,
}
# Invoke a lambda to materialize this file.

logger.info("Invoking materialization for %s", path)
futures.append(
executor.submit(
self.invoke_with_retries,
FunctionName=self.lambda_name,
InvocationType="RequestResponse",
Payload=json.dumps(payload),
)
)

done, not_done = wait(futures)
logger.info("Done: %s Not Done: %s", done, not_done)
for f in done:
response = f.result()
output = json.loads(response["Payload"].read())
done, not_done = wait(futures)
logger.info("Done: %s Not Done: %s", done, not_done)
errors = []
for f in done:
response, payload = f.result()

logger.info(
f"Ingested task; request id {response['ResponseMetadata']['RequestId']}, "
f"Output: {output}"
)
logger.info(
f"Ingested task; request id {response['ResponseMetadata']['RequestId']}, "
f"Output: {payload}"
)
if "errorMessage" in payload.keys():
errors.append(payload["errorMessage"])

for f in not_done:
response = f.result()
logger.error(f"Ingestion failed: {response}")
for f in not_done:
response, payload = f.result()
logger.error(f"Ingestion failed: {response=}, {payload=}")

return LambdaMaterializationJob(
job_id=job_id,
status=MaterializationJobStatus.SUCCEEDED
if not not_done
else MaterializationJobStatus.ERROR,
)
if len(not_done) == 0 and len(errors) == 0:
return LambdaMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
)
else:
return LambdaMaterializationJob(
job_id=job_id,
status=MaterializationJobStatus.ERROR,
error=RuntimeError(
f"Lambda functions did not finish successfully: {errors}"
),
)

def invoke_with_retries(self, **kwargs):
"""Invoke the Lambda function and retry if it times out.

The Lambda function may time out initially if many values are updated
and DynamoDB throttles requests. As soon as the DynamoDB tables
are scaled up, the Lambda function can succeed upon retry with higher
throughput.

"""
retries = 0
while retries < LAMBDA_TIMEOUT_RETRIES:
response = self.lambda_client.invoke(**kwargs)
payload = json.loads(response["Payload"].read()) or {}
if "Task timed out after" not in payload.get("errorMessage", ""):
break
retries += 1
logger.warning(
"Retrying lambda function after lambda timeout in request"
f"{response['ResponseMetadata']['RequestId']}"
)
return response, payload
Loading