Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_historical_features(
config: RepoConfig,
feature_views: List[FeatureView],
feature_refs: List[str],
entity_df: Union[pandas.DataFrame, str],
entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
registry: Registry,
project: str,
full_feature_names: bool = False,
Expand Down Expand Up @@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range(
entity_df_event_timestamp.min().to_pydatetime(),
entity_df_event_timestamp.max().to_pydatetime(),
)
elif isinstance(entity_df, str):
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
# If the entity_df is a string (SQL query), determine range
# from table
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)

# Checks if executing entity sql resulted in any data
if df.rdd.isEmpty():
raise EntitySQLEmptyResults(entity_df)

if isinstance(entity_df, str):
df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col)
# Checks if executing entity sql resulted in any data
if df.rdd.isEmpty():
raise EntitySQLEmptyResults(entity_df)
else:
df = entity_df
# TODO(kzhang132): need utc conversion here.

entity_df_event_timestamp_range = (
Expand All @@ -499,8 +500,11 @@ def _get_entity_schema(
) -> Dict[str, np.dtype]:
if isinstance(entity_df, pd.DataFrame):
return dict(zip(entity_df.columns, entity_df.dtypes))
elif isinstance(entity_df, str):
entity_spark_df = spark_session.sql(entity_df)
elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame):
if isinstance(entity_df, str):
entity_spark_df = spark_session.sql(entity_df)
else:
entity_spark_df = entity_df
return dict(
zip(
entity_spark_df.columns,
Expand All @@ -526,6 +530,9 @@ def _upload_entity_df(
elif isinstance(entity_df, str):
spark_session.sql(entity_df).createOrReplaceTempView(table_name)
return
elif isinstance(entity_df, pyspark.sql.DataFrame):
entity_df.createOrReplaceTempView(table_name)
return
else:
raise InvalidEntityType(type(entity_df))

Expand Down