Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def trino_to_feast_value_type(trino_type_as_str: str) -> ValueType:
"boolean": ValueType.BOOL,
"real": ValueType.FLOAT,
"date": ValueType.STRING,
"binary": ValueType.STRING,
"varbinary": ValueType.STRING,
"json": ValueType.STRING,
}
_trino_type_as_str: str = trino_type_as_str
trino_type_as_str = trino_type_as_str.lower()
Expand All @@ -36,13 +39,18 @@ def trino_to_feast_value_type(trino_type_as_str: str) -> ValueType:
trino_type_as_str = "decimal64"
else:
trino_type_as_str = "decimal32"
else:
trino_type_as_str = "decimal64"

elif trino_type_as_str.startswith("timestamp"):
trino_type_as_str = "timestamp"

elif trino_type_as_str.startswith("varchar"):
trino_type_as_str = "varchar"

elif trino_type_as_str.startswith("char"):
trino_type_as_str = "char"

if trino_type_as_str not in type_map:
raise ValueError(f"Trino type not supported by feast {_trino_type_as_str}")
return type_map[trino_type_as_str]
Expand All @@ -55,7 +63,11 @@ def pa_to_trino_value_type(pa_type_as_str: str) -> str:
trino_type = "{}"
if pa_type_as_str.startswith("list"):
trino_type = "array<{}>"
pa_type_as_str = re.search(r"^list<item:\s(.+)>$", pa_type_as_str).group(1)
match = re.search(r"^list<item:\s(.+)>$", pa_type_as_str)
if match:
pa_type_as_str = match.group(1)
else:
return trino_type.format("varchar")

if pa_type_as_str.startswith("date"):
return trino_type.format("date")
Expand All @@ -67,7 +79,10 @@ def pa_to_trino_value_type(pa_type_as_str: str) -> str:
return trino_type.format("timestamp")

if pa_type_as_str.startswith("decimal"):
return trino_type.format(pa_type_as_str)
# PyArrow renders decimal types as decimal128(10, 2) or decimal256(10, 2),
# but Trino expects just decimal(10, 2)
normalized = re.sub(r"^decimal\d+", "decimal", pa_type_as_str)
return trino_type.format(normalized)

if pa_type_as_str.startswith("map<"):
return trino_type.format("varchar")
Expand All @@ -92,33 +107,43 @@ def pa_to_trino_value_type(pa_type_as_str: str) -> str:
"float": "double",
"double": "double",
"binary": "binary",
"varbinary": "binary",
"string": "varchar",
"char": "varchar",
}
return trino_type.format(type_map[pa_type_as_str])


