Skip to content
17 changes: 16 additions & 1 deletion sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from collections import defaultdict
from datetime import datetime, timezone
from typing import (
Expand Down Expand Up @@ -297,7 +298,7 @@ def _type_err(item, dtype):
None,
),
ValueType.FLOAT: ("float_val", lambda x: float(x), None),
ValueType.DOUBLE: ("double_val", lambda x: x, {float, np.float64}),
ValueType.DOUBLE: ("double_val", lambda x: x, {float, np.float64, int, np.int_}),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see issue #3884

ValueType.STRING: ("string_val", lambda x: str(x), None),
ValueType.BYTES: ("bytes_val", lambda x: x, {bytes}),
ValueType.BOOL: ("bool_val", lambda x: x, {bool, np.bool_, int, np.int_}),
Expand Down Expand Up @@ -353,6 +354,19 @@ def _python_value_to_proto_value(
feast_value_type
]

# Bytes to array type conversion
if isinstance(sample, (bytes, bytearray)):
# Bytes of an array containing elements of bytes not supported
if feast_value_type == ValueType.BYTES_LIST:
raise _type_err(sample, ValueType.BYTES_LIST)

json_value = json.loads(sample)
if isinstance(json_value, list):
if feast_value_type == ValueType.BOOL_LIST:
json_value = [bool(item) for item in json_value]
return [ProtoValue(**{field_name: proto_type(val=json_value)})] # type: ignore
raise _type_err(sample, valid_types[0])

if sample is not None and not all(
type(item) in valid_types for item in sample
):
Expand Down Expand Up @@ -631,6 +645,7 @@ def redshift_to_feast_value_type(redshift_type_as_str: str) -> ValueType:
"varchar": ValueType.STRING,
"timestamp": ValueType.UNIX_TIMESTAMP,
"timestamptz": ValueType.UNIX_TIMESTAMP,
"super": ValueType.BYTES,
# skip date, geometry, hllsketch, time, timetz
}

Expand Down
32 changes: 32 additions & 0 deletions sdk/python/tests/unit/test_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,35 @@ def test_python_values_to_proto_values_bool(values):
converted = feast_value_type_to_python_type(protos[0])

assert converted is bool(values[0])


@pytest.mark.parametrize(
"values, value_type, expected",
(
(np.array([b"[1,2,3]"]), ValueType.INT64_LIST, [1, 2, 3]),
(np.array([b"[1,2,3]"]), ValueType.INT32_LIST, [1, 2, 3]),
(np.array([b"[1.5,2.5,3.5]"]), ValueType.FLOAT_LIST, [1.5, 2.5, 3.5]),
(np.array([b"[1.5,2.5,3.5]"]), ValueType.DOUBLE_LIST, [1.5, 2.5, 3.5]),
(np.array([b'["a","b","c"]']), ValueType.STRING_LIST, ["a", "b", "c"]),
(np.array([b"[true,false]"]), ValueType.BOOL_LIST, [True, False]),
(np.array([b"[1,0]"]), ValueType.BOOL_LIST, [True, False]),
(np.array([None]), ValueType.STRING_LIST, None),
([b"[1,2,3]"], ValueType.INT64_LIST, [1, 2, 3]),
([b"[1,2,3]"], ValueType.INT32_LIST, [1, 2, 3]),
([b"[1.5,2.5,3.5]"], ValueType.FLOAT_LIST, [1.5, 2.5, 3.5]),
([b"[1.5,2.5,3.5]"], ValueType.DOUBLE_LIST, [1.5, 2.5, 3.5]),
([b'["a","b","c"]'], ValueType.STRING_LIST, ["a", "b", "c"]),
([b"[true,false]"], ValueType.BOOL_LIST, [True, False]),
([b"[1,0]"], ValueType.BOOL_LIST, [True, False]),
([None], ValueType.STRING_LIST, None),
),
)
def test_python_values_to_proto_values_bytes_to_list(values, value_type, expected):
protos = python_values_to_proto_values(values, value_type)
converted = feast_value_type_to_python_type(protos[0])
assert converted == expected


def test_python_values_to_proto_values_bytes_to_list_not_supported():
with pytest.raises(TypeError):
_ = python_values_to_proto_values([b"[]"], ValueType.BYTES_LIST)