Skip to content

Commit ad8edbc

Browse files
author
Jesse Whitehouse
committed
Refactor dialect's has_table method to extract _describe_table_extended
This passes the HasTableTest test group in the dialect compliance test suite Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent 48604ef commit ad8edbc

File tree

1 file changed

+88
-27
lines changed

1 file changed

+88
-27
lines changed

src/databricks/sqlalchemy/__init__.py

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import sqlalchemy
55
from sqlalchemy import event, DDL
6-
from sqlalchemy.engine import Engine, default, reflection
6+
from sqlalchemy.engine import Engine, default, reflection, Connection, Row, CursorResult
77
from sqlalchemy.engine.interfaces import (
88
ReflectedForeignKeyConstraint,
99
ReflectedPrimaryKeyConstraint,
@@ -31,9 +31,35 @@
3131
class DatabricksImpl(DefaultImpl):
3232
__dialect__ = "databricks"
3333

34+
3435
DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
3536
DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
3637

38+
39+
def _match_table_not_found_string(message: str) -> bool:
40+
"""Return True if the message contains a substring indicating that a table was not found"""
41+
return any(
42+
[
43+
DBR_LTE_12_NOT_FOUND_STRING in message,
44+
DBR_GT_12_NOT_FOUND_STRING in message,
45+
]
46+
)
47+
48+
49+
def _describe_table_extended_result_to_dict(result: CursorResult) -> dict:
50+
"""Transform the output of DESCRIBE TABLE EXTENDED into a dictionary
51+
52+
The output from DESCRIBE TABLE EXTENDED puts all values in the `data_type` column
53+
Even CONSTRAINT descriptions are contained in the `data_type` column
54+
Some rows have an empty string for their col_name. These are present only for spacing
55+
so we ignore them.
56+
"""
57+
58+
result_dict = {row.col_name: row.data_type for row in result if row.col_name != ""}
59+
60+
return result_dict
61+
62+
3763
COLUMN_TYPE_MAP = {
3864
"boolean": sqlalchemy.types.Boolean,
3965
"smallint": sqlalchemy.types.SmallInteger,
@@ -54,6 +80,7 @@ class DatabricksImpl(DefaultImpl):
5480
"date": sqlalchemy.types.Date,
5581
}
5682

83+
5784
class DatabricksDialect(default.DefaultDialect):
5885
"""This dialect implements only those methods required to pass our e2e tests"""
5986

@@ -156,6 +183,46 @@ def get_columns(self, connection, table_name, schema=None, **kwargs):
156183

157184
return columns
158185

186+
def _describe_table_extended(
187+
self,
188+
connection: Connection,
189+
table_name: str,
190+
catalog_name: Optional[str] = None,
191+
schema_name: Optional[str] = None,
192+
expect_result=True,
193+
):
194+
"""Run DESCRIBE TABLE EXTENDED on a table and return a dictionary of the result.
195+
196+
This method is the fastest way to check for the presence of a table in a schema.
197+
198+
If expect_result is False, this method returns None as the output dict isn't required.
199+
200+
Raises NoSuchTableError if the table is not present in the schema.
201+
"""
202+
203+
_target_catalog = catalog_name or self.catalog
204+
_target_schema = schema_name or self.schema
205+
_target = f"`{_target_catalog}`.`{_target_schema}`.`{table_name}`"
206+
207+
# sql injection risk?
208+
# DESCRIBE TABLE EXTENDED in DBR doesn't support parameterised inputs :(
209+
stmt = DDL(f"DESCRIBE TABLE EXTENDED {_target}")
210+
211+
try:
212+
result = connection.execute(stmt).all()
213+
except DatabaseError as e:
214+
if _match_table_not_found_string(str(e)):
215+
raise sqlalchemy.exc.NoSuchTableError(
216+
f"No such table {table_name}"
217+
) from e
218+
raise e
219+
220+
if not expect_result:
221+
return None
222+
223+
fmt_result = _describe_table_extended_result_to_dict(result)
224+
return fmt_result
225+
159226
@reflection.cache
160227
def get_pk_constraint(
161228
self,
@@ -169,16 +236,18 @@ def get_pk_constraint(
169236
"""
170237

171238
with self.get_connection_cursor(connection) as cursor:
172-
173239
try:
174240
# DESCRIBE TABLE EXTENDED doesn't support parameterised inputs :(
175-
result = cursor.execute(f"DESCRIBE TABLE EXTENDED {table_name}").fetchall()
241+
result = cursor.execute(
242+
f"DESCRIBE TABLE EXTENDED {table_name}"
243+
).fetchall()
176244
except ServerOperationError as e:
177245
if DBR_GT_12_NOT_FOUND_STRING in str(
178246
e
179247
) or DBR_LTE_12_NOT_FOUND_STRING in str(e):
180-
raise sqlalchemy.exc.NoSuchTableError(f"No such table {table_name}") from e
181-
248+
raise sqlalchemy.exc.NoSuchTableError(
249+
f"No such table {table_name}"
250+
) from e
182251

183252
# DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
184253
# a primary key constraint will be found in its output. So we cycle through its
@@ -237,7 +306,9 @@ def get_foreign_keys(
237306
if DBR_GT_12_NOT_FOUND_STRING in str(
238307
e
239308
) or DBR_LTE_12_NOT_FOUND_STRING in str(e):
240-
raise sqlalchemy.exc.NoSuchTableError(f"No such table {table_name}") from e
309+
raise sqlalchemy.exc.NoSuchTableError(
310+
f"No such table {table_name}"
311+
) from e
241312

242313
# DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
243314
# a foreign key constraint will be found in its output. So we cycle through its
@@ -333,29 +404,20 @@ def do_rollback(self, dbapi_connection):
333404
def has_table(
334405
self, connection, table_name, schema=None, catalog=None, **kwargs
335406
) -> bool:
336-
"""SQLAlchemy docstrings say dialect providers must implement this method"""
337-
338-
_schema = schema or self.schema
339-
_catalog = catalog or self.catalog
340-
341-
# DBR >12.x uses underscores in error messages
342-
DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
343-
DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
407+
"""For internal dialect use, check the existence of a particular table
408+
or view in the database.
409+
"""
344410

345411
try:
346-
res = connection.execute(
347-
sqlalchemy.text(
348-
f"DESCRIBE TABLE `{_catalog}`.`{_schema}`.`{table_name}`"
349-
)
412+
self._describe_table_extended(
413+
connection=connection,
414+
table_name=table_name,
415+
catalog_name=catalog,
416+
schema_name=schema,
350417
)
351418
return True
352-
except DatabaseError as e:
353-
if DBR_GT_12_NOT_FOUND_STRING in str(
354-
e
355-
) or DBR_LTE_12_NOT_FOUND_STRING in str(e):
356-
return False
357-
else:
358-
raise e
419+
except sqlalchemy.exc.NoSuchTableError as e:
420+
return False
359421

360422
def get_connection_cursor(self, connection):
361423
"""Added for backwards compatibility with 1.3.x"""
@@ -372,8 +434,7 @@ def get_connection_cursor(self, connection):
372434

373435
@reflection.cache
374436
def get_schema_names(self, connection, **kw):
375-
"""Return a list of all schema names available in the database.
376-
"""
437+
"""Return a list of all schema names available in the database."""
377438
stmt = DDL("SHOW SCHEMAS")
378439
result = connection.execute(stmt)
379440
schema_list = [row[0] for row in result]

0 commit comments

Comments
 (0)