Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 109 additions & 6 deletions libs/openant-core/parsers/zig/call_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,22 @@ def build(self) -> Dict[str, Any]:
# Build an index of function names to IDs for resolution
name_to_ids = self._build_name_index()

# Build per-file simple const fn-alias bindings (`const f = handler;`)
# so that a later `f()` resolves to `handler`.
alias_to_target = self._build_alias_index(name_to_ids)

# For each function, find calls in its body
for func_id, func_info in self.functions.items():
code = func_info.get("code", "")
file_path = func_info.get("file_path", "")

# Parse the function code to find call sites
calls = self._find_calls_in_code(code)
calls = self._find_calls_in_code(code, file_path)

# Resolve each call to a function ID
for call_name in calls:
resolved_ids = self._resolve_call(
call_name, file_path, name_to_ids
call_name, file_path, name_to_ids, alias_to_target
)
for resolved_id in resolved_ids:
if resolved_id != func_id: # No self-calls
Expand Down Expand Up @@ -219,7 +223,59 @@ def _build_name_index(self) -> Dict[str, List[str]]:

return name_to_ids

def _find_calls_in_code(self, code: str) -> Set[str]:
def _build_alias_index(
self, name_to_ids: Dict[str, List[str]]
) -> Dict[str, Dict[str, str]]:
"""Index simple const fn-aliases per file: `const f = handler;` -> {f: handler}.

Only bindings whose right-hand side is a bare identifier naming a known
function are tracked (a genuine fn alias), so arbitrary const dataflow
(`const x = 1;`) is ignored. Scoped per file to avoid cross-file leaks.
"""
alias_to_target: Dict[str, Dict[str, str]] = defaultdict(dict)

for func_info in self.functions.values():
file_path = func_info.get("file_path", "")
code = func_info.get("code", "")
if not code:
continue
try:
tree = self.parser.parse(code.encode("utf-8"))
except Exception:
continue
self._collect_aliases_from_node(
tree.root_node,
code.encode("utf-8"),
name_to_ids,
alias_to_target[file_path],
)

return alias_to_target

def _collect_aliases_from_node(
self,
node: Node,
source: bytes,
name_to_ids: Dict[str, List[str]],
aliases: Dict[str, str],
) -> None:
"""Collect `const <alias> = <known-fn>;` bindings from a parse tree."""
if node.type in ("variable_declaration", "VarDecl"):
ident_children = [
c for c in node.children if c.type in ("identifier", "IDENTIFIER")
]
# A simple alias is exactly: const <alias> = <target-identifier>;
if len(ident_children) == 2:
alias_name = self._get_node_text(ident_children[0], source)
target_name = self._get_node_text(ident_children[1], source)
# Only record when the target is a known function name.
if alias_name and target_name in name_to_ids:
aliases[alias_name] = target_name

for child in node.children:
self._collect_aliases_from_node(child, source, name_to_ids, aliases)

def _find_calls_in_code(self, code: str, caller_file: str = "") -> Set[str]:
"""Find all function calls in a code snippet."""
calls = set()

Expand All @@ -230,11 +286,32 @@ def _find_calls_in_code(self, code: str) -> Set[str]:
# Fallback to regex-based extraction
calls = self._find_calls_with_regex(code)

# Filter out builtins
calls = {c for c in calls if c not in self.ZIG_BUILTINS and not c.startswith("@")}
# Filter out builtins, but NEVER filter a name that a same-file user
# function actually defines. A user fn whose name collides with a
# ZIG_BUILTINS entry (e.g. `expect`) must keep its edge. Scope the
# shadow check to the caller's own file so a builtin call is not
# spuriously linked to an unrelated same-named user fn elsewhere.
shadowing = self._same_file_function_names(caller_file)
calls = {
c
for c in calls
if c in shadowing or (c not in self.ZIG_BUILTINS and not c.startswith("@"))
}

return calls

def _same_file_function_names(self, caller_file: str) -> Set[str]:
"""Names of user functions defined in `caller_file` (same-file scope)."""
if not caller_file:
return set()
names: Set[str] = set()
for func_info in self.functions.values():
if func_info.get("file_path") == caller_file:
name = func_info.get("name", "")
if name:
names.add(name)
return names

