Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/python-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,8 @@ To keep the original table around instead of dropping it, pass the ``keep_table=

table.transform(types={"age": int}, keep_table="original_table")

``CHECK`` constraints are preserved across a transform. If a column referenced by a constraint is renamed the constraint is updated to match, and if such a column is dropped the constraint is dropped along with it.

This method raises a ``sqlite_utils.db.TransformError`` exception if the table cannot be transformed, usually because there are existing constraints or indexes that are incompatible with modifications to the columns.

.. _python_api_transform_alter_column_types:
Expand Down
162 changes: 161 additions & 1 deletion sqlite_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,148 @@ class TransformError(Exception):
pass


def _tokenize_sql(sql: str) -> List[Tuple[str, str]]:
# Split SQL into (kind, text) tokens. Enough to walk a CREATE TABLE
# statement while respecting string literals, quoted identifiers, comments
# and nesting - not a full parser.
tokens = []
i, n = 0, len(sql)
while i < n:
c = sql[i]
if c in " \t\r\n":
j = i + 1
while j < n and sql[j] in " \t\r\n":
j += 1
tokens.append(("ws", sql[i:j]))
elif sql[i : i + 2] == "--":
j = sql.find("\n", i)
j = n if j == -1 else j
tokens.append(("comment", sql[i:j]))
elif sql[i : i + 2] == "/*":
j = sql.find("*/", i + 2)
j = n if j == -1 else j + 2
tokens.append(("comment", sql[i:j]))
elif c in "'\"`":
j = i + 1
while j < n:
if sql[j] == c:
if sql[j : j + 2] == c + c:
j += 2
continue
j += 1
break
j += 1
tokens.append(("string" if c == "'" else "quoted", sql[i:j]))
elif c == "[":
j = sql.find("]", i + 1)
j = n if j == -1 else j + 1
tokens.append(("quoted", sql[i:j]))
elif c.isalnum() or c in "_$":
j = i + 1
while j < n and (sql[j].isalnum() or sql[j] in "_$"):
j += 1
tokens.append(("word", sql[i:j]))
else:
j = i + 1
tokens.append(("punct", c))
i = j
return tokens


def _capture_paren_inner(
tokens: List[Tuple[str, str]], open_index: int
) -> Tuple[str, int]:
# tokens[open_index] is the opening "(" - return the text inside the
# matching parentheses and the index just past the closing ")".
depth = 0
parts = []
i, n = open_index, len(tokens)
while i < n:
kind, text = tokens[i]
if kind == "punct" and text == "(":
depth += 1
if depth > 1:
parts.append(text)
elif kind == "punct" and text == ")":
depth -= 1
if depth == 0:
return "".join(parts), i + 1
parts.append(text)
else:
parts.append(text)
i += 1
return "".join(parts), i


def _extract_check_constraints(create_table_sql: str) -> List[str]:
# CHECK constraints (column-level and table-level) live only in the stored
# CREATE TABLE SQL, not in any PRAGMA. Return the expression inside each one.
# Every CHECK keyword sits at the top level of the table body regardless of
# whether it is attached to a column or the table.
tokens = _tokenize_sql(create_table_sql)
checks = []
depth = 0
started = False
i, n = 0, len(tokens)
while i < n:
kind, text = tokens[i]
if kind == "punct" and text == "(":
depth += 1
started = True
elif kind == "punct" and text == ")":
depth -= 1
if started and depth == 0:
break
elif started and depth == 1 and kind == "word" and text.upper() == "CHECK":
j = i + 1
while j < n and tokens[j][0] in ("ws", "comment"):
j += 1
if j < n and tokens[j] == ("punct", "("):
inner, after = _capture_paren_inner(tokens, j)
checks.append(inner.strip())
i = after
continue
i += 1
return checks


def _unquote_identifier(text: str) -> str:
if len(text) >= 2:
if text[0] == '"' and text[-1] == '"':
return text[1:-1].replace('""', '"')
if text[0] == "`" and text[-1] == "`":
return text[1:-1].replace("``", "`")
if text[0] == "[" and text[-1] == "]":
return text[1:-1]
return text


