#include "Simplify.h"
#include "Simplify_Internal.h"
#include "CSE.h"
#include "IRMutator.h"
#include "Substitute.h"
namespace Halide {
namespace Internal {
using std::map;
using std::ostringstream;
using std::pair;
using std::string;
using std::vector;
#if LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS
int Simplify::debug_indent = 0;
#endif
Simplify::Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai) :
remove_dead_lets(r), no_float_simplify(false) {
alignment_info.set_containing_scope(ai);
// Only respect the constant bounds from the containing scope.
for (Scope<Interval>::const_iterator iter = bi->cbegin(); iter != bi->cend(); ++iter) {
ConstBounds bounds;
if (const int64_t *i_min = as_const_int(iter.value().min)) {
bounds.min_defined = true;
bounds.min = *i_min;
}
if (const int64_t *i_max = as_const_int(iter.value().max)) {
bounds.max_defined = true;
bounds.max = *i_max;
}
if (bounds.min_defined || bounds.max_defined) {
bounds_info.push(iter.name(), bounds);
}
}
}
void Simplify::found_buffer_reference(const string &name, size_t dimensions) {
for (size_t i = 0; i < dimensions; i++) {
string stride = name + ".stride." + std::to_string(i);
if (var_info.contains(stride)) {
var_info.ref(stride).old_uses++;
}
string min = name + ".min." + std::to_string(i);
if (var_info.contains(min)) {
var_info.ref(min).old_uses++;
}
}
if (var_info.contains(name)) {
var_info.ref(name).old_uses++;
}
}
bool Simplify::const_float(const Expr &e, double *f) {
if (e.type().is_vector()) {
return false;
} else if (const double *p = as_const_float(e)) {
*f = *p;
return true;
} else {
return false;
}
}
bool Simplify::const_int(const Expr &e, int64_t *i) {
if (e.type().is_vector()) {
return false;
} else if (const int64_t *p = as_const_int(e)) {
*i = *p;
return true;
} else {
return false;
}
}
bool Simplify::const_uint(const Expr &e, uint64_t *u) {
if (e.type().is_vector()) {
return false;
} else if (const uint64_t *p = as_const_uint(e)) {
*u = *p;
return true;
} else {
return false;
}
}
void Simplify::ScopedFact::learn_false(const Expr &fact) {
Simplify::VarInfo info;
info.old_uses = info.new_uses = 0;
if (const Variable *v = fact.as<Variable>()) {
info.replacement = const_false(fact.type().lanes());
simplify->var_info.push(v->name, info);
pop_list.push_back(v);
} else if (const NE *ne = fact.as<NE>()) {
const Variable *v = ne->a.as<Variable>();
if (v && is_const(ne->b)) {
info.replacement = ne->b;
simplify->var_info.push(v->name, info);
pop_list.push_back(v);
}
} else if (const LT *lt = fact.as<LT>()) {
const Variable *v = lt->a.as<Variable>();
const int64_t *i = as_const_int(lt->b);
if (v && i) {
// !(v < i)
learn_lower_bound(v, *i);
}
v = lt->b.as<Variable>();
i = as_const_int(lt->a);
if (v && i) {
// !(i < v)
learn_upper_bound(v, *i);
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
const int64_t *i = as_const_int(le->b);
if (v && i) {
// !(v <= i)
learn_lower_bound(v, *i + 1);
}
v = le->b.as<Variable>();
i = as_const_int(le->a);
if (v && i) {
// !(i <= v)
learn_upper_bound(v, *i - 1);
}
} else if (const Or *o = fact.as<Or>()) {
// Both must be false
learn_false(o->a);
learn_false(o->b);
} else if (const Not *n = fact.as<Not>()) {
learn_true(n->a);
} else if (simplify->falsehoods.insert(fact).second) {
falsehoods.push_back(fact);
}
}
void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) {
ConstBounds b;
if (simplify->bounds_info.contains(v->name)) {
b = simplify->bounds_info.get(v->name);
}
if (b.max_defined && b.max < val) return;
b.max_defined = true;
b.max = val;
simplify->bounds_info.push(v->name, b);
bounds_pop_list.push_back(v);
}
void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) {
ConstBounds b;
if (simplify->bounds_info.contains(v->name)) {
b = simplify->bounds_info.get(v->name);
}
if (b.min_defined && b.min > val) return;
b.min_defined = true;
b.min = val;
simplify->bounds_info.push(v->name, b);
bounds_pop_list.push_back(v);
}
void Simplify::ScopedFact::learn_true(const Expr &fact) {
Simplify::VarInfo info;
info.old_uses = info.new_uses = 0;
if (const Variable *v = fact.as<Variable>()) {
info.replacement = const_true(fact.type().lanes());
simplify->var_info.push(v->name, info);
pop_list.push_back(v);
} else if (const EQ *eq = fact.as<EQ>()) {
const Variable *v = eq->a.as<Variable>();
if (v && is_const(eq->b)) {
info.replacement = eq->b;
simplify->var_info.push(v->name, info);
pop_list.push_back(v);
}
} else if (const LT *lt = fact.as<LT>()) {
const Variable *v = lt->a.as<Variable>();
const int64_t *i = as_const_int(lt->b);
if (v && i) {
// v < i
learn_upper_bound(v, *i - 1);
}
v = lt->b.as<Variable>();
i = as_const_int(lt->a);
if (v && i) {
// i < v
learn_lower_bound(v, *i + 1);
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
const int64_t *i = as_const_int(le->b);
if (v && i) {
// v <= i
learn_upper_bound(v, *i);
}
v = le->b.as<Variable>();
i = as_const_int(le->a);
if (v && i) {
// i <= v
learn_lower_bound(v, *i);
}
} else if (const And *a = fact.as<And>()) {
// Both must be true
learn_true(a->a);
learn_true(a->b);
} else if (const Not *n = fact.as<Not>()) {
learn_false(n->a);
} else if (simplify->truths.insert(fact).second) {
truths.push_back(fact);
}
}
Simplify::ScopedFact::~ScopedFact() {
for (auto v : pop_list) {
simplify->var_info.pop(v->name);
}
for (auto v : bounds_pop_list) {
simplify->bounds_info.pop(v->name);
}
for (const auto &e : truths) {
simplify->truths.erase(e);
}
for (const auto &e : falsehoods) {
simplify->falsehoods.erase(e);
}
}
Expr simplify(Expr e, bool remove_dead_lets,
const Scope<Interval> &bounds,
const Scope<ModulusRemainder> &alignment) {
return Simplify(remove_dead_lets, &bounds, &alignment).mutate(e, nullptr);
}
Stmt simplify(Stmt s, bool remove_dead_lets,
const Scope<Interval> &bounds,
const Scope<ModulusRemainder> &alignment) {
return Simplify(remove_dead_lets, &bounds, &alignment).mutate(s);
}
class SimplifyExprs : public IRMutator2 {
public:
using IRMutator2::mutate;
Expr mutate(const Expr &e) override {
return simplify(e);
}
};
Stmt simplify_exprs(Stmt s) {
return SimplifyExprs().mutate(s);
}
bool can_prove(Expr e, const Scope<Interval> &bounds) {
internal_assert(e.type().is_bool())
<< "Argument to can_prove is not a boolean Expr: " << e << "\n";
// Remove likelies
struct RemoveLikelies : public IRMutator2 {
using IRMutator2::visit;
Expr visit(const Call *op) override {
if (op->is_intrinsic(Call::likely) ||
op->is_intrinsic(Call::likely_if_innermost)) {
return mutate(op->args[0]);
} else {
return IRMutator2::visit(op);
}
}
};
e = RemoveLikelies().mutate(e);
e = common_subexpression_elimination(e);
Expr orig = e;
e = simplify(e, true, bounds);
// When terms cancel, the simplifier doesn't always successfully
// kill the dead lets.
while (const Let *l = e.as<Let>()) {
e = l->body;
}
// Take a closer look at all failed proof attempts to hunt for
// simplifier weaknesses
if (debug::debug_level() > 0 && !is_const(e)) {
struct RenameVariables : public IRMutator2 {
using IRMutator2::visit;
Expr visit(const Variable *op) override {
auto it = vars.find(op->name);
if (lets.contains(op->name)) {
return Variable::make(op->type, lets.get(op->name));
} else if (it == vars.end()) {
std::string name = "v" + std::to_string(count++);
vars[op->name] = name;
out_vars.emplace_back(op->type, name);
return Variable::make(op->type, name);
} else {
return Variable::make(op->type, it->second);
}
}
Expr visit(const Let *op) override {
std::string name = "v" + std::to_string(count++);
ScopedBinding<string> bind(lets, op->name, name);
return Let::make(name, mutate(op->value), mutate(op->body));
}
int count = 0;
map<string, string> vars;
Scope<string> lets;
std::vector<pair<Type, string>> out_vars;
} renamer;
e = renamer.mutate(e);
// Look for a concrete counter-example with random probing
static std::mt19937 rng(0);
for (int i = 0; i < 100; i++) {
map<string, Expr> s;
for (auto p : renamer.out_vars) {
if (p.first.is_handle()) {
// This aint gonna work
return false;
}
s[p.second] = make_const(p.first, (int)(rng() & 0xffff) - 0x7fff);
}
Expr probe = simplify(substitute(s, e));
if (const Call *c = probe.as<Call>()) {
if (c->is_intrinsic(Call::likely) ||
c->is_intrinsic(Call::likely_if_innermost)) {
probe = c->args[0];
}
}
if (!is_one(probe)) {
// Found a counter-example, or something that fails to fold
return false;
}
}
debug(0) << "Failed to prove, but could not find a counter-example:\n " << e << "\n";
debug(0) << "Original expression:\n" << orig << "\n";
return false;
}
return is_one(e);
}
} // namespace Internal
} // namespace Halide