Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,18 @@ def to_remote_storage(self) -> List[str]:
return aws_utils.list_s3_files(
self._config.offline_store.region, output_uri
)

elif self._config.offline_store.staging_location.startswith("hdfs://"):
output_uri = os.path.join(
self._config.offline_store.staging_location, str(uuid.uuid4())
)
sdf.write.parquet(output_uri)
spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=self._config.offline_store
)
return _list_hdfs_files(spark_session, output_uri)
else:
raise NotImplementedError(
"to_remote_storage is only implemented for file:// and s3:// uri schemes"
"to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes"
)

else:
Expand Down Expand Up @@ -629,6 +637,22 @@ def _list_files_in_folder(folder):
return files


def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
jvm = spark_session._jvm
jsc = spark_session._jsc
if jvm is None or jsc is None:
raise RuntimeError("Spark JVM or JavaSparkContext is not available")
conf = jsc.hadoopConfiguration()
path = jvm.org.apache.hadoop.fs.Path(uri)
fs = jvm.org.apache.hadoop.fs.FileSystem.get(path.toUri(), conf)
statuses = fs.listStatus(path)
files = []
for f in statuses:
if f.isFile():
files.append(f.getPath().toString())
return files


def _cast_data_frame(
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
) -> pyspark.sql.DataFrame:
Expand Down
Loading