_TRINO_TO_PA_TYPE_MAP = {
_TRINO_TO_PA_TYPE_MAP: Dict[str, pa.DataType] = {
"null": pa.null(),
"boolean": pa.bool_(),
"date": pa.date32(),
"tinyint": pa.int8(),
"smallint": pa.int16(),
"integer": pa.int32(),
"int": pa.int32(),
"bigint": pa.int64(),
"double": pa.float64(),
"binary": pa.binary(),
"varbinary": pa.binary(),
"char": pa.string(),
"json": pa.string(),
"real": pa.float32(),
}


def _trino_array_item_type(trino_type_as_str: str) -> str | None:
if trino_type_as_str.startswith("array(") and trino_type_as_str.endswith(")"):
return trino_type_as_str[6:-1].strip()
return None


def trino_to_pa_value_type(trino_type_as_str: str) -> pa.DataType:
trino_type_as_str = trino_type_as_str.lower()
trino_type_as_str = trino_type_as_str.lower().strip()

_is_list: bool = False
if trino_type_as_str.startswith("array"):
_is_list = True
trino_type_as_str = re.search(r"^array\((\w+)\)$", trino_type_as_str).group(1)
array_item_type = _trino_array_item_type(trino_type_as_str)
if array_item_type is not None:
return pa.list_(trino_to_pa_value_type(array_item_type))

if trino_type_as_str.startswith("decimal"):
search_precision = re.search(
Expand All @@ -127,20 +152,24 @@ def trino_to_pa_value_type(trino_type_as_str: str) -> pa.DataType:
if search_precision:
precision = int(search_precision.group(1))
if precision > 32:
pa_type = pa.float64()
return pa.float64()
else:
pa_type = pa.float32()
return pa.float32()
return pa.float64()

elif trino_type_as_str.startswith("timestamp"):
pa_type = pa.timestamp("us")
if trino_type_as_str.startswith("timestamp"):
return pa.timestamp("us")

elif trino_type_as_str.startswith("varchar"):
pa_type = pa.string()
if trino_type_as_str.startswith("varchar"):
return pa.string()

if trino_type_as_str.startswith("char"):
return pa.string()

if trino_type_as_str.startswith("row("):
return pa.string()

else:
pa_type = _TRINO_TO_PA_TYPE_MAP[trino_type_as_str]
if trino_type_as_str.startswith("map("):
return pa.string()

if _is_list:
return pa.list_(pa_type)
else:
return pa_type
return _TRINO_TO_PA_TYPE_MAP[trino_type_as_str]
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import pyarrow as pa
import pytest

from feast import ValueType
from feast.infra.offline_stores.contrib.trino_offline_store.trino_type_map import (
_trino_array_item_type,
pa_to_trino_value_type,
trino_to_feast_value_type,
trino_to_pa_value_type,
)


class TestTrinoArrayItemType:
def test_simple_type(self) -> None:
assert _trino_array_item_type("array(bigint)") == "bigint"

def test_parameterized_type(self) -> None:
assert _trino_array_item_type("array(varchar(10))") == "varchar(10)"

def test_parameterized_with_comma(self) -> None:
assert _trino_array_item_type("array(decimal(10, 2))") == "decimal(10, 2)"

def test_nested_array(self) -> None:
assert _trino_array_item_type("array(array(varchar))") == "array(varchar)"

def test_complex_row(self) -> None:
assert (
_trino_array_item_type("array(row(x bigint, y varchar(10)))")
== "row(x bigint, y varchar(10))"
)

def test_not_an_array(self) -> None:
assert _trino_array_item_type("varchar") is None

def test_partial_prefix(self) -> None:
assert _trino_array_item_type("array") is None
assert _trino_array_item_type("array(") is None


class TestTrinoToFeastValueType:
def test_simple_types(self) -> None:
assert trino_to_feast_value_type("boolean") == ValueType.BOOL
assert trino_to_feast_value_type("bigint") == ValueType.INT64
assert trino_to_feast_value_type("integer") == ValueType.INT32
assert trino_to_feast_value_type("int") == ValueType.INT32
assert trino_to_feast_value_type("double") == ValueType.DOUBLE
assert trino_to_feast_value_type("real") == ValueType.FLOAT
assert trino_to_feast_value_type("date") == ValueType.STRING
assert trino_to_feast_value_type("tinyint") == ValueType.INT32
assert trino_to_feast_value_type("smallint") == ValueType.INT32

def test_parameterized_varchar(self) -> None:
assert trino_to_feast_value_type("varchar(10)") == ValueType.STRING

def test_parameterized_char(self) -> None:
assert trino_to_feast_value_type("char(10)") == ValueType.STRING
assert trino_to_feast_value_type("char") == ValueType.STRING

def test_timestamp_with_precision(self) -> None:
assert trino_to_feast_value_type("timestamp(3)") == ValueType.UNIX_TIMESTAMP
assert trino_to_feast_value_type("timestamp") == ValueType.UNIX_TIMESTAMP

def test_decimal_with_precision(self) -> None:
assert trino_to_feast_value_type("decimal(10, 2)") == ValueType.FLOAT
assert trino_to_feast_value_type("decimal(38, 2)") == ValueType.DOUBLE
assert trino_to_feast_value_type("decimal(32)") == ValueType.FLOAT
assert trino_to_feast_value_type("decimal(33)") == ValueType.DOUBLE

def test_bare_decimal(self) -> None:
assert trino_to_feast_value_type("decimal") == ValueType.DOUBLE

def test_binary_types(self) -> None:
assert trino_to_feast_value_type("binary") == ValueType.STRING
assert trino_to_feast_value_type("varbinary") == ValueType.STRING

def test_json(self) -> None:
assert trino_to_feast_value_type("json") == ValueType.STRING

def test_unsupported_type(self) -> None:
with pytest.raises(ValueError, match="Trino type not supported"):
trino_to_feast_value_type("unknown_type")


class TestTrinoToPaValueType:
def test_simple_types(self) -> None:
assert trino_to_pa_value_type("boolean") == pa.bool_()
assert trino_to_pa_value_type("bigint") == pa.int64()
assert trino_to_pa_value_type("integer") == pa.int32()
assert trino_to_pa_value_type("int") == pa.int32()
assert trino_to_pa_value_type("double") == pa.float64()
assert trino_to_pa_value_type("real") == pa.float32()
assert trino_to_pa_value_type("date") == pa.date32()
assert trino_to_pa_value_type("tinyint") == pa.int8()
assert trino_to_pa_value_type("smallint") == pa.int16()

def test_parameterized_varchar(self) -> None:
assert trino_to_pa_value_type("varchar(10)") == pa.string()

def test_parameterized_char(self) -> None:
assert trino_to_pa_value_type("char(10)") == pa.string()
assert trino_to_pa_value_type("char") == pa.string()

def test_binary_types(self) -> None:
assert trino_to_pa_value_type("binary") == pa.binary()
assert trino_to_pa_value_type("varbinary") == pa.binary()

def test_json(self) -> None:
assert trino_to_pa_value_type("json") == pa.string()

def test_timestamp(self) -> None:
assert trino_to_pa_value_type("timestamp") == pa.timestamp("us")
assert trino_to_pa_value_type("timestamp(3)") == pa.timestamp("us")

def test_decimal_bare(self) -> None:
assert trino_to_pa_value_type("decimal") == pa.float64()

def test_decimal_with_precision(self) -> None:
assert trino_to_pa_value_type("decimal(10, 2)") == pa.float32()
assert trino_to_pa_value_type("decimal(38, 2)") == pa.float64()
assert trino_to_pa_value_type("decimal(32)") == pa.float32()
assert trino_to_pa_value_type("decimal(33)") == pa.float64()

def test_array_simple(self) -> None:
assert trino_to_pa_value_type("array(bigint)") == pa.list_(pa.int64())

def test_array_parameterized_varchar(self) -> None:
assert trino_to_pa_value_type("array(varchar(10))") == pa.list_(pa.string())

def test_array_parameterized_decimal(self) -> None:
assert trino_to_pa_value_type("array(decimal(10, 2))") == pa.list_(pa.float32())

def test_array_nested(self) -> None:
assert trino_to_pa_value_type("array(array(bigint))") == pa.list_(
pa.list_(pa.int64())
)

def test_row_type(self) -> None:
assert trino_to_pa_value_type("row(x bigint)") == pa.string()
assert trino_to_pa_value_type("row(x bigint, y varchar)") == pa.string()

def test_map_type(self) -> None:
assert trino_to_pa_value_type("map(varchar, bigint)") == pa.string()

def test_array_of_row(self) -> None:
assert trino_to_pa_value_type(
"array(row(x bigint, y varchar(10)))"
) == pa.list_(pa.string())

def test_unsupported_type(self) -> None:
with pytest.raises(KeyError):
trino_to_pa_value_type("unknown_type")


class TestPaToTrinoValueType:
def test_simple_types(self) -> None:
assert pa_to_trino_value_type(str(pa.bool_())) == "boolean"
assert pa_to_trino_value_type(str(pa.int8())) == "tinyint"
assert pa_to_trino_value_type(str(pa.int16())) == "smallint"
assert pa_to_trino_value_type(str(pa.int32())) == "int"
assert pa_to_trino_value_type(str(pa.int64())) == "bigint"
assert pa_to_trino_value_type(str(pa.float32())) == "double"
assert pa_to_trino_value_type(str(pa.float64())) == "double"
assert pa_to_trino_value_type(str(pa.binary())) == "binary"

def test_string(self) -> None:
assert pa_to_trino_value_type(str(pa.string())) == "varchar"
assert pa_to_trino_value_type("large_string") == "varchar"
assert pa_to_trino_value_type("char") == "varchar"

def test_varbinary(self) -> None:
assert pa_to_trino_value_type("varbinary") == "binary"

def test_date(self) -> None:
assert pa_to_trino_value_type(str(pa.date32())) == "date"

def test_timestamp(self) -> None:
assert pa_to_trino_value_type(str(pa.timestamp("us"))) == "timestamp"
assert (
pa_to_trino_value_type(str(pa.timestamp("us", tz="UTC")))
== "timestamp with time zone"
)

def test_decimal128(self) -> None:
assert pa_to_trino_value_type(str(pa.decimal128(10, 2))) == "decimal(10, 2)"

def test_decimal256(self) -> None:
assert pa_to_trino_value_type(str(pa.decimal256(10, 2))) == "decimal(10, 2)"

def test_list(self) -> None:
assert pa_to_trino_value_type(str(pa.list_(pa.int64()))) == "array<bigint>"

def test_list_of_string(self) -> None:
assert pa_to_trino_value_type(str(pa.list_(pa.string()))) == "array<varchar>"

def test_map_degrades_to_varchar(self) -> None:
type_str = str(pa.map_(pa.string(), pa.int64()))
assert pa_to_trino_value_type(type_str) == "varchar"

def test_struct_degrades_to_varchar(self) -> None:
type_str = str(pa.struct([("x", pa.int64())]))
assert pa_to_trino_value_type(type_str) == "varchar"

def test_null(self) -> None:
assert pa_to_trino_value_type(str(pa.null())) == "null"
Loading