Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ uv run pytest tests/integration/ -v
# Lint
uv run ruff check src/ tests/

# Type check (strict mode; pre-existing lark type errors are expected)
# Type check (strict mode; should report no errors)
uv run mypy src/pycel2sql/
```

Expand Down
105 changes: 52 additions & 53 deletions src/pycel2sql/_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
PatternType,
)
from pycel2sql.dialect._base import IndexAdvisor
from pycel2sql.schema import Schema
from pycel2sql.schema import Schema, field_is_json

# Annotation-only alias for a Lark parse-tree node (see _converter.py).
TreeT = Tree[Token]

class IndexAnalyzer(Interpreter):

class IndexAnalyzer(Interpreter[Token, None]):
"""Lightweight Lark Interpreter that walks the parse tree detecting index-worthy patterns.

Does NOT generate SQL - only collects patterns.
Expand All @@ -40,7 +43,7 @@ def __init__(
def patterns(self) -> list[IndexPattern]:
return list(self._patterns.values())

def visit(self, tree: Tree) -> Any:
def visit(self, tree: TreeT) -> Any:
if isinstance(tree, Token):
return None
return super().visit(tree)
Expand All @@ -56,64 +59,64 @@ def _add_pattern(self, pattern: IndexPattern) -> None:

# --- Tree walking ---

def _visit_children(self, tree: Tree) -> None:
def _visit_children(self, tree: TreeT) -> None:
for child in tree.children:
if isinstance(child, Tree):
self.visit(child)

# Top-level passthrough handlers
def expr(self, tree: Tree) -> None:
def expr(self, tree: TreeT) -> None:
self._visit_children(tree)

def conditionalor(self, tree: Tree) -> None:
def conditionalor(self, tree: TreeT) -> None:
self._visit_children(tree)

def conditionaland(self, tree: Tree) -> None:
def conditionaland(self, tree: TreeT) -> None:
self._visit_children(tree)

def addition(self, tree: Tree) -> None:
def addition(self, tree: TreeT) -> None:
self._visit_children(tree)

def multiplication(self, tree: Tree) -> None:
def multiplication(self, tree: TreeT) -> None:
self._visit_children(tree)

def unary(self, tree: Tree) -> None:
def unary(self, tree: TreeT) -> None:
self._visit_children(tree)

def member(self, tree: Tree) -> None:
def member(self, tree: TreeT) -> None:
self._visit_children(tree)

def primary(self, tree: Tree) -> None:
def primary(self, tree: TreeT) -> None:
self._visit_children(tree)

def paren_expr(self, tree: Tree) -> None:
def paren_expr(self, tree: TreeT) -> None:
self._visit_children(tree)

def literal(self, tree: Tree) -> None:
def literal(self, tree: TreeT) -> None:
pass

def list_lit(self, tree: Tree) -> None:
def list_lit(self, tree: TreeT) -> None:
self._visit_children(tree)

def map_lit(self, tree: Tree) -> None:
def map_lit(self, tree: TreeT) -> None:
self._visit_children(tree)

def exprlist(self, tree: Tree) -> None:
def exprlist(self, tree: TreeT) -> None:
self._visit_children(tree)

def mapinits(self, tree: Tree) -> None:
def mapinits(self, tree: TreeT) -> None:
self._visit_children(tree)

def fieldinits(self, tree: Tree) -> None:
def fieldinits(self, tree: TreeT) -> None:
self._visit_children(tree)

def ident(self, tree: Tree) -> None:
def ident(self, tree: TreeT) -> None:
pass

def ident_arg(self, tree: Tree) -> None:
def ident_arg(self, tree: TreeT) -> None:
func_name = str(tree.children[0])
args_node = tree.children[1] if len(tree.children) > 1 else None
args = args_node.children if args_node is not None else []
args = args_node.children if isinstance(args_node, Tree) else []

if func_name == "matches" and len(args) >= 1:
col = self._extract_column_name(args[0])
Expand All @@ -127,43 +130,43 @@ def ident_arg(self, tree: Tree) -> None:

self._visit_children(tree)

def dot_ident_arg(self, tree: Tree) -> None:
def dot_ident_arg(self, tree: TreeT) -> None:
self._visit_children(tree)

def dot_ident(self, tree: Tree) -> None:
def dot_ident(self, tree: TreeT) -> None:
pass

def member_index(self, tree: Tree) -> None:
def member_index(self, tree: TreeT) -> None:
self._visit_children(tree)

def member_object(self, tree: Tree) -> None:
def member_object(self, tree: TreeT) -> None:
pass

# Operator prefix handlers
def addition_add(self, tree: Tree) -> None:
def addition_add(self, tree: TreeT) -> None:
self._visit_children(tree)

def addition_sub(self, tree: Tree) -> None:
def addition_sub(self, tree: TreeT) -> None:
self._visit_children(tree)

def multiplication_mul(self, tree: Tree) -> None:
def multiplication_mul(self, tree: TreeT) -> None:
self._visit_children(tree)

def multiplication_div(self, tree: Tree) -> None:
def multiplication_div(self, tree: TreeT) -> None:
self._visit_children(tree)

def multiplication_mod(self, tree: Tree) -> None:
def multiplication_mod(self, tree: TreeT) -> None:
self._visit_children(tree)

def unary_not(self, tree: Tree) -> None:
def unary_not(self, tree: TreeT) -> None:
pass

def unary_neg(self, tree: Tree) -> None:
def unary_neg(self, tree: TreeT) -> None:
pass

# --- Key detection points ---

def relation(self, tree: Tree) -> None:
def relation(self, tree: TreeT) -> None:
"""Detect comparison operators -> COMPARISON pattern."""
children = tree.children
if len(children) == 2:
Expand Down Expand Up @@ -197,28 +200,28 @@ def relation(self, tree: Tree) -> None:
self._visit_children(tree)

# Relation prefix handlers
def relation_eq(self, tree: Tree) -> None:
def relation_eq(self, tree: TreeT) -> None:
self._visit_children(tree)

def relation_ne(self, tree: Tree) -> None:
def relation_ne(self, tree: TreeT) -> None:
self._visit_children(tree)

def relation_lt(self, tree: Tree) -> None:
def relation_lt(self, tree: TreeT) -> None:
self._visit_children(tree)

def relation_le(self, tree: Tree) -> None:
def relation_le(self, tree: TreeT) -> None:
self._visit_children(tree)

def relation_gt(self, tree: Tree) -> None:
def relation_gt(self, tree: TreeT) -> None:
self._visit_children(tree)

def relation_ge(self, tree: Tree) -> None:
def relation_ge(self, tree: TreeT) -> None:
self._visit_children(tree)

def relation_in(self, tree: Tree) -> None:
def relation_in(self, tree: TreeT) -> None:
self._visit_children(tree)

def member_dot(self, tree: Tree) -> None:
def member_dot(self, tree: TreeT) -> None:
"""Detect JSON field access -> JSON_ACCESS pattern."""
obj = tree.children[0]
field_name = str(tree.children[1])
Expand All @@ -235,7 +238,7 @@ def member_dot(self, tree: Tree) -> None:

self._visit_children(tree)

def member_dot_arg(self, tree: Tree) -> None:
def member_dot_arg(self, tree: TreeT) -> None:
"""Detect matches() -> REGEX_MATCH and comprehensions -> ARRAY/JSON_ARRAY_COMPREHENSION."""
obj = tree.children[0]
method_name = str(tree.children[1])
Expand Down Expand Up @@ -271,7 +274,7 @@ def member_dot_arg(self, tree: Tree) -> None:

# --- Helper methods ---

def _extract_column_name(self, tree: Tree | Token) -> str | None:
def _extract_column_name(self, tree: TreeT | Token) -> str | None:
"""Extract a column/field name from a tree node."""
node = tree
while isinstance(node, Tree):
Expand All @@ -287,12 +290,12 @@ def _extract_column_name(self, tree: Tree | Token) -> str | None:
return str(node)
return None

def _extract_table_name(self, tree: Tree | Token) -> str:
def _extract_table_name(self, tree: TreeT | Token) -> str:
"""Extract the root table name from a tree."""
root = self._get_root_ident(tree)
return root or ""

def _get_root_ident(self, tree: Tree | Token) -> str | None:
def _get_root_ident(self, tree: TreeT | Token) -> str | None:
node = tree
while isinstance(node, Tree):
if node.data == "ident":
Expand All @@ -307,7 +310,7 @@ def _get_root_ident(self, tree: Tree | Token) -> str | None:
return str(node)
return None

def _get_first_field(self, obj: Tree | Token, fallback: str) -> str:
def _get_first_field(self, obj: TreeT | Token, fallback: str) -> str:
node = obj
while isinstance(node, Tree):
if node.data == "member_dot":
Expand All @@ -325,11 +328,7 @@ def _get_first_field(self, obj: Tree | Token, fallback: str) -> str:
return fallback

def _is_field_json(self, table_name: str, field_name: str) -> bool:
schema = self._schemas.get(table_name)
if not schema:
return False
field = schema.find_field(field_name)
return field is not None and (field.is_json or field.is_jsonb)
return field_is_json(self._schemas, table_name, field_name)


def _pattern_priority(pattern: PatternType) -> int:
Expand All @@ -346,7 +345,7 @@ def _pattern_priority(pattern: PatternType) -> int:


def analyze_patterns(
tree: Tree,
tree: TreeT,
advisor: IndexAdvisor,
schemas: dict[str, Schema] | None = None,
) -> list[IndexRecommendation]:
Expand Down
Loading