|
33 | 33 | Float64, |
34 | 34 | Int64, |
35 | 35 | PdfBytes, |
| 36 | + Set, |
36 | 37 | String, |
37 | 38 | UnixTimestamp, |
38 | 39 | ValueType, |
@@ -1731,3 +1732,170 @@ def docling_transform_docs(inputs: dict[str, Any]): |
1731 | 1732 | "Let's have fun with Natural Language Processing on PDFs." |
1732 | 1733 | ], |
1733 | 1734 | } |
| 1735 | + |
| 1736 | + |
| 1737 | +def test_python_transformation_with_set_types(): |
| 1738 | + """Test that Set types work correctly in on-demand feature views.""" |
| 1739 | + with tempfile.TemporaryDirectory() as data_dir: |
| 1740 | + store = FeatureStore( |
| 1741 | + config=RepoConfig( |
| 1742 | + project="test_set_types", |
| 1743 | + registry=os.path.join(data_dir, "registry.db"), |
| 1744 | + provider="local", |
| 1745 | + entity_key_serialization_version=3, |
| 1746 | + online_store=SqliteOnlineStoreConfig( |
| 1747 | + path=os.path.join(data_dir, "online.db") |
| 1748 | + ), |
| 1749 | + ) |
| 1750 | + ) |
| 1751 | + |
| 1752 | + # Create a simple driver entity |
| 1753 | + driver = Entity( |
| 1754 | + name="driver", join_keys=["driver_id"], value_type=ValueType.INT64 |
| 1755 | + ) |
| 1756 | + |
| 1757 | + # Generate test data |
| 1758 | + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) |
| 1759 | + start_date = end_date - timedelta(days=15) |
| 1760 | + driver_entities = [1001, 1002, 1003] |
| 1761 | + driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date) |
| 1762 | + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") |
| 1763 | + driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True) |
| 1764 | + |
| 1765 | + driver_stats_source = FileSource( |
| 1766 | + name="driver_hourly_stats_source", |
| 1767 | + path=driver_stats_path, |
| 1768 | + timestamp_field="event_timestamp", |
| 1769 | + created_timestamp_column="created", |
| 1770 | + ) |
| 1771 | + |
| 1772 | + driver_stats_fv = FeatureView( |
| 1773 | + name="driver_hourly_stats", |
| 1774 | + entities=[driver], |
| 1775 | + ttl=timedelta(days=0), |
| 1776 | + schema=[ |
| 1777 | + Field(name="conv_rate", dtype=Float32), |
| 1778 | + Field(name="acc_rate", dtype=Float32), |
| 1779 | + Field(name="avg_daily_trips", dtype=Int64), |
| 1780 | + ], |
| 1781 | + online=True, |
| 1782 | + source=driver_stats_source, |
| 1783 | + ) |
| 1784 | + |
| 1785 | + # Request source with Set types |
| 1786 | + request_source = RequestSource( |
| 1787 | + name="request_source", |
| 1788 | + schema=[ |
| 1789 | + Field(name="visited_locations", dtype=Set(String)), |
| 1790 | + Field(name="favorite_numbers", dtype=Set(Int64)), |
| 1791 | + ], |
| 1792 | + ) |
| 1793 | + |
| 1794 | + # On-demand feature view that processes sets |
| 1795 | + @on_demand_feature_view( |
| 1796 | + sources=[request_source, driver_stats_fv], |
| 1797 | + schema=[ |
| 1798 | + Field(name="unique_locations", dtype=Set(String)), |
| 1799 | + Field(name="location_count", dtype=Int64), |
| 1800 | + Field(name="unique_numbers", dtype=Set(Int64)), |
| 1801 | + Field(name="number_count", dtype=Int64), |
| 1802 | + Field(name="has_favorite_location", dtype=Bool), |
| 1803 | + ], |
| 1804 | + mode="python", |
| 1805 | + ) |
| 1806 | + def set_processor_view(inputs: dict[str, Any]) -> dict[str, Any]: |
| 1807 | + output = {} |
| 1808 | + # Sets automatically deduplicate |
| 1809 | + output["unique_locations"] = inputs["visited_locations"] |
| 1810 | + output["location_count"] = [ |
| 1811 | + len(locs) for locs in inputs["visited_locations"] |
| 1812 | + ] |
| 1813 | + output["unique_numbers"] = inputs["favorite_numbers"] |
| 1814 | + output["number_count"] = [len(nums) for nums in inputs["favorite_numbers"]] |
| 1815 | + output["has_favorite_location"] = [ |
| 1816 | + "NYC" in locs for locs in inputs["visited_locations"] |
| 1817 | + ] |
| 1818 | + return output |
| 1819 | + |
| 1820 | + # Apply the feature store objects |
| 1821 | + store.apply([driver, driver_stats_source, driver_stats_fv, set_processor_view]) |
| 1822 | + |
| 1823 | + # Write to online store |
| 1824 | + store.write_to_online_store(feature_view_name="driver_hourly_stats", df=driver_df) |
| 1825 | + |
| 1826 | + # Test online feature retrieval with Set types |
| 1827 | + entity_rows = [ |
| 1828 | + { |
| 1829 | + "driver_id": 1001, |
| 1830 | + "visited_locations": {"NYC", "LA", "SF", "NYC"}, # Duplicate NYC |
| 1831 | + "favorite_numbers": {1, 2, 3, 2, 1}, # Duplicates |
| 1832 | + } |
| 1833 | + ] |
| 1834 | + |
| 1835 | + online_response = store.get_online_features( |
| 1836 | + entity_rows=entity_rows, |
| 1837 | + features=[ |
| 1838 | + "driver_hourly_stats:conv_rate", |
| 1839 | + "driver_hourly_stats:avg_daily_trips", |
| 1840 | + "set_processor_view:unique_locations", |
| 1841 | + "set_processor_view:location_count", |
| 1842 | + "set_processor_view:unique_numbers", |
| 1843 | + "set_processor_view:number_count", |
| 1844 | + "set_processor_view:has_favorite_location", |
| 1845 | + ], |
| 1846 | + ).to_dict() |
| 1847 | + |
| 1848 | + result = {name: value[0] for name, value in online_response.items()} |
| 1849 | + |
| 1850 | + # Type assertions - verify Set types are returned as sets |
| 1851 | + assert isinstance(result["unique_locations"], set) |
| 1852 | + assert all(isinstance(loc, str) for loc in result["unique_locations"]) |
| 1853 | + |
| 1854 | + assert isinstance(result["unique_numbers"], set) |
| 1855 | + assert all(isinstance(num, int) for num in result["unique_numbers"]) |
| 1856 | + |
| 1857 | + assert isinstance(result["location_count"], int) |
| 1858 | + assert isinstance(result["number_count"], int) |
| 1859 | + assert isinstance(result["has_favorite_location"], bool) |
| 1860 | + |
| 1861 | + # Value assertions - verify deduplication worked |
| 1862 | + assert result["unique_locations"] == {"NYC", "LA", "SF"} |
| 1863 | + assert result["location_count"] == 3 # Duplicate "NYC" was removed |
| 1864 | + |
| 1865 | + assert result["unique_numbers"] == {1, 2, 3} |
| 1866 | + assert result["number_count"] == 3 # Duplicates were removed |
| 1867 | + |
| 1868 | + assert result["has_favorite_location"] is True # NYC is in the set |
| 1869 | + |
| 1870 | + # Test with list input that gets converted to set |
| 1871 | + entity_rows_with_list = [ |
| 1872 | + { |
| 1873 | + "driver_id": 1002, |
| 1874 | + "visited_locations": ["Boston", "Boston", "Seattle", "Portland"], # List with duplicates |
| 1875 | + "favorite_numbers": [7, 8, 9, 7], # List with duplicates |
| 1876 | + } |
| 1877 | + ] |
| 1878 | + |
| 1879 | + online_response_list = store.get_online_features( |
| 1880 | + entity_rows=entity_rows_with_list, |
| 1881 | + features=[ |
| 1882 | + "set_processor_view:unique_locations", |
| 1883 | + "set_processor_view:location_count", |
| 1884 | + "set_processor_view:unique_numbers", |
| 1885 | + "set_processor_view:number_count", |
| 1886 | + "set_processor_view:has_favorite_location", |
| 1887 | + ], |
| 1888 | + ).to_dict() |
| 1889 | + |
| 1890 | + result_list = {name: value[0] for name, value in online_response_list.items()} |
| 1891 | + |
| 1892 | + # Verify list input was converted to set and deduplicated |
| 1893 | + assert isinstance(result_list["unique_locations"], set) |
| 1894 | + assert result_list["unique_locations"] == {"Boston", "Seattle", "Portland"} |
| 1895 | + assert result_list["location_count"] == 3 |
| 1896 | + |
| 1897 | + assert isinstance(result_list["unique_numbers"], set) |
| 1898 | + assert result_list["unique_numbers"] == {7, 8, 9} |
| 1899 | + assert result_list["number_count"] == 3 |
| 1900 | + |
| 1901 | + assert result_list["has_favorite_location"] is False # NYC not in the set |
0 commit comments