Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -482,21 +482,50 @@ def _to_arrow_via_staging(self) -> pyarrow.Table:
if not parquet_paths:
return pyarrow.table({})

normalized_paths = self._normalize_staging_paths(parquet_paths)
dataset = ds.dataset(normalized_paths, format="parquet")
pa_fs, stripped_paths = self._resolve_staging_filesystem(parquet_paths)
dataset = ds.dataset(stripped_paths, format="parquet", filesystem=pa_fs)
return dataset.to_table()

def _normalize_staging_paths(self, paths: List[str]) -> List[str]:
"""Normalize staging paths for PyArrow datasets."""
def _resolve_staging_filesystem(
self, paths: List[str]
) -> Tuple[Optional[pyarrow.fs.FileSystem], List[str]]:
"""Return (pyarrow filesystem, prefix-stripped paths) for staging URIs."""
sample = paths[0]

import pyarrow.fs as pafs

if sample.startswith("s3://") or sample.startswith("s3a://"):
endpoint = os.environ.get("AWS_ENDPOINT_URL_S3") or os.environ.get(
"AWS_S3_ENDPOINT", ""
)
region = getattr(
self._config.offline_store, "region", None
) or os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
kwargs: Dict[str, Any] = {"region": region}
if endpoint:
kwargs["endpoint_override"] = (
endpoint.rstrip("/")
.removeprefix("https://")
.removeprefix("http://")
)
kwargs["scheme"] = "https" if endpoint.startswith("https") else "http"
fs = pafs.S3FileSystem(**kwargs)
stripped = [p.removeprefix("s3a://").removeprefix("s3://") for p in paths]
return fs, stripped

if sample.startswith("gs://"):
fs = pafs.GcsFileSystem()
stripped = [p[len("gs://") :] for p in paths]
return fs, stripped

# Local paths
normalized = []
for path in paths:
if path.startswith("file://"):
normalized.append(path[len("file://") :])
elif "://" in path:
normalized.append(path)
for p in paths:
if p.startswith("file://"):
normalized.append(p[len("file://") :])
else:
normalized.append(path)
return normalized
normalized.append(p)
return None, normalized

def to_feast_df(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Unit tests for SparkRetrievalJob._resolve_staging_filesystem.

Verifies that the correct PyArrow filesystem and prefix-stripped paths
are returned for S3, S3A, GCS, file://, and plain local paths.
"""

from unittest.mock import MagicMock, patch

import pytest

from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkRetrievalJob,
)


@pytest.fixture()
def retrieval_job():
"""Minimal SparkRetrievalJob with a mock config that has no offline_store region."""
job = object.__new__(SparkRetrievalJob)
config = MagicMock()
config.offline_store.region = None
job._config = config
return job


class TestResolveS3Filesystem:
def test_s3_scheme_returns_s3_filesystem(self, retrieval_job):
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
fs, paths = retrieval_job._resolve_staging_filesystem(
["s3://my-bucket/path/a.parquet", "s3://my-bucket/path/b.parquet"]
)
mock_s3.assert_called_once()
assert fs is mock_s3.return_value
assert paths == ["my-bucket/path/a.parquet", "my-bucket/path/b.parquet"]

def test_s3a_scheme_strips_prefix(self, retrieval_job):
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
fs, paths = retrieval_job._resolve_staging_filesystem(
["s3a://bucket/dir/file.parquet"]
)
assert paths == ["bucket/dir/file.parquet"]

def test_s3_with_minio_endpoint(self, retrieval_job, monkeypatch):
monkeypatch.setenv("AWS_ENDPOINT_URL_S3", "http://minio.local:9000")
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
call_kwargs = mock_s3.call_args[1]
assert call_kwargs["endpoint_override"] == "minio.local:9000"
assert call_kwargs["scheme"] == "http"

def test_s3_with_https_endpoint(self, retrieval_job, monkeypatch):
monkeypatch.setenv("AWS_ENDPOINT_URL_S3", "https://s3.custom.corp")
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
call_kwargs = mock_s3.call_args[1]
assert call_kwargs["endpoint_override"] == "s3.custom.corp"
assert call_kwargs["scheme"] == "https"

def test_s3_falls_back_to_aws_s3_endpoint_env(self, retrieval_job, monkeypatch):
monkeypatch.delenv("AWS_ENDPOINT_URL_S3", raising=False)
monkeypatch.setenv("AWS_S3_ENDPOINT", "http://legacy-minio:9000")
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
call_kwargs = mock_s3.call_args[1]
assert "endpoint_override" in call_kwargs

def test_s3_no_endpoint_no_override(self, retrieval_job, monkeypatch):
monkeypatch.delenv("AWS_ENDPOINT_URL_S3", raising=False)
monkeypatch.delenv("AWS_S3_ENDPOINT", raising=False)
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
call_kwargs = mock_s3.call_args[1]
assert "endpoint_override" not in call_kwargs
assert "scheme" not in call_kwargs

def test_s3_region_from_offline_store_config(self, retrieval_job):
retrieval_job._config.offline_store.region = "eu-west-1"
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
call_kwargs = mock_s3.call_args[1]
assert call_kwargs["region"] == "eu-west-1"

def test_s3_region_fallback_to_env(self, retrieval_job, monkeypatch):
retrieval_job._config.offline_store.region = None
monkeypatch.setenv("AWS_DEFAULT_REGION", "ap-southeast-1")
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
mock_s3.return_value = MagicMock(name="s3fs")
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
call_kwargs = mock_s3.call_args[1]
assert call_kwargs["region"] == "ap-southeast-1"


class TestResolveGCSFilesystem:
def test_gs_scheme_returns_gcs_filesystem(self, retrieval_job):
with patch("pyarrow.fs.GcsFileSystem") as mock_gcs:
mock_gcs.return_value = MagicMock(name="gcsfs")
fs, paths = retrieval_job._resolve_staging_filesystem(
["gs://my-bucket/path/a.parquet", "gs://my-bucket/path/b.parquet"]
)
mock_gcs.assert_called_once()
assert fs is mock_gcs.return_value
assert paths == ["my-bucket/path/a.parquet", "my-bucket/path/b.parquet"]


class TestResolveLocalFilesystem:
def test_file_scheme_stripped(self, retrieval_job):
fs, paths = retrieval_job._resolve_staging_filesystem(
["file:///tmp/staging/a.parquet"]
)
assert fs is None
assert paths == ["/tmp/staging/a.parquet"]

def test_plain_local_path_unchanged(self, retrieval_job):
fs, paths = retrieval_job._resolve_staging_filesystem(
["/tmp/staging/a.parquet", "/tmp/staging/b.parquet"]
)
assert fs is None
assert paths == ["/tmp/staging/a.parquet", "/tmp/staging/b.parquet"]

def test_mixed_file_and_plain_paths(self, retrieval_job):
fs, paths = retrieval_job._resolve_staging_filesystem(
["file:///tmp/a.parquet", "/tmp/b.parquet"]
)
assert fs is None
assert paths == ["/tmp/a.parquet", "/tmp/b.parquet"]
Loading