https://github.com/halide/Halide
Tip revision: 1ebf2598bce0b92a457556dc619139dc7686ee37 authored by Andrew Adams on 14 February 2017, 21:59:30 UTC
Better Affine Var tracking
Better Affine Var tracking
Tip revision: 1ebf259
Prefetch.cpp
#include <algorithm>
#include <map>
#include <string>
#include "Prefetch.h"
#include "IRMutator.h"
#include "Bounds.h"
#include "Scope.h"
#include "Util.h"
namespace Halide {
namespace Internal {
using std::map;
using std::string;
using std::vector;
namespace {
// We need to be able to make loads from a buffer that refer to the
// same original image/param/etc. This visitor finds a load to the
// buffer we want to load from, and generates a similar load, but with
// different args.
class MakeSimilarLoad : public IRVisitor {
public:
const string &buf_name;
const vector<Expr> &args;
Expr load;
MakeSimilarLoad(const string &name, const vector<Expr> &args)
: buf_name(name), args(args) {}
private:
using IRVisitor::visit;
void visit(const Call *op) {
if (op->name == buf_name) {
load = Call::make(op->type, op->name, args, op->call_type, op->func, op->value_index, op->image, op->param);
} else {
IRVisitor::visit(op);
}
}
};
Expr make_similar_load(Stmt s, const string &name, const vector<Expr> &args) {
MakeSimilarLoad v(name, args);
s.accept(&v);
return v.load;
}
// Build a Box representing the bounds of a buffer.
Box buffer_bounds(const string &buf_name, int dims) {
Box bounds;
for (int i = 0; i < dims; i++) {
string dim_name = std::to_string(i);
Expr buf_min_i = Variable::make(Int(32), buf_name + ".min." + dim_name);
Expr buf_extent_i = Variable::make(Int(32), buf_name + ".extent." + dim_name);
Expr buf_max_i = buf_min_i + buf_extent_i - 1;
bounds.push_back(Interval(buf_min_i, buf_max_i));
}
return bounds;
}
class InjectPrefetch : public IRMutator {
public:
InjectPrefetch(const map<string, Function> &e) : env(e) { }
private:
const map<string, Function> &env;
const vector<Prefetch> *prefetches = nullptr;
Scope<Interval> bounds;
private:
using IRMutator::visit;
void visit(const Let *op) {
Interval in = bounds_of_expr_in_scope(op->value, bounds);
bounds.push(op->name, in);
IRMutator::visit(op);
bounds.pop(op->name);
}
void visit(const LetStmt *op) {
Interval in = bounds_of_expr_in_scope(op->value, bounds);
bounds.push(op->name, in);
IRMutator::visit(op);
bounds.pop(op->name);
}
void visit(const ProducerConsumer *op) {
const vector<Prefetch> *old_prefetches = prefetches;
map<string, Function>::const_iterator iter = env.find(op->name);
internal_assert(iter != env.end()) << "function not in environment.\n";
prefetches = &iter->second.schedule().prefetches();
IRMutator::visit(op);
prefetches = old_prefetches;
}
Stmt add_prefetch(const string &buf_name, const Box &box, Stmt body) {
// Construct the bounds to be prefetched.
vector<Expr> prefetch_min;
vector<Expr> prefetch_extent;
for (size_t i = 0; i < box.size(); i++) {
prefetch_min.push_back(box[i].min);
prefetch_extent.push_back(box[i].max - box[i].min + 1);
}
// Construct an array of index expressions to construct
// address_of calls with. The first 2 dimensions are handled
// by (up to) 2D prefetches, the rest we will generate loops
// to define.
vector<string> index_names(box.size());
vector<Expr> indices(box.size());
for (size_t i = 0; i < box.size(); i++) {
index_names[i] = "prefetch_" + buf_name + "." + std::to_string(i);
indices[i] = i < 2 ? prefetch_min[i] : Variable::make(Int(32), index_names[i]);
}
// Make a load at the index and get the address.
Expr prefetch_load = make_similar_load(body, buf_name, indices);
internal_assert(prefetch_load.defined());
Type type = prefetch_load.type();
Expr prefetch_addr = Call::make(Handle(), Call::address_of, {prefetch_load}, Call::Intrinsic);
Stmt prefetch;
Expr stride_0 = Variable::make(Int(32), buf_name + ".stride.0");
// TODO: This is inefficient if stride_0 != 1, because memory
// potentially not accessed will be prefetched, and it will be
// fetched multiple times. The right way to handle this would
// be to set up a prefetch for each individual element of the
// buffer, in case it is sparse, and then try to optimize the
// prefetch to fetch dense ranges of addresses. This is hard
// to do statically.
Expr extent_0_bytes = prefetch_extent[0] * stride_0 * type.bytes();
if (box.size() == 1) {
// The prefetch is only 1 dimensional, just emit a flat prefetch.
prefetch = Evaluate::make(Call::make(Int(32), Call::prefetch,
{prefetch_addr, extent_0_bytes},
Call::PureIntrinsic));
} else {
// Make a 2D prefetch.
Expr stride_1 = Variable::make(Int(32), buf_name + ".stride.1");
Expr stride_1_bytes = stride_1 * type.bytes();
prefetch = Evaluate::make(Call::make(Int(32), Call::prefetch_2d,
{prefetch_addr, extent_0_bytes, prefetch_extent[1], stride_1_bytes},
Call::PureIntrinsic));
// Make loops for the rest of the dimensions (possibly zero).
for (size_t i = 2; i < box.size(); i++) {
prefetch = For::make(index_names[i], prefetch_min[i], prefetch_extent[i],
ForType::Serial, DeviceAPI::None,
prefetch);
}
}
// We should only prefetch buffers that are used.
if (box.maybe_unused()) {
prefetch = IfThenElse::make(box.used, prefetch);
}
return Block::make({prefetch, body});
}
void visit(const For *op) {
// Add loop variable to interval scope for any inner loop prefetch
Expr loop_var = Variable::make(Int(32), op->name);
bounds.push(op->name, Interval(loop_var, loop_var));
Stmt body = mutate(op->body);
bounds.pop(op->name);
if (prefetches) {
for (const Prefetch &p : *prefetches) {
if (!ends_with(op->name, "." + p.var)) {
continue;
}
// Add loop variable + prefetch offset to interval scope for box computation
Expr fetch_at = loop_var + p.offset;
bounds.push(op->name, Interval(fetch_at, fetch_at));
map<string, Box> boxes_read = boxes_required(body, bounds);
bounds.pop(op->name);
// Don't prefetch buffers that are written to. We assume that these already
// have good locality.
// TODO: This is not a good assumption. It would be better to have the
// prefetch directive specify the buffer that we want to prefetch, instead
// of trying to figure out which buffers should be prefetched. This would also
// mean that we don't need the "make_similar_load" hack, because we can make
// calls the standard way (using the ImageParam/Function object referenced in
// the prefetch).
map<string, Box> boxes_written = boxes_provided(body, bounds);
for (const auto &b : boxes_written) {
auto it = boxes_read.find(b.first);
if (it != boxes_read.end()) {
debug(2) << "Not prefetching buffer " << it->first
<< " also written in loop " << op->name << "\n";
boxes_read.erase(it);
}
}
// TODO: Only prefetch the newly accessed data from the previous iteration.
// This should use boxes_touched (instead of boxes_required) so we exclude memory
// either read or written.
for (const auto &b : boxes_read) {
const string &buf_name = b.first;
// Only prefetch the region that is in bounds.
Box bounds = buffer_bounds(buf_name, b.second.size());
Box prefetch_box = box_intersection(b.second, bounds);
body = add_prefetch(buf_name, prefetch_box, body);
}
}
}
if (!body.same_as(op->body)) {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
} else {
stmt = op;
}
}
};
} // namespace
Stmt inject_prefetch(Stmt s, const map<string, Function> &env) {
return InjectPrefetch(env).mutate(s);
}
}
}