https://github.com/halide/Halide
Tip revision: 3e80492fe120f2d02cc014b1d73652e755c7ba11 authored by Marcos Slomp on 02 August 2023, 17:14:49 UTC
replacing ostringstream sentinel by nil ostream
replacing ostringstream sentinel by nil ostream
Tip revision: 3e80492
StorageFolding.cpp
#include "StorageFolding.h"
#include "Bounds.h"
#include "CSE.h"
#include "Debug.h"
#include "ExprUsesVar.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Monotonic.h"
#include "Simplify.h"
#include "Substitute.h"
#include <utility>
namespace Halide {
namespace Internal {
namespace {
int64_t next_power_of_two(int64_t x) {
return static_cast<int64_t>(1) << static_cast<int64_t>(std::ceil(std::log2(x)));
}
using std::map;
using std::string;
using std::vector;
// Count the number of producers of a particular func.
class CountProducers : public IRVisitor {
const std::string &name;
void visit(const ProducerConsumer *op) override {
if (op->is_producer && (op->name == name)) {
count++;
} else {
IRVisitor::visit(op);
}
}
using IRVisitor::visit;
public:
int count = 0;
CountProducers(const std::string &name)
: name(name) {
}
};
int count_producers(const Stmt &in, const std::string &name) {
CountProducers counter(name);
in.accept(&counter);
return counter.count;
}
// Fold the storage of a function in a particular dimension by a particular factor
class FoldStorageOfFunction : public IRMutator {
string func;
int dim;
Expr factor;
string dynamic_footprint;
using IRMutator::visit;
Expr visit(const Call *op) override {
Expr expr = IRMutator::visit(op);
op = expr.as<Call>();
internal_assert(op);
if (op->name == func && op->call_type == Call::Halide) {
vector<Expr> args = op->args;
internal_assert(dim < (int)args.size());
args[dim] = is_const_one(factor) ? 0 : (args[dim] % factor);
expr = Call::make(op->type, op->name, args, op->call_type,
op->func, op->value_index, op->image, op->param);
} else if (op->name == Call::buffer_crop) {
Expr source = op->args[2];
const Variable *buf_var = source.as<Variable>();
if (buf_var &&
starts_with(buf_var->name, func + ".") &&
ends_with(buf_var->name, ".buffer")) {
// We are taking a crop of a folded buffer. For now
// we'll just assert that the crop doesn't wrap
// around, so that the crop doesn't need to be treated
// as a folded buffer too. But to take the crop, we
// need to use folded coordinates, and then restore
// the non-folded min after the crop operation.
// Pull out the expressions we need
internal_assert(op->args.size() >= 5);
Expr mins_arg = op->args[3];
Expr extents_arg = op->args[4];
const Call *mins_call = mins_arg.as<Call>();
const Call *extents_call = extents_arg.as<Call>();
internal_assert(mins_call && extents_call);
vector<Expr> mins = mins_call->args;
const vector<Expr> &extents = extents_call->args;
internal_assert(dim < (int)mins.size() && dim < (int)extents.size());
Expr old_min = mins[dim];
Expr old_extent = extents[dim];
// Rewrite the crop args
mins[dim] = old_min % factor;
Expr new_mins = Call::make(type_of<int *>(), Call::make_struct, mins, Call::Intrinsic);
vector<Expr> new_args = op->args;
new_args[3] = new_mins;
expr = Call::make(op->type, op->name, new_args, op->call_type);
// Inject the assertion
Expr no_wraparound = mins[dim] + extents[dim] <= factor;
Expr valid_min = old_min;
if (!dynamic_footprint.empty()) {
// If the footprint is being tracked dynamically, it's
// not enough to just check we don't overlap a
// fold. We also need to check the min against the
// valid min.
// TODO: dynamic footprint is no longer the min, and may be tracked separately on producer and consumer sides (head vs tail)
valid_min =
Load::make(Int(32), dynamic_footprint, 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
Expr check = (old_min >= valid_min &&
(old_min + old_extent - 1) < valid_min + factor);
no_wraparound = no_wraparound && check;
}
Expr error = Call::make(Int(32), "halide_error_bad_extern_fold",
{Expr(func), Expr(dim), old_min, old_extent, valid_min, factor},
Call::Extern);
expr = Call::make(op->type, Call::require,
{no_wraparound, expr, error}, Call::Intrinsic);
// Restore the correct min coordinate
expr = Call::make(op->type, Call::buffer_set_bounds,
{expr, dim, old_min, old_extent}, Call::Extern);
}
}
return expr;
}
Stmt visit(const Provide *op) override {
Stmt stmt = IRMutator::visit(op);
op = stmt.as<Provide>();
internal_assert(op);
if (op->name == func) {
vector<Expr> args = op->args;
args[dim] = is_const_one(factor) ? 0 : (args[dim] % factor);
stmt = Provide::make(op->name, op->values, args, op->predicate);
}
return stmt;
}
public:
FoldStorageOfFunction(string f, int d, Expr e, string p)
: func(std::move(f)), dim(d), factor(std::move(e)), dynamic_footprint(std::move(p)) {
}
};
// Inject dynamic folding checks against a tracked live range.
class InjectFoldingCheck : public IRMutator {
Function func;
string head, tail, loop_var;
Expr sema_var;
int dim;
bool in_produce;
const StorageDim &storage_dim;
using IRMutator::visit;
Stmt visit(const ProducerConsumer *op) override {
if (op->name == func.name()) {
Stmt body = op->body;
if (op->is_producer) {
if (func.has_extern_definition()) {
// We'll update the valid min at the buffer_crop call.
in_produce = true;
body = mutate(op->body);
} else {
// Update valid range based on bounds written to.
Box b = box_provided(body, func.name());
Expr old_leading_edge =
Load::make(Int(32), head + "_next", 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
internal_assert(!b.empty());
// Track the logical address range the memory
// currently represents.
Expr new_leading_edge;
if (storage_dim.fold_forward) {
new_leading_edge = max(b[dim].max, old_leading_edge);
} else {
new_leading_edge = min(b[dim].min, old_leading_edge);
}
string new_leading_edge_var_name = unique_name('t');
Expr new_leading_edge_var = Variable::make(Int(32), new_leading_edge_var_name);
Stmt update_leading_edge =
Store::make(head, new_leading_edge_var, 0, Parameter(), const_true(), ModulusRemainder());
Stmt update_next_leading_edge =
Store::make(head + "_next", new_leading_edge_var, 0, Parameter(), const_true(), ModulusRemainder());
// Check the region being written to in this
// iteration lies within the range of coordinates
// currently represented.
Expr fold_non_monotonic_error =
Call::make(Int(32), "halide_error_bad_fold",
{func.name(), storage_dim.var, loop_var},
Call::Extern);
Expr in_valid_range;
if (storage_dim.fold_forward) {
in_valid_range = b[dim].min > new_leading_edge - storage_dim.fold_factor;
} else {
in_valid_range = b[dim].max < new_leading_edge + storage_dim.fold_factor;
}
Stmt check_in_valid_range =
AssertStmt::make(in_valid_range, fold_non_monotonic_error);
Expr extent = b[dim].max - b[dim].min + 1;
// Separately check the extent for *this* loop iteration fits.
Expr fold_too_small_error =
Call::make(Int(32), "halide_error_fold_factor_too_small",
{func.name(), storage_dim.var, storage_dim.fold_factor, loop_var, extent},
Call::Extern);
Stmt check_extent =
AssertStmt::make(extent <= storage_dim.fold_factor, fold_too_small_error);
Stmt checks = Block::make({check_extent, check_in_valid_range,
update_leading_edge, update_next_leading_edge});
if (func.schedule().async()) {
Expr to_acquire;
if (storage_dim.fold_forward) {
to_acquire = new_leading_edge_var - old_leading_edge;
} else {
to_acquire = old_leading_edge - new_leading_edge_var;
}
body = Block::make(checks, body);
body = Acquire::make(sema_var, to_acquire, body);
body = LetStmt::make(new_leading_edge_var_name, new_leading_edge, body);
} else {
checks = LetStmt::make(new_leading_edge_var_name, new_leading_edge, checks);
body = Block::make(checks, body);
}
}
} else {
// Check the accessed range against the valid range.
Box b = box_required(body, func.name());
if (b.empty()) {
// Must be used in an extern call (TODO:
// assert this, TODO: What if it's used in an
// extern call and native Halide). We'll
// update the valid min at the buffer_crop
// call.
in_produce = false;
body = mutate(op->body);
} else {
Expr leading_edge =
Load::make(Int(32), tail + "_next", 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
if (func.schedule().async()) {
Expr new_leading_edge;
if (storage_dim.fold_forward) {
new_leading_edge = b[dim].min - 1 + storage_dim.fold_factor;
} else {
new_leading_edge = b[dim].max + 1 - storage_dim.fold_factor;
}
string new_leading_edge_name = unique_name('t');
Expr new_leading_edge_var = Variable::make(Int(32), new_leading_edge_name);
Expr to_release;
if (storage_dim.fold_forward) {
to_release = new_leading_edge_var - leading_edge;
} else {
to_release = leading_edge - new_leading_edge_var;
}
Expr release_producer =
Call::make(Int(32), "halide_semaphore_release", {sema_var, to_release}, Call::Extern);
// The consumer is going to get its own forked copy of the footprint, so it needs to update it too.
Stmt update_leading_edge = Store::make(tail, new_leading_edge_var, 0, Parameter(), const_true(), ModulusRemainder());
update_leading_edge = Block::make(Store::make(tail + "_next", new_leading_edge_var, 0, Parameter(), const_true(), ModulusRemainder()),
update_leading_edge);
update_leading_edge = Block::make(Evaluate::make(release_producer), update_leading_edge);
update_leading_edge = LetStmt::make(new_leading_edge_name, new_leading_edge, update_leading_edge);
body = Block::make(update_leading_edge, body);
} else {
Expr check;
if (storage_dim.fold_forward) {
check = (b[dim].min > leading_edge - storage_dim.fold_factor && b[dim].max <= leading_edge);
} else {
check = (b[dim].max < leading_edge + storage_dim.fold_factor && b[dim].min >= leading_edge);
}
Expr bad_fold_error = Call::make(Int(32), "halide_error_bad_fold",
{func.name(), storage_dim.var, loop_var},
Call::Extern);
body = Block::make(AssertStmt::make(check, bad_fold_error), body);
}
}
}
return ProducerConsumer::make(op->name, op->is_producer, body);
} else {
return IRMutator::visit(op);
}
}
Stmt visit(const LetStmt *op) override {
if (starts_with(op->name, func.name() + ".") &&
ends_with(op->name, ".tmp_buffer")) {
Stmt body = op->body;
Expr buf = Variable::make(type_of<halide_buffer_t *>(), op->name);
if (in_produce) {
// We're taking a crop of the buffer to act as an output
// to an extern stage. Update the valid min or max
// coordinate accordingly.
Expr leading_edge;
if (storage_dim.fold_forward) {
leading_edge =
Call::make(Int(32), Call::buffer_get_max, {buf, dim}, Call::Extern);
} else {
leading_edge =
Call::make(Int(32), Call::buffer_get_min, {buf, dim}, Call::Extern);
}
Stmt update_leading_edge =
Store::make(head, leading_edge, 0, Parameter(), const_true(), ModulusRemainder());
body = Block::make(update_leading_edge, body);
// We don't need to make sure the min is moving
// monotonically, because we can't do sliding window on
// extern stages, so we don't have to worry about whether
// we're preserving valid values from previous loop
// iterations.
if (func.schedule().async()) {
Expr old_leading_edge =
Load::make(Int(32), head, 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
Expr to_acquire;
if (storage_dim.fold_forward) {
to_acquire = leading_edge - old_leading_edge;
} else {
to_acquire = old_leading_edge - leading_edge;
}
body = Acquire::make(sema_var, to_acquire, body);
}
} else {
// We're taking a crop of the buffer to act as an input
// to an extern stage. Update the valid min or max
// coordinate accordingly.
Expr leading_edge;
if (storage_dim.fold_forward) {
leading_edge =
Call::make(Int(32), Call::buffer_get_min, {buf, dim}, Call::Extern) - 1 + storage_dim.fold_factor;
} else {
leading_edge =
Call::make(Int(32), Call::buffer_get_max, {buf, dim}, Call::Extern) + 1 - storage_dim.fold_factor;
}
Stmt update_leading_edge =
Store::make(tail, leading_edge, 0, Parameter(), const_true(), ModulusRemainder());
body = Block::make(update_leading_edge, body);
if (func.schedule().async()) {
Expr old_leading_edge =
Load::make(Int(32), tail, 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
Expr to_release;
if (storage_dim.fold_forward) {
to_release = leading_edge - old_leading_edge;
} else {
to_release = old_leading_edge - leading_edge;
}
Expr release_producer =
Call::make(Int(32), "halide_semaphore_release", {sema_var, to_release}, Call::Extern);
body = Block::make(Evaluate::make(release_producer), body);
}
}
return LetStmt::make(op->name, op->value, body);
} else {
return LetStmt::make(op->name, op->value, mutate(op->body));
}
}
public:
InjectFoldingCheck(Function func,
string head, string tail,
string loop_var, Expr sema_var,
int dim, const StorageDim &storage_dim)
: func(std::move(func)),
head(std::move(head)), tail(std::move(tail)), loop_var(std::move(loop_var)), sema_var(std::move(sema_var)),
dim(dim), storage_dim(storage_dim) {
}
};
struct Semaphore {
string name;
Expr var;
Expr init;
};
class HasExternConsumer : public IRVisitor {
using IRVisitor::visit;
void visit(const Variable *op) override {
if (op->name == func + ".buffer") {
result = true;
}
}
const std::string &func;
public:
HasExternConsumer(const std::string &func)
: func(func) {
}
bool result = false;
};
class VectorAccessOfFoldedDim : public IRVisitor {
using IRVisitor::visit;
void visit(const Provide *op) override {
if (op->name == func) {
internal_assert(dim < (int)op->args.size());
if (expr_uses_vars(op->args[dim], vector_vars)) {
result = true;
}
} else {
IRVisitor::visit(op);
}
}
void visit(const Call *op) override {
if (op->name == func &&
op->call_type == Call::Halide) {
internal_assert(dim < (int)op->args.size());
if (expr_uses_vars(op->args[dim], vector_vars)) {
result = true;
}
} else {
IRVisitor::visit(op);
}
}
template<typename LetOrLetStmt>
void visit_let(const LetOrLetStmt *op) {
op->value.accept(this);
bool is_vec = expr_uses_vars(op->value, vector_vars);
ScopedBinding<> bind(is_vec, vector_vars, op->name);
op->body.accept(this);
}
void visit(const Let *op) override {
visit_let(op);
}
void visit(const LetStmt *op) override {
visit_let(op);
}
void visit(const For *op) override {
ScopedBinding<> bind(op->for_type == ForType::Vectorized,
vector_vars, op->name);
IRVisitor::visit(op);
}
Scope<> vector_vars;
const string &func;
int dim;
public:
bool result = false;
VectorAccessOfFoldedDim(const string &func, int dim)
: func(func), dim(dim) {
}
};
// Attempt to fold the storage of a particular function in a statement
class AttemptStorageFoldingOfFunction : public IRMutator {
Function func;
bool explicit_only;
using IRMutator::visit;
Stmt visit(const ProducerConsumer *op) override {
if (op->name == func.name()) {
// Can't proceed into the pipeline for this func
return op;
} else {
return IRMutator::visit(op);
}
}
Stmt visit(const For *op) override {
if (op->for_type != ForType::Serial && op->for_type != ForType::Unrolled) {
// We can't proceed into a parallel for loop.
// TODO: If there's no overlap between the region touched
// by the threads as this loop counter varies
// (i.e. there's no cross-talk between threads), then it's
// safe to proceed.
return op;
}
Stmt stmt;
Stmt body = op->body;
Box provided = box_provided(body, func.name());
Box required = box_required(body, func.name());
// For storage folding, we don't care about conditional reads.
required.used = Expr();
Box box = box_union(provided, required);
Expr loop_var = Variable::make(Int(32), op->name);
Expr loop_min = Variable::make(Int(32), op->name + ".loop_min");
Expr loop_max = Variable::make(Int(32), op->name + ".loop_max");
string dynamic_footprint;
Scope<Interval> bounds;
bounds.push(op->name, Interval(op->min, simplify(op->min + op->extent - 1)));
Scope<Interval> steady_bounds;
steady_bounds.push(op->name, Interval(simplify(op->min + 1), simplify(op->min + op->extent - 1)));
HasExternConsumer has_extern_consumer(func.name());
body.accept(&has_extern_consumer);
// Try each dimension in turn from outermost in
for (size_t i = box.size(); i > 0; i--) {
int dim = (int)(i - 1);
if (!box[dim].is_bounded()) {
continue;
}
Expr min = simplify(common_subexpression_elimination(box[dim].min));
Expr max = simplify(common_subexpression_elimination(box[dim].max));
if (is_const(min) || is_const(max)) {
debug(3) << "\nNot considering folding " << func.name()
<< " over for loop over " << op->name
<< " dimension " << i - 1 << "\n"
<< " because the min or max are constants."
<< "Min: " << min << "\n"
<< "Max: " << max << "\n";
continue;
}
Expr min_provided, max_provided, min_required, max_required;
if (func.schedule().async() && !explicit_only) {
if (!provided.empty()) {
min_provided = simplify(provided[dim].min);
max_provided = simplify(provided[dim].max);
}
if (!required.empty()) {
min_required = simplify(required[dim].min);
max_required = simplify(required[dim].max);
}
}
string sema_name = func.name() + ".folding_semaphore." + unique_name('_');
Expr sema_var = Variable::make(type_of<halide_semaphore_t *>(), sema_name);
// Consider the initial iteration and steady state
// separately for all these proofs.
Expr loop_var = Variable::make(Int(32), op->name);
Expr steady_state = (op->min < loop_var);
Expr min_steady = simplify(substitute(steady_state, const_true(), min), true, steady_bounds);
Expr max_steady = simplify(substitute(steady_state, const_true(), max), true, steady_bounds);
Expr min_initial = simplify(substitute(steady_state, const_false(), min), true, bounds);
Expr max_initial = simplify(substitute(steady_state, const_false(), max), true, bounds);
Expr extent_initial = simplify(substitute(loop_var, op->min, max_initial - min_initial + 1), true, bounds);
Expr extent_steady = simplify(max_steady - min_steady + 1, true, steady_bounds);
Expr extent = Max::make(extent_initial, extent_steady);
extent = simplify(common_subexpression_elimination(extent), true, bounds);
// Find the StorageDim corresponding to dim.
const std::vector<StorageDim> &storage_dims = func.schedule().storage_dims();
auto storage_dim_i = std::find_if(storage_dims.begin(), storage_dims.end(),
[&](const StorageDim &i) { return i.var == func.args()[dim]; });
internal_assert(storage_dim_i != storage_dims.end());
const StorageDim &storage_dim = *storage_dim_i;
Expr explicit_factor;
if (!is_pure(min) ||
!is_pure(max) ||
has_extern_consumer.result ||
expr_uses_var(min, op->name) ||
expr_uses_var(max, op->name)) {
// We only use the explicit fold factor if the fold is
// relevant for this loop. If the fold isn't relevant
// for this loop, the added asserts will be too
// conservative.
explicit_factor = storage_dim.fold_factor;
}
debug(3) << "\nConsidering folding " << func.name()
<< " over for loop over " << op->name
<< " dimension " << i - 1 << "\n"
<< "Min: " << min << "\n"
<< "Max: " << max << "\n"
<< "Extent: " << extent << "\n"
<< "explicit_factor: " << explicit_factor << "\n";
// First, attempt to detect if the loop is monotonically
// increasing or decreasing (if we allow automatic folding).
bool can_fold_forwards = false, can_fold_backwards = false;
if (!explicit_only) {
// We can't clobber data that will be read later. If
// async, the producer can't un-release slots in the
// circular buffer.
can_fold_forwards = (is_monotonic(min, op->name) == Monotonic::Increasing);
can_fold_backwards = (is_monotonic(max, op->name) == Monotonic::Decreasing);
if (func.schedule().async()) {
// Our semaphore acquire primitive can't take
// negative values, so we can't un-acquire slots
// in the circular buffer.
can_fold_forwards &= (is_monotonic(max_provided, op->name) == Monotonic::Increasing);
can_fold_backwards &= (is_monotonic(min_provided, op->name) == Monotonic::Decreasing);
// We need to be able to analyze the required footprint to know how much to release
can_fold_forwards &= min_required.defined();
can_fold_backwards &= max_required.defined();
}
}
// Uncomment to pretend that static analysis always fails (for testing)
// can_fold_forwards = can_fold_backwards = false;
if (!can_fold_forwards && !can_fold_backwards) {
if (explicit_factor.defined()) {
// If we didn't find a monotonic dimension, and we
// have an explicit fold factor, we need to
// dynamically check that the min/max do in fact
// monotonically increase/decrease. We'll allocate
// some stack space to store the valid footprint,
// update it outside produce nodes, and check it
// outside consume nodes.
string head, tail;
if (func.schedule().async()) {
// If we're async, we need to keep a separate
// counter for the producer and consumer. They
// are coupled by a semaphore. The counter
// represents the max index the producer may
// write to. The invariant is that the
// semaphore count is the difference between
// the counters. So...
//
// when folding forwards, semaphore == head - tail
// when folding backwards, semaphore == tail - head
//
// We'll initialize to head = tail, and
// semaphore = 0. Every time the producer or
// consumer wants to move the counter, it must
// also acquire or release the semaphore to
// prevent them from diverging too far.
dynamic_footprint = func.name() + ".folding_semaphore." + op->name + unique_name('_');
head = dynamic_footprint + ".head";
tail = dynamic_footprint + ".tail";
} else {
dynamic_footprint = func.name() + "." + op->name + unique_name('_') + ".head";
head = tail = dynamic_footprint;
}
body = InjectFoldingCheck(func,
head, tail,
op->name,
sema_var,
dim,
storage_dim)
.mutate(body);
if (storage_dim.fold_forward) {
can_fold_forwards = true;
} else {
can_fold_backwards = true;
}
} else {
// Can't do much with this dimension
if (!explicit_only) {
debug(3) << "Not folding because loop min or max not monotonic in the loop variable\n"
<< "min_initial = " << min_initial << "\n"
<< "min_steady = " << min_steady << "\n"
<< "max_initial = " << max_initial << "\n"
<< "max_steady = " << max_steady << "\n";
} else {
debug(3) << "Not folding because there is no explicit storage folding factor\n";
}
continue;
}
}
internal_assert(can_fold_forwards || can_fold_backwards);
Expr factor;
if (explicit_factor.defined()) {
if (dynamic_footprint.empty() && !func.schedule().async()) {
// We were able to prove monotonicity
// statically, but we may need a runtime
// assertion for maximum extent. In many cases
// it will simplify away. For async schedules
// it gets dynamically tracked anyway.
Expr error = Call::make(Int(32), "halide_error_fold_factor_too_small",
{func.name(), storage_dim.var, explicit_factor, op->name, extent},
Call::Extern);
body = Block::make(AssertStmt::make(extent <= explicit_factor, error), body);
}
factor = explicit_factor;
} else {
// The max of the extent over all values of the loop variable must be a constant
Scope<Interval> scope;
scope.push(op->name, Interval(loop_min, loop_max));
Expr max_extent = find_constant_bound(extent, Direction::Upper, scope);
scope.pop(op->name);
const int max_fold = 1024;
const int64_t *const_max_extent = as_const_int(max_extent);
if (const_max_extent && *const_max_extent <= max_fold) {
factor = static_cast<int>(next_power_of_two(*const_max_extent));
} else {
// Try a little harder to find a bounding power of two
int e = max_fold * 2;
bool success = false;
while (e > 0 && can_prove(extent <= e / 2)) {
success = true;
e /= 2;
}
if (success) {
factor = e;
} else {
debug(3) << "Not folding because extent not bounded by a constant not greater than " << max_fold << "\n"
<< "extent = " << extent << "\n"
<< "max extent = " << max_extent << "\n";
// Try the next dimension
continue;
}
}
}
internal_assert(factor.defined());
if (!explicit_factor.defined()) {
VectorAccessOfFoldedDim vector_access_of_folded_dim{func.name(), dim};
body.accept(&vector_access_of_folded_dim);
if (vector_access_of_folded_dim.result) {
user_warning
<< "Not folding Func " << func.name() << " along dimension " << func.args()[dim]
<< " because there is vectorized access to that Func in that dimension and "
<< "storage folding was not explicitly requested in the schedule. In previous "
<< "versions of Halide this would have folded with factor " << factor
<< ". To restore the old behavior add " << func.name()
<< ".fold_storage(" << func.args()[dim] << ", " << factor
<< ") to your schedule.\n";
// Try the next dimension
continue;
}
}
debug(3) << "Proceeding with factor " << factor << "\n";
Fold fold = {(int)i - 1, factor};
dims_folded.push_back(fold);
{
string head;
if (!dynamic_footprint.empty() && func.schedule().async()) {
head = dynamic_footprint + ".head";
} else {
head = dynamic_footprint;
}
body = FoldStorageOfFunction(func.name(), (int)i - 1, factor, head).mutate(body);
}
// If the producer is async, it can run ahead by
// some amount controlled by a semaphore.
if (func.schedule().async()) {
Semaphore sema;
sema.name = sema_name;
sema.var = sema_var;
sema.init = 0;
if (dynamic_footprint.empty()) {
// We are going to do the sem acquires and releases using static analysis of the boxes accessed.
sema.init = factor;
// Do the analysis of how much to acquire and release statically
Expr to_acquire, to_release;
if (can_fold_forwards) {
Expr max_provided_prev = substitute(op->name, loop_var - 1, max_provided);
Expr min_required_next = substitute(op->name, loop_var + 1, min_required);
to_acquire = max_provided - max_provided_prev; // This is the first time we use these entries
to_release = min_required_next - min_required; // This is the last time we use these entries
} else {
internal_assert(can_fold_backwards);
Expr min_provided_prev = substitute(op->name, loop_var - 1, min_provided);
Expr max_required_next = substitute(op->name, loop_var + 1, max_required);
to_acquire = min_provided_prev - min_provided; // This is the first time we use these entries
to_release = max_required - max_required_next; // This is the last time we use these entries
}
if (provided.used.defined()) {
to_acquire = select(provided.used, to_acquire, 0);
}
// We should always release the required region, even if we don't use it.
// On the first iteration, we need to acquire the extent of the region shared
// between the producer and consumer, and we need to release it on the last
// iteration.
to_acquire = select(loop_var > loop_min, to_acquire, extent);
to_release = select(loop_var < loop_max, to_release, extent);
// We may need dynamic assertions that a positive
// amount of the semaphore is acquired/released,
// and that the semaphore is initialized to a
// positive value. If we are able to prove it,
// these checks will simplify away.
string to_release_name = unique_name('t');
Expr to_release_var = Variable::make(Int(32), to_release_name);
string to_acquire_name = unique_name('t');
Expr to_acquire_var = Variable::make(Int(32), to_acquire_name);
Expr bad_fold_error =
Call::make(Int(32), "halide_error_bad_fold",
{func.name(), storage_dim.var, op->name},
Call::Extern);
Expr release_producer =
Call::make(Int(32), "halide_semaphore_release", {sema.var, to_release_var}, Call::Extern);
Stmt release = Evaluate::make(release_producer);
Stmt check_release = AssertStmt::make(to_release_var >= 0 && to_release <= factor, bad_fold_error);
release = Block::make(check_release, release);
release = LetStmt::make(to_release_name, to_release, release);
Stmt check_acquire = AssertStmt::make(to_acquire_var >= 0 && to_acquire_var <= factor, bad_fold_error);
body = Block::make(body, release);
body = Acquire::make(sema.var, to_acquire_var, body);
body = Block::make(check_acquire, body);
body = LetStmt::make(to_acquire_name, to_acquire, body);
} else {
// We injected runtime tracking and semaphore logic already
}
dims_folded.back().semaphore = sema;
}
if (!dynamic_footprint.empty()) {
if (func.schedule().async()) {
dims_folded.back().head = dynamic_footprint + ".head";
dims_folded.back().tail = dynamic_footprint + ".tail";
} else {
dims_folded.back().head = dynamic_footprint;
dims_folded.back().tail.clear();
}
dims_folded.back().fold_forward = storage_dim.fold_forward;
}
Expr min_next = substitute(op->name, loop_var + 1, min);
if (can_prove(max < min_next)) {
// There's no overlapping usage between loop
// iterations, so we can continue to search
// for further folding opportunities
// recursively.
} else if (!body.same_as(op->body)) {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
break;
} else {
stmt = op;
debug(3) << "Not folding because loop min or max not monotonic in the loop variable\n"
<< "min = " << min << "\n"
<< "max = " << max << "\n";
break;
}
}
// If there's no communication of values from one loop
// iteration to the next (which may happen due to sliding),
// then we're safe to fold an inner loop.
if (box_contains(provided, required)) {
body = mutate(body);
}
if (body.same_as(op->body)) {
stmt = op;
} else {
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
}
if (func.schedule().async() && !dynamic_footprint.empty()) {
// Step the counters backwards over the entire extent of
// the realization, in case we're in an inner loop and are
// going to run this loop again with the same
// semaphore. Our invariant is that the difference between
// the two counters is the semaphore.
//
// Doing this instead of synchronizing and resetting the
// counters and semaphores lets producers advance to the
// next scanline while a consumer is still on the last few
// pixels of the previous scanline.
Expr head = Load::make(Int(32), dynamic_footprint + ".head", 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
Expr tail = Load::make(Int(32), dynamic_footprint + ".tail", 0, Buffer<>(), Parameter(), const_true(), ModulusRemainder());
Expr step = Variable::make(Int(32), func.name() + ".extent." + std::to_string(dims_folded.back().dim)) + dims_folded.back().factor;
Stmt reset_head = Store::make(dynamic_footprint + ".head_next", head - step, 0, Parameter(), const_true(), ModulusRemainder());
Stmt reset_tail = Store::make(dynamic_footprint + ".tail_next", tail - step, 0, Parameter(), const_true(), ModulusRemainder());
stmt = Block::make({stmt, reset_head, reset_tail});
}
return stmt;
}
public:
struct Fold {
int dim;
Expr factor;
Semaphore semaphore;
string head, tail;
bool fold_forward;
};
vector<Fold> dims_folded;
AttemptStorageFoldingOfFunction(Function f, bool explicit_only)
: func(std::move(f)), explicit_only(explicit_only) {
}
};
// Look for opportunities for storage folding in a statement
class StorageFolding : public IRMutator {
const map<string, Function> &env;
using IRMutator::visit;
Stmt visit(const Realize *op) override {
Stmt body = mutate(op->body);
// Get the function associated with this realization, which
// contains the explicit fold directives from the schedule.
auto func_it = env.find(op->name);
Function func = func_it != env.end() ? func_it->second : Function();
// Don't attempt automatic storage folding if there is
// more than one produce node for this func.
bool explicit_only = count_producers(body, op->name) != 1;
AttemptStorageFoldingOfFunction folder(func, explicit_only);
if (explicit_only) {
debug(3) << "Attempting to fold " << op->name << " explicitly\n";
} else {
debug(3) << "Attempting to fold " << op->name << " automatically or explicitly\n";
}
body = folder.mutate(body);
if (body.same_as(op->body)) {
return op;
} else if (folder.dims_folded.empty()) {
return Realize::make(op->name, op->types, op->memory_type, op->bounds, op->condition, body);
} else {
Region bounds = op->bounds;
// Collapse down the extent in the folded dimension
for (const auto &dim : folder.dims_folded) {
int d = dim.dim;
Expr f = dim.factor;
internal_assert(d >= 0 &&
d < (int)bounds.size());
bounds[d] = Range(0, f);
}
Stmt stmt = Realize::make(op->name, op->types, op->memory_type, bounds, op->condition, body);
// Each fold may have an associated semaphore that needs initialization, along with some counters
for (const auto &fold : folder.dims_folded) {
auto sema = fold.semaphore;
if (sema.var.defined()) {
Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
{sema.init}, Call::Extern);
stmt = LetStmt::make(sema.name, sema_space, stmt);
}
Expr init;
if (fold.fold_forward) {
init = op->bounds[fold.dim].min;
} else {
init = op->bounds[fold.dim].min + op->bounds[fold.dim].extent - 1;
}
if (!fold.head.empty()) {
stmt = Block::make(Store::make(fold.head + "_next", init, 0, Parameter(), const_true(), ModulusRemainder()), stmt);
stmt = Allocate::make(fold.head + "_next", Int(32), MemoryType::Stack, {}, const_true(), stmt);
stmt = Block::make(Store::make(fold.head, init, 0, Parameter(), const_true(), ModulusRemainder()), stmt);
stmt = Allocate::make(fold.head, Int(32), MemoryType::Stack, {}, const_true(), stmt);
}
if (!fold.tail.empty()) {
internal_assert(func.schedule().async()) << "Expected a single counter for synchronous folding";
stmt = Block::make(Store::make(fold.tail + "_next", init, 0, Parameter(), const_true(), ModulusRemainder()), stmt);
stmt = Allocate::make(fold.tail + "_next", Int(32), MemoryType::Stack, {}, const_true(), stmt);
stmt = Block::make(Store::make(fold.tail, init, 0, Parameter(), const_true(), ModulusRemainder()), stmt);
stmt = Allocate::make(fold.tail, Int(32), MemoryType::Stack, {}, const_true(), stmt);
}
}
return stmt;
}
}
public:
StorageFolding(const map<string, Function> &env)
: env(env) {
}
};
} // namespace
Stmt storage_folding(const Stmt &s, const std::map<std::string, Function> &env) {
return StorageFolding(env).mutate(s);
}
} // namespace Internal
} // namespace Halide