From 696c085ae88c660e7469dd53744e364609be4314 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 17 May 2026 14:15:57 -0700 Subject: [PATCH 1/3] Fix exponential blow-up in skip_stages on chains of let-bound selects In SkipStages::visit(Select), the two branches' per-Func .used / .loaded predicates were combined as `(t_used && cond) || (f_used && !cond)`. When both branches contributed the same Expr -- which is exactly what happens when both branches read the same let-stashed FuncInfo from an outer let -- make_or could not recognise the And nodes as equivalent (they aren't same_as even when their operands are), so the predicate roughly doubled in size at every nested Select. A long chain of CSE'd lets where each let value contains a Select then drove the predicate size to 2^N, well past the point where allocating the IR is feasible. Combine the two branches with `select(cond, t, f)` instead, and add a make_select helper that collapses `select(c, X, X) -> X` and the constant-cond cases. When both branches contributed the same Expr, make_select drops the condition immediately and the chain stays linear. The new correctness test (many_inlined_selects.cpp) constructs a 500- element CSE'd let chain whose values each carry a Param-gated Select, then feeds the chain into a final Select. With the bug present this test would not terminate -- skip_stages would crash allocating ~2^500 IR nodes long before any reasonable timeout fired. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/SkipStages.cpp | 59 +++++++++++++++- test/correctness/CMakeLists.txt | 1 + test/correctness/many_inlined_selects.cpp | 86 +++++++++++++++++++++++ 3 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 test/correctness/many_inlined_selects.cpp diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index 6f59ca538216..5550690380d1 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -397,6 +397,21 @@ class SkipStages : public IRMutator { } } + // select on bools that collapses select(c, X, X) -> X and the + // constant-cond cases. + Expr make_select(const Expr &c, const Expr &t, const Expr &f) { + if (is_const_one(c)) { + return t; + } + if (is_const_zero(c)) { + return f; + } + if (t.same_as(f)) { + return t; + } + return select(c, t, f); + } + void merge_func_info(std::map *old, const std::map &new_info, const Expr &used = Expr{}, @@ -468,10 +483,48 @@ class SkipStages : public IRMutator { std::map old; old.swap(func_info); mutate(op->true_value); - merge_func_info(&old, func_info, op->condition); - func_info.clear(); + std::map true_info; + true_info.swap(func_info); mutate(op->false_value); - merge_func_info(&old, func_info, !op->condition); + // func_info now holds the false-branch info. + + // Combine the two branches with select(cond, t, f). A missing + // entry on a side means used=false / loaded=false for that branch. + auto combine = [&](const Expr &t, const Expr &f) { + Expr tt = t.defined() ? t : const_false(); + Expr ff = f.defined() ? f : const_false(); + return make_select(op->condition, tt, ff); + }; + + // Walk every id present in either branch. + for (auto &p : func_info) { + size_t id = p.first; + auto it_t = true_info.find(id); + Expr t_u, t_l; + if (it_t != true_info.end()) { + t_u = it_t->second.used; + t_l = it_t->second.loaded; + true_info.erase(it_t); + } + FuncInfo merged{combine(t_u, p.second.used), + combine(t_l, p.second.loaded)}; + auto [q, inserted] = old.try_emplace(id, merged); + if (!inserted) { + q->second.used = make_or(q->second.used, merged.used); + q->second.loaded = make_or(q->second.loaded, merged.loaded); + } + } + for (auto &p : true_info) { + size_t id = p.first; + FuncInfo merged{combine(p.second.used, Expr()), + combine(p.second.loaded, Expr())}; + auto [q, inserted] = old.try_emplace(id, merged); + if (!inserted) { + q->second.used = make_or(q->second.used, merged.used); + q->second.loaded = make_or(q->second.loaded, merged.loaded); + } + } + func_info.clear(); old.swap(func_info); mutate(op->condition); // Check for any calls in the condition diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index ca43f5f2cf40..7151442a0292 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -219,6 +219,7 @@ tests(GROUPS correctness low_bit_depth_noise.cpp make_struct.cpp many_dimensions.cpp + many_inlined_selects.cpp many_small_extern_stages.cpp many_updates.cpp math.cpp diff --git a/test/correctness/many_inlined_selects.cpp b/test/correctness/many_inlined_selects.cpp new file mode 100644 index 000000000000..a7fb915e5b79 --- /dev/null +++ b/test/correctness/many_inlined_selects.cpp @@ -0,0 +1,86 @@ +#include "Halide.h" +#include +#include + +// Stress test for skip_stages on a Func whose value is read via a long chain +// of CSE'd let bindings, each carrying a Select gated on an independent +// scalar Param. The setup: +// - F has a self-referencing update definition (so it gets +// compute_at(innermost) and shows up in conditionally_used_funcs). +// - A chain of derived Exprs over F is built, each referenced multiple +// times downstream so CSE materialises them as a let chain. +// - Each chain value contains a Select gated on a Param. +// - A final Select feeds the whole chain into the output. +// +// This pattern used to scale exponentially in skip_stages: every nested +// Select roughly doubled the size of the .used / .loaded predicate that +// the mutator built up for F, because the boolean form +// `(t && cond) || (f && !cond)` couldn't recognise that the t and f +// sub-predicates were the same Expr coming from a let-stashed FuncInfo +// above. At this chain length the predicate would contain ~2^500 IR +// nodes — i.e. skip_stages would crash trying to allocate it long +// before any wall-clock timeout fired. Post-fix it lowers in a fraction +// of a second. + +using namespace Halide; + +int main(int argc, char **argv) { + Var x("x"), y("y"), c("c"); + + ImageParam src(Float(32), 3, "src"); + Param gate("gate"); + + constexpr int num_params = 8; + std::vector> conds; + conds.reserve(num_params); + for (int i = 0; i < num_params; i++) { + conds.emplace_back("cond" + std::to_string(i)); + } + + // F: self-referencing update -> can't be inlined, becomes a separate + // compute_at(innermost) Func, so it shows up in skip_stages's analysis. + Func F("F"); + F(x, y, c) = src(x, y, c) * 1.5f + 0.5f; + F(x, y, c) = clamp(F(x, y, c), 0.0f, 1.0f); + + // -- Build a long chain of derived expressions. Each Expr is held in a + // C++ variable and referenced multiple times by subsequent Exprs, so + // CSE will materialise each as a let in the lowered IR. + constexpr int chain_len = 500; + std::vector chain; + chain.reserve(chain_len + 3); + chain.push_back(F(x, y, 0)); + chain.push_back(F(x, y, 1)); + chain.push_back(F(x, y, 2)); + for (int i = 0; i < chain_len; i++) { + Expr a = chain[(i * 3) % chain.size()]; + Expr b = chain[(i * 5 + 1) % chain.size()]; + Expr d = chain[(i * 7 + 2) % chain.size()]; + // Each chain entry's value contains a select gated on one of the + // Params. The let cascade in skip_stages's analysis marks + // the whole chain as interesting and the mutator builds a + // predicate for each let in turn. + Expr cond = conds[i % num_params]; + Expr e = select(cond, a * b + d, (a - d) * b) + + cast(i) * 0.0001f; + chain.push_back(e); + } + + // -- Final expression is a single select gated on a runtime param. + // The branches reference the *last* (and therefore transitively every) + // chain entry, so the in-condition cascade marks every let interesting. + Expr branch_t = chain.back() + chain[chain.size() / 2]; + Expr branch_f = chain.back() - chain[chain.size() / 2]; + Func out("out"); + out(x, y) = select(gate, branch_t, branch_f); + + Target target = get_jit_target_from_environment(); + int vec = target.natural_vector_size(); + Var tx("tx"), yo("yo"), yi("yi"); + out.split(x, x, tx, vec, TailStrategy::GuardWithIf).vectorize(tx) + .split(y, yo, yi, 64, TailStrategy::GuardWithIf).parallel(yo); + + out.compile_jit(target); + printf("Success!\n"); + return 0; +} From 69a4a169a3ed6fed79cc89d32dac469e5f94829c Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 17 May 2026 14:23:52 -0700 Subject: [PATCH 2/3] Clean up branch-merging in SkipStages::visit(Select) When an id is only touched on one branch of the Select, the previous code passed an undefined Expr to a `combine` helper that then turned `undefined` into const_false and built a `select(cond, X, false)` -- which is just `X && cond` dressed up as a select. Call make_and directly in those cases and keep make_select for the both-branches case, where the `select(c, X, X) -> X` collapse is the whole point. Also factor the "merge into old" body into a small helper to remove the duplication. No behaviour change. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/SkipStages.cpp | 48 +++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index 5550690380d1..dffa16b16291 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -488,41 +488,41 @@ class SkipStages : public IRMutator { mutate(op->false_value); // func_info now holds the false-branch info. - // Combine the two branches with select(cond, t, f). A missing - // entry on a side means used=false / loaded=false for that branch. - auto combine = [&](const Expr &t, const Expr &f) { - Expr tt = t.defined() ? t : const_false(); - Expr ff = f.defined() ? f : const_false(); - return make_select(op->condition, tt, ff); + // Ids touched on both branches: combine with select(cond, t, f), + // so make_select can collapse `select(c, X, X) -> X` when the + // two branches contributed the same Expr. + // Ids touched on only one branch: AND the predicate with the + // appropriate side of the condition. + auto merge_into_old = [&](size_t id, const Expr &u, const Expr &l) { + auto [q, inserted] = old.try_emplace(id, FuncInfo{u, l}); + if (!inserted) { + q->second.used = make_or(q->second.used, u); + q->second.loaded = make_or(q->second.loaded, l); + } }; - // Walk every id present in either branch. for (auto &p : func_info) { size_t id = p.first; auto it_t = true_info.find(id); - Expr t_u, t_l; + Expr u, l; if (it_t != true_info.end()) { - t_u = it_t->second.used; - t_l = it_t->second.loaded; + u = make_select(op->condition, it_t->second.used, p.second.used); + l = make_select(op->condition, it_t->second.loaded, p.second.loaded); true_info.erase(it_t); + } else { + u = make_and(p.second.used, !op->condition); + l = make_and(p.second.loaded, !op->condition); } - FuncInfo merged{combine(t_u, p.second.used), - combine(t_l, p.second.loaded)}; - auto [q, inserted] = old.try_emplace(id, merged); - if (!inserted) { - q->second.used = make_or(q->second.used, merged.used); - q->second.loaded = make_or(q->second.loaded, merged.loaded); - } + merge_into_old(id, u, l); } + // The ids left in true_info are the ones that were only touched + // on the true branch (the ones present in both branches were + // combined and erased in the loop above). for (auto &p : true_info) { size_t id = p.first; - FuncInfo merged{combine(p.second.used, Expr()), - combine(p.second.loaded, Expr())}; - auto [q, inserted] = old.try_emplace(id, merged); - if (!inserted) { - q->second.used = make_or(q->second.used, merged.used); - q->second.loaded = make_or(q->second.loaded, merged.loaded); - } + Expr u = make_and(p.second.used, op->condition); + Expr l = make_and(p.second.loaded, op->condition); + merge_into_old(id, u, l); } func_info.clear(); old.swap(func_info); From 83b9cdfc5481074344433da91757758ed9176740 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 17 May 2026 14:29:52 -0700 Subject: [PATCH 3/3] Simplify many_inlined_selects test Lowercase the Func name, drop the unnecessary top-level select and output schedule, and make each chain entry depend on chain.back() so nothing gets eliminated as dead. The test still reproduces the pre-fix exponential blow-up (verified by reverting the fix: it times out at 30s on a 500-element chain). Co-Authored-By: Claude Opus 4.7 (1M context) --- test/correctness/many_inlined_selects.cpp | 56 +++++++++-------------- 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/test/correctness/many_inlined_selects.cpp b/test/correctness/many_inlined_selects.cpp index a7fb915e5b79..c2eecc14230a 100644 --- a/test/correctness/many_inlined_selects.cpp +++ b/test/correctness/many_inlined_selects.cpp @@ -5,20 +5,19 @@ // Stress test for skip_stages on a Func whose value is read via a long chain // of CSE'd let bindings, each carrying a Select gated on an independent // scalar Param. The setup: -// - F has a self-referencing update definition (so it gets +// - f has a self-referencing update definition (so it gets // compute_at(innermost) and shows up in conditionally_used_funcs). -// - A chain of derived Exprs over F is built, each referenced multiple +// - A chain of derived Exprs over f is built, each referenced multiple // times downstream so CSE materialises them as a let chain. // - Each chain value contains a Select gated on a Param. -// - A final Select feeds the whole chain into the output. // // This pattern used to scale exponentially in skip_stages: every nested // Select roughly doubled the size of the .used / .loaded predicate that -// the mutator built up for F, because the boolean form +// the mutator built up for f, because the boolean form // `(t && cond) || (f && !cond)` couldn't recognise that the t and f // sub-predicates were the same Expr coming from a let-stashed FuncInfo // above. At this chain length the predicate would contain ~2^500 IR -// nodes — i.e. skip_stages would crash trying to allocate it long +// nodes -- i.e. skip_stages would crash trying to allocate it long // before any wall-clock timeout fired. Post-fix it lowers in a fraction // of a second. @@ -28,7 +27,6 @@ int main(int argc, char **argv) { Var x("x"), y("y"), c("c"); ImageParam src(Float(32), 3, "src"); - Param gate("gate"); constexpr int num_params = 8; std::vector> conds; @@ -37,50 +35,38 @@ int main(int argc, char **argv) { conds.emplace_back("cond" + std::to_string(i)); } - // F: self-referencing update -> can't be inlined, becomes a separate + // f: self-referencing update -> can't be inlined, becomes a separate // compute_at(innermost) Func, so it shows up in skip_stages's analysis. - Func F("F"); - F(x, y, c) = src(x, y, c) * 1.5f + 0.5f; - F(x, y, c) = clamp(F(x, y, c), 0.0f, 1.0f); + Func f("f"); + f(x, y, c) = src(x, y, c) * 1.5f + 0.5f; + f(x, y, c) = clamp(f(x, y, c), 0.0f, 1.0f); - // -- Build a long chain of derived expressions. Each Expr is held in a - // C++ variable and referenced multiple times by subsequent Exprs, so - // CSE will materialise each as a let in the lowered IR. + // Build a long chain of derived expressions. Each entry references + // the immediately preceding one (chain.back()) plus two pseudo-random + // earlier ones. The dependency on chain.back() guarantees that every + // entry is reachable from the final one, so nothing gets dropped as + // dead. CSE will materialise each entry as a let in the lowered IR. + // Each chain entry's value contains a Select gated on one of the + // Params. constexpr int chain_len = 500; std::vector chain; chain.reserve(chain_len + 3); - chain.push_back(F(x, y, 0)); - chain.push_back(F(x, y, 1)); - chain.push_back(F(x, y, 2)); + chain.push_back(f(x, y, 0)); + chain.push_back(f(x, y, 1)); + chain.push_back(f(x, y, 2)); for (int i = 0; i < chain_len; i++) { - Expr a = chain[(i * 3) % chain.size()]; + Expr a = chain.back(); Expr b = chain[(i * 5 + 1) % chain.size()]; Expr d = chain[(i * 7 + 2) % chain.size()]; - // Each chain entry's value contains a select gated on one of the - // Params. The let cascade in skip_stages's analysis marks - // the whole chain as interesting and the mutator builds a - // predicate for each let in turn. Expr cond = conds[i % num_params]; Expr e = select(cond, a * b + d, (a - d) * b) + cast(i) * 0.0001f; chain.push_back(e); } - // -- Final expression is a single select gated on a runtime param. - // The branches reference the *last* (and therefore transitively every) - // chain entry, so the in-condition cascade marks every let interesting. - Expr branch_t = chain.back() + chain[chain.size() / 2]; - Expr branch_f = chain.back() - chain[chain.size() / 2]; Func out("out"); - out(x, y) = select(gate, branch_t, branch_f); - - Target target = get_jit_target_from_environment(); - int vec = target.natural_vector_size(); - Var tx("tx"), yo("yo"), yi("yi"); - out.split(x, x, tx, vec, TailStrategy::GuardWithIf).vectorize(tx) - .split(y, yo, yi, 64, TailStrategy::GuardWithIf).parallel(yo); - - out.compile_jit(target); + out(x, y) = chain.back(); + out.compile_jit(get_jit_target_from_environment()); printf("Success!\n"); return 0; }