Skip to content

Commit 11b765c

Browse files
style: ruff format spark.py
Signed-off-by: abhijeet-dhumal <abhijeetdhumal652@gmail.com>
1 parent c21053c commit 11b765c

2 files changed

Lines changed: 137 additions & 5 deletions

File tree

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,10 @@ def _resolve_staging_filesystem(
503503
) or os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
504504
kwargs: Dict[str, Any] = {"region": region}
505505
if endpoint:
506-
kwargs["endpoint_override"] = endpoint.rstrip("/").replace(
507-
"https://", ""
508-
).replace("http://", "")
509-
kwargs["scheme"] = (
510-
"https" if endpoint.startswith("https") else "http"
506+
kwargs["endpoint_override"] = (
507+
endpoint.rstrip("/").replace("https://", "").replace("http://", "")
511508
)
509+
kwargs["scheme"] = "https" if endpoint.startswith("https") else "http"
512510
fs = pafs.S3FileSystem(**kwargs)
513511
stripped = [p.replace("s3a://", "").replace("s3://", "") for p in paths]
514512
return fs, stripped
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Unit tests for SparkRetrievalJob._resolve_staging_filesystem.
3+
4+
Verifies that the correct PyArrow filesystem and prefix-stripped paths
5+
are returned for S3, S3A, GCS, file://, and plain local paths.
6+
"""
7+
8+
from unittest.mock import MagicMock, patch
9+
10+
import pytest
11+
12+
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
13+
SparkRetrievalJob,
14+
)
15+
16+
17+
@pytest.fixture()
18+
def retrieval_job():
19+
"""Minimal SparkRetrievalJob with a mock config that has no offline_store region."""
20+
job = object.__new__(SparkRetrievalJob)
21+
config = MagicMock()
22+
config.offline_store.region = None
23+
job._config = config
24+
return job
25+
26+
27+
class TestResolveS3Filesystem:
28+
def test_s3_scheme_returns_s3_filesystem(self, retrieval_job):
29+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
30+
mock_s3.return_value = MagicMock(name="s3fs")
31+
fs, paths = retrieval_job._resolve_staging_filesystem(
32+
["s3://my-bucket/path/a.parquet", "s3://my-bucket/path/b.parquet"]
33+
)
34+
mock_s3.assert_called_once()
35+
assert fs is mock_s3.return_value
36+
assert paths == ["my-bucket/path/a.parquet", "my-bucket/path/b.parquet"]
37+
38+
def test_s3a_scheme_strips_prefix(self, retrieval_job):
39+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
40+
mock_s3.return_value = MagicMock(name="s3fs")
41+
fs, paths = retrieval_job._resolve_staging_filesystem(
42+
["s3a://bucket/dir/file.parquet"]
43+
)
44+
assert paths == ["bucket/dir/file.parquet"]
45+
46+
def test_s3_with_minio_endpoint(self, retrieval_job, monkeypatch):
47+
monkeypatch.setenv("AWS_ENDPOINT_URL_S3", "http://minio.local:9000")
48+
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")
49+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
50+
mock_s3.return_value = MagicMock(name="s3fs")
51+
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
52+
call_kwargs = mock_s3.call_args[1]
53+
assert call_kwargs["endpoint_override"] == "minio.local:9000"
54+
assert call_kwargs["scheme"] == "http"
55+
56+
def test_s3_with_https_endpoint(self, retrieval_job, monkeypatch):
57+
monkeypatch.setenv("AWS_ENDPOINT_URL_S3", "https://s3.custom.corp")
58+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
59+
mock_s3.return_value = MagicMock(name="s3fs")
60+
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
61+
call_kwargs = mock_s3.call_args[1]
62+
assert call_kwargs["endpoint_override"] == "s3.custom.corp"
63+
assert call_kwargs["scheme"] == "https"
64+
65+
def test_s3_falls_back_to_aws_s3_endpoint_env(self, retrieval_job, monkeypatch):
66+
monkeypatch.delenv("AWS_ENDPOINT_URL_S3", raising=False)
67+
monkeypatch.setenv("AWS_S3_ENDPOINT", "http://legacy-minio:9000")
68+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
69+
mock_s3.return_value = MagicMock(name="s3fs")
70+
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
71+
call_kwargs = mock_s3.call_args[1]
72+
assert "endpoint_override" in call_kwargs
73+
74+
def test_s3_no_endpoint_no_override(self, retrieval_job, monkeypatch):
75+
monkeypatch.delenv("AWS_ENDPOINT_URL_S3", raising=False)
76+
monkeypatch.delenv("AWS_S3_ENDPOINT", raising=False)
77+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
78+
mock_s3.return_value = MagicMock(name="s3fs")
79+
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
80+
call_kwargs = mock_s3.call_args[1]
81+
assert "endpoint_override" not in call_kwargs
82+
assert "scheme" not in call_kwargs
83+
84+
def test_s3_region_from_offline_store_config(self, retrieval_job):
85+
retrieval_job._config.offline_store.region = "eu-west-1"
86+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
87+
mock_s3.return_value = MagicMock(name="s3fs")
88+
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
89+
call_kwargs = mock_s3.call_args[1]
90+
assert call_kwargs["region"] == "eu-west-1"
91+
92+
def test_s3_region_fallback_to_env(self, retrieval_job, monkeypatch):
93+
retrieval_job._config.offline_store.region = None
94+
monkeypatch.setenv("AWS_DEFAULT_REGION", "ap-southeast-1")
95+
with patch("pyarrow.fs.S3FileSystem") as mock_s3:
96+
mock_s3.return_value = MagicMock(name="s3fs")
97+
retrieval_job._resolve_staging_filesystem(["s3://bucket/file.parquet"])
98+
call_kwargs = mock_s3.call_args[1]
99+
assert call_kwargs["region"] == "ap-southeast-1"
100+
101+
102+
class TestResolveGCSFilesystem:
103+
def test_gs_scheme_returns_gcs_filesystem(self, retrieval_job):
104+
with patch("pyarrow.fs.GcsFileSystem") as mock_gcs:
105+
mock_gcs.return_value = MagicMock(name="gcsfs")
106+
fs, paths = retrieval_job._resolve_staging_filesystem(
107+
["gs://my-bucket/path/a.parquet", "gs://my-bucket/path/b.parquet"]
108+
)
109+
mock_gcs.assert_called_once()
110+
assert fs is mock_gcs.return_value
111+
assert paths == ["my-bucket/path/a.parquet", "my-bucket/path/b.parquet"]
112+
113+
114+
class TestResolveLocalFilesystem:
115+
def test_file_scheme_stripped(self, retrieval_job):
116+
fs, paths = retrieval_job._resolve_staging_filesystem(
117+
["file:///tmp/staging/a.parquet"]
118+
)
119+
assert fs is None
120+
assert paths == ["/tmp/staging/a.parquet"]
121+
122+
def test_plain_local_path_unchanged(self, retrieval_job):
123+
fs, paths = retrieval_job._resolve_staging_filesystem(
124+
["/tmp/staging/a.parquet", "/tmp/staging/b.parquet"]
125+
)
126+
assert fs is None
127+
assert paths == ["/tmp/staging/a.parquet", "/tmp/staging/b.parquet"]
128+
129+
def test_mixed_file_and_plain_paths(self, retrieval_job):
130+
fs, paths = retrieval_job._resolve_staging_filesystem(
131+
["file:///tmp/a.parquet", "/tmp/b.parquet"]
132+
)
133+
assert fs is None
134+
assert paths == ["/tmp/a.parquet", "/tmp/b.parquet"]

0 commit comments

Comments
 (0)