|
45 | 45 |
|
46 | 46 | class TestOnDemandPythonTransformation(unittest.TestCase): |
47 | 47 | def setUp(self): |
48 | | - with tempfile.TemporaryDirectory() as data_dir: |
49 | | - self.store = FeatureStore( |
50 | | - config=RepoConfig( |
51 | | - project="test_on_demand_python_transformation", |
52 | | - registry=os.path.join(data_dir, "registry.db"), |
53 | | - provider="local", |
54 | | - entity_key_serialization_version=3, |
55 | | - online_store=SqliteOnlineStoreConfig( |
56 | | - path=os.path.join(data_dir, "online.db") |
57 | | - ), |
58 | | - ) |
| 48 | + self.data_dir = tempfile.mkdtemp() |
| 49 | + data_dir = self.data_dir |
| 50 | + self.store = FeatureStore( |
| 51 | + config=RepoConfig( |
| 52 | + project="test_on_demand_python_transformation", |
| 53 | + registry=os.path.join(data_dir, "registry.db"), |
| 54 | + provider="local", |
| 55 | + entity_key_serialization_version=3, |
| 56 | + online_store=SqliteOnlineStoreConfig( |
| 57 | + path=os.path.join(data_dir, "online.db") |
| 58 | + ), |
59 | 59 | ) |
| 60 | + ) |
60 | 61 |
|
61 | | - # Generate test data. |
62 | | - end_date = datetime.now().replace(microsecond=0, second=0, minute=0) |
63 | | - start_date = end_date - timedelta(days=15) |
| 62 | + # Generate test data. |
| 63 | + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) |
| 64 | + start_date = end_date - timedelta(days=15) |
64 | 65 |
|
65 | | - driver_entities = [1001, 1002, 1003, 1004, 1005] |
66 | | - driver_df = create_driver_hourly_stats_df( |
67 | | - driver_entities, start_date, end_date |
68 | | - ) |
69 | | - driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") |
70 | | - driver_df.to_parquet( |
71 | | - path=driver_stats_path, allow_truncated_timestamps=True |
72 | | - ) |
| 66 | + driver_entities = [1001, 1002, 1003, 1004, 1005] |
| 67 | + driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date) |
| 68 | + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") |
| 69 | + driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True) |
73 | 70 |
|
74 | | - driver = Entity( |
75 | | - name="driver", join_keys=["driver_id"], value_type=ValueType.INT64 |
76 | | - ) |
| 71 | + driver = Entity( |
| 72 | + name="driver", join_keys=["driver_id"], value_type=ValueType.INT64 |
| 73 | + ) |
77 | 74 |
|
78 | | - driver_stats_source = FileSource( |
79 | | - name="driver_hourly_stats_source", |
80 | | - path=driver_stats_path, |
81 | | - timestamp_field="event_timestamp", |
82 | | - created_timestamp_column="created", |
83 | | - ) |
84 | | - input_request_source = RequestSource( |
85 | | - name="counter_source", |
86 | | - schema=[ |
87 | | - Field(name="counter", dtype=Int64), |
88 | | - Field(name="input_datetime", dtype=UnixTimestamp), |
89 | | - ], |
90 | | - ) |
| 75 | + driver_stats_source = FileSource( |
| 76 | + name="driver_hourly_stats_source", |
| 77 | + path=driver_stats_path, |
| 78 | + timestamp_field="event_timestamp", |
| 79 | + created_timestamp_column="created", |
| 80 | + ) |
| 81 | + input_request_source = RequestSource( |
| 82 | + name="counter_source", |
| 83 | + schema=[ |
| 84 | + Field(name="counter", dtype=Int64), |
| 85 | + Field(name="input_datetime", dtype=UnixTimestamp), |
| 86 | + ], |
| 87 | + ) |
91 | 88 |
|
92 | | - driver_stats_fv = FeatureView( |
93 | | - name="driver_hourly_stats", |
94 | | - entities=[driver], |
95 | | - ttl=timedelta(days=0), |
96 | | - schema=[ |
97 | | - Field(name="conv_rate", dtype=Float32), |
98 | | - Field(name="acc_rate", dtype=Float32), |
99 | | - Field(name="avg_daily_trips", dtype=Int64), |
100 | | - ], |
101 | | - online=True, |
102 | | - source=driver_stats_source, |
103 | | - ) |
| 89 | + driver_stats_fv = FeatureView( |
| 90 | + name="driver_hourly_stats", |
| 91 | + entities=[driver], |
| 92 | + ttl=timedelta(days=0), |
| 93 | + schema=[ |
| 94 | + Field(name="conv_rate", dtype=Float32), |
| 95 | + Field(name="acc_rate", dtype=Float32), |
| 96 | + Field(name="avg_daily_trips", dtype=Int64), |
| 97 | + ], |
| 98 | + online=True, |
| 99 | + source=driver_stats_source, |
| 100 | + ) |
104 | 101 |
|
105 | | - driver_stats_entity_less_fv = FeatureView( |
106 | | - name="driver_hourly_stats_no_entity", |
107 | | - entities=[], |
108 | | - ttl=timedelta(days=0), |
109 | | - schema=[ |
110 | | - Field(name="conv_rate", dtype=Float32), |
111 | | - Field(name="acc_rate", dtype=Float32), |
112 | | - Field(name="avg_daily_trips", dtype=Int64), |
113 | | - ], |
114 | | - online=True, |
115 | | - source=driver_stats_source, |
116 | | - ) |
| 102 | + driver_stats_entity_less_fv = FeatureView( |
| 103 | + name="driver_hourly_stats_no_entity", |
| 104 | + entities=[], |
| 105 | + ttl=timedelta(days=0), |
| 106 | + schema=[ |
| 107 | + Field(name="conv_rate", dtype=Float32), |
| 108 | + Field(name="acc_rate", dtype=Float32), |
| 109 | + Field(name="avg_daily_trips", dtype=Int64), |
| 110 | + ], |
| 111 | + online=True, |
| 112 | + source=driver_stats_source, |
| 113 | + ) |
117 | 114 |
|
118 | | - @on_demand_feature_view( |
119 | | - sources=[driver_stats_fv], |
120 | | - schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)], |
121 | | - mode="pandas", |
122 | | - ) |
123 | | - def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: |
124 | | - df = pd.DataFrame() |
125 | | - df["conv_rate_plus_acc_pandas"] = ( |
126 | | - inputs["conv_rate"] + inputs["acc_rate"] |
127 | | - ) |
128 | | - return df |
| 115 | + @on_demand_feature_view( |
| 116 | + sources=[driver_stats_fv], |
| 117 | + schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)], |
| 118 | + mode="pandas", |
| 119 | + ) |
| 120 | + def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: |
| 121 | + df = pd.DataFrame() |
| 122 | + df["conv_rate_plus_acc_pandas"] = inputs["conv_rate"] + inputs["acc_rate"] |
| 123 | + return df |
129 | 124 |
|
130 | | - @on_demand_feature_view( |
131 | | - sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], |
132 | | - schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)], |
133 | | - mode="python", |
134 | | - ) |
135 | | - def python_view(inputs: dict[str, Any]) -> dict[str, Any]: |
136 | | - output: dict[str, Any] = { |
137 | | - "conv_rate_plus_acc_python": conv_rate + acc_rate |
| 125 | + @on_demand_feature_view( |
| 126 | + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], |
| 127 | + schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)], |
| 128 | + mode="python", |
| 129 | + ) |
| 130 | + def python_view(inputs: dict[str, Any]) -> dict[str, Any]: |
| 131 | + output: dict[str, Any] = { |
| 132 | + "conv_rate_plus_acc_python": conv_rate + acc_rate |
| 133 | + for conv_rate, acc_rate in zip(inputs["conv_rate"], inputs["acc_rate"]) |
| 134 | + } |
| 135 | + return output |
| 136 | + |
| 137 | + @on_demand_feature_view( |
| 138 | + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], |
| 139 | + schema=[ |
| 140 | + Field(name="conv_rate_plus_val1_python", dtype=Float64), |
| 141 | + Field(name="conv_rate_plus_val2_python", dtype=Float64), |
| 142 | + ], |
| 143 | + mode="python", |
| 144 | + ) |
| 145 | + def python_demo_view(inputs: dict[str, Any]) -> dict[str, Any]: |
| 146 | + output: dict[str, Any] = { |
| 147 | + "conv_rate_plus_val1_python": [ |
| 148 | + conv_rate + acc_rate |
138 | 149 | for conv_rate, acc_rate in zip( |
139 | 150 | inputs["conv_rate"], inputs["acc_rate"] |
140 | 151 | ) |
141 | | - } |
142 | | - return output |
143 | | - |
144 | | - @on_demand_feature_view( |
145 | | - sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], |
146 | | - schema=[ |
147 | | - Field(name="conv_rate_plus_val1_python", dtype=Float64), |
148 | | - Field(name="conv_rate_plus_val2_python", dtype=Float64), |
149 | 152 | ], |
150 | | - mode="python", |
151 | | - ) |
152 | | - def python_demo_view(inputs: dict[str, Any]) -> dict[str, Any]: |
153 | | - output: dict[str, Any] = { |
154 | | - "conv_rate_plus_val1_python": [ |
155 | | - conv_rate + acc_rate |
156 | | - for conv_rate, acc_rate in zip( |
157 | | - inputs["conv_rate"], inputs["acc_rate"] |
158 | | - ) |
159 | | - ], |
160 | | - "conv_rate_plus_val2_python": [ |
161 | | - conv_rate + acc_rate |
162 | | - for conv_rate, acc_rate in zip( |
163 | | - inputs["conv_rate"], inputs["acc_rate"] |
164 | | - ) |
165 | | - ], |
166 | | - } |
167 | | - return output |
168 | | - |
169 | | - @on_demand_feature_view( |
170 | | - sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], |
171 | | - schema=[ |
172 | | - Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64), |
173 | | - Field( |
174 | | - name="conv_rate_plus_acc_python_singleton_array", |
175 | | - dtype=Array(Float64), |
176 | | - ), |
| 153 | + "conv_rate_plus_val2_python": [ |
| 154 | + conv_rate + acc_rate |
| 155 | + for conv_rate, acc_rate in zip( |
| 156 | + inputs["conv_rate"], inputs["acc_rate"] |
| 157 | + ) |
177 | 158 | ], |
178 | | - mode="python", |
179 | | - singleton=True, |
| 159 | + } |
| 160 | + return output |
| 161 | + |
| 162 | + @on_demand_feature_view( |
| 163 | + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], |
| 164 | + schema=[ |
| 165 | + Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64), |
| 166 | + Field( |
| 167 | + name="conv_rate_plus_acc_python_singleton_array", |
| 168 | + dtype=Array(Float64), |
| 169 | + ), |
| 170 | + ], |
| 171 | + mode="python", |
| 172 | + singleton=True, |
| 173 | + ) |
| 174 | + def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: |
| 175 | + output: dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf")) |
| 176 | + output["conv_rate_plus_acc_python_singleton"] = ( |
| 177 | + inputs["conv_rate"] + inputs["acc_rate"] |
180 | 178 | ) |
181 | | - def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: |
182 | | - output: dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf")) |
183 | | - output["conv_rate_plus_acc_python_singleton"] = ( |
184 | | - inputs["conv_rate"] + inputs["acc_rate"] |
185 | | - ) |
186 | | - output["conv_rate_plus_acc_python_singleton_array"] = [0.1, 0.2, 0.3] |
187 | | - return output |
| 179 | + output["conv_rate_plus_acc_python_singleton_array"] = [0.1, 0.2, 0.3] |
| 180 | + return output |
188 | 181 |
|
189 | | - @on_demand_feature_view( |
190 | | - sources=[ |
191 | | - driver_stats_fv[["conv_rate", "acc_rate"]], |
192 | | - input_request_source, |
193 | | - ], |
194 | | - schema=[ |
195 | | - Field(name="conv_rate_plus_acc", dtype=Float64), |
196 | | - Field(name="current_datetime", dtype=UnixTimestamp), |
197 | | - Field(name="counter", dtype=Int64), |
198 | | - Field(name="input_datetime", dtype=UnixTimestamp), |
| 182 | + @on_demand_feature_view( |
| 183 | + sources=[ |
| 184 | + driver_stats_fv[["conv_rate", "acc_rate"]], |
| 185 | + input_request_source, |
| 186 | + ], |
| 187 | + schema=[ |
| 188 | + Field(name="conv_rate_plus_acc", dtype=Float64), |
| 189 | + Field(name="current_datetime", dtype=UnixTimestamp), |
| 190 | + Field(name="counter", dtype=Int64), |
| 191 | + Field(name="input_datetime", dtype=UnixTimestamp), |
| 192 | + ], |
| 193 | + mode="python", |
| 194 | + write_to_online_store=True, |
| 195 | + ) |
| 196 | + def python_stored_writes_feature_view( |
| 197 | + inputs: dict[str, Any], |
| 198 | + ) -> dict[str, Any]: |
| 199 | + output: dict[str, Any] = { |
| 200 | + "conv_rate_plus_acc": [ |
| 201 | + conv_rate + acc_rate |
| 202 | + for conv_rate, acc_rate in zip( |
| 203 | + inputs["conv_rate"], inputs["acc_rate"] |
| 204 | + ) |
199 | 205 | ], |
200 | | - mode="python", |
201 | | - write_to_online_store=True, |
202 | | - ) |
203 | | - def python_stored_writes_feature_view( |
204 | | - inputs: dict[str, Any], |
205 | | - ) -> dict[str, Any]: |
206 | | - output: dict[str, Any] = { |
207 | | - "conv_rate_plus_acc": [ |
208 | | - conv_rate + acc_rate |
209 | | - for conv_rate, acc_rate in zip( |
210 | | - inputs["conv_rate"], inputs["acc_rate"] |
211 | | - ) |
212 | | - ], |
213 | | - "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], |
214 | | - "counter": [c + 1 for c in inputs["counter"]], |
215 | | - "input_datetime": [d for d in inputs["input_datetime"]], |
216 | | - } |
217 | | - return output |
| 206 | + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], |
| 207 | + "counter": [c + 1 for c in inputs["counter"]], |
| 208 | + "input_datetime": [d for d in inputs["input_datetime"]], |
| 209 | + } |
| 210 | + return output |
218 | 211 |
|
219 | | - self.store.apply( |
220 | | - [ |
221 | | - driver, |
222 | | - driver_stats_source, |
223 | | - driver_stats_fv, |
224 | | - pandas_view, |
225 | | - python_view, |
226 | | - python_singleton_view, |
227 | | - python_demo_view, |
228 | | - driver_stats_entity_less_fv, |
229 | | - python_stored_writes_feature_view, |
230 | | - ] |
231 | | - ) |
232 | | - self.store.write_to_online_store( |
233 | | - feature_view_name="driver_hourly_stats", df=driver_df |
234 | | - ) |
235 | | - assert driver_stats_fv.entity_columns == [ |
236 | | - Field(name=driver.join_key, dtype=from_value_type(driver.value_type)) |
| 212 | + self.store.apply( |
| 213 | + [ |
| 214 | + driver, |
| 215 | + driver_stats_source, |
| 216 | + driver_stats_fv, |
| 217 | + pandas_view, |
| 218 | + python_view, |
| 219 | + python_singleton_view, |
| 220 | + python_demo_view, |
| 221 | + driver_stats_entity_less_fv, |
| 222 | + python_stored_writes_feature_view, |
237 | 223 | ] |
238 | | - assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD] |
| 224 | + ) |
| 225 | + self.store.write_to_online_store( |
| 226 | + feature_view_name="driver_hourly_stats", df=driver_df |
| 227 | + ) |
| 228 | + assert driver_stats_fv.entity_columns == [ |
| 229 | + Field(name=driver.join_key, dtype=from_value_type(driver.value_type)) |
| 230 | + ] |
| 231 | + assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD] |
239 | 232 |
|
240 | | - assert len(self.store.list_all_feature_views()) == 7 |
241 | | - assert len(self.store.list_feature_views()) == 2 |
242 | | - assert len(self.store.list_on_demand_feature_views()) == 5 |
243 | | - assert len(self.store.list_stream_feature_views()) == 0 |
| 233 | + assert len(self.store.list_all_feature_views()) == 7 |
| 234 | + assert len(self.store.list_feature_views()) == 2 |
| 235 | + assert len(self.store.list_on_demand_feature_views()) == 5 |
| 236 | + assert len(self.store.list_stream_feature_views()) == 0 |
| 237 | + |
| 238 | + def tearDown(self): |
| 239 | + import shutil |
| 240 | + |
| 241 | + if hasattr(self, "data_dir"): |
| 242 | + shutil.rmtree(self.data_dir, ignore_errors=True) |
244 | 243 |
|
245 | 244 | def test_setup(self): |
246 | 245 | pass |
|
0 commit comments