#include "SkipStages.h" #include "CSE.h" #include "Debug.h" #include "ExprUsesVar.h" #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" #include "IRPrinter.h" #include "Scope.h" #include "Simplify.h" #include "Substitute.h" #include #include namespace Halide { namespace Internal { using std::set; using std::string; using std::vector; namespace { bool extern_call_uses_buffer(const Call *op, const std::string &func) { if (op->is_extern()) { if (starts_with(op->name, "halide_memoization")) { return false; } for (const auto &arg : op->args) { const Variable *var = arg.as(); if (var && starts_with(var->name, func + ".") && ends_with(var->name, ".buffer")) { return true; } } } return false; } class PredicateFinder : public IRVisitor { public: Expr predicate; PredicateFinder(const string &b, bool s) : predicate(const_false()), buffer(b), treat_selects_as_guards(s) { } private: using IRVisitor::visit; string buffer; bool varies = false; bool treat_selects_as_guards; bool in_produce = false; Scope<> varying; Scope<> in_pipeline; Scope<> local_buffers; void visit(const Variable *op) override { bool this_varies = varying.contains(op->name); varies |= this_varies; } void visit(const For *op) override { op->min.accept(this); bool min_varies = varies; op->extent.accept(this); bool should_pop = false; if (!is_const_one(op->extent) || min_varies) { should_pop = true; varying.push(op->name); } op->body.accept(this); if (should_pop) { varying.pop(op->name); } else if (expr_uses_var(predicate, op->name)) { predicate = Let::make(op->name, op->min, predicate); } } template void visit_let(const T *op) { struct Frame { const T *op; ScopedBinding<> binding; }; vector frames; decltype(op->body) body; do { bool old_varies = varies; varies = false; op->value.accept(this); frames.push_back(Frame{op, ScopedBinding<>(varies, varying, op->name)}); varies |= old_varies; body = op->body; op = body.template as(); } while (op); body.accept(this); for (auto it = frames.rbegin(); it != frames.rend(); it++) { if (expr_uses_var(predicate, it->op->name)) { predicate = Let::make(it->op->name, it->op->value, predicate); } } } void visit(const LetStmt *op) override { visit_let(op); } void visit(const Let *op) override { visit_let(op); } void visit(const ProducerConsumer *op) override { ScopedBinding<> bind(in_pipeline, op->name); if (op->is_producer && op->name == buffer) { ScopedValue sv(in_produce, true); IRVisitor::visit(op); } else { IRVisitor::visit(op); } } // Logical operators with eager constant folding Expr make_and(Expr a, Expr b) { if (is_const_zero(a) || is_const_one(b)) { return a; } else if (is_const_zero(b) || is_const_one(a)) { return b; } else if (equal(a, b)) { return a; } else { return a && b; } } Expr make_or(Expr a, Expr b) { if (is_const_zero(a) || is_const_one(b)) { return b; } else if (is_const_zero(b) || is_const_one(a)) { return a; } else if (equal(a, b)) { return a; } else { return a || b; } } Expr make_select(const Expr &a, Expr b, Expr c) { if (is_const_one(a)) { return b; } else if (is_const_zero(a)) { return c; } else if (is_const_one(b)) { return make_or(a, c); } else if (is_const_zero(b)) { return make_and(make_not(a), c); } else if (is_const_one(c)) { return make_or(make_not(a), b); } else if (is_const_zero(c)) { return make_and(a, b); } else { return select(a, b, c); } } Expr make_not(const Expr &a) { if (is_const_one(a)) { return make_zero(a.type()); } else if (is_const_zero(a)) { return make_one(a.type()); } else { return !a; } } template void visit_conditional(const Expr &condition, T true_case, T false_case) { Expr old_predicate = predicate; predicate = const_false(); true_case.accept(this); Expr true_predicate = predicate; predicate = const_false(); if (false_case.defined()) { false_case.accept(this); } Expr false_predicate = predicate; bool old_varies = varies; predicate = const_false(); varies = false; condition.accept(this); predicate = make_or(predicate, old_predicate); if (varies) { predicate = make_or(predicate, make_or(true_predicate, false_predicate)); } else { predicate = make_or(predicate, make_select(condition, true_predicate, false_predicate)); } varies = varies || old_varies; } void visit(const Select *op) override { if (treat_selects_as_guards) { visit_conditional(op->condition, op->true_value, op->false_value); } else { IRVisitor::visit(op); } } void visit(const IfThenElse *op) override { visit_conditional(op->condition, op->then_case, op->else_case); } void visit(const Call *op) override { varies |= in_pipeline.contains(op->name); IRVisitor::visit(op); if (!in_produce && (op->name == buffer || extern_call_uses_buffer(op, buffer))) { predicate = const_true(); } } void visit(const Provide *op) override { IRVisitor::visit(op); if (in_produce && op->name != buffer && !local_buffers.contains(op->name)) { predicate = const_true(); } } void visit(const Realize *op) override { ScopedBinding<> bind(local_buffers, op->name); IRVisitor::visit(op); } void visit(const Allocate *op) override { // This code works to ensure expressions depending on an // allocation don't get moved outside the allocation and are // marked as varying if predicate depends on the value of the // allocation. ScopedBinding<> bind_host_ptr(varying, op->name), bind_buffer(varying, op->name + ".buffer"); IRVisitor::visit(op); } }; class ProductionGuarder : public IRMutator { public: ProductionGuarder(const string &b, Expr compute_p, Expr alloc_p) : buffer(b), compute_predicate(std::move(compute_p)), alloc_predicate(std::move(alloc_p)) { } private: string buffer; Expr compute_predicate; Expr alloc_predicate; using IRMutator::visit; bool memoize_call_uses_buffer(const Call *op) { internal_assert(op->call_type == Call::Extern); internal_assert(starts_with(op->name, "halide_memoization")); for (const auto &arg : op->args) { const Variable *var = arg.as(); if (var && starts_with(var->name, buffer + ".") && ends_with(var->name, ".buffer")) { return true; } } return false; } Expr visit(const Call *op) override { if ((op->name == "halide_memoization_cache_lookup") && memoize_call_uses_buffer(op)) { // We need to guard call to halide_memoization_cache_lookup to only // be executed if the corresponding buffer is allocated. We ignore // the compute_predicate since in the case that alloc_predicate is // true but compute_predicate is false, the consumer would still load // data from the buffer even if it won't actually use the result, // hence, we need to allocate some scratch memory for the consumer // to load from. For memoized func, the memory might already be in // the cache, so we perform the lookup instead of allocating a new one. return Call::make(op->type, Call::if_then_else, {alloc_predicate, op, 0}, Call::PureIntrinsic); } else if ((op->name == "halide_memoization_cache_store") && memoize_call_uses_buffer(op)) { // We need to wrap the halide_memoization_cache_store with the // compute_predicate, since the data to be written is only valid if // the producer of the buffer is executed. return Call::make(op->type, Call::if_then_else, {compute_predicate, op, 0}, Call::PureIntrinsic); } else { return IRMutator::visit(op); } } Stmt visit(const ProducerConsumer *op) override { // If the compute_predicate at this stage depends on something // vectorized we should bail out. Stmt stmt = IRMutator::visit(op); if (op->is_producer) { op = stmt.as(); internal_assert(op); if (op->name == buffer) { Stmt body = IfThenElse::make(compute_predicate, op->body); stmt = ProducerConsumer::make(op->name, op->is_producer, body); } } return stmt; } }; class StageSkipper : public IRMutator { public: StageSkipper(const string &f) : func(f) { } private: string func; using IRMutator::visit; Scope<> vector_vars; bool in_vector_loop = false; Stmt visit(const For *op) override { bool old_in_vector_loop = in_vector_loop; // We want to be sure that the predicate doesn't vectorize. if (op->for_type == ForType::Vectorized) { vector_vars.push(op->name); in_vector_loop = true; } Stmt stmt = IRMutator::visit(op); if (op->for_type == ForType::Vectorized) { vector_vars.pop(op->name); } in_vector_loop = old_in_vector_loop; return stmt; } Stmt visit(const LetStmt *op) override { struct Frame { const LetStmt *op; bool vector_var; }; vector frames; Stmt result; while (op) { bool vector_var = in_vector_loop && expr_uses_vars(op->value, vector_vars); frames.push_back({op, vector_var}); if (vector_var) { vector_vars.push(op->name); } result = op->body; op = result.as(); } result = mutate(result); for (auto it = frames.rbegin(); it != frames.rend(); it++) { if (it->vector_var) { vector_vars.pop(it->op->name); } result = LetStmt::make(it->op->name, it->op->value, result); } return result; } Stmt visit(const Realize *op) override { if (op->name == func) { debug(3) << "Finding compute predicate for " << op->name << "\n"; PredicateFinder find_compute(op->name, true); op->body.accept(&find_compute); debug(3) << "Simplifying compute predicate for " << op->name << ": " << find_compute.predicate << "\n"; Expr compute_predicate = simplify(common_subexpression_elimination(find_compute.predicate)); debug(3) << "Compute predicate for " << op->name << " : " << compute_predicate << "\n"; if (expr_uses_vars(compute_predicate, vector_vars)) { // Don't try to skip stages if the predicate may vary // per lane. This will just unvectorize the // production, which is probably contrary to the // intent of the user. compute_predicate = const_true(); } if (!is_const_one(compute_predicate)) { debug(3) << "Finding allocate predicate for " << op->name << "\n"; PredicateFinder find_alloc(op->name, false); op->body.accept(&find_alloc); debug(3) << "Simplifying allocate predicate for " << op->name << "\n"; Expr alloc_predicate = simplify(common_subexpression_elimination(find_alloc.predicate)); debug(3) << "Allocate predicate for " << op->name << " : " << alloc_predicate << "\n"; ProductionGuarder g(op->name, compute_predicate, alloc_predicate); Stmt body = g.mutate(op->body); debug(3) << "Done guarding computation for " << op->name << "\n"; return Realize::make(op->name, op->types, op->memory_type, op->bounds, alloc_predicate, body); } else { return IRMutator::visit(op); } } else { return IRMutator::visit(op); } } }; // Find Funcs where at least one of the consume nodes only uses the // Func conditionally. We may want to guard the production of these // Funcs. class MightBeSkippable : public IRVisitor { using IRVisitor::visit; bool in_conditional_stmt{false}; void visit(const Call *op) override { IRVisitor::visit(op); if (op->call_type == Call::Halide) { unconditionally_used.insert(op->name); } } void visit(const IfThenElse *op) override { op->condition.accept(this); std::set old; unconditionally_used.swap(old); ScopedValue old_in_conditional(in_conditional_stmt, true); op->then_case.accept(this); std::set used_in_true; used_in_true.swap(unconditionally_used); if (op->else_case.defined()) { op->else_case.accept(this); } // Take the set intersection of the true and false paths, and add them to the set. std::set_intersection(used_in_true.begin(), used_in_true.end(), unconditionally_used.begin(), unconditionally_used.end(), std::inserter(old, old.begin())); unconditionally_used.swap(old); } void visit(const Select *op) override { op->condition.accept(this); std::set old; unconditionally_used.swap(old); op->true_value.accept(this); std::set used_in_true; used_in_true.swap(unconditionally_used); op->false_value.accept(this); // Again, take the set intersection std::set_intersection(used_in_true.begin(), used_in_true.end(), unconditionally_used.begin(), unconditionally_used.end(), std::inserter(old, old.begin())); unconditionally_used.swap(old); } void visit(const ProducerConsumer *op) override { if (!op->is_producer) { op->body.accept(this); if (!unconditionally_used.count(op->name) || in_conditional_stmt) { // This Func has a least one consume clause in which // it is only used conditionally. candidates.insert(op->name); } } else { IRVisitor::visit(op); // Calls inside the produce don't count - that's the block of code we intend to guard. unconditionally_used.erase(op->name); } } set unconditionally_used; public: set candidates; }; } // namespace Stmt skip_stages(Stmt stmt, const vector &order) { // Don't consider the last stage, because it's the output, so it's // never skippable. MightBeSkippable check; stmt.accept(&check); for (size_t i = order.size() - 1; i > 0; i--) { debug(2) << "skip_stages checking " << order[i - 1] << "\n"; if (check.candidates.count(order[i - 1])) { debug(2) << "skip_stages can skip " << order[i - 1] << "\n"; StageSkipper skipper(order[i - 1]); Stmt new_stmt = skipper.mutate(stmt); if (!new_stmt.same_as(stmt)) { // Might have made earlier stages skippable too new_stmt.accept(&check); } stmt = new_stmt; } } return stmt; } } // namespace Internal } // namespace Halide