Raw File
UniquifyVariableNames.cpp
#include "UniquifyVariableNames.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "Scope.h"
#include "Var.h"
#include <sstream>

namespace Halide {
namespace Internal {

using std::pair;
using std::string;
using std::vector;

namespace {
class UniquifyVariableNames : public IRMutator {

    using IRMutator::visit;

    // The mapping from old names to new names
    Scope<string> renaming;

    // Get a new previously unused name for a let binding or for loop,
    // and push it onto the renaming. Will return the original name if
    // possible, but pushes unconditionally to simplify cleanup.
    string make_new_name(const string &base) {
        if (!renaming.contains(base)) {
            renaming.push(base, base);
            return base;
        }
        for (size_t i = std::max((size_t)1, renaming.count(base));; i++) {
            string candidate = base + "_" + std::to_string(i);
            if (!renaming.contains(candidate)) {
                // Reserve this name for this base name
                renaming.push(base, candidate);
                // And reserve the generated name forever more (will not be popped)
                renaming.push(candidate, candidate);
                return candidate;
            }
        }
    }

    template<typename LetOrLetStmt>
    auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) {
        struct Frame {
            const LetOrLetStmt *op;
            Expr value;
            string new_name;
        };

        vector<Frame> frames;
        decltype(op->body) result;
        while (op) {
            frames.emplace_back();
            auto &f = frames.back();
            f.op = op;
            f.value = mutate(op->value);
            f.new_name = make_new_name(op->name);
            result = op->body;
            op = result.template as<LetOrLetStmt>();
        }

        result = mutate(result);

        for (auto it = frames.rbegin(); it != frames.rend(); it++) {
            renaming.pop(it->op->name);
            if (it->new_name == it->op->name &&
                result.same_as(it->op->body) &&
                it->op->value.same_as(it->value)) {
                result = it->op;
            } else {
                result = LetOrLetStmt::make(it->new_name, it->value, result);
            }
        }

        return result;
    }

    Stmt visit(const LetStmt *op) override {
        return visit_let(op);
    }

    Expr visit(const Let *op) override {
        return visit_let(op);
    }

    Stmt visit(const For *op) override {
        Expr min = mutate(op->min);
        Expr extent = mutate(op->extent);
        string new_name = make_new_name(op->name);
        Stmt body = mutate(op->body);
        renaming.pop(op->name);

        if (new_name == op->name &&
            body.same_as(op->body) &&
            min.same_as(op->min) &&
            extent.same_as(op->extent)) {
            return op;
        } else {
            return For::make(new_name, min, extent, op->for_type, op->device_api, body);
        }
    }

    Expr visit(const Variable *op) override {
        if (renaming.contains(op->name)) {
            string new_name = renaming.get(op->name);
            if (new_name != op->name) {
                return Variable::make(op->type, new_name);
            }
        }
        return op;
    }

public:
    UniquifyVariableNames(const Scope<string> *free_vars) {
        renaming.set_containing_scope(free_vars);
    }
};

class FindFreeVars : public IRVisitor {

    using IRVisitor::visit;

    Scope<> scope;

    void visit(const Variable *op) override {
        if (!scope.contains(op->name)) {
            free_vars.push(op->name, op->name);
        }
    }

    template<typename T>
    void visit_let(const T *op) {
        vector<ScopedBinding<>> frame;
        decltype(op->body) body;
        do {
            op->value.accept(this);
            frame.emplace_back(scope, op->name);
            body = op->body;
            op = body.template as<T>();
        } while (op);
        body.accept(this);
    }

    void visit(const Let *op) override {
        visit_let(op);
    }

    void visit(const LetStmt *op) override {
        visit_let(op);
    }

    void visit(const For *op) override {
        op->min.accept(this);
        op->extent.accept(this);
        {
            ScopedBinding<> bind(scope, op->name);
            op->body.accept(this);
        }
    }

public:
    Scope<string> free_vars;
};
}  // namespace

Stmt uniquify_variable_names(const Stmt &s) {
    FindFreeVars finder;
    s.accept(&finder);
    UniquifyVariableNames u(&finder.free_vars);
    return u.mutate(s);
}

void check(vector<pair<Var, Expr>> in,
           vector<pair<Var, Expr>> out) {
    Stmt in_stmt = Evaluate::make(0), out_stmt = Evaluate::make(0);
    for (auto it = in.rbegin(); it != in.rend(); it++) {
        in_stmt = LetStmt::make(it->first.name(), it->second, in_stmt);
    }
    for (auto it = out.rbegin(); it != out.rend(); it++) {
        out_stmt = LetStmt::make(it->first.name(), it->second, out_stmt);
    }

    Stmt s = uniquify_variable_names(in_stmt);

    internal_assert(equal(s, out_stmt))
        << "Failure in uniquify_variable_names\n"
        << "Input:\n"
        << in_stmt << "\n"
        << "Produced:\n"
        << s << "\n"
        << "Correct output:\n"
        << out_stmt << "\n";
}

void uniquify_variable_names_test() {
    Var x("x"), x_1("x_1"), x_2("x_2"), x_3{"x_3"};
    Var y("y"), y_1("y_1"), y_2("y_2"), y_3{"y_3"};

    // Stmts with all names already unique should be unchanged
    check({{x, 3},
           {y, x}},
          {{x, 3},
           {y, x}});

    // Shadowed definitions of Vars should be given unique names
    check({{x, 3},
           {y, x},
           {x, x + y},
           {y, x + y},
           {x, x + y},
           {y, x + y}},
          {{x, 3},
           {y, x},
           {x_1, x + y},
           {y_1, x_1 + y},
           {x_2, x_1 + y_1},
           {y_2, x_2 + y_1}});

    // Check a case with a free var after then end of the scope of a let of the same name
    check({{x, Let::make(y.name(), 3, y)},      // y is bound
           {x, y}},                             // This is not the same y. It's free and can't be renamed.
          {{x, Let::make(y_1.name(), 3, y_1)},  // We rename the bound one
           {x_1, y}});

    // An existing in-scope use of one of the names that would be
    // autogenerated should be skipped over
    check({{x_1, 8},
           {x, 3},
           {y, x},
           {x, x + y},
           {y, x + y},
           {x, x + y},
           {y, x + y}},
          {{x_1, 8},
           {x, 3},
           {y, x},
           {x_2, x + y},
           {y_1, x_2 + y},
           {x_3, x_2 + y_1},
           {y_2, x_3 + y_1}});

    // Check parallel bindings. The scope doesn't overlap so they can keep their name
    check({{x, Let::make(y.name(), 3, y)},
           {x, Let::make(y.name(), 4, y)}},
          {{x, Let::make(y.name(), 3, y)},
           {x_1, Let::make(y.name(), 4, y)}});

    std::cout << "uniquify_variable_names test passed" << std::endl;
}

}  // namespace Internal
}  // namespace Halide
back to top