Raw File
SplitTuples.cpp
#include "SplitTuples.h"
#include "Bounds.h"
#include "IRMutator.h"

namespace Halide {
namespace Internal {

using std::map;
using std::pair;
using std::set;
using std::string;
using std::vector;

namespace {

// Collect all value indices of internal halide calls.
class FindCallValueIndices : public IRVisitor {
public:
    const string func;
    map<string, set<int>> func_value_indices;

    using IRVisitor::visit;

    void visit(const Call *call) {
        IRVisitor::visit(call);
        if ((call->call_type == Call::Halide) && call->func.defined()) {
            func_value_indices[call->name].insert(call->value_index);
        }
    }
};

// Visitor and helper function to test if a piece of IR uses an extern image.
class UsesExternImage : public IRVisitor {
    using IRVisitor::visit;

    void visit(const Call *c) {
        if (c->call_type == Call::Image) {
            result = true;
        } else {
            IRVisitor::visit(c);
        }
    }
public:
    UsesExternImage() : result(false) {}
    bool result;
};

inline bool uses_extern_image(Stmt s) {
    UsesExternImage uses;
    s.accept(&uses);
    return uses.result;
}

class SplitTuples : public IRMutator2 {
    using IRMutator2::visit;

    map<string, set<int>> func_value_indices;

    Stmt visit(const Realize *op) override {
        ScopedBinding<int> bind(realizations, op->name, 0);
        if (op->types.size() > 1) {
            // Make a nested set of realize nodes for each tuple element
            Stmt body = mutate(op->body);
            for (int i = (int)op->types.size() - 1; i >= 0; i--) {
                body = Realize::make(op->name + "." + std::to_string(i), {op->types[i]}, op->memory_type, op->bounds, op->condition, body);
            }
            return body;
        } else {
            return IRMutator2::visit(op);
        }
    }

    Stmt visit(const For *op) override {
        map<string, set<int>> old_func_value_indices = func_value_indices;

        FindCallValueIndices find;
        op->body.accept(&find);

        func_value_indices = find.func_value_indices;
        Stmt stmt = IRMutator2::visit(op);
        func_value_indices = old_func_value_indices;
        return stmt;
    }

    Stmt visit(const Prefetch *op) override {
        if (!op->prefetch.param.defined() && (op->types.size() > 1)) {
            Stmt body = mutate(op->body);
            // Split the prefetch from a multi-dimensional halide tuple to
            // prefetches of each tuple element. Keep only prefetches of
            // elements that are actually used in the loop body.
            const auto &indices = func_value_indices.find(op->name);
            internal_assert(indices != func_value_indices.end());

            for (const auto &idx : indices->second) {
                internal_assert(idx < (int)op->types.size());
                body = Prefetch::make(op->name + "." + std::to_string(idx), {op->types[(idx)]}, op->bounds, op->prefetch, op->condition, body);
            }
            return body;
        } else {
            return IRMutator2::visit(op);
        }
    }

    Expr visit(const Call *op) override {
        if (op->call_type == Call::Halide) {
            auto it = env.find(op->name);
            internal_assert(it != env.end());
            Function f = it->second;
            string name = op->name;
            if (f.outputs() > 1) {
                name += "." + std::to_string(op->value_index);
            }
            vector<Expr> args;
            for (Expr e : op->args) {
                args.push_back(mutate(e));
            }
            // It's safe to hook up the pointer to the function
            // unconditionally. This expr never gets held by a
            // Function, so there can't be a cycle. We do this even
            // for scalar provides.
            return Call::make(op->type, name, args, op->call_type, f.get_contents());
        } else {
            return IRMutator2::visit(op);
        }
    }

    Stmt visit(const Provide *op) override {
        if (op->values.size() == 1) {
            return IRMutator2::visit(op);
        }

        // Detect if the provide needs to be lowered atomically. By
        // this we mean can we compute and store the values one at a
        // time (not atomic), or must we compute them all, and then
        // store them all (atomic).
        bool atomic = false;
        if (!realizations.contains(op->name) &&
            uses_extern_image(op)) {
            // If the provide is an output (it's not inside a
            // realization), and it uses an input, then the input
            // might alias with the output. We'd better just do it
            // atomically.
            atomic = true;
        } else {
            // If the boxes provided and required might overlap,
            // the provide must be done atomically.
            Box provided = box_provided(op, op->name);
            Box required = box_required(op, op->name);
            atomic = boxes_overlap(provided, required);
        }

        // Mutate the args
        vector<Expr> args;
        for (Expr e : op->args) {
            args.push_back(mutate(e));
        }

        // Get the Function
        auto it = env.find(op->name);
        internal_assert(it != env.end());
        Function f = it->second;

        // Build a list of scalar provide statements, and a list of
        // lets to wrap them.
        vector<Stmt> provides;
        vector<pair<string, Expr>> lets;

        for (size_t i = 0; i < op->values.size(); i++) {
            string name = op->name + "." + std::to_string(i);
            string var_name = name + ".value";
            Expr val = mutate(op->values[i]);
            if (!is_undef(val) && atomic) {
                lets.push_back({ var_name, val });
                val = Variable::make(val.type(), var_name);
            }
            provides.push_back(Provide::make(name, {val}, args));
        }

        Stmt result = Block::make(provides);

        while (!lets.empty()) {
            auto p = lets.back();
            lets.pop_back();
            result = LetStmt::make(p.first, p.second, result);
        }

        return result;
    }

    const map<string, Function> &env;
    Scope<int> realizations;

public:

    SplitTuples(const map<string, Function> &e) : env(e) {}
};

}  // namespace

Stmt split_tuples(Stmt s, const map<string, Function> &env) {
    return SplitTuples(env).mutate(s);
}

}  // namespace Internal
}  // namespace Halide
back to top