Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions pineforge_codegen/analyzer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(self, ast: Program, filename: str = "<stdin>") -> None:
# Probe: data/validation/udt-method-probe-20-udt-return-from-func.
self._func_udt_return_types: dict[str, str] = {}
self._func_return_type_specs: dict[str, "TypeSpec"] = {}
self._func_param_type_specs: dict[str, list] = {}
# Per-function var_members and series_vars (for call-site cloning)
self._func_var_members: dict[str, list] = {} # func_name -> [(name, PineType, init_str)]
self._func_series_vars: dict[str, set] = {} # func_name -> set[str]
Expand Down Expand Up @@ -1184,6 +1185,16 @@ def _visit_FuncDef(self, node: FuncDef) -> PineType:
if hi > lo:
self._func_ta_ranges[node.name] = (lo, hi)

inferred_param_specs = self._param_type_specs_from_def(node)
for i, param in enumerate(node.params):
if i < len(inferred_param_specs) and inferred_param_specs[i] is not None:
continue
sym = self._symbols.resolve(param)
spec = getattr(sym, "type_spec", None) if sym is not None else None
if spec is not None and i < len(inferred_param_specs):
inferred_param_specs[i] = spec
self._func_param_type_specs[node.name] = inferred_param_specs

self._symbols.exit_scope()

# Detect if function returns a tuple (last stmt is TupleLiteral)
Expand Down Expand Up @@ -1511,6 +1522,20 @@ def _visit_BinOp(self, node: BinOp) -> PineType:

# String concatenation: if either side is STRING, result is STRING
if left_type == PineType.STRING or right_type == PineType.STRING:
def _mark_string_param(expr) -> None:
if not isinstance(expr, Identifier):
return
sym = self._symbols.resolve(expr.name)
if sym is None or not (sym.scope or "").startswith("func_"):
return
if sym.pine_type == PineType.UNKNOWN:
sym.pine_type = PineType.STRING
sym.type_spec = TypeSpec.primitive("string")

if left_type == PineType.STRING:
_mark_string_param(node.right)
if right_type == PineType.STRING:
_mark_string_param(node.left)
return PineType.STRING

# Arithmetic: promote to FLOAT if either side is FLOAT
Expand Down Expand Up @@ -1583,11 +1608,14 @@ def _visit_FuncCall(self, node: FuncCall) -> PineType:
if isinstance(obj, Identifier) and obj.name == "str":
for arg in node.args:
self._visit(arg)
# Most str.* return a string, but a few don't:
# str.tonumber -> float, str.length -> int
# Most str.* return a string, but predicates and index helpers
# are scalar. Keep this aligned with signatures.py and the C++
# emitter's _infer_type path.
if member in ("contains", "startswith", "endswith"):
return PineType.BOOL
if member == "tonumber":
return PineType.FLOAT
if member == "length":
if member in ("length", "pos"):
return PineType.INT
return PineType.STRING

Expand Down Expand Up @@ -1812,9 +1840,11 @@ def _visit_MemberAccess(self, node: MemberAccess) -> PineType:

# syminfo.*
if ns == "syminfo":
if node.member == "mintick":
return PineType.FLOAT
return PineType.STRING
from .. import signatures as _pf_sigs
return _pf_sigs.SYMINFO_VARIABLES.get(
f"syminfo.{node.member}",
PineType.STRING,
)

# color.* constants
if ns == "color":
Expand Down Expand Up @@ -1873,7 +1903,7 @@ def _visit_MemberAccess(self, node: MemberAccess) -> PineType:

# text.* constants (align_left, align_right, etc.)
if ns == "text":
return PineType.INT
return PineType.STRING

# extend.* constants (left, right, both, none)
if ns == "extend":
Expand Down
51 changes: 46 additions & 5 deletions pineforge_codegen/analyzer/call_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from .contracts import FixnanCallSite, FuncInfo, SecurityCallInfo, TACallSite
from .tables import (
BAR_FIELDS, TA_CLASS_MAP, TA_MULTI_CTOR, TA_NO_CTOR, TA_PERIOD_ARG,
TA_TUPLE_RETURNS,
TA_TUPLE_RETURNS, TA_TUPLE_ELEMENT_COUNTS, TA_COMPUTE_ARGS,
)


Expand Down Expand Up @@ -235,9 +235,14 @@ def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType:
elif func_name in TA_PERIOD_ARG:
ctor_indices = {TA_PERIOD_ARG[func_name]}

for i, arg in enumerate(all_args):
if i not in ctor_indices and arg is not None:
compute_args.append(arg)
if func_name in TA_COMPUTE_ARGS:
for i in TA_COMPUTE_ARGS[func_name]:
if i < len(all_args) and all_args[i] is not None:
compute_args.append(all_args[i])
else:
for i, arg in enumerate(all_args):
if i not in ctor_indices and arg is not None:
compute_args.append(arg)