def _rewrite_check_expression(
expression: str, rename: Dict[str, str], drop: Set[str]
) -> Optional[str]:
# Apply column renames to identifiers referenced by a CHECK expression.
# Returns None if the expression references a dropped column, in which case
# the constraint can no longer be enforced and should be discarded rather
# than producing a table that fails to build.
rename_lower = {k.lower(): v for k, v in rename.items()}
drop_lower = {d.lower() for d in drop}
out = []
for kind, text in _tokenize_sql(expression):
key = None
if kind == "word":
key = text.lower()
elif kind == "quoted":
key = _unquote_identifier(text).lower()
if key is not None:
if key in drop_lower:
return None
if key in rename_lower:
out.append(quote_identifier(rename_lower[key]))
continue
out.append(text)
return "".join(out)


ForeignKeyIndicator = Union[
str,
ForeignKey,
Expand Down Expand Up @@ -977,6 +1119,7 @@ def create_table_sql(
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
if_not_exists: bool = False,
strict: bool = False,
check_constraints: Optional[List[str]] = None,
) -> str:
"""
Returns the SQL ``CREATE TABLE`` statement for creating the specified table.
Expand All @@ -993,6 +1136,7 @@ def create_table_sql(
:param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts`
:param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS``
:param strict: Apply STRICT mode to table
:param check_constraints: List of ``CHECK`` constraint expressions to add as table-level constraints, for example ``["age >= 0"]``
"""
if hash_id_columns and (hash_id is None):
hash_id = "id"
Expand Down Expand Up @@ -1094,15 +1238,21 @@ def sort_key(p):
extra_pk = ",\n PRIMARY KEY ({pks})".format(
pks=", ".join([quote_identifier(p) for p in pk])
)
extra_checks = ""
if check_constraints:
extra_checks = "".join(
",\n CHECK ({})".format(check) for check in check_constraints
)
columns_sql = ",\n".join(column_defs)
sql = """CREATE TABLE {if_not_exists}{table} (
{columns_sql}{extra_pk}
{columns_sql}{extra_pk}{extra_checks}
){strict};
""".format(
if_not_exists="IF NOT EXISTS " if if_not_exists else "",
table=quote_identifier(name),
columns_sql=columns_sql,
extra_pk=extra_pk,
extra_checks=extra_checks,
strict=" STRICT" if strict and self.supports_strict else "",
)
return sql
Expand Down Expand Up @@ -2136,6 +2286,15 @@ def transform_sql(
if column_order is not None:
column_order = [rename.get(col) or col for col in column_order]

# CHECK constraints are not exposed by any PRAGMA, so pull them out of the
# stored schema and carry them across, applying any renames and dropping
# constraints that reference a removed column.
create_table_checks = []
for check in _extract_check_constraints(self.schema):
rewritten = _rewrite_check_expression(check, rename, set(drop))
if rewritten is not None:
create_table_checks.append(rewritten)

sqls = []
sqls.append(
self.db.create_table_sql(
Expand All @@ -2147,6 +2306,7 @@ def transform_sql(
foreign_keys=create_table_foreign_keys,
column_order=column_order,
strict=self.strict,
check_constraints=create_table_checks or None,
).strip()
)

Expand Down
142 changes: 141 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from sqlite_utils.db import ForeignKey, TransformError
from sqlite_utils.db import (
ForeignKey,
TransformError,
_extract_check_constraints,
_rewrite_check_expression,
)
from sqlite_utils.utils import OperationalError
import pytest

Expand Down Expand Up @@ -659,3 +664,138 @@ def test_transform_with_unique_constraint_implicit_index(fresh_db):
"You must manually drop this index prior to running this transformation and manually recreate the new index after running this transformation."
in str(excinfo.value)
)


def test_transform_preserves_check_constraints(fresh_db):
fresh_db.execute(
"CREATE TABLE dogs (\n"
" id integer primary key,\n"
" age integer CHECK (age >= 0),\n"
" name text CHECK (length(name) > 0),\n"
" CHECK (age < 200)\n"
")"
)
dogs = fresh_db["dogs"]
assert dogs.transform_sql(tmp_suffix="suffix")[0] == (
'CREATE TABLE "dogs_new_suffix" (\n'
' "id" INTEGER PRIMARY KEY,\n'
' "age" INTEGER,\n'
' "name" TEXT,\n'
" CHECK (age >= 0),\n"
" CHECK (length(name) > 0),\n"
" CHECK (age < 200)\n"
");"
)
dogs.transform()
# Constraints must still be enforced after the transform
with pytest.raises(Exception):
dogs.insert({"id": 1, "age": -1, "name": "Cleo"})
with pytest.raises(Exception):
dogs.insert({"id": 2, "age": 5, "name": ""})
dogs.insert({"id": 3, "age": 5, "name": "Cleo"})
assert dogs.count == 1


def test_transform_rewrites_check_constraint_on_rename(fresh_db):
fresh_db.execute(
"CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0))"
)
dogs = fresh_db["dogs"]
dogs.transform(rename={"age": "years"})
assert 'CHECK ("years" >= 0)' in dogs.schema
# The renamed constraint is still enforced, on the new column name
with pytest.raises(Exception):
dogs.insert({"id": 1, "years": -1})
dogs.insert({"id": 2, "years": 4})
assert dogs.count == 1


def test_transform_rewrites_check_constraint_to_name_needing_quotes(fresh_db):
# Renaming a checked column to a reserved word or a name with a space must
# still produce valid SQL, just like a plain rename does
fresh_db.execute(
"CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0))"
)
dogs = fresh_db["dogs"]
dogs.transform(rename={"age": "order by"})
assert 'CHECK ("order by" >= 0)' in dogs.schema
with pytest.raises(Exception):
dogs.insert({"id": 1, "order by": -1})
dogs.insert({"id": 2, "order by": 4})
assert dogs.count == 1


def test_transform_drops_check_constraint_for_dropped_column(fresh_db):
fresh_db.execute(
"CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0), "
"name text)"
)
dogs = fresh_db["dogs"]
# Dropping a column referenced by a CHECK drops the constraint rather than
# producing a table that fails to build
dogs.transform(drop=["age"])
assert "CHECK" not in dogs.schema
assert dogs.columns_dict == {"id": int, "name": str}


def test_transform_check_constraints_are_idempotent(fresh_db):
fresh_db.execute(
"CREATE TABLE dogs (id integer primary key, age integer CHECK (age >= 0))"
)
dogs = fresh_db["dogs"]
dogs.transform()
once = dogs.schema
dogs.transform()
assert dogs.schema == once
assert "CHECK (age >= 0)" in once


@pytest.mark.parametrize(
"sql,expected",
[
(
"CREATE TABLE t (id integer, age integer CHECK (age >= 0), "
"CHECK (age < 200))",
["age >= 0", "age < 200"],
),
# Nested parentheses inside the expression
("CREATE TABLE t (a int CHECK (a > 0 AND (b < 10)))", ["a > 0 AND (b < 10)"]),
# A comma and the word CHECK inside a string literal must not confuse it
(
"CREATE TABLE t (name text CHECK (name != 'a,b CHECK ('))",
["name != 'a,b CHECK ('"],
),
# CONSTRAINT name form, and no whitespace before the parenthesis
("CREATE TABLE t (a int CONSTRAINT positive CHECK(a>0))", ["a>0"]),
# A DEFAULT expression in parentheses is not a CHECK
("CREATE TABLE t (a int DEFAULT (1 + 2), b int CHECK (b > 0))", ["b > 0"]),
("CREATE TABLE t (a int, b text)", []),
],
)
def test_extract_check_constraints(sql, expected):
assert _extract_check_constraints(sql) == expected


@pytest.mark.parametrize(
"expression,rename,drop,expected",
[
("age >= 0", {"age": "years"}, set(), '"years" >= 0'),
# A substring match must not be rewritten
(
"agent = 1 AND age > 0",
{"age": "years"},
set(),
'agent = 1 AND "years" > 0',
),
("length(name) > 0", {"name": "full_name"}, set(), 'length("full_name") > 0'),
# A rename to a name needing quoting stays valid
("age >= 0", {"age": "order by"}, set(), '"order by" >= 0'),
# Identifiers inside string literals are left alone
("x != 'age'", {"age": "years"}, set(), "x != 'age'"),
# Referencing a dropped column removes the constraint entirely
("age >= 0", {}, {"age"}, None),
("age >= 0", {}, {"weight"}, "age >= 0"),
],
)
def test_rewrite_check_expression(expression, rename, drop, expected):
assert _rewrite_check_expression(expression, rename, drop) == expected