Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,16 @@ def named_parameters_to_dbsqlparams_v2(parameters: List[Any]):
return dbsqlparams


def resolve_databricks_sql_integer_type(integer):
"""Returns the smallest Databricks SQL integer type that can contain the passed integer"""
if -128 <= integer <= 127:
return DbSqlType.TINYINT
elif -2147483648 <= integer <= 2147483647:
return DbSqlType.INTEGER
else:
return DbSqlType.BIGINT


def infer_types(params: list[DbSqlParameter]):
type_lookup_table = {
str: DbSqlType.STRING,
Expand Down Expand Up @@ -568,6 +578,10 @@ def infer_types(params: list[DbSqlParameter]):
cast_exp = calculate_decimal_cast_string(param.value)
_type = DbsqlDynamicDecimalType(cast_exp)

# int() requires special handling because one Python type can be cast to multiple SQL types (INT, BIGINT, TINYINT)
if _type == DbSqlType.INTEGER:
_type = resolve_databricks_sql_integer_type(param.value)

# VOID / NULL types must be passed in a unique way as TSparkParameters with no value
if _type == DbSqlType.VOID:
new_params.append(DbSqlParameter(name=_name, type=DbSqlType.VOID))
Expand Down
63 changes: 42 additions & 21 deletions tests/unit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,36 @@


class TestTSparkParameterConversion(object):
def test_conversion_e2e(self):
@pytest.mark.parametrize(
"input_value, expected_type",
[
("a", "STRING"),
(1, "TINYINT"),
(1000, "INTEGER"),
(9223372036854775807, "BIGINT"), # Max value of a signed 64-bit integer
(True, "BOOLEAN"),
(1.0, "FLOAT"),
],
)
def test_conversion_e2e(self, input_value, expected_type):
"""This behaviour falls back to Python's default string formatting of numbers"""
assert named_parameters_to_tsparkparams(
["a", 1, True, 1.0, DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)]
) == [
TSparkParameter(
name="", type="STRING", value=TSparkParameterValue(stringValue="a")
),
TSparkParameter(
name="", type="INTEGER", value=TSparkParameterValue(stringValue="1")
),
TSparkParameter(
name="", type="BOOLEAN", value=TSparkParameterValue(stringValue="True")
),
TSparkParameter(
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
),
output = named_parameters_to_tsparkparams([input_value])
expected = TSparkParameter(
name="",
type=expected_type,
value=TSparkParameterValue(stringValue=str(input_value)),
)
assert output == [expected]

def test_conversion_e2e_decimal(self):
input = DbSqlParameter(value="1.0", type=DbSqlType.DECIMAL)
output = named_parameters_to_tsparkparams([input])
assert output == [
TSparkParameter(
name="",
type="DECIMAL(2,1)",
value=TSparkParameterValue(stringValue="1.0"),
),
)
]

def test_basic_conversions_v1(self):
Expand Down Expand Up @@ -69,10 +77,24 @@ def test_infer_types_dict(self):
with pytest.raises(ValueError):
infer_types([DbSqlParameter("", {1: 1})])

def test_infer_types_integer(self):
input = DbSqlParameter("", 1)
@pytest.mark.parametrize(
"input_value, expected_type",
[
(-128, DbSqlType.TINYINT),
(127, DbSqlType.TINYINT),
(-2147483649, DbSqlType.BIGINT),
(-2147483648, DbSqlType.INTEGER),
(2147483647, DbSqlType.INTEGER),
(-9223372036854775808, DbSqlType.BIGINT),
(9223372036854775807, DbSqlType.BIGINT),
],
)
def test_infer_types_integer(self, input_value, expected_type):
input = DbSqlParameter("", input_value)
output = infer_types([input])
assert output == [DbSqlParameter("", "1", DbSqlType.INTEGER)]
assert output == [
DbSqlParameter("", str(input_value), expected_type)
], f"{output[0].type} received, expected {expected_type}"

def test_infer_types_boolean(self):
input = DbSqlParameter("", True)
Expand Down Expand Up @@ -101,7 +123,6 @@ def test_infer_types_decimal(self):
assert x.type.value == "DECIMAL(2,1)"

def test_infer_types_none(self):

input = DbSqlParameter("", None)
output: List[DbSqlParameter] = infer_types([input])

Expand Down