Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions src/SkipStages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,21 @@ class SkipStages : public IRMutator {
}
}

// select on bools that collapses select(c, X, X) -> X and the
// constant-cond cases.
Expr make_select(const Expr &c, const Expr &t, const Expr &f) {
if (is_const_one(c)) {
return t;
}
if (is_const_zero(c)) {
return f;
}
if (t.same_as(f)) {
return t;
}
return select(c, t, f);
}
Comment on lines +402 to +413
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and the helpers above) seems like the kind of thing that is likely to be useful elsewhere in the compiler, maybe even duplicated already. I suppose we didn't want to call simplify(Select::make(...)) here for some reason.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True enough, but in this context eager simplification is absolutely required for this to not blow up. If I just made this the default behavior of operator&&/operator||/select I wouldn't have that clear guarantee at this point in the code.


void merge_func_info(std::map<size_t, FuncInfo> *old,
const std::map<size_t, FuncInfo> &new_info,
const Expr &used = Expr{},
Expand Down Expand Up @@ -468,10 +483,48 @@ class SkipStages : public IRMutator {
std::map<size_t, FuncInfo> old;
old.swap(func_info);
mutate(op->true_value);
merge_func_info(&old, func_info, op->condition);
func_info.clear();
std::map<size_t, FuncInfo> true_info;
true_info.swap(func_info);
mutate(op->false_value);
merge_func_info(&old, func_info, !op->condition);
// func_info now holds the false-branch info.

// Ids touched on both branches: combine with select(cond, t, f),
// so make_select can collapse `select(c, X, X) -> X` when the
// two branches contributed the same Expr.
// Ids touched on only one branch: AND the predicate with the
// appropriate side of the condition.
auto merge_into_old = [&](size_t id, const Expr &u, const Expr &l) {
auto [q, inserted] = old.try_emplace(id, FuncInfo{u, l});
if (!inserted) {
q->second.used = make_or(q->second.used, u);
q->second.loaded = make_or(q->second.loaded, l);
}
};

for (auto &p : func_info) {
size_t id = p.first;
auto it_t = true_info.find(id);
Expr u, l;
if (it_t != true_info.end()) {
u = make_select(op->condition, it_t->second.used, p.second.used);
l = make_select(op->condition, it_t->second.loaded, p.second.loaded);
true_info.erase(it_t);
} else {
u = make_and(p.second.used, !op->condition);
l = make_and(p.second.loaded, !op->condition);
}
merge_into_old(id, u, l);
}
// The ids left in true_info are the ones that were only touched
// on the true branch (the ones present in both branches were
// combined and erased in the loop above).
for (auto &p : true_info) {
size_t id = p.first;
Expr u = make_and(p.second.used, op->condition);
Expr l = make_and(p.second.loaded, op->condition);
merge_into_old(id, u, l);
}
func_info.clear();
old.swap(func_info);
mutate(op->condition); // Check for any calls in the condition

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ tests(GROUPS correctness
low_bit_depth_noise.cpp
make_struct.cpp
many_dimensions.cpp
many_inlined_selects.cpp
many_small_extern_stages.cpp
many_updates.cpp
math.cpp
Expand Down
72 changes: 72 additions & 0 deletions test/correctness/many_inlined_selects.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "Halide.h"
#include <cstdio>
#include <vector>

// Stress test for skip_stages on a Func whose value is read via a long chain
// of CSE'd let bindings, each carrying a Select gated on an independent
// scalar Param. The setup:
// - f has a self-referencing update definition (so it gets
// compute_at(innermost) and shows up in conditionally_used_funcs).
// - A chain of derived Exprs over f is built, each referenced multiple
// times downstream so CSE materialises them as a let chain.
// - Each chain value contains a Select gated on a Param<bool>.
//
// This pattern used to scale exponentially in skip_stages: every nested
// Select roughly doubled the size of the .used / .loaded predicate that
// the mutator built up for f, because the boolean form
// `(t && cond) || (f && !cond)` couldn't recognise that the t and f
// sub-predicates were the same Expr coming from a let-stashed FuncInfo
// above. At this chain length the predicate would contain ~2^500 IR
// nodes -- i.e. skip_stages would crash trying to allocate it long
// before any wall-clock timeout fired. Post-fix it lowers in a fraction
// of a second.

using namespace Halide;

int main(int argc, char **argv) {
Var x("x"), y("y"), c("c");

ImageParam src(Float(32), 3, "src");

constexpr int num_params = 8;
std::vector<Param<bool>> conds;
conds.reserve(num_params);
for (int i = 0; i < num_params; i++) {
conds.emplace_back("cond" + std::to_string(i));
}

// f: self-referencing update -> can't be inlined, becomes a separate
// compute_at(innermost) Func, so it shows up in skip_stages's analysis.
Func f("f");
f(x, y, c) = src(x, y, c) * 1.5f + 0.5f;
f(x, y, c) = clamp(f(x, y, c), 0.0f, 1.0f);

// Build a long chain of derived expressions. Each entry references
// the immediately preceding one (chain.back()) plus two pseudo-random
// earlier ones. The dependency on chain.back() guarantees that every
// entry is reachable from the final one, so nothing gets dropped as
// dead. CSE will materialise each entry as a let in the lowered IR.
// Each chain entry's value contains a Select gated on one of the
// Param<bool>s.
constexpr int chain_len = 500;
std::vector<Expr> chain;
chain.reserve(chain_len + 3);
chain.push_back(f(x, y, 0));
chain.push_back(f(x, y, 1));
chain.push_back(f(x, y, 2));
for (int i = 0; i < chain_len; i++) {
Expr a = chain.back();
Expr b = chain[(i * 5 + 1) % chain.size()];
Expr d = chain[(i * 7 + 2) % chain.size()];
Expr cond = conds[i % num_params];
Expr e = select(cond, a * b + d, (a - d) * b) +
cast<float>(i) * 0.0001f;
chain.push_back(e);
}

Func out("out");
out(x, y) = chain.back();
out.compile_jit(get_jit_target_from_environment());
printf("Success!\n");
return 0;
}
Loading