|
15 | 15 | import decimal |
16 | 16 | import json |
17 | 17 | import logging |
| 18 | +import re |
18 | 19 | import uuid as uuid_module |
19 | 20 | from collections import defaultdict |
20 | 21 | from datetime import datetime, timezone |
@@ -2097,25 +2098,62 @@ def pg_type_code_to_arrow(code: int) -> str: |
2097 | 2098 |
|
2098 | 2099 | def athena_to_feast_value_type(athena_type_as_str: str) -> ValueType: |
2099 | 2100 | # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html |
| 2101 | + athena_type = athena_type_as_str.lower().strip() |
| 2102 | + if athena_type.startswith("array"): |
| 2103 | + inner_type_match = re.search(r'(?:<|\[)(.+)(?:>|\])', athena_type) |
| 2104 | + if inner_type_match: |
| 2105 | + inner_type = inner_type_match.group(1).strip() |
| 2106 | + inner_feast_type = athena_to_feast_value_type(inner_type) |
| 2107 | + |
| 2108 | + list_mapping = { |
| 2109 | + ValueType.BYTES: ValueType.BYTES_LIST, |
| 2110 | + ValueType.STRING: ValueType.STRING_LIST, |
| 2111 | + ValueType.INT32: ValueType.INT32_LIST, |
| 2112 | + ValueType.INT64: ValueType.INT64_LIST, |
| 2113 | + ValueType.DOUBLE: ValueType.DOUBLE_LIST, |
| 2114 | + ValueType.FLOAT: ValueType.FLOAT_LIST, |
| 2115 | + ValueType.BOOL: ValueType.BOOL_LIST, |
| 2116 | + ValueType.UNIX_TIMESTAMP: ValueType.UNIX_TIMESTAMP_LIST, |
| 2117 | + ValueType.MAP: ValueType.MAP_LIST, |
| 2118 | + ValueType.JSON: ValueType.JSON_LIST, |
| 2119 | + ValueType.STRUCT: ValueType.STRUCT_LIST, |
| 2120 | + ValueType.UUID: ValueType.UUID_LIST, |
| 2121 | + ValueType.DECIMAL: ValueType.DECIMAL_LIST, |
| 2122 | + } |
| 2123 | + return list_mapping.get(inner_feast_type, ValueType.VALUE_LIST) |
| 2124 | + return ValueType.VALUE_LIST |
| 2125 | + |
| 2126 | + base_type = re.split(r'[(<\[]', athena_type)[0].strip() |
| 2127 | + |
| 2128 | + if "timestamp" in base_type or "time" in base_type or "date" in base_type: |
| 2129 | + return ValueType.UNIX_TIMESTAMP |
| 2130 | + |
2100 | 2131 | type_map = { |
2101 | | - "null": ValueType.UNKNOWN, |
| 2132 | + "null": ValueType.NULL, |
2102 | 2133 | "boolean": ValueType.BOOL, |
2103 | 2134 | "tinyint": ValueType.INT32, |
2104 | 2135 | "smallint": ValueType.INT32, |
2105 | 2136 | "int": ValueType.INT32, |
| 2137 | + "integer": ValueType.INT32, |
2106 | 2138 | "bigint": ValueType.INT64, |
2107 | 2139 | "double": ValueType.DOUBLE, |
2108 | 2140 | "float": ValueType.FLOAT, |
| 2141 | + "real": ValueType.FLOAT, |
| 2142 | + "decimal": ValueType.DECIMAL, |
2109 | 2143 | "binary": ValueType.BYTES, |
| 2144 | + "varbinary": ValueType.BYTES, |
2110 | 2145 | "char": ValueType.STRING, |
2111 | 2146 | "varchar": ValueType.STRING, |
2112 | 2147 | "string": ValueType.STRING, |
2113 | | - "timestamp": ValueType.UNIX_TIMESTAMP, |
2114 | 2148 | "json": ValueType.JSON, |
2115 | 2149 | "struct": ValueType.STRUCT, |
| 2150 | + "row": ValueType.STRUCT, |
2116 | 2151 | "map": ValueType.MAP, |
| 2152 | + "uuid": ValueType.UUID, |
| 2153 | + "ipaddress": ValueType.STRING, |
2117 | 2154 | } |
2118 | | - return type_map[athena_type_as_str.lower()] |
| 2155 | + |
| 2156 | + return type_map.get(base_type, ValueType.UNKNOWN) |
2119 | 2157 |
|
2120 | 2158 |
|
2121 | 2159 | def pa_to_athena_value_type(pa_type: "pyarrow.DataType") -> str: |
|
0 commit comments