#include "FindIntrinsics.h" #include "CSE.h" #include "CodeGen_Internal.h" #include "ConciseCasts.h" #include "IRMatch.h" #include "IRMutator.h" #include "Simplify.h" namespace Halide { namespace Internal { using namespace Halide::ConciseCasts; namespace { bool find_intrinsics_for_type(const Type &t) { // Currently, we only try to find and replace intrinsics for vector types that aren't bools. return t.is_vector() && t.bits() >= 8; } Expr widen(Expr a) { Type result_type = a.type().widen(); return Cast::make(result_type, std::move(a)); } Expr narrow(Expr a) { Type result_type = a.type().narrow(); return Cast::make(result_type, std::move(a)); } Expr lossless_narrow(const Expr &x) { return lossless_cast(x.type().narrow(), x); } // Remove a widening cast even if it changes the sign of the result. Expr strip_widening_cast(const Expr &x) { Expr narrow = lossless_narrow(x); if (narrow.defined()) { return narrow; } return lossless_cast(x.type().narrow().with_code(halide_type_uint), x); } Expr saturating_narrow(const Expr &a) { Type narrow = a.type().narrow(); return saturating_cast(narrow, a); } // Returns true iff t is an integral type where overflow is undefined bool no_overflow_int(Type t) { return t.is_int() && t.bits() >= 32; } // Returns true iff t does not have a well defined overflow behavior. bool no_overflow(Type t) { return t.is_float() || no_overflow_int(t); } // If there's a widening add or subtract in the first e.type().bits() / 2 - 1 // levels down a tree of adds or subtracts, we know there's enough headroom for // another add without overflow. For example, it is safe to add to // (widening_add(x, y) - z) without overflow. bool is_safe_for_add(const Expr &e, int max_depth) { if (max_depth-- <= 0) { return false; } if (const Add *add = e.as()) { return is_safe_for_add(add->a, max_depth) || is_safe_for_add(add->b, max_depth); } else if (const Sub *sub = e.as()) { return is_safe_for_add(sub->a, max_depth) || is_safe_for_add(sub->b, max_depth); } else if (const Cast *cast = e.as()) { if (cast->type.bits() > cast->value.type().bits()) { return true; } else if (cast->type.bits() == cast->value.type().bits()) { return is_safe_for_add(cast->value, max_depth); } } else if (Call::as_intrinsic(e, {Call::widening_add, Call::widening_sub, Call::widen_right_add, Call::widen_right_sub})) { return true; } return false; } bool is_safe_for_add(const Expr &e) { return is_safe_for_add(e, e.type().bits() / 2 - 1); } // We want to find and remove an add of 'round' from e. This is not // the same thing as just subtracting round, we specifically want // to remove an addition of exactly round. Expr find_and_subtract(const Expr &e, const Expr &round) { if (const Add *add = e.as()) { Expr a = find_and_subtract(add->a, round); if (a.defined()) { return Add::make(a, add->b); } Expr b = find_and_subtract(add->b, round); if (b.defined()) { return Add::make(add->a, b); } } else if (const Sub *sub = e.as()) { Expr a = find_and_subtract(sub->a, round); if (a.defined()) { return Sub::make(a, sub->b); } // We can't recurse into the negatve part of a subtract. } else if (can_prove(e == round)) { return make_zero(e.type()); } return Expr(); } Expr to_rounding_shift(const Call *c) { if (c->is_intrinsic(Call::shift_left) || c->is_intrinsic(Call::shift_right)) { internal_assert(c->args.size() == 2); Expr a = c->args[0]; Expr b = c->args[1]; // Helper to make the appropriate shift. auto rounding_shift = [&](const Expr &a, const Expr &b) { if (c->is_intrinsic(Call::shift_right)) { return rounding_shift_right(a, b); } else { return rounding_shift_left(a, b); } }; // The rounding offset for the shift we have. Type round_type = a.type().with_lanes(1); if (Call::as_intrinsic(a, {Call::widening_add})) { round_type = round_type.narrow(); } Expr round; if (c->is_intrinsic(Call::shift_right)) { round = (make_one(round_type) << max(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; } else { round = (make_one(round_type) >> min(cast(b.type().with_bits(round_type.bits()), b), 0)) / 2; } // Input expressions are simplified before running find_intrinsics, but b // has been lifted here so we need to lower_intrinsics before simplifying // and re-lifting. Should we move this code into the FindIntrinsics class // to make it easier to lift round? round = lower_intrinsics(round); round = simplify(round); round = find_intrinsics(round); // We can always handle widening adds. if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) { if (can_prove(lower_intrinsics(add->args[0] == round))) { return rounding_shift(cast(add->type, add->args[1]), b); } else if (can_prove(lower_intrinsics(add->args[1] == round))) { return rounding_shift(cast(add->type, add->args[0]), b); } } if (const Call *add = Call::as_intrinsic(a, {Call::widen_right_add})) { if (can_prove(lower_intrinsics(add->args[1] == round))) { return rounding_shift(cast(add->type, add->args[0]), b); } } // Also need to handle the annoying case of a reinterpret wrapping a widen_right_add // TODO: this pattern makes me want to change the semantics of this op. if (const Reinterpret *reinterp = a.as()) { if (reinterp->type.bits() == reinterp->value.type().bits()) { if (const Call *add = Call::as_intrinsic(reinterp->value, {Call::widen_right_add})) { if (can_prove(lower_intrinsics(add->args[1] == round))) { // We expect the first operand to be a reinterpet. const Reinterpret *reinterp_a = add->args[0].as(); internal_assert(reinterp_a) << "Failed: " << add->args[0] << "\n"; return rounding_shift(reinterp_a->value, b); } } } } // If it wasn't a widening or saturating add, we might still // be able to safely accept the rounding. Expr a_less_round = find_and_subtract(a, round); if (a_less_round.defined()) { // We found and removed the rounding. However, we may have just changed // behavior due to overflow. This is still safe if the type is not // overflowing, or we can find a widening add or subtract in the tree // of adds/subtracts. This is a common pattern, e.g. // rounding_halving_add(a, b) = shift_round(widening_add(a, b) + 1, 1). // TODO: This could be done with bounds inference instead of this hack // if it supported intrinsics like widening_add and tracked bounds for // types other than int32. if (no_overflow(a.type()) || is_safe_for_add(a_less_round)) { return rounding_shift(simplify(a_less_round), b); } } } return Expr(); } class FindIntrinsics : public IRMutator { protected: using IRMutator::visit; IRMatcher::Wild<0> x; IRMatcher::Wild<1> y; IRMatcher::Wild<2> z; IRMatcher::WildConst<0> c0; IRMatcher::WildConst<1> c1; Expr visit(const Add *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } Expr a = mutate(op->a); Expr b = mutate(op->b); // Try widening both from the same signedness as the result, and from uint. for (halide_type_code_t code : {op->type.code(), halide_type_uint}) { Type narrow = op->type.narrow().with_code(code); Expr narrow_a = lossless_cast(narrow, a); Expr narrow_b = lossless_cast(narrow, b); if (narrow_a.defined() && narrow_b.defined()) { Expr result = widening_add(narrow_a, narrow_b); if (result.type() != op->type) { result = Cast::make(op->type, result); } return mutate(result); } } if (op->type.is_int_or_uint() && op->type.bits() > 8) { // Look for widen_right_add intrinsics. // Yes we do an duplicate code, but we want to check the op->type.code() first, // and the opposite as well. for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) { Type narrow = op->type.narrow().with_code(code); // Pulling casts out of VectorReduce nodes breaks too much codegen, skip for now. Expr narrow_a = (a.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, a); Expr narrow_b = (b.node_type() == IRNodeType::VectorReduce) ? Expr() : lossless_cast(narrow, b); // This case should have been handled by the above check for widening_add. internal_assert(!(narrow_a.defined() && narrow_b.defined())) << "find_intrinsics failed to find a widening_add: " << a << " + " << b << "\n"; if (narrow_a.defined()) { Expr result; if (b.type().code() != narrow_a.type().code()) { // Need to do a safe reinterpret. Type t = b.type().with_code(code); result = widen_right_add(reinterpret(t, b), narrow_a); internal_assert(result.type() != op->type); result = reinterpret(op->type, result); } else { result = widen_right_add(b, narrow_a); } internal_assert(result.type() == op->type); return result; } else if (narrow_b.defined()) { Expr result; if (a.type().code() != narrow_b.type().code()) { // Need to do a safe reinterpret. Type t = a.type().with_code(code); result = widen_right_add(reinterpret(t, a), narrow_b); internal_assert(result.type() != op->type); result = reinterpret(op->type, result); } else { result = widen_right_add(a, narrow_b); } internal_assert(result.type() == op->type); return mutate(result); } } } // TODO: there can be widen_right_add + widen_right_add simplification rules. // i.e. widen_right_add(a, b) + widen_right_add(c, d) = (a + c) + widening_add(b, d) if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { return Add::make(a, b); } } Expr visit(const Sub *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } Expr a = mutate(op->a); Expr b = mutate(op->b); // Try widening both from the same type as the result, and from uint. for (halide_type_code_t code : {op->type.code(), halide_type_uint}) { Type narrow = op->type.narrow().with_code(code); Expr narrow_a = lossless_cast(narrow, a); Expr narrow_b = lossless_cast(narrow, b); if (narrow_a.defined() && narrow_b.defined()) { Expr negative_narrow_b = lossless_negate(narrow_b); Expr result; if (negative_narrow_b.defined()) { result = widening_add(narrow_a, negative_narrow_b); } else { result = widening_sub(narrow_a, narrow_b); } if (result.type() != op->type) { result = Cast::make(op->type, result); } return mutate(result); } } Expr negative_b = lossless_negate(b); if (negative_b.defined()) { return Add::make(a, negative_b); } // Run after the lossless_negate check, because we want that to turn into an widen_right_add if relevant. if (op->type.is_int_or_uint() && op->type.bits() > 8) { // Look for widen_right_sub intrinsics. // Yes we do an duplicate code, but we want to check the op->type.code() first, // and the opposite as well. for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) { Type narrow = op->type.narrow().with_code(code); Expr narrow_b = lossless_cast(narrow, b); if (narrow_b.defined()) { Expr result; if (a.type().code() != narrow_b.type().code()) { // Need to do a safe reinterpret. Type t = a.type().with_code(code); result = widen_right_sub(reinterpret(t, a), narrow_b); internal_assert(result.type() != op->type); result = reinterpret(op->type, result); } else { result = widen_right_sub(a, narrow_b); } internal_assert(result.type() == op->type); return mutate(result); } } } if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { return Sub::make(a, b); } } Expr visit(const Mul *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } if (as_const_int(op->b) || as_const_uint(op->b)) { // Distribute constants through add/sub. Do this before we muck everything up with widening // intrinsics. // TODO: Only do this for widening? // TODO: Try to do this with IRMatcher::rewriter. The challenge is managing the narrowing/widening casts, // and doing constant folding without the simplifier undoing the work. if (const Add *add_a = op->a.as()) { return mutate(Add::make(simplify(Mul::make(add_a->a, op->b)), simplify(Mul::make(add_a->b, op->b)))); } else if (const Sub *sub_a = op->a.as()) { return mutate(Sub::make(simplify(Mul::make(sub_a->a, op->b)), simplify(Mul::make(sub_a->b, op->b)))); } } Expr a = mutate(op->a); Expr b = mutate(op->b); // Rewrite multiplies to shifts if possible. if (op->type.is_int() || op->type.is_uint()) { int pow2 = 0; if (is_const_power_of_two_integer(a, &pow2)) { return mutate(b << cast(UInt(b.type().bits()), pow2)); } else if (is_const_power_of_two_integer(b, &pow2)) { return mutate(a << cast(UInt(a.type().bits()), pow2)); } } // We're applying this to float, which seems OK? float16 * float16 -> float32 is a widening multiply? // This uses strip_widening_cast to ignore the signedness of the narrow value. Expr narrow_a = strip_widening_cast(a); Expr narrow_b = strip_widening_cast(b); if (narrow_a.defined() && narrow_b.defined() && (narrow_a.type().is_int_or_uint() == narrow_b.type().is_int_or_uint() || narrow_a.type().is_float() == narrow_b.type().is_float())) { Expr result = widening_mul(narrow_a, narrow_b); if (result.type() != op->type) { result = Cast::make(op->type, result); } return mutate(result); } if (op->type.is_int_or_uint() && op->type.bits() > 8) { // Look for widen_right_mul intrinsics. // Yes we do an duplicate code, but we want to check the op->type.code() first, // and the opposite as well. for (halide_type_code_t code : {op->type.code(), halide_type_uint, halide_type_int}) { Type narrow = op->type.narrow().with_code(code); Expr narrow_a = lossless_cast(narrow, a); Expr narrow_b = lossless_cast(narrow, b); // This case should have been handled by the above check for widening_mul. internal_assert(!(narrow_a.defined() && narrow_b.defined())) << "find_intrinsics failed to find a widening_mul: " << a << " + " << b << "\n"; if (narrow_a.defined()) { Expr result; if (b.type().code() != narrow_a.type().code()) { // Need to do a safe reinterpret. Type t = b.type().with_code(code); result = widen_right_mul(reinterpret(t, b), narrow_a); internal_assert(result.type() != op->type); result = reinterpret(op->type, result); } else { result = widen_right_mul(b, narrow_a); } internal_assert(result.type() == op->type); return result; } else if (narrow_b.defined()) { Expr result; if (a.type().code() != narrow_b.type().code()) { // Need to do a safe reinterpret. Type t = a.type().with_code(code); result = widen_right_mul(reinterpret(t, a), narrow_b); internal_assert(result.type() != op->type); result = reinterpret(op->type, result); } else { result = widen_right_mul(a, narrow_b); } internal_assert(result.type() == op->type); return mutate(result); } } } if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { return Mul::make(a, b); } } Expr visit(const Div *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } Expr a = mutate(op->a); Expr b = mutate(op->b); int shift_amount; if (is_const_power_of_two_integer(b, &shift_amount) && op->type.is_int_or_uint()) { return mutate(a >> make_const(UInt(a.type().bits()), shift_amount)); } if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { return Div::make(a, b); } } // We don't handle Mod because we don't have any patterns that look for bitwise and vs. // mod. template Expr visit_min_or_max(const MinOrMax *op) { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } Expr a = mutate(op->a); Expr b = mutate(op->b); if (const Cast *cast_a = a.as()) { Expr cast_b = lossless_cast(cast_a->value.type(), b); if (cast_a->type.can_represent(cast_a->value.type()) && cast_b.defined()) { // This is a widening cast that can be moved outside the min. return mutate(Cast::make(cast_a->type, MinOrMax::make(cast_a->value, cast_b))); } } if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { return MinOrMax::make(a, b); } } Expr visit(const Min *op) override { return visit_min_or_max(op); } Expr visit(const Max *op) override { return visit_min_or_max(op); } Expr visit(const Cast *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } Expr value = mutate(op->value); // This mutator can generate redundant casts. We can't use the simplifier because it // undoes some of the intrinsic lowering here, and it causes some problems due to // factoring (instead of distributing) constants. if (const Cast *cast = value.as()) { if (cast->type.can_represent(cast->value.type()) || cast->type.can_represent(op->type)) { // The intermediate cast is redundant. value = cast->value; } } if (op->type.is_int() || op->type.is_uint()) { Expr lower = cast(value.type(), op->type.min()); Expr upper = cast(value.type(), op->type.max()); auto rewrite = IRMatcher::rewriter(value, op->type); Type op_type_wide = op->type.widen(); Type signed_type_wide = op_type_wide.with_code(halide_type_int); Type unsigned_type = op->type.with_code(halide_type_uint); // Give concise names to various predicates we want to use in // rewrite rules below. int bits = op->type.bits(); auto is_x_same_int = op->type.is_int() && is_int(x, bits); auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; auto x_y_same_sign = (is_int(x) && is_int(y)) || (is_uint(x) && is_uint(y)); auto is_y_narrow_uint = op->type.is_uint() && is_uint(y, bits / 2); if ( // Saturating patterns rewrite(max(min(widening_add(x, y), upper), lower), saturating_add(x, y), is_x_same_int_or_uint) || rewrite(max(min(widening_sub(x, y), upper), lower), saturating_sub(x, y), is_x_same_int_or_uint) || rewrite(min(cast(signed_type_wide, widening_add(x, y)), upper), saturating_add(x, y), is_x_same_uint) || rewrite(min(widening_add(x, y), upper), saturating_add(x, y), op->type.is_uint() && is_x_same_uint) || rewrite(max(widening_sub(x, y), lower), saturating_sub(x, y), op->type.is_uint() && is_x_same_uint) || // Saturating narrow patterns. rewrite(max(min(x, upper), lower), saturating_cast(op->type, x)) || rewrite(min(x, upper), saturating_cast(op->type, x), is_uint(x)) || // Averaging patterns // // We have a slight preference for rounding_halving_add over // using halving_add when unsigned, because x86 supports it. rewrite(shift_right(widening_add(x, c0), 1), rounding_halving_add(x, c0 - 1), c0 > 0 && is_x_same_uint) || rewrite(shift_right(widening_add(x, y), 1), halving_add(x, y), is_x_same_int_or_uint) || rewrite(shift_right(widening_add(x, c0), c1), rounding_shift_right(x, cast(op->type, c1)), c0 == shift_left(1, c1 - 1) && is_x_same_int_or_uint) || rewrite(shift_right(widening_add(x, c0), c1), shift_right(rounding_halving_add(x, cast(op->type, fold(c0 - 1))), cast(op->type, fold(c1 - 1))), c0 > 0 && c1 > 0 && is_x_same_uint) || rewrite(shift_right(widening_add(x, y), c0), shift_right(halving_add(x, y), cast(op->type, fold(c0 - 1))), c0 > 0 && is_x_same_int_or_uint) || rewrite(shift_right(widening_sub(x, y), 1), halving_sub(x, y), is_x_same_int_or_uint) || rewrite(halving_add(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) || rewrite(halving_add(widening_add(x, 1), y), rounding_halving_add(x, y), is_x_same_int_or_uint) || rewrite(rounding_shift_right(widening_add(x, y), 1), rounding_halving_add(x, y), is_x_same_int_or_uint) || // Multiply-keep-high-bits patterns. rewrite(max(min(shift_right(widening_mul(x, y), z), upper), lower), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || rewrite(max(min(rounding_shift_right(widening_mul(x, y), z), upper), lower), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || rewrite(min(shift_right(widening_mul(x, y), z), upper), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_uint && x_y_same_sign && is_uint(z)) || rewrite(min(rounding_shift_right(widening_mul(x, y), z), upper), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_uint && x_y_same_sign && is_uint(z)) || // We don't need saturation for the full upper half of a multiply. // For signed integers, this is almost true, except for when x and y // are both the most negative value. For these, we only need saturation // at the upper bound. rewrite(min(shift_right(widening_mul(x, y), c0), upper), mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && x_y_same_sign && c0 >= bits - 1) || rewrite(min(rounding_shift_right(widening_mul(x, y), c0), upper), rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int && x_y_same_sign && c0 >= bits - 1) || rewrite(shift_right(widening_mul(x, y), c0), mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) || rewrite(rounding_shift_right(widening_mul(x, y), c0), rounding_mul_shift_right(x, y, cast(unsigned_type, c0)), is_x_same_int_or_uint && x_y_same_sign && c0 >= bits) || // We can also match on smaller shifts if one of the args is // narrow. We don't do this for signed (yet), because the // saturation issue is tricky. rewrite(shift_right(widening_mul(x, cast(op->type, y)), c0), mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || rewrite(rounding_shift_right(widening_mul(x, cast(op->type, y)), c0), rounding_mul_shift_right(x, cast(op->type, y), cast(unsigned_type, c0)), is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || rewrite(shift_right(widening_mul(cast(op->type, y), x), c0), mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || rewrite(rounding_shift_right(widening_mul(cast(op->type, y), x), c0), rounding_mul_shift_right(cast(op->type, y), x, cast(unsigned_type, c0)), is_x_same_int_or_uint && is_y_narrow_uint && c0 >= bits / 2) || // Halving subtract patterns rewrite(shift_right(cast(op_type_wide, widening_sub(x, y)), 1), halving_sub(x, y), is_x_same_int_or_uint) || false) { internal_assert(rewrite.result.type() == op->type) << "Rewrite changed type: " << Expr(op) << " -> " << rewrite.result << "\n"; return mutate(rewrite.result); } // When the argument is a widened rounding shift, we might not need the widening. // When there is saturation, we can only avoid the widening if we know the shift is // a right shift. Without saturation, we can ignore the widening. auto is_x_wide_int = op->type.is_int() && is_int(x, bits * 2); auto is_x_wide_uint = op->type.is_uint() && is_uint(x, bits * 2); auto is_x_wide_int_or_uint = is_x_wide_int || is_x_wide_uint; // We can't do everything we want here with rewrite rules alone. So, we rewrite them // to rounding_shifts with the widening still in place, and narrow it after the rewrite // succeeds. // clang-format off if (rewrite(max(min(rounding_shift_right(x, y), upper), lower), rounding_shift_right(x, y), is_x_wide_int_or_uint) || rewrite(rounding_shift_right(x, y), rounding_shift_right(x, y), is_x_wide_int_or_uint) || rewrite(rounding_shift_left(x, y), rounding_shift_left(x, y), is_x_wide_int_or_uint) || false) { const Call *shift = Call::as_intrinsic(rewrite.result, {Call::rounding_shift_right, Call::rounding_shift_left}); internal_assert(shift); bool is_saturated = op->value.as() || op->value.as(); Expr a = lossless_cast(op->type, shift->args[0]); Expr b = lossless_cast(op->type.with_code(shift->args[1].type().code()), shift->args[1]); if (a.defined() && b.defined()) { if (!is_saturated || (shift->is_intrinsic(Call::rounding_shift_right) && can_prove(b >= 0)) || (shift->is_intrinsic(Call::rounding_shift_left) && can_prove(b <= 0))) { return mutate(Call::make(op->type, shift->name, {a, b}, Call::PureIntrinsic)); } } } // clang-format on } if (value.same_as(op->value)) { return op; } else if (op->type != value.type()) { return Cast::make(op->type, value); } else { return value; } } Expr visit(const Call *op) override { if (!find_intrinsics_for_type(op->type)) { return IRMutator::visit(op); } Expr mutated = IRMutator::visit(op); op = mutated.as(); if (!op) { return mutated; } auto rewrite = IRMatcher::rewriter(op, op->type); if (rewrite(intrin(Call::abs, widening_sub(x, y)), cast(op->type, intrin(Call::absd, x, y))) || false) { return rewrite.result; } const int bits = op->type.bits(); const auto is_x_same_int = op->type.is_int() && is_int(x, bits); const auto is_x_same_uint = op->type.is_uint() && is_uint(x, bits); const auto is_x_same_int_or_uint = is_x_same_int || is_x_same_uint; auto x_y_same_sign = (is_int(x) == is_int(y)) || (is_uint(x) && is_uint(y)); Type unsigned_type = op->type.with_code(halide_type_uint); const auto is_x_wider_int_or_uint = (op->type.is_int() && is_int(x, 2 * bits)) || (op->type.is_uint() && is_uint(x, 2 * bits)); Type opposite_type = op->type.is_int() ? op->type.with_code(halide_type_uint) : op->type.with_code(halide_type_int); const auto is_x_wider_opposite_int = (op->type.is_int() && is_uint(x, 2 * bits)) || (op->type.is_uint() && is_int(x, 2 * bits)); if ( // Simplify extending patterns. // (x + widen(y)) + widen(z) = x + widening_add(y, z). rewrite(widen_right_add(widen_right_add(x, y), z), x + widening_add(y, z), // We only care about integers, this should be trivially true. is_x_same_int_or_uint) || // (x - widen(y)) - widen(z) = x - widening_add(y, z). rewrite(widen_right_sub(widen_right_sub(x, y), z), x - widening_add(y, z), // We only care about integers, this should be trivially true. is_x_same_int_or_uint) || // (x + widen(y)) - widen(z) = x + cast(t, widening_sub(y, z)) // cast (reinterpret) is needed only for uints. rewrite(widen_right_sub(widen_right_add(x, y), z), x + widening_sub(y, z), is_x_same_int) || rewrite(widen_right_sub(widen_right_add(x, y), z), x + cast(op->type, widening_sub(y, z)), is_x_same_uint) || // (x - widen(y)) + widen(z) = x + cast(t, widening_sub(z, y)) // cast (reinterpret) is needed only for uints. rewrite(widen_right_add(widen_right_sub(x, y), z), x + widening_sub(z, y), is_x_same_int) || rewrite(widen_right_add(widen_right_sub(x, y), z), x + cast(op->type, widening_sub(z, y)), is_x_same_uint) || // Saturating patterns. rewrite(saturating_cast(op->type, widening_add(x, y)), saturating_add(x, y), is_x_same_int_or_uint) || rewrite(saturating_cast(op->type, widening_sub(x, y)), saturating_sub(x, y), is_x_same_int_or_uint) || rewrite(saturating_cast(op->type, shift_right(widening_mul(x, y), z)), mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || rewrite(saturating_cast(op->type, rounding_shift_right(widening_mul(x, y), z)), rounding_mul_shift_right(x, y, cast(unsigned_type, z)), is_x_same_int_or_uint && x_y_same_sign && is_uint(z)) || // We can remove unnecessary widening if we are then performing a saturating narrow. // This is similar to the logic inside `visit_min_or_max`. (((bits <= 32) && // Examples: // i8_sat(int16(i8)) -> i8 // u8_sat(uint16(u8)) -> u8 rewrite(saturating_cast(op->type, cast(op->type.widen(), x)), x, is_x_same_int_or_uint)) || ((bits <= 16) && // Examples: // i8_sat(int32(i16)) -> i8_sat(i16) // u8_sat(uint32(u16)) -> u8_sat(u16) (rewrite(saturating_cast(op->type, cast(op->type.widen().widen(), x)), saturating_cast(op->type, x), is_x_wider_int_or_uint) || // Examples: // i8_sat(uint32(u16)) -> i8_sat(u16) // u8_sat(int32(i16)) -> i8_sat(i16) rewrite(saturating_cast(op->type, cast(opposite_type.widen().widen(), x)), saturating_cast(op->type, x), is_x_wider_opposite_int) || false))) || false) { return mutate(rewrite.result); } if (no_overflow(op->type)) { // clang-format off if (rewrite(halving_add(x + y, 1), rounding_halving_add(x, y)) || rewrite(halving_add(x, y + 1), rounding_halving_add(x, y)) || rewrite(halving_add(x + 1, y), rounding_halving_add(x, y)) || rewrite(halving_add(x, 1), rounding_shift_right(x, 1)) || rewrite(shift_right(x + y, 1), halving_add(x, y)) || rewrite(shift_right(x - y, 1), halving_sub(x, y)) || rewrite(rounding_shift_right(x + y, 1), rounding_halving_add(x, y)) || false) { return mutate(rewrite.result); } // clang-format on } // Move widening casts inside widening arithmetic outside the arithmetic, // e.g. widening_mul(widen(u8), widen(i8)) -> widen(widening_mul(u8, i8)). if (op->is_intrinsic(Call::widening_mul)) { internal_assert(op->args.size() == 2); Expr narrow_a = strip_widening_cast(op->args[0]); Expr narrow_b = strip_widening_cast(op->args[1]); if (narrow_a.defined() && narrow_b.defined()) { return mutate(Cast::make(op->type, widening_mul(narrow_a, narrow_b))); } } else if (op->is_intrinsic(Call::widening_add) && (op->type.bits() >= 16)) { internal_assert(op->args.size() == 2); for (halide_type_code_t t : {op->type.code(), halide_type_uint}) { Type narrow_t = op->type.narrow().narrow().with_code(t); Expr narrow_a = lossless_cast(narrow_t, op->args[0]); Expr narrow_b = lossless_cast(narrow_t, op->args[1]); if (narrow_a.defined() && narrow_b.defined()) { return mutate(Cast::make(op->type, widening_add(narrow_a, narrow_b))); } } } else if (op->is_intrinsic(Call::widening_sub) && (op->type.bits() >= 16)) { internal_assert(op->args.size() == 2); for (halide_type_code_t t : {op->type.code(), halide_type_uint}) { Type narrow_t = op->type.narrow().narrow().with_code(t); Expr narrow_a = lossless_cast(narrow_t, op->args[0]); Expr narrow_b = lossless_cast(narrow_t, op->args[1]); if (narrow_a.defined() && narrow_b.defined()) { return mutate(Cast::make(op->type, widening_sub(narrow_a, narrow_b))); } } } // TODO: do we want versions of widen_right_add here? if (op->is_intrinsic(Call::shift_right) || op->is_intrinsic(Call::shift_left)) { // Try to turn this into a widening shift. internal_assert(op->args.size() == 2); Expr a_narrow = lossless_narrow(op->args[0]); Expr b_narrow = lossless_narrow(op->args[1]); if (a_narrow.defined() && b_narrow.defined()) { Expr result = op->is_intrinsic(Call::shift_left) ? widening_shift_left(a_narrow, b_narrow) : widening_shift_right(a_narrow, b_narrow); if (result.type() != op->type) { result = Cast::make(op->type, result); } return mutate(result); } // Try to turn this into a rounding shift. Expr rounding_shift = to_rounding_shift(op); if (rounding_shift.defined()) { return mutate(rounding_shift); } } if (op->is_intrinsic(Call::rounding_shift_left) || op->is_intrinsic(Call::rounding_shift_right)) { // Try to turn this into a widening shift. internal_assert(op->args.size() == 2); Expr a_narrow = lossless_narrow(op->args[0]); Expr b_narrow = lossless_narrow(op->args[1]); if (a_narrow.defined() && b_narrow.defined()) { Expr result; if (op->is_intrinsic(Call::rounding_shift_right) && can_prove(b_narrow > 0)) { result = rounding_shift_right(a_narrow, b_narrow); } else if (op->is_intrinsic(Call::rounding_shift_left) && can_prove(b_narrow < 0)) { result = rounding_shift_left(a_narrow, b_narrow); } else { return op; } if (result.type() != op->type) { result = Cast::make(op->type, result); } return mutate(result); } } return op; } }; // Substitute in let values than have an output vector // type wider than all the types of other variables // referenced. This can't cause combinatorial explosion, // because each let in a chain has a wider value than the // ones it refers to. class SubstituteInWideningLets : public IRMutator { using IRMutator::visit; bool widens(const Expr &e) { class AllInputsNarrowerThan : public IRVisitor { int bits; using IRVisitor::visit; void visit(const Variable *op) override { result &= op->type.bits() < bits; } void visit(const Load *op) override { result &= op->type.bits() < bits; } void visit(const Call *op) override { if (op->is_pure() && op->is_intrinsic()) { IRVisitor::visit(op); } else { result &= op->type.bits() < bits; } } public: AllInputsNarrowerThan(Type t) : bits(t.bits()) { } bool result = true; } widens(e.type()); e.accept(&widens); return widens.result; } Scope replacements; Expr visit(const Variable *op) override { if (replacements.contains(op->name)) { return replacements.get(op->name); } else { return op; } } template auto visit_let(const T *op) -> decltype(op->body) { struct Frame { std::string name; Expr new_value; ScopedBinding bind; Frame(const std::string &name, const Expr &new_value, ScopedBinding &&bind) : name(name), new_value(new_value), bind(std::move(bind)) { } }; std::vector frames; decltype(op->body) body; do { body = op->body; Expr value = op->value; bool should_replace = find_intrinsics_for_type(value.type()) && widens(value); // We can only substitute in pure stuff. Isolate all // impure subexpressions and leave them behind here as // lets. class LeaveBehindSubexpressions : public IRMutator { using IRMutator::visit; Expr visit(const Call *op) override { if (!op->is_pure() || !op->is_intrinsic()) { // Only enter pure intrinsics (e.g. existing uses of widening_add) std::string name = unique_name('t'); frames.emplace_back(name, op, ScopedBinding{}); return Variable::make(op->type, name); } else { return IRMutator::visit(op); } } Expr visit(const Load *op) override { // Never enter loads. They can be impure and none // of our patterns match them. std::string name = unique_name('t'); frames.emplace_back(name, op, ScopedBinding{}); return Variable::make(op->type, name); } std::vector &frames; public: LeaveBehindSubexpressions(std::vector &frames) : frames(frames) { } } extractor(frames); if (should_replace) { size_t start_of_new_lets = frames.size(); value = extractor.mutate(value); // Mutate any subexpressions the extractor decided to // leave behind, in case they in turn depend on lets // we've decided to substitute in. for (size_t i = start_of_new_lets; i < frames.size(); i++) { frames[i].new_value = mutate(frames[i].new_value); } // Check it wasn't lifted entirely should_replace = !value.as(); } // TODO: If it's an int32/64 vector, it may be // implicitly widening because overflow is UB. Hard to // see how to handle this without worrying about // combinatorial explosion of substitutions. value = mutate(value); ScopedBinding bind(should_replace, replacements, op->name, value); frames.emplace_back(op->name, value, std::move(bind)); op = body.template as(); } while (op); body = mutate(body); while (!frames.empty()) { if (!frames.back().bind.bound()) { body = T::make(frames.back().name, frames.back().new_value, body); } frames.pop_back(); } return body; } Expr visit(const Let *op) override { return visit_let(op); } Stmt visit(const LetStmt *op) override { return visit_let(op); } }; } // namespace Stmt find_intrinsics(const Stmt &s) { Stmt stmt = SubstituteInWideningLets().mutate(s); stmt = FindIntrinsics().mutate(stmt); // In case we want to hoist widening ops back out stmt = common_subexpression_elimination(stmt); return stmt; } Expr find_intrinsics(const Expr &e) { Expr expr = SubstituteInWideningLets().mutate(e); expr = FindIntrinsics().mutate(expr); expr = common_subexpression_elimination(expr); return expr; } Expr lower_widen_right_add(const Expr &a, const Expr &b) { return a + widen(b); } Expr lower_widen_right_mul(const Expr &a, const Expr &b) { return a * widen(b); } Expr lower_widen_right_sub(const Expr &a, const Expr &b) { return a - widen(b); } Expr lower_widening_add(const Expr &a, const Expr &b) { return widen(a) + widen(b); } Expr lower_widening_mul(const Expr &a, const Expr &b) { return widen(a) * widen(b); } Expr lower_widening_sub(const Expr &a, const Expr &b) { Type wide = a.type().widen(); if (wide.is_uint()) { wide = wide.with_code(halide_type_int); } return Cast::make(wide, a) - Cast::make(wide, b); } Expr lower_widening_shift_left(const Expr &a, const Expr &b) { return widen(a) << b; } Expr lower_widening_shift_right(const Expr &a, const Expr &b) { return widen(a) >> b; } Expr lower_rounding_shift_left(const Expr &a, const Expr &b) { // Shift left, then add one to the result if bits were dropped // (because b < 0) and the most significant dropped bit was a one. Expr b_negative = select(b < 0, make_one(a.type()), make_zero(a.type())); return simplify((a << b) + (b_negative & (a << (b + 1)))); } Expr lower_rounding_shift_right(const Expr &a, const Expr &b) { if (is_positive_const(b)) { // We can handle the rounding with an averaging instruction. We prefer // the rounding average instruction (we could use either), because the // non-rounding one is missing on x86. Expr shift = simplify(b - 1); Expr round = simplify(cast(a.type(), (1 << shift) - 1)); return rounding_halving_add(a, round) >> shift; } // Shift right, then add one to the result if bits were dropped // (because b > 0) and the most significant dropped bit was a one. Expr b_positive = select(b > 0, make_one(a.type()), make_zero(a.type())); return simplify((a >> b) + (b_positive & (a >> (b - 1)))); } Expr lower_saturating_add(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); // Lower saturating add without using widening arithmetic, which may require // types that aren't supported. return simplify(clamp(a, a.type().min() - min(b, 0), a.type().max() - max(b, 0))) + b; } Expr lower_saturating_sub(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); // Lower saturating add without using widening arithmetic, which may require // types that aren't supported. return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b; } Expr lower_saturating_cast(const Type &t, const Expr &a) { // For float to float, guarantee infinities are always pinned to range. if (t.is_float() && a.type().is_float()) { if (t.bits() < a.type().bits()) { return cast(t, clamp(a, t.min(), t.max())); } else { return clamp(cast(t, a), t.min(), t.max()); } } else if (a.type() != t) { // Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n) if (a.type().is_float() && !t.is_float() && t.bits() >= a.type().bits()) { Expr e = max(a, t.min()); // min values turn out to be always representable // This line depends on t.max() rounding upward, which should always // be the case as it is one less than a representable value, thus // the one larger is always the closest. e = select(e >= cast(e.type(), t.max()), t.max(), cast(t, e)); return e; } else { Expr min_bound; if (!a.type().is_uint()) { min_bound = lossless_cast(a.type(), t.min()); } Expr max_bound = lossless_cast(a.type(), t.max()); Expr e; if (min_bound.defined() && max_bound.defined()) { e = clamp(a, min_bound, max_bound); } else if (min_bound.defined()) { e = max(a, min_bound); } else if (max_bound.defined()) { e = min(a, max_bound); } else { e = a; } return cast(t, std::move(e)); } } return a; } Expr lower_halving_add(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); // Borrowed from http://aggregate.org/MAGIC/#Average%20of%20Integers return (a & b) + ((a ^ b) >> 1); } Expr lower_halving_sub(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); Expr e = rounding_halving_add(a, ~b); if (a.type().is_uint()) { // An explanation in 8-bit: // (x - y) / 2 // = (x + 256 - y) / 2 - 128 // = (x + (255 - y) + 1) / 2 - 128 // = (x + ~y + 1) / 2 - 128 // = rounding_halving_add(x, ~y) - 128 // = rounding_halving_add(x, ~y) + 128 (due to 2s-complement wrap-around) return e + make_const(e.type(), (uint64_t)1 << (a.type().bits() - 1)); } else { // For 2s-complement signed integers, negating is done by flipping the // bits and adding one, so: // (x - y) / 2 // = (x + (-y)) / 2 // = (x + (~y + 1)) / 2 // = rounding_halving_add(x, ~y) return e; } } Expr lower_rounding_halving_add(const Expr &a, const Expr &b) { internal_assert(a.type() == b.type()); return halving_add(a, b) + ((a ^ b) & 1); } Expr lower_sorted_avg(const Expr &a, const Expr &b) { // b > a, so the following works without widening. return a + ((b - a) >> 1); } Expr lower_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) { internal_assert(a.type() == b.type()); int full_q = a.type().bits(); if (a.type().is_int()) { full_q -= 1; } if (can_prove(q < full_q)) { // Try to rewrite this to a "full precision" multiply by multiplying // one of the operands and the denominator by a constant. We only do this // if it isn't already full precision. This avoids infinite loops despite // "lowering" this to another mul_shift_right operation. Expr missing_q = full_q - q; internal_assert(missing_q.type().bits() == b.type().bits()); Expr new_b = simplify(b << missing_q); if (is_const(new_b) && can_prove(new_b >> missing_q == b)) { return mul_shift_right(a, new_b, full_q); } Expr new_a = simplify(a << missing_q); if (is_const(new_a) && can_prove(new_a >> missing_q == a)) { return mul_shift_right(new_a, b, full_q); } } if (can_prove(q > a.type().bits())) { // If q is bigger than the narrow type, write it as an exact upper // half multiply, followed by an extra shift. Expr result = mul_shift_right(a, b, a.type().bits()); result = result >> simplify(q - a.type().bits()); return result; } // If all else fails, just widen, shift, and narrow. Expr result = widening_mul(a, b) >> q; if (!can_prove(q >= a.type().bits())) { result = saturating_narrow(result); } else { result = narrow(result); } return result; } Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q) { internal_assert(a.type() == b.type()); int full_q = a.type().bits(); if (a.type().is_int()) { full_q -= 1; } // Try to rewrite this to a "full precision" multiply by multiplying // one of the operands and the denominator by a constant. We only do this // if it isn't already full precision. This avoids infinite loops despite // "lowering" this to another mul_shift_right operation. if (can_prove(q < full_q)) { Expr missing_q = full_q - q; internal_assert(missing_q.type().bits() == b.type().bits()); Expr new_b = simplify(b << missing_q); if (is_const(new_b) && can_prove(new_b >> missing_q == b)) { return rounding_mul_shift_right(a, new_b, full_q); } Expr new_a = simplify(a << missing_q); if (is_const(new_a) && can_prove(new_a >> missing_q == a)) { return rounding_mul_shift_right(new_a, b, full_q); } } // If all else fails, just widen, shift, and narrow. Expr result = rounding_shift_right(widening_mul(a, b), q); if (!can_prove(q >= a.type().bits())) { result = saturating_narrow(result); } else { result = narrow(result); } return result; } Expr lower_intrinsic(const Call *op) { if (op->is_intrinsic(Call::widen_right_add)) { internal_assert(op->args.size() == 2); return lower_widen_right_add(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::widen_right_mul)) { internal_assert(op->args.size() == 2); return lower_widen_right_mul(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::widen_right_sub)) { internal_assert(op->args.size() == 2); return lower_widen_right_sub(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::widening_add)) { internal_assert(op->args.size() == 2); return lower_widening_add(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::widening_mul)) { internal_assert(op->args.size() == 2); return lower_widening_mul(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::widening_sub)) { internal_assert(op->args.size() == 2); return lower_widening_sub(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::saturating_add)) { internal_assert(op->args.size() == 2); return lower_saturating_add(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::saturating_sub)) { internal_assert(op->args.size() == 2); return lower_saturating_sub(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::saturating_cast)) { internal_assert(op->args.size() == 1); return lower_saturating_cast(op->type, op->args[0]); } else if (op->is_intrinsic(Call::widening_shift_left)) { internal_assert(op->args.size() == 2); return lower_widening_shift_left(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::widening_shift_right)) { internal_assert(op->args.size() == 2); return lower_widening_shift_right(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::rounding_shift_right)) { internal_assert(op->args.size() == 2); return lower_rounding_shift_right(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::rounding_shift_left)) { internal_assert(op->args.size() == 2); return lower_rounding_shift_left(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::halving_add)) { internal_assert(op->args.size() == 2); return lower_halving_add(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::halving_sub)) { internal_assert(op->args.size() == 2); return lower_halving_sub(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::rounding_halving_add)) { internal_assert(op->args.size() == 2); return lower_rounding_halving_add(op->args[0], op->args[1]); } else if (op->is_intrinsic(Call::rounding_mul_shift_right)) { internal_assert(op->args.size() == 3); return lower_rounding_mul_shift_right(op->args[0], op->args[1], op->args[2]); } else if (op->is_intrinsic(Call::mul_shift_right)) { internal_assert(op->args.size() == 3); return lower_mul_shift_right(op->args[0], op->args[1], op->args[2]); } else if (op->is_intrinsic(Call::sorted_avg)) { internal_assert(op->args.size() == 2); return lower_sorted_avg(op->args[0], op->args[1]); } else { return Expr(); } } namespace { class LowerIntrinsics : public IRMutator { using IRMutator::visit; Expr visit(const Call *op) override { Expr lowered = lower_intrinsic(op); if (lowered.defined()) { return mutate(lowered); } return IRMutator::visit(op); } }; } // namespace Expr lower_intrinsics(const Expr &e) { return LowerIntrinsics().mutate(e); } Stmt lower_intrinsics(const Stmt &s) { return LowerIntrinsics().mutate(s); } } // namespace Internal } // namespace Halide