Skip to content

Commit e15542f

Browse files
author
Jesse Whitehouse
committed
Refactor primary and foreign key parsing
This now PASSES test_get_primary_keys and test_get_pk_constraint Needs further refactor into actual parse methods... Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com>
1 parent ad8edbc commit e15542f

File tree

3 files changed

+165
-114
lines changed

3 files changed

+165
-114
lines changed

src/databricks/sqlalchemy/__init__.py

Lines changed: 147 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Any, Optional
2+
from typing import Any, Optional, List
33

44
import sqlalchemy
55
from sqlalchemy import event, DDL
@@ -19,6 +19,7 @@
1919
from databricks.sqlalchemy.utils import (
2020
extract_identifier_groups_from_string,
2121
extract_identifiers_from_string,
22+
extract_three_level_identifier_from_constraint_string
2223
)
2324

2425
try:
@@ -32,6 +33,10 @@ class DatabricksImpl(DefaultImpl):
3233
__dialect__ = "databricks"
3334

3435

36+
import logging
37+
38+
logger = logging.getLogger(__name__)
39+
3540
DBR_LTE_12_NOT_FOUND_STRING = "Table or view not found"
3641
DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"
3742

@@ -60,6 +65,133 @@ def _describe_table_extended_result_to_dict(result: CursorResult) -> dict:
6065
return result_dict
6166

6267

