-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy path__init__.py
More file actions
114 lines (94 loc) · 3.52 KB
/
__init__.py
File metadata and controls
114 lines (94 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Aggregation module for Feast.
"""
from datetime import timedelta
from typing import Any, Dict, Iterable, Optional, Tuple
from google.protobuf.duration_pb2 import Duration
from typeguard import typechecked
from feast.protos.feast.core.Aggregation_pb2 import Aggregation as AggregationProto
@typechecked
class Aggregation:
"""
NOTE: Feast-handled aggregations are not yet supported. This class provides a way to register user-defined aggregations.
Attributes:
column: str # Column name of the feature we are aggregating.
function: str # Provided built in aggregations sum, max, min, count mean
time_window: timedelta # The time window for this aggregation.
slide_interval: timedelta # The sliding window for these aggregations
"""
column: str
function: str
time_window: Optional[timedelta]
slide_interval: Optional[timedelta]
def __init__(
self,
column: Optional[str] = "",
function: Optional[str] = "",
time_window: Optional[timedelta] = None,
slide_interval: Optional[timedelta] = None,
):
self.column = column or ""
self.function = function or ""
self.time_window = time_window
if not slide_interval:
self.slide_interval = self.time_window
else:
self.slide_interval = slide_interval
def to_proto(self) -> AggregationProto:
window_duration = None
if self.time_window is not None:
window_duration = Duration()
window_duration.FromTimedelta(self.time_window)
slide_interval_duration = None
if self.slide_interval is not None:
slide_interval_duration = Duration()
slide_interval_duration.FromTimedelta(self.slide_interval)
return AggregationProto(
column=self.column,
function=self.function,
time_window=window_duration,
slide_interval=slide_interval_duration,
)
@classmethod
def from_proto(cls, agg_proto: AggregationProto):
time_window = (
timedelta(days=0)
if agg_proto.time_window.ToNanoseconds() == 0
else agg_proto.time_window.ToTimedelta()
)
slide_interval = (
timedelta(days=0)
if agg_proto.slide_interval.ToNanoseconds() == 0
else agg_proto.slide_interval.ToTimedelta()
)
aggregation = cls(
column=agg_proto.column,
function=agg_proto.function,
time_window=time_window,
slide_interval=slide_interval,
)
return aggregation
def __eq__(self, other):
if not isinstance(other, Aggregation):
raise TypeError("Comparisons should only involve Aggregations.")
if (
self.column != other.column
or self.function != other.function
or self.time_window != other.time_window
or self.slide_interval != other.slide_interval
):
return False
return True
def aggregation_specs_to_agg_ops(
agg_specs: Iterable[Any],
*,
time_window_unsupported_error_message: str,
) -> Dict[str, Tuple[str, str]]:
agg_ops: Dict[str, Tuple[str, str]] = {}
for agg in agg_specs:
if getattr(agg, "time_window", None) is not None:
raise ValueError(time_window_unsupported_error_message)
alias = f"{agg.function}_{agg.column}"
agg_ops[alias] = (agg.function, agg.column)
return agg_ops
__all__ = ["Aggregation", "aggregation_specs_to_agg_ops"]