-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathgrpc_server.py
More file actions
187 lines (166 loc) · 6.47 KB
/
grpc_server.py
File metadata and controls
187 lines (166 loc) · 6.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import logging
import threading
from collections.abc import Mapping
from concurrent import futures
from typing import Optional, Union
import grpc
import pandas as pd
from feast.data_source import PushMode
from feast.errors import FeatureServiceNotFoundException, PushSourceNotFoundException
from feast.feature_service import FeatureService
from feast.feature_store import FeatureStore
from feast.protos.feast.serving.GrpcServer_pb2 import (
PushResponse,
WriteToOnlineStoreResponse,
)
from feast.protos.feast.serving.GrpcServer_pb2_grpc import (
GrpcFeatureServerServicer,
add_GrpcFeatureServerServicer_to_server,
)
from feast.protos.feast.serving.ServingService_pb2 import (
GetOnlineFeaturesRequest,
GetOnlineFeaturesResponse,
)
logger = logging.getLogger(__name__)
def parse(features):
df = {}
for i in features.keys():
df[i] = [features.get(i)]
return pd.DataFrame.from_dict(df)
def parse_typed(typed_features):
df = {}
for key, value in typed_features.items():
val_case = value.WhichOneof("val")
if val_case is None or val_case == "null_val":
df[key] = [None]
else:
raw = getattr(value, val_case)
if hasattr(raw, "val"):
raw = dict(raw.val) if isinstance(raw.val, Mapping) else list(raw.val)
df[key] = [raw]
return pd.DataFrame.from_dict(df)
class GrpcFeatureServer(GrpcFeatureServerServicer):
fs: FeatureStore
_shuting_down: bool = False
_active_timer: Optional[threading.Timer] = None
def __init__(self, fs: FeatureStore, registry_ttl_sec: int = 5):
self.fs = fs
self.registry_ttl_sec = registry_ttl_sec
super().__init__()
self._async_refresh()
def Push(self, request, context):
try:
if request.features and request.typed_features:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(
"Only one of features or typed_features may be set, not both"
)
return PushResponse(status=False)
df = (
parse_typed(request.typed_features)
if request.typed_features
else parse(request.features)
)
if request.to == "offline":
to = PushMode.OFFLINE
elif request.to == "online":
to = PushMode.ONLINE
elif request.to == "online_and_offline":
to = PushMode.ONLINE_AND_OFFLINE
else:
raise ValueError(
f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', "
f"'online_and_offline']."
)
self.fs.push(
push_source_name=request.stream_feature_view,
df=df,
allow_registry_cache=request.allow_registry_cache,
to=to,
)
except PushSourceNotFoundException as e:
logger.exception(str(e))
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(str(e))
return PushResponse(status=False)
except Exception as e:
logger.exception(str(e))
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return PushResponse(status=False)
return PushResponse(status=True)
def WriteToOnlineStore(self, request, context):
logger.warning(
"write_to_online_store is deprecated. Please consider using Push instead"
)
try:
if request.features and request.typed_features:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(
"Only one of features or typed_features may be set, not both"
)
return WriteToOnlineStoreResponse(status=False)
df = (
parse_typed(request.typed_features)
if request.typed_features
else parse(request.features)
)
self.fs.write_to_online_store(
feature_view_name=request.feature_view_name,
df=df,
allow_registry_cache=request.allow_registry_cache,
)
except Exception as e:
logger.exception(str(e))
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return WriteToOnlineStoreResponse(status=False)
return WriteToOnlineStoreResponse(status=True)
def GetOnlineFeatures(self, request: GetOnlineFeaturesRequest, context):
if request.HasField("feature_service"):
logger.info(f"Requesting feature service: {request.feature_service}")
try:
features: Union[list[str], FeatureService] = (
self.fs.get_feature_service(
request.feature_service, allow_cache=True
)
)
except FeatureServiceNotFoundException as e:
logger.error(f"Feature service {request.feature_service} not found")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return GetOnlineFeaturesResponse()
else:
features = list(request.features.val)
result = self.fs.get_online_features(
features,
request.entities,
request.full_feature_names,
).proto
return result
def _async_refresh(self):
self.fs.refresh_registry()
if self._shuting_down:
return
self._active_timer = threading.Timer(self.registry_ttl_sec, self._async_refresh)
self._active_timer.start()
def get_grpc_server(
address: str,
fs: FeatureStore,
max_workers: int,
registry_ttl_sec: int,
):
from grpc_health.v1 import health, health_pb2_grpc
logger.info(f"Initializing gRPC server on {address}")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers))
add_GrpcFeatureServerServicer_to_server(
GrpcFeatureServer(fs, registry_ttl_sec=registry_ttl_sec),
server,
)
health_servicer = health.HealthServicer(
experimental_non_blocking=True,
experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
)
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
server.add_insecure_port(address)
return server