#include "Substitute.h" #include "IREquality.h" #include "IRMutator.h" #include "Scope.h" namespace Halide { namespace Internal { using std::map; using std::string; namespace { class Substitute : public IRMutator { const map &replace; Scope<> hidden; Expr find_replacement(const string &s) { map::const_iterator iter = replace.find(s); if (iter != replace.end() && !hidden.contains(s)) { return iter->second; } else { return Expr(); } } public: Substitute(const map &m) : replace(m) { } using IRMutator::visit; Expr visit(const Variable *v) override { Expr r = find_replacement(v->name); if (r.defined()) { return r; } else { return v; } } template auto visit_let(const T *op) -> decltype(op->body) { decltype(op->body) orig = op; struct Frame { const T *op; Expr new_value; ScopedBinding<> bind; }; std::vector frames; decltype(op->body) body; bool values_unchanged = true; do { Expr new_value = mutate(op->value); values_unchanged &= new_value.same_as(op->value); frames.push_back(Frame{op, std::move(new_value), ScopedBinding<>(hidden, op->name)}); body = op->body; op = body.template as(); } while (op); auto new_body = mutate(body); if (values_unchanged && new_body.same_as(body)) { return orig; } else { for (auto it = frames.rbegin(); it != frames.rend(); it++) { new_body = T::make(it->op->name, it->new_value, new_body); } return new_body; } } Expr visit(const Let *op) override { return visit_let(op); } Stmt visit(const LetStmt *op) override { return visit_let(op); } Stmt visit(const For *op) override { Expr new_min = mutate(op->min); Expr new_extent = mutate(op->extent); hidden.push(op->name); Stmt new_body = mutate(op->body); hidden.pop(op->name); if (new_min.same_as(op->min) && new_extent.same_as(op->extent) && new_body.same_as(op->body)) { return op; } else { return For::make(op->name, new_min, new_extent, op->for_type, op->device_api, new_body); } } }; } // namespace Expr substitute(const string &name, const Expr &replacement, const Expr &expr) { map m; m[name] = replacement; Substitute s(m); return s.mutate(expr); } Stmt substitute(const string &name, const Expr &replacement, const Stmt &stmt) { map m; m[name] = replacement; Substitute s(m); return s.mutate(stmt); } Expr substitute(const map &m, const Expr &expr) { Substitute s(m); return s.mutate(expr); } Stmt substitute(const map &m, const Stmt &stmt) { Substitute s(m); return s.mutate(stmt); } namespace { class SubstituteExpr : public IRMutator { public: Expr find, replacement; using IRMutator::mutate; Expr mutate(const Expr &e) override { if (equal(e, find)) { return replacement; } else { return IRMutator::mutate(e); } } }; } // namespace Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr) { SubstituteExpr s; s.find = find; s.replacement = replacement; return s.mutate(expr); } Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) { SubstituteExpr s; s.find = find; s.replacement = replacement; return s.mutate(stmt); } namespace { /** Substitute an expr for a var in a graph. */ class GraphSubstitute : public IRGraphMutator { string var; Expr value; using IRGraphMutator::visit; Expr visit(const Variable *op) override { if (op->name == var) { return value; } else { return op; } } Expr visit(const Let *op) override { Expr new_value = mutate(op->value); if (op->name == var) { return Let::make(op->name, new_value, op->body); } else { return Let::make(op->name, new_value, mutate(op->body)); } } public: GraphSubstitute(const string &var, const Expr &value) : var(var), value(value) { } }; /** Substitute an Expr for another Expr in a graph. Unlike substitute, * this only checks for shallow equality. */ class GraphSubstituteExpr : public IRGraphMutator { Expr find, replace; public: using IRGraphMutator::mutate; Expr mutate(const Expr &e) override { if (e.same_as(find)) { return replace; } else { return IRGraphMutator::mutate(e); } } GraphSubstituteExpr(const Expr &find, const Expr &replace) : find(find), replace(replace) { } }; } // namespace Expr graph_substitute(const string &name, const Expr &replacement, const Expr &expr) { return GraphSubstitute(name, replacement).mutate(expr); } Stmt graph_substitute(const string &name, const Expr &replacement, const Stmt &stmt) { return GraphSubstitute(name, replacement).mutate(stmt); } Expr graph_substitute(const Expr &find, const Expr &replacement, const Expr &expr) { return GraphSubstituteExpr(find, replacement).mutate(expr); } Stmt graph_substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) { return GraphSubstituteExpr(find, replacement).mutate(stmt); } namespace { class SubstituteInAllLets : public IRGraphMutator { using IRGraphMutator::visit; Expr visit(const Let *op) override { Expr value = mutate(op->value); Expr body = mutate(op->body); return graph_substitute(op->name, value, body); } }; } // namespace Expr substitute_in_all_lets(const Expr &expr) { return SubstituteInAllLets().mutate(expr); } Stmt substitute_in_all_lets(const Stmt &stmt) { return SubstituteInAllLets().mutate(stmt); } } // namespace Internal } // namespace Halide