1616 Union ,
1717 cast ,
1818)
19+ from urllib .parse import urlparse
1920
2021if TYPE_CHECKING :
2122 from feast .saved_dataset import ValidationReference
2425import pandas
2526import pandas as pd
2627import pyarrow
28+ import pyarrow .dataset as ds
2729import pyarrow .parquet as pq
2830import pyspark
2931from pydantic import StrictStr
@@ -445,8 +447,43 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
445447
446448 def _to_arrow_internal (self , timeout : Optional [int ] = None ) -> pyarrow .Table :
447449 """Return dataset as pyarrow Table synchronously"""
450+ if self ._should_use_staging_for_arrow ():
451+ return self ._to_arrow_via_staging ()
452+
448453 return pyarrow .Table .from_pandas (self ._to_df_internal (timeout = timeout ))
449454
455+ def _should_use_staging_for_arrow (self ) -> bool :
456+ offline_store = getattr (self ._config , "offline_store" , None )
457+ return bool (
458+ isinstance (offline_store , SparkOfflineStoreConfig )
459+ and getattr (offline_store , "staging_location" , None )
460+ )
461+
462+ def _to_arrow_via_staging (self ) -> pyarrow .Table :
463+ paths = self .to_remote_storage ()
464+ if not paths :
465+ return pyarrow .table ({})
466+
467+ parquet_paths = _filter_parquet_files (paths )
468+ if not parquet_paths :
469+ return pyarrow .table ({})
470+
471+ normalized_paths = self ._normalize_staging_paths (parquet_paths )
472+ dataset = ds .dataset (normalized_paths , format = "parquet" )
473+ return dataset .to_table ()
474+
475+ def _normalize_staging_paths (self , paths : List [str ]) -> List [str ]:
476+ """Normalize staging paths for PyArrow datasets."""
477+ normalized = []
478+ for path in paths :
479+ if path .startswith ("file://" ):
480+ normalized .append (path [len ("file://" ) :])
481+ elif "://" in path :
482+ normalized .append (path )
483+ else :
484+ normalized .append (path )
485+ return normalized
486+
450487 def to_feast_df (
451488 self ,
452489 validation_reference : Optional ["ValidationReference" ] = None ,
@@ -508,55 +545,53 @@ def supports_remote_storage_export(self) -> bool:
508545
509546 def to_remote_storage (self ) -> List [str ]:
510547 """Currently only works for local and s3-based staging locations"""
511- if self .supports_remote_storage_export ():
512- sdf : pyspark .sql .DataFrame = self .to_spark_df ()
513-
514- if self ._config .offline_store .staging_location .startswith ("/" ):
515- local_file_staging_location = os .path .abspath (
516- self ._config .offline_store .staging_location
517- )
518-
519- # write to staging location
520- output_uri = os .path .join (
521- str (local_file_staging_location ), str (uuid .uuid4 ())
522- )
523- sdf .write .parquet (output_uri )
524-
525- return _list_files_in_folder (output_uri )
526- elif self ._config .offline_store .staging_location .startswith ("s3://" ):
527- from feast .infra .utils import aws_utils
528-
529- spark_compatible_s3_staging_location = (
530- self ._config .offline_store .staging_location .replace (
531- "s3://" , "s3a://"
532- )
533- )
534-
535- # write to staging location
536- output_uri = os .path .join (
537- str (spark_compatible_s3_staging_location ), str (uuid .uuid4 ())
538- )
539- sdf .write .parquet (output_uri )
540-
541- return aws_utils .list_s3_files (
542- self ._config .offline_store .region , output_uri
543- )
544- elif self ._config .offline_store .staging_location .startswith ("hdfs://" ):
545- output_uri = os .path .join (
546- self ._config .offline_store .staging_location , str (uuid .uuid4 ())
547- )
548- sdf .write .parquet (output_uri )
549- spark_session = get_spark_session_or_start_new_with_repoconfig (
550- store_config = self ._config .offline_store
551- )
552- return _list_hdfs_files (spark_session , output_uri )
553- else :
554- raise NotImplementedError (
555- "to_remote_storage is only implemented for file://, s3:// and hdfs:// uri schemes"
556- )
548+ if not self .supports_remote_storage_export ():
549+ raise NotImplementedError ()
550+
551+ sdf : pyspark .sql .DataFrame = self .to_spark_df ()
552+ staging_location = self ._config .offline_store .staging_location
553+
554+ if staging_location .startswith ("/" ):
555+ local_file_staging_location = os .path .abspath (staging_location )
556+ output_uri = os .path .join (local_file_staging_location , str (uuid .uuid4 ()))
557+ sdf .write .parquet (output_uri )
558+ return _list_files_in_folder (output_uri )
559+ elif staging_location .startswith ("s3://" ):
560+ from feast .infra .utils import aws_utils
557561
562+ spark_compatible_s3_staging_location = staging_location .replace (
563+ "s3://" , "s3a://"
564+ )
565+ output_uri = os .path .join (
566+ spark_compatible_s3_staging_location , str (uuid .uuid4 ())
567+ )
568+ sdf .write .parquet (output_uri )
569+ s3_uri_for_listing = output_uri .replace ("s3a://" , "s3://" , 1 )
570+ return aws_utils .list_s3_files (
571+ self ._config .offline_store .region , s3_uri_for_listing
572+ )
573+ elif staging_location .startswith ("gs://" ):
574+ output_uri = os .path .join (staging_location , str (uuid .uuid4 ()))
575+ sdf .write .parquet (output_uri )
576+ return _list_gcs_files (output_uri )
577+ elif staging_location .startswith (("wasbs://" , "abfs://" , "abfss://" )) or (
578+ staging_location .startswith ("https://" )
579+ and ".blob.core.windows.net" in staging_location
580+ ):
581+ output_uri = os .path .join (staging_location , str (uuid .uuid4 ()))
582+ sdf .write .parquet (output_uri )
583+ return _list_azure_files (output_uri )
584+ elif staging_location .startswith ("hdfs://" ):
585+ output_uri = os .path .join (staging_location , str (uuid .uuid4 ()))
586+ sdf .write .parquet (output_uri )
587+ spark_session = get_spark_session_or_start_new_with_repoconfig (
588+ store_config = self ._config .offline_store
589+ )
590+ return _list_hdfs_files (spark_session , output_uri )
558591 else :
559- raise NotImplementedError ()
592+ raise NotImplementedError (
593+ "to_remote_storage is only implemented for file://, s3://, gs://, azure, and hdfs uri schemes"
594+ )
560595
561596 @property
562597 def metadata (self ) -> Optional [RetrievalMetadata ]:
@@ -789,6 +824,10 @@ def _list_files_in_folder(folder):
789824 return files
790825
791826
827+ def _filter_parquet_files (paths : List [str ]) -> List [str ]:
828+ return [path for path in paths if path .endswith (".parquet" )]
829+
830+
792831def _list_hdfs_files (spark_session : SparkSession , uri : str ) -> List [str ]:
793832 jvm = spark_session ._jvm
794833 jsc = spark_session ._jsc
@@ -805,6 +844,81 @@ def _list_hdfs_files(spark_session: SparkSession, uri: str) -> List[str]:
805844 return files
806845
807846
847+ def _list_gcs_files (path : str ) -> List [str ]:
848+ try :
849+ from google .cloud import storage
850+ except ImportError as e :
851+ from feast .errors import FeastExtrasDependencyImportError
852+
853+ raise FeastExtrasDependencyImportError ("gcp" , str (e ))
854+
855+ assert path .startswith ("gs://" ), "GCS path must start with gs://"
856+ bucket_path = path [len ("gs://" ) :]
857+ if "/" in bucket_path :
858+ bucket , prefix = bucket_path .split ("/" , 1 )
859+ else :
860+ bucket , prefix = bucket_path , ""
861+
862+ client = storage .Client ()
863+ bucket_obj = client .bucket (bucket )
864+ blobs = bucket_obj .list_blobs (prefix = prefix )
865+
866+ files = []
867+ for blob in blobs :
868+ if not blob .name .endswith ("/" ):
869+ files .append (f"gs://{ bucket } /{ blob .name } " )
870+ return files
871+
872+
873+ def _list_azure_files (path : str ) -> List [str ]:
874+ try :
875+ from azure .identity import DefaultAzureCredential
876+ from azure .storage .blob import BlobServiceClient
877+ except ImportError as e :
878+ from feast .errors import FeastExtrasDependencyImportError
879+
880+ raise FeastExtrasDependencyImportError ("azure" , str (e ))
881+
882+ parsed = urlparse (path )
883+ scheme = parsed .scheme
884+
885+ if scheme in ("wasbs" , "abfs" , "abfss" ):
886+ if "@" not in parsed .netloc :
887+ raise ValueError ("Azure staging URI must include container@account host" )
888+ container , account_host = parsed .netloc .split ("@" , 1 )
889+ account_url = f"https://{ account_host } "
890+ base = f"{ scheme } ://{ container } @{ account_host } "
891+ prefix = parsed .path .lstrip ("/" )
892+ else :
893+ account_url = f"{ parsed .scheme } ://{ parsed .netloc } "
894+ container_and_prefix = parsed .path .lstrip ("/" ).split ("/" , 1 )
895+ container = container_and_prefix [0 ]
896+ base = f"{ account_url } /{ container } "
897+ prefix = container_and_prefix [1 ] if len (container_and_prefix ) > 1 else ""
898+
899+ credential = os .environ .get ("AZURE_STORAGE_KEY" ) or os .environ .get (
900+ "AZURE_STORAGE_ACCOUNT_KEY"
901+ )
902+ if credential :
903+ client = BlobServiceClient (account_url = account_url , credential = credential )
904+ else :
905+ default_credential = DefaultAzureCredential (
906+ exclude_shared_token_cache_credential = True
907+ )
908+ client = BlobServiceClient (
909+ account_url = account_url , credential = default_credential
910+ )
911+
912+ container_client = client .get_container_client (container )
913+ blobs = container_client .list_blobs (name_starts_with = prefix if prefix else None )
914+
915+ files = []
916+ for blob in blobs :
917+ if not blob .name .endswith ("/" ):
918+ files .append (f"{ base } /{ blob .name } " )
919+ return files
920+
921+
808922def _cast_data_frame (
809923 df_new : pyspark .sql .DataFrame , df_existing : pyspark .sql .DataFrame
810924) -> pyspark .sql .DataFrame :
0 commit comments