https://github.com/halide/Halide
Raw File
Tip revision: cdfd2b747f3deca917a4371036833076f602fe7c authored by Andrew Adams on 10 February 2016, 01:16:20 UTC
Reschedule camera pipe
Tip revision: cdfd2b7
StorageFlattening.cpp
#include <sstream>

#include "StorageFlattening.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Scope.h"
#include "Bounds.h"

namespace Halide {
namespace Internal {

using std::ostringstream;
using std::string;
using std::vector;
using std::map;

namespace {
// Visitor and helper function to test if a piece of IR uses an extern image.
class UsesExternImage : public IRVisitor {
    using IRVisitor::visit;

    void visit(const Call *c) {
        if (c->call_type == Call::Image) {
            result = true;
        } else {
            IRVisitor::visit(c);
        }
    }
public:
    UsesExternImage() : result(false) {}
    bool result;
};

inline bool uses_extern_image(Stmt s) {
    UsesExternImage uses;
    s.accept(&uses);
    return uses.result;
}
}

class FlattenDimensions : public IRMutator {
public:
    FlattenDimensions(const vector<Function> &outputs, const map<string, Function> &e)
        : outputs(outputs), env(e) {}
    Scope<int> scope;
private:
    const vector<Function> &outputs;
    const map<string, Function> &env;
    Scope<int> realizations;

    Expr flatten_args(const string &name, const vector<Expr> &args,
                      bool internal) {
        Expr idx = 0;
        vector<Expr> mins(args.size()), strides(args.size());

        for (size_t i = 0; i < args.size(); i++) {
            string dim = std::to_string(i);
            string stride_name = name + ".stride." + dim;
            string min_name = name + ".min." + dim;
            string stride_name_constrained = stride_name + ".constrained";
            string min_name_constrained = min_name + ".constrained";
            if (scope.contains(stride_name_constrained)) {
                stride_name = stride_name_constrained;
            }
            if (scope.contains(min_name_constrained)) {
                min_name = min_name_constrained;
            }
            strides[i] = Variable::make(Int(32), stride_name);
            mins[i] = Variable::make(Int(32), min_name);
        }

        if (internal) {
            // f(x, y) -> f[(x-xmin)*xstride + (y-ymin)*ystride] This
            // strategy makes sense when we expect x to cancel with
            // something in xmin.  We use this for internal allocations
            for (size_t i = 0; i < args.size(); i++) {
                idx += (args[i] - mins[i]) * strides[i];
            }
        } else {
            // f(x, y) -> f[x*stride + y*ystride - (xstride*xmin +
            // ystride*ymin)]. The idea here is that the last term
            // will be pulled outside the inner loop. We use this for
            // external buffers, where the mins and strides are likely
            // to be symbolic
            Expr base = 0;
            for (size_t i = 0; i < args.size(); i++) {
                idx += args[i] * strides[i];
                base += mins[i] * strides[i];
            }
            idx -= base;
        }

        return idx;
    }

    using IRMutator::visit;

