-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathconflict_resolver.py
More file actions
157 lines (129 loc) · 5.47 KB
/
conflict_resolver.py
File metadata and controls
157 lines (129 loc) · 5.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Offline-store conflict resolution for LabelView labels.
Applies the configured ConflictPolicy to a DataFrame containing all historical
label rows (from pull_all_from_table_or_query), producing one resolved row per
entity key.
The online store continues to use LAST_WRITE_WINS regardless of policy — this
resolver is for offline/batch reads used in training data generation, the UI
browse/quality endpoints, and any batch pipeline consuming resolved labels.
"""
from typing import List, Optional
import pandas as pd
from feast.labeling.conflict_policy import ConflictPolicy
def resolve_conflicts(
df: pd.DataFrame,
join_key_columns: List[str],
feature_name_columns: List[str],
timestamp_field: str,
labeler_field: str,
conflict_policy: ConflictPolicy,
labeler_priorities: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Resolve label conflicts by applying the configured policy.
Args:
df: Full history DataFrame (all rows, not deduplicated).
join_key_columns: Entity key column names.
feature_name_columns: Label/feature column names.
timestamp_field: Event timestamp column name.
labeler_field: Column identifying who wrote the label.
conflict_policy: The resolution strategy to apply.
labeler_priorities: Ordered list of labelers from highest to lowest
priority. Required for LABELER_PRIORITY policy.
Returns:
DataFrame with one resolved row per unique entity key combination.
"""
if df.empty:
return df
if conflict_policy == ConflictPolicy.LAST_WRITE_WINS:
return _resolve_last_write_wins(df, join_key_columns, timestamp_field)
elif conflict_policy == ConflictPolicy.LABELER_PRIORITY:
return _resolve_labeler_priority(
df, join_key_columns, timestamp_field, labeler_field, labeler_priorities
)
elif conflict_policy == ConflictPolicy.MAJORITY_VOTE:
return _resolve_majority_vote(
df, join_key_columns, feature_name_columns, timestamp_field
)
else:
return _resolve_last_write_wins(df, join_key_columns, timestamp_field)
def _resolve_last_write_wins(
df: pd.DataFrame,
join_key_columns: List[str],
timestamp_field: str,
) -> pd.DataFrame:
"""Keep only the row with the latest timestamp per entity."""
df_sorted = df.sort_values(timestamp_field, ascending=False)
return df_sorted.drop_duplicates(subset=join_key_columns, keep="first").reset_index(
drop=True
)
def _resolve_labeler_priority(
df: pd.DataFrame,
join_key_columns: List[str],
timestamp_field: str,
labeler_field: str,
labeler_priorities: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Pick the label from the highest-priority labeler per entity.
If multiple rows exist from the same priority labeler, the latest timestamp
wins. Labelers not in the priority list are ranked lowest.
"""
if not labeler_priorities:
return _resolve_last_write_wins(df, join_key_columns, timestamp_field)
priority_map = {name: i for i, name in enumerate(labeler_priorities)}
max_priority = len(labeler_priorities)
df = df.copy()
df["_priority_rank"] = df[labeler_field].map(
lambda x: priority_map.get(x, max_priority)
)
df_sorted = df.sort_values(
["_priority_rank", timestamp_field], ascending=[True, False]
)
result = df_sorted.drop_duplicates(subset=join_key_columns, keep="first")
result = result.drop(columns=["_priority_rank"])
return result.reset_index(drop=True)
def _resolve_majority_vote(
df: pd.DataFrame,
join_key_columns: List[str],
feature_name_columns: List[str],
timestamp_field: str,
) -> pd.DataFrame:
"""For each entity, pick the most common value per feature column.
Uses the most frequent value across all labelers for each feature.
Ties are broken by recency (latest timestamp wins).
"""
if not feature_name_columns:
return _resolve_last_write_wins(df, join_key_columns, timestamp_field)
groups = df.groupby(join_key_columns, sort=False)
resolved_rows = []
for keys, group in groups:
if not isinstance(keys, tuple):
keys = (keys,)
row = dict(zip(join_key_columns, keys))
for col in feature_name_columns:
value_counts = group[col].value_counts()
if value_counts.empty:
row[col] = None
continue
top_value = value_counts.index[0]
top_count = value_counts.iloc[0]
# Tie-breaking: if multiple values have the same count, pick the
# one with the most recent timestamp
tied = value_counts[value_counts == top_count]
if len(tied) > 1:
tied_values = tied.index.tolist()
tied_rows = group[group[col].isin(tied_values)]
latest_row = tied_rows.sort_values(
timestamp_field, ascending=False
).iloc[0]
row[col] = latest_row[col]
else:
row[col] = top_value
row[timestamp_field] = group[timestamp_field].max()
if "labeler" in group.columns and "labeler" not in join_key_columns:
row["labeler"] = "majority_vote"
resolved_rows.append(row)
if not resolved_rows:
return df.head(0)
result = pd.DataFrame(resolved_rows)
# Preserve column order from original
cols = [c for c in df.columns if c in result.columns]
return result[cols].reset_index(drop=True)