-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathlabel_view.py
More file actions
446 lines (383 loc) · 17 KB
/
label_view.py
File metadata and controls
446 lines (383 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
import copy
import warnings
from datetime import timedelta
from typing import Any, Dict, List, Optional, Type
from google.protobuf.duration_pb2 import Duration
from google.protobuf.message import Message
from typeguard import typechecked
from feast.base_feature_view import BaseFeatureView
from feast.data_source import DataSource
from feast.entity import Entity
from feast.feature_view_projection import FeatureViewProjection
from feast.field import Field
from feast.labeling.conflict_policy import ConflictPolicy
from feast.proto_utils import serialize_data_source
from feast.protos.feast.core.LabelView_pb2 import LabelView as LabelViewProto
from feast.protos.feast.core.LabelView_pb2 import LabelViewMeta as LabelViewMetaProto
from feast.protos.feast.core.LabelView_pb2 import LabelViewSpec as LabelViewSpecProto
from feast.types import String as FeastString
from feast.types import from_value_type
from feast.value_type import ValueType
warnings.simplefilter("once", DeprecationWarning)
@typechecked
class LabelView(BaseFeatureView):
"""A LabelView manages mutable labels decoupled from immutable feature data.
A LabelView defines a mutable set of labels or annotations that are kept
separate from the immutable feature data stored in regular FeatureViews.
It supports multi-labeler workflows where different sources (human reviewers,
automated safety scanners, reward models) can independently write labels for
the same entity keys.
.. note::
**Enforcement scope:**
- ``conflict_policy`` is enforced for **offline store reads** (training
data generation, UI browse/quality endpoints, batch pipelines). The
online store always uses LAST_WRITE_WINS for low-latency serving.
- The offline store always retains full write history (all writes are
appended). The online store keeps only the latest value per entity key.
Attributes:
name: The unique name of the label view.
entities: The list of entity names associated with this label view.
ttl: How long labels are valid for online serving. ``timedelta(0)``
means labels never expire.
source: The data source (typically a ``PushSource``) feeding label data.
entity_columns: The entity key columns in the schema.
features: The label columns (non-entity fields in the schema).
online: Whether labels are served from the online store.
description: A human-readable description.
tags: Arbitrary key-value metadata.
owner: Owner email or identifier.
labeler_field: Name of the schema field that identifies who wrote the
label (default ``"labeler"``).
conflict_policy: How conflicting labels from different labelers are
resolved (default ``ConflictPolicy.LAST_WRITE_WINS``). Enforced for
offline store reads (training, UI). Online store uses LAST_WRITE_WINS.
reference_feature_view: Optional name of the ``FeatureView`` whose
entities this label view annotates.
"""
name: str
entities: List[str]
ttl: Optional[timedelta]
source: Optional[DataSource]
entity_columns: List[Field]
features: List[Field]
online: bool
description: str
tags: Dict[str, str]
owner: str
labeler_field: str
conflict_policy: ConflictPolicy
reference_feature_view: Optional[str]
def __init__(
self,
*,
name: str,
source: Optional[DataSource] = None,
schema: Optional[List[Field]] = None,
entities: Optional[List[Entity]] = None,
ttl: Optional[timedelta] = timedelta(days=0),
online: bool = True,
description: str = "",
tags: Optional[Dict[str, str]] = None,
owner: str = "",
labeler_field: str = "labeler",
conflict_policy: ConflictPolicy = ConflictPolicy.LAST_WRITE_WINS,
reference_feature_view: Optional[str] = None,
):
"""Creates a LabelView object.
Args:
name: The unique name of this label view.
source: The data source for ingesting labels, typically a
``PushSource``. If ``None``, labels can only be written
programmatically via ``FeatureStore.push()``.
schema: The list of ``Field`` objects describing both entity
columns and label columns. Entity columns are identified
by matching against entity join keys.
entities: The list of ``Entity`` objects whose join keys are
used to key the labels.
ttl: The time-to-live for labels in the online store.
``timedelta(0)`` means labels never expire. ``None`` means
the label view inherits the default TTL.
online: Whether this label view should be materialized to the
online store for low-latency serving.
description: A human-readable description of what the labels
represent.
tags: A dictionary of key-value pairs for arbitrary metadata.
owner: The owner of this label view, typically an email address.
labeler_field: The name of the field in the schema that
identifies the labeler. Defaults to ``"labeler"``.
conflict_policy: The policy for resolving conflicting labels
from different labelers. Defaults to
``ConflictPolicy.LAST_WRITE_WINS``. Enforced for offline
store reads (training, UI). Online store uses LAST_WRITE_WINS.
reference_feature_view: The name of the ``FeatureView`` whose
entities this label view annotates. This is informational
and does not create a hard dependency.
"""
self.ttl = ttl
self.entities = []
self.source = source
schema = schema or []
features: List[Field] = []
self.entity_columns = []
join_keys: List[str] = []
if entities:
for entity in entities:
join_keys.append(entity.join_key)
if entity.name != entity.join_key:
self.entities.append(entity.name)
else:
self.entities.append(entity.name)
if len(set(join_keys)) < len(join_keys):
raise ValueError(
"A label view should not have entities that share a join key."
)
for field in schema:
if field.name in join_keys:
self.entity_columns.append(field)
matching_entities = (
[e for e in entities if e.join_key == field.name]
if entities
else []
)
if matching_entities:
entity = matching_entities[0]
if entity.value_type != ValueType.UNKNOWN:
if from_value_type(entity.value_type) != field.dtype:
raise ValueError(
f"Entity {entity.name} has type {entity.value_type}, "
f"which does not match the inferred type {field.dtype}."
)
else:
features.append(field)
existing_entity_col_names = {ec.name for ec in self.entity_columns}
for jk in join_keys:
if jk not in existing_entity_col_names:
matching = [e for e in entities if e.join_key == jk] if entities else []
if matching and matching[0].value_type != ValueType.UNKNOWN:
dtype = from_value_type(matching[0].value_type)
else:
dtype = FeastString
self.entity_columns.append(Field(name=jk, dtype=dtype))
self.labeler_field = labeler_field
self.conflict_policy = conflict_policy
self.reference_feature_view = reference_feature_view or ""
super().__init__(
name=name,
features=features,
description=description,
tags=tags,
owner=owner,
source=source,
)
self.projection.view_type = "labelView"
self.online = online
def __hash__(self):
return super().__hash__()
def __copy__(self):
lv = LabelView(
name=self.name,
ttl=self.ttl,
source=self.source,
schema=self.schema,
tags=self.tags,
online=self.online,
description=self.description,
owner=self.owner,
labeler_field=self.labeler_field,
conflict_policy=self.conflict_policy,
reference_feature_view=self.reference_feature_view or None,
)
lv.entities = list(self.entities)
lv.features = copy.copy(self.features)
lv.entity_columns = copy.copy(self.entity_columns)
lv.projection = copy.copy(self.projection)
lv.version = self.version
lv.current_version_number = self.current_version_number
return lv
def __eq__(self, other):
if not isinstance(other, LabelView):
raise TypeError("Comparisons should only involve LabelView class objects.")
if not super().__eq__(other):
return False
if (
sorted(self.entities) != sorted(other.entities)
or self.ttl != other.ttl
or self.online != other.online
or sorted(self.entity_columns) != sorted(other.entity_columns)
or self.labeler_field != other.labeler_field
or self.conflict_policy != other.conflict_policy
or self.reference_feature_view != other.reference_feature_view
):
return False
return True
@property
def join_keys(self) -> List[str]:
"""The entity join key column names for this label view."""
return [ec.name for ec in self.entity_columns]
@property
def schema(self) -> List[Field]:
"""The full schema including both entity columns and label columns."""
return list(set(self.entity_columns + self.features))
@property
def batch_source(self) -> Optional[DataSource]:
"""The batch data source for this label view.
If the source is a ``PushSource``, returns its underlying
``batch_source``. Otherwise returns the source directly.
This property enables compatibility with offline store
``get_historical_features`` implementations.
"""
from feast.data_source import PushSource
if self.source is None:
return None
if isinstance(self.source, PushSource):
return self.source.batch_source
return self.source
@property
def stream_source(self) -> Optional[DataSource]:
"""The stream data source for this label view.
Returns the source if it is a ``PushSource``, ``None`` otherwise.
"""
from feast.data_source import PushSource
if self.source is not None and isinstance(self.source, PushSource):
return self.source
return None
# --- Labeling method helpers (parsed from tags) ---
_TAG_PREFIX_PROFILE = "feast.io/labeling-method"
_TAG_PREFIX_ROLE = "feast.io/field-role:"
_TAG_PREFIX_VALUES = "feast.io/label-values:"
_TAG_PREFIX_WIDGET = "feast.io/label-widget:"
@property
def labeling_method(self) -> str:
"""The labeling method for this label view.
Parsed from the ``feast.io/labeling-method`` tag. Supported
methods: ``table`` (default), ``document-span``,
``entity-form``, ``active-learning``.
"""
return self.tags.get(self._TAG_PREFIX_PROFILE, "table")
@property
def annotation_config(self) -> Dict[str, Any]:
"""Structured annotation configuration derived from tags.
Returns a dict with::
{
"profile": "document-span",
"field_roles": {"source_document": "content_ref", ...},
"label_values": {"relevance": ["relevant", "irrelevant"]},
"label_widgets": {"relevance": "binary", ...},
}
"""
field_roles: Dict[str, str] = {}
label_values: Dict[str, List[str]] = {}
label_widgets: Dict[str, str] = {}
for key, value in self.tags.items():
if key.startswith(self._TAG_PREFIX_ROLE):
field_name = key[len(self._TAG_PREFIX_ROLE) :]
field_roles[field_name] = value
elif key.startswith(self._TAG_PREFIX_VALUES):
field_name = key[len(self._TAG_PREFIX_VALUES) :]
label_values[field_name] = [v.strip() for v in value.split(",")]
elif key.startswith(self._TAG_PREFIX_WIDGET):
field_name = key[len(self._TAG_PREFIX_WIDGET) :]
label_widgets[field_name] = value
return {
"profile": self.labeling_method,
"field_roles": field_roles,
"label_values": label_values,
"label_widgets": label_widgets,
}
def ensure_valid(self):
"""Validates the label view configuration.
Raises:
ValueError: If the label view has no name (from ``BaseFeatureView``)
or no entities.
"""
super().ensure_valid()
if not self.entities:
raise ValueError("Label view has no entities.")
@property
def proto_class(self) -> Type[Message]:
"""The protobuf message class for LabelView."""
return LabelViewProto
def to_proto(self) -> LabelViewProto:
"""Converts this LabelView to its protobuf representation.
Returns:
A ``LabelViewProto`` message with the spec and metadata populated.
"""
meta = LabelViewMetaProto()
if self.created_timestamp:
meta.created_timestamp.FromDatetime(self.created_timestamp)
if self.last_updated_timestamp:
meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp)
ttl_duration = None
if self.ttl is not None:
ttl_duration = Duration()
ttl_duration.FromTimedelta(self.ttl)
source_proto = serialize_data_source(self.source)
spec = LabelViewSpecProto(
name=self.name,
entities=self.entities,
entity_columns=[field.to_proto() for field in self.entity_columns],
features=[feature.to_proto() for feature in self.features],
description=self.description,
tags=self.tags,
owner=self.owner,
ttl=(ttl_duration if ttl_duration is not None else None),
online=self.online,
source=source_proto,
labeler_field=self.labeler_field,
conflict_policy=self.conflict_policy.to_proto(), # type: ignore[arg-type]
reference_feature_view=self.reference_feature_view or "",
)
return LabelViewProto(spec=spec, meta=meta)
@classmethod
def from_proto(cls, label_view_proto: LabelViewProto) -> "LabelView":
"""Creates a LabelView from a protobuf representation.
Args:
label_view_proto: A ``LabelViewProto`` message to deserialize.
Returns:
A ``LabelView`` instance populated from the protobuf data.
"""
source = (
DataSource.from_proto(label_view_proto.spec.source)
if label_view_proto.spec.HasField("source")
else None
)
label_view = cls(
name=label_view_proto.spec.name,
description=label_view_proto.spec.description,
tags=dict(label_view_proto.spec.tags),
owner=label_view_proto.spec.owner,
online=label_view_proto.spec.online,
ttl=(
timedelta(days=0)
if label_view_proto.spec.ttl.ToNanoseconds() == 0
else label_view_proto.spec.ttl.ToTimedelta()
),
source=source,
labeler_field=label_view_proto.spec.labeler_field or "labeler",
conflict_policy=ConflictPolicy.from_proto(
label_view_proto.spec.conflict_policy
),
reference_feature_view=(
label_view_proto.spec.reference_feature_view or None
),
)
label_view.entities = list(label_view_proto.spec.entities)
label_view.features = [
Field.from_proto(field_proto)
for field_proto in label_view_proto.spec.features
]
label_view.entity_columns = [
Field.from_proto(field_proto)
for field_proto in label_view_proto.spec.entity_columns
]
label_view.projection = FeatureViewProjection.from_definition(label_view)
label_view.projection.view_type = "labelView"
if label_view_proto.meta.HasField("created_timestamp"):
label_view.created_timestamp = (
label_view_proto.meta.created_timestamp.ToDatetime()
)
if label_view_proto.meta.HasField("last_updated_timestamp"):
label_view.last_updated_timestamp = (
label_view_proto.meta.last_updated_timestamp.ToDatetime()
)
return label_view