diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index 6f59ca538216..dffa16b16291 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -397,6 +397,21 @@ class SkipStages : public IRMutator { } } + // select on bools that collapses select(c, X, X) -> X and the + // constant-cond cases. + Expr make_select(const Expr &c, const Expr &t, const Expr &f) { + if (is_const_one(c)) { + return t; + } + if (is_const_zero(c)) { + return f; + } + if (t.same_as(f)) { + return t; + } + return select(c, t, f); + } + void merge_func_info(std::map *old, const std::map &new_info, const Expr &used = Expr{}, @@ -468,10 +483,48 @@ class SkipStages : public IRMutator { std::map old; old.swap(func_info); mutate(op->true_value); - merge_func_info(&old, func_info, op->condition); - func_info.clear(); + std::map true_info; + true_info.swap(func_info); mutate(op->false_value); - merge_func_info(&old, func_info, !op->condition); + // func_info now holds the false-branch info. + + // Ids touched on both branches: combine with select(cond, t, f), + // so make_select can collapse `select(c, X, X) -> X` when the + // two branches contributed the same Expr. + // Ids touched on only one branch: AND the predicate with the + // appropriate side of the condition. + auto merge_into_old = [&](size_t id, const Expr &u, const Expr &l) { + auto [q, inserted] = old.try_emplace(id, FuncInfo{u, l}); + if (!inserted) { + q->second.used = make_or(q->second.used, u); + q->second.loaded = make_or(q->second.loaded, l); + } + }; + + for (auto &p : func_info) { + size_t id = p.first; + auto it_t = true_info.find(id); + Expr u, l; + if (it_t != true_info.end()) { + u = make_select(op->condition, it_t->second.used, p.second.used); + l = make_select(op->condition, it_t->second.loaded, p.second.loaded); + true_info.erase(it_t); + } else { + u = make_and(p.second.used, !op->condition); + l = make_and(p.second.loaded, !op->condition); + } + merge_into_old(id, u, l); + } + // The ids left in true_info are the ones that were only touched + // on the true branch (the ones present in both branches were + // combined and erased in the loop above). + for (auto &p : true_info) { + size_t id = p.first; + Expr u = make_and(p.second.used, op->condition); + Expr l = make_and(p.second.loaded, op->condition); + merge_into_old(id, u, l); + } + func_info.clear(); old.swap(func_info); mutate(op->condition); // Check for any calls in the condition diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index ca43f5f2cf40..7151442a0292 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -219,6 +219,7 @@ tests(GROUPS correctness low_bit_depth_noise.cpp make_struct.cpp many_dimensions.cpp + many_inlined_selects.cpp many_small_extern_stages.cpp many_updates.cpp math.cpp diff --git a/test/correctness/many_inlined_selects.cpp b/test/correctness/many_inlined_selects.cpp new file mode 100644 index 000000000000..c2eecc14230a --- /dev/null +++ b/test/correctness/many_inlined_selects.cpp @@ -0,0 +1,72 @@ +#include "Halide.h" +#include +#include + +// Stress test for skip_stages on a Func whose value is read via a long chain +// of CSE'd let bindings, each carrying a Select gated on an independent +// scalar Param. The setup: +// - f has a self-referencing update definition (so it gets +// compute_at(innermost) and shows up in conditionally_used_funcs). +// - A chain of derived Exprs over f is built, each referenced multiple +// times downstream so CSE materialises them as a let chain. +// - Each chain value contains a Select gated on a Param. +// +// This pattern used to scale exponentially in skip_stages: every nested +// Select roughly doubled the size of the .used / .loaded predicate that +// the mutator built up for f, because the boolean form +// `(t && cond) || (f && !cond)` couldn't recognise that the t and f +// sub-predicates were the same Expr coming from a let-stashed FuncInfo +// above. At this chain length the predicate would contain ~2^500 IR +// nodes -- i.e. skip_stages would crash trying to allocate it long +// before any wall-clock timeout fired. Post-fix it lowers in a fraction +// of a second. + +using namespace Halide; + +int main(int argc, char **argv) { + Var x("x"), y("y"), c("c"); + + ImageParam src(Float(32), 3, "src"); + + constexpr int num_params = 8; + std::vector> conds; + conds.reserve(num_params); + for (int i = 0; i < num_params; i++) { + conds.emplace_back("cond" + std::to_string(i)); + } + + // f: self-referencing update -> can't be inlined, becomes a separate + // compute_at(innermost) Func, so it shows up in skip_stages's analysis. + Func f("f"); + f(x, y, c) = src(x, y, c) * 1.5f + 0.5f; + f(x, y, c) = clamp(f(x, y, c), 0.0f, 1.0f); + + // Build a long chain of derived expressions. Each entry references + // the immediately preceding one (chain.back()) plus two pseudo-random + // earlier ones. The dependency on chain.back() guarantees that every + // entry is reachable from the final one, so nothing gets dropped as + // dead. CSE will materialise each entry as a let in the lowered IR. + // Each chain entry's value contains a Select gated on one of the + // Params. + constexpr int chain_len = 500; + std::vector chain; + chain.reserve(chain_len + 3); + chain.push_back(f(x, y, 0)); + chain.push_back(f(x, y, 1)); + chain.push_back(f(x, y, 2)); + for (int i = 0; i < chain_len; i++) { + Expr a = chain.back(); + Expr b = chain[(i * 5 + 1) % chain.size()]; + Expr d = chain[(i * 7 + 2) % chain.size()]; + Expr cond = conds[i % num_params]; + Expr e = select(cond, a * b + d, (a - d) * b) + + cast(i) * 0.0001f; + chain.push_back(e); + } + + Func out("out"); + out(x, y) = chain.back(); + out.compile_jit(get_jit_target_from_environment()); + printf("Success!\n"); + return 0; +}