-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathtype_map.py
More file actions
88 lines (74 loc) · 2.95 KB
/
type_map.py
File metadata and controls
88 lines (74 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from datetime import timezone
from typing import List
import pyarrow as pa
from feast.protos.feast.types import Value_pb2
from feast.types import Array, PrimitiveFeastType
PA_TIMESTAMP_TYPE = pa.timestamp("s", tz=timezone.utc)
ARROW_TYPE_TO_PROTO_FIELD = {
pa.int32(): "int32_val",
pa.int64(): "int64_val",
pa.float32(): "float_val",
pa.float64(): "double_val",
pa.bool_(): "bool_val",
pa.string(): "string_val",
pa.binary(): "bytes_val",
PA_TIMESTAMP_TYPE: "unix_timestamp_val",
}
ARROW_LIST_TYPE_TO_PROTO_FIELD = {
pa.int32(): "int32_list_val",
pa.int64(): "int64_list_val",
pa.float32(): "float_list_val",
pa.float64(): "double_list_val",
pa.bool_(): "bool_list_val",
pa.string(): "string_list_val",
pa.binary(): "bytes_list_val",
PA_TIMESTAMP_TYPE: "unix_timestamp_list_val",
}
ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS = {
pa.int32(): Value_pb2.Int32List,
pa.int64(): Value_pb2.Int64List,
pa.float32(): Value_pb2.FloatList,
pa.float64(): Value_pb2.DoubleList,
pa.bool_(): Value_pb2.BoolList,
pa.string(): Value_pb2.StringList,
pa.binary(): Value_pb2.BytesList,
PA_TIMESTAMP_TYPE: Value_pb2.Int64List,
}
FEAST_TYPE_TO_ARROW_TYPE = {
PrimitiveFeastType.INT32: pa.int32(),
PrimitiveFeastType.INT64: pa.int64(),
PrimitiveFeastType.FLOAT32: pa.float32(),
PrimitiveFeastType.FLOAT64: pa.float64(),
PrimitiveFeastType.STRING: pa.string(),
PrimitiveFeastType.BYTES: pa.binary(),
PrimitiveFeastType.BOOL: pa.bool_(),
PrimitiveFeastType.UNIX_TIMESTAMP: pa.timestamp("s"),
Array(PrimitiveFeastType.INT32): pa.list_(pa.int32()),
Array(PrimitiveFeastType.INT64): pa.list_(pa.int64()),
Array(PrimitiveFeastType.FLOAT32): pa.list_(pa.float32()),
Array(PrimitiveFeastType.FLOAT64): pa.list_(pa.float64()),
Array(PrimitiveFeastType.STRING): pa.list_(pa.string()),
Array(PrimitiveFeastType.BYTES): pa.list_(pa.binary()),
Array(PrimitiveFeastType.BOOL): pa.list_(pa.bool_()),
Array(PrimitiveFeastType.UNIX_TIMESTAMP): pa.list_(pa.timestamp("s")),
}
def arrow_array_to_array_of_proto(
arrow_type: pa.DataType, arrow_array: pa.Array
) -> List[Value_pb2.Value]:
values = []
if isinstance(arrow_type, pa.ListType):
proto_list_class = ARROW_LIST_TYPE_TO_PROTO_LIST_CLASS[arrow_type.value_type]
proto_field_name = ARROW_LIST_TYPE_TO_PROTO_FIELD[arrow_type.value_type]
if arrow_type.value_type == PA_TIMESTAMP_TYPE:
arrow_array = arrow_array.cast(pa.list_(pa.int64()))
for v in arrow_array.tolist():
values.append(
Value_pb2.Value(**{proto_field_name: proto_list_class(val=v)})
)
else:
proto_field_name = ARROW_TYPE_TO_PROTO_FIELD[arrow_type]
if arrow_type == PA_TIMESTAMP_TYPE:
arrow_array = arrow_array.cast(pa.int64())
for v in arrow_array.tolist():
values.append(Value_pb2.Value(**{proto_field_name: v}))
return values