diff --git a/pineforge_codegen/analyzer/base.py b/pineforge_codegen/analyzer/base.py index b7e4d78..0befe92 100644 --- a/pineforge_codegen/analyzer/base.py +++ b/pineforge_codegen/analyzer/base.py @@ -156,6 +156,7 @@ def __init__(self, ast: Program, filename: str = "") -> None: # Probe: data/validation/udt-method-probe-20-udt-return-from-func. self._func_udt_return_types: dict[str, str] = {} self._func_return_type_specs: dict[str, "TypeSpec"] = {} + self._func_param_type_specs: dict[str, list] = {} # Per-function var_members and series_vars (for call-site cloning) self._func_var_members: dict[str, list] = {} # func_name -> [(name, PineType, init_str)] self._func_series_vars: dict[str, set] = {} # func_name -> set[str] @@ -1184,6 +1185,16 @@ def _visit_FuncDef(self, node: FuncDef) -> PineType: if hi > lo: self._func_ta_ranges[node.name] = (lo, hi) + inferred_param_specs = self._param_type_specs_from_def(node) + for i, param in enumerate(node.params): + if i < len(inferred_param_specs) and inferred_param_specs[i] is not None: + continue + sym = self._symbols.resolve(param) + spec = getattr(sym, "type_spec", None) if sym is not None else None + if spec is not None and i < len(inferred_param_specs): + inferred_param_specs[i] = spec + self._func_param_type_specs[node.name] = inferred_param_specs + self._symbols.exit_scope() # Detect if function returns a tuple (last stmt is TupleLiteral) @@ -1511,6 +1522,20 @@ def _visit_BinOp(self, node: BinOp) -> PineType: # String concatenation: if either side is STRING, result is STRING if left_type == PineType.STRING or right_type == PineType.STRING: + def _mark_string_param(expr) -> None: + if not isinstance(expr, Identifier): + return + sym = self._symbols.resolve(expr.name) + if sym is None or not (sym.scope or "").startswith("func_"): + return + if sym.pine_type == PineType.UNKNOWN: + sym.pine_type = PineType.STRING + sym.type_spec = TypeSpec.primitive("string") + + if left_type == PineType.STRING: + _mark_string_param(node.right) + if right_type == PineType.STRING: + _mark_string_param(node.left) return PineType.STRING # Arithmetic: promote to FLOAT if either side is FLOAT @@ -1583,11 +1608,14 @@ def _visit_FuncCall(self, node: FuncCall) -> PineType: if isinstance(obj, Identifier) and obj.name == "str": for arg in node.args: self._visit(arg) - # Most str.* return a string, but a few don't: - # str.tonumber -> float, str.length -> int + # Most str.* return a string, but predicates and index helpers + # are scalar. Keep this aligned with signatures.py and the C++ + # emitter's _infer_type path. + if member in ("contains", "startswith", "endswith"): + return PineType.BOOL if member == "tonumber": return PineType.FLOAT - if member == "length": + if member in ("length", "pos"): return PineType.INT return PineType.STRING @@ -1812,9 +1840,11 @@ def _visit_MemberAccess(self, node: MemberAccess) -> PineType: # syminfo.* if ns == "syminfo": - if node.member == "mintick": - return PineType.FLOAT - return PineType.STRING + from .. import signatures as _pf_sigs + return _pf_sigs.SYMINFO_VARIABLES.get( + f"syminfo.{node.member}", + PineType.STRING, + ) # color.* constants if ns == "color": @@ -1873,7 +1903,7 @@ def _visit_MemberAccess(self, node: MemberAccess) -> PineType: # text.* constants (align_left, align_right, etc.) if ns == "text": - return PineType.INT + return PineType.STRING # extend.* constants (left, right, both, none) if ns == "extend": diff --git a/pineforge_codegen/analyzer/call_handlers.py b/pineforge_codegen/analyzer/call_handlers.py index 552bcbc..11ef1c0 100644 --- a/pineforge_codegen/analyzer/call_handlers.py +++ b/pineforge_codegen/analyzer/call_handlers.py @@ -75,7 +75,7 @@ from .contracts import FixnanCallSite, FuncInfo, SecurityCallInfo, TACallSite from .tables import ( BAR_FIELDS, TA_CLASS_MAP, TA_MULTI_CTOR, TA_NO_CTOR, TA_PERIOD_ARG, - TA_TUPLE_RETURNS, + TA_TUPLE_RETURNS, TA_TUPLE_ELEMENT_COUNTS, TA_COMPUTE_ARGS, ) @@ -235,9 +235,14 @@ def _handle_ta_call(self, func_name: str, node: FuncCall) -> PineType: elif func_name in TA_PERIOD_ARG: ctor_indices = {TA_PERIOD_ARG[func_name]} - for i, arg in enumerate(all_args): - if i not in ctor_indices and arg is not None: - compute_args.append(arg) + if func_name in TA_COMPUTE_ARGS: + for i in TA_COMPUTE_ARGS[func_name]: + if i < len(all_args) and all_args[i] is not None: + compute_args.append(all_args[i]) + else: + for i, arg in enumerate(all_args): + if i not in ctor_indices and arg is not None: + compute_args.append(arg) is_static = self._global_scope and all(self._is_static_expression(arg) for arg in compute_args) site = TACallSite( @@ -309,6 +314,27 @@ def _handle_request_call(self, func_name: str, node: FuncCall) -> PineType: returns_tuple = isinstance(expr_node, TupleLiteral) tuple_size = len(expr_node.elements) if returns_tuple else 0 + if not returns_tuple and isinstance(expr_node, FuncCall): + expr_func = None + expr_ns = None + if (isinstance(expr_node.callee, MemberAccess) + and isinstance(expr_node.callee.object, Identifier)): + expr_ns = expr_node.callee.object.name + expr_func = expr_node.callee.member + if expr_ns == "ta": + if expr_func == "vwap": + merged_v = list(expr_node.args) + for i, pname in enumerate(["source", "anchor", "stdev_mult"]): + if pname in expr_node.kwargs: + while len(merged_v) <= i: + merged_v.append(None) + if merged_v[i] is None: + merged_v[i] = expr_node.kwargs[pname] + if len(merged_v) >= 3: + expr_func = "vwap_bands" + if expr_func in TA_TUPLE_RETURNS: + returns_tuple = True + tuple_size = TA_TUPLE_ELEMENT_COUNTS.get(expr_func, 0) gaps_node = all_args[3] if len(all_args) > 3 else None lookahead_node = all_args[4] if len(all_args) > 4 else None @@ -899,6 +925,18 @@ def _handle_user_func_call(self, func_name: str, node: FuncCall) -> PineType: arg = node.args[p_idx] if isinstance(arg, Identifier) and arg.name in BAR_FIELDS: self._series_bar_fields.add(arg.name) + elif isinstance(arg, Identifier): + sym = self._symbols.resolve(arg.name) + spec = getattr(sym, "type_spec", None) if sym is not None else None + if spec is not None and spec.kind in ("array", "map", "matrix"): + continue + if sym is not None: + sym.is_series = True + if sym.scope and sym.scope.startswith("func_"): + caller_name = sym.scope[5:] + self._func_series_vars.setdefault(caller_name, set()).add(arg.name) + else: + self._series_vars.add(arg.name) # Per-call-site cloning: if this function has TA calls or series vars, # track call sites so codegen can create per-call-site variants. @@ -1041,7 +1079,10 @@ def _rep(m: re.Match) -> str: # Per-param TypeSpec: declared hints are authoritative; for untyped # params, infer from the call-site argument's type_spec (so an untyped # ``s`` used as a string, or a UDT passed by value, emits correctly). - param_specs = self._param_type_specs_from_def(func_def) + param_specs = list( + getattr(self, "_func_param_type_specs", {}).get(func_name) + or self._param_type_specs_from_def(func_def) + ) arg_specs = [self._type_spec_from_expr(arg) for arg in node.args] for i in range(len(param_specs)): if param_specs[i] is None and i < len(arg_specs): diff --git a/pineforge_codegen/analyzer/tables.py b/pineforge_codegen/analyzer/tables.py index 558f836..2d4faf1 100644 --- a/pineforge_codegen/analyzer/tables.py +++ b/pineforge_codegen/analyzer/tables.py @@ -110,7 +110,7 @@ TA_PERIOD_ARG = { "sma": 1, "ema": 1, "rma": 1, "rsi": 1, "atr": 0, "highest": 1, "lowest": 1, "change": 1, - "wma": 1, "hma": 1, "stdev": 1, + "wma": 1, "hma": 1, # Task 6 "sum": 1, # Task 7 Batch 1 @@ -119,7 +119,7 @@ "mom": 1, "roc": 1, "rising": 1, "falling": 1, "cci": 1, # cum has no period arg — handled in TA_NO_CTOR # Task 7 Batch 3 - "variance": 1, "median": 1, "highestbars": 1, "lowestbars": 1, + "median": 1, "highestbars": 1, "lowestbars": 1, # Batch 4 "cmo": 1, "cog": 1, "correlation": 2, "percentile_nearest_rank": 1, "percentile_linear_interpolation": 1, @@ -130,6 +130,14 @@ # Functions that return tuples TA_TUPLE_RETURNS = {"macd", "supertrend", "dmi", "bb", "kc", "vwap_bands"} +TA_TUPLE_ELEMENT_COUNTS = { + "macd": 3, + "supertrend": 2, + "dmi": 3, + "bb": 3, + "kc": 3, + "vwap_bands": 3, +} # Functions with multiple constructor args TA_MULTI_CTOR = { @@ -156,6 +164,16 @@ "bbw": [1, 2], # length, mult "kcw": [1, 2], # length, mult "tr": [0], # handle_na (compile-time bool) + "stdev": [1, 2], # length, biased + "variance": [1, 2], # length, biased +} + +# Compute-arg indices: which positional args are forwarded to ``.compute()``. +# Entries here override the default analyzer behavior of forwarding every +# non-constructor argument. +TA_COMPUTE_ARGS = { + "stdev": [0], + "variance": [0], } # No-state functions (no constructor args, stateless or self-contained) diff --git a/pineforge_codegen/codegen/base.py b/pineforge_codegen/codegen/base.py index 6558728..7561308 100644 --- a/pineforge_codegen/codegen/base.py +++ b/pineforge_codegen/codegen/base.py @@ -53,6 +53,7 @@ SKIP_VAR_TYPES, SYMINFO_MEMBER_MAP, COLOR_CONST_MAP, + ARRAY_NEW_CTORS, ARRAY_METHODS, MAP_METHODS, MATRIX_METHODS, @@ -65,6 +66,15 @@ _merge_kwargs, ) +TA_TUPLE_RESULT_TYPES = { + "macd": "ta::MACDResult", + "supertrend": "ta::SupertrendResult", + "dmi": "ta::DMIResult", + "bb": "ta::BBResult", + "kc": "ta::KCResult", + "vwap_bands": "ta::VWAPBandsResult", +} + # (TA_IMPLICIT_COMPUTE / TA_COMPUTE_ARGS now imported from .tables above.) # (TA_IMPLICIT_COMPUTE_FULL / TA_IMPLICIT_APPEND / PINE_TYPE_TO_CPP / @@ -424,12 +434,15 @@ def __init__(self, ctx: AnalyzerContext) -> None: self._security_calls: list[dict] = [self._normalize_security_call(item) for item in ctx.security_calls] # Current function parameter types (set during _emit_func_def) self._current_func_param_types: dict[str, str] = {} + self._current_func_param_specs: dict[str, "TypeSpec"] = {} # Current function params that are series (const Series&) self._current_func_series_params: set[str] = set() # Locals declared in the function currently being emitted (symbol table loses them after analysis) self._current_func_locals: set[str] = set() + self._current_func_local_types: dict[str, str] = {} # for-in loop iterator names (must resolve member access, not enum fallback) self._current_loop_vars: set[str] = set() + self._current_loop_var_specs: dict[str, "TypeSpec"] = {} # Track array variables for codegen self._array_vars: set[str] = set() # Track map variables for codegen @@ -854,31 +867,33 @@ def _register_global_aggregate_member_types(self) -> None: spec = self._matrix_specs.get(recv_name) or TypeSpec.matrix(TypeSpec.primitive("float")) self._matrix_specs[name] = spec self._collection_types[name] = spec - elif ns == "array" and fn in ( - "new", - "new_float", - "new_int", - "new_bool", - "new_string", - "from", - ): + elif ns == "array" and fn in ({"new", "from"} | set(ARRAY_NEW_CTORS)): self._array_vars.add(name) + spec = self._type_spec_from_expr(expr) or self._array_spec_for_name(name) + self._collection_types[name] = spec elif ns == "map" and fn == "new": self._map_vars.add(name) - # Also register var/varip matrix members from AST nodes so that - # the typed-matrix gate checks see the correct element spec. + # Also register var/varip aggregate members from AST nodes so that + # class-member declarations see the precise collection type before + # on_bar emits the initializer. This is required for unannotated + # drawing arrays such as ``var boxes = array.new_box()``. var_decl_map: dict[str, FuncCall] = {} for stmt in (self.ctx.ast.body if hasattr(self.ctx, "ast") else []): if isinstance(stmt, VarDecl) and isinstance(stmt.value, FuncCall): var_decl_map[stmt.name] = stmt.value for name, _ptype, _init_str in self.ctx.var_members: - if name in self._matrix_specs: - continue expr = var_decl_map.get(name) if expr is None: continue fn2, ns2 = self._resolve_callee(expr.callee) + if ns2 == "array" and fn2 in ({"new", "from"} | set(ARRAY_NEW_CTORS)): + self._array_vars.add(name) + spec2 = self._type_spec_from_expr(expr) or self._array_spec_for_name(name) + self._collection_types[name] = spec2 + continue + if name in self._matrix_specs: + continue if ns2 == "matrix" and fn2 == "new": targs2 = self._template_args_from_call(expr) if hasattr(expr, "annotations") else [] elem_spec2 = self._type_spec_from_hint_name(targs2[0]) if targs2 else TypeSpec.primitive("float") @@ -1226,6 +1241,9 @@ def generate(self) -> str: # request.security TA call-sites read at a history offset (``ta.ema(...)[k>=1]``). # Maps sec_id -> set of TA call-site indices needing an HTF history Series. self._security_ta_hist_idx_by_sec: dict[int, set[int]] = {} + # request.security helper-call results read at a history offset + # (``myHelper()[k]``). Maps (sec_id, node-id) -> backing Series metadata. + self._security_expr_hist_by_node: dict[tuple[int, int], dict] = {} lines: list[str] = [] @@ -1318,19 +1336,26 @@ def generate(self) -> str: for field in sorted( self._security_ohlc_hist_fields_by_sec.get(sec_id, ()) ): + ctype = self._security_bar_hist_type(field) lines.append( - f" Series {self._security_ohlc_hist_series_cpp(sec_id, field)}{_mbb};" + f" Series<{ctype}> {self._security_ohlc_hist_series_cpp(sec_id, field)}{_mbb};" ) self._security_ta_hist_idx_by_sec[sec_id] = ( self._collect_security_ta_hist_indices(expr_node) ) for name in self._security_ta_hist_series_names(sec_id): lines.append(f" Series {name}{_mbb};") + self._emit_security_expr_hist_members(sec_id, expr_node, lines, _mbb) continue if returns_tuple and tuple_size and tuple_size > 0 and isinstance(expr_node, TupleLiteral): hist_fields: set[str] = set() for el in expr_node.elements: hist_fields |= self._collect_security_ohlc_hist_fields(el) + for name in item.get("mutable_globals", []) or []: + info = self._global_mutable_infos.get(name) + if info is not None: + for stmt in getattr(info, "source_stmts", []) or []: + hist_fields |= self._collect_security_ohlc_hist_fields(stmt) self._security_ohlc_hist_fields_by_sec[sec_id] = hist_fields for i, el in enumerate(expr_node.elements): ctype = self._infer_cpp_type_for_security_elem(el) @@ -1338,20 +1363,31 @@ def generate(self) -> str: lines.append(f" {ctype} _req_sec_{sec_id}_{i}{{}};") else: lines.append(f" {ctype} _req_sec_{sec_id}_{i} = na();") + elif returns_tuple and tuple_size and tuple_size > 0: + self._security_ohlc_hist_fields_by_sec[sec_id] = ( + self._collect_security_ohlc_hist_fields_for_call(item) + ) + site = self._get_ta_site(expr_node) + ta_name = self._ta_name_from_site(site) if site is not None else "" + ctype = TA_TUPLE_RESULT_TYPES.get(ta_name, "std::tuple") + default = self._security_tuple_result_default(ctype, tuple_size) + lines.append(f" {ctype} _req_sec_{sec_id} = {default};") else: - self._security_ohlc_hist_fields_by_sec[sec_id] = self._collect_security_ohlc_hist_fields( - expr_node + self._security_ohlc_hist_fields_by_sec[sec_id] = ( + self._collect_security_ohlc_hist_fields_for_call(item) ) lines.append(f" double _req_sec_{sec_id} = na();") for field in sorted(self._security_ohlc_hist_fields_by_sec.get(sec_id, ())): + ctype = self._security_bar_hist_type(field) lines.append( - f" Series {self._security_ohlc_hist_series_cpp(sec_id, field)}{_mbb};" + f" Series<{ctype}> {self._security_ohlc_hist_series_cpp(sec_id, field)}{_mbb};" ) self._security_ta_hist_idx_by_sec[sec_id] = ( self._collect_security_ta_hist_indices(expr_node) ) for name in self._security_ta_hist_series_names(sec_id): lines.append(f" Series {name}{_mbb};") + self._emit_security_expr_hist_members(sec_id, expr_node, lines, _mbb) if self._security_calls: lines.append(' std::unordered_map> _security_helper_series_;') @@ -1377,7 +1413,7 @@ def generate(self) -> str: # 3. TA members for site in self.ctx.ta_call_sites: lines.append(f" {site.class_name} {site.member_name};") - if getattr(site, "is_static", False): + if self._ta_site_uses_precalc(site): vtype = self._ta_return_type(site) lines.append(f" std::vector<{vtype}> _precalc_{site.member_name};") lines.append(" bool _use_precalc = false;") @@ -1762,7 +1798,7 @@ def _is_skip_expr(self, node) -> bool: if self._is_chart_point_callee(node.callee): return False func_name, namespace = self._resolve_callee(node.callee) - if func_name in SKIP_FUNC_NAMES: + if namespace is None and func_name in SKIP_FUNC_NAMES: return True if namespace in SKIP_NAMESPACES: return True diff --git a/pineforge_codegen/codegen/drawing.py b/pineforge_codegen/codegen/drawing.py index 474deb2..6c9c19b 100644 --- a/pineforge_codegen/codegen/drawing.py +++ b/pineforge_codegen/codegen/drawing.py @@ -255,7 +255,7 @@ def _emit_chart_point(self, func_name: str, node: FuncCall) -> str: vals = self._merge_drawing_args(node, ["price"]) price = (self._visit_expr(vals["price"]) if vals.get("price") is not None else "current_bar_.close") - return (f"ChartPoint{{ .index=bar_index_, " + return (f"ChartPoint{{ .index=(int64_t)(pine_bar_index()), " f".time=(int64_t)current_bar_.timestamp, .price=({price}) }}") if func_name == "from_index": vals = self._merge_drawing_args(node, ["index", "price"]) diff --git a/pineforge_codegen/codegen/emit_top.py b/pineforge_codegen/codegen/emit_top.py index f3a3e99..b34f620 100644 --- a/pineforge_codegen/codegen/emit_top.py +++ b/pineforge_codegen/codegen/emit_top.py @@ -470,7 +470,7 @@ def _emit_on_bar(self, lines: list[str]) -> None: if _bname in self._var_names: continue _bexpr = BAR_BUILTINS.get(_bname) - if _bexpr is None or f"{_bname}(" in _bexpr: + if _bexpr is None or _bexpr.strip().startswith(f"{_bname}("): continue _bsafe = self._safe_name(_bname) lines.append(f" if (is_first_tick_) {_bsafe}.push({_bexpr});") @@ -776,10 +776,12 @@ def _emit_func_def(self, fi: FuncInfo, lines: list[str], call_site_idx: int | No # Determine param types and set context for type inference inside body param_strs = [] self._current_func_param_types = {} + self._current_func_param_specs = {} self._current_func_series_params = set() self._udt_param_udt = {} func_sv = self.ctx.func_series_vars.get(fi.name, set()) for i, p in enumerate(node.params): + spec = None if is_udt and i == 0 and fi.udt_type_name: # A method receiver whose type is a drawing primitive # (egoigor's ``method slope(line ln)``) must emit ``Line&`` not @@ -821,6 +823,9 @@ def _emit_func_def(self, fi: FuncInfo, lines: list[str], call_site_idx: int | No cpp_t = "double" param_strs.append(f"{cpp_t} {self._safe_name(p)}") self._current_func_param_types[p] = cpp_t + if spec is not None: + self._current_func_param_specs[p] = spec + self._current_func_param_specs[self._safe_name(p)] = spec # Determine return type: tuple, UDT, or scalar. # The UDT branch handles user functions whose body is ``T.new(...)``; @@ -873,6 +878,7 @@ def _emit_func_def(self, fi: FuncInfo, lines: list[str], call_site_idx: int | No self._current_instance_name = None prev_func_locals = self._current_func_locals + prev_func_local_types = self._current_func_local_types prev_func_body = getattr(self, "_current_func_body", None) prev_func_name = getattr(self, "_active_func_name", None) # The function body is the lexical scope used by the UDT-alias analysis @@ -886,6 +892,7 @@ def _emit_func_def(self, fi: FuncInfo, lines: list[str], call_site_idx: int | No prev_ptr_alias = self._udt_ptr_alias_locals self._udt_ptr_alias_locals = set() self._current_func_locals = {n for n, _, _ in self.ctx.func_var_members.get(fi.name, [])} + self._current_func_local_types = {} # Plain (non-persistent) scalar locals are emitted inline and live in # no other set; collect them so the unknown-identifier guard in # _visit_ident does not mistake them for undeclared symbols. @@ -956,9 +963,11 @@ def _emit_func_def(self, fi: FuncInfo, lines: list[str], call_site_idx: int | No lines.append(" }") self._current_func_param_types = {} + self._current_func_param_specs = {} self._current_func_series_params = set() self._udt_param_udt = {} self._current_func_locals = prev_func_locals + self._current_func_local_types = prev_func_local_types self._current_func_body = prev_func_body self._active_func_name = prev_func_name self._udt_ptr_alias_locals = prev_ptr_alias @@ -1031,10 +1040,22 @@ class member but its initializer was dropped, leaving the member ``na`` lines.append(" }") def _emit_precalculate_and_run(self, lines: list[str]) -> None: - has_static_ta = any(getattr(site, "is_static", False) for site in self.ctx.ta_call_sites) + has_static_ta = any(self._ta_site_uses_precalc(site) for site in self.ctx.ta_call_sites) if not has_static_ta: return + replayed_source_series: list[str] = [] + for stmt in self.ctx.ast.body: + if not isinstance(stmt, VarDecl): + continue + if stmt.name not in self._global_member_vars: + continue + if not (isinstance(stmt.value, FuncCall) and self._is_source_input(stmt.value)): + continue + if stmt.name in self.ctx.series_vars: + replayed_source_series.append(self._safe_name(stmt.name)) + replayed_source_series = sorted(set(replayed_source_series)) + lines.append(" void precalculate(const Bar* bars, int n) {") lines.append(" _use_precalc = false;") lines.append(" if (n <= 0 || bars == nullptr) return;") @@ -1042,13 +1063,13 @@ def _emit_precalculate_and_run(self, lines: list[str]) -> None: # Resize precalculated vectors for site in self.ctx.ta_call_sites: - if getattr(site, "is_static", False): + if self._ta_site_uses_precalc(site): lines.append(f" _precalc_{site.member_name}.resize(n);") # Reset indicators to clean slate lines.append("") for site in self.ctx.ta_call_sites: - if getattr(site, "is_static", False): + if self._ta_site_uses_precalc(site): resolved = [self._resolve_known(a) for a in site.ctor_args] safe_resolved = [] for r in resolved: @@ -1059,6 +1080,13 @@ def _emit_precalculate_and_run(self, lines: list[str]) -> None: lines.append("") for field_name in sorted(self.ctx.series_bar_fields): lines.append(f" _s_{field_name}.clear();") + for safe in replayed_source_series: + lines.append(f" {safe}.clear();") + if self._script_has_input_source(): + lines.append(" _src_open_.clear(); _src_high_.clear(); _src_low_.clear();") + lines.append(" _src_close_.clear(); _src_volume_.clear();") + lines.append(" _src_hl2_.clear(); _src_hlc3_.clear();") + lines.append(" _src_ohlc4_.clear(); _src_hlcc4_.clear();") # Start precalculation loop lines.append("") @@ -1081,8 +1109,8 @@ def _emit_precalculate_and_run(self, lines: list[str]) -> None: # entire precalculation, silently corrupting its precalculated # values (e.g. a Bollinger Band's stdev collapsing to 0). Gated on # ``_src_series_active_`` to stay a no-op for scripts with no - # input.source() usage; cleared again before the real run begins - # (BacktestEngine::run() clears every _src_*_ series unconditionally). + # input.source() usage; cleared before and after the precalc pass so + # replayed source history cannot leak into the real run. lines.append(" if (_src_series_active_) {") lines.append(" const double _pc_o = bars[i].open;") lines.append(" const double _pc_h = bars[i].high;") @@ -1137,7 +1165,7 @@ def _emit_precalculate_and_run(self, lines: list[str]) -> None: self._precalc_loop_active = True try: for site in self.ctx.ta_call_sites: - if getattr(site, "is_static", False): + if self._ta_site_uses_precalc(site): compute_args = self._ta_compute_args_for_site(site) compute_args_bars = compute_args.replace("current_bar_.", "bars[i].") lines.append(f" _precalc_{site.member_name}[i] = {site.member_name}.compute({compute_args_bars});") @@ -1149,7 +1177,7 @@ def _emit_precalculate_and_run(self, lines: list[str]) -> None: # Reset indicators and series for the real backtest run lines.append("") for site in self.ctx.ta_call_sites: - if getattr(site, "is_static", False): + if self._ta_site_uses_precalc(site): resolved = [self._resolve_known(a) for a in site.ctor_args] safe_resolved = [] for r in resolved: @@ -1158,6 +1186,13 @@ def _emit_precalculate_and_run(self, lines: list[str]) -> None: for field_name in sorted(self.ctx.series_bar_fields): lines.append(f" _s_{field_name}.clear();") + for safe in replayed_source_series: + lines.append(f" {safe}.clear();") + if self._script_has_input_source(): + lines.append(" _src_open_.clear(); _src_high_.clear(); _src_low_.clear();") + lines.append(" _src_close_.clear(); _src_volume_.clear();") + lines.append(" _src_hl2_.clear(); _src_hlc3_.clear();") + lines.append(" _src_ohlc4_.clear(); _src_hlcc4_.clear();") lines.append("") lines.append(" _use_precalc = true;") diff --git a/pineforge_codegen/codegen/security.py b/pineforge_codegen/codegen/security.py index defbe1e..33e82f1 100644 --- a/pineforge_codegen/codegen/security.py +++ b/pineforge_codegen/codegen/security.py @@ -60,15 +60,20 @@ from ..ast_nodes import ( ASTNode, Assignment, BinOp, BreakStmt, ContinueStmt, ExprStmt, ForStmt, - ForInStmt, FuncCall, FuncDef, Identifier, IfStmt, NumberLiteral, StringLiteral, - Subscript, SwitchStmt, Ternary, TupleAssign, TupleLiteral, UnaryOp, VarDecl, - WhileStmt, + ForInStmt, FuncCall, FuncDef, Identifier, IfStmt, MemberAccess, NumberLiteral, + StringLiteral, Subscript, SwitchStmt, Ternary, TupleAssign, TupleLiteral, + UnaryOp, VarDecl, WhileStmt, ) from ..analyzer import ( FuncInfo, TACallSite, TA_MULTI_CTOR, TA_NO_CTOR, TA_PERIOD_ARG, ) +from .. import signatures as sigs from ..symbols import PineType -from .tables import PINE_TYPE_TO_CPP, SECURITY_OHLC_BAR_FIELDS +from .tables import ( + MATH_FUNC_MAP, PINE_TYPE_TO_CPP, SECURITY_BAR_FIELDS, + SECURITY_BAR_FIELD_EXPRS, SECURITY_BAR_FIELD_TYPES, _math_minmax_na_expr, + _merge_kwargs, +) class SecurityEmitter: @@ -100,6 +105,13 @@ def _resolve_security_tf(self, tf_node, containing_func: str): return self._known_vars[name], None if name in self._input_backed_vars and name in self._input_var_to_call: return None, self._visit_expr(self._input_var_to_call[name]) + global_expr_map = getattr(self.ctx, "global_expr_map", {}) or {} + if name in global_expr_map: + expanded = self._security_tf_runtime_expr( + global_expr_map[name], resolving={name} + ) + if expanded is not None: + return None, expanded # class-scope resolvable (global / input member)? if self._ident_is_resolvable(name): try: @@ -115,10 +127,56 @@ def _resolve_security_tf(self, tf_node, containing_func: str): return None, "input_tf_" # any other expression — visit if it resolves at class scope try: - return None, self._visit_expr(tf_node) + expanded = self._security_tf_runtime_expr(tf_node) + return None, expanded if expanded is not None else self._visit_expr(tf_node) except Exception: return None, "input_tf_" + def _security_tf_runtime_expr(self, node, resolving: set[str] | None = None) -> str | None: + """Render a request.security timeframe expression for registration time. + + Security evaluators are registered before ``on_bar()`` initializes global + variables, so input-backed aliases such as ``tf = useChart ? timeframe.period + : inputTf`` must be expanded to their source expression with direct + ``get_input_*`` reads. Emitting the member name would register with its + default-constructed value (usually an empty string). + """ + if node is None: + return None + resolving = resolving or set() + if isinstance(node, StringLiteral): + return self._visit_expr(node) + if isinstance(node, Identifier): + name = node.name + if name in self._timeframe_period_vars: + return "script_tf_" + if name in self._input_backed_vars and name in self._input_var_to_call: + return self._visit_expr(self._input_var_to_call[name]) + if (name in self._known_vars and name not in self._input_backed_vars + and isinstance(self._known_vars[name], str)): + return self._visit_expr(StringLiteral(value=self._known_vars[name])) + global_expr_map = getattr(self.ctx, "global_expr_map", {}) or {} + if name in global_expr_map and name not in resolving: + resolving.add(name) + out = self._security_tf_runtime_expr(global_expr_map[name], resolving) + resolving.remove(name) + return out + return self._visit_expr(node) + if ( + isinstance(node, MemberAccess) + and isinstance(node.object, Identifier) + and node.object.name == "timeframe" + and node.member == "period" + ): + return "script_tf_" + if isinstance(node, Ternary): + cond = self._security_tf_runtime_expr(node.condition, resolving) + tv = self._security_tf_runtime_expr(node.true_val, resolving) + fv = self._security_tf_runtime_expr(node.false_val, resolving) + if cond is not None and tv is not None and fv is not None: + return f"(({cond}) ? ({tv}) : ({fv}))" + return self._visit_expr(node) + def _resolve_param_tf_from_callsites(self, func_name: str, param_name: str): """For a ``request.security`` whose tf is function parameter ``param_name`` of user function ``func_name``, return ``(tf_str, tf_expr)`` resolved from @@ -276,7 +334,7 @@ def _security_lookup_helper_binding( return None def _literal_int_for_security_index(self, node) -> int | None: - """Integer index for OHLC[ n ] inside request.security (must be literal).""" + """Integer index for bar-field[n] inside request.security (must be literal).""" if isinstance(node, NumberLiteral): v = node.value if isinstance(v, bool): @@ -296,18 +354,20 @@ def _literal_int_for_security_index(self, node) -> int | None: return None def _collect_security_ohlc_hist_fields(self, node) -> set[str]: - """Which OHLC fields need HTF history (subscript index >= 1) for this expression.""" + """Which security bar fields need HTF history (subscript index >= 1).""" out: set[str] = set() def walk(n): if n is None: return if isinstance(n, Subscript) and isinstance(n.object, Identifier): - if n.object.name in SECURITY_OHLC_BAR_FIELDS: + if n.object.name in SECURITY_BAR_FIELDS: idx = self._literal_int_for_security_index(n.index) - # high[0] uses current HTF `bar`; high[k>=1] reads prior completed HTF - # bars from Series history (filled before push in _eval_security_*). - if idx is not None and idx >= 1: + # high[0]/time[0] uses current HTF `bar`; k>=1 reads prior + # completed HTF bars from Series history (filled before push + # in _eval_security_*). Dynamic indices may be >=1 at + # runtime, so they need the same backing Series. + if idx is None or idx >= 1: out.add(n.object.name) if isinstance(n, (list, tuple)): for x in n: @@ -324,9 +384,32 @@ def walk(n): walk(node) return out + def _collect_security_ohlc_hist_fields_for_call(self, item: dict) -> set[str]: + """Collect HTF OHLC history needed by a security expression and any + mutable-global rebinds replayed inside that security evaluator.""" + fields = self._collect_security_ohlc_hist_fields(item.get("expr_node")) + for name in item.get("mutable_globals", []) or []: + info = self._global_mutable_infos.get(name) + if info is None: + continue + for stmt in getattr(info, "source_stmts", []) or []: + fields |= self._collect_security_ohlc_hist_fields(stmt) + return fields + def _security_ohlc_hist_series_cpp(self, sec_id: int, field: str) -> str: return f"_sec{sec_id}_hist_{field}" + def _security_bar_hist_type(self, field: str) -> str: + return SECURITY_BAR_FIELD_TYPES.get(field, "double") + + def _security_bar_field_expr(self, field: str) -> str: + return SECURITY_BAR_FIELD_EXPRS.get(field, f"bar.{field}") + + @staticmethod + def _security_tuple_result_default(cpp_type: str, tuple_size: int) -> str: + vals = ", ".join("na()" for _ in range(max(0, tuple_size))) + return f"{cpp_type}{{{vals}}}" + def _collect_security_ta_hist_indices(self, node) -> set[int]: """Which security TA call-site indices need HTF history (subscript index >= 1). @@ -396,6 +479,172 @@ def _security_ta_hist_series_names(self, sec_id: int) -> list[str]: names.append(self._security_ta_hist_series_cpp(variant["member_name"])) return names + def _collect_security_expr_hist_subscripts( + self, node, resolving: set[str] | None = None + ) -> list[Subscript]: + """Subscripted helper-call results needing security-context history.""" + if node is None: + return [] + if resolving is None: + resolving = set() + + out: list[Subscript] = [] + seen: set[int] = set() + + def add(n: Subscript) -> None: + key = id(n) + if key not in seen: + seen.add(key) + out.append(n) + + def walk(n) -> None: + if n is None: + return + if isinstance(n, Identifier): + global_expr_map = getattr(self.ctx, "global_expr_map", {}) or {} + if n.name in global_expr_map and n.name not in resolving: + resolving.add(n.name) + walk(global_expr_map[n.name]) + resolving.remove(n.name) + return + if ( + isinstance(n, Subscript) + and isinstance(n.object, FuncCall) + and self._get_ta_site(n.object) is None + ): + add(n) + if isinstance(n, (list, tuple)): + for x in n: + walk(x) + return + for _k, v in getattr(n, "__dict__", {}).items(): + if isinstance(v, ASTNode): + walk(v) + elif isinstance(v, (list, tuple)): + for x in v: + if isinstance(x, ASTNode): + walk(x) + + walk(node) + return out + + def _security_expr_hist_series_names(self, sec_id: int) -> list[str]: + names = [] + for (sid, _node_id), meta in sorted(self._security_expr_hist_by_node.items()): + if sid == sec_id: + names.append(meta["name"]) + return names + + def _emit_security_expr_hist_members( + self, sec_id: int, expr_node, lines: list[str], mbb_suffix: str + ) -> None: + for idx, node in enumerate(self._collect_security_expr_hist_subscripts(expr_node)): + cpp_t = self._infer_type(node.object) + if cpp_t not in ("double", "int", "bool"): + cpp_t = "double" + name = f"_sec{sec_id}_expr_hist_{idx}" + self._security_expr_hist_by_node[(sec_id, id(node))] = { + "name": name, + "type": cpp_t, + } + lines.append(f" Series<{cpp_t}> {name}{mbb_suffix};") + + def _build_security_math_call( + self, + sec_id: int, + func_name: str, + node: FuncCall, + ta_range, + ta_results: dict, + resolving: set[str], + security_mutable_names: set[str], + helper_binding_stack: tuple[dict[str, ASTNode], ...], + emitted_lines: list[str] | None, + ) -> str: + visit = lambda arg: self._build_security_expr( + sec_id, + arg, + ta_range, + ta_results, + resolving, + security_mutable_names, + helper_binding_stack, + emitted_lines, + ) + args = _merge_kwargs( + node.args, + node.kwargs, + sigs.get_param_names("math", func_name), + visit, + ) + if func_name == "round" and len(args) == 2: + return f"(std::round({args[0]} * std::pow(10.0, {args[1]})) / std::pow(10.0, {args[1]}))" + if func_name == "round_to_mintick": + x = args[0] if args else "0.0" + return f"round_to_mintick({x})" + if func_name == "todegrees": + x = args[0] if args else "0.0" + return f"({x} * 180.0 / M_PI)" + if func_name == "toradians": + x = args[0] if args else "0.0" + return f"({x} * M_PI / 180.0)" + if func_name == "random": + lo = args[0] if len(args) > 0 else "0.0" + hi = args[1] if len(args) > 1 else "1.0" + seed = args[2] if len(args) > 2 else "0" + call_site = self._random_call_counter + self._random_call_counter += 1 + return f"pine_random({lo}, {call_site}u, {hi}, (uint32_t)({seed}), bar_index_)" + if func_name == "avg" and len(args) > 2: + sum_expr = " + ".join(f"(double)({a})" for a in args) + return f"(({sum_expr}) / {len(args)}.0)" + if func_name in ("min", "max"): + return _math_minmax_na_expr(func_name, args) + if func_name in MATH_FUNC_MAP: + mapped = MATH_FUNC_MAP[func_name] + if "{0}" in mapped: + return mapped.format(*args) + return f"{mapped}({', '.join(args)})" + return f"0.0 /* unsupported: math.{func_name} */" + + def _security_timeframe_expr(self, sec_id: int) -> str: + """C++ expression for the timeframe of a request.security evaluator.""" + info = self._security_eval_info[sec_id] + if info.get("tf"): + return f'"{info["tf"]}"' + if info.get("tf_expr"): + return info["tf_expr"] + return "input_tf_" + + def _build_security_timeframe_member(self, sec_id: int, member: str) -> str | None: + """Lower timeframe.* reads inside request.security to the requested TF.""" + tf = self._security_timeframe_expr(sec_id) + if member == "period": + return tf + if member == "main_period": + return "main_period()" + if member == "multiplier": + return f"tf_multiplier({tf})" + if member == "isintraday": + return f"tf_is_intraday({tf})" + if member == "isminutes": + return f"(tf_is_intraday({tf}) && !tf_is_seconds({tf}))" + if member == "isdaily": + return f"tf_is_daily({tf})" + if member == "isweekly": + return f"tf_is_weekly({tf})" + if member == "ismonthly": + return f"tf_is_monthly({tf})" + if member == "isdwm": + return f"(tf_is_daily({tf}) || tf_is_weekly({tf}) || tf_is_monthly({tf}))" + if member == "isseconds": + return f"tf_is_seconds({tf})" + if member == "in_seconds": + return f"tf_to_seconds({tf})" + if member == "isticks": + return "false" + return None + @staticmethod def _security_series_binding(series_name: str) -> str: return f"@series:{series_name}" @@ -1335,7 +1584,7 @@ def _emit_security_ohlc_hist_pushes(self, sec_id: int, lines: list[str]) -> None lines.append(" if (is_complete) {") for field in fields: lines.append( - f" {self._security_ohlc_hist_series_cpp(sec_id, field)}.push(bar.{field});" + f" {self._security_ohlc_hist_series_cpp(sec_id, field)}.push({self._security_bar_field_expr(field)});" ) lines.append(" }") @@ -1535,6 +1784,8 @@ def emit_security_ta(indices: list[int]) -> None: ) for name in self._security_ta_hist_series_names(sec_id): lines.append(f" {name}.clear();") + for name in self._security_expr_hist_series_names(sec_id): + lines.append(f" {name}.clear();") lines.append(" break;") continue if returns_tuple and tuple_size and tuple_size > 0 and isinstance(expr_node, TupleLiteral): @@ -1559,11 +1810,39 @@ def emit_security_ta(indices: list[int]) -> None: ) for name in self._security_ta_hist_series_names(sec_id): lines.append(f" {name}.clear();") + for name in self._security_expr_hist_series_names(sec_id): + lines.append(f" {name}.clear();") + lines.append(" break;") + elif returns_tuple and tuple_size and tuple_size > 0: + site = self._get_ta_site(expr_node) + ta_name = self._ta_name_from_site(site) if site is not None else "" + ctype = { + "macd": "ta::MACDResult", + "supertrend": "ta::SupertrendResult", + "dmi": "ta::DMIResult", + "bb": "ta::BBResult", + "kc": "ta::KCResult", + "vwap_bands": "ta::VWAPBandsResult", + }.get(ta_name, "std::tuple") + lines.append(f" case {sec_id}:") + lines.append( + f" _req_sec_{sec_id} = " + f"{self._security_tuple_result_default(ctype, tuple_size)};" + ) + for field in sorted(self._security_ohlc_hist_fields_by_sec.get(sec_id, ())): + lines.append( + f" {self._security_ohlc_hist_series_cpp(sec_id, field)}.clear();" + ) + for name in self._security_ta_hist_series_names(sec_id): + lines.append(f" {name}.clear();") + for name in self._security_expr_hist_series_names(sec_id): + lines.append(f" {name}.clear();") lines.append(" break;") else: hist = self._security_ohlc_hist_fields_by_sec.get(sec_id, ()) ta_hist_names = self._security_ta_hist_series_names(sec_id) - if hist or ta_hist_names: + expr_hist_names = self._security_expr_hist_series_names(sec_id) + if hist or ta_hist_names or expr_hist_names: lines.append(f" case {sec_id}:") lines.append(f" _req_sec_{sec_id} = na();") for field in sorted(hist): @@ -1572,6 +1851,8 @@ def emit_security_ta(indices: list[int]) -> None: ) for name in ta_hist_names: lines.append(f" {name}.clear();") + for name in expr_hist_names: + lines.append(f" {name}.clear();") lines.append(" break;") else: lines.append(f" case {sec_id}: _req_sec_{sec_id} = na(); break;") @@ -1619,9 +1900,7 @@ def _build_security_expr( emitted_lines, ) bar_fields = { - "close": "bar.close", "high": "bar.high", - "low": "bar.low", "open": "bar.open", - "volume": "bar.volume", + **SECURITY_BAR_FIELD_EXPRS, "hl2": "((bar.high + bar.low) / 2.0)", "hlc3": "((bar.high + bar.low + bar.close) / 3.0)", "ohlc4": "((bar.open + bar.high + bar.low + bar.close) / 4.0)", @@ -1652,6 +1931,15 @@ def _build_security_expr( resolving.remove(expr_node.name) return resolved + if ( + isinstance(expr_node, MemberAccess) + and isinstance(expr_node.object, Identifier) + and expr_node.object.name == "timeframe" + ): + resolved = self._build_security_timeframe_member(sec_id, expr_node.member) + if resolved is not None: + return resolved + if isinstance(expr_node, Subscript): index_cpp = self._build_security_expr( sec_id, @@ -1682,29 +1970,33 @@ def _build_security_expr( emitted_lines, ) return f"{obj_cpp}[{index_cpp}]" - if expr_node.object.name in SECURITY_OHLC_BAR_FIELDS: + if expr_node.object.name in SECURITY_BAR_FIELDS: + field = expr_node.object.name idx_lit = self._literal_int_for_security_index(expr_node.index) if idx_lit is not None: - bar_map = { - "open": "bar.open", - "high": "bar.high", - "low": "bar.low", - "close": "bar.close", - "volume": "bar.volume", - } if idx_lit == 0: - return bar_map[expr_node.object.name] + return self._security_bar_field_expr(field) if idx_lit >= 1: # lookahead_off: we evaluate when an HTF bar completes; `bar` is that - # bar. On the HTF series, high[0]/close is the current (just-finished) - # bar; high[1] is one HTF bar back = hist[field][0] *before* we push - # `bar` (Series [0] = most recent prior push). high[k] -> hist[k-1]. - field = expr_node.object.name + # bar. On the HTF series, high[0]/time[0] is the current + # (just-finished) bar; high[1]/time[1] is one HTF bar back + # = hist[field][0] *before* we push `bar` (Series [0] = + # most recent prior push). field[k] -> hist[k-1]. hist = self._security_ohlc_hist_series_cpp(sec_id, field) return f"{hist}[{idx_lit - 1}]" - self._codegen_error( - expr_node, - "request.security() OHLC history index must be a literal integer (e.g. high[1])", + if idx_lit is not None: + self._codegen_error( + expr_node, + "request.security() bar-field history index must be non-negative", + ) + hist = self._security_ohlc_hist_series_cpp(sec_id, field) + cpp_t = self._security_bar_hist_type(field) + current = self._security_bar_field_expr(field) + return ( + f"([&]() -> {cpp_t} {{ " + f"int _hidx = (int)({index_cpp}); " + f"return (_hidx <= 0) ? {current} : {hist}[_hidx - 1]; " + f"}}())" ) # Indirect TA binding: ``v = ta.ema(close, 55)`` then @@ -1734,6 +2026,31 @@ def _build_security_expr( ) resolving.remove(expr_node.object.name) return resolved + if ( + isinstance(expr_node.object, FuncCall) + and self._get_ta_site(expr_node.object) is None + ): + meta = self._security_expr_hist_by_node.get((sec_id, id(expr_node))) + hist = meta["name"] if meta else f"_sec{sec_id}_expr_hist_missing" + cpp_t = meta["type"] if meta else "double" + inner = self._build_security_expr( + sec_id, + expr_node.object, + ta_range, + ta_results, + resolving, + security_mutable_names, + helper_binding_stack, + emitted_lines, + ) + return ( + f"([&]() -> {cpp_t} {{ " + f"{cpp_t} _hv = ({inner}); " + f"int _hidx = (int)({index_cpp}); " + f"{cpp_t} _out = (_hidx <= 0) ? _hv : {hist}[_hidx - 1]; " + f"if (is_complete) {hist}.push(_hv); " + f"return _out; }}())" + ) ta_site = self._get_ta_site(expr_node.object) if ta_site is not None: # ``ta.(...)[k]`` inside request.security(): the inner TA call @@ -1850,6 +2167,24 @@ def _build_security_expr( resolving.remove(call_key) return resolved + if ( + isinstance(expr_node, FuncCall) + and isinstance(expr_node.callee, MemberAccess) + and isinstance(expr_node.callee.object, Identifier) + and expr_node.callee.object.name == "math" + ): + return self._build_security_math_call( + sec_id, + expr_node.callee.member, + expr_node, + ta_range, + ta_results, + resolving, + security_mutable_names, + helper_binding_stack, + emitted_lines, + ) + site = self._get_ta_site(expr_node) if site: idx = self._ta_index_by_site_id.get(id(site)) diff --git a/pineforge_codegen/codegen/ta.py b/pineforge_codegen/codegen/ta.py index 00a6858..aff5213 100644 --- a/pineforge_codegen/codegen/ta.py +++ b/pineforge_codegen/codegen/ta.py @@ -28,7 +28,9 @@ from typing import TYPE_CHECKING from ..ast_nodes import ( - Assignment, BinOp, ExprStmt, FuncCall, Ternary, UnaryOp, VarDecl, + Assignment, BinOp, BoolLiteral, ColorLiteral, ExprStmt, FuncCall, + Identifier, MemberAccess, NaLiteral, NumberLiteral, StringLiteral, + Subscript, Ternary, TupleLiteral, UnaryOp, VarDecl, ) from .tables import TA_IMPLICIT_APPEND, TA_IMPLICIT_COMPUTE_FULL @@ -221,6 +223,90 @@ def _ta_compute_args_for_site(self, site: "TACallSite") -> str: return "" + # ------------------------------------------------------------------ + # Precalculation safety + # ------------------------------------------------------------------ + + _PRECALC_BAR_IDENTIFIERS = { + "open", "high", "low", "close", "volume", + "hl2", "hlc3", "ohlc4", "hlcc4", + "time", "time_close", "bar_index", + } + + def _is_precalc_replayed_source_var(self, name: str) -> bool: + """True for top-level ``x = input.source(...)`` variables replayed in + ``precalculate()``. + + The precompute loop explicitly advances native source series and then + replays those source-input assignments before computing static TA + sites. Other user aliases, even when they are statically derived from + bar data (``src = close`` / ``ha_close = close``), are not replayed + there and must therefore use the normal per-bar TA path.""" + ast = getattr(self.ctx, "ast", None) + for stmt in getattr(ast, "body", ()): + if ( + isinstance(stmt, VarDecl) + and stmt.name == name + and isinstance(stmt.value, FuncCall) + and self._is_source_input(stmt.value) + ): + return True + return False + + def _expr_safe_for_ta_precalc(self, expr) -> bool: + if expr is None: + return True + if isinstance(expr, (NumberLiteral, StringLiteral, BoolLiteral, NaLiteral, ColorLiteral)): + return True + if isinstance(expr, Identifier): + if expr.name in self._PRECALC_BAR_IDENTIFIERS: + return True + if self._is_precalc_replayed_source_var(expr.name): + return True + if expr.name in getattr(self.ctx, "series_vars", set()): + return False + return expr.name in getattr(self, "_static_vars", set()) + if isinstance(expr, MemberAccess): + if isinstance(expr.object, Identifier) and ( + expr.object.name.startswith("input") or expr.object.name in getattr(self, "_enum_defs", {}) + ): + return True + return self._expr_safe_for_ta_precalc(expr.object) + if isinstance(expr, BinOp): + return self._expr_safe_for_ta_precalc(expr.left) and self._expr_safe_for_ta_precalc(expr.right) + if isinstance(expr, UnaryOp): + return self._expr_safe_for_ta_precalc(expr.operand) + if isinstance(expr, Ternary): + return ( + self._expr_safe_for_ta_precalc(expr.condition) + and self._expr_safe_for_ta_precalc(expr.true_val) + and self._expr_safe_for_ta_precalc(expr.false_val) + ) + if isinstance(expr, Subscript): + return self._expr_safe_for_ta_precalc(expr.object) and self._expr_safe_for_ta_precalc(expr.index) + if isinstance(expr, TupleLiteral): + return all(self._expr_safe_for_ta_precalc(elem) for elem in expr.elements) + if isinstance(expr, FuncCall): + if isinstance(expr.callee, MemberAccess) and isinstance(expr.callee.object, Identifier): + if expr.callee.object.name in ("math", "str", "color"): + return all(self._expr_safe_for_ta_precalc(arg) for arg in expr.args) + return False + return False + + def _ta_site_uses_precalc(self, site: "TACallSite") -> bool: + """Whether a static TA site can safely read from ``_precalc_*``. + + Static-ness from the analyzer means the expression can be represented + from bar data and constants, but the standalone precompute loop only + replays a narrow subset of per-bar assignments. A user alias such as + ``ha_close = close`` is static in that analyzer sense, yet its Series is + empty during precompute, so ``ta.stdev(ha_close, 20)`` precalculates as + all-``na``. Opting that site out preserves correctness; it simply uses + the ordinary stateful TA object during ``on_bar``.""" + if not getattr(site, "is_static", False): + return False + return all(self._expr_safe_for_ta_precalc(arg) for arg in site.compute_args) + def _security_ta_compute_args_for_site( self, sec_id: int, diff --git a/pineforge_codegen/codegen/tables.py b/pineforge_codegen/codegen/tables.py index 243cb9d..5c217ff 100644 --- a/pineforge_codegen/codegen/tables.py +++ b/pineforge_codegen/codegen/tables.py @@ -55,7 +55,7 @@ def tz_time_field_lambda(field_expr: str, ts_arg: str, tz_arg: str) -> str: """ return ( "[&]() -> int { " - f"std::string _tz = ({tz_arg}); " + f"std::string _tz = pineforge::normalize_timezone_for_posix(({tz_arg})); " f"time_t _secs = (time_t)(({ts_arg}) / 1000); " "struct tm tm_buf; " "if (_tz.empty() || _tz == \"UTC\" || _tz == \"Etc/UTC\") { " @@ -76,11 +76,11 @@ def tz_time_field_lambda(field_expr: str, ts_arg: str, tz_arg: str) -> str: BAR_BUILTINS = { - "bar_index": "bar_index_", + "bar_index": "pine_bar_index()", "time": "current_bar_.timestamp", "time_close": "time_close()", "timenow": "current_bar_.timestamp", - "last_bar_index": "last_bar_index_", + "last_bar_index": "pine_last_bar_index()", "last_bar_time": "last_bar_time_", # time_tradingday: Unix-ms of the session-open of the trading day that # contains the current bar. Backed by pine_time_tradingday() in the engine. @@ -113,7 +113,27 @@ def tz_time_field_lambda(field_expr: str, ts_arg: str, tz_arg: str) -> str: "ohlc4": "((current_bar_.open + current_bar_.high + current_bar_.low + current_bar_.close) / 4.0)", } -# OHLCV identifiers that refer to the *security* (HTF) bar inside ``request.security()``. +# Bar identifiers that refer to the *security* (HTF) bar inside +# ``request.security()``. ``time`` is the HTF bar-open timestamp. +SECURITY_BAR_FIELD_EXPRS = { + "open": "bar.open", + "high": "bar.high", + "low": "bar.low", + "close": "bar.close", + "volume": "bar.volume", + "time": "bar.timestamp", +} +SECURITY_BAR_FIELD_TYPES = { + "open": "double", + "high": "double", + "low": "double", + "close": "double", + "volume": "double", + "time": "int64_t", +} +SECURITY_BAR_FIELDS = frozenset(SECURITY_BAR_FIELD_EXPRS) + +# Backwards-compatible name for consumers that only need the OHLCV subset. SECURITY_OHLC_BAR_FIELDS = frozenset({"open", "high", "low", "close", "volume"}) # Generated C++ runtime function names referenced by the codegen as string @@ -634,6 +654,33 @@ def _matrix_add_col(m: str, args: list) -> str: # Math / String dispatch # --------------------------------------------------------------------------- +def _math_minmax_na_expr(func_name: str, args: list[str]) -> str: + """Emit Pine-compatible math.min/math.max with na propagation. + + ``std::min``/``std::max`` do not propagate NaN consistently because their + comparison is specified in terms of ``operator<``. Pine math.min/max return + ``na`` when any operand is ``na``, so generated clamp-style expressions + like ``math.max(-1, math.min(1, na))`` must stay ``na``. + """ + if not args: + return "na()" + if len(args) == 1: + return f"(double)({args[0]})" + op = "std::max" if func_name == "max" else "std::min" + decls = " ".join( + f"double _v{i} = (double)({arg});" for i, arg in enumerate(args) + ) + guard = " || ".join(f"is_na(_v{i})" for i in range(len(args))) + updates = " ".join( + f"_out = {op}(_out, _v{i});" for i in range(1, len(args)) + ) + return ( + f"([&]() -> double {{ {decls} " + f"if ({guard}) return na(); " + f"double _out = _v0; {updates} return _out; }}())" + ) + + MATH_FUNC_MAP = { "abs": "std::abs", "max": "std::max", "min": "std::min", "ceil": "std::ceil", "floor": "std::floor", "round": "std::round", diff --git a/pineforge_codegen/codegen/types.py b/pineforge_codegen/codegen/types.py index adf79f0..af26683 100644 --- a/pineforge_codegen/codegen/types.py +++ b/pineforge_codegen/codegen/types.py @@ -210,8 +210,14 @@ def _type_spec_from_expr(self, node) -> TypeSpec | None: if isinstance(node, StringLiteral): return TypeSpec.primitive("string") if isinstance(node, Identifier): + loop_specs = getattr(self, "_current_loop_var_specs", None) + if loop_specs and node.name in loop_specs: + return loop_specs[node.name] if node.name in self._collection_types: return self._collection_types[node.name] + param_specs = getattr(self, "_current_func_param_specs", {}) + if node.name in param_specs: + return param_specs[node.name] if node.name in self._udt_var_types: return TypeSpec.udt(self._udt_var_types[node.name]) # Drawing-typed method/function parameter (L.6d / U.5): a ``line ln`` @@ -512,12 +518,29 @@ def _is_int64_builtin_init(self, name: str) -> bool: def _is_udt_lvalue(self, expr) -> str | None: """If ``expr`` is a *user-defined* UDT lvalue (a bare ``Identifier`` that - names a class-scope ``var``/global UDT member, e.g. ``wyckoffSwingLow``), - return its UDT type name; else ``None``. + names a class-scope ``var``/global UDT member, e.g. ``wyckoffSwingLow``, + or an element selected from ``array``), return its UDT type name; + else ``None``. Pine UDTs are reference types, so a local initialised from such an lvalue and then mutated through must write back to the global. Drawing UDTs are handled by the separate ``_uses_drawing`` path and are excluded here.""" + if isinstance(expr, FuncCall): + callee = expr.callee + func_name, namespace = self._resolve_callee(callee) + receiver = None + if namespace == "array" and func_name in ("get", "first", "last") and expr.args: + receiver = expr.args[0] + elif (isinstance(callee, MemberAccess) + and func_name in ("get", "first", "last")): + receiver = callee.object + if receiver is not None: + spec = self._type_spec_from_expr(receiver) + elem = spec.element if spec is not None and spec.kind == "array" else None + if (elem is not None and elem.kind == "udt" and elem.name in self._udt_defs + and elem.name not in DRAWING_TYPE_TO_CPP): + return elem.name + return None if not isinstance(expr, Identifier): return None udt_t = self._udt_var_types.get(expr.name) @@ -791,6 +814,8 @@ def _infer_type(self, node) -> str: return "double" if node.name in self._current_func_param_types: return self._current_func_param_types[node.name] + if node.name in getattr(self, "_current_func_local_types", {}): + return self._current_func_local_types[node.name] sym = self.ctx.symbols.resolve(node.name) if sym is not None and getattr(sym, "type_spec", None) is not None: return self._type_spec_to_cpp(sym.type_spec) @@ -829,9 +854,11 @@ def _infer_type(self, node) -> str: if namespace == "str": if func_name == "split": return "std::vector" + if func_name in ("contains", "startswith", "endswith"): + return "bool" if func_name == "tonumber": return "double" - if func_name == "length": + if func_name in ("length", "pos"): return "int" return "std::string" if namespace == "ta" and func_name == "pivot_point_levels": @@ -881,6 +908,12 @@ def _infer_type(self, node) -> str: # pine_str_tostring); bare reads must declare std::string. if ename == "format": return "std::string" + if ename == "timeframe": + if node.member in ("period", "main_period"): + return "std::string" + if node.member == "multiplier": + return "int" + return "bool" # syminfo.* type inference: look up in SYMINFO_MEMBER_MAP # and derive C++ type from the expression (na() or function call). if ename == "syminfo": @@ -888,6 +921,9 @@ def _infer_type(self, node) -> str: sym_key = f"syminfo.{node.member}" if sym_key in _pf_sigs.SYMINFO_VARIABLES: return PINE_TYPE_TO_CPP.get(_pf_sigs.SYMINFO_VARIABLES[sym_key], "double") + spec = self._type_spec_from_expr(node) + if spec is not None: + return self._type_spec_to_cpp(spec) if isinstance(node, Ternary): tt = self._infer_type(node.true_val) ft = self._infer_type(node.false_val) @@ -895,6 +931,12 @@ def _infer_type(self, node) -> str: return tt if tt.startswith("std::vector") else ft if tt == "std::string" or ft == "std::string": return "std::string" + if tt == "double" or ft == "double": + return "double" + if tt == "int64_t" or ft == "int64_t": + return "int64_t" + if tt == "bool" and ft == "bool": + return "bool" return tt # Block-as-expression cases: read the type of the last statement of # the first branch / case; matches Pine semantics for ``x = if...``. diff --git a/pineforge_codegen/codegen/visit_call.py b/pineforge_codegen/codegen/visit_call.py index c84cf35..ab20cdf 100644 --- a/pineforge_codegen/codegen/visit_call.py +++ b/pineforge_codegen/codegen/visit_call.py @@ -136,6 +136,7 @@ FuncCall, Identifier, MemberAccess, + NaLiteral, TupleLiteral, StringLiteral, ) @@ -160,6 +161,7 @@ SKIP_VAR_TYPES, STR_FUNC_MAP, TIME_FIELD_EXPRS, + _math_minmax_na_expr, _merge_kwargs, _merge_kwargs_with_defaults, tz_time_field_lambda, @@ -209,6 +211,44 @@ class CallVisitor: # Function-call dispatch # ------------------------------------------------------------------ + def _array_init_value_expr(self, elem_spec: TypeSpec | None, value_node) -> str: + if isinstance(value_node, NaLiteral): + if elem_spec is not None and elem_spec.kind == "udt": + return self._default_for_spec(elem_spec) + cpp_type = self._type_spec_to_cpp(elem_spec) + if cpp_type in ("double", "int", "int64_t", "bool", "std::string"): + return f"na<{cpp_type}>()" + return self._default_for_spec(elem_spec) + return self._visit_expr(value_node) + + def _array_method_args( + self, method: str, arg_nodes: list, spec: TypeSpec | None, + ) -> list[str]: + elem_spec = ( + spec.element + if spec is not None and spec.kind == "array" and spec.element is not None + else TypeSpec.primitive("float") + ) + value_arg_indexes = { + "set": {1}, + "push": {0}, + "unshift": {0}, + "insert": {1}, + "fill": {0}, + "includes": {0}, + "indexof": {0}, + "lastindexof": {0}, + "binary_search": {0}, + "binary_search_leftmost": {0}, + "binary_search_rightmost": {0}, + }.get(method, set()) + return [ + self._array_init_value_expr(elem_spec, arg) + if idx in value_arg_indexes + else self._visit_expr(arg) + for idx, arg in enumerate(arg_nodes) + ] + def _visit_func_call(self, node: FuncCall) -> str: callee = node.callee if isinstance(callee, MemberAccess): @@ -270,7 +310,9 @@ def _visit_func_call(self, node: FuncCall) -> str: meth = callee.member raw_args = [self._visit_expr(a) for a in node.args] if recv_spec is not None and recv_spec.kind == "array" and meth in ARRAY_METHODS: - return self._array_method_expr(recv, meth, raw_args, recv_spec) + return self._array_method_expr( + recv, meth, self._array_method_args(meth, node.args, recv_spec), recv_spec + ) if recv_spec is not None and recv_spec.kind == "map" and meth in MAP_METHODS: return self._map_method_expr(recv, meth, raw_args, recv_spec) args = ", ".join(raw_args) @@ -294,7 +336,9 @@ def _visit_func_call(self, node: FuncCall) -> str: return self._map_method_expr(m, meth_raw, margs, self._map_spec_for_name(oname)) if oname in self._array_vars and meth_raw in ARRAY_METHODS: arr = self._safe_name(oname) - margs = [self._visit_expr(a) for a in node.args] + margs = self._array_method_args( + meth_raw, node.args, self._array_spec_for_name(oname) + ) return self._array_method_expr(arr, meth_raw, margs, self._array_spec_for_name(oname)) if oname in self._matrix_specs and meth_raw in MATRIX_METHODS: arr = self._safe_name(oname) @@ -387,9 +431,10 @@ def _visit_func_call(self, node: FuncCall) -> str: if site is not None: compute_args = self._ta_compute_args_for_site(site) ta_mem = self._ta_member_name(site) - if getattr(self, "_precalc_loop_active", False) and getattr(site, "is_static", False): + uses_precalc = self._ta_site_uses_precalc(site) + if getattr(self, "_precalc_loop_active", False) and uses_precalc: return f"_precalc_{ta_mem}[i]" - if getattr(site, "is_static", False): + if uses_precalc: return f"(_use_precalc ? _precalc_{ta_mem}[bar_index_] : (is_first_tick_ ? {ta_mem}.compute({compute_args}) : {ta_mem}.recompute({compute_args})))" return f"(is_first_tick_ ? {ta_mem}.compute({compute_args}) : {ta_mem}.recompute({compute_args}))" @@ -451,18 +496,23 @@ def _visit_func_call(self, node: FuncCall) -> str: # Array method syntax: arr.push(val) where namespace is the array variable name if namespace is not None and namespace in self._array_vars and func_name in ARRAY_METHODS: arr = self._safe_name(namespace) - args = [self._visit_expr(a) for a in node.args] - return self._array_method_expr(arr, func_name, args, self._array_spec_for_name(namespace)) + spec = self._array_spec_for_name(namespace) + args = self._array_method_args(func_name, node.args, spec) + return self._array_method_expr(arr, func_name, args, spec) # Array operations — emit proper C++ vector operations if namespace == "array": if func_name in ("new", "new_float", "new_int", "new_bool", "new_string") or func_name in ARRAY_DRAWING_NEW_CTORS: spec = self._type_spec_from_expr(node) or TypeSpec.array(TypeSpec.primitive("float")) cpp_type = self._type_spec_to_cpp(spec) - init_default = self._default_for_spec(spec.element if spec.element is not None else TypeSpec.primitive("float")) + elem_spec = spec.element if spec.element is not None else TypeSpec.primitive("float") + init_default = self._default_for_spec(elem_spec) if node.args: size_arg = self._visit_expr(node.args[0]) - init_val = self._visit_expr(node.args[1]) if len(node.args) > 1 else init_default + if len(node.args) > 1: + init_val = self._array_init_value_expr(elem_spec, node.args[1]) + else: + init_val = init_default return f"{cpp_type}((size_t)({size_arg}), {init_val})" return f"{cpp_type}()" if func_name == "from": @@ -472,8 +522,8 @@ def _visit_func_call(self, node: FuncCall) -> str: # Method calls: array.method(arr, args...) if func_name in ARRAY_METHODS and node.args: arr = self._visit_expr(node.args[0]) - rest = [self._visit_expr(a) for a in node.args[1:]] spec = self._type_spec_from_expr(node.args[0]) + rest = self._array_method_args(func_name, node.args[1:], spec) return self._array_method_expr(arr, func_name, rest, spec) return "0" @@ -699,6 +749,8 @@ def _visit_func_call(self, node: FuncCall) -> str: is_tz_first = True elif isinstance(node.args[0], StringLiteral): is_tz_first = True + elif self._infer_type(node.args[0]) == "std::string": + is_tz_first = True if is_tz_first: # A single string argument is the timestamp(dateString) @@ -748,7 +800,7 @@ def _visit_func_call(self, node: FuncCall) -> str: sc = args[6] if len(args) > 6 else "0" return ( f"[&]() -> int64_t {{ " - f"std::string _tz = ({tz}); " + f"std::string _tz = pineforge::normalize_timezone_for_posix(({tz})); " f"int _yr = ({yr}); int _mo = ({mo}); int _dy = ({dy}); " f"int _hr = ({hr}); int _min = ({mn}); int _sc = ({sc}); " f"static thread_local std::string _last_tz; " @@ -760,6 +812,7 @@ def _visit_func_call(self, node: FuncCall) -> str: f"struct tm t = {{}}; " f"t.tm_year = _yr - 1900; t.tm_mon = _mo - 1; " f"t.tm_mday = _dy; t.tm_hour = _hr; t.tm_min = _min; t.tm_sec = _sc; " + f"t.tm_isdst = -1; " f"int64_t _res; " f"if (_tz.empty() || _tz == \"UTC\" || _tz == \"Etc/UTC\") {{ " f"_res = (int64_t)timegm(&t) * 1000; " @@ -840,7 +893,15 @@ def _visit_func_call(self, node: FuncCall) -> str: if func_name == "float" and namespace is None and node.args: return f"(double)({self._visit_expr(node.args[0])})" if func_name == "bool" and namespace is None and node.args: - return f"(bool)({self._visit_expr(node.args[0])})" + # Pine v6 bools are two-state. Explicit bool(int/float) treats na + # like false, while a raw C++ cast would make NaN truthy. + x = self._visit_expr(node.args[0]) + return ( + f"[&](){{ auto _pf_v = ({x}); " + f"using _pf_t = std::decay_t; " + f"if constexpr (std::is_same_v<_pf_t, bool>) {{ return _pf_v; }} " + f"else {{ return is_na(_pf_v) ? false : (bool)_pf_v; }} }}()" + ) if func_name == "string" and namespace is None and node.args: # Pine string(x) cast — same emission as str.tostring(x), with # string passthrough and TV-style "true"/"false" for bools @@ -1060,17 +1121,31 @@ def _visit_func_call(self, node: FuncCall) -> str: def _visit_arg_for_series(arg_node, arg_idx): """Visit a function argument, returning Series ref for series params.""" - if arg_idx in _func_series_param_indices and isinstance(arg_node, Identifier): - aname = arg_node.name - # Bar field: pass _s_close instead of current_bar_.close - if aname in BAR_FIELDS or aname in BAR_SERIES_PUSH: - return f"_s_{aname}" - # Series var: pass the Series object directly - if aname in self.ctx.series_vars: - safe = self._safe_name(aname) - if self._active_var_remap and safe in self._active_var_remap: - safe = self._active_var_remap[safe] - return safe + if arg_idx in _func_series_param_indices: + if isinstance(arg_node, Identifier): + aname = arg_node.name + # Bar field: pass _s_close instead of current_bar_.close + if aname in BAR_FIELDS or aname in BAR_SERIES_PUSH: + return f"_s_{aname}" + # Series var: pass the Series object directly + if aname in self.ctx.series_vars: + safe = self._safe_name(aname) + if self._active_var_remap and safe in self._active_var_remap: + safe = self._active_var_remap[safe] + return safe + expr_cpp = self._visit_expr(arg_node) + cpp_t = self._infer_type(arg_node) + if cpp_t not in ("double", "int", "bool"): + cpp_t = "double" + return ( + f"([&]() -> const Series<{cpp_t}>& {{ " + f"static thread_local Series<{cpp_t}> _series_arg; " + f"if (is_first_tick_ && bar_index_ == 0) _series_arg.clear(); " + f"{cpp_t} _sv = ({expr_cpp}); " + f"if (is_first_tick_) _series_arg.push(_sv); " + f"else _series_arg.update(_sv); " + f"return _series_arg; }}())" + ) return self._visit_expr(arg_node) if node.kwargs: @@ -1252,17 +1327,17 @@ def _visit_strategy_call(self, func_name: str, node: FuncCall) -> str: qty_val = self._visit_expr(qty_n) if qty_n else "na()" comment = self._visit_expr(comment_n) if comment_n is not None else '""' oca_val = self._visit_expr(oca_name_n) if oca_name_n is not None else '""' + profit_ticks = "na()" + loss_ticks = "na()" if profit_n and not limit_n: - ticks = self._visit_expr(profit_n) - limit_val = f"(position_entry_price_ + (signed_position_size() > 0 ? 1.0 : -1.0) * ({ticks}) * syminfo_mintick_)" + profit_ticks = self._visit_expr(profit_n) if loss_n and not stop_n: - ticks = self._visit_expr(loss_n) - stop_val = f"(position_entry_price_ - (signed_position_size() > 0 ? 1.0 : -1.0) * ({ticks}) * syminfo_mintick_)" + loss_ticks = self._visit_expr(loss_n) return (f"strategy_exit({exit_id}, {from_id}, {limit_val}, {stop_val}, " f"{trail_pts}, {trail_off}, {trail_pr}, {qty_pct}, {comment}, " - f"{qty_val}, {oca_val})") + f"{qty_val}, {oca_val}, {profit_ticks}, {loss_ticks})") close_comment = self._visit_expr(comment_n) if comment_n is not None else '""' return f"strategy_close({exit_id}, {close_comment})" @@ -1483,23 +1558,12 @@ def _visit_math_call(self, func_name: str, node: FuncCall) -> str: if func_name == "avg" and len(args) > 2: sum_expr = " + ".join(f"(double)({a})" for a in args) return f"(({sum_expr}) / {len(args)}.0)" - if func_name == "max" and len(args) > 2: - result = f"std::max((double)({args[0]}), (double)({args[1]}))" - for a in args[2:]: - result = f"std::max({result}, (double)({a}))" - return result - if func_name == "min" and len(args) > 2: - result = f"std::min((double)({args[0]}), (double)({args[1]}))" - for a in args[2:]: - result = f"std::min({result}, (double)({a}))" - return result + if func_name in ("min", "max"): + return _math_minmax_na_expr(func_name, args) if func_name in MATH_FUNC_MAP: mapped = MATH_FUNC_MAP[func_name] if "{0}" in mapped: return mapped.format(*args) - # std::min/std::max require same types — cast to double - if func_name in ("min", "max") and len(args) == 2: - return f"{mapped}((double)({args[0]}), (double)({args[1]}))" return f"{mapped}({', '.join(args)})" # Unknown math.* — safe fallback return f"0.0 /* unsupported: math.{func_name} */" diff --git a/pineforge_codegen/codegen/visit_expr.py b/pineforge_codegen/codegen/visit_expr.py index 5fb35d2..4f6bce6 100644 --- a/pineforge_codegen/codegen/visit_expr.py +++ b/pineforge_codegen/codegen/visit_expr.py @@ -729,6 +729,18 @@ def _visit_binop(self, node: BinOp) -> str: right = self._visit_expr(node.right) cpp_ops = {"and": "&&", "or": "||"} op = cpp_ops.get(node.op, node.op) + if node.op == "+": + lt = self._infer_type(node.left) + rt = self._infer_type(node.right) + if lt == "std::string" or rt == "std::string": + def _as_string(rendered, inferred): + if inferred == "std::string": + return rendered + if inferred == "bool": + return f'(({rendered}) ? std::string("true") : std::string("false"))' + return f"std::to_string({rendered})" + + return f"({_as_string(left, lt)} + {_as_string(right, rt)})" # PineScript % works on floats — use std::fmod in C++ if node.op == "%": return f"std::fmod((double)({left}), (double)({right}))" diff --git a/pineforge_codegen/codegen/visit_stmt.py b/pineforge_codegen/codegen/visit_stmt.py index 041e19f..8e30e92 100644 --- a/pineforge_codegen/codegen/visit_stmt.py +++ b/pineforge_codegen/codegen/visit_stmt.py @@ -258,6 +258,10 @@ def _visit_var_decl(self, node: VarDecl, lines: list[str], pad: str) -> None: # Global-scope non-var vars are class members — emit assignment, not declaration is_global_member = node.name in self._global_member_vars + def remember_local_type(cpp_type: str | None) -> None: + if cpp_type and not is_global_member: + self._current_func_local_types[node.name] = cpp_type + # Check if it is a static (non-series) global member variable already evaluated inside _inputs_initialized_ block is_static_global_input = False if is_global_member and isinstance(node.value, FuncCall) and self._is_input_call(node.value): @@ -457,6 +461,7 @@ def _visit_var_decl(self, node: VarDecl, lines: list[str], pad: str) -> None: if is_global_member: lines.append(f"{pad}{safe} = {cpp_val};") else: + remember_local_type(cpp_type) lines.append(f"{pad}{cpp_type} {safe} = {cpp_val};") @staticmethod @@ -716,40 +721,37 @@ def _visit_for(self, node: ForStmt, lines: list[str], indent: int) -> None: pad = " " * indent start = self._visit_expr(node.start) end = self._visit_expr(node.end) - var = node.var # new AST uses .var instead of .var_name - if node.step is not None: - # Explicit `by` step: unchanged from before — ascending compare - # (matches every existing corpus use, all positive literal steps). - step = self._visit_expr(node.step) - lines.append(f"{pad}for (int {var} = {start}; {var} <= {end}; {var} += {step}) {{") - else: - # No `by` clause: Pine v6 auto-infers the loop direction from - # start/end — descending (step -1) when start > end, else - # ascending (step +1); see the Pine v6 `for` reference. start/end - # are arbitrary runtime expressions (``for i = array.size(arr)-1 - # to 0`` — a common "iterate backward to safely remove an element - # while iterating" idiom), so the direction can't always be - # resolved at codegen time. Compute start/end into locals ONCE - # (avoids re-evaluating a side-effecting expression, same class - # of bug as nz()'s double-eval) and pick the comparison direction - # at runtime from their relative order — this previously always - # emitted an ascending `<=` loop, which never executes when - # start > end (silently dropping the whole loop body). - fid = self._for_counter - self._for_counter += 1 - s_var, e_var = f"_for_start_{fid}", f"_for_end_{fid}" - lines.append(f"{pad}int {s_var} = ({start}), {e_var} = ({end});") - lines.append( - f"{pad}for (int {var} = {s_var}; " - f"({s_var} <= {e_var}) ? ({var} <= {e_var}) : ({var} >= {e_var}); " - f"{var} += ({s_var} <= {e_var}) ? 1 : -1) {{" - ) + step = self._visit_expr(node.step) if node.step is not None else "1" + var = self._safe_name(node.var) # new AST uses .var instead of .var_name + + # Pine infers loop direction from the initial ``from``/``to`` values for + # both implicit and explicit ``by`` loops. The ``by`` value is a positive + # magnitude; descending loops subtract it. ``to`` can change during the + # loop, so refresh the cached end expression after each iteration while + # keeping the initial direction and step fixed. + fid = self._for_counter + self._for_counter += 1 + s_var = f"_for_start_{fid}" + e_var = f"_for_end_{fid}" + step_var = f"_for_step_{fid}" + down_var = f"_for_down_{fid}" + lines.append(f"{pad}int {s_var} = ({start});") + lines.append(f"{pad}int {e_var} = ({end});") + lines.append(f"{pad}int {step_var} = ({step});") + lines.append(f"{pad}if ({step_var} < 0) {step_var} = -{step_var};") + lines.append(f"{pad}if ({step_var} == 0) {step_var} = 1;") + lines.append(f"{pad}const bool {down_var} = ({s_var} > {e_var});") + lines.append( + f"{pad}for (int {var} = {s_var}; " + f"({down_var} ? ({var} >= {e_var}) : ({var} <= {e_var})); " + f"{var} += ({down_var} ? -{step_var} : {step_var}), {e_var} = ({end})) {{" + ) # Register the loop counter so reads of it inside the body resolve (the # unknown-identifier guard in _visit_ident would otherwise flag it). saved_loop = self._current_loop_vars self._current_loop_vars = set(self._current_loop_vars) - if var: - self._current_loop_vars.add(var) + if node.var: + self._current_loop_vars.add(node.var) _blk_saved = self._push_block_var_remap(node) try: for s in node.body: @@ -787,9 +789,19 @@ def _visit_for_in(self, node, lines: list[str], indent: int) -> None: pad = " " * indent iterable = self._visit_expr(node.iterable) saved_loop = self._current_loop_vars + saved_loop_specs = self._current_loop_var_specs self._current_loop_vars = set(self._current_loop_vars) + self._current_loop_var_specs = dict(self._current_loop_var_specs) + iterable_spec = self._type_spec_from_expr(node.iterable) + elem_spec = ( + iterable_spec.element + if iterable_spec is not None and iterable_spec.kind == "array" + else None + ) if node.var: self._current_loop_vars.add(node.var) + if elem_spec is not None: + self._current_loop_var_specs[node.var] = elem_spec if node.vars: for v in node.vars: if v != "_": @@ -809,6 +821,7 @@ def _visit_for_in(self, node, lines: list[str], indent: int) -> None: self._pop_block_var_remap(_blk_saved) lines.append(f"{pad}}}") self._current_loop_vars = saved_loop + self._current_loop_var_specs = saved_loop_specs def _visit_while(self, node: WhileStmt, lines: list[str], indent: int) -> None: pad = " " * indent diff --git a/pineforge_codegen/parser.py b/pineforge_codegen/parser.py index 74cbbd6..6d1a957 100644 --- a/pineforge_codegen/parser.py +++ b/pineforge_codegen/parser.py @@ -166,6 +166,28 @@ def _recover(self) -> None: # ------------------------------------------------------------------ def _parse_statement(self): + stmt = self._parse_single_statement() + if not self._check(TokenType.COMMA): + return stmt + + stmts: list = [] + self._extend_statement_list(stmts, stmt) + while self._match(TokenType.COMMA): + if self._check(TokenType.NEWLINE) or self._check(TokenType.DEDENT) or self._at_end(): + break + self._extend_statement_list(stmts, self._parse_single_statement()) + return stmts + + @staticmethod + def _extend_statement_list(stmts: list, stmt) -> None: + if stmt is None: + return + if isinstance(stmt, list): + stmts.extend(stmt) + else: + stmts.append(stmt) + + def _parse_single_statement(self): cur = self._current() # Control flow keywords @@ -374,12 +396,24 @@ def _parse_var_decl(self) -> VarDecl | list: first = VarDecl(name=name_tok.value, value=value) self._set_loc(first, start_tok) - # Check for comma-separated additional declarations: x=1, y=2, z=3 - if not self._check(TokenType.COMMA): + # Check for comma-separated additional declarations: x=1, y=2, z=3. + # Other comma-separated simple statements (``a := 1, b := 2`` or + # ``array.fill(a, na), array.set(a, 0, 1)``) are handled by the + # statement wrapper above, so do not greedily consume their comma. + if not ( + self._check(TokenType.COMMA) + and self._peek().type == TokenType.IDENT + and self._peek(2).type == TokenType.EQUALS + ): return first decls = [first] - while self._match(TokenType.COMMA): + while ( + self._check(TokenType.COMMA) + and self._peek().type == TokenType.IDENT + and self._peek(2).type == TokenType.EQUALS + ): + self._advance() st = self._current() n = self._consume(TokenType.IDENT) self._consume(TokenType.EQUALS) diff --git a/pineforge_codegen/signatures.py b/pineforge_codegen/signatures.py index 6ad6e8e..2da4d9f 100644 --- a/pineforge_codegen/signatures.py +++ b/pineforge_codegen/signatures.py @@ -133,8 +133,8 @@ def _ta(short_name: str, *sigs: FuncSig) -> None: # --- Volatility & Range --- _ta("atr", _sig([("length", I)])) _ta("tr", _sig([("handle_na", B, False)])) -_ta("stdev", _sig([("source", F), ("length", I)])) -_ta("variance",_sig([("source", F), ("length", I)])) +_ta("stdev", _sig([("source", F), ("length", I), ("biased", B, True)])) +_ta("variance",_sig([("source", F), ("length", I), ("biased", B, True)])) # --- Trend --- _ta("supertrend", _sig([("factor", F), ("atrPeriod", I)], diff --git a/pineforge_codegen/support_checker.py b/pineforge_codegen/support_checker.py index 15315b0..f9cffb1 100644 --- a/pineforge_codegen/support_checker.py +++ b/pineforge_codegen/support_checker.py @@ -462,6 +462,12 @@ def __init__(self, ast: Program, filename: str = "") -> None: # (``panel.cell(...)``) is a visual sink whose args may carry visual # constants, so it routes through ``_visit_children_const_ok``. self._visual_container_vars: set[str] = set() + # User helpers whose body is only a visual sink (for example + # ``cell(table t, ..., align) => t.cell(..., text_halign=align)``). + # A call to such a helper is itself a visual context, so style + # constants passed through its arguments are safe and should not trip + # the free-expression constant-namespace rejection. + self._visual_sink_funcs: set[str] = set() self._drawing_tuple_vars: set[str] = set() self._func_tuple_drawing_returns: dict[str, list[bool]] = {} # Track whether we are inside an if/ternary condition expression. @@ -520,6 +526,8 @@ def _collect_user_definitions(self, ast: Program) -> None: if tuple_returns: self._func_tuple_drawing_returns[stmt.name] = tuple_returns self._collect_visual_container_params(stmt) + if self._func_body_is_visual_sink(stmt): + self._visual_sink_funcs.add(stmt.name) elif isinstance(stmt, MethodDef): self._user_methods.add(stmt.name) self._collect_visual_container_params(stmt) @@ -535,6 +543,27 @@ def _collect_visual_container_params(self, fn) -> None: if hint and str(hint).replace(" ", "") in _VISUAL_CONTAINER_TYPES: self._visual_container_vars.add(pname) + def _expr_is_visual_sink_call(self, expr: ASTNode | None) -> bool: + if not isinstance(expr, FuncCall): + return False + ns, name = _qualified_name(expr.callee) + if ns is None and name in SKIP_FUNC_NAMES: + return True + if ns is not None and ns in SKIP_NAMESPACES: + return True + if ns in _DRAWING_NOOP_BY_NS and name in _DRAWING_NOOP_BY_NS[ns]: + return True + return ns is not None and ns in self._visual_container_vars + + def _func_body_is_visual_sink(self, fn: FuncDef | MethodDef) -> bool: + if not fn.body: + return False + for stmt in fn.body: + expr = stmt.expr if isinstance(stmt, ExprStmt) else None + if not self._expr_is_visual_sink_call(expr): + return False + return True + @staticmethod def _type_name_contains_drawing(type_name: str | None) -> bool: if not type_name: @@ -1207,6 +1236,10 @@ def _visit_FuncCall(self, node: FuncCall) -> None: self._visit_children_const_ok(node) return + if ns is None and name in self._visual_sink_funcs: + self._visit_children_const_ok(node) + return + self._visit_children(node) def _visit_Identifier(self, node: Identifier) -> None: diff --git a/tests/golden/matrix_eigen_pca.cpp b/tests/golden/matrix_eigen_pca.cpp index b7ab028..9b2de1f 100644 --- a/tests/golden/matrix_eigen_pca.cpp +++ b/tests/golden/matrix_eigen_pca.cpp @@ -96,9 +96,7 @@ static inline std::string _pf_derive_country(const std::string& tickerid) { class GeneratedStrategy : public BacktestEngine { public: ta::SMA _ta_sma_1; - std::vector _precalc__ta_sma_1; ta::SMA _ta_sma_2; - std::vector _precalc__ta_sma_2; ta::SMA _ta_sma_3; ta::SMA _ta_sma_4; ta::SMA _ta_sma_5; @@ -200,47 +198,6 @@ class GeneratedStrategy : public BacktestEngine { } } - void precalculate(const Bar* bars, int n) { - _use_precalc = false; - if (n <= 0 || bars == nullptr) return; - - _precalc__ta_sma_1.resize(n); - _precalc__ta_sma_2.resize(n); - - _ta_sma_1 = ta::SMA(14); - _ta_sma_2 = ta::SMA(14); - - - for (int i = 0; i < n; ++i) { - _precalc__ta_sma_1[i] = _ta_sma_1.compute(v1); - _precalc__ta_sma_2[i] = _ta_sma_2.compute(v2); - } - - _ta_sma_1 = ta::SMA(14); - _ta_sma_2 = ta::SMA(14); - - _use_precalc = true; - } - - void run(const Bar* bars, int n) { - precalculate(bars, n); - BacktestEngine::run(bars, n); - } - - void run(const Bar* input_bars, int n_input, - const std::string& input_tf, - const std::string& script_tf, - bool bar_magnifier = false, - int magnifier_samples = 4, - MagnifierDistribution magnifier_dist = MagnifierDistribution::ENDPOINTS) { - bool needs_dynamic = bar_magnifier || !input_tf.empty() || !script_tf.empty(); - if (needs_dynamic) { - _use_precalc = false; - } else { - precalculate(input_bars, n_input); - } - BacktestEngine::run(input_bars, n_input, input_tf, script_tf, bar_magnifier, magnifier_samples, magnifier_dist); - } }; diff --git a/tests/test_codegen_audit_fixes.py b/tests/test_codegen_audit_fixes.py index 31cba7d..07a364b 100644 --- a/tests/test_codegen_audit_fixes.py +++ b/tests/test_codegen_audit_fixes.py @@ -82,6 +82,7 @@ def test_timestamp_numeric_form_works(): def test_timestamp_tz_form_works(): cpp = _gen('t = timestamp("GMT+2", 2020, 1, 2)\nplot(close)\n') + assert "normalize_timezone_for_posix" in cpp assert "mktime" in cpp diff --git a/tests/test_codegen_drawing_data.py b/tests/test_codegen_drawing_data.py index f29b807..3991e69 100644 --- a/tests/test_codegen_drawing_data.py +++ b/tests/test_codegen_drawing_data.py @@ -36,8 +36,8 @@ def test_line_new_drops_visual_kwargs_keeps_geometry(): "xloc=xloc.bar_index, extend=extend.both, color=color.red, " "style=line.style_dashed, width=2)" ) - assert ("pf_line_new(_pf_lines_, (int64_t)(bar_index_), (double)(current_bar_.close), " - "(int64_t)((bar_index_ + 1)), (double)(current_bar_.open), " + assert ("pf_line_new(_pf_lines_, (int64_t)(pine_bar_index()), (double)(current_bar_.close), " + "(int64_t)((pine_bar_index() + 1)), (double)(current_bar_.open), " "XLoc::bar_index, true, true)") in cpp assert "color" not in cpp.split("pf_line_new")[1].split(";")[0] @@ -56,7 +56,7 @@ def test_line_getters_setters_delete_copy(): assert "double y" in cpp and "pf_line_get_y2(_pf_lines_, ln)" in cpp assert "int64_t x" in cpp and "pf_line_get_x1(_pf_lines_, ln)" in cpp assert "pf_line_set_y2(_pf_lines_, ln, (double)(current_bar_.close))" in cpp - assert "pf_line_set_x2(_pf_lines_, ln, (int64_t)(bar_index_))" in cpp + assert "pf_line_set_x2(_pf_lines_, ln, (int64_t)(pine_bar_index()))" in cpp assert "Line c" in cpp and "pf_line_copy(_pf_lines_, ln)" in cpp assert "pf_line_delete(_pf_lines_, ln)" in cpp @@ -72,8 +72,8 @@ def test_box_new_and_getters(): "l = box.get_left(bx)\n" "box.set_bottom(bx, low)" ) - assert ("pf_box_new(_pf_boxes_, (int64_t)(bar_index_), (double)(current_bar_.high), " - "(int64_t)((bar_index_ + 5)), (double)(current_bar_.low), XLoc::bar_index)") in cpp + assert ("pf_box_new(_pf_boxes_, (int64_t)(pine_bar_index()), (double)(current_bar_.high), " + "(int64_t)((pine_bar_index() + 5)), (double)(current_bar_.low), XLoc::bar_index)") in cpp assert "double t" in cpp and "pf_box_get_top(_pf_boxes_, bx)" in cpp assert "int64_t l" in cpp and "pf_box_get_left(_pf_boxes_, bx)" in cpp assert "pf_box_set_bottom(_pf_boxes_, bx, (double)(current_bar_.low))" in cpp @@ -90,7 +90,7 @@ def test_label_new_text_and_getters(): "s = lb.get_text()\n" "yy = lb.get_y()" ) - assert ('pf_label_new(_pf_labels_, (int64_t)(bar_index_), (double)(current_bar_.close), ' + assert ('pf_label_new(_pf_labels_, (int64_t)(pine_bar_index()), (double)(current_bar_.close), ' 'std::string("hi"), XLoc::bar_index, YLoc::abovebar)') in cpp assert 'pf_label_set_text(_pf_labels_, lb, std::string("bye"))' in cpp assert "std::string s" in cpp and "pf_label_get_text(_pf_labels_, lb)" in cpp @@ -109,8 +109,8 @@ def test_chart_point_and_line_pts_and_linefill(): "lf = linefill.new(ln1, ln2, color.new(color.red, 80))\n" "g = linefill.get_line1(lf)" ) - assert "ChartPoint{ .index=bar_index_, .time=(int64_t)current_bar_.timestamp, .price=(current_bar_.close) }" in cpp - assert "ChartPoint{ .index=(int64_t)((bar_index_ + 3)), .time=na(), .price=(current_bar_.high) }" in cpp + assert "ChartPoint{ .index=(int64_t)(pine_bar_index()), .time=(int64_t)current_bar_.timestamp, .price=(current_bar_.close) }" in cpp + assert "ChartPoint{ .index=(int64_t)((pine_bar_index() + 3)), .time=na(), .price=(current_bar_.high) }" in cpp assert "pf_line_new_pts(_pf_lines_, p1, p2, XLoc::bar_index)" in cpp # linefill drops the color arg. assert "pf_linefill_new(_pf_linefills_, ln1, ln2)" in cpp diff --git a/tests/test_codegen_new.py b/tests/test_codegen_new.py index 3950fb3..b050d49 100644 --- a/tests/test_codegen_new.py +++ b/tests/test_codegen_new.py @@ -226,6 +226,21 @@ def test_timeframe_isdwm_uses_runtime_timeframe_helpers(): assert "x = 0;" not in cpp +def test_timeframe_namespace_uses_requested_tf_inside_security(): + src = ( + '//@version=6\nstrategy("T")\n' + 'w = request.security(syminfo.tickerid, "W", ' + 'timeframe.isintraday ? 1 : timeframe.isweekly ? 2 : 3)\n' + 'm = request.security(syminfo.tickerid, "M", timeframe.ismonthly ? 4 : 5)\n' + 'chart = timeframe.isintraday\n' + ) + cpp = _generate(src) + assert 'tf_is_intraday("W")' in cpp + assert 'tf_is_weekly("W")' in cpp + assert 'tf_is_monthly("M")' in cpp + assert "tf_is_intraday(script_tf_)" in cpp + + def test_hour_two_arg_passes_tz(): """``hour(time, "America/New_York")`` must propagate the tz string into the emitted C++ so the runtime can honor a non-UTC chart. Without this, @@ -238,6 +253,7 @@ def test_hour_two_arg_passes_tz(): ) cpp = _generate(src) assert "America/New_York" in cpp + assert "normalize_timezone_for_posix" in cpp # Two-arg form must use localtime_r (with the TZ env mutation) rather # than just gmtime_r — that is the whole point of the tz argument. assert "localtime_r" in cpp @@ -263,6 +279,7 @@ def test_hour_one_arg_uses_syminfo_timezone(): # TV docs. The default ``SymInfo::timezone`` of "UTC" keeps the # cheap gmtime_r path active for crypto. assert "syminfo_.timezone" in cpp + assert "normalize_timezone_for_posix" in cpp # The chart-display TZ slot must NOT leak into the bar-time lambda; # if it ever does, this test catches the regression. assert "chart_timezone_" not in cpp @@ -717,6 +734,14 @@ def test_math_min_variadic(): assert cpp.count("std::min") >= 2 # nested std::min calls +def test_math_min_max_propagate_na_in_clamp(): + src = "x = math.max(-1.0, math.min(1.0, na))\n" + cpp = _generate_raw(src) + assert "return na();" in cpp + assert "if (is_na(_v0) || is_na(_v1)) return na();" in cpp + assert "std::max((double)((-1.0)), (double)(std::min" not in cpp + + # === Task 9: strategy.closedtrades API + max_drawdown/max_runup === @@ -1154,6 +1179,55 @@ def test_request_security_ta_history_offset_uses_htf_gating(): assert "is_first_tick_ ? _ta_ema_1" not in eval_body +def test_request_security_helper_history_offset_uses_htf_context(): + cpp = _generate(""" +//@version=6 +strategy("T") +clamp(float x) => + math.max(-1.0, math.min(1.0, x)) +score() => + raw = timeframe.isweekly ? clamp(close - open) : na + raw +htfScore = request.security(syminfo.tickerid, "W", score()[1], lookahead=barmerge.lookahead_off) +plot(htfScore) +""") + start = cpp.index("void _eval_security_0(") + end = cpp.index("void evaluate_security(", start) + eval_body = cpp[start:end] + + assert "Series _sec0_expr_hist_0" in cpp + assert "tf_is_weekly(\"W\")" in eval_body + assert "bar.close - bar.open" in eval_body + assert "_req_sec_0 = ([&]() -> double" in eval_body + assert "_sec0_expr_hist_0[_hidx - 1]" in eval_body + assert "if (is_complete) _sec0_expr_hist_0.push(_hv);" in eval_body + assert "_hist_call" not in eval_body + assert "is_first_tick_" not in eval_body + assert "current_bar_" not in eval_body + + +def test_request_security_math_min_max_propagate_na(): + cpp = _generate(""" +//@version=6 +strategy("T") +clamp(float x) => + math.max(-1.0, math.min(1.0, x)) +score() => + v = timeframe.isweekly ? na : close + clamp(v) +w = request.security(syminfo.tickerid, "W", score(), lookahead=barmerge.lookahead_off) +plot(w) +""") + start = cpp.index("void _eval_security_0(") + end = cpp.index("void evaluate_security(", start) + eval_body = cpp[start:end] + + assert 'tf_is_weekly("W")' in eval_body + assert "return na();" in eval_body + assert "if (is_na(_v0) || is_na(_v1)) return na();" in eval_body + assert "std::max((double)((-1.0)), (double)(std::min" not in eval_body + + def test_request_financial_still_na(): cpp = _generate(""" //@version=6 diff --git a/tests/test_codegen_validation_fixes.py b/tests/test_codegen_validation_fixes.py index 48b8735..cbffdea 100644 --- a/tests/test_codegen_validation_fixes.py +++ b/tests/test_codegen_validation_fixes.py @@ -1,6 +1,6 @@ """Regression tests for codegen bugs found by pinescript-scrapper validation. -Covers five fix families: +Covers seven fix families: 1. drawing-handle ``na`` reset/assignment (Box{}/Line{}/... not na()), plus typed ``na`` for string/int/bool declaration init. 2. void drawing setter used as a UDF's last expression / if-branch value. @@ -9,6 +9,14 @@ 4. parser handling of ``T[]`` array-typed function parameters (``float[] arr``, ``line[] ln``) — previously the whole function was dropped. 5. typed drawing array constructors ``array.new_line/box/label/linefill``. + 6. Pine v6 bool casts that must treat ``na`` as ``false`` instead of C++'s + truthy NaN conversion. + 7. ``input.source`` series replayed during TA precompute must be cleared + before the real run. + 8. Pine ``for`` loops infer direction even when an explicit positive ``by`` + step is supplied. + 9. Numeric ternaries promote an ``int`` literal branch to ``double`` when the + other branch is floating-point arithmetic. """ from pineforge_codegen import transpile @@ -126,6 +134,47 @@ def test_security_param_tf_dead_code_falls_back_to_chart_tf(): assert "Unknown variable" not in cpp +def test_security_input_backed_timeframe_alias_expands_at_registration(): + # NicoCashFx shape: request.security receives a global timeframe alias whose + # value is assigned on_bar from input-backed operands. Registration happens + # before on_bar, so emitting the alias member itself registers an empty tf. + cpp = _cpp( + 'useChart = input.bool(false, "Use chart")\n' + 'fixedTf = input.string("30", "Fixed TF", options=["15", "30", "60"])\n' + "tf = useChart ? timeframe.period : fixedTf\n" + 'htf = request.security(syminfo.tickerid, tf, close)\n' + "plot(htf)" + ) + assert 'register_security_eval(0, tf, input_tf_, false, false)' not in cpp + assert 'get_input_bool("Use chart", false)' in cpp + assert 'get_input_string("Fixed TF", std::string("30"))' in cpp + assert 'register_security_eval(0, ((get_input_bool("Use chart", false)) ? (script_tf_) : (get_input_string("Fixed TF", std::string("30")))), input_tf_, false, false)' in cpp + + +def test_bar_index_builtin_uses_public_offset_helper(): + cpp = _cpp( + "fire = bar_index % 200 == 0\n" + "if fire\n" + " strategy.entry(\"L\", strategy.long)\n" + "plot(close)" + ) + assert "std::fmod((double)(pine_bar_index()), (double)(200))" in cpp + + +def test_bar_index_history_series_is_pushed_from_offset_helper(): + cpp = _cpp( + "past = bar_index[6]\n" + "span = bar_index - past\n" + "if span >= 6\n" + " strategy.entry(\"L\", strategy.long)\n" + "plot(close)" + ) + assert "Series bar_index" in cpp + assert "if (is_first_tick_) bar_index.push(pine_bar_index());" in cpp + assert "else bar_index.update(pine_bar_index());" in cpp + assert "(pine_bar_index() - past)" in cpp + + def test_security_param_tf_mixed_with_non_literal_callsite_rejected(): # Two distinct literal tfs PLUS a third call site whose tf isn't a # compile-time literal (a ternary the const-folder can't resolve) -> @@ -302,6 +351,157 @@ def test_drawing_array_constructor_default_value_arg(): assert "std::vector((size_t)(2)" in cpp +def test_untyped_var_drawing_array_constructor_emits_typed_member(): + cpp = _cpp( + "var boxes = array.new_box()\n" + "if bar_index == 0\n" + " b = box.new(bar_index, high, bar_index + 1, low)\n" + " array.push(boxes, b)\n" + "plot(array.size(boxes))" + ) + assert "std::vector boxes;" in cpp + assert "std::vector boxes;" not in cpp + + +def test_comma_separated_statements_and_array_fill_emit_all_side_effects(): + cpp = _cpp( + "var float a = na\n" + "var float b = na\n" + "var float[] xs = array.new_float(3, na)\n" + "var int[] ys = array.new_int(2, na)\n" + "var label[] lbs = array.new_label(2, na)\n" + "if true\n" + " a := 1, b := 2\n" + " array.fill(xs, na), array.set(xs, 1, 7)\n" + " array.fill(ys, na), ys.set(1, na)\n" + " array.fill(lbs, na)\n" + "plot(a + b + array.get(xs, 1))" + ) + assert "a = 1;" in cpp + assert "b = 2;" in cpp + assert "xs = std::vector((size_t)(3), na());" in cpp + assert "ys = std::vector((size_t)(2), na());" in cpp + assert "lbs = std::vector