68+
def _extract_pk_from_dte_result(result: dict) -> ReflectedPrimaryKeyConstraint:
69+
"""Return a dictionary with the keys:
70+
71+
constrained_columns
72+
a list of column names that make up the primary key. Results is an empty list
73+
if no PRIMARY KEY is defined.
74+
75+
name
76+
the name of the primary key constraint
77+
78+
Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
79+
a primary key constraint will be found in its output. So we cycle through its
80+
output looking for a match that includes "PRIMARY KEY". This is brittle. We
81+
could optionally make two roundtrips: the first would query information_schema
82+
for the name of the primary key constraint on this table, and a second to
83+
DESCRIBE TABLE EXTENDED, at which point we would know the name of the constraint.
84+
But for now we instead assume that Python list comprehension is faster than a
85+
network roundtrip.
86+
"""
87+
88+
# find any rows that contain "PRIMARY KEY" as the `data_type`
89+
filtered_rows = [(k, v) for k, v in result.items() if "PRIMARY KEY" in v]
90+
91+
# bail if no primary key was found
92+
if not filtered_rows:
93+
return {"constrained_columns": [], "name": None}
94+
95+
# there should only ever be one PRIMARY KEY that matches
96+
if len(filtered_rows) > 1:
97+
logger.warning(
98+
"Found more than one primary key constraint in DESCRIBE TABLE EXTENDED output. "
99+
"This is unexpected. Please report this as a bug. "
100+
"Only the first primary key constraint will be returned."
101+
)
102+
103+
# target is a tuple of (constraint_name, constraint_string)
104+
target = filtered_rows[0]
105+
name = target[0]
106+
_constraint_string = target[1]
107+
column_list = extract_identifiers_from_string(_constraint_string)
108+
109+
return {"constrained_columns": column_list, "name": name}
110+
111+
112+
def _extract_single_fk_dict_from_dte_result_row(
113+
table_name: str, schema_name: Optional[str], fk_name: str, fk_constraint_string: str
114+
) -> dict:
115+
"""
116+
"""
117+
118+
# SQLAlchemy's ComponentReflectionTest::test_get_foreign_keys is strange in that it
119+
# expects the `referred_schema` member of the outputted dictionary to be None if
120+
# a `schema` argument was not passed to the dialect's `get_foreign_keys` method
121+
referred_table_dict = extract_three_level_identifier_from_constraint_string(fk_constraint_string)
122+
referred_table = referred_table_dict["table"]
123+
if schema_name:
124+
referred_schema = referred_table_dict["schema"]
125+
else:
126+
referred_schema = None
127+
128+
_extracted = extract_identifier_groups_from_string(fk_constraint_string)
129+
constrained_columns_str, referred_columns_str = (
130+
_extracted[0],
131+
_extracted[1],
132+
)
133+
134+
constrainted_columns = extract_identifiers_from_string(constrained_columns_str)
135+
referred_columns = extract_identifiers_from_string(referred_columns_str)
136+
137+
return {
138+
"constrained_columns": constrainted_columns,
139+
"name": fk_name,
140+
"referred_table": referred_table,
141+
"referred_columns": referred_columns,
142+
"referred_schema": referred_schema,
143+
}
144+
145+
146+
def _extract_fk_from_dte_result(
147+
table_name: str, schema_name: Optional[str], result: dict
148+
) -> ReflectedForeignKeyConstraint:
149+
"""Return a list of dictionaries with the keys:
150+
151+
constrained_columns
152+
a list of column names that make up the foreign key
153+
154+
name
155+
the name of the foreign key constraint
156+
157+
referred_table
158+
the name of the table that the foreign key references
159+
160+
referred_columns
161+
a list of column names that are referenced by the foreign key
162+
163+
Returns an empty list if no foreign key is defined.
164+
165+
Today, DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
166+
a foreign key constraint will be found in its output. So we cycle through its
167+
output looking for a match that includes "FOREIGN KEY". This is brittle. We
168+
could optionally make two roundtrips: the first would query information_schema
169+
for the name of the foreign key constraint on this table, and a second to
170+
DESCRIBE TABLE EXTENDED, at which point we would know the name of the constraint.
171+
But for now we instead assume that Python list comprehension is faster than a
172+
network roundtrip.
173+
"""
174+
175+
# find any rows that contain "FOREIGN_KEY" as the `data_type`
176+
filtered_rows = [(k, v) for k, v in result.items() if "FOREIGN KEY" in v]
177+
178+
# bail if no foreign key was found
179+
if not filtered_rows:
180+
return []
181+
182+
constraint_list = []
183+
184+
# target is a tuple of (constraint_name, constraint_string)
185+
for target in filtered_rows:
186+
_constraint_name, _constraint_string = target
187+
this_constraint_dict = _extract_single_fk_dict_from_dte_result_row(
188+
table_name, schema_name, _constraint_name, _constraint_string
189+
)
190+
constraint_list.append(this_constraint_dict)
191+
192+
return constraint_list
193+
194+
63195
COLUMN_TYPE_MAP = {
64196
"boolean": sqlalchemy.types.Boolean,
65197
"smallint": sqlalchemy.types.SmallInteger,
@@ -235,125 +367,26 @@ def get_pk_constraint(
235367
table_name`.
236368
"""
237369

238-
with self.get_connection_cursor(connection) as cursor:
239-
try:
240-
# DESCRIBE TABLE EXTENDED doesn't support parameterised inputs :(
241-
result = cursor.execute(
242-
f"DESCRIBE TABLE EXTENDED {table_name}"
243-
).fetchall()
244-
except ServerOperationError as e:
245-
if DBR_GT_12_NOT_FOUND_STRING in str(
246-
e
247-
) or DBR_LTE_12_NOT_FOUND_STRING in str(e):
248-
raise sqlalchemy.exc.NoSuchTableError(
249-
f"No such table {table_name}"
250-
) from e
251-
252-
# DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
253-
# a primary key constraint will be found in its output. So we cycle through its
254-
# output looking for a match that includes "PRIMARY KEY". This is brittle. We
255-
# could optionally make two roundtrips: the first would query information_schema
256-
# for the name of the primary key constraint on this table, and a second to
257-
# DESCRIBE TABLE EXTENDED, at which point we would know the name of the constraint.
258-
# But for now we instead assume that Python list comprehension is faster than a
259-
# network roundtrip.
260-
dte_dict = {row["col_name"]: row["data_type"] for row in result}
261-
target = [(k, v) for k, v in dte_dict.items() if "PRIMARY KEY" in v]
262-
if target:
263-
name, _constraint_string = target[0]
264-
column_list = extract_identifiers_from_string(_constraint_string)
265-
else:
266-
name, column_list = None, None
267-
268-
return {"constrained_columns": column_list, "name": name}
370+
result = self._describe_table_extended(
371+
connection=connection,
372+
table_name=table_name,
373+
schema_name=schema,
374+
)
375+
376+
return _extract_pk_from_dte_result(result)
269377

270378
def get_foreign_keys(
271379
self, connection, table_name, schema=None, **kw
272380
) -> ReflectedForeignKeyConstraint:
273-
"""Return information about foreign_keys in `table_name`.
274-
275-
Given a :class:`_engine.Connection`, a string
276-
`table_name`, and an optional string `schema`, return foreign
277-
key information as a list of dicts with these keys:
278-
279-
name
280-
the constraint's name
381+
"""Return information about foreign_keys in `table_name`."""
281382

282-
constrained_columns
283-
a list of column names that make up the foreign key
284-
285-
referred_schema
286-
the name of the referred schema
287-
288-
referred_table
289-
the name of the referred table
290-
291-
referred_columns
292-
a list of column names in the referred table that correspond to
293-
constrained_columns
294-
"""
295-
"""Return information about the primary key constraint on
296-
table_name`.
297-
"""
298-
299-
try:
300-
with self.get_connection_cursor(connection) as cursor:
301-
# DESCRIBE TABLE EXTENDED doesn't support parameterised inputs :(
302-
result = cursor.execute(
303-
f"DESCRIBE TABLE EXTENDED {schema + '.' if schema else ''}{table_name}"
304-
).fetchall()
305-
except ServerOperationError as e:
306-
if DBR_GT_12_NOT_FOUND_STRING in str(
307-
e
308-
) or DBR_LTE_12_NOT_FOUND_STRING in str(e):
309-
raise sqlalchemy.exc.NoSuchTableError(
310-
f"No such table {table_name}"
311-
) from e
312-
313-
# DESCRIBE TABLE EXTENDED doesn't give a deterministic name to the field where
314-
# a foreign key constraint will be found in its output. So we cycle through its
315-
# output looking for a match that includes "FOREIGN KEY". This is brittle. We
316-
# could optionally make two roundtrips: the first would query information_schema
317-
# for the name of the foreign key constraint on this table, and a second to
318-
# DESCRIBE TABLE EXTENDED, at which point we would know the name of the constraint.
319-
# But for now we instead assume that Python list comprehension is faster than a
320-
# network roundtrip.
321-
dte_dict = {row["col_name"]: row["data_type"] for row in result}
322-
target = [(k, v) for k, v in dte_dict.items() if "FOREIGN KEY" in v]
323-
324-
def extract_constraint_dict_from_target(target):
325-
if target:
326-
name, _constraint_string = target
327-
_extracted = extract_identifier_groups_from_string(_constraint_string)
328-
constrained_columns_str, referred_columns_str = (
329-
_extracted[0],
330-
_extracted[1],
331-
)
332-
333-
constrained_columns = extract_identifiers_from_string(
334-
constrained_columns_str
335-
)
336-
referred_columns = extract_identifiers_from_string(referred_columns_str)
337-
referred_table = str(table_name)
338-
else:
339-
name, constrained_columns, referred_columns, referred_table = (
340-
None,
341-
None,
342-
None,
343-
None,
344-
)
345-
346-
return {
347-
"constrained_columns": constrained_columns,
348-
"name": name,
349-
"referred_table": referred_table,
350-
"referred_columns": referred_columns,
351-
}
383+
result = self._describe_table_extended(
384+
connection=connection,
385+
table_name=table_name,
386+
schema_name=schema,
387+
)
352388

353-
if target:
354-
return [extract_constraint_dict_from_target(i) for i in target]
355-
else:
356-
return []
389+
return _extract_fk_from_dte_result(table_name, schema, result)
357390

358391
def get_indexes(self, connection, table_name, schema=None, **kw):
359392
"""Return information about indexes in `table_name`.

src/databricks/sqlalchemy/requirements.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,9 @@ def unique_constraint_reflection(self):
176176
Databricks supports unique constraints but they are not implemented in this dialect.
177177
"""
178178
return sqlalchemy.testing.exclusions.closed()
179+
180+
@property
181+
def reflects_pk_names(self):
182+
"""Target driver reflects the name of primary key constraints."""
183+
184+
return sqlalchemy.testing.exclusions.open()

src/databricks/sqlalchemy/test_local/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from databricks.sqlalchemy.utils import (
33
extract_identifiers_from_string,
44
extract_identifier_groups_from_string,
5+
extract_three_level_identifier_from_constraint_string
56
)
67

78

@@ -36,3 +37,14 @@ def test_extract_identifer_batches(input, expected):
3637
assert (
3738
extract_identifier_groups_from_string(input) == expected
3839
), "Failed to extract identifier groups from string"
40+
41+
def test_extract_3l_namespace_from_constraint_string():
42+
43+
input = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)"
44+
expected = {
45+
"catalog": "main",
46+
"schema": "pysql_dialect_compliance",
47+
"table": "users"
48+
}
49+
50+
assert extract_three_level_identifier_from_constraint_string(input) == expected, "Failed to extract 3L namespace from constraint string"

0 commit comments

Comments
 (0)