https://github.com/halide/Halide
Raw File
Tip revision: b5e8217dfbe905ef1f30f7bd584a83dfb9e2260e authored by Steven Johnson on 19 January 2023, 19:40:52 UTC
Remove the watchdog timer from generator_main(). It was intended to kill pathologically slow builds, but in the environment it was added for (Google build servers), it ended up being redundant to existing mechanisms, and removing it allows us to remove a dependency on threading libraries in libHalide.
Tip revision: b5e8217
Substitute.cpp
#include "Substitute.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "Scope.h"

namespace Halide {
namespace Internal {

using std::map;
using std::string;

namespace {

class Substitute : public IRMutator {
    const map<string, Expr> &replace;
    Scope<> hidden;

    Expr find_replacement(const string &s) {
        map<string, Expr>::const_iterator iter = replace.find(s);
        if (iter != replace.end() && !hidden.contains(s)) {
            return iter->second;
        } else {
            return Expr();
        }
    }

public:
    Substitute(const map<string, Expr> &m)
        : replace(m) {
    }

    using IRMutator::visit;

    Expr visit(const Variable *v) override {
        Expr r = find_replacement(v->name);
        if (r.defined()) {
            return r;
        } else {
            return v;
        }
    }

    template<typename T>
    auto visit_let(const T *op) -> decltype(op->body) {
        decltype(op->body) orig = op;

        struct Frame {
            const T *op;
            Expr new_value;
            ScopedBinding<> bind;
        };
        std::vector<Frame> frames;
        decltype(op->body) body;
        bool values_unchanged = true;
        do {
            Expr new_value = mutate(op->value);
            values_unchanged &= new_value.same_as(op->value);
            frames.push_back(Frame{op, std::move(new_value), ScopedBinding<>(hidden, op->name)});
            body = op->body;
            op = body.template as<T>();
        } while (op);

        auto new_body = mutate(body);

        if (values_unchanged &&
            new_body.same_as(body)) {
            return orig;
        } else {
            for (auto it = frames.rbegin(); it != frames.rend(); it++) {
                new_body = T::make(it->op->name, it->new_value, new_body);
            }
            return new_body;
        }
    }

    Expr visit(const Let *op) override {
        return visit_let(op);
    }

    Stmt visit(const LetStmt *op) override {
        return visit_let(op);
    }

    Stmt visit(const For *op) override {
        Expr new_min = mutate(op->min);
        Expr new_extent = mutate(op->extent);
        hidden.push(op->name);
        Stmt new_body = mutate(op->body);
        hidden.pop(op->name);

        if (new_min.same_as(op->min) &&
            new_extent.same_as(op->extent) &&
            new_body.same_as(op->body)) {
            return op;
        } else {
            return For::make(op->name, new_min, new_extent, op->for_type, op->device_api, new_body);
        }
    }
};

}  // namespace

Expr substitute(const string &name, const Expr &replacement, const Expr &expr) {
    map<string, Expr> m;
    m[name] = replacement;
    Substitute s(m);
    return s.mutate(expr);
}

Stmt substitute(const string &name, const Expr &replacement, const Stmt &stmt) {
    map<string, Expr> m;
    m[name] = replacement;
    Substitute s(m);
    return s.mutate(stmt);
}

Expr substitute(const map<string, Expr> &m, const Expr &expr) {
    Substitute s(m);
    return s.mutate(expr);
}

Stmt substitute(const map<string, Expr> &m, const Stmt &stmt) {
    Substitute s(m);
    return s.mutate(stmt);
}

namespace {

class SubstituteExpr : public IRMutator {
public:
    Expr find, replacement;

    using IRMutator::mutate;

    Expr mutate(const Expr &e) override {
        if (equal(e, find)) {
            return replacement;
        } else {
            return IRMutator::mutate(e);
        }
    }
};

}  // namespace

Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr) {
    SubstituteExpr s;
    s.find = find;
    s.replacement = replacement;
    return s.mutate(expr);
}

Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) {
    SubstituteExpr s;
    s.find = find;
    s.replacement = replacement;
    return s.mutate(stmt);
}

namespace {

/** Substitute an expr for a var in a graph. */
class GraphSubstitute : public IRGraphMutator {
    string var;
    Expr value;

    using IRGraphMutator::visit;

    Expr visit(const Variable *op) override {
        if (op->name == var) {
            return value;
        } else {
            return op;
        }
    }

    Expr visit(const Let *op) override {
        Expr new_value = mutate(op->value);
        if (op->name == var) {
            return Let::make(op->name, new_value, op->body);
        } else {
            return Let::make(op->name, new_value, mutate(op->body));
        }
    }

public:
    GraphSubstitute(const string &var, const Expr &value)
        : var(var), value(value) {
    }
};

/** Substitute an Expr for another Expr in a graph. Unlike substitute,
 * this only checks for shallow equality. */
class GraphSubstituteExpr : public IRGraphMutator {
    Expr find, replace;

public:
    using IRGraphMutator::mutate;

    Expr mutate(const Expr &e) override {
        if (e.same_as(find)) {
            return replace;
        } else {
            return IRGraphMutator::mutate(e);
        }
    }

    GraphSubstituteExpr(const Expr &find, const Expr &replace)
        : find(find), replace(replace) {
    }
};

}  // namespace

Expr graph_substitute(const string &name, const Expr &replacement, const Expr &expr) {
    return GraphSubstitute(name, replacement).mutate(expr);
}

Stmt graph_substitute(const string &name, const Expr &replacement, const Stmt &stmt) {
    return GraphSubstitute(name, replacement).mutate(stmt);
}

Expr graph_substitute(const Expr &find, const Expr &replacement, const Expr &expr) {
    return GraphSubstituteExpr(find, replacement).mutate(expr);
}

Stmt graph_substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) {
    return GraphSubstituteExpr(find, replacement).mutate(stmt);
}

namespace {

class SubstituteInAllLets : public IRGraphMutator {

    using IRGraphMutator::visit;

    Expr visit(const Let *op) override {
        Expr value = mutate(op->value);
        Expr body = mutate(op->body);
        return graph_substitute(op->name, value, body);
    }
};

}  // namespace

Expr substitute_in_all_lets(const Expr &expr) {
    return SubstituteInAllLets().mutate(expr);
}

Stmt substitute_in_all_lets(const Stmt &stmt) {
    return SubstituteInAllLets().mutate(stmt);
}

}  // namespace Internal
}  // namespace Halide
back to top