diff --git a/CLAUDE.md b/CLAUDE.md index 61a3874..041b2c5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -26,7 +26,7 @@ uv run pytest tests/integration/ -v # Lint uv run ruff check src/ tests/ -# Type check (strict mode; pre-existing lark type errors are expected) +# Type check (strict mode; should report no errors) uv run mypy src/pycel2sql/ ``` diff --git a/src/pycel2sql/_analysis.py b/src/pycel2sql/_analysis.py index 67715dd..81bea96 100644 --- a/src/pycel2sql/_analysis.py +++ b/src/pycel2sql/_analysis.py @@ -18,10 +18,13 @@ PatternType, ) from pycel2sql.dialect._base import IndexAdvisor -from pycel2sql.schema import Schema +from pycel2sql.schema import Schema, field_is_json +# Annotation-only alias for a Lark parse-tree node (see _converter.py). +TreeT = Tree[Token] -class IndexAnalyzer(Interpreter): + +class IndexAnalyzer(Interpreter[Token, None]): """Lightweight Lark Interpreter that walks the parse tree detecting index-worthy patterns. Does NOT generate SQL - only collects patterns. @@ -40,7 +43,7 @@ def __init__( def patterns(self) -> list[IndexPattern]: return list(self._patterns.values()) - def visit(self, tree: Tree) -> Any: + def visit(self, tree: TreeT) -> Any: if isinstance(tree, Token): return None return super().visit(tree) @@ -56,64 +59,64 @@ def _add_pattern(self, pattern: IndexPattern) -> None: # --- Tree walking --- - def _visit_children(self, tree: Tree) -> None: + def _visit_children(self, tree: TreeT) -> None: for child in tree.children: if isinstance(child, Tree): self.visit(child) # Top-level passthrough handlers - def expr(self, tree: Tree) -> None: + def expr(self, tree: TreeT) -> None: self._visit_children(tree) - def conditionalor(self, tree: Tree) -> None: + def conditionalor(self, tree: TreeT) -> None: self._visit_children(tree) - def conditionaland(self, tree: Tree) -> None: + def conditionaland(self, tree: TreeT) -> None: self._visit_children(tree) - def addition(self, tree: Tree) -> None: + def addition(self, tree: TreeT) -> None: self._visit_children(tree) - def multiplication(self, tree: Tree) -> None: + def multiplication(self, tree: TreeT) -> None: self._visit_children(tree) - def unary(self, tree: Tree) -> None: + def unary(self, tree: TreeT) -> None: self._visit_children(tree) - def member(self, tree: Tree) -> None: + def member(self, tree: TreeT) -> None: self._visit_children(tree) - def primary(self, tree: Tree) -> None: + def primary(self, tree: TreeT) -> None: self._visit_children(tree) - def paren_expr(self, tree: Tree) -> None: + def paren_expr(self, tree: TreeT) -> None: self._visit_children(tree) - def literal(self, tree: Tree) -> None: + def literal(self, tree: TreeT) -> None: pass - def list_lit(self, tree: Tree) -> None: + def list_lit(self, tree: TreeT) -> None: self._visit_children(tree) - def map_lit(self, tree: Tree) -> None: + def map_lit(self, tree: TreeT) -> None: self._visit_children(tree) - def exprlist(self, tree: Tree) -> None: + def exprlist(self, tree: TreeT) -> None: self._visit_children(tree) - def mapinits(self, tree: Tree) -> None: + def mapinits(self, tree: TreeT) -> None: self._visit_children(tree) - def fieldinits(self, tree: Tree) -> None: + def fieldinits(self, tree: TreeT) -> None: self._visit_children(tree) - def ident(self, tree: Tree) -> None: + def ident(self, tree: TreeT) -> None: pass - def ident_arg(self, tree: Tree) -> None: + def ident_arg(self, tree: TreeT) -> None: func_name = str(tree.children[0]) args_node = tree.children[1] if len(tree.children) > 1 else None - args = args_node.children if args_node is not None else [] + args = args_node.children if isinstance(args_node, Tree) else [] if func_name == "matches" and len(args) >= 1: col = self._extract_column_name(args[0]) @@ -127,43 +130,43 @@ def ident_arg(self, tree: Tree) -> None: self._visit_children(tree) - def dot_ident_arg(self, tree: Tree) -> None: + def dot_ident_arg(self, tree: TreeT) -> None: self._visit_children(tree) - def dot_ident(self, tree: Tree) -> None: + def dot_ident(self, tree: TreeT) -> None: pass - def member_index(self, tree: Tree) -> None: + def member_index(self, tree: TreeT) -> None: self._visit_children(tree) - def member_object(self, tree: Tree) -> None: + def member_object(self, tree: TreeT) -> None: pass # Operator prefix handlers - def addition_add(self, tree: Tree) -> None: + def addition_add(self, tree: TreeT) -> None: self._visit_children(tree) - def addition_sub(self, tree: Tree) -> None: + def addition_sub(self, tree: TreeT) -> None: self._visit_children(tree) - def multiplication_mul(self, tree: Tree) -> None: + def multiplication_mul(self, tree: TreeT) -> None: self._visit_children(tree) - def multiplication_div(self, tree: Tree) -> None: + def multiplication_div(self, tree: TreeT) -> None: self._visit_children(tree) - def multiplication_mod(self, tree: Tree) -> None: + def multiplication_mod(self, tree: TreeT) -> None: self._visit_children(tree) - def unary_not(self, tree: Tree) -> None: + def unary_not(self, tree: TreeT) -> None: pass - def unary_neg(self, tree: Tree) -> None: + def unary_neg(self, tree: TreeT) -> None: pass # --- Key detection points --- - def relation(self, tree: Tree) -> None: + def relation(self, tree: TreeT) -> None: """Detect comparison operators -> COMPARISON pattern.""" children = tree.children if len(children) == 2: @@ -197,28 +200,28 @@ def relation(self, tree: Tree) -> None: self._visit_children(tree) # Relation prefix handlers - def relation_eq(self, tree: Tree) -> None: + def relation_eq(self, tree: TreeT) -> None: self._visit_children(tree) - def relation_ne(self, tree: Tree) -> None: + def relation_ne(self, tree: TreeT) -> None: self._visit_children(tree) - def relation_lt(self, tree: Tree) -> None: + def relation_lt(self, tree: TreeT) -> None: self._visit_children(tree) - def relation_le(self, tree: Tree) -> None: + def relation_le(self, tree: TreeT) -> None: self._visit_children(tree) - def relation_gt(self, tree: Tree) -> None: + def relation_gt(self, tree: TreeT) -> None: self._visit_children(tree) - def relation_ge(self, tree: Tree) -> None: + def relation_ge(self, tree: TreeT) -> None: self._visit_children(tree) - def relation_in(self, tree: Tree) -> None: + def relation_in(self, tree: TreeT) -> None: self._visit_children(tree) - def member_dot(self, tree: Tree) -> None: + def member_dot(self, tree: TreeT) -> None: """Detect JSON field access -> JSON_ACCESS pattern.""" obj = tree.children[0] field_name = str(tree.children[1]) @@ -235,7 +238,7 @@ def member_dot(self, tree: Tree) -> None: self._visit_children(tree) - def member_dot_arg(self, tree: Tree) -> None: + def member_dot_arg(self, tree: TreeT) -> None: """Detect matches() -> REGEX_MATCH and comprehensions -> ARRAY/JSON_ARRAY_COMPREHENSION.""" obj = tree.children[0] method_name = str(tree.children[1]) @@ -271,7 +274,7 @@ def member_dot_arg(self, tree: Tree) -> None: # --- Helper methods --- - def _extract_column_name(self, tree: Tree | Token) -> str | None: + def _extract_column_name(self, tree: TreeT | Token) -> str | None: """Extract a column/field name from a tree node.""" node = tree while isinstance(node, Tree): @@ -287,12 +290,12 @@ def _extract_column_name(self, tree: Tree | Token) -> str | None: return str(node) return None - def _extract_table_name(self, tree: Tree | Token) -> str: + def _extract_table_name(self, tree: TreeT | Token) -> str: """Extract the root table name from a tree.""" root = self._get_root_ident(tree) return root or "" - def _get_root_ident(self, tree: Tree | Token) -> str | None: + def _get_root_ident(self, tree: TreeT | Token) -> str | None: node = tree while isinstance(node, Tree): if node.data == "ident": @@ -307,7 +310,7 @@ def _get_root_ident(self, tree: Tree | Token) -> str | None: return str(node) return None - def _get_first_field(self, obj: Tree | Token, fallback: str) -> str: + def _get_first_field(self, obj: TreeT | Token, fallback: str) -> str: node = obj while isinstance(node, Tree): if node.data == "member_dot": @@ -325,11 +328,7 @@ def _get_first_field(self, obj: Tree | Token, fallback: str) -> str: return fallback def _is_field_json(self, table_name: str, field_name: str) -> bool: - schema = self._schemas.get(table_name) - if not schema: - return False - field = schema.find_field(field_name) - return field is not None and (field.is_json or field.is_jsonb) + return field_is_json(self._schemas, table_name, field_name) def _pattern_priority(pattern: PatternType) -> int: @@ -346,7 +345,7 @@ def _pattern_priority(pattern: PatternType) -> int: def analyze_patterns( - tree: Tree, + tree: TreeT, advisor: IndexAdvisor, schemas: dict[str, Schema] | None = None, ) -> list[IndexRecommendation]: diff --git a/src/pycel2sql/_converter.py b/src/pycel2sql/_converter.py index 61dd1d0..9639c0a 100644 --- a/src/pycel2sql/_converter.py +++ b/src/pycel2sql/_converter.py @@ -33,8 +33,15 @@ validate_field_name, validate_no_null_bytes, ) -from pycel2sql.dialect._base import Dialect -from pycel2sql.schema import Schema +from pycel2sql.dialect._base import Dialect, WriteFunc +from pycel2sql.schema import Schema, field_is_json + +# Annotation-only aliases for Lark parse-tree nodes. Safe at runtime because +# both this module and `_analysis.py` use `from __future__ import annotations`, +# so the subscripted generics are never evaluated. `Branch` mirrors lark's own +# `Tree.children` element type: a child is either a sub-tree or a leaf Token. +TreeT = Tree[Token] +Branch = TreeT | Token def _strip_quotes(s: str) -> str: @@ -76,12 +83,12 @@ def _is_bytes_token(token: Token) -> bool: return token.type == "BYTES_LIT" -def _get_literal_token(tree: Tree) -> Token | None: +def _get_literal_token(tree: Branch) -> Token | None: """Extract the literal token from a deeply-nested expression tree. Walks through the precedence chain to find a literal at the bottom. """ - node: Tree | Token = tree + node: TreeT | Token = tree while isinstance(node, Tree): if node.data == "literal": if node.children: @@ -97,7 +104,12 @@ def _get_literal_token(tree: Tree) -> Token | None: return None -def _tree_contains_string_literal(tree: Tree) -> bool: +def _children(node: Branch | None) -> list[Branch]: + """Return a node's children, or [] for a leaf Token / None.""" + return node.children if isinstance(node, Tree) else [] + + +def _tree_contains_string_literal(tree: Branch) -> bool: """Check if a tree contains any string literal at its leaves.""" if isinstance(tree, Token): return _is_string_token(tree) @@ -111,9 +123,9 @@ def _tree_contains_string_literal(tree: Tree) -> bool: ) -def _tree_is_list_literal(tree: Tree) -> bool: +def _tree_is_list_literal(tree: Branch) -> bool: """Check if a tree is a list literal (for IN operator).""" - node: Tree | Token = tree + node: TreeT | Token = tree while isinstance(node, Tree): if node.data == "list_lit": return True @@ -124,9 +136,9 @@ def _tree_is_list_literal(tree: Tree) -> bool: return False -def _unwrap_to_data(tree: Tree, target_data: str) -> Tree | None: +def _unwrap_to_data(tree: Branch, target_data: str) -> TreeT | None: """Unwrap single-child tree nodes to find a node with the given data.""" - node: Tree | Token = tree + node: TreeT | Token = tree while isinstance(node, Tree): if node.data == target_data: return node @@ -137,7 +149,7 @@ def _unwrap_to_data(tree: Tree, target_data: str) -> Tree | None: return None -class Converter(Interpreter): +class Converter(Interpreter[Token, None]): """Converts a CEL Lark parse tree into a SQL WHERE clause string.""" def __init__( @@ -198,7 +210,7 @@ def _add_param(self, value: Any) -> int: self._parameters.append(value) return self._param_count - def _visit_child(self, tree: Tree) -> None: + def _visit_child(self, tree: Branch) -> None: """Visit a child node, incrementing depth.""" self._depth += 1 try: @@ -209,7 +221,7 @@ def _visit_child(self, tree: Tree) -> None: # ---- Top-level entry ---- - def visit(self, tree: Tree) -> Any: + def visit(self, tree: Branch) -> Any: """Override to handle Token children transparently.""" if isinstance(tree, Token): # Bare tokens shouldn't appear at this level normally, @@ -220,7 +232,7 @@ def visit(self, tree: Tree) -> Any: # ---- expr: top-level, potentially ternary ---- - def expr(self, tree: Tree) -> None: + def expr(self, tree: TreeT) -> None: children = tree.children if len(children) == 3: # Ternary: condition ? true_val : false_val @@ -241,7 +253,7 @@ def expr(self, tree: Tree) -> None: # ---- Logical operators ---- - def conditionalor(self, tree: Tree) -> None: + def conditionalor(self, tree: TreeT) -> None: children = tree.children if len(children) == 2: self._visit_child(children[0]) @@ -255,7 +267,7 @@ def conditionalor(self, tree: Tree) -> None: f"conditionalor has {len(children)} children", ) - def conditionaland(self, tree: Tree) -> None: + def conditionaland(self, tree: TreeT) -> None: children = tree.children if len(children) == 2: self._visit_child(children[0]) @@ -271,7 +283,7 @@ def conditionaland(self, tree: Tree) -> None: # ---- Comparison / relation ---- - def relation(self, tree: Tree) -> None: + def relation(self, tree: TreeT) -> None: children = tree.children if len(children) == 1: self._visit_child(children[0]) @@ -349,20 +361,25 @@ def relation(self, tree: Tree) -> None: rhs_is_numeric = self._is_numeric_literal(rhs) lhs_is_numeric = self._is_numeric_literal(lhs) if self._is_json_text_extraction(lhs) and rhs_is_numeric: - self._dialect.write_cast_to_numeric( - self._w, - lambda: (self._w.write("("), self._visit_child(lhs), self._w.write(")")), - ) + def write_lhs_paren() -> None: + self._w.write("(") + self._visit_child(lhs) + self._w.write(")") + + self._dialect.write_cast_to_numeric(self._w, write_lhs_paren) self._w.write(f" {sql_op} ") self._visit_child(rhs) return if self._is_json_text_extraction(rhs) and lhs_is_numeric: self._visit_child(lhs) self._w.write(f" {sql_op} ") - self._dialect.write_cast_to_numeric( - self._w, - lambda: (self._w.write("("), self._visit_child(rhs), self._w.write(")")), - ) + + def write_rhs_paren() -> None: + self._w.write("(") + self._visit_child(rhs) + self._w.write(")") + + self._dialect.write_cast_to_numeric(self._w, write_rhs_paren) return self._visit_child(lhs) @@ -370,30 +387,30 @@ def relation(self, tree: Tree) -> None: self._visit_child(rhs) # Operator prefix handlers - they just delegate to relation - def relation_eq(self, tree: Tree) -> None: + def relation_eq(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def relation_ne(self, tree: Tree) -> None: + def relation_ne(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def relation_lt(self, tree: Tree) -> None: + def relation_lt(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def relation_le(self, tree: Tree) -> None: + def relation_le(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def relation_gt(self, tree: Tree) -> None: + def relation_gt(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def relation_ge(self, tree: Tree) -> None: + def relation_ge(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def relation_in(self, tree: Tree) -> None: + def relation_in(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) # ---- IN operator ---- - def _visit_in(self, lhs: Tree, rhs: Tree) -> None: + def _visit_in(self, lhs: Branch, rhs: Branch) -> None: """Handle the 'in' operator: x in [1,2,3], x in arr, or x in t.json_arr.""" # x in routes to a dialect-specific membership # predicate; a JSON array can't be tested with plain `= ANY(...)`. @@ -406,7 +423,7 @@ def _visit_in(self, lhs: Tree, rhs: Tree) -> None: lambda: self._visit_child(rhs), ) - def _visit_in_json_array(self, lhs: Tree, rhs: Tree) -> bool: + def _visit_in_json_array(self, lhs: Branch, rhs: Branch) -> bool: """Route `x in ` to the JSON-array membership dialect hooks. Returns True if the RHS is a JSON array (schema-declared JSON field, @@ -454,9 +471,9 @@ def write_array() -> None: self._dialect.write_nested_json_array_membership(self._w, write_elem, write_array) return True - def _unwrap_to_member_dot(self, tree: Tree) -> Tree | None: + def _unwrap_to_member_dot(self, tree: Branch) -> TreeT | None: """Strip single-child grammar wrappers and return a member_dot node, if any.""" - node: Tree | Token = tree + node: TreeT | Token = tree while ( isinstance(node, Tree) and node.data in ( @@ -470,9 +487,9 @@ def _unwrap_to_member_dot(self, tree: Tree) -> Tree | None: return node return None - def _obj_is_bare_root(self, obj: Tree | Token) -> bool: + def _obj_is_bare_root(self, obj: TreeT | Token) -> bool: """True if a member_dot operand unwraps to a bare identifier (the table root).""" - node: Tree | Token = obj + node: TreeT | Token = obj while ( isinstance(node, Tree) and node.data in ("member", "primary") @@ -483,7 +500,7 @@ def _obj_is_bare_root(self, obj: Tree | Token) -> bool: # ---- Arithmetic ---- - def addition(self, tree: Tree) -> None: + def addition(self, tree: TreeT) -> None: children = tree.children if len(children) == 1: self._visit_child(children[0]) @@ -559,13 +576,13 @@ def addition(self, tree: Tree) -> None: f"unknown addition operator: {op_name}", ) - def addition_add(self, tree: Tree) -> None: + def addition_add(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def addition_sub(self, tree: Tree) -> None: + def addition_sub(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def multiplication(self, tree: Tree) -> None: + def multiplication(self, tree: TreeT) -> None: children = tree.children if len(children) == 1: self._visit_child(children[0]) @@ -606,18 +623,18 @@ def multiplication(self, tree: Tree) -> None: f"unknown multiplication operator: {op_name}", ) - def multiplication_mul(self, tree: Tree) -> None: + def multiplication_mul(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def multiplication_div(self, tree: Tree) -> None: + def multiplication_div(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) - def multiplication_mod(self, tree: Tree) -> None: + def multiplication_mod(self, tree: TreeT) -> None: self._visit_child(tree.children[0]) # ---- Unary ---- - def unary(self, tree: Tree) -> None: + def unary(self, tree: TreeT) -> None: children = tree.children if len(children) == 1: self._visit_child(children[0]) @@ -641,15 +658,15 @@ def unary(self, tree: Tree) -> None: f"unary has {len(children)} children", ) - def unary_not(self, tree: Tree) -> None: + def unary_not(self, tree: TreeT) -> None: pass # Handled by unary() - def unary_neg(self, tree: Tree) -> None: + def unary_neg(self, tree: TreeT) -> None: pass # Handled by unary() # ---- Member access ---- - def member(self, tree: Tree) -> None: + def member(self, tree: TreeT) -> None: if len(tree.children) == 1: self._visit_child(tree.children[0]) else: @@ -658,7 +675,7 @@ def member(self, tree: Tree) -> None: f"member has {len(tree.children)} children", ) - def member_dot(self, tree: Tree) -> None: + def member_dot(self, tree: TreeT) -> None: """Field access: a.b""" obj = tree.children[0] field_name = str(tree.children[1]) @@ -687,12 +704,12 @@ def member_dot(self, tree: Tree) -> None: validate_field_name(field_name) self._w.write(field_name) - def member_dot_arg(self, tree: Tree) -> None: + def member_dot_arg(self, tree: TreeT) -> None: """Method call: a.method(args) or comprehension macro.""" obj = tree.children[0] method_name = str(tree.children[1]) args_node = tree.children[2] if len(tree.children) > 2 else None - args = args_node.children if args_node is not None else [] + args = _children(args_node) # Comprehension macros if method_name in ("all", "exists", "exists_one", "map", "filter"): @@ -774,7 +791,7 @@ def member_dot_arg(self, tree: Tree) -> None: f"unknown method: {method_name}", ) - def member_index(self, tree: Tree) -> None: + def member_index(self, tree: TreeT) -> None: """Index access: a[0] or a["key"].""" obj = tree.children[0] index_expr = tree.children[1] @@ -821,12 +838,12 @@ def member_index(self, tree: Tree) -> None: lambda: self._visit_child(index_expr), ) - def member_object(self, tree: Tree) -> None: + def member_object(self, tree: TreeT) -> None: raise UnsupportedExpressionError("object construction not supported in SQL conversion") # ---- Primary expressions ---- - def primary(self, tree: Tree) -> None: + def primary(self, tree: TreeT) -> None: if len(tree.children) == 1: self._visit_child(tree.children[0]) else: @@ -835,7 +852,7 @@ def primary(self, tree: Tree) -> None: f"primary has {len(tree.children)} children", ) - def ident(self, tree: Tree) -> None: + def ident(self, tree: TreeT) -> None: """Bare identifier.""" name = str(tree.children[0]) # Don't validate or alias comprehension iteration variables @@ -846,11 +863,11 @@ def ident(self, tree: Tree) -> None: validate_field_name(resolved) self._w.write(resolved) - def ident_arg(self, tree: Tree) -> None: + def ident_arg(self, tree: TreeT) -> None: """Function call: func(args).""" func_name = str(tree.children[0]) args_node = tree.children[1] if len(tree.children) > 1 else None - args = args_node.children if args_node is not None else [] + args = _children(args_node) # has() function if func_name == "has": @@ -908,22 +925,22 @@ def ident_arg(self, tree: Tree) -> None: self._visit_child(arg) self._w.write(")") - def dot_ident_arg(self, tree: Tree) -> None: + def dot_ident_arg(self, tree: TreeT) -> None: func_name = str(tree.children[0]) self._w.write(f".{func_name}(") if len(tree.children) > 1: args_node = tree.children[1] - for i, arg in enumerate(args_node.children): + for i, arg in enumerate(_children(args_node)): if i > 0: self._w.write(", ") self._visit_child(arg) self._w.write(")") - def dot_ident(self, tree: Tree) -> None: + def dot_ident(self, tree: TreeT) -> None: name = str(tree.children[0]) self._w.write(f".{name}") - def paren_expr(self, tree: Tree) -> None: + def paren_expr(self, tree: TreeT) -> None: """Parenthesized expression.""" self._w.write("(") self._visit_child(tree.children[0]) @@ -931,7 +948,7 @@ def paren_expr(self, tree: Tree) -> None: # ---- Literals ---- - def literal(self, tree: Tree) -> None: + def literal(self, tree: TreeT) -> None: token = tree.children[0] if not isinstance(token, Token): raise UnsupportedExpressionError("unexpected literal structure") @@ -956,9 +973,9 @@ def literal(self, tree: Tree) -> None: else: self._w.write(str(val)) elif _is_float_token(token): - val = float(str(token)) + fval = float(str(token)) if self._parameterize: - idx = self._add_param(val) + idx = self._add_param(fval) self._dialect.write_param_placeholder(self._w, idx) else: self._w.write(str(token)) @@ -998,7 +1015,7 @@ def literal(self, tree: Tree) -> None: f"unknown token type: {token.type}", ) - def list_lit(self, tree: Tree) -> None: + def list_lit(self, tree: TreeT) -> None: """List literal: [1, 2, 3] -> ARRAY[1, 2, 3].""" self._dialect.write_array_literal_open(self._w) if tree.children: @@ -1010,7 +1027,7 @@ def list_lit(self, tree: Tree) -> None: self._visit_child(child) self._dialect.write_array_literal_close(self._w) - def map_lit(self, tree: Tree) -> None: + def map_lit(self, tree: TreeT) -> None: """Map literal: {"k": v} -> ROW(v) with .k access.""" self._dialect.write_struct_open(self._w) if tree.children: @@ -1027,13 +1044,13 @@ def map_lit(self, tree: Tree) -> None: self._visit_child(children[i + 1]) self._dialect.write_struct_close(self._w) - def exprlist(self, tree: Tree) -> None: + def exprlist(self, tree: TreeT) -> None: for i, child in enumerate(tree.children): if i > 0: self._w.write(", ") self._visit_child(child) - def mapinits(self, tree: Tree) -> None: + def mapinits(self, tree: TreeT) -> None: children = tree.children first = True for i in range(0, len(children), 2): @@ -1042,7 +1059,7 @@ def mapinits(self, tree: Tree) -> None: first = False self._visit_child(children[i + 1]) - def fieldinits(self, tree: Tree) -> None: + def fieldinits(self, tree: TreeT) -> None: children = tree.children first = True for i in range(0, len(children), 2): @@ -1053,7 +1070,7 @@ def fieldinits(self, tree: Tree) -> None: # ---- String functions ---- - def _visit_contains(self, obj: Tree, args: list) -> None: + def _visit_contains(self, obj: Branch, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("contains() requires exactly 1 argument") needle_token = _get_literal_token(args[0]) @@ -1075,7 +1092,7 @@ def _visit_contains(self, obj: Tree, args: list) -> None: lambda: self._visit_child(args[0]), ) - def _visit_starts_with(self, obj: Tree, args: list) -> None: + def _visit_starts_with(self, obj: Branch, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("startsWith() requires exactly 1 argument") token = _get_literal_token(args[0]) @@ -1090,7 +1107,7 @@ def _visit_starts_with(self, obj: Tree, args: list) -> None: self._w.write(f" LIKE '{escaped}%'") self._dialect.write_like_escape(self._w) - def _visit_ends_with(self, obj: Tree, args: list) -> None: + def _visit_ends_with(self, obj: Branch, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("endsWith() requires exactly 1 argument") token = _get_literal_token(args[0]) @@ -1105,7 +1122,7 @@ def _visit_ends_with(self, obj: Tree, args: list) -> None: self._w.write(f" LIKE '%{escaped}'") self._dialect.write_like_escape(self._w) - def _visit_matches_method(self, obj: Tree, args: list) -> None: + def _visit_matches_method(self, obj: Branch, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("matches() requires exactly 1 argument") token = _get_literal_token(args[0]) @@ -1122,7 +1139,7 @@ def _visit_matches_method(self, obj: Tree, args: list) -> None: case_insensitive, ) - def _visit_matches_func(self, target: Tree, pattern_expr: Tree) -> None: + def _visit_matches_func(self, target: Branch, pattern_expr: Branch) -> None: token = _get_literal_token(pattern_expr) if not token or not _is_string_token(token): raise InvalidArgumentsError("matches() requires a string literal pattern") @@ -1137,7 +1154,7 @@ def _visit_matches_func(self, target: Tree, pattern_expr: Tree) -> None: case_insensitive, ) - def _visit_size_method(self, obj: Tree) -> None: + def _visit_size_method(self, obj: Branch) -> None: """size() as a method call on an object.""" if self._is_array_expression(obj): self._dialect.write_array_length( @@ -1148,7 +1165,7 @@ def _visit_size_method(self, obj: Tree) -> None: self._visit_child(obj) self._w.write(")") - def _visit_size_func(self, arg: Tree) -> None: + def _visit_size_func(self, arg: Branch) -> None: """size(x) function call.""" if self._is_array_expression(arg): self._dialect.write_array_length( @@ -1159,7 +1176,7 @@ def _visit_size_func(self, arg: Tree) -> None: self._visit_child(arg) self._w.write(")") - def _visit_char_at(self, obj: Tree, args: list) -> None: + def _visit_char_at(self, obj: Branch, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("charAt() requires exactly 1 argument") idx_literal = _get_literal_token(args[0]) @@ -1174,7 +1191,7 @@ def _visit_char_at(self, obj: Tree, args: list) -> None: self._w.write(" + 1") self._w.write(", 1)") - def _visit_index_of(self, obj: Tree, args: list) -> None: + def _visit_index_of(self, obj: Branch, args: list[Branch]) -> None: if len(args) < 1 or len(args) > 2: raise InvalidArgumentsError("indexOf() requires 1 or 2 arguments") @@ -1224,7 +1241,7 @@ def _visit_index_of(self, obj: Tree, args: list) -> None: self._visit_child(args[1]) self._w.write(" - 1 ELSE -1 END") - def _visit_last_index_of(self, obj: Tree, args: list) -> None: + def _visit_last_index_of(self, obj: Branch, args: list[Branch]) -> None: if len(args) < 1: raise InvalidArgumentsError("lastIndexOf() requires at least 1 argument") # lastIndexOf(needle) using REVERSE @@ -1242,7 +1259,7 @@ def _visit_last_index_of(self, obj: Tree, args: list) -> None: self._visit_child(args[0]) self._w.write(") + 1 ELSE -1 END") - def _visit_substring(self, obj: Tree, args: list) -> None: + def _visit_substring(self, obj: Branch, args: list[Branch]) -> None: if len(args) < 1 or len(args) > 2: raise InvalidArgumentsError("substring() requires 1 or 2 arguments") @@ -1283,7 +1300,7 @@ def _visit_substring(self, obj: Tree, args: list) -> None: self._w.write(")") self._w.write(")") - def _visit_replace(self, obj: Tree, args: list) -> None: + def _visit_replace(self, obj: Branch, args: list[Branch]) -> None: if len(args) < 2 or len(args) > 3: raise InvalidArgumentsError("replace() requires 2 or 3 arguments") @@ -1306,7 +1323,7 @@ def _visit_replace(self, obj: Tree, args: list) -> None: self._visit_child(args[1]) self._w.write(")") - def _visit_split(self, obj: Tree, args: list) -> None: + def _visit_split(self, obj: Branch, args: list[Branch]) -> None: if len(args) < 1 or len(args) > 2: raise InvalidArgumentsError("split() requires 1 or 2 arguments") @@ -1353,7 +1370,7 @@ def _visit_split(self, obj: Tree, args: list) -> None: else: raise InvalidArgumentsError("split() limit must be an integer literal") - def _visit_join(self, obj: Tree, args: list) -> None: + def _visit_join(self, obj: Branch, args: list[Branch]) -> None: if len(args) > 1: raise InvalidArgumentsError("join() requires 0 or 1 arguments") if len(args) == 0: @@ -1369,7 +1386,7 @@ def _visit_join(self, obj: Tree, args: list) -> None: lambda: self._visit_child(args[0]), ) - def _visit_format(self, obj: Tree, args: list) -> None: + def _visit_format(self, obj: Branch, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("format() requires exactly 1 argument (the arg list)") @@ -1417,7 +1434,7 @@ def _visit_format(self, obj: Tree, args: list) -> None: # ---- has() function ---- - def _visit_has(self, args: list) -> None: + def _visit_has(self, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("has() requires exactly 1 argument") arg = args[0] @@ -1465,7 +1482,7 @@ def _visit_has(self, args: list) -> None: # ---- Type casting ---- - def _visit_type_cast(self, type_name: str, args: list) -> None: + def _visit_type_cast(self, type_name: str, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError(f"{type_name}() requires exactly 1 argument") @@ -1486,7 +1503,7 @@ def _visit_type_cast(self, type_name: str, args: list) -> None: # ---- Timestamp functions ---- - def _visit_timestamp_func(self, args: list) -> None: + def _visit_timestamp_func(self, args: list[Branch]) -> None: if len(args) == 1: # timestamp("2021-01-01T00:00:00Z") -> CAST('...' AS TIMESTAMP WITH TIME ZONE) self._dialect.write_timestamp_cast( @@ -1500,7 +1517,7 @@ def _visit_timestamp_func(self, args: list) -> None: else: raise InvalidArgumentsError("timestamp() requires 1 or 2 arguments") - def _visit_duration_func(self, args: list) -> None: + def _visit_duration_func(self, args: list[Branch]) -> None: if len(args) != 1: raise InvalidArgumentsError("duration() requires exactly 1 argument") token = _get_literal_token(args[0]) @@ -1512,7 +1529,7 @@ def _visit_duration_func(self, args: list) -> None: value, unit = self._parse_duration(raw) self._dialect.write_duration(self._w, value, unit) - def _visit_interval_func(self, args: list) -> None: + def _visit_interval_func(self, args: list[Branch]) -> None: if len(args) != 2: raise InvalidArgumentsError("interval() requires exactly 2 arguments") # interval(value, UNIT) - UNIT is an identifier @@ -1532,7 +1549,7 @@ def _visit_interval_func(self, args: list) -> None: unit, ) - def _visit_datetime_constructor(self, func_name: str, args: list) -> None: + def _visit_datetime_constructor(self, func_name: str, args: list[Branch]) -> None: """Handle date(), time(), datetime() constructors.""" self._w.write(func_name.upper()) self._w.write("(") @@ -1542,7 +1559,7 @@ def _visit_datetime_constructor(self, func_name: str, args: list) -> None: self._visit_child(arg) self._w.write(")") - def _visit_current_datetime(self, func_name: str, args: list) -> None: + def _visit_current_datetime(self, func_name: str, args: list[Branch]) -> None: self._w.write(func_name.upper()) self._w.write("(") for i, arg in enumerate(args): @@ -1551,7 +1568,7 @@ def _visit_current_datetime(self, func_name: str, args: list) -> None: self._visit_child(arg) self._w.write(")") - def _visit_timestamp_extract(self, obj: Tree, method_name: str, args: list) -> None: + def _visit_timestamp_extract(self, obj: Branch, method_name: str, args: list[Branch]) -> None: """Handle timestamp extraction methods: getFullYear(), getMonth(), etc.""" part_map = { "getFullYear": "YEAR", @@ -1570,10 +1587,10 @@ def _visit_timestamp_extract(self, obj: Tree, method_name: str, args: list) -> N raise UnsupportedExpressionError(f"unsupported timestamp method: {method_name}") # Check for timezone argument - write_tz = None + write_tz: WriteFunc | None = None if args: - def write_tz(): - return self._visit_child(args[0]) + def write_tz() -> None: + self._visit_child(args[0]) self._dialect.write_extract( self._w, part, lambda: self._visit_child(obj), write_tz @@ -1628,7 +1645,7 @@ def _parse_duration(self, duration_str: str) -> tuple[int, str]: # ---- Comprehensions ---- - def _visit_comprehension(self, source: Tree, macro_name: str, args: list) -> None: + def _visit_comprehension(self, source: Branch, macro_name: str, args: list[Branch]) -> None: """Handle comprehension macros: all, exists, exists_one, map, filter.""" if self._comprehension_depth >= MAX_COMPREHENSION_DEPTH: raise MaxComprehensionDepthExceededError( @@ -1656,12 +1673,12 @@ def _visit_comprehension(self, source: Tree, macro_name: str, args: list) -> Non finally: self._comprehension_depth -= 1 - def _write_unnest_source(self, source: Tree, iter_var: str) -> None: + def _write_unnest_source(self, source: Branch, iter_var: str) -> None: """Write the UNNEST(source) AS var clause.""" self._dialect.write_unnest(self._w, lambda: self._visit_child(source)) self._w.write(f" AS {iter_var}") - def _visit_comp_all(self, source: Tree, args: list) -> None: + def _visit_comp_all(self, source: Branch, args: list[Branch]) -> None: """all(x, pred) -> NOT EXISTS (SELECT 1 FROM UNNEST(src) AS x WHERE NOT (pred))""" if len(args) != 2: raise InvalidArgumentsError("all() requires exactly 2 arguments") @@ -1677,7 +1694,7 @@ def _visit_comp_all(self, source: Tree, args: list) -> None: finally: self._comprehension_vars.discard(iter_var) - def _visit_comp_exists(self, source: Tree, args: list) -> None: + def _visit_comp_exists(self, source: Branch, args: list[Branch]) -> None: """exists(x, pred) -> EXISTS (SELECT 1 FROM UNNEST(src) AS x WHERE pred)""" if len(args) != 2: raise InvalidArgumentsError("exists() requires exactly 2 arguments") @@ -1693,7 +1710,7 @@ def _visit_comp_exists(self, source: Tree, args: list) -> None: finally: self._comprehension_vars.discard(iter_var) - def _visit_comp_exists_one(self, source: Tree, args: list) -> None: + def _visit_comp_exists_one(self, source: Branch, args: list[Branch]) -> None: """exists_one(x, pred) -> (SELECT COUNT(*) FROM UNNEST(src) AS x WHERE pred) = 1""" if len(args) != 2: raise InvalidArgumentsError("exists_one() requires exactly 2 arguments") @@ -1709,7 +1726,7 @@ def _visit_comp_exists_one(self, source: Tree, args: list) -> None: finally: self._comprehension_vars.discard(iter_var) - def _visit_comp_map(self, source: Tree, args: list) -> None: + def _visit_comp_map(self, source: Branch, args: list[Branch]) -> None: """map(x, transform) -> ARRAY(SELECT transform FROM UNNEST(src) AS x)""" if len(args) != 2: raise InvalidArgumentsError("map() requires exactly 2 arguments") @@ -1726,7 +1743,7 @@ def _visit_comp_map(self, source: Tree, args: list) -> None: finally: self._comprehension_vars.discard(iter_var) - def _visit_comp_map_filter(self, source: Tree, args: list) -> None: + def _visit_comp_map_filter(self, source: Branch, args: list[Branch]) -> None: """map(x, filter, transform) -> ARRAY(SELECT transform FROM UNNEST(src) AS x WHERE filter)""" iter_var = self._get_ident_name(args[0]) filter_pred = args[1] @@ -1744,7 +1761,7 @@ def _visit_comp_map_filter(self, source: Tree, args: list) -> None: finally: self._comprehension_vars.discard(iter_var) - def _visit_comp_filter(self, source: Tree, args: list) -> None: + def _visit_comp_filter(self, source: Branch, args: list[Branch]) -> None: """filter(x, pred) -> ARRAY(SELECT x FROM UNNEST(src) AS x WHERE pred)""" if len(args) != 2: raise InvalidArgumentsError("filter() requires exactly 2 arguments") @@ -1779,11 +1796,7 @@ def _is_json_variable_root(self, name: str | None) -> bool: return name in self._json_variables def _is_field_json(self, table_name: str, field_name: str) -> bool: - schema = self._schemas.get(table_name) - if not schema: - return False - field = schema.find_field(field_name) - return field is not None and (field.is_json or field.is_jsonb) + return field_is_json(self._schemas, table_name, field_name) def _is_field_jsonb(self, table_name: str, field_name: str) -> bool: schema = self._schemas.get(table_name) @@ -1792,7 +1805,7 @@ def _is_field_jsonb(self, table_name: str, field_name: str) -> bool: field = schema.find_field(field_name) return field is not None and field.is_jsonb - def _is_nested_json_field(self, tree: Tree) -> bool: + def _is_nested_json_field(self, tree: TreeT) -> bool: """Check if a member_dot chain involves a JSON field.""" if tree.data != "member_dot": return False @@ -1803,11 +1816,11 @@ def _is_nested_json_field(self, tree: Tree) -> bool: first_field = self._get_first_field(obj, str(tree.children[1])) return self._is_field_json(table_name, first_field) - def _build_json_path(self, tree: Tree) -> None: + def _build_json_path(self, tree: TreeT) -> None: """Build a JSON path expression from a member_dot chain.""" # Collect the chain: table.json_field.path1.path2... parts: list[str] = [] - node = tree + node: Branch = tree while isinstance(node, Tree) and node.data == "member_dot": parts.append(str(node.children[1])) node = node.children[0] @@ -1823,7 +1836,7 @@ def _build_json_path(self, tree: Tree) -> None: def _emit_json_path( self, - root_node: Tree | Token, + root_node: TreeT | Token, parts: list[str], *, root_is_column: bool, @@ -1882,7 +1895,7 @@ def intermediate() -> None: return intermediate current_base = make_base() - def _is_json_text_extraction(self, tree: Tree) -> bool: + def _is_json_text_extraction(self, tree: Branch) -> bool: """Check if a tree represents a JSON text extraction (->>).""" node = tree while isinstance(node, Tree) and node.data in ( @@ -1915,7 +1928,7 @@ def _validate_field_in_schema(self, table_name: str, field_name: str) -> None: f"field '{field_name}' not found in schema for '{table_name}'", ) - def _is_member_dot_array_field(self, tree: Tree) -> bool: + def _is_member_dot_array_field(self, tree: Branch) -> bool: """Check if a tree represents an array field via schema.""" node = tree while isinstance(node, Tree) and node.data in ( @@ -1944,7 +1957,7 @@ def _is_field_array(self, table_name: str, field_name: str) -> bool: field = schema.find_field(field_name) return field is not None and field.repeated - def _is_timestamp_field(self, root_name: str, tree: Tree) -> bool: + def _is_timestamp_field(self, root_name: str, tree: Branch) -> bool: """Check if a field is a timestamp type based on schema or naming.""" _TIMESTAMP_NAMES = {"created_at", "updated_at", "timestamp", "ts"} # Check the root identifier name directly @@ -1960,14 +1973,14 @@ def _is_timestamp_field(self, root_name: str, tree: Tree) -> bool: return False @staticmethod - def _is_numeric_literal(tree: Tree) -> bool: + def _is_numeric_literal(tree: Branch) -> bool: """Check if a tree is a numeric literal (int or float).""" tok = _get_literal_token(tree) if tok is None: return False return _is_int_token(tok) or _is_float_token(tok) or _is_uint_token(tok) - def _is_array_expression(self, tree: Tree) -> bool: + def _is_array_expression(self, tree: Branch) -> bool: """Check if a tree produces an array result. Detects: schema array fields, list literals, split(), filter(), map(), @@ -2011,7 +2024,7 @@ def _is_array_expression(self, tree: Tree) -> bool: # ---- Utility helpers ---- - def _get_root_ident(self, tree: Tree | Token) -> str | None: + def _get_root_ident(self, tree: TreeT | Token) -> str | None: """Get the root identifier name from a tree.""" node = tree while isinstance(node, Tree): @@ -2027,7 +2040,7 @@ def _get_root_ident(self, tree: Tree | Token) -> str | None: return str(node) return None - def _get_first_field(self, obj: Tree | Token, fallback: str) -> str: + def _get_first_field(self, obj: TreeT | Token, fallback: str) -> str: """Get the first field name in a member_dot chain.""" node = obj while isinstance(node, Tree): @@ -2047,7 +2060,7 @@ def _get_first_field(self, obj: Tree | Token, fallback: str) -> str: break return fallback - def _get_ident_name(self, tree: Tree) -> str: + def _get_ident_name(self, tree: Branch) -> str: """Extract identifier name from a tree node.""" node = tree while isinstance(node, Tree): @@ -2058,7 +2071,7 @@ def _get_ident_name(self, tree: Tree) -> str: else: raise InvalidArgumentsError( "expected identifier", - f"cannot extract identifier from {tree.data}", + f"cannot extract identifier from {node.data}", ) if isinstance(node, Token) and node.type == "IDENT": return str(node) @@ -2070,7 +2083,7 @@ def _get_ident_name(self, tree: Tree) -> str: def _is_comprehension_var(self, name: str) -> bool: return name in self._comprehension_vars - def _is_duration_expression(self, tree: Tree | Token) -> bool: + def _is_duration_expression(self, tree: TreeT | Token) -> bool: """Check if a tree is specifically a duration/interval expression.""" if isinstance(tree, Token): return False @@ -2084,7 +2097,7 @@ def _is_duration_expression(self, tree: Tree | Token) -> bool: and not (isinstance(child, Tree) and child.data == "exprlist") ) - def _is_timestamp_or_duration_context(self, lhs: Tree, rhs: Tree) -> bool: + def _is_timestamp_or_duration_context(self, lhs: Branch, rhs: Branch) -> bool: """Detect if this is a timestamp/duration arithmetic context. We detect this by looking for duration() or interval() or timestamp() calls, @@ -2092,7 +2105,7 @@ def _is_timestamp_or_duration_context(self, lhs: Tree, rhs: Tree) -> bool: """ return self._tree_has_temporal(lhs) or self._tree_has_temporal(rhs) - def _tree_has_temporal(self, tree: Tree | Token) -> bool: + def _tree_has_temporal(self, tree: TreeT | Token) -> bool: """Check if a tree contains temporal expressions (duration, interval, timestamp, etc.).""" if isinstance(tree, Token): return False diff --git a/src/pycel2sql/_utils.py b/src/pycel2sql/_utils.py index 03d3b30..5ea7fda 100644 --- a/src/pycel2sql/_utils.py +++ b/src/pycel2sql/_utils.py @@ -57,6 +57,42 @@ def validate_field_name(name: str) -> None: ) +def validate_sql_field_name( + name: str, + *, + reserved: set[str], + keyword_label: str, + max_length: int = 0, +) -> None: + """Validate a field/identifier name against one dialect's rules. + + ``reserved`` is the dialect's reserved-keyword set, ``keyword_label`` names + the dialect in the (internal) error detail, and ``max_length`` (0 = no limit) + caps the identifier length. Shared by the dialects whose validation differs + only in those three parameters (DuckDB, BigQuery, MySQL, SQLite). + """ + if not name: + raise InvalidFieldNameError( + "field name cannot be empty", + "empty field name provided", + ) + if max_length and len(name) > max_length: + raise InvalidFieldNameError( + "field name too long", + f"field name '{name}' exceeds {max_length} characters", + ) + if not FIELD_NAME_RE.match(name): + raise InvalidFieldNameError( + "invalid field name format", + f"field name '{name}' contains invalid characters", + ) + if name.lower() in reserved: + raise InvalidFieldNameError( + "field name is a reserved SQL keyword", + f"field name '{name}' is a reserved {keyword_label} keyword", + ) + + def escape_like_pattern(pattern: str) -> str: """Escape special characters in a SQL LIKE pattern.""" result = pattern.replace("\\", "\\\\") @@ -100,69 +136,7 @@ def convert_re2_to_posix(re2_pattern: str) -> tuple[str, bool]: Returns (posix_pattern, case_insensitive). """ - if len(re2_pattern) > MAX_REGEX_LENGTH: - raise InvalidRegexPatternError( - "regex pattern too long", - f"pattern length {len(re2_pattern)} exceeds limit {MAX_REGEX_LENGTH}", - ) - - validate_no_null_bytes(re2_pattern, "regex patterns") - - case_insensitive = False - pattern = re2_pattern - - # Extract (?i) flag - if pattern.startswith("(?i)"): - case_insensitive = True - pattern = pattern[4:] - - # Reject unsupported features - if re.search(r"\(\?[!=<]", pattern): - raise InvalidRegexPatternError( - "lookahead/lookbehind not supported", - f"pattern contains lookahead/lookbehind: {re2_pattern}", - ) - if re.search(r"\(\?P<", pattern): - raise InvalidRegexPatternError( - "named captures not supported", - f"pattern contains named captures: {re2_pattern}", - ) - # Reject inline flags other than (?i) at start - if re.search(r"\(\?[imsx]", pattern): - raise InvalidRegexPatternError( - "inline flags not supported", - f"pattern contains inline flags: {re2_pattern}", - ) - - # ReDoS detection - if _REDOS_NESTED_QUANTIFIER.search(pattern): - raise InvalidRegexPatternError( - "potential ReDoS: nested quantifiers detected", - f"pattern has nested quantifiers: {re2_pattern}", - ) - - # Check nesting depth - depth = 0 - max_depth = 0 - for ch in pattern: - if ch == "(": - depth += 1 - max_depth = max(max_depth, depth) - elif ch == ")": - depth -= 1 - if max_depth > MAX_REGEX_NESTING: - raise InvalidRegexPatternError( - "regex nesting too deep", - f"pattern nesting depth {max_depth} exceeds limit {MAX_REGEX_NESTING}", - ) - - # Count groups - group_count = pattern.count("(") - pattern.count("(?:") - if group_count > MAX_REGEX_GROUPS: - raise InvalidRegexPatternError( - "too many regex groups", - f"pattern has {group_count} groups, limit is {MAX_REGEX_GROUPS}", - ) + pattern, case_insensitive = _validate_regex_common(re2_pattern) # Convert RE2 shorthand classes to POSIX pattern = pattern.replace("\\d", "[[:digit:]]") diff --git a/src/pycel2sql/dialect/_base.py b/src/pycel2sql/dialect/_base.py index 70646bf..4a0b781 100644 --- a/src/pycel2sql/dialect/_base.py +++ b/src/pycel2sql/dialect/_base.py @@ -36,10 +36,16 @@ class Dialect(ABC): Methods receive a StringIO writer and callback functions for sub-expressions. """ + # CEL type name -> SQL type name. Subclasses populate this; consumed by the + # default ``write_type_name`` below. + TYPE_MAP: dict[str, str] = {} + # --- Literals --- - @abstractmethod - def write_string_literal(self, w: StringIO, value: str) -> None: ... + def write_string_literal(self, w: StringIO, value: str) -> None: + """Default: single-quoted with ``'`` doubled. Override for other escaping.""" + escaped = value.replace("'", "''") + w.write(f"'{escaped}'") @abstractmethod def write_bytes_literal(self, w: StringIO, value: bytes) -> None: ... @@ -49,10 +55,13 @@ def write_param_placeholder(self, w: StringIO, param_index: int) -> None: ... # --- Operators --- - @abstractmethod def write_string_concat( self, w: StringIO, write_lhs: WriteFunc, write_rhs: WriteFunc - ) -> None: ... + ) -> None: + """Default: SQL-standard ``||`` concatenation. Override for function-style concat.""" + write_lhs() + w.write(" || ") + write_rhs() @abstractmethod def write_regex_match( @@ -72,8 +81,9 @@ def write_array_membership( @abstractmethod def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: ... - @abstractmethod - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: ... + def write_type_name(self, w: StringIO, cel_type_name: str) -> None: + """Default: look up ``TYPE_MAP``, falling back to the upper-cased CEL name.""" + w.write(self.TYPE_MAP.get(cel_type_name, cel_type_name.upper())) @abstractmethod def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: ... @@ -147,13 +157,17 @@ def write_nested_json_array_membership( # --- Timestamps --- - @abstractmethod - def write_duration(self, w: StringIO, value: int, unit: str) -> None: ... + def write_duration(self, w: StringIO, value: int, unit: str) -> None: + """Default: ``INTERVAL ``. Override where intervals differ (e.g. SQLite).""" + w.write(f"INTERVAL {value} {unit}") - @abstractmethod def write_interval( self, w: StringIO, write_value: WriteFunc, unit: str - ) -> None: ... + ) -> None: + """Default: ``INTERVAL ``. Override where intervals differ (e.g. SQLite).""" + w.write("INTERVAL ") + write_value() + w.write(f" {unit}") @abstractmethod def write_extract( @@ -164,14 +178,18 @@ def write_extract( write_tz: WriteFunc | None, ) -> None: ... - @abstractmethod def write_timestamp_arithmetic( self, w: StringIO, op: str, write_ts: WriteFunc, write_dur: WriteFunc, - ) -> None: ... + ) -> None: + """Default: infix `` ``. Override for function-style (BigQuery) + or modifier-style (SQLite) arithmetic.""" + write_ts() + w.write(f" {op} ") + write_dur() # --- String Functions --- @@ -221,8 +239,9 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: ... @abstractmethod def write_struct_open(self, w: StringIO) -> None: ... - @abstractmethod - def write_struct_close(self, w: StringIO) -> None: ... + def write_struct_close(self, w: StringIO) -> None: + """Default: ``)``. Pairs with the dialect-specific ``write_struct_open``.""" + w.write(")") # --- Validation --- diff --git a/src/pycel2sql/dialect/bigquery.py b/src/pycel2sql/dialect/bigquery.py index f847f16..3205676 100644 --- a/src/pycel2sql/dialect/bigquery.py +++ b/src/pycel2sql/dialect/bigquery.py @@ -2,12 +2,9 @@ from __future__ import annotations -import re - from io import StringIO -from pycel2sql._errors import InvalidFieldNameError -from pycel2sql._utils import convert_re2_to_re2_native +from pycel2sql._utils import convert_re2_to_re2_native, validate_sql_field_name from pycel2sql.dialect._base import Dialect, WriteFunc # BigQuery reserved keywords @@ -27,8 +24,6 @@ "when", "where", "window", "with", "within", } -_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - # CEL type name -> BigQuery type name _TYPE_MAP: dict[str, str] = { "bool": "BOOL", @@ -64,6 +59,8 @@ class BigQueryDialect(Dialect): """BigQuery dialect for CEL-to-SQL conversion.""" + TYPE_MAP = _TYPE_MAP + # --- Literals --- def write_string_literal(self, w: StringIO, value: str) -> None: @@ -81,13 +78,7 @@ def write_param_placeholder(self, w: StringIO, param_index: int) -> None: w.write(f"@p{param_index}") # --- Operators --- - - def write_string_concat( - self, w: StringIO, write_lhs: WriteFunc, write_rhs: WriteFunc - ) -> None: - write_lhs() - w.write(" || ") - write_rhs() + # write_string_concat: inherits the base default (`||`). def write_regex_match( self, w: StringIO, write_target: WriteFunc, pattern: str, case_insensitive: bool @@ -121,10 +112,6 @@ def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: write_expr() w.write(" AS FLOAT64)") - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: - sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) - w.write(sql_type) - def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: w.write("UNIX_SECONDS(") write_expr() @@ -220,16 +207,7 @@ def write_nested_json_array_membership( w.write("))") # --- Timestamps --- - - def write_duration(self, w: StringIO, value: int, unit: str) -> None: - w.write(f"INTERVAL {value} {unit}") - - def write_interval( - self, w: StringIO, write_value: WriteFunc, unit: str - ) -> None: - w.write("INTERVAL ") - write_value() - w.write(f" {unit}") + # write_duration / write_interval: inherit the base `INTERVAL ` defaults. def write_extract( self, @@ -339,8 +317,7 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: def write_struct_open(self, w: StringIO) -> None: w.write("STRUCT(") - def write_struct_close(self, w: StringIO) -> None: - w.write(")") + # write_struct_close: inherits the base `)` default. # --- Validation --- @@ -348,26 +325,12 @@ def max_identifier_length(self) -> int: return 300 def validate_field_name(self, name: str) -> None: - if not name: - raise InvalidFieldNameError( - "field name cannot be empty", - "empty field name provided", - ) - if len(name) > 300: - raise InvalidFieldNameError( - "field name too long", - f"field name '{name}' exceeds 300 characters", - ) - if not _FIELD_NAME_RE.match(name): - raise InvalidFieldNameError( - "invalid field name format", - f"field name '{name}' contains invalid characters", - ) - if name.lower() in _BIGQUERY_RESERVED: - raise InvalidFieldNameError( - "field name is a reserved SQL keyword", - f"field name '{name}' is a reserved BigQuery keyword", - ) + validate_sql_field_name( + name, + reserved=_BIGQUERY_RESERVED, + keyword_label="BigQuery", + max_length=300, + ) # --- Capabilities --- diff --git a/src/pycel2sql/dialect/duckdb.py b/src/pycel2sql/dialect/duckdb.py index 29bb64d..79a9f3e 100644 --- a/src/pycel2sql/dialect/duckdb.py +++ b/src/pycel2sql/dialect/duckdb.py @@ -2,12 +2,9 @@ from __future__ import annotations -import re - from io import StringIO -from pycel2sql._errors import InvalidFieldNameError -from pycel2sql._utils import convert_re2_to_re2_native +from pycel2sql._utils import convert_re2_to_re2_native, validate_sql_field_name from pycel2sql.dialect._base import Dialect, WriteFunc # DuckDB reserved keywords @@ -25,8 +22,6 @@ "values", "when", "where", "with", } -_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - # CEL type name -> DuckDB type name _TYPE_MAP: dict[str, str] = { "bool": "BOOLEAN", @@ -42,11 +37,10 @@ class DuckDBDialect(Dialect): """DuckDB dialect for CEL-to-SQL conversion.""" - # --- Literals --- + TYPE_MAP = _TYPE_MAP - def write_string_literal(self, w: StringIO, value: str) -> None: - escaped = value.replace("'", "''") - w.write(f"'{escaped}'") + # --- Literals --- + # write_string_literal: inherits the base default (single-quote escaping). def write_bytes_literal(self, w: StringIO, value: bytes) -> None: hex_str = value.hex().upper() @@ -56,13 +50,7 @@ def write_param_placeholder(self, w: StringIO, param_index: int) -> None: w.write(f"${param_index}") # --- Operators --- - - def write_string_concat( - self, w: StringIO, write_lhs: WriteFunc, write_rhs: WriteFunc - ) -> None: - write_lhs() - w.write(" || ") - write_rhs() + # write_string_concat: inherits the base default (`||`). def write_regex_match( self, w: StringIO, write_target: WriteFunc, pattern: str, case_insensitive: bool @@ -94,10 +82,6 @@ def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: write_expr() w.write("::DOUBLE") - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: - sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) - w.write(sql_type) - def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: w.write("EXTRACT(EPOCH FROM ") write_expr() @@ -191,16 +175,7 @@ def write_nested_json_array_membership( w.write(")") # --- Timestamps --- - - def write_duration(self, w: StringIO, value: int, unit: str) -> None: - w.write(f"INTERVAL {value} {unit}") - - def write_interval( - self, w: StringIO, write_value: WriteFunc, unit: str - ) -> None: - w.write("INTERVAL ") - write_value() - w.write(f" {unit}") + # write_duration / write_interval: inherit the base `INTERVAL ` defaults. def write_extract( self, @@ -216,16 +191,7 @@ def write_extract( write_tz() w.write(")") - def write_timestamp_arithmetic( - self, - w: StringIO, - op: str, - write_ts: WriteFunc, - write_dur: WriteFunc, - ) -> None: - write_ts() - w.write(f" {op} ") - write_dur() + # write_timestamp_arithmetic: inherits the base infix ` ` default. # --- String Functions --- @@ -298,8 +264,7 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: def write_struct_open(self, w: StringIO) -> None: w.write("ROW(") - def write_struct_close(self, w: StringIO) -> None: - w.write(")") + # write_struct_close: inherits the base `)` default. # --- Validation --- @@ -307,21 +272,9 @@ def max_identifier_length(self) -> int: return 0 # No limit def validate_field_name(self, name: str) -> None: - if not name: - raise InvalidFieldNameError( - "field name cannot be empty", - "empty field name provided", - ) - if not _FIELD_NAME_RE.match(name): - raise InvalidFieldNameError( - "invalid field name format", - f"field name '{name}' contains invalid characters", - ) - if name.lower() in _DUCKDB_RESERVED: - raise InvalidFieldNameError( - "field name is a reserved SQL keyword", - f"field name '{name}' is a reserved DuckDB keyword", - ) + validate_sql_field_name( + name, reserved=_DUCKDB_RESERVED, keyword_label="DuckDB" + ) # --- Capabilities --- diff --git a/src/pycel2sql/dialect/mysql.py b/src/pycel2sql/dialect/mysql.py index 07fd98f..4078403 100644 --- a/src/pycel2sql/dialect/mysql.py +++ b/src/pycel2sql/dialect/mysql.py @@ -2,12 +2,10 @@ from __future__ import annotations -import re - from io import StringIO -from pycel2sql._errors import InvalidFieldNameError, UnsupportedDialectFeatureError -from pycel2sql._utils import convert_re2_to_mysql +from pycel2sql._errors import UnsupportedDialectFeatureError +from pycel2sql._utils import convert_re2_to_mysql, validate_sql_field_name from pycel2sql.dialect._base import Dialect, WriteFunc # MySQL reserved keywords @@ -57,8 +55,6 @@ "year_month", "zerofill", } -_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - # CEL type name -> MySQL type name _TYPE_MAP: dict[str, str] = { "bool": "UNSIGNED", @@ -74,11 +70,10 @@ class MySQLDialect(Dialect): """MySQL dialect for CEL-to-SQL conversion.""" - # --- Literals --- + TYPE_MAP = _TYPE_MAP - def write_string_literal(self, w: StringIO, value: str) -> None: - escaped = value.replace("'", "''") - w.write(f"'{escaped}'") + # --- Literals --- + # write_string_literal: inherits the base default (single-quote escaping). def write_bytes_literal(self, w: StringIO, value: bytes) -> None: hex_str = value.hex().upper() @@ -124,10 +119,6 @@ def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: write_expr() w.write(" + 0") - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: - sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) - w.write(sql_type) - def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: w.write("UNIX_TIMESTAMP(") write_expr() @@ -223,16 +214,7 @@ def write_nested_json_array_membership( w.write(")") # --- Timestamps --- - - def write_duration(self, w: StringIO, value: int, unit: str) -> None: - w.write(f"INTERVAL {value} {unit}") - - def write_interval( - self, w: StringIO, write_value: WriteFunc, unit: str - ) -> None: - w.write("INTERVAL ") - write_value() - w.write(f" {unit}") + # write_duration / write_interval: inherit the base `INTERVAL ` defaults. def write_extract( self, @@ -251,16 +233,7 @@ def write_extract( write_expr() w.write(")") - def write_timestamp_arithmetic( - self, - w: StringIO, - op: str, - write_ts: WriteFunc, - write_dur: WriteFunc, - ) -> None: - write_ts() - w.write(f" {op} ") - write_dur() + # write_timestamp_arithmetic: inherits the base infix ` ` default. # --- String Functions --- @@ -326,8 +299,7 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: def write_struct_open(self, w: StringIO) -> None: w.write("ROW(") - def write_struct_close(self, w: StringIO) -> None: - w.write(")") + # write_struct_close: inherits the base `)` default. # --- Validation --- @@ -335,26 +307,12 @@ def max_identifier_length(self) -> int: return 64 def validate_field_name(self, name: str) -> None: - if not name: - raise InvalidFieldNameError( - "field name cannot be empty", - "empty field name provided", - ) - if len(name) > 64: - raise InvalidFieldNameError( - "field name too long", - f"field name '{name}' exceeds 64 characters", - ) - if not _FIELD_NAME_RE.match(name): - raise InvalidFieldNameError( - "invalid field name format", - f"field name '{name}' contains invalid characters", - ) - if name.lower() in _MYSQL_RESERVED: - raise InvalidFieldNameError( - "field name is a reserved SQL keyword", - f"field name '{name}' is a reserved MySQL keyword", - ) + validate_sql_field_name( + name, + reserved=_MYSQL_RESERVED, + keyword_label="MySQL", + max_length=64, + ) # --- Capabilities --- diff --git a/src/pycel2sql/dialect/postgres.py b/src/pycel2sql/dialect/postgres.py index 84845b0..a7fbfb3 100644 --- a/src/pycel2sql/dialect/postgres.py +++ b/src/pycel2sql/dialect/postgres.py @@ -28,11 +28,10 @@ class PostgresDialect(Dialect): """PostgreSQL dialect for CEL-to-SQL conversion.""" - # --- Literals --- + TYPE_MAP = _TYPE_MAP - def write_string_literal(self, w: StringIO, value: str) -> None: - escaped = value.replace("'", "''") - w.write(f"'{escaped}'") + # --- Literals --- + # write_string_literal: inherits the base default (single-quote escaping). def write_bytes_literal(self, w: StringIO, value: bytes) -> None: hex_str = value.hex().upper() @@ -42,13 +41,7 @@ def write_param_placeholder(self, w: StringIO, param_index: int) -> None: w.write(f"${param_index}") # --- Operators --- - - def write_string_concat( - self, w: StringIO, write_lhs: WriteFunc, write_rhs: WriteFunc - ) -> None: - write_lhs() - w.write(" || ") - write_rhs() + # write_string_concat: inherits the base default (`||`). def write_regex_match( self, w: StringIO, write_target: WriteFunc, pattern: str, case_insensitive: bool @@ -78,10 +71,6 @@ def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: write_expr() w.write("::numeric") - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: - sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) - w.write(sql_type) - def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: w.write("EXTRACT(EPOCH FROM ") write_expr() @@ -181,16 +170,7 @@ def write_nested_json_array_membership( w.write(")))") # --- Timestamps --- - - def write_duration(self, w: StringIO, value: int, unit: str) -> None: - w.write(f"INTERVAL {value} {unit}") - - def write_interval( - self, w: StringIO, write_value: WriteFunc, unit: str - ) -> None: - w.write("INTERVAL ") - write_value() - w.write(f" {unit}") + # write_duration / write_interval: inherit the base `INTERVAL ` defaults. def write_extract( self, @@ -206,16 +186,7 @@ def write_extract( write_tz() w.write(")") - def write_timestamp_arithmetic( - self, - w: StringIO, - op: str, - write_ts: WriteFunc, - write_dur: WriteFunc, - ) -> None: - write_ts() - w.write(f" {op} ") - write_dur() + # write_timestamp_arithmetic: inherits the base infix ` ` default. # --- String Functions --- @@ -288,8 +259,7 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: def write_struct_open(self, w: StringIO) -> None: w.write("ROW(") - def write_struct_close(self, w: StringIO) -> None: - w.write(")") + # write_struct_close: inherits the base `)` default. # --- Validation --- diff --git a/src/pycel2sql/dialect/spark.py b/src/pycel2sql/dialect/spark.py index 74071e8..7a56533 100644 --- a/src/pycel2sql/dialect/spark.py +++ b/src/pycel2sql/dialect/spark.py @@ -258,11 +258,10 @@ def _convert_re2_to_spark(pattern: str) -> tuple[str, bool]: class SparkDialect(Dialect): """Apache Spark SQL dialect for CEL-to-SQL conversion.""" - # --- Literals --- + TYPE_MAP = _TYPE_MAP - def write_string_literal(self, w: StringIO, value: str) -> None: - escaped = value.replace("'", "''") - w.write(f"'{escaped}'") + # --- Literals --- + # write_string_literal: inherits the base default (single-quote escaping). def write_bytes_literal(self, w: StringIO, value: bytes) -> None: hex_str = value.hex().upper() @@ -314,10 +313,6 @@ def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: write_expr() w.write(" + 0") - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: - sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) - w.write(sql_type) - def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: w.write("UNIX_TIMESTAMP(") write_expr() @@ -423,16 +418,7 @@ def write_nested_json_array_membership( w.write(")") # --- Timestamps --- - - def write_duration(self, w: StringIO, value: int, unit: str) -> None: - w.write(f"INTERVAL {value} {unit}") - - def write_interval( - self, w: StringIO, write_value: WriteFunc, unit: str - ) -> None: - w.write("INTERVAL ") - write_value() - w.write(f" {unit}") + # write_duration / write_interval: inherit the base `INTERVAL ` defaults. def write_extract( self, @@ -458,16 +444,7 @@ def write_extract( write_tz() w.write(")") - def write_timestamp_arithmetic( - self, - w: StringIO, - op: str, - write_ts: WriteFunc, - write_dur: WriteFunc, - ) -> None: - write_ts() - w.write(f" {op} ") - write_dur() + # write_timestamp_arithmetic: inherits the base infix ` ` default. # --- String Functions --- @@ -546,8 +523,7 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: def write_struct_open(self, w: StringIO) -> None: w.write("struct(") - def write_struct_close(self, w: StringIO) -> None: - w.write(")") + # write_struct_close: inherits the base `)` default. # --- Validation --- diff --git a/src/pycel2sql/dialect/sqlite.py b/src/pycel2sql/dialect/sqlite.py index 8314509..623f291 100644 --- a/src/pycel2sql/dialect/sqlite.py +++ b/src/pycel2sql/dialect/sqlite.py @@ -2,11 +2,10 @@ from __future__ import annotations -import re - from io import StringIO -from pycel2sql._errors import InvalidFieldNameError, UnsupportedDialectFeatureError +from pycel2sql._errors import UnsupportedDialectFeatureError +from pycel2sql._utils import validate_sql_field_name from pycel2sql.dialect._base import Dialect, WriteFunc # SQLite reserved keywords @@ -35,8 +34,6 @@ "virtual", "when", "where", "window", "with", "without", } -_FIELD_NAME_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") - # CEL type name -> SQLite type name _TYPE_MAP: dict[str, str] = { "bool": "INTEGER", @@ -65,11 +62,10 @@ class SQLiteDialect(Dialect): """SQLite dialect for CEL-to-SQL conversion.""" - # --- Literals --- + TYPE_MAP = _TYPE_MAP - def write_string_literal(self, w: StringIO, value: str) -> None: - escaped = value.replace("'", "''") - w.write(f"'{escaped}'") + # --- Literals --- + # write_string_literal: inherits the base default (single-quote escaping). def write_bytes_literal(self, w: StringIO, value: bytes) -> None: hex_str = value.hex().upper() @@ -79,13 +75,7 @@ def write_param_placeholder(self, w: StringIO, param_index: int) -> None: w.write("?") # --- Operators --- - - def write_string_concat( - self, w: StringIO, write_lhs: WriteFunc, write_rhs: WriteFunc - ) -> None: - write_lhs() - w.write(" || ") - write_rhs() + # write_string_concat: inherits the base default (`||`). def write_regex_match( self, w: StringIO, write_target: WriteFunc, pattern: str, case_insensitive: bool @@ -112,10 +102,6 @@ def write_cast_to_numeric(self, w: StringIO, write_expr: WriteFunc) -> None: write_expr() w.write(" + 0") - def write_type_name(self, w: StringIO, cel_type_name: str) -> None: - sql_type = _TYPE_MAP.get(cel_type_name, cel_type_name.upper()) - w.write(sql_type) - def write_epoch_extract(self, w: StringIO, write_expr: WriteFunc) -> None: w.write("CAST(strftime('%s', ") write_expr() @@ -346,8 +332,7 @@ def convert_regex(self, re2_pattern: str) -> tuple[str, bool]: def write_struct_open(self, w: StringIO) -> None: w.write("json_object(") - def write_struct_close(self, w: StringIO) -> None: - w.write(")") + # write_struct_close: inherits the base `)` default. # --- Validation --- @@ -355,21 +340,9 @@ def max_identifier_length(self) -> int: return 0 # No limit def validate_field_name(self, name: str) -> None: - if not name: - raise InvalidFieldNameError( - "field name cannot be empty", - "empty field name provided", - ) - if not _FIELD_NAME_RE.match(name): - raise InvalidFieldNameError( - "invalid field name format", - f"field name '{name}' contains invalid characters", - ) - if name.lower() in _SQLITE_RESERVED: - raise InvalidFieldNameError( - "field name is a reserved SQL keyword", - f"field name '{name}' is a reserved SQLite keyword", - ) + validate_sql_field_name( + name, reserved=_SQLITE_RESERVED, keyword_label="SQLite" + ) # --- Capabilities --- diff --git a/src/pycel2sql/schema.py b/src/pycel2sql/schema.py index d7cda82..01d2503 100644 --- a/src/pycel2sql/schema.py +++ b/src/pycel2sql/schema.py @@ -35,3 +35,14 @@ def find_field(self, name: str) -> FieldSchema | None: def __len__(self) -> int: return len(self._fields) + + +def field_is_json( + schemas: dict[str, Schema], table_name: str, field_name: str +) -> bool: + """Return True if ``table_name.field_name`` is a declared JSON/JSONB field.""" + schema = schemas.get(table_name) + if not schema: + return False + field = schema.find_field(field_name) + return field is not None and (field.is_json or field.is_jsonb)