Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/databricks/sqlalchemy/_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,51 @@ def get_column_specification(self, column, **kwargs):


class DatabricksStatementCompiler(compiler.SQLCompiler):
# Names that a bare Databricks named-parameter marker (`:name`) accepts:
# a letter or underscore followed by letters, digits, or underscores.
# Anything outside that set — hyphens, spaces, dots, brackets, a leading
# digit, etc. — must be wrapped in backticks (`:`name``), which the
# Spark/Databricks SQL grammar accepts as a quoted parameter identifier.
_bindname_is_bare_identifier = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")

def bindparam_string(self, name, **kw):
"""Render a bind parameter marker.

Databricks named parameter markers only accept bare identifiers
([A-Za-z_][A-Za-z0-9_]*) out of the box. DataFrame-origin column
names frequently contain hyphens (e.g. ``col-with-hyphen``), which
SQLAlchemy would otherwise pass through verbatim and produce an
invalid marker ``:col-with-hyphen`` — the parser splits on ``-``
and reports UNBOUND_SQL_PARAMETER.

The Spark SQL grammar accepts a quoted form ``:`col-with-hyphen```,
mirroring Oracle's ``:"name"`` pattern. The backticks are *quoting*
only: the parameter's logical name is still the text between them,
so the params dict sent to the driver must keep the original
unquoted key. We therefore emit the backticked marker directly
without populating ``escaped_bind_names`` — leaving the key
translation in ``construct_params`` a no-op.

For bare identifiers (the common case), we fall through to the
default implementation so INSERT/SELECT output stays unchanged.
"""
if (
not kw.get("escaped_from")
and not kw.get("post_compile", False)
and not self._bindname_is_bare_identifier.match(name)
):
accumulate = kw.get("accumulate_bind_names")
if accumulate is not None:
accumulate.add(name)
visited = kw.get("visited_bindparam")
if visited is not None:
visited.append(name)
quoted = f"`{name}`"
if self.state is compiler.CompilerState.COMPILING:
return self.compilation_bindtemplate % {"name": quoted}
return self.bindtemplate % {"name": quoted}
return super().bindparam_string(name, **kw)

def limit_clause(self, select, **kw):
"""Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,
since Databricks SQL doesn't support the latter.
Expand Down
116 changes: 115 additions & 1 deletion tests/test_local/test_ddl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine
from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine, insert
from sqlalchemy.schema import (
CreateTable,
DropColumnComment,
Expand Down Expand Up @@ -114,3 +114,117 @@ def test_create_table_with_complex_type(self, metadata):
assert "array_array_string ARRAY<ARRAY<STRING>>" in output
assert "map_string_string MAP<STRING,STRING>" in output
assert "variant_col VARIANT" in output


class TestBindParamQuoting(DDLTestBase):
"""Regression tests for column names that contain characters which are not
legal inside a bare Databricks named-parameter marker (`:name`). Without
the custom ``bindparam_string`` override, a column like
``col-with-hyphen`` produces SQL like ``VALUES (:col-with-hyphen)`` which
fails with UNBOUND_SQL_PARAMETER on the server. The fix wraps such names
in backticks (``VALUES (:`col-with-hyphen`)``), which the Databricks SQL
grammar accepts as a quoted parameter identifier.
"""

def _compile_insert(self, table, values):
stmt = insert(table).values(values)
return stmt.compile(bind=self.engine)

def test_hyphenated_column_renders_backticked_bind_marker(self):
metadata = MetaData()
table = Table(
"t",
metadata,
Column("col-with-hyphen", String()),
Column("normal_col", String()),
)
compiled = self._compile_insert(
table, {"col-with-hyphen": "x", "normal_col": "y"}
)

sql = str(compiled)
# Hyphenated name is wrapped in backticks at the marker site
assert ":`col-with-hyphen`" in sql
# Plain name is untouched
assert ":normal_col" in sql
# The params dict sent to the driver keeps the ORIGINAL unquoted key
# — this matches what the Databricks server expects (verified
# empirically: a backticked marker `:`name`` binds against a plain
# `name` key in the params dict).
params = compiled.construct_params()
assert params["col-with-hyphen"] == "x"
assert params["normal_col"] == "y"
assert "`col-with-hyphen`" not in params

def test_hyphen_and_underscore_columns_do_not_collide(self):
"""A table containing both ``col-name`` and ``col_name`` must produce
two distinct bind parameters with two distinct dict keys; otherwise
one value would silently clobber the other.
"""
metadata = MetaData()
table = Table(
"t",
metadata,
Column("col-name", String()),
Column("col_name", String()),
)
compiled = self._compile_insert(
table, {"col-name": "hyphen_value", "col_name": "underscore_value"}
)

sql = str(compiled)
assert ":`col-name`" in sql
assert ":col_name" in sql

params = compiled.construct_params()
assert params["col-name"] == "hyphen_value"
assert params["col_name"] == "underscore_value"

def test_plain_identifier_bind_names_are_unchanged(self):
"""No regression: ordinary column names must not be backticked."""
metadata = MetaData()
table = Table(
"t",
metadata,
Column("id", String()),
Column("name", String()),
)
compiled = self._compile_insert(table, {"id": "1", "name": "n"})
sql = str(compiled)
assert ":id" in sql
assert ":name" in sql
assert ":`id`" not in sql
assert ":`name`" not in sql

def test_space_and_dot_in_column_name_also_backticked(self):
"""The bare-identifier check covers all non-[A-Za-z0-9_] characters,
not just hyphens — spaces, dots, etc. should also be wrapped.
"""
metadata = MetaData()
table = Table(
"t",
metadata,
Column("col with space", String()),
Column("col.with.dot", String()),
)
compiled = self._compile_insert(
table, {"col with space": "s", "col.with.dot": "d"}
)
sql = str(compiled)
assert ":`col with space`" in sql
assert ":`col.with.dot`" in sql

params = compiled.construct_params()
assert params["col with space"] == "s"
assert params["col.with.dot"] == "d"

def test_leading_digit_column_is_backticked(self):
"""Databricks bind names cannot start with a digit either."""
metadata = MetaData()
table = Table("t", metadata, Column("1col", String()))
compiled = self._compile_insert(table, {"1col": "x"})
sql = str(compiled)
assert ":`1col`" in sql

params = compiled.construct_params()
assert params["1col"] == "x"
Loading