Skip to content

Commit ba4c68e

Browse files
Add Set type test for On Demand Feature Views
Co-authored-by: franciscojavierarceo <4163062+franciscojavierarceo@users.noreply.github.com>
1 parent e11795d commit ba4c68e

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed

sdk/python/tests/unit/test_on_demand_python_transformation.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Float64,
3434
Int64,
3535
PdfBytes,
36+
Set,
3637
String,
3738
UnixTimestamp,
3839
ValueType,
@@ -1731,3 +1732,170 @@ def docling_transform_docs(inputs: dict[str, Any]):
17311732
"Let's have fun with Natural Language Processing on PDFs."
17321733
],
17331734
}
1735+
1736+
1737+
def test_python_transformation_with_set_types():
1738+
"""Test that Set types work correctly in on-demand feature views."""
1739+
with tempfile.TemporaryDirectory() as data_dir:
1740+
store = FeatureStore(
1741+
config=RepoConfig(
1742+
project="test_set_types",
1743+
registry=os.path.join(data_dir, "registry.db"),
1744+
provider="local",
1745+
entity_key_serialization_version=3,
1746+
online_store=SqliteOnlineStoreConfig(
1747+
path=os.path.join(data_dir, "online.db")
1748+
),
1749+
)
1750+
)
1751+
1752+
# Create a simple driver entity
1753+
driver = Entity(
1754+
name="driver", join_keys=["driver_id"], value_type=ValueType.INT64
1755+
)
1756+
1757+
# Generate test data
1758+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
1759+
start_date = end_date - timedelta(days=15)
1760+
driver_entities = [1001, 1002, 1003]
1761+
driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date)
1762+
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
1763+
driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True)
1764+
1765+
driver_stats_source = FileSource(
1766+
name="driver_hourly_stats_source",
1767+
path=driver_stats_path,
1768+
timestamp_field="event_timestamp",
1769+
created_timestamp_column="created",
1770+
)
1771+
1772+
driver_stats_fv = FeatureView(
1773+
name="driver_hourly_stats",
1774+
entities=[driver],
1775+
ttl=timedelta(days=0),
1776+
schema=[
1777+
Field(name="conv_rate", dtype=Float32),
1778+
Field(name="acc_rate", dtype=Float32),
1779+
Field(name="avg_daily_trips", dtype=Int64),
1780+
],
1781+
online=True,
1782+
source=driver_stats_source,
1783+
)
1784+
1785+
# Request source with Set types
1786+
request_source = RequestSource(
1787+
name="request_source",
1788+
schema=[
1789+
Field(name="visited_locations", dtype=Set(String)),
1790+
Field(name="favorite_numbers", dtype=Set(Int64)),
1791+
],
1792+
)
1793+
1794+
# On-demand feature view that processes sets
1795+
@on_demand_feature_view(
1796+
sources=[request_source, driver_stats_fv],
1797+
schema=[
1798+
Field(name="unique_locations", dtype=Set(String)),
1799+
Field(name="location_count", dtype=Int64),
1800+
Field(name="unique_numbers", dtype=Set(Int64)),
1801+
Field(name="number_count", dtype=Int64),
1802+
Field(name="has_favorite_location", dtype=Bool),
1803+
],
1804+
mode="python",
1805+
)
1806+
def set_processor_view(inputs: dict[str, Any]) -> dict[str, Any]:
1807+
output = {}
1808+
# Sets automatically deduplicate
1809+
output["unique_locations"] = inputs["visited_locations"]
1810+
output["location_count"] = [
1811+
len(locs) for locs in inputs["visited_locations"]
1812+
]
1813+
output["unique_numbers"] = inputs["favorite_numbers"]
1814+
output["number_count"] = [len(nums) for nums in inputs["favorite_numbers"]]
1815+
output["has_favorite_location"] = [
1816+
"NYC" in locs for locs in inputs["visited_locations"]
1817+
]
1818+
return output
1819+
1820+
# Apply the feature store objects
1821+
store.apply([driver, driver_stats_source, driver_stats_fv, set_processor_view])
1822+
1823+
# Write to online store
1824+
store.write_to_online_store(feature_view_name="driver_hourly_stats", df=driver_df)
1825+
1826+
# Test online feature retrieval with Set types
1827+
entity_rows = [
1828+
{
1829+
"driver_id": 1001,
1830+
"visited_locations": {"NYC", "LA", "SF", "NYC"}, # Duplicate NYC
1831+
"favorite_numbers": {1, 2, 3, 2, 1}, # Duplicates
1832+
}
1833+
]
1834+
1835+
online_response = store.get_online_features(
1836+
entity_rows=entity_rows,
1837+
features=[
1838+
"driver_hourly_stats:conv_rate",
1839+
"driver_hourly_stats:avg_daily_trips",
1840+
"set_processor_view:unique_locations",
1841+
"set_processor_view:location_count",
1842+
"set_processor_view:unique_numbers",
1843+
"set_processor_view:number_count",
1844+
"set_processor_view:has_favorite_location",
1845+
],
1846+
).to_dict()
1847+
1848+
result = {name: value[0] for name, value in online_response.items()}
1849+
1850+
# Type assertions - verify Set types are returned as sets
1851+
assert isinstance(result["unique_locations"], set)
1852+
assert all(isinstance(loc, str) for loc in result["unique_locations"])
1853+
1854+
assert isinstance(result["unique_numbers"], set)
1855+
assert all(isinstance(num, int) for num in result["unique_numbers"])
1856+
1857+
assert isinstance(result["location_count"], int)
1858+
assert isinstance(result["number_count"], int)
1859+
assert isinstance(result["has_favorite_location"], bool)
1860+
1861+
# Value assertions - verify deduplication worked
1862+
assert result["unique_locations"] == {"NYC", "LA", "SF"}
1863+
assert result["location_count"] == 3 # Duplicate "NYC" was removed
1864+
1865+
assert result["unique_numbers"] == {1, 2, 3}
1866+
assert result["number_count"] == 3 # Duplicates were removed
1867+
1868+
assert result["has_favorite_location"] is True # NYC is in the set
1869+
1870+
# Test with list input that gets converted to set
1871+
entity_rows_with_list = [
1872+
{
1873+
"driver_id": 1002,
1874+
"visited_locations": ["Boston", "Boston", "Seattle", "Portland"], # List with duplicates
1875+
"favorite_numbers": [7, 8, 9, 7], # List with duplicates
1876+
}
1877+
]
1878+
1879+
online_response_list = store.get_online_features(
1880+
entity_rows=entity_rows_with_list,
1881+
features=[
1882+
"set_processor_view:unique_locations",
1883+
"set_processor_view:location_count",
1884+
"set_processor_view:unique_numbers",
1885+
"set_processor_view:number_count",
1886+
"set_processor_view:has_favorite_location",
1887+
],
1888+
).to_dict()
1889+
1890+
result_list = {name: value[0] for name, value in online_response_list.items()}
1891+
1892+
# Verify list input was converted to set and deduplicated
1893+
assert isinstance(result_list["unique_locations"], set)
1894+
assert result_list["unique_locations"] == {"Boston", "Seattle", "Portland"}
1895+
assert result_list["location_count"] == 3
1896+
1897+
assert isinstance(result_list["unique_numbers"], set)
1898+
assert result_list["unique_numbers"] == {7, 8, 9}
1899+
assert result_list["number_count"] == 3
1900+
1901+
assert result_list["has_favorite_location"] is False # NYC not in the set

0 commit comments

Comments
 (0)