Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,36 @@ def persist(
):
"""
Run the retrieval and persist the results in the same offline store used for read.
Please note the persisting is done only within the scope of the spark session.
Please note the persisting is done only within the scope of the spark session for local warehouse directory.
"""
assert isinstance(storage, SavedDatasetSparkStorage)
table_name = storage.spark_options.table
if not table_name:
raise ValueError("Cannot persist, table_name is not defined")
self.to_spark_df().createOrReplaceTempView(table_name)
if self._has_remote_warehouse_in_config():
file_format = storage.spark_options.file_format
if not file_format:
self.to_spark_df().write.saveAsTable(table_name)
else:
self.to_spark_df().write.format(file_format).saveAsTable(table_name)
else:
self.to_spark_df().createOrReplaceTempView(table_name)

def _has_remote_warehouse_in_config(self) -> bool:
"""
Check if Spark Session config has info about hive metastore uri
or warehouse directory is not a local path
"""
self.spark_session.sparkContext.getConf().getAll()
try:
self.spark_session.conf.get("hive.metastore.uris")
return True
except Exception:
warehouse_dir = self.spark_session.conf.get("spark.sql.warehouse.dir")
if warehouse_dir and warehouse_dir.startswith("file:"):
return False
else:
return True

def supports_remote_storage_export(self) -> bool:
return self._config.offline_store.staging_location is not None
Expand Down