https://github.com/halide/Halide
Tip revision: bf3b66f9d5ce84bd85f0390c9e198b2e8f7bd554 authored by Z Stern on 24 September 2020, 18:00:10 UTC
Add atomic update support to thread_pool_common parallel for implementations.
Add atomic update support to thread_pool_common parallel for implementations.
Tip revision: bf3b66f
Monotonic.cpp
#include "Monotonic.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Scope.h"
#include "Simplify.h"
#include "Substitute.h"
namespace Halide {
namespace Internal {
std::ostream &operator<<(std::ostream &stream, const Monotonic &m) {
switch (m) {
case Monotonic::Constant:
stream << "Constant";
break;
case Monotonic::Increasing:
stream << "Increasing";
break;
case Monotonic::Decreasing:
stream << "Decreasing";
break;
case Monotonic::Unknown:
stream << "Unknown";
break;
}
return stream;
}
using std::string;
class MonotonicVisitor : public IRVisitor {
const string &var;
Scope<Monotonic> scope;
void visit(const IntImm *) override {
result = Monotonic::Constant;
}
void visit(const UIntImm *) override {
result = Monotonic::Constant;
}
void visit(const FloatImm *) override {
result = Monotonic::Constant;
}
void visit(const StringImm *) override {
// require() Exprs can includes Strings.
result = Monotonic::Constant;
}
void visit(const Cast *op) override {
op->value.accept(this);
if (op->type.can_represent(op->value.type())) {
// No overflow.
return;
}
if (op->value.type().bits() >= 32 && op->type.bits() >= 32) {
// We assume 32-bit types don't overflow.
return;
}
// A narrowing cast. There may be more cases we can catch, but
// for now we punt.
if (result != Monotonic::Constant) {
result = Monotonic::Unknown;
}
}
void visit(const Variable *op) override {
if (op->name == var) {
result = Monotonic::Increasing;
} else if (scope.contains(op->name)) {
result = scope.get(op->name);
} else {
result = Monotonic::Constant;
}
}
Monotonic flip(Monotonic r) {
switch (r) {
case Monotonic::Increasing:
return Monotonic::Decreasing;
case Monotonic::Decreasing:
return Monotonic::Increasing;
default:
return r;
}
}
Monotonic unify(Monotonic a, Monotonic b) {
if (a == b) {
return a;
}
if (a == Monotonic::Unknown || b == Monotonic::Unknown) {
return Monotonic::Unknown;
}
if (a == Monotonic::Constant) {
return b;
}
if (b == Monotonic::Constant) {
return a;
}
return Monotonic::Unknown;
}
void visit(const Add *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
result = unify(ra, rb);
}
void visit(const Sub *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
result = unify(ra, flip(rb));
}
void visit(const Mul *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
if (ra == Monotonic::Constant && rb == Monotonic::Constant) {
result = Monotonic::Constant;
} else if (is_positive_const(op->a)) {
result = rb;
} else if (is_positive_const(op->b)) {
result = ra;
} else if (is_negative_const(op->a)) {
result = flip(rb);
} else if (is_negative_const(op->b)) {
result = flip(ra);
} else {
result = Monotonic::Unknown;
}
}
void visit(const Div *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
if (ra == Monotonic::Constant && rb == Monotonic::Constant) {
result = Monotonic::Constant;
} else if (is_positive_const(op->b)) {
result = ra;
} else if (is_negative_const(op->b)) {
result = flip(ra);
} else {
result = Monotonic::Unknown;
}
}
void visit(const Mod *op) override {
result = Monotonic::Unknown;
}
void visit(const Min *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
result = unify(ra, rb);
}
void visit(const Max *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
result = unify(ra, rb);
}
void visit_eq(const Expr &a, const Expr &b) {
a.accept(this);
Monotonic ra = result;
b.accept(this);
Monotonic rb = result;
if (ra == Monotonic::Constant && rb == Monotonic::Constant) {
result = Monotonic::Constant;
} else {
result = Monotonic::Unknown;
}
}
void visit(const EQ *op) override {
visit_eq(op->a, op->b);
}
void visit(const NE *op) override {
visit_eq(op->a, op->b);
}
void visit_lt(const Expr &a, const Expr &b) {
a.accept(this);
Monotonic ra = result;
b.accept(this);
Monotonic rb = result;
result = unify(flip(ra), rb);
}
void visit(const LT *op) override {
visit_lt(op->a, op->b);
}
void visit(const LE *op) override {
visit_lt(op->a, op->b);
}
void visit(const GT *op) override {
visit_lt(op->b, op->a);
}
void visit(const GE *op) override {
visit_lt(op->b, op->a);
}
void visit(const And *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
result = unify(ra, rb);
}
void visit(const Or *op) override {
op->a.accept(this);
Monotonic ra = result;
op->b.accept(this);
Monotonic rb = result;
result = unify(ra, rb);
}
void visit(const Not *op) override {
op->a.accept(this);
result = flip(result);
}
void visit(const Select *op) override {
op->condition.accept(this);
Monotonic rcond = result;
op->true_value.accept(this);
Monotonic ra = result;
op->false_value.accept(this);
Monotonic rb = result;
Monotonic unified = unify(ra, rb);
if (rcond == Monotonic::Constant) {
result = unified;
return;
}
bool true_value_ge_false_value = can_prove(op->true_value >= op->false_value);
bool true_value_le_false_value = can_prove(op->true_value <= op->false_value);
bool switches_from_true_to_false = rcond == Monotonic::Decreasing;
bool switches_from_false_to_true = rcond == Monotonic::Increasing;
if (true_value_ge_false_value &&
true_value_le_false_value) {
// The true value equals the false value.
result = ra;
} else if ((unified == Monotonic::Increasing || unified == Monotonic::Constant) &&
((switches_from_false_to_true && true_value_ge_false_value) ||
(switches_from_true_to_false && true_value_le_false_value))) {
// Both paths increase, and the condition makes it switch
// from the lesser path to the greater path.
result = Monotonic::Increasing;
} else if ((unified == Monotonic::Decreasing || unified == Monotonic::Constant) &&
((switches_from_false_to_true && true_value_le_false_value) ||
(switches_from_true_to_false && true_value_ge_false_value))) {
// Both paths decrease, and the condition makes it switch
// from the greater path to the lesser path.
result = Monotonic::Decreasing;
} else {
result = Monotonic::Unknown;
}
}
void visit(const Load *op) override {
op->index.accept(this);
if (result != Monotonic::Constant) {
result = Monotonic::Unknown;
}
}
void visit(const Ramp *op) override {
Expr equiv = op->base + Variable::make(op->base.type(), unique_name('t')) * op->stride;
equiv.accept(this);
}
void visit(const Broadcast *op) override {
op->value.accept(this);
}
void visit(const Call *op) override {
// Some functions are known to be monotonic
if (op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost) ||
op->is_intrinsic(Call::return_second)) {
op->args.back().accept(this);
return;
}
if (op->is_intrinsic(Call::unsafe_promise_clamped) ||
op->is_intrinsic(Call::promise_clamped)) {
op->args[0].accept(this);
return;
}
if (op->is_intrinsic(Call::require)) {
// require() returns the value of the second arg in all non-failure cases
op->args[1].accept(this);
return;
}
if (!op->is_pure()) {
// Even with constant args, the result could vary from one loop iteration to the next.
result = Monotonic::Unknown;
return;
}
for (size_t i = 0; i < op->args.size(); i++) {
op->args[i].accept(this);
if (result != Monotonic::Constant) {
// One of the args is not constant.
result = Monotonic::Unknown;
return;
}
}
result = Monotonic::Constant;
}
void visit(const Let *op) override {
op->value.accept(this);
if (result == Monotonic::Constant) {
// No point pushing it if it's constant w.r.t the var,
// because unknown variables are treated as constant.
op->body.accept(this);
} else {
scope.push(op->name, result);
op->body.accept(this);
scope.pop(op->name);
}
}
void visit(const Shuffle *op) override {
for (size_t i = 0; i < op->vectors.size(); i++) {
op->vectors[i].accept(this);
if (result != Monotonic::Constant) {
result = Monotonic::Unknown;
return;
}
}
result = Monotonic::Constant;
}
void visit(const VectorReduce *op) override {
op->value.accept(this);
switch (op->op) {
case VectorReduce::Add:
case VectorReduce::Min:
case VectorReduce::Max:
// These reductions are monotonic in the arg
break;
case VectorReduce::Mul:
case VectorReduce::And:
case VectorReduce::Or:
// These ones are not
if (result != Monotonic::Constant) {
result = Monotonic::Unknown;
}
}
}
void visit(const LetStmt *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const AssertStmt *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const ProducerConsumer *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const For *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Acquire *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Store *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Provide *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Allocate *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Free *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Realize *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Block *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Fork *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const IfThenElse *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Evaluate *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Prefetch *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Atomic *op) override {
internal_error << "Monotonic of statement\n";
}
public:
Monotonic result;
MonotonicVisitor(const std::string &v, const Scope<Monotonic> &parent)
: var(v), result(Monotonic::Unknown) {
scope.set_containing_scope(&parent);
}
};
Monotonic is_monotonic(const Expr &e, const std::string &var, const Scope<Monotonic> &scope) {
if (!e.defined()) return Monotonic::Unknown;
MonotonicVisitor m(var, scope);
e.accept(&m);
return m.result;
}
namespace {
void check_increasing(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Increasing)
<< "Was supposed to be increasing: " << e << "\n";
}
void check_decreasing(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Decreasing)
<< "Was supposed to be decreasing: " << e << "\n";
}
void check_constant(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Constant)
<< "Was supposed to be constant: " << e << "\n";
}
void check_unknown(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Unknown)
<< "Was supposed to be unknown: " << e << "\n";
}
} // namespace
void is_monotonic_test() {
Expr x = Variable::make(Int(32), "x");
Expr y = Variable::make(Int(32), "y");
check_increasing(x);
check_increasing(x + 4);
check_increasing(x + y);
check_increasing(x * 4);
check_increasing(min(x + 4, y + 4));
check_increasing(max(x + y, x - y));
check_increasing(x >= y);
check_increasing(x > y);
check_decreasing(-x);
check_decreasing(x * -4);
check_decreasing(y - x);
check_decreasing(x < y);
check_decreasing(x <= y);
check_unknown(x == y);
check_unknown(x != y);
check_unknown(x * y);
// Not constant despite having constant args, because there's a side-effect.
check_unknown(Call::make(Int(32), "foo", {Expr(3)}, Call::Extern));
check_increasing(select(y == 2, x, x + 4));
check_decreasing(select(y == 2, -x, x * -4));
check_increasing(select(x > 2, x + 1, x));
check_increasing(select(x < 2, x, x + 1));
check_decreasing(select(x > 2, -x - 1, -x));
check_decreasing(select(x < 2, -x, -x - 1));
check_unknown(select(x < 2, x, x - 5));
check_unknown(select(x > 2, x - 5, x));
check_constant(y);
check_increasing(select(x < 17, y, y + 1));
check_increasing(select(x > 17, y, y - 1));
check_decreasing(select(x < 17, y, y - 1));
check_decreasing(select(x > 17, y, y + 1));
check_increasing(select(x % 2 == 0, x + 3, x + 3));
check_constant(select(y > 3, y + 23, y - 65));
std::cout << "is_monotonic test passed" << std::endl;
}
} // namespace Internal
} // namespace Halide