https://github.com/halide/Halide
Tip revision: 0181dd9d660d58bf1e84a12d58898f0f4f0df16e authored by Andrew Adams on 07 November 2022, 22:03:38 UTC
Revert formatting changes
Revert formatting changes
Tip revision: 0181dd9
FindIntrinsics.cpp
#include "FindIntrinsics.h"
#include "CSE.h"
#include "CodeGen_Internal.h"
#include "ConciseCasts.h"
#include "IRMatch.h"
#include "IRMutator.h"
#include "Simplify.h"
namespace Halide {
namespace Internal {
using namespace Halide::ConciseCasts;
namespace {
bool find_intrinsics_for_type(const Type &t) {
// Currently, we only try to find and replace intrinsics for vector types that aren't bools.
return t.is_vector() && t.bits() >= 8;
}
Expr widen(Expr a) {
Type result_type = a.type().widen();
return Cast::make(result_type, std::move(a));
}
Expr narrow(Expr a) {
Type result_type = a.type().narrow();
return Cast::make(result_type, std::move(a));
}
Expr lossless_narrow(const Expr &x) {
return lossless_cast(x.type().narrow(), x);
}
// Remove a widening cast even if it changes the sign of the result.
Expr strip_widening_cast(const Expr &x) {
Expr narrow = lossless_narrow(x);
if (narrow.defined()) {
return narrow;
}
return lossless_cast(x.type().narrow().with_code(halide_type_uint), x);
}
Expr saturating_narrow(const Expr &a) {
Type narrow = a.type().narrow();
return saturating_cast(narrow, a);
}
// Returns true iff t is an integral type where overflow is undefined
bool no_overflow_int(Type t) {
return t.is_int() && t.bits() >= 32;
}
// Returns true iff t does not have a well defined overflow behavior.
bool no_overflow(Type t) {
return t.is_float() || no_overflow_int(t);
}
// If there's a widening add or subtract in the first e.type().bits() / 2 - 1
// levels down a tree of adds or subtracts, we know there's enough headroom for
// another add without overflow. For example, it is safe to add to
// (widening_add(x, y) - z) without overflow.
bool is_safe_for_add(const Expr &e, int max_depth) {
if (max_depth-- <= 0) {
return false;
}
if (const Add *add = e.as<Add>()) {
return is_safe_for_add(add->a, max_depth) || is_safe_for_add(add->b, max_depth);
} else if (const Sub *sub = e.as<Sub>()) {
return is_safe_for_add(sub->a, max_depth) || is_safe_for_add(sub->b, max_depth);
} else if (const Cast *cast = e.as<Cast>()) {
if (cast->type.bits() > cast->value.type().bits()) {
return true;
} else if (cast->type.bits() == cast->value.type().bits()) {
return is_safe_for_add(cast->value, max_depth);
}
} else if (Call::as_intrinsic(e, {Call::widening_add, Call::widening_sub, Call::widen_right_add, Call::widen_right_sub})) {
return true;
}
return false;
}
bool is_safe_for_add(const Expr &e) {
return is_safe_for_add(e, e.type().bits() / 2 - 1);
}
// We want to find and remove an add of 'round' from e. This is not
// the same thing as just subtracting round, we specifically want
// to remove an addition of exactly round.
Expr find_and_subtract(const Expr &e, const Expr &round) {
if (const Add *add = e.as<Add>()) {
Expr a = find_and_subtract(add->a, round);
if (a.defined()) {
return Add::make(a, add->b);
}
Expr b = find_and_subtract(add->b, round);
if (b.defined()) {
return Add::make(add->a, b);
}
} else if (const Sub *sub = e.as<Sub>()) {
Expr a = find_and_subtract(sub->a, round);
if (a.defined()) {
return Sub::make(a, sub->b);
}
// We can't recurse into the negatve part of a subtract.
} else if (can_prove(e == round)) {
return make_zero(e.type());
}
return Expr();
}
Expr to_rounding_shift(const Call *c) {
if (c->is_intrinsic(Call::shift_left) || c->is_intrinsic(Call::shift_right)) {
internal_assert(c->args.size() == 2);
Expr a = c->args[0];
Expr b = c->args[1];
// Helper to make the appropriate shift.
auto rounding_shift = [&](const Expr &a, const Expr &b) {
if (c->is_intrinsic(Call::shift_right)) {
return rounding_shift_right(a, b);
} else {
return rounding_shift_left(a, b);
}
};
// The rounding offset for the shift we have.
Type round_type = a.type().with_lanes(1);
if (Call::as_intrinsic(a, {Call::widening_add})) {
round_type = round_type.narrow();
}
Expr round;
if (c->is_intrinsic(Call::shift_right)) {
round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2;
} else {
round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2;
}
// Input expressions are simplified before running find_intrinsics, but b
// has been lifted here so we need to lower_intrinsics before simplifying
// and re-lifting. Should we move this code into the FindIntrinsics class
// to make it easier to lift round?
round = lower_intrinsics(round);
round = simplify(round);
round = find_intrinsics(round);
// We can always handle widening adds.
if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) {
if (can_prove(lower_intrinsics(add->args[0] == round))) {
return rounding_shift(cast(add->type, add->args[1]), b);
} else if (can_prove(lower_intrinsics(add->args[1] == round))) {
return rounding_shift(cast(add->type, add->args[0]), b);
}
}
if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) {
if (can_prove(lower_intrinsics(add->args[1] == round))) {
return rounding_shift(cast(add->type, add->args[0]), b);
}
}
// Also need to handle the annoying case of a reinterpret wrapping a widen_right_add
// TODO: this pattern makes me want to change the semantics of this op.
if (const Reinterpret *reinterp = a.as<Reinterpret>()) {
if (reinterp->type.bits() == reinterp->value.type().bits()) {
if (const Call *add = Call::as_intrinsic(reinterp->value, {Call::widen_right_add})) {
if (can_prove(lower_intrinsics(add->args[1] == round))) {
// We expect the first operand to be a reinterpet.
const Reinterpret *reinterp_a = add->args[0].as<Reinterpret>();
internal_assert(reinterp_a) << "Failed: " << add->args[0] << "\n";
return rounding_shift(reinterp_a->value, b);
}
}
}
}
// If it wasn't a widening or saturating add, we might still
// be able to safely accept the rounding.
Expr a_less_round = find_and_subtract(a, round);
if (a_less_round.defined()) {
// We found and removed the rounding. However, we may have just changed
// behavior due to overflow. This is still safe if the type is not
// overflowing, or we can find a widening add or subtract in the tree
// of adds/subtracts. This is a common pattern, e.g.
// rounding_halving_add(a, b) = shift_round(widening_add(a, b) + 1, 1).
// TODO: This could be done with bounds inference instead of this hack
// if it supported intrinsics like widening_add and tracked bounds for
// types other than int32.
if (no_overflow(a.type()) || is_safe_for_add(a_less_round)) {
return rounding_shift(simplify(a_less_round), b);
}
}
}
return Expr();
}
class FindIntrinsics : public IRMutator {
protected:
using IRMutator::visit;
IRMatcher::Wild<0> x;
IRMatcher::Wild<1> y;
IRMatcher::Wild<2> z;
IRMatcher::WildConst<0> c0;
IRMatcher::WildConst<1> c1;
Expr visit(const Add *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
Expr a = mutate(op->a);
Expr b = mutate(op->b);
// Try widening both from the same signedness as the result, and from uint.
for (halide_type_code_t code : {op->type.code(), halide_type_uint}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);
if (narrow_a.defined() && narrow_b.defined()) {
Expr result = widening_add(narrow_a, narrow_b);
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
}
if (op->type.is_int_or_uint() && op->type.bits() > 8) {
// Look for widen_right_add intrinsics.
// Yes we do an duplicate code, but we want to check the op->type.code() first,
// and the opposite as well.
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
Type narrow = op->type.narrow().with_code(code);
// Pulling casts out of VectorReduce nodes breaks too much codegen, skip for now.
Expr narrow_a = (a.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, a);
Expr narrow_b = (b.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, b);
// This case should have been handled by the above check for widening_add.
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
<< "find_intrinsics failed to find a widening_add: " << a << " + " << b << "\n";
if (narrow_a.defined()) {
Expr result;
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_add(reinterpret(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
} else {
result = widen_right_add(b, narrow_a);
}
internal_assert(result.type() == op->type);
return result;
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_add(reinterpret(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
} else {
result = widen_right_add(a, narrow_b);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
}
}
// TODO: there can be widen_right_add + widen_right_add simplification rules.
// i.e. widen_right_add(a, b) + widen_right_add(c, d) = (a + c) + widening_add(b, d)
if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
return Add::make(a, b);
}
}
Expr visit(const Sub *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
Expr a = mutate(op->a);
Expr b = mutate(op->b);
// Try widening both from the same type as the result, and from uint.
for (halide_type_code_t code : {op->type.code(), halide_type_uint}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);
if (narrow_a.defined() && narrow_b.defined()) {
Expr negative_narrow_b = lossless_negate(narrow_b);
Expr result;
if (negative_narrow_b.defined()) {
result = widening_add(narrow_a, negative_narrow_b);
} else {
result = widening_sub(narrow_a, narrow_b);
}
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
}
Expr negative_b = lossless_negate(b);
if (negative_b.defined()) {
return Add::make(a, negative_b);
}
// Run after the lossless_negate check, because we want that to turn into an widen_right_add if relevant.
if (op->type.is_int_or_uint() && op->type.bits() > 8) {
// Look for widen_right_sub intrinsics.
// Yes we do an duplicate code, but we want to check the op->type.code() first,
// and the opposite as well.
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_b = lossless_cast(narrow, b);
if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_sub(reinterpret(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
} else {
result = widen_right_sub(a, narrow_b);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
}
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
return Sub::make(a, b);
}
}
Expr visit(const Mul *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
if (as_const_int(op->b) || as_const_uint(op->b)) {
// Distribute constants through add/sub. Do this before we muck everything up with widening
// intrinsics.
// TODO: Only do this for widening?
// TODO: Try to do this with IRMatcher::rewriter. The challenge is managing the narrowing/widening casts,
// and doing constant folding without the simplifier undoing the work.
if (const Add *add_a = op->a.as<Add>()) {
return mutate(Add::make(simplify(Mul::make(add_a->a, op->b)), simplify(Mul::make(add_a->b, op->b))));
} else if (const Sub *sub_a = op->a.as<Sub>()) {
return mutate(Sub::make(simplify(Mul::make(sub_a->a, op->b)), simplify(Mul::make(sub_a->b, op->b))));
}
}
Expr a = mutate(op->a);
Expr b = mutate(op->b);
// Rewrite multiplies to shifts if possible.
if (op->type.is_int() || op->type.is_uint()) {
int pow2 = 0;
if (is_const_power_of_two_integer(a, &pow2)) {
return mutate(b << cast(UInt(b.type().bits()), pow2));
} else if (is_const_power_of_two_integer(b, &pow2)) {
return mutate(a << cast(UInt(a.type().bits()), pow2));
}
}
// We're applying this to float, which seems OK? float16 * float16 -> float32 is a widening multiply?
// This uses strip_widening_cast to ignore the signedness of the narrow value.
Expr narrow_a = strip_widening_cast(a);
Expr narrow_b = strip_widening_cast(b);
if (narrow_a.defined() && narrow_b.defined() &&
(narrow_a.type().is_int_or_uint() == narrow_b.type().is_int_or_uint() ||
narrow_a.type().is_float() == narrow_b.type().is_float())) {
Expr result = widening_mul(narrow_a, narrow_b);
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
if (op->type.is_int_or_uint() && op->type.bits() > 8) {
// Look for widen_right_mul intrinsics.
// Yes we do an duplicate code, but we want to check the op->type.code() first,
// and the opposite as well.
for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) {
Type narrow = op->type.narrow().with_code(code);
Expr narrow_a = lossless_cast(narrow, a);
Expr narrow_b = lossless_cast(narrow, b);
// This case should have been handled by the above check for widening_mul.
internal_assert(!(narrow_a.defined() && narrow_b.defined()))
<< "find_intrinsics failed to find a widening_mul: " << a << " + " << b << "\n";
if (narrow_a.defined()) {
Expr result;
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_mul(reinterpret(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
} else {
result = widen_right_mul(b, narrow_a);
}
internal_assert(result.type() == op->type);
return result;
} else if (narrow_b.defined()) {
Expr result;
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_mul(reinterpret(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
} else {
result = widen_right_mul(a, narrow_b);
}
internal_assert(result.type() == op->type);
return mutate(result);
}
}
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
return Mul::make(a, b);
}
}
Expr visit(const Div *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
Expr a = mutate(op->a);
Expr b = mutate(op->b);
int shift_amount;
if (is_const_power_of_two_integer(b, &shift_amount) && op->type.is_int_or_uint()) {
return mutate(a >> make_const(UInt(a.type().bits()), shift_amount));
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
return Div::make(a, b);
}
}
// We don't handle Mod because we don't have any patterns that look for bitwise and vs.
// mod.
template<class MinOrMax>
Expr visit_min_or_max(const MinOrMax *op) {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
Expr a = mutate(op->a);
Expr b = mutate(op->b);
if (const Cast *cast_a = a.as<Cast>()) {
Expr cast_b = lossless_cast(cast_a->value.type(), b);
if (cast_a->type.can_represent(cast_a->value.type()) && cast_b.defined()) {
// This is a widening cast that can be moved outside the min.
return mutate(Cast::make(cast_a->type, MinOrMax::make(cast_a->value, cast_b)));
}
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
return MinOrMax::make(a, b);
}
}
Expr visit(const Min *op) override {
return visit_min_or_max(op);
}
Expr visit(const Max *op) override {
return visit_min_or_max(op);
}
Expr visit(const Cast *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
Expr value = mutate(op->value);
// This mutator can generate redundant casts. We can't use the simplifier because it
// undoes some of the intrinsic lowering here, and it causes some problems due to
// factoring (instead of distributing) constants.
if (const Cast *cast = value.as<Cast>()) {
if (cast->type.can_represent(cast->value.type()) || cast->type.can_represent(op->type)) {
// The intermediate cast is redundant.
value = cast->value;
}
}
if (op->type.is_int() || op->type.is_uint()) {
Expr lower = cast(value.type(), op->type.min());
Expr upper = cast(value.type(), op->type.max());
auto rewrite = IRMatcher::rewriter(value, op->type);
Type op_type_wide = op->type.widen();
Type signed_type_wide = op_type_wide.with_code(halide_type_int);
Type unsigned_type = op->type.with_code(halide_type_uint);
// Give concise names to various predicates we want to use in
// rewrite rules below.
int bits = op->type.bits();
auto is_x_same_int = op->type.is_int() && is_int(x, bits);
auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y));
auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2);
if (
// Saturating patterns
rewrite(max(min(widening_add(x, y), upper), lower),
saturating_add(x, y),
is_x_same_int_or_uint) ||
rewrite(max(min(widening_sub(x, y), upper), lower),
saturating_sub(x, y),
is_x_same_int_or_uint) ||
rewrite(min(cast(signed_type_wide, widening_add(x, y)), upper),
saturating_add(x, y),
is_x_same_uint) ||
rewrite(min(widening_add(x, y), upper),
saturating_add(x, y),
op->type.is_uint() && is_x_same_uint) ||
rewrite(max(widening_sub(x, y), lower),
saturating_sub(x, y),
op->type.is_uint() && is_x_same_uint) ||
// Saturating narrow patterns.
rewrite(max(min(x, upper), lower),
saturating_cast(op->type, x)) ||
rewrite(min(x, upper),
saturating_cast(op->type, x),
is_uint(x)) ||
// Averaging patterns
//
// We have a slight preference for rounding_halving_add over
// using halving_add when unsigned, because x86 supports it.
rewrite(shift_right(widening_add(x, c0), 1),
rounding_halving_add(x, c0 - 1),
c0 > 0 && is_x_same_uint) ||
rewrite(shift_right(widening_add(x, y), 1),
halving_add(x, y),
is_x_same_int_or_uint) ||
rewrite(shift_right(widening_add(x, c0), c1),
rounding_shift_right(x, cast(op->type, c1)),
c0 == shift_left(1, c1 - 1) && is_x_same_int_or_uint) ||
rewrite(shift_right(widening_add(x, c0), c1),
shift_right(rounding_halving_add(x, cast(op->type, fold(c0 - 1))), cast(op->type, fold(c1 - 1))),
c0 > 0 && c1 > 0 && is_x_same_uint) ||
rewrite(shift_right(widening_add(x, y), c0),
shift_right(halving_add(x, y), cast(op->type, fold(c0 - 1))),
c0 > 0 && is_x_same_int_or_uint) ||
rewrite(shift_right(widening_sub(x, y), 1),
halving_sub(x, y),
is_x_same_int_or_uint) ||
rewrite(halving_add(widening_add(x, y), 1),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||
rewrite(halving_add(widening_add(x, 1), y),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||
rewrite(rounding_shift_right(widening_add(x, y), 1),
rounding_halving_add(x, y),
is_x_same_int_or_uint) ||
// Multiply-keep-high-bits patterns.
rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) ||
rewrite(max(min(rounding_shift_right(widening_mul(x, y), z), upper), lower),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) ||
rewrite(min(shift_right(widening_mul(x, y), z), upper),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_uint && x_y_same_sign && is_uint(z)) ||
rewrite(min(rounding_shift_right(widening_mul(x, y), z), upper),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_uint && x_y_same_sign && is_uint(z)) ||
// We don't need saturation for the full upper half of a multiply.
// For signed integers, this is almost true, except for when x and y
// are both the most negative value. For these, we only need saturation
// at the upper bound.
rewrite(min(shift_right(widening_mul(x, y), c0), upper),
mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int && x_y_same_sign && c0 >= bits - 1) ||
rewrite(min(rounding_shift_right(widening_mul(x, y), c0), upper),
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int && x_y_same_sign && c0 >= bits - 1) ||
rewrite(shift_right(widening_mul(x, y), c0),
mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) ||
rewrite(rounding_shift_right(widening_mul(x, y), c0),
rounding_mul_shift_right(x, y, cast(unsigned_type, c0)),
is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) ||
// We can also match on smaller shifts if one of the args is
// narrow. We don't do this for signed (yet), because the
// saturation issue is tricky.
rewrite(shift_right(widening_mul(x, cast(op->type, y)), c0),
mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||
rewrite(rounding_shift_right(widening_mul(x, cast(op->type, y)), c0),
rounding_mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||
rewrite(shift_right(widening_mul(cast(op->type, y), x), c0),
mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||
rewrite(rounding_shift_right(widening_mul(cast(op->type, y), x), c0),
rounding_mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)),
is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) ||
// Halving subtract patterns
rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1),
halving_sub(x, y),
is_x_same_int_or_uint) ||
false) {
internal_assert(rewrite.result.type() == op->type)
<< "Rewrite changed type: " << Expr(op) << " -> " << rewrite.result << "\n";
return mutate(rewrite.result);
}
// When the argument is a widened rounding shift, we might not need the widening.
// When there is saturation, we can only avoid the widening if we know the shift is
// a right shift. Without saturation, we can ignore the widening.
auto is_x_wide_int = op->type.is_int() && is_int(x, bits * 2);
auto is_x_wide_uint = op->type.is_uint() && is_uint(x, bits * 2);
auto is_x_wide_int_or_uint = is_x_wide_int || is_x_wide_uint;
// We can't do everything we want here with rewrite rules alone. So, we rewrite them
// to rounding_shifts with the widening still in place, and narrow it after the rewrite
// succeeds.
// clang-format off
if (rewrite(max(min(rounding_shift_right(x, y), upper), lower), rounding_shift_right(x, y), is_x_wide_int_or_uint) ||
rewrite(rounding_shift_right(x, y), rounding_shift_right(x, y), is_x_wide_int_or_uint) ||
rewrite(rounding_shift_left(x, y), rounding_shift_left(x, y), is_x_wide_int_or_uint) ||
false) {
const Call *shift = Call::as_intrinsic(rewrite.result, {Call::rounding_shift_right, Call::rounding_shift_left});
internal_assert(shift);
bool is_saturated = op->value.as<Max>() || op->value.as<Min>();
Expr a = lossless_cast(op->type, shift->args[0]);
Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]);
if (a.defined() && b.defined()) {
if (!is_saturated ||
(shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) ||
(shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) {
return mutate(Call::make(op->type, shift->name, {a, b}, Call::PureIntrinsic));
}
}
}
// clang-format on
}
if (value.same_as(op->value)) {
return op;
} else if (op->type != value.type()) {
return Cast::make(op->type, value);
} else {
return value;
}
}
Expr visit(const Call *op) override {
if (!find_intrinsics_for_type(op->type)) {
return IRMutator::visit(op);
}
Expr mutated = IRMutator::visit(op);
op = mutated.as<Call>();
if (!op) {
return mutated;
}
auto rewrite = IRMatcher::rewriter(op, op->type);
if (rewrite(intrin(Call::abs, widening_sub(x, y)), cast(op->type, intrin(Call::absd, x, y))) ||
false) {
return rewrite.result;
}
const int bits = op->type.bits();
const auto is_x_same_int = op->type.is_int() && is_int(x, bits);
const auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits);
const auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint;
auto x_y_same_sign = (is_int(x) == is_int(y)) || (is_uint(x) && is_uint(y));
Type unsigned_type = op->type.with_code(halide_type_uint);
const auto is_x_wider_int_or_uint = (op->type.is_int() && is_int(x, 2 * bits)) || (op->type.is_uint() && is_uint(x, 2 * bits));
Type opposite_type = op->type.is_int() ? op->type.with_code(halide_type_uint) : op->type.with_code(halide_type_int);
const auto is_x_wider_opposite_int = (op->type.is_int() && is_uint(x, 2 * bits)) || (op->type.is_uint() && is_int(x, 2 * bits));
if (
// Simplify extending patterns.
// (x + widen(y)) + widen(z) = x + widening_add(y, z).
rewrite(widen_right_add(widen_right_add(x, y), z),
x + widening_add(y, z),
// We only care about integers, this should be trivially true.
is_x_same_int_or_uint) ||
// (x - widen(y)) - widen(z) = x - widening_add(y, z).
rewrite(widen_right_sub(widen_right_sub(x, y), z),
x - widening_add(y, z),
// We only care about integers, this should be trivially true.
is_x_same_int_or_uint) ||
// (x + widen(y)) - widen(z) = x + cast(t, widening_sub(y, z))
// cast (reinterpret) is needed only for uints.
rewrite(widen_right_sub(widen_right_add(x, y), z),
x + widening_sub(y, z),
is_x_same_int) ||
rewrite(widen_right_sub(widen_right_add(x, y), z),
x + cast(op->type, widening_sub(y, z)),
is_x_same_uint) ||
// (x - widen(y)) + widen(z) = x + cast(t, widening_sub(z, y))
// cast (reinterpret) is needed only for uints.
rewrite(widen_right_add(widen_right_sub(x, y), z),
x + widening_sub(z, y),
is_x_same_int) ||
rewrite(widen_right_add(widen_right_sub(x, y), z),
x + cast(op->type, widening_sub(z, y)),
is_x_same_uint) ||
// Saturating patterns.
rewrite(saturating_cast(op->type, widening_add(x, y)),
saturating_add(x, y),
is_x_same_int_or_uint) ||
rewrite(saturating_cast(op->type, widening_sub(x, y)),
saturating_sub(x, y),
is_x_same_int_or_uint) ||
rewrite(saturating_cast(op->type, shift_right(widening_mul(x, y), z)),
mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) ||
rewrite(saturating_cast(op->type, rounding_shift_right(widening_mul(x, y), z)),
rounding_mul_shift_right(x, y, cast(unsigned_type, z)),
is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) ||
// We can remove unnecessary widening if we are then performing a saturating narrow.
// This is similar to the logic inside `visit_min_or_max`.
(((bits <= 32) &&
// Examples:
// i8_sat(int16(i8)) -> i8
// u8_sat(uint16(u8)) -> u8
rewrite(saturating_cast(op->type, cast(op->type.widen(), x)),
x,
is_x_same_int_or_uint)) ||
((bits <= 16) &&
// Examples:
// i8_sat(int32(i16)) -> i8_sat(i16)
// u8_sat(uint32(u16)) -> u8_sat(u16)
(rewrite(saturating_cast(op->type, cast(op->type.widen().widen(), x)),
saturating_cast(op->type, x),
is_x_wider_int_or_uint) ||
// Examples:
// i8_sat(uint32(u16)) -> i8_sat(u16)
// u8_sat(int32(i16)) -> i8_sat(i16)
rewrite(saturating_cast(op->type, cast(opposite_type.widen().widen(), x)),
saturating_cast(op->type, x),
is_x_wider_opposite_int) ||
false))) ||
false) {
return mutate(rewrite.result);
}
if (no_overflow(op->type)) {
// clang-format off
if (rewrite(halving_add(x + y, 1), rounding_halving_add(x, y)) ||
rewrite(halving_add(x, y + 1), rounding_halving_add(x, y)) ||
rewrite(halving_add(x + 1, y), rounding_halving_add(x, y)) ||
rewrite(halving_add(x, 1), rounding_shift_right(x, 1)) ||
rewrite(shift_right(x + y, 1), halving_add(x, y)) ||
rewrite(shift_right(x - y, 1), halving_sub(x, y)) ||
rewrite(rounding_shift_right(x + y, 1), rounding_halving_add(x, y)) ||
false) {
return mutate(rewrite.result);
}
// clang-format on
}
// Move widening casts inside widening arithmetic outside the arithmetic,
// e.g. widening_mul(widen(u8), widen(i8)) -> widen(widening_mul(u8, i8)).
if (op->is_intrinsic(Call::widening_mul)) {
internal_assert(op->args.size() == 2);
Expr narrow_a = strip_widening_cast(op->args[0]);
Expr narrow_b = strip_widening_cast(op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_mul(narrow_a, narrow_b)));
}
} else if (op->is_intrinsic(Call::widening_add) && (op->type.bits() >= 16)) {
internal_assert(op->args.size() == 2);
for (halide_type_code_t t : {op->type.code(), halide_type_uint}) {
Type narrow_t = op->type.narrow().narrow().with_code(t);
Expr narrow_a = lossless_cast(narrow_t, op->args[0]);
Expr narrow_b = lossless_cast(narrow_t, op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_add(narrow_a, narrow_b)));
}
}
} else if (op->is_intrinsic(Call::widening_sub) && (op->type.bits() >= 16)) {
internal_assert(op->args.size() == 2);
for (halide_type_code_t t : {op->type.code(), halide_type_uint}) {
Type narrow_t = op->type.narrow().narrow().with_code(t);
Expr narrow_a = lossless_cast(narrow_t, op->args[0]);
Expr narrow_b = lossless_cast(narrow_t, op->args[1]);
if (narrow_a.defined() && narrow_b.defined()) {
return mutate(Cast::make(op->type, widening_sub(narrow_a, narrow_b)));
}
}
}
// TODO: do we want versions of widen_right_add here?
if (op->is_intrinsic(Call::shift_right) || op->is_intrinsic(Call::shift_left)) {
// Try to turn this into a widening shift.
internal_assert(op->args.size() == 2);
Expr a_narrow = lossless_narrow(op->args[0]);
Expr b_narrow = lossless_narrow(op->args[1]);
if (a_narrow.defined() && b_narrow.defined()) {
Expr result = op->is_intrinsic(Call::shift_left) ? widening_shift_left(a_narrow, b_narrow) : widening_shift_right(a_narrow, b_narrow);
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
// Try to turn this into a rounding shift.
Expr rounding_shift = to_rounding_shift(op);
if (rounding_shift.defined()) {
return mutate(rounding_shift);
}
}
if (op->is_intrinsic(Call::rounding_shift_left) || op->is_intrinsic(Call::rounding_shift_right)) {
// Try to turn this into a widening shift.
internal_assert(op->args.size() == 2);
Expr a_narrow = lossless_narrow(op->args[0]);
Expr b_narrow = lossless_narrow(op->args[1]);
if (a_narrow.defined() && b_narrow.defined()) {
Expr result;
if (op->is_intrinsic(Call::rounding_shift_right) && can_prove(b_narrow > 0)) {
result = rounding_shift_right(a_narrow, b_narrow);
} else if (op->is_intrinsic(Call::rounding_shift_left) && can_prove(b_narrow < 0)) {
result = rounding_shift_left(a_narrow, b_narrow);
} else {
return op;
}
if (result.type() != op->type) {
result = Cast::make(op->type, result);
}
return mutate(result);
}
}
return op;
}
};
// Substitute in let values than have an output vector
// type wider than all the types of other variables
// referenced. This can't cause combinatorial explosion,
// because each let in a chain has a wider value than the
// ones it refers to.
class SubstituteInWideningLets : public IRMutator {
using IRMutator::visit;
bool widens(const Expr &e) {
class AllInputsNarrowerThan : public IRVisitor {
int bits;
using IRVisitor::visit;
void visit(const Variable *op) override {
result &= op->type.bits() < bits;
}
void visit(const Load *op) override {
result &= op->type.bits() < bits;
}
void visit(const Call *op) override {
if (op->is_pure() && op->is_intrinsic()) {
IRVisitor::visit(op);
} else {
result &= op->type.bits() < bits;
}
}
public:
AllInputsNarrowerThan(Type t)
: bits(t.bits()) {
}
bool result = true;
} widens(e.type());
e.accept(&widens);
return widens.result;
}
Scope<Expr> replacements;
Expr visit(const Variable *op) override {
if (replacements.contains(op->name)) {
return replacements.get(op->name);
} else {
return op;
}
}
template<typename T>
auto visit_let(const T *op) -> decltype(op->body) {
struct Frame {
std::string name;
Expr new_value;
ScopedBinding<Expr> bind;
Frame(const std::string &name, const Expr &new_value, ScopedBinding<Expr> &&bind)
: name(name), new_value(new_value), bind(std::move(bind)) {
}
};
std::vector<Frame> frames;
decltype(op->body) body;
do {
body = op->body;
Expr value = op->value;
bool should_replace = find_intrinsics_for_type(value.type()) && widens(value);
// We can only substitute in pure stuff. Isolate all
// impure subexpressions and leave them behind here as
// lets.
class LeaveBehindSubexpressions : public IRMutator {
using IRMutator::visit;
Expr visit(const Call *op) override {
if (!op->is_pure() || !op->is_intrinsic()) {
// Only enter pure intrinsics (e.g. existing uses of widening_add)
std::string name = unique_name('t');
frames.emplace_back(name, op, ScopedBinding<Expr>{});
return Variable::make(op->type, name);
} else {
return IRMutator::visit(op);
}
}
Expr visit(const Load *op) override {
// Never enter loads. They can be impure and none
// of our patterns match them.
std::string name = unique_name('t');
frames.emplace_back(name, op, ScopedBinding<Expr>{});
return Variable::make(op->type, name);
}
std::vector<Frame> &frames;
public:
LeaveBehindSubexpressions(std::vector<Frame> &frames)
: frames(frames) {
}
} extractor(frames);
if (should_replace) {
size_t start_of_new_lets = frames.size();
value = extractor.mutate(value);
// Mutate any subexpressions the extractor decided to
// leave behind, in case they in turn depend on lets
// we've decided to substitute in.
for (size_t i = start_of_new_lets; i < frames.size(); i++) {
frames[i].new_value = mutate(frames[i].new_value);
}
// Check it wasn't lifted entirely
should_replace = !value.as<Variable>();
}
// TODO: If it's an int32/64 vector, it may be
// implicitly widening because overflow is UB. Hard to
// see how to handle this without worrying about
// combinatorial explosion of substitutions.
value = mutate(value);
ScopedBinding<Expr> bind(should_replace, replacements, op->name, value);
frames.emplace_back(op->name, value, std::move(bind));
op = body.template as<T>();
} while (op);
body = mutate(body);
while (!frames.empty()) {
if (!frames.back().bind.bound()) {
body = T::make(frames.back().name, frames.back().new_value, body);
}
frames.pop_back();
}
return body;
}
Expr visit(const Let *op) override {
return visit_let(op);
}
Stmt visit(const LetStmt *op) override {
return visit_let(op);
}
};
} // namespace
Stmt find_intrinsics(const Stmt &s) {
Stmt stmt = SubstituteInWideningLets().mutate(s);
stmt = FindIntrinsics().mutate(stmt);
// In case we want to hoist widening ops back out
stmt = common_subexpression_elimination(stmt);
return stmt;
}
Expr find_intrinsics(const Expr &e) {
Expr expr = SubstituteInWideningLets().mutate(e);
expr = FindIntrinsics().mutate(expr);
expr = common_subexpression_elimination(expr);
return expr;
}
Expr lower_widen_right_add(const Expr &a, const Expr &b) {
return a + widen(b);
}
Expr lower_widen_right_mul(const Expr &a, const Expr &b) {
return a * widen(b);
}
Expr lower_widen_right_sub(const Expr &a, const Expr &b) {
return a - widen(b);
}
Expr lower_widening_add(const Expr &a, const Expr &b) {
return widen(a) + widen(b);
}
Expr lower_widening_mul(const Expr &a, const Expr &b) {
return widen(a) * widen(b);
}
Expr lower_widening_sub(const Expr &a, const Expr &b) {
Type wide = a.type().widen();
if (wide.is_uint()) {
wide = wide.with_code(halide_type_int);
}
return Cast::make(wide, a) - Cast::make(wide, b);
}
Expr lower_widening_shift_left(const Expr &a, const Expr &b) {
return widen(a) << b;
}
Expr lower_widening_shift_right(const Expr &a, const Expr &b) {
return widen(a) >> b;
}
Expr lower_rounding_shift_left(const Expr &a, const Expr &b) {
// Shift left, then add one to the result if bits were dropped
// (because b < 0) and the most significant dropped bit was a one.
Expr b_negative = select(b < 0, make_one(a.type()), make_zero(a.type()));
return simplify((a << b) + (b_negative & (a << (b + 1))));
}
Expr lower_rounding_shift_right(const Expr &a, const Expr &b) {
if (is_positive_const(b)) {
// We can handle the rounding with an averaging instruction. We prefer
// the rounding average instruction (we could use either), because the
// non-rounding one is missing on x86.
Expr shift = simplify(b - 1);
Expr round = simplify(cast(a.type(), (1 << shift) - 1));
return rounding_halving_add(a, round) >> shift;
}
// Shift right, then add one to the result if bits were dropped
// (because b > 0) and the most significant dropped bit was a one.
Expr b_positive = select(b > 0, make_one(a.type()), make_zero(a.type()));
return simplify((a >> b) + (b_positive & (a >> (b - 1))));
}
Expr lower_saturating_add(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Lower saturating add without using widening arithmetic, which may require
// types that aren't supported.
return simplify(clamp(a, a.type().min() - min(b, 0), a.type().max() - max(b, 0))) + b;
}
Expr lower_saturating_sub(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Lower saturating add without using widening arithmetic, which may require
// types that aren't supported.
return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b;
}
Expr lower_saturating_cast(const Type &t, const Expr &a) {
// For float to float, guarantee infinities are always pinned to range.
if (t.is_float() && a.type().is_float()) {
if (t.bits() < a.type().bits()) {
return cast(t, clamp(a, t.min(), t.max()));
} else {
return clamp(cast(t, a), t.min(), t.max());
}
} else if (a.type() != t) {
// Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n)
if (a.type().is_float() && !t.is_float() && t.bits() >= a.type().bits()) {
Expr e = max(a, t.min()); // min values turn out to be always representable
// This line depends on t.max() rounding upward, which should always
// be the case as it is one less than a representable value, thus
// the one larger is always the closest.
e = select(e >= cast(e.type(), t.max()), t.max(), cast(t, e));
return e;
} else {
Expr min_bound;
if (!a.type().is_uint()) {
min_bound = lossless_cast(a.type(), t.min());
}
Expr max_bound = lossless_cast(a.type(), t.max());
Expr e;
if (min_bound.defined() && max_bound.defined()) {
e = clamp(a, min_bound, max_bound);
} else if (min_bound.defined()) {
e = max(a, min_bound);
} else if (max_bound.defined()) {
e = min(a, max_bound);
} else {
e = a;
}
return cast(t, std::move(e));
}
}
return a;
}
Expr lower_halving_add(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Borrowed from http://aggregate.org/MAGIC/#Average%20of%20Integers
return (a & b) + ((a ^ b) >> 1);
}
Expr lower_halving_sub(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
Expr e = rounding_halving_add(a, ~b);
if (a.type().is_uint()) {
// An explanation in 8-bit:
// (x - y) / 2
// = (x + 256 - y) / 2 - 128
// = (x + (255 - y) + 1) / 2 - 128
// = (x + ~y + 1) / 2 - 128
// = rounding_halving_add(x, ~y) - 128
// = rounding_halving_add(x, ~y) + 128 (due to 2s-complement wrap-around)
return e + make_const(e.type(), (uint64_t)1 << (a.type().bits() - 1));
} else {
// For 2s-complement signed integers, negating is done by flipping the
// bits and adding one, so:
// (x - y) / 2
// = (x + (-y)) / 2
// = (x + (~y + 1)) / 2
// = rounding_halving_add(x, ~y)
return e;
}
}
Expr lower_rounding_halving_add(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
return halving_add(a, b) + ((a ^ b) & 1);
}
Expr lower_sorted_avg(const Expr &a, const Expr &b) {
// b > a, so the following works without widening.
return a + ((b - a) >> 1);
}
Expr lower_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) {
internal_assert(a.type() == b.type());
int full_q = a.type().bits();
if (a.type().is_int()) {
full_q -= 1;
}
if (can_prove(q < full_q)) {
// Try to rewrite this to a "full precision" multiply by multiplying
// one of the operands and the denominator by a constant. We only do this
// if it isn't already full precision. This avoids infinite loops despite
// "lowering" this to another mul_shift_right operation.
Expr missing_q = full_q - q;
internal_assert(missing_q.type().bits() == b.type().bits());
Expr new_b = simplify(b << missing_q);
if (is_const(new_b) && can_prove(new_b >> missing_q == b)) {
return mul_shift_right(a, new_b, full_q);
}
Expr new_a = simplify(a << missing_q);
if (is_const(new_a) && can_prove(new_a >> missing_q == a)) {
return mul_shift_right(new_a, b, full_q);
}
}
if (can_prove(q > a.type().bits())) {
// If q is bigger than the narrow type, write it as an exact upper
// half multiply, followed by an extra shift.
Expr result = mul_shift_right(a, b, a.type().bits());
result = result >> simplify(q - a.type().bits());
return result;
}
// If all else fails, just widen, shift, and narrow.
Expr result = widening_mul(a, b) >> q;
if (!can_prove(q >= a.type().bits())) {
result = saturating_narrow(result);
} else {
result = narrow(result);
}
return result;
}
Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) {
internal_assert(a.type() == b.type());
int full_q = a.type().bits();
if (a.type().is_int()) {
full_q -= 1;
}
// Try to rewrite this to a "full precision" multiply by multiplying
// one of the operands and the denominator by a constant. We only do this
// if it isn't already full precision. This avoids infinite loops despite
// "lowering" this to another mul_shift_right operation.
if (can_prove(q < full_q)) {
Expr missing_q = full_q - q;
internal_assert(missing_q.type().bits() == b.type().bits());
Expr new_b = simplify(b << missing_q);
if (is_const(new_b) && can_prove(new_b >> missing_q == b)) {
return rounding_mul_shift_right(a, new_b, full_q);
}
Expr new_a = simplify(a << missing_q);
if (is_const(new_a) && can_prove(new_a >> missing_q == a)) {
return rounding_mul_shift_right(new_a, b, full_q);
}
}
// If all else fails, just widen, shift, and narrow.
Expr result = rounding_shift_right(widening_mul(a, b), q);
if (!can_prove(q >= a.type().bits())) {
result = saturating_narrow(result);
} else {
result = narrow(result);
}
return result;
}
Expr lower_intrinsic(const Call *op) {
if (op->is_intrinsic(Call::widen_right_add)) {
internal_assert(op->args.size() == 2);
return lower_widen_right_add(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::widen_right_mul)) {
internal_assert(op->args.size() == 2);
return lower_widen_right_mul(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::widen_right_sub)) {
internal_assert(op->args.size() == 2);
return lower_widen_right_sub(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::widening_add)) {
internal_assert(op->args.size() == 2);
return lower_widening_add(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::widening_mul)) {
internal_assert(op->args.size() == 2);
return lower_widening_mul(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::widening_sub)) {
internal_assert(op->args.size() == 2);
return lower_widening_sub(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::saturating_add)) {
internal_assert(op->args.size() == 2);
return lower_saturating_add(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::saturating_sub)) {
internal_assert(op->args.size() == 2);
return lower_saturating_sub(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::saturating_cast)) {
internal_assert(op->args.size() == 1);
return lower_saturating_cast(op->type, op->args[0]);
} else if (op->is_intrinsic(Call::widening_shift_left)) {
internal_assert(op->args.size() == 2);
return lower_widening_shift_left(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::widening_shift_right)) {
internal_assert(op->args.size() == 2);
return lower_widening_shift_right(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::rounding_shift_right)) {
internal_assert(op->args.size() == 2);
return lower_rounding_shift_right(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::rounding_shift_left)) {
internal_assert(op->args.size() == 2);
return lower_rounding_shift_left(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::halving_add)) {
internal_assert(op->args.size() == 2);
return lower_halving_add(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::halving_sub)) {
internal_assert(op->args.size() == 2);
return lower_halving_sub(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::rounding_halving_add)) {
internal_assert(op->args.size() == 2);
return lower_rounding_halving_add(op->args[0], op->args[1]);
} else if (op->is_intrinsic(Call::rounding_mul_shift_right)) {
internal_assert(op->args.size() == 3);
return lower_rounding_mul_shift_right(op->args[0], op->args[1], op->args[2]);
} else if (op->is_intrinsic(Call::mul_shift_right)) {
internal_assert(op->args.size() == 3);
return lower_mul_shift_right(op->args[0], op->args[1], op->args[2]);
} else if (op->is_intrinsic(Call::sorted_avg)) {
internal_assert(op->args.size() == 2);
return lower_sorted_avg(op->args[0], op->args[1]);
} else {
return Expr();
}
}
namespace {
class LowerIntrinsics : public IRMutator {
using IRMutator::visit;
Expr visit(const Call *op) override {
Expr lowered = lower_intrinsic(op);
if (lowered.defined()) {
return mutate(lowered);
}
return IRMutator::visit(op);
}
};
} // namespace
Expr lower_intrinsics(const Expr &e) {
return LowerIntrinsics().mutate(e);
}
Stmt lower_intrinsics(const Stmt &s) {
return LowerIntrinsics().mutate(s);
}
} // namespace Internal
} // namespace Halide