def _extract_calls_from_node(
self, node: Node, source: bytes, calls: Set[str]
) -> None:
Expand All @@ -243,7 +320,25 @@ def _extract_calls_from_node(
if node.type in ("call_expr", "call_expression", "CallExpr"):
# Get the function being called
for child in node.children:
if child.type in ("identifier", "IDENTIFIER", "field_access"):
if child.type in (
"identifier",
"IDENTIFIER",
"field_access",
"field_expression",
):
# For a field access / expression (e.g. `o.m` or `C{}.m`),
# the method name is the trailing identifier child. Prefer
# that over text-splitting, which is brittle when the
# receiver itself contains punctuation (e.g. `C{}.m`).
if child.type in ("field_access", "field_expression"):
method_name = None
for sub in child.children:
if sub.type in ("identifier", "IDENTIFIER"):
method_name = self._get_node_text(sub, source)
if method_name:
calls.add(method_name) # the method name
calls.add(self._get_node_text(child, source)) # full path
break
call_name = self._get_node_text(child, source)
# Handle method calls (obj.method)
if "." in call_name:
Expand Down Expand Up @@ -286,6 +381,7 @@ def _resolve_call(
call_name: str,
caller_file: str,
name_to_ids: Dict[str, List[str]],
alias_to_target: Dict[str, Dict[str, str]] | None = None,
) -> List[str]:
"""
Resolve a call name to function ID(s).
Expand All @@ -295,6 +391,13 @@ def _resolve_call(
2. Imported files
3. Unique name match
"""
# Resolve a same-file const fn-alias (`const f = handler; f()`) to its
# target function name before looking up candidates.
if alias_to_target is not None:
target = alias_to_target.get(caller_file, {}).get(call_name)
if target is not None:
call_name = target

candidates = name_to_ids.get(call_name, [])

if not candidates:
Expand Down
63 changes: 49 additions & 14 deletions libs/openant-core/parsers/zig/function_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ class FunctionExtractor:

ZIG_LANGUAGE = Language(ts_zig.language())

# Real tree-sitter-zig node types for container declarations. The container
# body (struct/enum/union/opaque) is a child of a `variable_declaration`
# (`const Foo = struct {...}`). Legacy names (ContainerDecl/VarDecl) are kept
# for forward/back compatibility with other grammar revisions.
_CONTAINER_BODY_TYPES = frozenset(
{
"struct_declaration",
"enum_declaration",
"union_declaration",
"opaque_declaration",
"container_decl",
"ContainerDecl",
}
)
_VAR_DECL_TYPES = frozenset({"variable_declaration", "VarDecl"})

def __init__(self, repo_path: str, scan_results: Dict[str, Any]):
self.repo_path = Path(repo_path).resolve()
self.scan_results = scan_results
Expand Down Expand Up @@ -103,19 +119,30 @@ def _walk_node(
func_id = f"{file_path}:{func_info['qualified_name']}"
functions[func_id] = func_info

elif node.type == "VarDecl":
# Check if this is a struct/enum definition
elif node.type in self._VAR_DECL_TYPES:
# Check if this is a struct/enum/union/opaque container definition.
struct_info = self._extract_struct_from_var_decl(node, source, file_path)
if struct_info:
struct_id = f"{file_path}:{struct_info['name']}"
structs[struct_id] = struct_info
# Extract methods within the struct
self._extract_struct_methods(
node, source, file_path, struct_info["name"], functions
)

elif node.type == "container_decl" or node.type == "ContainerDecl":
# Direct struct/enum declarations
# Recurse into the container body with the struct name as
# context, so member functions are qualified `Foo.method`
# rather than being re-emitted as bare `method` by the generic
# recursion below.
for child in node.children:
self._walk_node(
child,
source,
file_path,
functions,
structs,
imports,
struct_info["name"],
)
return

elif node.type in self._CONTAINER_BODY_TYPES:
# Direct struct/enum declarations (anonymous container).
struct_info = self._extract_container(node, source, file_path)
if struct_info:
struct_id = f"{file_path}:{struct_info['name']}"
Expand Down Expand Up @@ -145,7 +172,11 @@ def _extract_function(

for child in node.children:
if child.type == "identifier" or child.type == "IDENTIFIER":
name = self._get_node_text(child, source)
# The FIRST identifier is the function name. A later identifier
# child is the return type (e.g. `fn makeWidget() Widget`) and
# must not overwrite the name.
if name is None:
name = self._get_node_text(child, source)
elif child.type == "parameters" or child.type == "ParamDeclList":
parameters = self._extract_parameters(child, source)

Expand Down Expand Up @@ -196,8 +227,9 @@ def _extract_struct_from_var_decl(

for child in node.children:
if child.type == "identifier" or child.type == "IDENTIFIER":
name = self._get_node_text(child, source)
elif child.type == "container_decl" or child.type == "ContainerDecl":
if name is None:
name = self._get_node_text(child, source)
elif child.type in self._CONTAINER_BODY_TYPES:
is_struct = True

if name and is_struct:
Expand Down Expand Up @@ -261,8 +293,11 @@ def _classify_function(self, name: str, file_path: str) -> str:
"""Classify the function type based on name and context."""
name_lower = name.lower()

# Test functions
if name_lower.startswith("test") or "_test" in name_lower:
# Test functions. Anchor on the underscore-delimited test convention
# (`test_foo`, `foo_test`, or a bare `test`). A camelCase identifier
# that merely starts with "test" (e.g. `testConnection`) is an ordinary
# function, not a zig `test "..." {}` block.
if name_lower == "test" or name_lower.startswith("test_") or name_lower.endswith("_test"):
return "test"

# Init/constructor patterns
Expand Down
Loading
Loading