is_static = self._global_scope and all(self._is_static_expression(arg) for arg in compute_args)
site = TACallSite(
Expand Down Expand Up @@ -309,6 +314,27 @@ def _handle_request_call(self, func_name: str, node: FuncCall) -> PineType:

returns_tuple = isinstance(expr_node, TupleLiteral)
tuple_size = len(expr_node.elements) if returns_tuple else 0
if not returns_tuple and isinstance(expr_node, FuncCall):
expr_func = None
expr_ns = None
if (isinstance(expr_node.callee, MemberAccess)
and isinstance(expr_node.callee.object, Identifier)):
expr_ns = expr_node.callee.object.name
expr_func = expr_node.callee.member
if expr_ns == "ta":
if expr_func == "vwap":
merged_v = list(expr_node.args)
for i, pname in enumerate(["source", "anchor", "stdev_mult"]):
if pname in expr_node.kwargs:
while len(merged_v) <= i:
merged_v.append(None)
if merged_v[i] is None:
merged_v[i] = expr_node.kwargs[pname]
if len(merged_v) >= 3:
expr_func = "vwap_bands"
if expr_func in TA_TUPLE_RETURNS:
returns_tuple = True
tuple_size = TA_TUPLE_ELEMENT_COUNTS.get(expr_func, 0)

gaps_node = all_args[3] if len(all_args) > 3 else None
lookahead_node = all_args[4] if len(all_args) > 4 else None
Expand Down Expand Up @@ -899,6 +925,18 @@ def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType:
arg = node.args[p_idx]
if isinstance(arg, Identifier) and arg.name in BAR_FIELDS:
self._series_bar_fields.add(arg.name)
elif isinstance(arg, Identifier):
sym = self._symbols.resolve(arg.name)
spec = getattr(sym, "type_spec", None) if sym is not None else None
if spec is not None and spec.kind in ("array", "map", "matrix"):
continue
if sym is not None:
sym.is_series = True
if sym.scope and sym.scope.startswith("func_"):
caller_name = sym.scope[5:]
self._func_series_vars.setdefault(caller_name, set()).add(arg.name)
else:
self._series_vars.add(arg.name)

# Per-call-site cloning: if this function has TA calls or series vars,
# track call sites so codegen can create per-call-site variants.
Expand Down Expand Up @@ -1041,7 +1079,10 @@ def _rep(m: re.Match) -> str:
# Per-param TypeSpec: declared hints are authoritative; for untyped
# params, infer from the call-site argument's type_spec (so an untyped
# ``s`` used as a string, or a UDT passed by value, emits correctly).
param_specs = self._param_type_specs_from_def(func_def)
param_specs = list(
getattr(self, "_func_param_type_specs", {}).get(func_name)
or self._param_type_specs_from_def(func_def)
)
arg_specs = [self._type_spec_from_expr(arg) for arg in node.args]
for i in range(len(param_specs)):
if param_specs[i] is None and i < len(arg_specs):
Expand Down
22 changes: 20 additions & 2 deletions pineforge_codegen/analyzer/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
TA_PERIOD_ARG = {
"sma": 1, "ema": 1, "rma": 1, "rsi": 1, "atr": 0,
"highest": 1, "lowest": 1, "change": 1,
"wma": 1, "hma": 1, "stdev": 1,
"wma": 1, "hma": 1,
# Task 6
"sum": 1,
# Task 7 Batch 1
Expand All @@ -119,7 +119,7 @@
"mom": 1, "roc": 1, "rising": 1, "falling": 1, "cci": 1,
# cum has no period arg — handled in TA_NO_CTOR
# Task 7 Batch 3
"variance": 1, "median": 1, "highestbars": 1, "lowestbars": 1,
"median": 1, "highestbars": 1, "lowestbars": 1,
# Batch 4
"cmo": 1, "cog": 1, "correlation": 2,
"percentile_nearest_rank": 1, "percentile_linear_interpolation": 1,
Expand All @@ -130,6 +130,14 @@

# Functions that return tuples
TA_TUPLE_RETURNS = {"macd", "supertrend", "dmi", "bb", "kc", "vwap_bands"}
TA_TUPLE_ELEMENT_COUNTS = {
"macd": 3,
"supertrend": 2,
"dmi": 3,
"bb": 3,
"kc": 3,
"vwap_bands": 3,
}

# Functions with multiple constructor args
TA_MULTI_CTOR = {
Expand All @@ -156,6 +164,16 @@
"bbw": [1, 2], # length, mult
"kcw": [1, 2], # length, mult
"tr": [0], # handle_na (compile-time bool)
"stdev": [1, 2], # length, biased
"variance": [1, 2], # length, biased
}

# Compute-arg indices: which positional args are forwarded to ``.compute()``.
# Entries here override the default analyzer behavior of forwarding every
# non-constructor argument.
TA_COMPUTE_ARGS = {
"stdev": [0],
"variance": [0],
}

# No-state functions (no constructor args, stateless or self-contained)
Expand Down
Loading
Loading