    void visit(const Realize *realize) {
        realizations.push(realize->name, 1);

        Stmt body = mutate(realize->body);

        // Compute the size
        std::vector<Expr> extents;
        for (size_t i = 0; i < realize->bounds.size(); i++) {
          extents.push_back(realize->bounds[i].extent);
          extents[i] = mutate(extents[i]);
        }
        Expr condition = mutate(realize->condition);

        realizations.pop(realize->name);

        vector<int> storage_permutation;
        {
            map<string, Function>::const_iterator iter = env.find(realize->name);
            internal_assert(iter != env.end()) << "Realize node refers to function not in environment.\n";
            const vector<string> &storage_dims = iter->second.schedule().storage_dims();
            const vector<string> &args = iter->second.args();
            for (size_t i = 0; i < storage_dims.size(); i++) {
                for (size_t j = 0; j < args.size(); j++) {
                    if (args[j] == storage_dims[i]) {
                        storage_permutation.push_back((int)j);
                    }
                }
                internal_assert(storage_permutation.size() == i+1);
            }
        }

        internal_assert(storage_permutation.size() == realize->bounds.size());

        stmt = body;
        for (size_t idx = 0; idx < realize->types.size(); idx++) {
            string buffer_name = realize->name;
            if (realize->types.size() > 1) {
                buffer_name = buffer_name + '.' + std::to_string(idx);
            }

            // Make the names for the mins, extents, and strides
            int dims = realize->bounds.size();
            vector<string> min_name(dims), extent_name(dims), stride_name(dims);
            for (int i = 0; i < dims; i++) {
                string d = std::to_string(i);
                min_name[i] = buffer_name + ".min." + d;
                stride_name[i] = buffer_name + ".stride." + d;
                extent_name[i] = buffer_name + ".extent." + d;
            }
            vector<Expr> min_var(dims), extent_var(dims), stride_var(dims);
            for (int i = 0; i < dims; i++) {
                min_var[i] = Variable::make(Int(32), min_name[i]);
                extent_var[i] = Variable::make(Int(32), extent_name[i]);
                stride_var[i] = Variable::make(Int(32), stride_name[i]);
            }

            // Promote the type to be a multiple of 8 bits
            Type t = realize->types[idx].with_bits(realize->types[idx].bytes() * 8);

            // Create a buffer_t object for this allocation.
            vector<Expr> args(dims*3 + 2);
            //args[0] = Call::make(Handle(), Call::null_handle, vector<Expr>(), Call::Intrinsic);
            Expr first_elem = Load::make(t, buffer_name, 0, Buffer(), Parameter());
            args[0] = Call::make(Handle(), Call::address_of, {first_elem}, Call::Intrinsic);
            args[1] = make_zero(realize->types[idx]);
            for (int i = 0; i < dims; i++) {
                args[3*i+2] = min_var[i];
                args[3*i+3] = extent_var[i];
                args[3*i+4] = stride_var[i];
            }
            Expr buf = Call::make(Handle(), Call::create_buffer_t,
                                  args, Call::Intrinsic);
            stmt = LetStmt::make(buffer_name + ".buffer",
                                 buf,
                                 stmt);

            // Make the allocation node
            stmt = Allocate::make(buffer_name, t, extents, condition, stmt);

            // Compute the strides
            for (int i = (int)realize->bounds.size()-1; i > 0; i--) {
                int prev_j = storage_permutation[i-1];
                int j = storage_permutation[i];
                Expr stride = stride_var[prev_j] * extent_var[prev_j];
                stmt = LetStmt::make(stride_name[j], stride, stmt);
            }
            // Innermost stride is one
            if (dims > 0) {
                int innermost = storage_permutation.empty() ? 0 : storage_permutation[0];
                stmt = LetStmt::make(stride_name[innermost], 1, stmt);
            }

            // Assign the mins and extents stored
            for (size_t i = realize->bounds.size(); i > 0; i--) {
                stmt = LetStmt::make(min_name[i-1], realize->bounds[i-1].min, stmt);
                stmt = LetStmt::make(extent_name[i-1], realize->bounds[i-1].extent, stmt);
            }
        }
    }

    struct ProvideValue {
        Expr value;
        string name;
    };

    void flatten_provide_values(vector<ProvideValue> &values, const Provide *provide) {
        values.resize(provide->values.size());

        for (size_t i = 0; i < values.size(); i++) {
            Expr value = mutate(provide->values[i]);

            // Promote the type to be a multiple of 8 bits
            Type t = value.type().with_bits(value.type().bytes() * 8);
            if (t.bits() != value.type().bits()) {
                value = Cast::make(t, value);
            }

            values[i].value = value;
            if (values.size() > 1) {
                values[i].name = provide->name + "." + std::to_string(i);
            } else {
                values[i].name = provide->name;
            }
        }
    }

