https://github.com/halide/Halide
Tip revision: f9e4c7878385f43cf88cca23d5bd663233e9e7da authored by Steven Johnson on 27 April 2021, 19:14:54 UTC
Add support for dynamic tensors to hannk (#5942)
Add support for dynamic tensors to hannk (#5942)
Tip revision: f9e4c78
Monotonic.cpp
#include "Monotonic.h"
#include "Bounds.h"
#include "IROperator.h"
#include "IRVisitor.h"
#include "Scope.h"
#include "Simplify.h"
#include "Substitute.h"
namespace Halide {
namespace Internal {
std::ostream &operator<<(std::ostream &stream, const Monotonic &m) {
switch (m) {
case Monotonic::Constant:
stream << "Constant";
break;
case Monotonic::Increasing:
stream << "Increasing";
break;
case Monotonic::Decreasing:
stream << "Decreasing";
break;
case Monotonic::Unknown:
stream << "Unknown";
break;
}
return stream;
}
using std::string;
namespace {
const int64_t *as_const_int_or_uint(const Expr &e) {
if (const int64_t *i = as_const_int(e)) {
return i;
} else if (const uint64_t *u = as_const_uint(e)) {
if (*u <= (uint64_t)std::numeric_limits<int64_t>::max()) {
return (const int64_t *)u;
}
}
return nullptr;
}
bool is_constant(const ConstantInterval &a) {
return a.is_single_point(0);
}
bool may_be_negative(const ConstantInterval &a) {
return !a.has_lower_bound() || a.min < 0;
}
bool may_be_positive(const ConstantInterval &a) {
return !a.has_upper_bound() || a.max > 0;
}
bool is_monotonic_increasing(const ConstantInterval &a) {
return !may_be_negative(a);
}
bool is_monotonic_decreasing(const ConstantInterval &a) {
return !may_be_positive(a);
}
ConstantInterval to_interval(Monotonic m) {
switch (m) {
case Monotonic::Constant:
return ConstantInterval::single_point(0);
case Monotonic::Increasing:
return ConstantInterval::bounded_below(0);
case Monotonic::Decreasing:
return ConstantInterval::bounded_above(0);
case Monotonic::Unknown:
return ConstantInterval::everything();
}
return ConstantInterval::everything();
}
Monotonic to_monotonic(const ConstantInterval &x) {
if (is_constant(x)) {
return Monotonic::Constant;
} else if (is_monotonic_increasing(x)) {
return Monotonic::Increasing;
} else if (is_monotonic_decreasing(x)) {
return Monotonic::Decreasing;
} else {
return Monotonic::Unknown;
}
}
ConstantInterval unify(const ConstantInterval &a, const ConstantInterval &b) {
return ConstantInterval::make_union(a, b);
}
ConstantInterval unify(const ConstantInterval &a, int64_t b) {
ConstantInterval result;
result.include(b);
return result;
}
// Helpers for doing arithmetic on ConstantIntervals that avoid generating
// expressions of pos_inf/neg_inf.
ConstantInterval add(const ConstantInterval &a, const ConstantInterval &b) {
ConstantInterval result;
result.min_defined = a.has_lower_bound() && b.has_lower_bound();
result.max_defined = a.has_upper_bound() && b.has_upper_bound();
if (result.has_lower_bound()) {
result.min = a.min + b.min;
}
if (result.has_upper_bound()) {
result.max = a.max + b.max;
}
return result;
}
ConstantInterval add(const ConstantInterval &a, int64_t b) {
return add(a, ConstantInterval(b, b));
}
ConstantInterval negate(const ConstantInterval &r) {
ConstantInterval result;
result.min_defined = r.has_upper_bound();
result.min = r.has_upper_bound() ? -r.max : 0;
result.max_defined = r.has_lower_bound();
result.max = r.has_lower_bound() ? -r.min : 0;
return result;
}
ConstantInterval sub(const ConstantInterval &a, const ConstantInterval &b) {
return add(a, negate(b));
}
ConstantInterval sub(const ConstantInterval &a, int64_t b) {
return sub(a, ConstantInterval(b, b));
}
ConstantInterval multiply(const ConstantInterval &a, int64_t b) {
ConstantInterval result(a);
if (b < 0) {
result = negate(result);
b = -b;
}
if (result.has_lower_bound()) {
result.min *= b;
}
if (result.has_upper_bound()) {
result.max *= b;
}
return result;
}
ConstantInterval multiply(const ConstantInterval &a, const Expr &b) {
if (const int64_t *bi = as_const_int_or_uint(b)) {
return multiply(a, *bi);
}
return ConstantInterval::everything();
}
ConstantInterval multiply(const ConstantInterval &a, const ConstantInterval &b) {
int64_t bounds[4];
int64_t *bounds_begin = &bounds[0];
int64_t *bounds_end = &bounds[0];
if (a.has_lower_bound() && b.has_lower_bound()) {
*bounds_end++ = a.min * b.min;
}
if (a.has_lower_bound() && b.has_upper_bound()) {
*bounds_end++ = a.min * b.max;
}
if (a.has_upper_bound() && b.has_lower_bound()) {
*bounds_end++ = a.max * b.min;
}
if (a.has_upper_bound() && b.has_upper_bound()) {
*bounds_end++ = a.max * b.max;
}
if (bounds_begin != bounds_end) {
ConstantInterval result = {
*std::min_element(bounds_begin, bounds_end),
*std::max_element(bounds_begin, bounds_end),
};
// There *must* be a better way than this... Even
// cutting half the cases with swapping isn't that much help.
if (!a.has_lower_bound()) {
if (may_be_negative(b)) result.max_defined = false; // NOLINT
if (may_be_positive(b)) result.min_defined = false; // NOLINT
}
if (!a.has_upper_bound()) {
if (may_be_negative(b)) result.min_defined = false; // NOLINT
if (may_be_positive(b)) result.max_defined = false; // NOLINT
}
if (!b.has_lower_bound()) {
if (may_be_negative(a)) result.max_defined = false; // NOLINT
if (may_be_positive(a)) result.min_defined = false; // NOLINT
}
if (!b.has_upper_bound()) {
if (may_be_negative(a)) result.min_defined = false; // NOLINT
if (may_be_positive(a)) result.max_defined = false; // NOLINT
}
return result;
} else {
return ConstantInterval::everything();
}
}
ConstantInterval divide(const ConstantInterval &a, int64_t b) {
ConstantInterval result(a);
if (b < 0) {
result = negate(result);
b = -b;
}
if (result.has_lower_bound()) {
result.min = div_imp(result.min, b);
}
if (result.has_upper_bound()) {
result.max = div_imp(result.max + b - 1, b);
}
return result;
}
class DerivativeBounds : public IRVisitor {
const string &var;
Scope<ConstantInterval> scope;
Scope<Interval> bounds;
void visit(const IntImm *) override {
result = ConstantInterval::single_point(0);
}
void visit(const UIntImm *) override {
result = ConstantInterval::single_point(0);
}
void visit(const FloatImm *) override {
result = ConstantInterval::single_point(0);
}
void visit(const StringImm *) override {
// require() Exprs can includes Strings.
result = ConstantInterval::single_point(0);
}
void visit(const Cast *op) override {
op->value.accept(this);
if (op->type.can_represent(op->value.type())) {
// No overflow.
return;
}
if (op->value.type().bits() >= 32 && op->type.bits() >= 32) {
// We assume 32-bit types don't overflow.
return;
}
// A narrowing cast. There may be more cases we can catch, but
// for now we punt.
if (!is_constant(result)) {
result = ConstantInterval::everything();
}
}
void visit(const Variable *op) override {
if (op->name == var) {
result = ConstantInterval::single_point(1);
} else if (scope.contains(op->name)) {
result = scope.get(op->name);
} else {
result = ConstantInterval::single_point(0);
}
}
void visit(const Add *op) override {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
result = add(ra, rb);
}
void visit(const Sub *op) override {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
result = sub(ra, rb);
}
void visit(const Mul *op) override {
if (op->type.is_scalar()) {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
// This is essentially the product rule: a*rb + b*ra
// but only implemented for the case where a or b is constant.
if (const int64_t *b = as_const_int_or_uint(op->b)) {
result = multiply(ra, *b);
} else if (const int64_t *a = as_const_int_or_uint(op->a)) {
result = multiply(rb, *a);
} else {
result = ConstantInterval::everything();
}
} else {
result = ConstantInterval::everything();
}
}
void visit(const Div *op) override {
if (op->type.is_scalar()) {
op->a.accept(this);
ConstantInterval ra = result;
if (const int64_t *b = as_const_int_or_uint(op->b)) {
result = divide(ra, *b);
} else {
result = ConstantInterval::everything();
}
} else {
result = ConstantInterval::everything();
}
}
void visit(const Mod *op) override {
result = ConstantInterval::everything();
}
void visit(const Min *op) override {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
result = unify(ra, rb);
}
void visit(const Max *op) override {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
result = unify(ra, rb);
}
void visit_eq(const Expr &a, const Expr &b) {
a.accept(this);
ConstantInterval ra = result;
b.accept(this);
ConstantInterval rb = result;
if (is_constant(ra) && is_constant(rb)) {
result = ConstantInterval::single_point(0);
} else {
// If the result is bounded, limit it to [-1, 1]. The largest
// difference possible is flipping from true to false or false
// to true.
result = ConstantInterval(-1, 1);
}
}
void visit(const EQ *op) override {
visit_eq(op->a, op->b);
}
void visit(const NE *op) override {
visit_eq(op->a, op->b);
}
void visit_lt(const Expr &a, const Expr &b) {
a.accept(this);
ConstantInterval ra = result;
b.accept(this);
ConstantInterval rb = result;
result = unify(negate(ra), rb);
// If the result is bounded, limit it to [-1, 1]. The largest
// difference possible is flipping from true to false or false
// to true.
if (result.has_lower_bound()) {
result.min = std::min<int64_t>(std::max<int64_t>(result.min, -1), 1);
}
if (result.has_upper_bound()) {
result.max = std::min<int64_t>(std::max<int64_t>(result.max, -1), 1);
}
}
void visit(const LT *op) override {
visit_lt(op->a, op->b);
}
void visit(const LE *op) override {
visit_lt(op->a, op->b);
}
void visit(const GT *op) override {
visit_lt(op->b, op->a);
}
void visit(const GE *op) override {
visit_lt(op->b, op->a);
}
void visit(const And *op) override {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
result = unify(ra, rb);
}
void visit(const Or *op) override {
op->a.accept(this);
ConstantInterval ra = result;
op->b.accept(this);
ConstantInterval rb = result;
result = unify(ra, rb);
}
void visit(const Not *op) override {
op->a.accept(this);
result = negate(result);
}
void visit(const Select *op) override {
// The result is the unified bounds, added to the "bump" that happens when switching from true to false.
if (op->type.is_scalar()) {
op->condition.accept(this);
ConstantInterval rcond = result;
op->true_value.accept(this);
ConstantInterval ra = result;
op->false_value.accept(this);
ConstantInterval rb = result;
ConstantInterval unified = unify(ra, rb);
// TODO: How to handle unsigned values?
Expr delta = simplify(op->true_value - op->false_value);
Interval delta_bounds = find_constant_bounds(delta, bounds);
ConstantInterval adjusted_delta;
// TODO: Maybe we can do something with one-sided intervals?
if (delta_bounds.is_bounded()) {
ConstantInterval delta_low = multiply(rcond, delta_bounds.min);
ConstantInterval delta_high = multiply(rcond, delta_bounds.max);
adjusted_delta = ConstantInterval::make_union(delta_low, delta_high);
} else {
delta.accept(this);
ConstantInterval rdelta = result;
adjusted_delta = multiply(rcond, rdelta);
}
result = add(unified, adjusted_delta);
} else {
result = ConstantInterval::everything();
}
}
void visit(const Load *op) override {
op->index.accept(this);
if (!is_constant(result)) {
result = ConstantInterval::everything();
}
}
void visit(const Ramp *op) override {
Expr equiv = op->base + Variable::make(op->base.type(), unique_name('t')) * op->stride;
equiv.accept(this);
}
void visit(const Broadcast *op) override {
op->value.accept(this);
}
void visit(const Call *op) override {
// Some functions are known to be monotonic
if (Call::as_tag(op) ||
op->is_intrinsic(Call::return_second)) {
op->args.back().accept(this);
return;
}
if (op->is_intrinsic(Call::unsafe_promise_clamped) ||
op->is_intrinsic(Call::promise_clamped)) {
op->args[0].accept(this);
return;
}
if (op->is_intrinsic(Call::require)) {
// require() returns the value of the second arg in all non-failure cases
op->args[1].accept(this);
return;
}
if (!op->is_pure() || !is_constant(result)) {
// Even with constant args, the result could vary from one loop iteration to the next.
result = ConstantInterval::everything();
return;
}
for (size_t i = 0; i < op->args.size(); i++) {
op->args[i].accept(this);
if (!is_constant(result)) {
// One of the args is not constant.
result = ConstantInterval::everything();
return;
}
}
result = ConstantInterval::single_point(0);
}
void visit(const Let *op) override {
op->value.accept(this);
ScopedBinding<Interval> bounds_binding(bounds, op->name, find_constant_bounds(op->value, bounds));
if (is_constant(result)) {
// No point pushing it if it's constant w.r.t the var,
// because unknown variables are treated as constant.
op->body.accept(this);
} else {
ScopedBinding<ConstantInterval> scope_binding(scope, op->name, result);
op->body.accept(this);
}
}
void visit(const Shuffle *op) override {
for (size_t i = 0; i < op->vectors.size(); i++) {
op->vectors[i].accept(this);
if (!is_constant(result)) {
result = ConstantInterval::everything();
return;
}
}
result = ConstantInterval::single_point(0);
}
void visit(const VectorReduce *op) override {
op->value.accept(this);
switch (op->op) {
case VectorReduce::Add:
case VectorReduce::SaturatingAdd:
result = multiply(result, op->value.type().lanes() / op->type.lanes());
break;
case VectorReduce::Min:
case VectorReduce::Max:
// These reductions are monotonic in the arg
break;
case VectorReduce::Mul:
case VectorReduce::And:
case VectorReduce::Or:
// These ones are not
if (!is_constant(result)) {
result = ConstantInterval::everything();
}
}
}
void visit(const LetStmt *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const AssertStmt *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const ProducerConsumer *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const For *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Acquire *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Store *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Provide *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Allocate *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Free *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Realize *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Block *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Fork *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const IfThenElse *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Evaluate *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Prefetch *op) override {
internal_error << "Monotonic of statement\n";
}
void visit(const Atomic *op) override {
internal_error << "Monotonic of statement\n";
}
public:
ConstantInterval result;
DerivativeBounds(const std::string &v, const Scope<ConstantInterval> &parent)
: var(v), result(ConstantInterval::everything()) {
scope.set_containing_scope(&parent);
}
};
} // namespace
ConstantInterval derivative_bounds(const Expr &e, const std::string &var, const Scope<ConstantInterval> &scope) {
if (!e.defined()) {
return ConstantInterval::everything();
}
DerivativeBounds m(var, scope);
e.accept(&m);
return m.result;
}
Monotonic is_monotonic(const Expr &e, const std::string &var, const Scope<ConstantInterval> &scope) {
if (!e.defined()) {
return Monotonic::Unknown;
}
return to_monotonic(derivative_bounds(e, var, scope));
}
Monotonic is_monotonic(const Expr &e, const std::string &var, const Scope<Monotonic> &scope) {
if (!e.defined()) {
return Monotonic::Unknown;
}
Scope<ConstantInterval> intervals_scope;
for (Scope<Monotonic>::const_iterator i = scope.cbegin(); i != scope.cend(); ++i) {
intervals_scope.push(i.name(), to_interval(i.value()));
}
return is_monotonic(e, var, intervals_scope);
}
namespace {
void check_increasing(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Increasing)
<< "Was supposed to be increasing: " << e << "\n";
}
void check_decreasing(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Decreasing)
<< "Was supposed to be decreasing: " << e << "\n";
}
void check_constant(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Constant)
<< "Was supposed to be constant: " << e << "\n";
}
void check_unknown(const Expr &e) {
internal_assert(is_monotonic(e, "x") == Monotonic::Unknown)
<< "Was supposed to be unknown: " << e << "\n";
}
} // namespace
void is_monotonic_test() {
Expr x = Variable::make(Int(32), "x");
Expr y = Variable::make(Int(32), "y");
check_increasing(x);
check_increasing(x + 4);
check_increasing(x + y);
check_increasing(x * 4);
check_increasing(x / 4);
check_increasing(min(x + 4, y + 4));
check_increasing(max(x + y, x - y));
check_increasing(x >= y);
check_increasing(x > y);
check_decreasing(-x);
check_decreasing(x * -4);
check_decreasing(x / -4);
check_decreasing(y - x);
check_decreasing(x < y);
check_decreasing(x <= y);
check_unknown(x == y);
check_unknown(x != y);
check_increasing(y <= x);
check_increasing(y < x);
check_decreasing(x <= y);
check_decreasing(x < y);
check_unknown(x * y);
// Not constant despite having constant args, because there's a side-effect.
check_unknown(Call::make(Int(32), "foo", {Expr(3)}, Call::Extern));
check_increasing(select(y == 2, x, x + 4));
check_decreasing(select(y == 2, -x, x * -4));
check_unknown(select(x > 2, x - 2, x));
check_unknown(select(x < 2, x, x - 2));
check_unknown(select(x > 2, -x + 2, -x));
check_unknown(select(x < 2, -x, -x + 2));
check_increasing(select(x > 2, x - 1, x));
check_increasing(select(x < 2, x, x - 1));
check_decreasing(select(x > 2, -x + 1, -x));
check_decreasing(select(x < 2, -x, -x + 1));
check_unknown(select(x < 2, x, x - 5));
check_unknown(select(x > 2, x - 5, x));
check_constant(y);
check_increasing(select(x < 17, y, y + 1));
check_increasing(select(x > 17, y, y - 1));
check_decreasing(select(x < 17, y, y - 1));
check_decreasing(select(x > 17, y, y + 1));
check_increasing(select(x % 2 == 0, x + 3, x + 3));
check_constant(select(y > 3, y + 23, y - 65));
check_decreasing(select(2 <= x, 0, 1));
check_increasing(select(2 <= x, 0, 1) + x);
check_decreasing(-min(x, 16));
check_unknown(select(0 < x, max(min(x, 4), 3), 4));
std::cout << "is_monotonic test passed" << std::endl;
}
} // namespace Internal
} // namespace Halide