https://github.com/halide/Halide
Raw File
Tip revision: 1d07ebdaaa5359eba6f7a15a20bbfbeb9441214c authored by Andrew Adams on 08 February 2022, 21:55:58 UTC
Add comment
Tip revision: 1d07ebd
Substitute.cpp
#include "Substitute.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "Scope.h"

namespace Halide {
namespace Internal {

using std::map;
using std::string;

namespace {

class Substitute : public IRMutator {
    const map<string, Expr> &replace;
    Scope<> hidden;

    Expr find_replacement(const string &s) {
        map<string, Expr>::const_iterator iter = replace.find(s);
        if (iter != replace.end() && !hidden.contains(s)) {
            return iter->second;
        } else {
            return Expr();
        }
    }

public:
    Substitute(const map<string, Expr> &m)
        : replace(m) {
    }

    using IRMutator::visit;

    Expr visit(const Variable *v) override {
        Expr r = find_replacement(v->name);
        if (r.defined()) {
            return r;
        } else {
            return v;
        }
    }

    template<typename T>
    auto visit_let(const T *op) -> decltype(op->body) {
        decltype(op->body) orig = op;

        struct Frame {
            const T *op;
            Expr new_value;
            ScopedBinding<> bind;
        };
        std::vector<Frame> frames;
        decltype(op->body) body;
        bool values_unchanged = true;
        do {
            Expr new_value = mutate(op->value);
            values_unchanged &= new_value.same_as(op->value);
            frames.push_back(Frame{op, std::move(new_value), ScopedBinding<>(hidden, op->name)});
            body = op->body;
            op = body.template as<T>();
        } while (op);

        auto new_body = mutate(body);

        if (values_unchanged &&
            new_body.same_as(body)) {
            return orig;
        } else {
            for (auto it = frames.rbegin(); it != frames.rend(); it++) {
                new_body = T::make(it->op->name, it->new_value, new_body);
            }
            return new_body;
        }
    }

    Expr visit(const Let *op) override {
        return visit_let(op);
    }

    Stmt visit(const LetStmt *op) override {
        return visit_let(op);
    }

    Stmt visit(const For *op) override {
        Expr new_min = mutate(op->min);
        Expr new_extent = mutate(op->extent);
        hidden.push(op->name);
        Stmt new_body = mutate(op->body);
        hidden.pop(op->name);

        if (new_min.same_as(op->min) &&
            new_extent.same_as(op->extent) &&
            new_body.same_as(op->body)) {
            return op;
        } else {
            return For::make(op->name, new_min, new_extent, op->for_type, op->device_api, new_body);
        }
    }
};

}  // namespace

Expr substitute(const string &name, const Expr &replacement, const Expr &expr) {
    map<string, Expr> m;
    m[name] = replacement;
    Substitute s(m);
    return s.mutate(expr);
}

Stmt substitute(const string &name, const Expr &replacement, const Stmt &stmt) {
    map<string, Expr> m;
    m[name] = replacement;
    Substitute s(m);
    return s.mutate(stmt);
}

Expr substitute(const map<string, Expr> &m, const Expr &expr) {
    Substitute s(m);
    return s.mutate(expr);
}

Stmt substitute(const map<string, Expr> &m, const Stmt &stmt) {
    Substitute s(m);
    return s.mutate(stmt);
}

namespace {

class SubstituteExpr : public IRMutator {
public:
    Expr find, replacement;

    using IRMutator::mutate;

    Expr mutate(const Expr &e) override {
        if (equal(e, find)) {
            return replacement;
        } else {
            return IRMutator::mutate(e);
        }
    }
};

}  // namespace

Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr) {
    SubstituteExpr s;
    s.find = find;
    s.replacement = replacement;
    return s.mutate(expr);
}

Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) {
    SubstituteExpr s;
    s.find = find;
    s.replacement = replacement;
    return s.mutate(stmt);
}

namespace {

/** Substitute an expr for a var in a graph. */
class GraphSubstitute : public IRGraphMutator {
    string var;
    Expr value;

    using IRGraphMutator::visit;

    Expr visit(const Variable *op) override {
        if (op->name == var) {
            return value;
        } else {
            return op;
        }
    }

    Expr visit(const Let *op) override {
        Expr new_value = mutate(op->value);
        if (op->name == var) {
            return Let::make(op->name, new_value, op->body);
        } else {
            return Let::make(op->name, new_value, mutate(op->body));
        }
    }

public:
    GraphSubstitute(const string &var, const Expr &value)
        : var(var), value(value) {
    }
};

/** Substitute an Expr for another Expr in a graph. Unlike substitute,
 * this only checks for shallow equality. */
class GraphSubstituteExpr : public IRGraphMutator {
    Expr find, replace;

public:
    using IRGraphMutator::mutate;

    Expr mutate(const Expr &e) override {
        if (e.same_as(find)) {
            return replace;
        } else {
            return IRGraphMutator::mutate(e);
        }
    }

    GraphSubstituteExpr(const Expr &find, const Expr &replace)
        : find(find), replace(replace) {
    }
};

}  // namespace

Expr graph_substitute(const string &name, const Expr &replacement, const Expr &expr) {
    return GraphSubstitute(name, replacement).mutate(expr);
}

Stmt graph_substitute(const string &name, const Expr &replacement, const Stmt &stmt) {
    return GraphSubstitute(name, replacement).mutate(stmt);
}

Expr graph_substitute(const Expr &find, const Expr &replacement, const Expr &expr) {
    return GraphSubstituteExpr(find, replacement).mutate(expr);
}

Stmt graph_substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) {
    return GraphSubstituteExpr(find, replacement).mutate(stmt);
}

namespace {

class SubstituteInAllLets : public IRGraphMutator {

    using IRGraphMutator::visit;

    Expr visit(const Let *op) override {
        Expr value = mutate(op->value);
        Expr body = mutate(op->body);
        return graph_substitute(op->name, value, body);
    }
};

}  // namespace

Expr substitute_in_all_lets(const Expr &expr) {
    return SubstituteInAllLets().mutate(expr);
}

Stmt substitute_in_all_lets(const Stmt &stmt) {
    return SubstituteInAllLets().mutate(stmt);
}

}  // namespace Internal
}  // namespace Halide
back to top