Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ jobs:
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- name: Install torch (platform-specific)
run: |
if [[ "$RUNNER_OS" == "Linux" ]]; then
pip install torch==2.2.2+cpu torchvision==0.17.2+cpu \
-f https://download.pytorch.org/whl/torch_stable.html
fi
- name: Install dependencies
run: make install-python-dependencies-ci
- name: Test Python
Expand Down
36 changes: 36 additions & 0 deletions sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Expand All @@ -38,6 +39,7 @@
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.torch_wrapper import get_torch

if TYPE_CHECKING:
from feast.saved_dataset import ValidationReference
Expand Down Expand Up @@ -137,6 +139,40 @@ def to_arrow(

return features_table

def to_tensor(
self,
kind: str = "torch",
default_value: Any = float("nan"),
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""
Converts historical features into a dictionary of 1D torch tensors or lists (for non-numeric types).

Args:
kind: "torch" (default and only supported kind).
default_value: Value to replace missing (None or NaN) entries.
timeout: Optional timeout for query execution.

Returns:
Dict[str, Union[torch.Tensor, List]]: Feature column name -> tensor or list.
"""
if kind != "torch":
raise ValueError(
f"Unsupported tensor kind: {kind}. Only 'torch' is supported."
)
torch = get_torch()
device = "cuda" if torch.cuda.is_available() else "cpu"
df = self.to_df(timeout=timeout)
tensor_dict = {}
for column in df.columns:
values = df[column].fillna(default_value).tolist()
first_non_null = next((v for v in values if v is not None), None)
if isinstance(first_non_null, (int, float, bool)):
tensor_dict[column] = torch.tensor(values, device=device)
else:
tensor_dict[column] = values
return tensor_dict

def to_sql(self) -> str:
"""
Return RetrievalJob generated SQL statement if applicable.
Expand Down
54 changes: 53 additions & 1 deletion sdk/python/feast/online_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Union

import pandas as pd
import pyarrow as pa

from feast.feature_view import DUMMY_ENTITY_ID
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
from feast.torch_wrapper import get_torch
from feast.type_map import feast_value_type_to_python_type

if TYPE_CHECKING:
import torch

TorchTensor = torch.Tensor
else:
TorchTensor = Any

TIMESTAMP_POSTFIX: str = "__ts"


Expand Down Expand Up @@ -88,3 +96,47 @@ def to_arrow(self, include_event_timestamps: bool = False) -> pa.Table:
"""

return pa.Table.from_pydict(self.to_dict(include_event_timestamps))

def to_tensor(
self,
kind: str = "torch",
default_value: Any = float("nan"),
) -> Dict[str, Union[TorchTensor, List[Any]]]:
"""
Converts GetOnlineFeaturesResponse features into a dictionary of tensors or lists.

- Numeric features (int, float, bool) -> torch.Tensor
- Non-numeric features (e.g., strings) -> list[Any]

Args:
kind: Backend tensor type. Currently only "torch" is supported.
default_value: Value to substitute for missing (None) entries.

Returns:
Dict[str, Union[torch.Tensor, List[Any]]]: Mapping of feature names to tensors or lists.
"""
if kind != "torch":
raise ValueError(
f"Unsupported tensor kind: {kind}. Only 'torch' is supported currently."
)
torch = get_torch()
feature_dict = self.to_dict(include_event_timestamps=False)
feature_keys = set(self.proto.metadata.feature_names.val)
tensor_dict: Dict[str, Union[TorchTensor, List[Any]]] = {}
for key in feature_keys:
raw_values = feature_dict[key]
values = [v if v is not None else default_value for v in raw_values]
first_valid = next((v for v in values if v is not None), None)
if isinstance(first_valid, (int, float, bool)):
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor_dict[key] = torch.tensor(values, device=device)
except Exception as e:
raise ValueError(
f"Failed to convert values for '{key}' to tensor: {e}"
)
else:
tensor_dict[key] = (
values # Return as-is for strings or unsupported types
)
return tensor_dict
39 changes: 39 additions & 0 deletions sdk/python/feast/torch_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import importlib

TORCH_AVAILABLE = False
_torch = None
_torch_import_error = None


def _import_torch():
global _torch, TORCH_AVAILABLE, _torch_import_error
try:
_torch = importlib.import_module("torch")
TORCH_AVAILABLE = True
except Exception as e:
# Catch import errors including CUDA lib missing
TORCH_AVAILABLE = False
_torch_import_error = e


_import_torch()


def get_torch():
"""
Return the torch module if available, else raise a friendly error.

This prevents crashing on import if CUDA libs are missing.
"""
if TORCH_AVAILABLE:
return _torch
else:
error_message = (
"Torch is not available or failed to import.\n"
"Original error:\n"
f"{_torch_import_error}\n\n"
"If you are on a CPU-only system, make sure you install the CPU-only torch wheel:\n"
" pip install torch==2.2.2+cpu torchvision==0.17.2+cpu -f https://download.pytorch.org/whl/torch_stable.html\n"
"Or check your CUDA installation if using GPU torch.\n"
)
raise ImportError(error_message) from _torch_import_error
42 changes: 42 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import RegistryConfig
from feast.torch_wrapper import get_torch
from feast.types import ValueType
from feast.utils import _utc_now
from tests.integration.feature_repos.universal.feature_views import TAGS
Expand Down Expand Up @@ -129,6 +130,38 @@ def test_get_online_features() -> None:
assert result["name"] == ["John", "John"]
assert result["trips"] == [7, 7]

tensor_result = store.get_online_features(
features=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[
{"driver_id": 1, "customer_id": "5"},
{"driver_id": 1, "customer_id": 5},
],
full_feature_names=False,
).to_tensor()

assert "lon" in tensor_result
assert "avg_orders_day" in tensor_result
assert "name" in tensor_result
assert "trips" in tensor_result
# Entity values
torch = get_torch()
device = "cuda" if torch.cuda.is_available() else "cpu"
assert torch.equal(
tensor_result["driver_id"], torch.tensor([1, 1], device=device)
)
assert tensor_result["customer_id"] == ["5", "5"]

# Feature values
assert tensor_result["lon"] == ["1.0", "1.0"] # String -> not tensor
assert torch.equal(tensor_result["avg_orders_day"], torch.tensor([1.0, 1.0]))
assert tensor_result["name"] == ["John", "John"]
assert torch.equal(tensor_result["trips"], torch.tensor([7, 7], device=device))

# Ensure features are still in result when keys not found
result = store.get_online_features(
features=["customer_driver_combined:trips"],
Expand All @@ -138,6 +171,15 @@ def test_get_online_features() -> None:

assert "trips" in result

result = store.get_online_features(
features=["customer_driver_combined:trips"],
entity_rows=[{"driver_id": 0, "customer_id": 0}],
full_feature_names=False,
).to_tensor()

assert "trips" in result
assert isinstance(result["trips"], torch.Tensor)

with pytest.raises(KeyError) as excinfo:
_ = store.get_online_features(
features=["driver_locations:lon"],
Expand Down
63 changes: 63 additions & 0 deletions sdk/python/tests/unit/test_offline_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from feast.offline_server import OfflineServer, _init_auth_manager
from feast.repo_config import RepoConfig
from feast.torch_wrapper import get_torch
from tests.utils.cli_repo_creator import CliRunner

PROJECT_NAME = "test_remote_offline"
Expand Down Expand Up @@ -115,7 +116,9 @@ def test_remote_offline_store_apis():
fs = remote_feature_store(server)

_test_get_historical_features_returns_data(fs)
_test_get_historical_features_to_tensor(fs)
_test_get_historical_features_returns_nan(fs)
_test_get_historical_features_to_tensor_with_nan(fs)
_test_offline_write_batch(str(temp_dir), fs)
_test_write_logged_features(str(temp_dir), fs)
_test_pull_latest_from_table_or_query(str(temp_dir), fs)
Expand Down Expand Up @@ -187,6 +190,44 @@ def _test_get_historical_features_returns_data(fs: FeatureStore):
assertpy.assert_that(value).is_not_nan()


def _test_get_historical_features_to_tensor(fs: FeatureStore):
entity_df = pd.DataFrame.from_dict(
{
"driver_id": [1001, 1002, 1003],
"event_timestamp": [
datetime(2021, 4, 12, 10, 59, 42),
datetime(2021, 4, 12, 8, 12, 10),
datetime(2021, 4, 12, 16, 40, 26),
],
"label_driver_reported_satisfaction": [1, 5, 3],
"val_to_add": [1, 2, 3],
"val_to_add_2": [10, 20, 30],
}
)

features = [
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"transformed_conv_rate:conv_rate_plus_val1",
"transformed_conv_rate:conv_rate_plus_val2",
]

job = fs.get_historical_features(entity_df, features)
tensor_data = job.to_tensor()

assertpy.assert_that(tensor_data).is_not_none()
assertpy.assert_that(tensor_data["driver_id"].shape[0]).is_equal_to(3)
torch = get_torch()
for key, values in tensor_data.items():
if isinstance(values, torch.Tensor):
assertpy.assert_that(values.shape[0]).is_equal_to(3)
for val in values:
val_float = val.item()
assertpy.assert_that(val_float).is_instance_of((float, int))
assertpy.assert_that(val_float).is_not_nan()


def _test_get_historical_features_returns_nan(fs: FeatureStore):
entity_df = pd.DataFrame.from_dict(
{
Expand Down Expand Up @@ -223,6 +264,28 @@ def _test_get_historical_features_returns_nan(fs: FeatureStore):
assertpy.assert_that(value).is_nan()


def _test_get_historical_features_to_tensor_with_nan(fs: FeatureStore):
entity_df = pd.DataFrame.from_dict(
{
"driver_id": [9991, 9992], # IDs with no matching features
"event_timestamp": [
datetime(2021, 4, 12, 10, 59, 42),
datetime(2021, 4, 12, 10, 59, 42),
],
}
)
features = ["driver_hourly_stats:conv_rate"]
job = fs.get_historical_features(entity_df, features)
tensor_data = job.to_tensor()
assert "conv_rate" in tensor_data
values = tensor_data["conv_rate"]
# conv_rate is a float feature, missing values should be NaN
torch = get_torch()
for val in values:
assert isinstance(val, torch.Tensor) or torch.is_tensor(val)
assertpy.assert_that(torch.isnan(val).item()).is_true()


def _test_offline_write_batch(temp_dir, fs: FeatureStore):
data_file = os.path.join(
temp_dir, fs.project, "feature_repo/data/driver_stats.parquet"
Expand Down