#include "CSE.h" #include "CodeGen_Internal.h" #include "ConciseCasts.h" #include "FindIntrinsics.h" #include "IRMatch.h" #include "IRMutator.h" #include "Simplify.h" namespace Halide { namespace Internal { namespace { /** * Distribute constant RHS widening shift lefts as multiplies. * This is an extremely unfortunate mess. Unfortunately, the * simplifier needs to lift constant multiplications due to its * cost model. This transformation is very architecture and data- * type specific (e.g. useful on ARM and HVX due to a plethora of * dot product / widening multiply instructions). */ class DistributeShiftsAsMuls : public IRMutator { public: DistributeShiftsAsMuls(const bool multiply_adds) : multiply_adds(multiply_adds) { } private: const bool multiply_adds; static bool is_cast(const Expr &e, Type value_t) { if (const Cast *cast = e.as()) { return cast->value.type() == value_t; } return false; } static Expr distribute(const Expr &a, const Expr &b) { if (const Add *add = a.as()) { return Add::make(distribute(add->a, b), distribute(add->b, b)); } else if (const Sub *sub = a.as()) { Expr sub_a = distribute(sub->a, b); Expr sub_b = distribute(sub->b, b); Expr negative_sub_b = lossless_negate(sub_b); if (negative_sub_b.defined()) { return Add::make(sub_a, negative_sub_b); } else { return Sub::make(sub_a, sub_b); } } else if (const Cast *cast = a.as()) { Expr cast_b = lossless_cast(b.type().with_bits(cast->value.type().bits()), b); if (cast_b.defined()) { Expr mul = widening_mul(cast->value, cast_b); if (mul.type().bits() <= cast->type.bits()) { if (mul.type() != cast->type) { mul = Cast::make(cast->type, mul); } return mul; } } } else if (const Call *add = Call::as_intrinsic(a, {Call::widening_add})) { Expr add_a = Cast::make(add->type, add->args[0]); Expr add_b = Cast::make(add->type, add->args[1]); add_a = distribute(add_a, b); add_b = distribute(add_b, b); // If add_a and add_b are the same kind of cast, we should remake a widening add. const Cast *add_a_cast = add_a.as(); const Cast *add_b_cast = add_b.as(); if (add_a_cast && add_b_cast && add_a_cast->value.type() == add->args[0].type() && add_b_cast->value.type() == add->args[1].type()) { return widening_add(add_a_cast->value, add_b_cast->value); } else { return Add::make(add_a, add_b); } } else if (const Call *sub = Call::as_intrinsic(a, {Call::widening_sub})) { Expr sub_a = Cast::make(sub->type, sub->args[0]); Expr sub_b = Cast::make(sub->type, sub->args[1]); sub_a = distribute(sub_a, b); sub_b = distribute(sub_b, b); Expr negative_sub_b = lossless_negate(sub_b); if (negative_sub_b.defined()) { sub_b = negative_sub_b; } // If sub_a and sub_b are the same kind of cast, we should remake a widening sub. const Cast *sub_a_cast = sub_a.as(); const Cast *sub_b_cast = sub_b.as(); if (sub_a_cast && sub_b_cast && sub_a_cast->value.type() == sub->args[0].type() && sub_b_cast->value.type() == sub->args[1].type()) { if (negative_sub_b.defined()) { return widening_add(sub_a_cast->value, sub_b_cast->value); } else { return widening_sub(sub_a_cast->value, sub_b_cast->value); } } else { if (negative_sub_b.defined()) { return Add::make(sub_a, sub_b); } else { return Sub::make(sub_a, sub_b); } } } else if (const Call *mul = Call::as_intrinsic(a, {Call::widening_mul})) { Expr mul_a = Cast::make(mul->type, mul->args[0]); Expr mul_b = Cast::make(mul->type, mul->args[1]); mul_a = distribute(mul_a, b); if (const Cast *mul_a_cast = mul_a.as()) { if (mul_a_cast->value.type() == mul->args[0].type()) { return widening_mul(mul_a_cast->value, mul->args[1]); } } mul_b = distribute(mul_b, b); if (const Cast *mul_b_cast = mul_b.as()) { if (mul_b_cast->value.type() == mul->args[1].type()) { return widening_mul(mul->args[0], mul_b_cast->value); } } } return simplify(Mul::make(a, b)); } Expr distribute_shift(const Call *op) { if (op->is_intrinsic(Call::shift_left)) { if (const uint64_t *const_b = as_const_uint(op->args[1])) { Expr a = op->args[0]; // Only rewrite widening shifts. const Cast *cast_a = a.as(); bool is_widening_cast = cast_a && cast_a->type.bits() >= cast_a->value.type().bits() * 2; if (is_widening_cast || Call::as_intrinsic(a, {Call::widening_add, Call::widening_mul, Call::widening_sub})) { const uint64_t const_m = 1ull << *const_b; Expr b = make_const(a.type(), const_m); return mutate(distribute(a, b)); } } } else if (op->is_intrinsic(Call::widening_shift_left)) { if (const uint64_t *const_b = as_const_uint(op->args[1])) { const uint64_t const_m = 1ull << *const_b; Expr b = make_const(op->type, const_m); Expr a = Cast::make(op->type, op->args[0]); return mutate(distribute(a, b)); } } return IRMutator::visit(op); } template Expr visit_add_sub(const T *op) { if (multiply_adds) { Expr a, b; if (const Call *a_call = op->a.template as()) { if (a_call->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { a = distribute_shift(a_call); } } if (const Call *b_call = op->b.template as()) { if (b_call->is_intrinsic({Call::shift_left, Call::widening_shift_left})) { b = distribute_shift(b_call); } } if (a.defined() && b.defined()) { return T::make(a, b); } else if (a.defined()) { b = mutate(op->b); return T::make(a, b); } else if (b.defined()) { a = mutate(op->a); return T::make(a, b); } else { return IRMutator::visit(op); } } else { return IRMutator::visit(op); } } using IRMutator::visit; Expr visit(const Call *op) override { if (multiply_adds) { return IRMutator::visit(op); } else { return distribute_shift(op); } } Expr visit(const Add *op) override { return visit_add_sub(op); } Expr visit(const Sub *op) override { return visit_add_sub(op); } }; } // namespace Stmt distribute_shifts(const Stmt &s, const bool multiply_adds) { return DistributeShiftsAsMuls(multiply_adds).mutate(s); } } // namespace Internal } // namespace Halide