    // Lower a set of provides
    Stmt flatten_provide_atomic(const Provide *provide) {
        vector<ProvideValue> values;
        flatten_provide_values(values, provide);

        bool is_output = false;
        for (Function f : outputs) {
            is_output |= f.name() == provide->name;
            if (is_output) break;
        }

        Stmt result;
        for (size_t i = 0; i < values.size(); i++) {
            const ProvideValue &cv = values[i];

            Expr idx = mutate(flatten_args(cv.name, provide->args, !is_output));
            Expr var = Variable::make(cv.value.type(), cv.name + ".value");
            Stmt store = Store::make(cv.name, var, idx);

            if (result.defined()) {
                result = Block::make(result, store);
            } else {
                result = store;
            }
        }

        for (size_t i = values.size(); i > 0; i--) {
            const ProvideValue &cv = values[i-1];

            result = LetStmt::make(cv.name + ".value", cv.value, result);
        }
        return result;
    }

    Stmt flatten_provide(const Provide *provide) {
        vector<ProvideValue> values;
        flatten_provide_values(values, provide);

        bool is_output = false;
        for (Function f : outputs) {
            is_output |= f.name() == provide->name;
        }

        Stmt result;
        for (size_t i = 0; i < values.size(); i++) {
            const ProvideValue &cv = values[i];

            Expr idx = mutate(flatten_args(cv.name, provide->args, !is_output));
            Stmt store = Store::make(cv.name, cv.value, idx);

            if (result.defined()) {
                result = Block::make(result, store);
            } else {
                result = store;
            }
        }
        return result;
    }

    void visit(const Provide *provide) {
        Stmt result;

        // Handle the provide atomically if necessary. This logic is
        // currently very conservative, it will lower many provides
        // atomically that do not require it.
        if (provide->values.size() == 1) {
            // If there is only one value, we don't need to worry
            // about atomicity.
            result = flatten_provide(provide);
        } else if (!realizations.contains(provide->name) &&
                   uses_extern_image(provide)) {
            // If the provide is not a realization and it uses an
            // input image, it might be aliased. Flatten it atomically
            // because we can't prove the boxes don't overlap.
            result = flatten_provide_atomic(provide);
        } else {
            Box provided = box_provided(Stmt(provide), provide->name);
            Box required = box_required(Stmt(provide), provide->name);

            if (boxes_overlap(provided, required)) {
                // The boxes provided and required might overlap, so
                // the provide must be done atomically.
                result = flatten_provide_atomic(provide);
            } else {
                // The boxes don't overlap.
                result = flatten_provide(provide);
            }
        }
        stmt = result;
    }

    void visit(const Call *call) {

        if (call->call_type == Call::Extern || call->call_type == Call::Intrinsic) {
            vector<Expr> args(call->args.size());
            bool changed = false;
            for (size_t i = 0; i < args.size(); i++) {
                args[i] = mutate(call->args[i]);
                if (!args[i].same_as(call->args[i])) changed = true;
            }
            if (!changed) {
                expr = call;
            } else {
                expr = Call::make(call->type, call->name, args, call->call_type);
            }
        } else {
            string name = call->name;
            if (call->call_type == Call::Halide &&
                call->func.outputs() > 1) {
                name = name + '.' + std::to_string(call->value_index);
            }

            bool is_output = false;
            for (Function f : outputs) {
                is_output |= f.name() == call->name;
            }

            bool is_input = env.find(call->name) == env.end();

            // Promote the type to be a multiple of 8 bits
            Type t = call->type.with_bits(call->type.bytes() * 8);

            Expr idx = mutate(flatten_args(name, call->args, !(is_output || is_input)));
            expr = Load::make(t, name, idx, call->image, call->param);

            if (call->type.bits() != t.bits()) {
                expr = Cast::make(call->type, expr);
            }
        }
    }

    void visit(const LetStmt *let) {
        // Discover constrained versions of things.
        bool constrained_version_exists = ends_with(let->name, ".constrained");
        if (constrained_version_exists) {
            scope.push(let->name, 0);
        }

        IRMutator::visit(let);

        if (constrained_version_exists) {
            scope.pop(let->name);
        }
    }
};


Stmt storage_flattening(Stmt s,
                        const vector<Function> &outputs,
                        const map<string, Function> &env) {
    return FlattenDimensions(outputs, env).mutate(s);
}

}
}
back to top