#include "Reduction.h" #include "IR.h" #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" #include "IRVisitor.h" #include "Simplify.h" #include "Var.h" #include namespace Halide { namespace Internal { namespace { void check(const Expr &pred, std::vector &expected) { std::vector result; split_into_ands(pred, result); bool is_equal = true; if (result.size() != expected.size()) { is_equal = false; } else { for (size_t i = 0; i < expected.size(); ++i) { if (!equal(simplify(result[i]), simplify(expected[i]))) { is_equal = false; break; } } } if (!is_equal) { std::cout << "Expect predicate " << pred << " to be split into:\n"; for (const auto &e : expected) { std::cout << " " << e << "\n"; } std::cout << "Got:\n"; for (const auto &e : result) { std::cout << " " << e << "\n"; } internal_error << "\n"; } } } // namespace void split_predicate_test() { Expr x = Var("x"), y = Var("y"), z = Var("z"), w = Var("w"); { std::vector expected; expected.push_back(z < 10); check(z < 10, expected); } { std::vector expected; expected.push_back((x < y) || (x == 10)); check((x < y) || (x == 10), expected); } { std::vector expected; expected.push_back(x < y); expected.push_back(x == 10); check((x < y) && (x == 10), expected); } { std::vector expected; expected.push_back(x < y); expected.push_back(x == 10); expected.push_back(y == z); check((x < y) && (x == 10) && (y == z), expected); } { std::vector expected; expected.push_back((w == 1) || ((x == 10) && (y == z))); check((w == 1) || ((x == 10) && (y == z)), expected); } { std::vector expected; expected.push_back(x < y); expected.push_back((w == 1) || ((x == 10) && (y == z))); check((x < y) && ((w == 1) || ((x == 10) && (y == z))), expected); } std::cout << "Split predicate test passed" << std::endl; } struct ReductionDomainContents { mutable RefCount ref_count; std::vector domain; Expr predicate; bool frozen; ReductionDomainContents() : predicate(const_true()), frozen(false) { } // Pass an IRVisitor through to all Exprs referenced in the ReductionDomainContents void accept(IRVisitor *visitor) { for (const ReductionVariable &rvar : domain) { if (rvar.min.defined()) { rvar.min.accept(visitor); } if (rvar.extent.defined()) { rvar.extent.accept(visitor); } } if (predicate.defined()) { predicate.accept(visitor); } } // Pass an IRMutator through to all Exprs referenced in the ReductionDomainContents void mutate(IRMutator *mutator) { for (ReductionVariable &rvar : domain) { if (rvar.min.defined()) { rvar.min = mutator->mutate(rvar.min); } if (rvar.extent.defined()) { rvar.extent = mutator->mutate(rvar.extent); } } if (predicate.defined()) { predicate = mutator->mutate(predicate); } } }; template<> RefCount &ref_count(const ReductionDomainContents *p) noexcept { return p->ref_count; } template<> void destroy(const ReductionDomainContents *p) { delete p; } ReductionDomain::ReductionDomain(const std::vector &domain) : contents(new ReductionDomainContents) { contents->domain = domain; } ReductionDomain ReductionDomain::deep_copy() const { if (!contents.defined()) { return ReductionDomain(); } ReductionDomain copy(contents->domain); copy.contents->predicate = contents->predicate; copy.contents->frozen = contents->frozen; return copy; } const std::vector &ReductionDomain::domain() const { return contents->domain; } namespace { class DropSelfReferences : public IRMutator { using IRMutator::visit; Expr visit(const Variable *op) override { if (op->reduction_domain.defined()) { user_assert(op->reduction_domain.same_as(domain)) << "An RDom's predicate may only refer to its own RVars, " << " not the RVars of some other RDom. " << "Cannot set the predicate to : " << predicate << "\n"; return Variable::make(op->type, op->name); } else { return op; } } public: Expr predicate; const ReductionDomain &domain; DropSelfReferences(Expr p, const ReductionDomain &d) : predicate(std::move(p)), domain(d) { } }; } // namespace void ReductionDomain::set_predicate(const Expr &p) { // The predicate can refer back to the RDom. We need to break // those cycles to prevent a leak. contents->predicate = DropSelfReferences(p, *this).mutate(p); } void ReductionDomain::where(Expr predicate) { set_predicate(simplify(contents->predicate && std::move(predicate))); } Expr ReductionDomain::predicate() const { return contents->predicate; } std::vector ReductionDomain::split_predicate() const { std::vector predicates; split_into_ands(contents->predicate, predicates); return predicates; } void ReductionDomain::freeze() { contents->frozen = true; } bool ReductionDomain::frozen() const { return contents->frozen; } void ReductionDomain::accept(IRVisitor *visitor) const { if (contents.defined()) { contents->accept(visitor); } } void ReductionDomain::mutate(IRMutator *mutator) { if (contents.defined()) { contents->mutate(mutator); } } } // namespace Internal } // namespace Halide