@@ -34,6 +34,26 @@ class DatabricksImpl(DefaultImpl):
3434DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
3535DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
3636
37+ COLUMN_TYPE_MAP = {
38+ "boolean" : sqlalchemy .types .Boolean ,
39+ "smallint" : sqlalchemy .types .SmallInteger ,
40+ "int" : sqlalchemy .types .Integer ,
41+ "bigint" : sqlalchemy .types .BigInteger ,
42+ "float" : sqlalchemy .types .Float ,
43+ "double" : sqlalchemy .types .Float ,
44+ "string" : sqlalchemy .types .String ,
45+ "varchar" : sqlalchemy .types .String ,
46+ "char" : sqlalchemy .types .String ,
47+ "binary" : sqlalchemy .types .String ,
48+ "array" : sqlalchemy .types .String ,
49+ "map" : sqlalchemy .types .String ,
50+ "struct" : sqlalchemy .types .String ,
51+ "uniontype" : sqlalchemy .types .String ,
52+ "decimal" : sqlalchemy .types .Numeric ,
53+ "timestamp" : sqlalchemy .types .DateTime ,
54+ "date" : sqlalchemy .types .Date ,
55+ }
56+
3757class DatabricksDialect (default .DefaultDialect ):
3858 """This dialect implements only those methods required to pass our e2e tests"""
3959
@@ -111,26 +131,6 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
111131 Additional column attributes may be present.
112132 """
113133
114- _type_map = {
115- "boolean" : sqlalchemy .types .Boolean ,
116- "smallint" : sqlalchemy .types .SmallInteger ,
117- "int" : sqlalchemy .types .Integer ,
118- "bigint" : sqlalchemy .types .BigInteger ,
119- "float" : sqlalchemy .types .Float ,
120- "double" : sqlalchemy .types .Float ,
121- "string" : sqlalchemy .types .String ,
122- "varchar" : sqlalchemy .types .String ,
123- "char" : sqlalchemy .types .String ,
124- "binary" : sqlalchemy .types .String ,
125- "array" : sqlalchemy .types .String ,
126- "map" : sqlalchemy .types .String ,
127- "struct" : sqlalchemy .types .String ,
128- "uniontype" : sqlalchemy .types .String ,
129- "decimal" : sqlalchemy .types .Numeric ,
130- "timestamp" : sqlalchemy .types .DateTime ,
131- "date" : sqlalchemy .types .Date ,
132- }
133-
134134 with self .get_connection_cursor (connection ) as cur :
135135 resp = cur .columns (
136136 catalog_name = self .catalog ,
@@ -147,7 +147,7 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
147147 _col_type = re .search (r"^\w+" , col .TYPE_NAME ).group (0 )
148148 this_column = {
149149 "name" : col .COLUMN_NAME ,
150- "type" : _type_map [_col_type .lower ()],
150+ "type" : COLUMN_TYPE_MAP [_col_type .lower ()],
151151 "nullable" : bool (col .NULLABLE ),
152152 "default" : col .COLUMN_DEF ,
153153 "autoincrement" : False if col .IS_AUTO_INCREMENT == "NO" else True ,
0 commit comments