#include "HexagonOptimize.h"
#include "Bounds.h"
#include "CSE.h"
#include "ConciseCasts.h"
#include "ExprUsesVar.h"
#include "HexagonAlignment.h"
#include "IREquality.h"
#include "IRMatch.h"
#include "IRMutator.h"
#include "IROperator.h"
#include "Lerp.h"
#include "Scope.h"
#include "Simplify.h"
#include "Substitute.h"
#include <unordered_map>
#include <utility>
namespace Halide {
namespace Internal {
using std::pair;
using std::set;
using std::string;
using std::vector;
using namespace Halide::ConciseCasts;
Expr native_interleave(const Expr &x) {
string fn;
switch (x.type().bits()) {
case 8:
fn = "halide.hexagon.interleave.vb";
break;
case 16:
fn = "halide.hexagon.interleave.vh";
break;
case 32:
fn = "halide.hexagon.interleave.vw";
break;
default:
internal_error << "Cannot interleave native vectors of type " << x.type() << "\n";
}
return Call::make(x.type(), fn, {x}, Call::PureExtern);
}
Expr native_deinterleave(const Expr &x) {
string fn;
switch (x.type().bits()) {
case 8:
fn = "halide.hexagon.deinterleave.vb";
break;
case 16:
fn = "halide.hexagon.deinterleave.vh";
break;
case 32:
fn = "halide.hexagon.deinterleave.vw";
break;
default:
internal_error << "Cannot deinterleave native vectors of type " << x.type() << "\n";
}
return Call::make(x.type(), fn, {x}, Call::PureExtern);
}
bool is_native_interleave_op(const Expr &x, const char *name) {
const Call *c = x.as<Call>();
if (!c || c->args.size() != 1) return false;
return starts_with(c->name, name);
}
bool is_native_interleave(const Expr &x) {
return is_native_interleave_op(x, "halide.hexagon.interleave");
}
bool is_native_deinterleave(const Expr &x) {
return is_native_interleave_op(x, "halide.hexagon.deinterleave");
}
namespace {
// Broadcast to an unknown number of lanes, for making patterns.
Expr bc(Expr x) {
return Broadcast::make(std::move(x), 0);
}
// This mutator rewrites patterns with an unknown number of lanes to
// have the specified number of lanes.
class WithLanes : public IRMutator {
using IRMutator::visit;
int lanes;
Type with_lanes(Type t) const {
return t.with_lanes(lanes);
}
Expr visit(const Cast *op) override {
if (op->type.lanes() != lanes) {
return Cast::make(with_lanes(op->type), mutate(op->value));
} else {
return IRMutator::visit(op);
}
}
Expr visit(const Variable *op) override {
if (op->type.lanes() != lanes) {
return Variable::make(with_lanes(op->type), op->name);
} else {
return op;
}
}
Expr visit(const Broadcast *op) override {
if (op->type.lanes() != lanes) {
return Broadcast::make(op->value, lanes);
} else {
return IRMutator::visit(op);
}
}
public:
WithLanes(int lanes)
: lanes(lanes) {
}
};
Expr with_lanes(const Expr &x, int lanes) {
return WithLanes(lanes).mutate(x);
}
struct Pattern {
enum Flags {
InterleaveResult = 1 << 0, // After evaluating the pattern, interleave native vectors of the result.
SwapOps01 = 1 << 1, // Swap operands 0 and 1 prior to substitution.
SwapOps12 = 1 << 2, // Swap operands 1 and 2 prior to substitution.
ExactLog2Op1 = 1 << 3, // Replace operand 1 with its log base 2, if the log base 2 is exact.
ExactLog2Op2 = 1 << 4, // Save as above, but for operand 2.
BeginExactLog2Op = 1, // BeginExactLog2Op and EndExactLog2Op ensure that we check only op1 and op2
EndExactLog2Op = 3, // for ExactLog2Op
DeinterleaveOp0 = 1 << 5, // Prior to evaluating the pattern, deinterleave native vectors of operand 0.
DeinterleaveOp1 = 1 << 6, // Same as above, but for operand 1.
DeinterleaveOp2 = 1 << 7,
DeinterleaveOps = DeinterleaveOp0 | DeinterleaveOp1 | DeinterleaveOp2,
BeginDeinterleaveOp = 0, // BeginDeinterleaveOp and EndDeinterleaveOp ensure that we check only three
EndDeinterleaveOp = 3, // deinterleave Op0, 1 and 2.
// Many patterns are instructions that widen only
// operand 0, which need to both deinterleave operand 0, and then
// re-interleave the result.
ReinterleaveOp0 = InterleaveResult | DeinterleaveOp0,
NarrowOp0 = 1 << 10, // Replace operand 0 with its half-width equivalent.
NarrowOp1 = 1 << 11, // Same as above, but for operand 1.
NarrowOp2 = 1 << 12,
NarrowOp3 = 1 << 13,
NarrowOps = NarrowOp0 | NarrowOp1 | NarrowOp2 | NarrowOp3,
NarrowUnsignedOp0 = 1 << 15, // Similar to the above, but narrow to an unsigned half width type.
NarrowUnsignedOp1 = 1 << 16,
NarrowUnsignedOp2 = 1 << 17,
NarrowUnsignedOps = NarrowUnsignedOp0 | NarrowUnsignedOp1 | NarrowUnsignedOp2,
v65orLater = 1 << 21, // Pattern should be matched only for v65 target or later
v66orLater = 1 << 22, // Pattern should be matched only for v66 target or later
};
string intrin; // Name of the intrinsic
Expr pattern; // The pattern to match against
int flags;
Pattern() = default;
Pattern(const string &intrin, Expr p, int flags = 0)
: intrin(intrin), pattern(std::move(p)), flags(flags) {
}
};
Expr wild_u8 = Variable::make(UInt(8), "*");
Expr wild_u16 = Variable::make(UInt(16), "*");
Expr wild_u32 = Variable::make(UInt(32), "*");
Expr wild_u64 = Variable::make(UInt(64), "*");
Expr wild_i8 = Variable::make(Int(8), "*");
Expr wild_i16 = Variable::make(Int(16), "*");
Expr wild_i32 = Variable::make(Int(32), "*");
Expr wild_i64 = Variable::make(Int(64), "*");
Expr wild_u8x = Variable::make(Type(Type::UInt, 8, 0), "*");
Expr wild_u16x = Variable::make(Type(Type::UInt, 16, 0), "*");
Expr wild_u32x = Variable::make(Type(Type::UInt, 32, 0), "*");
Expr wild_u64x = Variable::make(Type(Type::UInt, 64, 0), "*");
Expr wild_i8x = Variable::make(Type(Type::Int, 8, 0), "*");
Expr wild_i16x = Variable::make(Type(Type::Int, 16, 0), "*");
Expr wild_i32x = Variable::make(Type(Type::Int, 32, 0), "*");
Expr wild_i64x = Variable::make(Type(Type::Int, 64, 0), "*");
// Check if a pattern with flags 'flags' is supported on the target.
bool check_pattern_target(int flags, const Target &target) {
if ((flags & (Pattern::v65orLater)) &&
!target.features_any_of({Target::HVX_v65, Target::HVX_v66})) {
return false;
}
if ((flags & (Pattern::v66orLater)) &&
!target.features_any_of({Target::HVX_v66})) {
return false;
}
return true;
}
// Check if the matches satisfy the given pattern flags, and mutate the matches
// as specified by the flags.
bool process_match_flags(vector<Expr> &matches, int flags) {
// The Pattern::Narrow*Op* flags are ordered such that the operand
// corresponds to the bit (with operand 0 corresponding to the least
// significant bit), so we can check for them all in a loop.
for (size_t i = 0; i < matches.size(); i++) {
Type t = matches[i].type();
Type target_t = t.with_bits(t.bits() / 2);
if (flags & (Pattern::NarrowOp0 << i)) {
matches[i] = lossless_cast(target_t, matches[i]);
} else if (flags & (Pattern::NarrowUnsignedOp0 << i)) {
matches[i] = lossless_cast(target_t.with_code(Type::UInt), matches[i]);
}
if (!matches[i].defined()) return false;
}
for (size_t i = Pattern::BeginExactLog2Op; i < Pattern::EndExactLog2Op; i++) {
// This flag is mainly to capture shifts. When the operand of a div or
// mul is a power of 2, we can use a shift instead.
if (flags & (Pattern::ExactLog2Op1 << (i - Pattern::BeginExactLog2Op))) {
int pow;
if (is_const_power_of_two_integer(matches[i], &pow)) {
matches[i] = cast(matches[i].type().with_lanes(1), pow);
} else {
return false;
}
}
}
for (size_t i = Pattern::BeginDeinterleaveOp; i < Pattern::EndDeinterleaveOp; i++) {
if (flags & (Pattern::DeinterleaveOp0 << (i - Pattern::BeginDeinterleaveOp))) {
internal_assert(matches[i].type().is_vector());
matches[i] = native_deinterleave(matches[i]);
}
}
if (flags & Pattern::SwapOps01) {
internal_assert(matches.size() >= 2);
std::swap(matches[0], matches[1]);
}
if (flags & Pattern::SwapOps12) {
internal_assert(matches.size() >= 3);
std::swap(matches[1], matches[2]);
}
return true;
}
// Replace an expression with the one specified by a pattern.
Expr replace_pattern(Expr x, const vector<Expr> &matches, const Pattern &p) {
x = Call::make(x.type(), p.intrin, matches, Call::PureExtern);
if (p.flags & Pattern::InterleaveResult) {
// The pattern wants us to interleave the result.
x = native_interleave(x);
}
return x;
}
bool is_double_vector(const Expr &x, const Target &target) {
int native_vector_lanes = target.natural_vector_size(x.type());
return x.type().lanes() % (2 * native_vector_lanes) == 0;
}
// Attempt to apply one of the patterns to x. If a match is
// successful, the expression is replaced with a call using the
// matched operands. Prior to substitution, the matches are mutated
// with op_mutator.
Expr apply_patterns(Expr x, const vector<Pattern> &patterns, const Target &target, IRMutator *op_mutator) {
debug(3) << "apply_patterns " << x << "\n";
vector<Expr> matches;
for (const Pattern &p : patterns) {
if (!check_pattern_target(p.flags, target)) {
continue;
}
if (expr_match(p.pattern, x, matches)) {
debug(3) << "matched " << p.pattern << "\n";
debug(3) << "matches:\n";
for (Expr i : matches) {
debug(3) << i << "\n";
}
if (!process_match_flags(matches, p.flags)) {
continue;
}
// Don't apply pattern if it involves an interleave,
// and is not a multiple of two vectors.
// See https://github.com/halide/Halide/issues/1582
if ((p.flags & Pattern::InterleaveResult) && !is_double_vector(x, target)) {
continue;
}
// Mutate the operands with the given mutator.
for (Expr &op : matches) {
op = op_mutator->mutate(op);
}
x = replace_pattern(x, matches, p);
debug(3) << "rewrote to: " << x << "\n";
return x;
}
}
return x;
}
// Replace x with a negated version of x, if it can be done without
// overflow.
Expr lossless_negate(const Expr &x) {
const Mul *m = x.as<Mul>();
if (m) {
Expr a = lossless_negate(m->a);
if (a.defined()) {
return Mul::make(a, m->b);
}
Expr b = lossless_negate(m->b);
if (b.defined()) {
return Mul::make(m->a, b);
}
}
if (is_negative_negatable_const(x) || is_positive_const(x)) {
return simplify(-x);
}
return Expr();
}
template<typename T>
Expr apply_commutative_patterns(const T *op, const vector<Pattern> &patterns, const Target &target, IRMutator *mutator) {
Expr ret = apply_patterns(op, patterns, target, mutator);
if (!ret.same_as(op)) return ret;
// Try commuting the op
Expr commuted = T::make(op->b, op->a);
ret = apply_patterns(commuted, patterns, target, mutator);
if (!ret.same_as(commuted)) return ret;
return op;
}
typedef pair<Expr, Expr> MulExpr;
// If ty is scalar, and x is a vector, try to remove a broadcast
// from x prior to using lossless_cast on it.
Expr unbroadcast_lossless_cast(Type ty, Expr x) {
if (ty.lanes() == 1 && x.type().lanes() > 1) {
if (const Broadcast *bc = x.as<Broadcast>()) {
x = bc->value;
}
}
if (ty.lanes() != x.type().lanes()) {
return Expr();
}
return lossless_cast(ty, x);
}
// Try to extract a list of multiplies of the form a_ty*b_ty added
// together, such that op is equivalent to the sum of the
// multiplies in 'mpys', added to 'rest'.
// Difference in mpys.size() - return indicates the number of
// expressions where we pretend the op to be multiplied by 1.
int find_mpy_ops(const Expr &op, Type a_ty, Type b_ty, int max_mpy_count,
vector<MulExpr> &mpys, Expr &rest) {
if ((int)mpys.size() >= max_mpy_count) {
rest = rest.defined() ? Add::make(rest, op) : op;
return 0;
}
// If the add is also widening, remove the cast.
int mpy_bits = std::max(a_ty.bits(), b_ty.bits()) * 2;
Expr maybe_mul = op;
if (op.type().bits() == mpy_bits * 2) {
if (const Cast *cast = op.as<Cast>()) {
if (cast->value.type().bits() == mpy_bits) {
maybe_mul = cast->value;
}
}
}
if (const Mul *mul = maybe_mul.as<Mul>()) {
Expr a = unbroadcast_lossless_cast(a_ty, mul->a);
Expr b = unbroadcast_lossless_cast(b_ty, mul->b);
if (a.defined() && b.defined()) {
mpys.emplace_back(a, b);
return 1;
} else {
// Try to commute the op.
a = unbroadcast_lossless_cast(a_ty, mul->b);
b = unbroadcast_lossless_cast(b_ty, mul->a);
if (a.defined() && b.defined()) {
mpys.emplace_back(a, b);
return 1;
}
}
} else if (const Add *add = op.as<Add>()) {
int mpy_count = 0;
mpy_count += find_mpy_ops(add->a, a_ty, b_ty, max_mpy_count, mpys, rest);
mpy_count += find_mpy_ops(add->b, a_ty, b_ty, max_mpy_count, mpys, rest);
return mpy_count;
} else if (const Sub *sub = op.as<Sub>()) {
// Try to rewrite subs as adds.
if (const Mul *mul_b = sub->b.as<Mul>()) {
if (is_positive_const(mul_b->a) || is_negative_negatable_const(mul_b->a)) {
Expr add_b = Mul::make(simplify(-mul_b->a), mul_b->b);
int mpy_count = 0;
mpy_count += find_mpy_ops(sub->a, a_ty, b_ty, max_mpy_count, mpys, rest);
mpy_count += find_mpy_ops(add_b, a_ty, b_ty, max_mpy_count, mpys, rest);
return mpy_count;
} else if (is_positive_const(mul_b->b) || is_negative_negatable_const(mul_b->b)) {
Expr add_b = Mul::make(mul_b->a, simplify(-mul_b->b));
int mpy_count = 0;
mpy_count += find_mpy_ops(sub->a, a_ty, b_ty, max_mpy_count, mpys, rest);
mpy_count += find_mpy_ops(add_b, a_ty, b_ty, max_mpy_count, mpys, rest);
return mpy_count;
}
}
}
// Attempt to pretend this op is multiplied by 1.
Expr as_a = unbroadcast_lossless_cast(a_ty, op);
Expr as_b = unbroadcast_lossless_cast(b_ty, op);
if (as_a.defined()) {
mpys.emplace_back(as_a, make_one(b_ty));
} else if (as_b.defined()) {
mpys.emplace_back(make_one(a_ty), as_b);
} else {
rest = rest.defined() ? Add::make(rest, op) : op;
}
return 0;
}
// Perform peephole optimizations on the IR, adding appropriate
// interleave and deinterleave calls.
class OptimizePatterns : public IRMutator {
private:
using IRMutator::visit;
Target target;
Expr visit(const Mul *op) override {
static const vector<Pattern> scalar_muls = {
// Vector by scalar widening multiplies.
{"halide.hexagon.mpy.vub.ub", wild_u16x * bc(wild_u16), Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.mpy.vub.b", wild_i16x * bc(wild_i16), Pattern::InterleaveResult | Pattern::NarrowUnsignedOp0 | Pattern::NarrowOp1},
{"halide.hexagon.mpy.vuh.uh", wild_u32x * bc(wild_u32), Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.mpy.vh.h", wild_i32x * bc(wild_i32), Pattern::InterleaveResult | Pattern::NarrowOps},
// Multiplication by powers of 2.
{"halide.hexagon.shl.vub.b", wild_u8x * bc(wild_u8), Pattern::ExactLog2Op1},
{"halide.hexagon.shl.vuh.h", wild_u16x * bc(wild_u16), Pattern::ExactLog2Op1},
{"halide.hexagon.shl.vuw.w", wild_u32x * bc(wild_u32), Pattern::ExactLog2Op1},
{"halide.hexagon.shl.vb.b", wild_i8x * bc(wild_i8), Pattern::ExactLog2Op1},
{"halide.hexagon.shl.vh.h", wild_i16x * bc(wild_i16), Pattern::ExactLog2Op1},
{"halide.hexagon.shl.vw.w", wild_i32x * bc(wild_i32), Pattern::ExactLog2Op1},
// Non-widening scalar multiplication.
{"halide.hexagon.mul.vh.b", wild_i16x * bc(wild_i16), Pattern::NarrowOp1},
{"halide.hexagon.mul.vw.h", wild_i32x * bc(wild_i32), Pattern::NarrowOp1},
// TODO: There's also mul.vw.b. We currently generate mul.vw.h
// instead. I'm not sure mul.vw.b is faster, it might even be
// slower due to the extra step in broadcasting the scalar up to
// 32 bits.
};
static const vector<Pattern> muls = {
// Widening multiplication
{"halide.hexagon.mpy.vub.vub", wild_u16x * wild_u16x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.mpy.vuh.vuh", wild_u32x * wild_u32x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.mpy.vb.vb", wild_i16x * wild_i16x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.mpy.vh.vh", wild_i32x * wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.mpy.vub.vb", wild_i16x * wild_i16x, Pattern::InterleaveResult | Pattern::NarrowUnsignedOp0 | Pattern::NarrowOp1},
{"halide.hexagon.mpy.vh.vuh", wild_i32x * wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOp0 | Pattern::NarrowUnsignedOp1},
// We need to check for the commuted versions of these patterns
// before the more general patterns below catch these ops. The
// other fix for this would be to break this into a third group of
// multiply patterns, so the commuted versions of these would get
// matched first.
{"halide.hexagon.mpy.vub.vb", wild_i16x * wild_i16x, Pattern::InterleaveResult | Pattern::NarrowOp0 | Pattern::NarrowUnsignedOp1 | Pattern::SwapOps01},
{"halide.hexagon.mpy.vh.vuh", wild_i32x * wild_i32x, Pattern::InterleaveResult | Pattern::NarrowUnsignedOp0 | Pattern::NarrowOp1 | Pattern::SwapOps01},
// One operand widening multiplication.
{"halide.hexagon.mul.vw.vh", wild_i32x * wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1},
{"halide.hexagon.mul.vw.vuh", wild_i32x * wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1},
{"halide.hexagon.mul.vuw.vuh", wild_u32x * wild_u32x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1},
};
if (op->type.is_vector()) {
Expr new_expr = apply_commutative_patterns(op, scalar_muls, target, this);
if (!new_expr.same_as(op)) {
return new_expr;
}
new_expr = apply_commutative_patterns(op, muls, target, this);
if (!new_expr.same_as(op)) {
return new_expr;
}
}
return IRMutator::visit(op);
}
// Helpers to generate horizontally reducing multiply operations.
static Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v0, Expr v1, Expr c0, Expr c1) {
Expr call = Call::make(result_type, "halide.hexagon.add_2mpy" + suffix, {std::move(v0), std::move(v1), std::move(c0), std::move(c1)}, Call::PureExtern);
return native_interleave(call);
}
static Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v01, Expr c01) {
return Call::make(result_type, "halide.hexagon.add_2mpy" + suffix, {std::move(v01), std::move(c01)}, Call::PureExtern);
}
static Expr halide_hexagon_add_4mpy(Type result_type, const string &suffix, Expr v01, Expr c01) {
return Call::make(result_type, "halide.hexagon.add_4mpy" + suffix, {std::move(v01), std::move(c01)}, Call::PureExtern);
}
// We'll try to sort the mpys based my mpys.first.
// But, for this all the mpy.first exprs should either be
// all loads or all slice_vectors.
static void sort_mpy_exprs(vector<MulExpr> &mpys) {
struct LoadCompare {
bool operator()(const MulExpr &m1, const MulExpr &m2) const {
if (!m1.first.defined() || !m2.first.defined()) {
return false;
}
const Load *m1_load = m1.first.as<Load>();
const Load *m2_load = m2.first.as<Load>();
internal_assert(m1_load && m2_load);
const Ramp *m1_ramp = m1_load->index.as<Ramp>();
const Ramp *m2_ramp = m2_load->index.as<Ramp>();
internal_assert(m1_ramp && m2_ramp);
return can_prove(m1_ramp->base < m2_ramp->base);
}
};
const Shuffle *first_shuffle = mpys[0].first.as<Shuffle>();
if (first_shuffle) {
for (MulExpr &m : mpys) {
const Shuffle *shuffle = m.first.as<Shuffle>();
if (!shuffle || !shuffle->is_slice()) {
return;
}
}
std::stable_sort(mpys.begin(), mpys.end(),
[](const MulExpr &m1, const MulExpr &m2) {
return m1.first.as<Shuffle>()->slice_begin() < m2.first.as<Shuffle>()->slice_begin();
});
return;
} else if (const Load *first_load = mpys[0].first.as<Load>()) {
const Ramp *first_ramp = first_load->index.as<Ramp>();
if (!first_ramp) {
return;
}
for (MulExpr &m : mpys) {
const Load *load = m.first.as<Load>();
if (!load ||
load->name != first_load->name ||
!load->index.as<Ramp>()) {
return;
}
}
std::stable_sort(mpys.begin(), mpys.end(), LoadCompare());
}
}
Expr visit(const Add *op) override {
// vmpa, vdmpy, and vrmpy instructions are hard to match with
// patterns, do it manually here.
// Try to find vrmpy opportunities first, which consume 4 operands.
if (op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
int lanes = op->type.lanes();
vector<MulExpr> mpys;
Expr rest;
string suffix;
int mpy_count = 0;
// Try to find a vector*scalar multiply first, which will
// match a subset of the expressions that vector*vector
// matches.
if (op->type.is_uint()) {
mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8), 4, mpys, rest);
suffix = ".vub.ub";
} else {
mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 4, mpys, rest);
suffix = ".vub.b";
}
if (mpy_count > 0 && mpys.size() == 4) {
// It's possible that permuting the order of the
// multiply operands can simplify the shuffle away.
// So, give yourself a fighting chance by ordering the
// mpys in the ascending order of their start lanes (if all
// are slice_vectors) or in the ascending order of their
// load indices if all are loads from the same buffer.
sort_mpy_exprs(mpys);
Expr a0123 = Shuffle::make_interleave({mpys[0].first, mpys[1].first, mpys[2].first, mpys[3].first});
a0123 = simplify(a0123);
// We can generate this op for 16 bits, but, it's only
// faster to do so if the interleave simplifies away.
if (op->type.bits() == 32 || !a0123.as<Shuffle>()) {
Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[2].second, mpys[3].second});
b0123 = simplify(b0123);
b0123 = reinterpret(Type(b0123.type().code(), 32, 1), b0123);
Expr new_expr = halide_hexagon_add_4mpy(op->type, suffix, a0123, b0123);
if (op->type.bits() == 16) {
// It's actually safe to use this op on 16 bit
// results, we just need to narrow the
// result. Overflow can occur, but will still
// produce the same result thanks to 2's
// complement arithmetic.
new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
return mutate(new_expr);
}
}
// Now try to match vector*vector vrmpy expressions.
mpys.clear();
rest = Expr();
if (op->type.is_uint()) {
mpy_count = find_mpy_ops(op, UInt(8, lanes), UInt(8, lanes), 4, mpys, rest);
suffix = ".vub.vub";
} else {
mpy_count = find_mpy_ops(op, Int(8, lanes), Int(8, lanes), 4, mpys, rest);
suffix = ".vb.vb";
}
// TODO: suffix = ".vub.vb"
if (mpy_count > 0 && mpys.size() == 4) {
// It's possible that permuting the order of the
// multiply operands can simplify the shuffle away.
// So, give yourself a fighting chance by ordering the
// mpys in the ascending order of their start lanes (if all
// are slice_vectors) or in the ascending order of their
// load indices if all are loads from the same buffer.
sort_mpy_exprs(mpys);
Expr a0123 = Shuffle::make_interleave({mpys[0].first, mpys[1].first, mpys[2].first, mpys[3].first});
Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[2].second, mpys[3].second});
a0123 = simplify(a0123);
b0123 = simplify(b0123);
// We can generate this op for 16 bits, but, it's only
// faster to do so if the interleave simplifies away.
if (op->type.bits() == 32 || (!a0123.as<Shuffle>() && !b0123.as<Shuffle>())) {
Expr new_expr = halide_hexagon_add_4mpy(op->type, suffix, a0123, b0123);
if (op->type.bits() == 16) {
// It's actually safe to use this op on 16 bit
// results, we just need to narrow the
// result. Overflow can occur, but will still
// produce the same result thanks to 2's
// complement arithmetic.
new_expr = Call::make(op->type, "halide.hexagon.pack.vw", {new_expr}, Call::PureExtern);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
return mutate(new_expr);
}
}
}
// Find opportunities vdmpy or vmpa.
if (op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
int lanes = op->type.lanes();
vector<MulExpr> mpys;
Expr rest;
string vmpa_suffix;
string vdmpy_suffix;
int mpy_count = 0;
// Try to find vector*scalar multiplies.
if (op->type.bits() == 16) {
mpy_count = find_mpy_ops(op, UInt(8, lanes), Int(8), 2, mpys, rest);
vmpa_suffix = ".vub.vub.b.b";
vdmpy_suffix = ".vub.b";
} else if (op->type.bits() == 32) {
mpy_count = find_mpy_ops(op, Int(16, lanes), Int(8), 2, mpys, rest);
vmpa_suffix = ".vh.vh.b.b";
vdmpy_suffix = ".vh.b";
}
if (mpy_count > 0 && mpys.size() == 2) {
// It's possible that permuting the order of the
// multiply operands can simplify the shuffle away.
// So, give yourself a fighting chance by ordering the
// mpys in the ascending order of their start lanes (if all
// are slice_vectors) or in the ascending order of their
// load indices if all are loads from the same buffer.
sort_mpy_exprs(mpys);
Expr a01 = Shuffle::make_interleave({mpys[0].first, mpys[1].first});
a01 = simplify(a01);
// TODO: This requires the operands to be in a
// particular order. It should be more robust... but
// this is pretty tough to do, other than simply
// trying all permutations.
Expr new_expr;
if (!a01.as<Shuffle>() || vmpa_suffix.empty()) {
Expr b01 = Shuffle::make_interleave({mpys[0].second, mpys[1].second});
b01 = simplify(b01);
b01 = reinterpret(Type(b01.type().code(), 16, 1), b01);
new_expr = halide_hexagon_add_2mpy(op->type, vdmpy_suffix, a01, b01);
} else {
new_expr = halide_hexagon_add_2mpy(op->type, vmpa_suffix, mpys[0].first, mpys[1].first, mpys[0].second, mpys[1].second);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
return mutate(new_expr);
}
}
static const vector<Pattern> adds = {
// Use accumulating versions of vmpa, vdmpy, vrmpy instructions when possible.
{"halide.hexagon.acc_add_2mpy.vh.vub.vub.b.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.vub.b.b", wild_u8x, wild_u8x, wild_i8, wild_i8), Pattern::ReinterleaveOp0},
{"halide.hexagon.acc_add_2mpy.vw.vh.vh.b.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.vh.b.b", wild_i16x, wild_i16x, wild_i8, wild_i8), Pattern::ReinterleaveOp0},
{"halide.hexagon.acc_add_2mpy.vh.vub.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.b", wild_u8x, wild_i16)},
{"halide.hexagon.acc_add_2mpy.vw.vh.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.b", wild_i16x, wild_i16)},
{"halide.hexagon.acc_add_4mpy.vw.vub.b", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.b", wild_u8x, wild_i32)},
{"halide.hexagon.acc_add_4mpy.vuw.vub.ub", wild_u32x + halide_hexagon_add_4mpy(UInt(32, 0), ".vub.ub", wild_u8x, wild_u32)},
{"halide.hexagon.acc_add_4mpy.vuw.vub.vub", wild_u32x + halide_hexagon_add_4mpy(UInt(32, 0), ".vub.vub", wild_u8x, wild_u8x)},
{"halide.hexagon.acc_add_4mpy.vw.vub.vb", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vub.vb", wild_u8x, wild_i8x)},
{"halide.hexagon.acc_add_4mpy.vw.vb.vb", wild_i32x + halide_hexagon_add_4mpy(Int(32, 0), ".vb.vb", wild_i8x, wild_i8x)},
// Widening adds. There are other instructions that add two vub and two vuh but do not widen.
// To differentiate those from the widening ones, we encode the return type in the name here.
{"halide.hexagon.add_vuh.vub.vub", wild_u16x + wild_u16x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.add_vuw.vuh.vuh", wild_u32x + wild_u32x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.add_vw.vh.vh", wild_i32x + wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOps},
// Widening multiply-accumulates with a scalar.
{"halide.hexagon.add_mpy.vuh.vub.ub", wild_u16x + wild_u16x * bc(wild_u16), Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vh.vub.b", wild_i16x + wild_i16x * bc(wild_i16), Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vuw.vuh.uh", wild_u32x + wild_u32x * bc(wild_u32), Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vuh.vub.ub", wild_u16x + bc(wild_u16) * wild_u16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12},
{"halide.hexagon.add_mpy.vh.vub.b", wild_i16x + bc(wild_i16) * wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowUnsignedOp2 | Pattern::SwapOps12},
{"halide.hexagon.add_mpy.vuw.vuh.uh", wild_u32x + bc(wild_u32) * wild_u32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12},
// These patterns aren't exactly right because the instruction
// saturates the result. However, this is really the instruction
// that we want to use in most cases, and we can exploit the fact
// that 32 bit signed arithmetic overflow is undefined to argue
// that these patterns are not completely incorrect.
{"halide.hexagon.satw_add_mpy.vw.vh.h", wild_i32x + wild_i32x * bc(wild_i32), Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.satw_add_mpy.vw.vh.h", wild_i32x + bc(wild_i32) * wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12},
// Widening multiply-accumulates.
{"halide.hexagon.add_mpy.vuh.vub.vub", wild_u16x + wild_u16x * wild_u16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vuw.vuh.vuh", wild_u32x + wild_u32x * wild_u32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vh.vb.vb", wild_i16x + wild_i16x * wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vw.vh.vh", wild_i32x + wild_i32x * wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vh.vub.vb", wild_i16x + wild_i16x * wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 | Pattern::NarrowOp2},
{"halide.hexagon.add_mpy.vw.vh.vuh", wild_i32x + wild_i32x * wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowUnsignedOp2},
{"halide.hexagon.add_mpy.vh.vub.vb", wild_i16x + wild_i16x * wild_i16x, Pattern::ReinterleaveOp0 | Pattern::NarrowOp1 | Pattern::NarrowUnsignedOp2 | Pattern::SwapOps12},
{"halide.hexagon.add_mpy.vw.vh.vuh", wild_i32x + wild_i32x * wild_i32x, Pattern::ReinterleaveOp0 | Pattern::NarrowUnsignedOp1 | Pattern::NarrowOp2 | Pattern::SwapOps12},
// Shift-accumulates.
{"halide.hexagon.add_shr.vw.vw.uw", wild_i32x + (wild_i32x >> bc(wild_u32))},
{"halide.hexagon.add_shl.vw.vw.uw", wild_i32x + (wild_i32x << bc(wild_u32))},
{"halide.hexagon.add_shl.vw.vw.uw", wild_u32x + (wild_u32x << bc(wild_u32))},
{"halide.hexagon.add_shr.vw.vw.uw", wild_i32x + (wild_i32x / bc(wild_i32)), Pattern::ExactLog2Op2},
{"halide.hexagon.add_shl.vw.vw.uw", wild_i32x + (wild_i32x * bc(wild_i32)), Pattern::ExactLog2Op2},
{"halide.hexagon.add_shl.vw.vw.uw", wild_u32x + (wild_u32x * bc(wild_u32)), Pattern::ExactLog2Op2},
{"halide.hexagon.add_shl.vw.vw.uw", wild_i32x + (bc(wild_i32) * wild_i32x), Pattern::ExactLog2Op1 | Pattern::SwapOps12},
{"halide.hexagon.add_shl.vw.vw.uw", wild_u32x + (bc(wild_u32) * wild_u32x), Pattern::ExactLog2Op1 | Pattern::SwapOps12},
{"halide.hexagon.add_shl.vh.vh.uh", wild_i16x + (wild_i16x << bc(wild_u16)), Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_u16x + (wild_u16x << bc(wild_u16)), Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_i16x + (bc(wild_i16) << wild_u16x), Pattern::SwapOps12 | Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_u16x + (bc(wild_u16) << wild_u16x), Pattern::SwapOps12 | Pattern::v65orLater},
{"halide.hexagon.add_shr.vh.vh.uh", wild_i16x + (wild_i16x >> bc(wild_u16)), Pattern::v65orLater},
{"halide.hexagon.add_shr.vh.vh.uh", wild_i16x + (wild_i16x / bc(wild_i16)), Pattern::ExactLog2Op2 | Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_i16x + (wild_i16x * bc(wild_i16)), Pattern::ExactLog2Op2 | Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_u16x + (wild_u16x * bc(wild_u16)), Pattern::ExactLog2Op2 | Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_i16x + (bc(wild_i16) * wild_i16x), Pattern::ExactLog2Op1 | Pattern::SwapOps12 | Pattern::v65orLater},
{"halide.hexagon.add_shl.vh.vh.uh", wild_u16x + (bc(wild_u16) * wild_u16x), Pattern::ExactLog2Op1 | Pattern::SwapOps12 | Pattern::v65orLater},
// Non-widening multiply-accumulates with a scalar.
{"halide.hexagon.add_mul.vh.vh.b", wild_i16x + wild_i16x * bc(wild_i16), Pattern::NarrowOp2},
{"halide.hexagon.add_mul.vw.vw.h", wild_i32x + wild_i32x * bc(wild_i32), Pattern::NarrowOp2},
{"halide.hexagon.add_mul.vh.vh.b", wild_i16x + bc(wild_i16) * wild_i16x, Pattern::NarrowOp1 | Pattern::SwapOps12},
{"halide.hexagon.add_mul.vw.vw.h", wild_i32x + bc(wild_i32) * wild_i32x, Pattern::NarrowOp1 | Pattern::SwapOps12},
// TODO: There's also a add_mul.vw.vw.b
// This pattern is very general, so it must come last.
{"halide.hexagon.add_mul.vh.vh.vh", wild_i16x + wild_i16x * wild_i16x},
};
if (op->type.is_vector()) {
Expr new_expr = apply_commutative_patterns(op, adds, target, this);
if (!new_expr.same_as(op)) {
return new_expr;
}
}
return IRMutator::visit(op);
}
Expr visit(const Sub *op) override {
if (op->type.is_vector()) {
// Try negating op->b, using an add pattern if successful.
Expr neg_b = lossless_negate(op->b);
if (neg_b.defined()) {
return mutate(op->a + neg_b);
} else {
static const vector<Pattern> subs = {
// Widening subtracts. There are other instructions that subtact two vub and two vuh but do not widen.
// To differentiate those from the widening ones, we encode the return type in the name here.
{"halide.hexagon.sub_vuh.vub.vub", wild_u16x - wild_u16x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.sub_vh.vub.vub", wild_i16x - wild_i16x, Pattern::InterleaveResult | Pattern::NarrowUnsignedOps},
{"halide.hexagon.sub_vuw.vuh.vuh", wild_u32x - wild_u32x, Pattern::InterleaveResult | Pattern::NarrowOps},
{"halide.hexagon.sub_vw.vuh.vuh", wild_i32x - wild_i32x, Pattern::InterleaveResult | Pattern::NarrowUnsignedOps},
{"halide.hexagon.sub_vw.vh.vh", wild_i32x - wild_i32x, Pattern::InterleaveResult | Pattern::NarrowOps},
};
Expr new_expr = apply_patterns(op, subs, target, this);
if (!new_expr.same_as(op)) {
return new_expr;
}
}
}
return IRMutator::visit(op);
}
Expr visit(const Max *op) override {
Expr expr = IRMutator::visit(op);
if (op->type.is_vector()) {
// This pattern is weird (two operands must match, result
// needs 1 added) and we're unlikely to need another
// pattern for max, so just match it directly.
static const pair<string, Expr> cl[] = {
{"halide.hexagon.cls.vh", max(count_leading_zeros(wild_i16x), count_leading_zeros(~wild_i16x))},
{"halide.hexagon.cls.vw", max(count_leading_zeros(wild_i32x), count_leading_zeros(~wild_i32x))},
};
vector<Expr> matches;
for (const auto &i : cl) {
if (expr_match(i.second, expr, matches) && equal(matches[0], matches[1])) {
return Call::make(op->type, i.first, {matches[0]}, Call::PureExtern) + 1;
}
}
}
return expr;
}
Expr visit(const Cast *op) override {
// Separate these so we can do some special handling below.
static const vector<Pattern> trunc_mpy = {
// Multiply keep high half
{"halide.hexagon.trunc_mpy.vw.vw", i32((wild_i64x * wild_i64x) / wild_i64), Pattern::NarrowOps},
// Scalar multiply keep high half, with multiplication by 2.
{"halide.hexagon.trunc_satw_mpy2.vh.h", i16_sat((wild_i32x * bc(wild_i32)) / wild_i32), Pattern::NarrowOps},
{"halide.hexagon.trunc_satw_mpy2.vh.h", i16_sat((bc(wild_i32) * wild_i32x) / wild_i32), Pattern::NarrowOps | Pattern::SwapOps01},
// Scalar and vector multiply keep high half, with multiplication by 2, and rounding.
{"halide.hexagon.trunc_satw_mpy2_rnd.vh.h", i16_sat((wild_i32x * bc(wild_i32) + wild_i32) / wild_i32), Pattern::NarrowOps},
{"halide.hexagon.trunc_satw_mpy2_rnd.vh.h", i16_sat((bc(wild_i32) * wild_i32x + wild_i32) / wild_i32), Pattern::NarrowOps | Pattern::SwapOps01},
{"halide.hexagon.trunc_satw_mpy2_rnd.vh.vh", i16_sat((wild_i32x * wild_i32x + wild_i32) / wild_i32), Pattern::NarrowOps},
{"halide.hexagon.trunc_satdw_mpy2_rnd.vw.vw", i32_sat((wild_i64x * wild_i64x + wild_i64) / wild_i64), Pattern::NarrowOps},
// Vector multiply keep high half, with multiplicatoin by 2.
{"halide.hexagon.trunc_satdw_mpy2.vw.vw", i32_sat((wild_i64x * wild_i64x) / wild_i64), Pattern::NarrowOps},
};
static const vector<Pattern> casts = {
// Averaging
{"halide.hexagon.avg.vub.vub", u8((wild_u16x + wild_u16x) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg.vuh.vuh", u16((wild_u32x + wild_u32x) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg.vh.vh", i16((wild_i32x + wild_i32x) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg.vw.vw", i32((wild_i64x + wild_i64x) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg.vb.vb", i8((wild_i16x + wild_i16x) / 2), Pattern::NarrowOps | Pattern::v65orLater},
{"halide.hexagon.avg.vuw.vuw", u32((wild_u64x + wild_u64x) / 2), Pattern::NarrowOps | Pattern::v65orLater},
{"halide.hexagon.avg_rnd.vub.vub", u8((wild_u16x + wild_u16x + 1) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg_rnd.vuh.vuh", u16((wild_u32x + wild_u32x + 1) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg_rnd.vh.vh", i16((wild_i32x + wild_i32x + 1) / 2), Pattern::NarrowOps},
{"halide.hexagon.avg_rnd.vw.vw", i32((wild_i64x + wild_i64x + 1) / 2), Pattern::NarrowOps},
{"halide.hexagon.navg.vub.vub", i8_sat((wild_i16x - wild_i16x) / 2), Pattern::NarrowUnsignedOps},
{"halide.hexagon.navg.vh.vh", i16_sat((wild_i32x - wild_i32x) / 2), Pattern::NarrowOps},
{"halide.hexagon.navg.vw.vw", i32_sat((wild_i64x - wild_i64x) / 2), Pattern::NarrowOps},
// vnavg.uw doesn't exist.
// Saturating add/subtract
{"halide.hexagon.satub_add.vub.vub", u8_sat(wild_u16x + wild_u16x), Pattern::NarrowOps},
{"halide.hexagon.satuh_add.vuh.vuh", u16_sat(wild_u32x + wild_u32x), Pattern::NarrowOps},
{"halide.hexagon.satuw_add.vuw.vuw", u32_sat(wild_u64x + wild_u64x), Pattern::NarrowOps},
{"halide.hexagon.sath_add.vh.vh", i16_sat(wild_i32x + wild_i32x), Pattern::NarrowOps},
{"halide.hexagon.satw_add.vw.vw", i32_sat(wild_i64x + wild_i64x), Pattern::NarrowOps},
{"halide.hexagon.satub_sub.vub.vub", u8_sat(wild_i16x - wild_i16x), Pattern::NarrowUnsignedOps},
{"halide.hexagon.satuh_sub.vuh.vuh", u16_sat(wild_i32x - wild_i32x), Pattern::NarrowUnsignedOps},
{"halide.hexagon.sath_sub.vh.vh", i16_sat(wild_i32x - wild_i32x), Pattern::NarrowOps},
{"halide.hexagon.satw_sub.vw.vw", i32_sat(wild_i64x - wild_i64x), Pattern::NarrowOps},
// Saturating narrowing casts with rounding
{"halide.hexagon.trunc_satub_rnd.vh", u8_sat((wild_i32x + 128) / 256), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0},
{"halide.hexagon.trunc_satb_rnd.vh", i8_sat((wild_i32x + 128) / 256), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0},
{"halide.hexagon.trunc_satuh_rnd.vw", u16_sat((wild_i64x + 32768) / 65536), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0},
{"halide.hexagon.trunc_sath_rnd.vw", i16_sat((wild_i64x + 32768) / 65536), Pattern::DeinterleaveOp0 | Pattern::NarrowOp0},
// Saturating narrowing casts
{"halide.hexagon.trunc_satub_shr.vh.uh", u8_sat(wild_i16x >> wild_u16), Pattern::DeinterleaveOp0},
{"halide.hexagon.trunc_satuh_shr.vw.uw", u16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0},
{"halide.hexagon.trunc_sath_shr.vw.uw", i16_sat(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0},
{"halide.hexagon.trunc_satub_shr.vh.uh", u8_sat(wild_i16x / wild_i16), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1},
{"halide.hexagon.trunc_satuh_shr.vw.uw", u16_sat(wild_i32x / wild_i32), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1},
{"halide.hexagon.trunc_sath_shr.vw.uw", i16_sat(wild_i32x / wild_i32), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1},
// For some of the following narrowing casts, we have the choice of
// non-interleaving or interleaving instructions. Because we don't
// know which one we prefer during pattern matching, we match the
// non-interleaving versions for now and replace them with the
// instructions that interleave later if it makes sense.
// Saturating narrowing casts. These may interleave later with trunc_sat.
{"halide.hexagon.pack_satub.vh", u8_sat(wild_i16x)},
{"halide.hexagon.pack_satuh.vw", u16_sat(wild_i32x)},
{"halide.hexagon.pack_satb.vh", i8_sat(wild_i16x)},
{"halide.hexagon.pack_sath.vw", i16_sat(wild_i32x)},
// We don't have a vpack equivalent to this one, so we match it directly.
{"halide.hexagon.trunc_satuh.vuw", u16_sat(wild_u32x), Pattern::DeinterleaveOp0},
// Narrowing casts. These may interleave later with trunclo.
{"halide.hexagon.packhi.vh", u8(wild_u16x / 256)},
{"halide.hexagon.packhi.vh", u8(wild_i16x / 256)},
{"halide.hexagon.packhi.vh", i8(wild_u16x / 256)},
{"halide.hexagon.packhi.vh", i8(wild_i16x / 256)},
{"halide.hexagon.packhi.vw", u16(wild_u32x / 65536)},
{"halide.hexagon.packhi.vw", u16(wild_i32x / 65536)},
{"halide.hexagon.packhi.vw", i16(wild_u32x / 65536)},
{"halide.hexagon.packhi.vw", i16(wild_i32x / 65536)},
// Narrowing with shifting.
{"halide.hexagon.trunc_shr.vw.uw", i16(wild_i32x >> wild_u32), Pattern::DeinterleaveOp0},
{"halide.hexagon.trunc_shr.vw.uw", i16(wild_i32x / wild_i32), Pattern::DeinterleaveOp0 | Pattern::ExactLog2Op1},
// Narrowing casts. These may interleave later with trunc.
{"halide.hexagon.pack.vh", u8(wild_u16x)},
{"halide.hexagon.pack.vh", u8(wild_i16x)},
{"halide.hexagon.pack.vh", i8(wild_u16x)},
{"halide.hexagon.pack.vh", i8(wild_i16x)},
{"halide.hexagon.pack.vw", u16(wild_u32x)},
{"halide.hexagon.pack.vw", u16(wild_i32x)},
{"halide.hexagon.pack.vw", i16(wild_u32x)},
{"halide.hexagon.pack.vw", i16(wild_i32x)},
// Widening casts
{"halide.hexagon.zxt.vub", u16(wild_u8x), Pattern::InterleaveResult},
{"halide.hexagon.zxt.vub", i16(wild_u8x), Pattern::InterleaveResult},
{"halide.hexagon.zxt.vuh", u32(wild_u16x), Pattern::InterleaveResult},
{"halide.hexagon.zxt.vuh", i32(wild_u16x), Pattern::InterleaveResult},
{"halide.hexagon.sxt.vb", u16(wild_i8x), Pattern::InterleaveResult},
{"halide.hexagon.sxt.vb", i16(wild_i8x), Pattern::InterleaveResult},
{"halide.hexagon.sxt.vh", u32(wild_i16x), Pattern::InterleaveResult},
{"halide.hexagon.sxt.vh", i32(wild_i16x), Pattern::InterleaveResult},
};
// To hit more of the patterns we want, rewrite "double casts"
// as two stage casts. This also avoids letting vector casts
// fall through to LLVM, which will generate large unoptimized
// shuffles.
static const vector<pair<Expr, Expr>> cast_rewrites = {
// Saturating narrowing
{u8_sat(wild_u32x), u8_sat(u16_sat(wild_u32x))},
{u8_sat(wild_i32x), u8_sat(i16_sat(wild_i32x))},
{i8_sat(wild_u32x), i8_sat(u16_sat(wild_u32x))},
{i8_sat(wild_i32x), i8_sat(i16_sat(wild_i32x))},
// Narrowing
{u8(wild_u32x), u8(u16(wild_u32x))},
{u8(wild_i32x), u8(i16(wild_i32x))},
{i8(wild_u32x), i8(u16(wild_u32x))},
{i8(wild_i32x), i8(i16(wild_i32x))},
// Widening
{u32(wild_u8x), u32(u16(wild_u8x))},
{u32(wild_i8x), u32(i16(wild_i8x))},
{i32(wild_u8x), i32(u16(wild_u8x))},
{i32(wild_i8x), i32(i16(wild_i8x))},
};
if (op->type.is_vector()) {
Expr cast = op;
// Truncating multiplies require special care, because the
// simplifier can cause them to have denominators we do not expect.
// If the simplifier cancels a factor out of these patterns, we can
// still use them, but we have to inject the factor back into the
// expression.
vector<Expr> matches;
for (const Pattern &p : trunc_mpy) {
if (!check_pattern_target(p.flags, target)) {
continue;
}
if (expr_match(p.pattern, cast, matches)) {
int log2_denominator = 0;
if (matches.size() == 4) {
// Rounding patterns have 4 operands, with the rounding
// in the 3rd operand and the denominator in the 4th.
if (!is_const_power_of_two_integer(matches[3], &log2_denominator)) {
continue;
}
if (!can_prove(matches[2] * 2 == i64(1) << log2_denominator)) {
continue;
}
} else {
// The non-rounding patterns have the denominator in the 3rd operand.
if (!is_const_power_of_two_integer(matches[2], &log2_denominator)) {
continue;
}
}
// Drop the divisor and the rounding.
matches.resize(2);
// If the power of 2 is not exactly 2^(bits of the result
// type), we need to scale up the operand accordingly.
int shift = cast.type().bits() - log2_denominator;
if (p.intrin.find("mpy2") != std::string::npos) {
// Account for a built-in factor of 2.
shift -= 1;
}
// We need to build the scale into one of the operands,
// which must be a constant.
if (is_const(matches[0])) {
matches[0] = simplify(matches[0] * (1 << shift));
} else {
// Just assume this is a constant. If it is, this will
// work correctly (will simplify only if it doesn't
// overflow the narrower type). If not, we will probably
// fail to satisfy the match flags, but it also might
// work...
matches[1] = simplify(matches[1] * (1 << shift));
}
if (!process_match_flags(matches, p.flags)) {
continue;
}
// Mutate the operands with the given mutator.
for (Expr &op : matches) {
op = mutate(op);
}
cast = replace_pattern(cast, matches, p);
return cast;
}
}
Expr new_expr = apply_patterns(cast, casts, target, this);
if (!new_expr.same_as(cast)) {
return new_expr;
}
// If we didn't find a pattern, try using one of the
// rewrites above.
for (auto i : cast_rewrites) {
if (expr_match(i.first, cast, matches)) {
debug(3) << "rewriting cast to: " << i.first << " from " << cast << "\n";
Expr replacement = with_lanes(i.second, op->type.lanes());
Expr expr = substitute("*", matches[0], replacement);
return mutate(expr);
}
}
}
return IRMutator::visit(op);
}
Expr visit(const Call *op) override {
if (op->is_intrinsic(Call::lerp)) {
// We need to lower lerps now to optimize the arithmetic
// that they generate.
internal_assert(op->args.size() == 3);
return mutate(lower_lerp(op->args[0], op->args[1], op->args[2]));
} else {
return IRMutator::visit(op);
}
}
public:
OptimizePatterns(Target t) {
target = t;
}
};
// Attempt to cancel out redundant interleave/deinterleave pairs. The
// basic strategy is to push interleavings toward the end of the
// program, using the fact that interleaves can pass through pointwise
// IR operations. When an interleave collides with a deinterleave,
// they cancel out.
class EliminateInterleaves : public IRMutator {
Scope<bool> vars;
// We need to know when loads are a multiple of 2 native vectors.
int native_vector_bits;
// Alignment analyzer for loads and stores
HexagonAlignmentAnalyzer alignment_analyzer;
// Check if x is an expression that is either an interleave, or
// transitively is an interleave.
bool yields_removable_interleave(const Expr &x) {
if (is_native_interleave(x)) {
return true;
}
if (const Let *let = x.as<Let>()) {
return yields_removable_interleave(let->body);
}
const Variable *var = x.as<Variable>();
if (var && vars.contains(var->name + ".deinterleaved")) {
return true;
}
return false;
}
// Check if x either has a removable interleave, or it can pretend
// to be an interleave at no cost (a scalar or a broadcast).
bool yields_interleave(const Expr &x) {
if (yields_removable_interleave(x)) {
return true;
}
// These yield an interleave, but we shouldn't
// deinterleave them if we want to remove an actual
// interleave.
if (x.type().is_scalar() || x.as<Broadcast>()) {
return true;
}
if (const Let *let = x.as<Let>()) {
return yields_interleave(let->body);
}
// This is different from the deinterleaved lets handled in
// yields_removable_interleave. These are lets that can be
// deinterleaved freely, but are not actually interleaves.
const Variable *var = x.as<Variable>();
if (var && vars.contains(var->name + ".weak_deinterleaved")) {
return true;
}
return false;
}
// Check that at least one of exprs is an interleave that should
// be removed, and that all of the exprs can yield an interleave.
bool yields_removable_interleave(const vector<Expr> &exprs) {
bool any_is_interleave = false;
for (const Expr &i : exprs) {
if (yields_removable_interleave(i)) {
any_is_interleave = true;
} else if (!yields_interleave(i)) {
return false;
}
}
return any_is_interleave;
}
// Asserting that x is an expression that can yield an interleave
// operation, return the expression being interleaved.
Expr remove_interleave(Expr x) {
if (is_native_interleave(x)) {
return x.as<Call>()->args[0];
} else if (x.type().is_scalar() || x.as<Broadcast>()) {
return x;
}
if (const Variable *var = x.as<Variable>()) {
if (vars.contains(var->name + ".deinterleaved")) {
return Variable::make(var->type, var->name + ".deinterleaved");
} else if (vars.contains(var->name + ".weak_deinterleaved")) {
return Variable::make(var->type, var->name + ".weak_deinterleaved");
}
}
if (const Let *let = x.as<Let>()) {
Expr body = remove_interleave(let->body);
if (!body.same_as(let->body)) {
return Let::make(let->name, let->value, body);
} else {
return x;
}
}
internal_error << "Expression '" << x << "' does not yield an interleave.\n";
return x;
}
template<typename T>
Expr visit_binary(const T *op) {
Expr expr;
Expr a = mutate(op->a);
Expr b = mutate(op->b);
if (yields_removable_interleave({a, b})) {
a = remove_interleave(a);
b = remove_interleave(b);
expr = T::make(a, b);
expr = native_interleave(expr);
} else if (!a.same_as(op->a) || !b.same_as(op->b)) {
expr = T::make(a, b);
} else {
expr = op;
}
return expr;
}
Expr visit(const Add *op) override {
return visit_binary(op);
}
Expr visit(const Sub *op) override {
return visit_binary(op);
}
Expr visit(const Mul *op) override {
return visit_binary(op);
}
Expr visit(const Div *op) override {
return visit_binary(op);
}
Expr visit(const Mod *op) override {
return visit_binary(op);
}
Expr visit(const Min *op) override {
return visit_binary(op);
}
Expr visit(const Max *op) override {
return visit_binary(op);
}
Expr visit(const Select *op) override {
Expr true_value = mutate(op->true_value);
Expr false_value = mutate(op->false_value);
Expr cond = mutate(op->condition);
// The condition isn't a vector, so we can just check if we
// should move an interleave from the true/false values.
if (cond.type().is_scalar() && yields_removable_interleave({true_value, false_value})) {
true_value = remove_interleave(true_value);
false_value = remove_interleave(false_value);
return native_interleave(Select::make(cond, true_value, false_value));
} else if (!cond.same_as(op->condition) ||
!true_value.same_as(op->true_value) ||
!false_value.same_as(op->false_value)) {
return Select::make(cond, true_value, false_value);
} else {
return op;
}
}
template<typename NodeType, typename LetType>
NodeType visit_let(const LetType *op) {
Expr value = mutate(op->value);
string deinterleaved_name;
NodeType body;
// Other code in this mutator needs to be able to tell the
// difference between a Let that yields a deinterleave, and a
// let that has a removable deinterleave. Lets that can
// pretend to be deinterleaved at no cost are given an
// alternative let labelled "weak_deinterleaved", while lets
// that have a removable interleave are given an alternative
// let labelled "deinterleaved".
if (yields_removable_interleave(value)) {
// We can provide a deinterleaved version of this let value.
deinterleaved_name = op->name + ".deinterleaved";
vars.push(deinterleaved_name, true);
body = mutate(op->body);
vars.pop(deinterleaved_name);
} else if (yields_interleave(value)) {
// We have a soft deinterleaved version of this let value.
deinterleaved_name = op->name + ".weak_deinterleaved";
vars.push(deinterleaved_name, true);
body = mutate(op->body);
vars.pop(deinterleaved_name);
} else {
body = mutate(op->body);
}
if (value.same_as(op->value) && body.same_as(op->body)) {
return op;
} else if (body.same_as(op->body)) {
// If the body didn't change, we must not have used the deinterleaved value.
return LetType::make(op->name, value, body);
} else {
// We need to rewrap the body with new lets.
NodeType result = body;
bool deinterleaved_used = stmt_or_expr_uses_var(result, deinterleaved_name);
bool interleaved_used = stmt_or_expr_uses_var(result, op->name);
if (deinterleaved_used && interleaved_used) {
// The body uses both the interleaved and
// deinterleaved version of this let. Generate both
// lets, using the deinterleaved one to generate the
// interleaved one.
Expr deinterleaved = remove_interleave(value);
// If we actually removed an interleave from the
// value, re-interleave it to get the interleaved let
// value.
Expr interleaved = Variable::make(deinterleaved.type(), deinterleaved_name);
if (!deinterleaved.same_as(value)) {
interleaved = native_interleave(interleaved);
}
result = LetType::make(op->name, interleaved, result);
return LetType::make(deinterleaved_name, deinterleaved, result);
} else if (deinterleaved_used) {
// Only the deinterleaved value is used, we can eliminate the interleave.
return LetType::make(deinterleaved_name, remove_interleave(value), result);
} else if (interleaved_used) {
// Only the original value is used, regenerate the let.
return LetType::make(op->name, value, result);
} else {
// The let must have been dead.
internal_assert(!stmt_or_expr_uses_var(op->body, op->name))
<< "EliminateInterleaves eliminated a non-dead let.\n";
return NodeType();
}
}
}
Expr visit(const Let *op) override {
Expr expr = visit_let<Expr>(op);
// Lift interleaves out of Let expression bodies.
const Let *let = expr.as<Let>();
if (let && yields_removable_interleave(let->body)) {
expr = native_interleave(Let::make(let->name, let->value, remove_interleave(let->body)));
}
return expr;
}
Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
}
Expr visit(const Cast *op) override {
if (op->type.bits() == op->value.type().bits()) {
// We can only move interleaves through casts of the same size.
Expr value = mutate(op->value);
if (yields_removable_interleave(value)) {
value = remove_interleave(value);
return native_interleave(Cast::make(op->type, value));
} else if (!value.same_as(op->value)) {
return Cast::make(op->type, value);
} else {
return op;
}
} else {
return IRMutator::visit(op);
}
}
static bool is_interleavable(const Call *op) {
// These calls can have interleaves moved from operands to the
// result...
static const set<string> interleavable = {
Call::get_intrinsic_name(Call::bitwise_and),
Call::get_intrinsic_name(Call::bitwise_not),
Call::get_intrinsic_name(Call::bitwise_xor),
Call::get_intrinsic_name(Call::bitwise_or),
Call::get_intrinsic_name(Call::shift_left),
Call::get_intrinsic_name(Call::shift_right),
Call::get_intrinsic_name(Call::abs),
Call::get_intrinsic_name(Call::absd)};
if (interleavable.count(op->name) != 0) return true;
// ...these calls cannot. Furthermore, these calls have the
// same return type as the arguments, which means our test
// below will be inaccurate.
static const set<string> not_interleavable = {
"halide.hexagon.interleave.vb",
"halide.hexagon.interleave.vh",
"halide.hexagon.interleave.vw",
"halide.hexagon.deinterleave.vb",
"halide.hexagon.deinterleave.vh",
"halide.hexagon.deinterleave.vw",
Call::get_intrinsic_name(Call::hvx_gather),
Call::get_intrinsic_name(Call::hvx_scatter),
Call::get_intrinsic_name(Call::hvx_scatter_acc),
};
if (not_interleavable.count(op->name) != 0) return false;
if (starts_with(op->name, "halide.hexagon.")) {
// We assume that any hexagon intrinsic is interleavable
// as long as all of the vector operands have the same
// number of lanes and lane width as the return type.
for (Expr i : op->args) {
if (i.type().is_scalar()) continue;
if (i.type().bits() != op->type.bits() || i.type().lanes() != op->type.lanes()) {
return false;
}
}
}
return true;
}
Expr visit(const Call *op) override {
vector<Expr> args(op->args);
// mutate all the args.
bool changed = false;
for (Expr &i : args) {
Expr new_i = mutate(i);
changed = changed || !new_i.same_as(i);
i = new_i;
}
// For a few operations, we have a choice of several
// instructions, an interleaving or a non-inerleaving
// variant. We handle this by generating the instruction that
// does not deinterleave, and then opportunistically select
// the interleaving alternative when we can cancel out to the
// interleave.
static std::map<string, string> deinterleaving_alts = {
{"halide.hexagon.pack.vh", "halide.hexagon.trunc.vh"},
{"halide.hexagon.pack.vw", "halide.hexagon.trunc.vw"},
{"halide.hexagon.packhi.vh", "halide.hexagon.trunclo.vh"},
{"halide.hexagon.packhi.vw", "halide.hexagon.trunclo.vw"},
{"halide.hexagon.pack_satub.vh", "halide.hexagon.trunc_satub.vh"},
{"halide.hexagon.pack_sath.vw", "halide.hexagon.trunc_sath.vw"},
{"halide.hexagon.pack_satuh.vw", "halide.hexagon.trunc_satuh.vw"},
};
// The reverse mapping of the above.
static std::map<string, string> interleaving_alts = {
{"halide.hexagon.trunc.vh", "halide.hexagon.pack.vh"},
{"halide.hexagon.trunc.vw", "halide.hexagon.pack.vw"},
{"halide.hexagon.trunclo.vh", "halide.hexagon.packhi.vh"},
{"halide.hexagon.trunclo.vw", "halide.hexagon.packhi.vw"},
{"halide.hexagon.trunc_satub.vh", "halide.hexagon.pack_satub.vh"},
{"halide.hexagon.trunc_sath.vw", "halide.hexagon.pack_sath.vw"},
{"halide.hexagon.trunc_satuh.vw", "halide.hexagon.pack_satuh.vw"},
};
if (is_native_deinterleave(op) && yields_interleave(args[0])) {
// This is a deinterleave of an interleave! Remove them both.
return remove_interleave(args[0]);
} else if (is_interleavable(op) && yields_removable_interleave(args)) {
// All the arguments yield interleaves (and one of
// them is an interleave), create a new call with the
// interleave removed from the arguments.
for (Expr &i : args) {
i = remove_interleave(i);
}
Expr expr = Call::make(op->type, op->name, args, op->call_type,
op->func, op->value_index, op->image, op->param);
// Add the interleave back to the result of the call.
return native_interleave(expr);
} else if (deinterleaving_alts.find(op->name) != deinterleaving_alts.end() &&
yields_removable_interleave(args)) {
// This call has a deinterleaving alternative, and the
// arguments are interleaved, so we should use the
// alternative instead.
for (Expr &i : args) {
i = remove_interleave(i);
}
return Call::make(op->type, deinterleaving_alts[op->name], args, op->call_type);
} else if (interleaving_alts.count(op->name) && is_native_deinterleave(args[0])) {
// This is an interleaving alternative with a
// deinterleave, which can be generated when we
// deinterleave storage. Revert back to the interleaving
// op so we can remove the deinterleave.
Expr arg = args[0].as<Call>()->args[0];
return Call::make(op->type, interleaving_alts[op->name], {arg}, op->call_type,
op->func, op->value_index, op->image, op->param);
} else if (changed) {
return Call::make(op->type, op->name, args, op->call_type,
op->func, op->value_index, op->image, op->param);
} else {
return op;
}
}
// Track whether buffers are interleaved or not.
enum class BufferState {
Unknown, // We don't know if this buffer is interleaved or not.
Interleaved, // We know the buffer is interleaved.
NotInterleaved, // We know the buffer is not interleaved.
};
Scope<BufferState> buffers;
// False for buffers that have any loads or stores that are unaligned
Scope<bool> aligned_buffer_access;
// Buffers we should deinterleave the storage of.
Scope<bool> deinterleave_buffers;
Stmt visit(const Allocate *op) override {
Expr condition = mutate(op->condition);
// First, we need to mutate the op, to pull native interleaves
// down, and to gather information about the loads and stores.
buffers.push(op->name, BufferState::Unknown);
// Assume buffers are accessed by aligned loads and stores by default.
aligned_buffer_access.push(op->name, true);
Stmt body = mutate(op->body);
bool deinterleave = (buffers.get(op->name) == BufferState::Interleaved) &&
(aligned_buffer_access.get(op->name) == true);
buffers.pop(op->name);
// Second, if we decided it would be useful to deinterleave
// the storage of this buffer, do so now.
if (deinterleave) {
deinterleave_buffers.push(op->name, true);
body = mutate(op->body);
deinterleave_buffers.pop(op->name);
}
aligned_buffer_access.pop(op->name);
if (!body.same_as(op->body) || !condition.same_as(op->condition)) {
return Allocate::make(op->name, op->type, op->memory_type,
op->extents, condition, body,
op->new_expr, op->free_function);
} else {
return op;
}
}
Stmt visit(const Store *op) override {
Expr predicate = mutate(op->predicate);
Expr value = mutate(op->value);
Expr index = mutate(op->index);
if (buffers.contains(op->name)) {
// When inspecting the stores to a buffer, update the state.
BufferState &state = buffers.ref(op->name);
if (!is_one(predicate) || !op->value.type().is_vector()) {
// TODO(psuriana): This store is predicated. Mark the buffer as
// not interleaved for now.
state = BufferState::NotInterleaved;
} else if (yields_removable_interleave(value)) {
// The value yields a removable interleave. If we aren't tracking
// this buffer, mark it as interleaved.
if (state == BufferState::Unknown) {
state = BufferState::Interleaved;
}
} else if (!yields_interleave(value)) {
// The value does not yield an interleave. Mark the
// buffer as not interleaved.
state = BufferState::NotInterleaved;
} else {
// If the buffer yields an interleave, but is not an
// interleave itself, we don't want to change the
// buffer state.
}
internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope");
bool &aligned_accesses = aligned_buffer_access.ref(op->name);
int64_t aligned_offset = 0;
if (!alignment_analyzer.is_aligned(op, &aligned_offset)) {
aligned_accesses = false;
}
}
if (deinterleave_buffers.contains(op->name)) {
// We're deinterleaving this buffer, remove the interleave
// from the store.
internal_assert(is_one(predicate)) << "The store shouldn't have been predicated.\n";
value = remove_interleave(value);
}
if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
return op;
} else {
return Store::make(op->name, value, index, op->param, predicate, op->alignment);
}
}
Expr visit(const Load *op) override {
if (buffers.contains(op->name)) {
if ((op->type.lanes() * op->type.bits()) % (native_vector_bits * 2) == 0) {
// This is a double vector load, we might be able to
// deinterleave the storage of this buffer.
// We don't want to actually do anything to the buffer
// state here. We know we can interleave the load if
// necessary, but we don't want to cause it to be
// interleaved unless it is a useful improvement,
// which is only true if any of the stores are
// actually interleaved (and don't just yield an
// interleave).
internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope");
bool &aligned_accesses = aligned_buffer_access.ref(op->name);
int64_t aligned_offset = 0;
if (!alignment_analyzer.is_aligned(op, &aligned_offset)) {
aligned_accesses = false;
}
} else {
// This is not a double vector load, so we can't
// deinterleave the storage of this buffer.
BufferState &state = buffers.ref(op->name);
state = BufferState::NotInterleaved;
}
}
Expr expr = IRMutator::visit(op);
if (deinterleave_buffers.contains(op->name)) {
expr = native_interleave(expr);
}
return expr;
}
using IRMutator::visit;
public:
EliminateInterleaves(int native_vector_bytes)
: native_vector_bits(native_vector_bytes * 8), alignment_analyzer(native_vector_bytes) {
}
};
// After eliminating interleaves, there may be some that remain. This
// mutator attempts to replace interleaves paired with other
// operations that do not require an interleave. It's important to do
// this after all other efforts to eliminate the interleaves,
// otherwise this might eat some interleaves that could have cancelled
// with other operations.
class FuseInterleaves : public IRMutator {
Expr visit(const Call *op) override {
// This is a list of {f, g} pairs that if the first operation
// is interleaved, interleave(f(x)) is equivalent to g(x).
static const std::vector<std::pair<string, string>> non_deinterleaving_alts = {
{"halide.hexagon.zxt.vub", "halide.hexagon.unpack.vub"},
{"halide.hexagon.sxt.vb", "halide.hexagon.unpack.vb"},
{"halide.hexagon.zxt.vuh", "halide.hexagon.unpack.vuh"},
{"halide.hexagon.sxt.vh", "halide.hexagon.unpack.vh"},
};
if (is_native_interleave(op)) {
if (const Call *arg = op->args[0].as<Call>()) {
for (const auto &i : non_deinterleaving_alts) {
if (arg->name == i.first) {
std::vector<Expr> args = arg->args;
for (Expr &j : args) {
j = mutate(j);
}
return Call::make(op->type, i.second, args, Call::PureExtern);
}
}
}
}
return IRMutator::visit(op);
}
using IRMutator::visit;
};
// Find an upper bound of bounds.max - bounds.min.
Expr span_of_bounds(const Interval &bounds) {
internal_assert(bounds.is_bounded());
const Min *min_min = bounds.min.as<Min>();
const Max *min_max = bounds.min.as<Max>();
const Min *max_min = bounds.max.as<Min>();
const Max *max_max = bounds.max.as<Max>();
const Add *min_add = bounds.min.as<Add>();
const Add *max_add = bounds.max.as<Add>();
const Sub *min_sub = bounds.min.as<Sub>();
const Sub *max_sub = bounds.max.as<Sub>();
if (min_min && max_min && equal(min_min->b, max_min->b)) {
return span_of_bounds({min_min->a, max_min->a});
} else if (min_max && max_max && equal(min_max->b, max_max->b)) {
return span_of_bounds({min_max->a, max_max->a});
} else if (min_add && max_add && equal(min_add->b, max_add->b)) {
return span_of_bounds({min_add->a, max_add->a});
} else if (min_sub && max_sub && equal(min_sub->b, max_sub->b)) {
return span_of_bounds({min_sub->a, max_sub->a});
} else {
return bounds.max - bounds.min;
}
}
// Replace indirect loads with dynamic_shuffle intrinsics where
// possible.
class OptimizeShuffles : public IRMutator {
int lut_alignment;
Scope<Interval> bounds;
std::vector<std::pair<string, Expr>> lets;
using IRMutator::visit;
template<typename NodeType, typename T>
NodeType visit_let(const T *op) {
// We only care about vector lets.
if (op->value.type().is_vector()) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
}
NodeType node = IRMutator::visit(op);
if (op->value.type().is_vector()) {
bounds.pop(op->name);
}
return node;
}
Expr visit(const Let *op) override {
lets.emplace_back(op->name, op->value);
Expr expr = visit_let<Expr>(op);
lets.pop_back();
return expr;
}
Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
}
Expr visit(const Load *op) override {
if (!is_one(op->predicate)) {
// TODO(psuriana): We shouldn't mess with predicated load for now.
return IRMutator::visit(op);
}
if (!op->type.is_vector() || op->index.as<Ramp>()) {
// Don't handle scalar or simple vector loads.
return IRMutator::visit(op);
}
Expr index = mutate(op->index);
Interval unaligned_index_bounds = bounds_of_expr_in_scope(index, bounds);
if (unaligned_index_bounds.is_bounded()) {
// We want to try both the unaligned and aligned
// bounds. The unaligned bounds might fit in 256 elements,
// while the aligned bounds do not.
int align = lut_alignment / op->type.bytes();
Interval aligned_index_bounds = {
(unaligned_index_bounds.min / align) * align,
((unaligned_index_bounds.max + align) / align) * align - 1};
ModulusRemainder alignment(align, 0);
for (Interval index_bounds : {aligned_index_bounds, unaligned_index_bounds}) {
Expr index_span = span_of_bounds(index_bounds);
index_span = common_subexpression_elimination(index_span);
index_span = simplify(index_span);
if (can_prove(index_span < 256)) {
// This is a lookup within an up to 256 element array. We
// can use dynamic_shuffle for this.
int const_extent = as_const_int(index_span) ? *as_const_int(index_span) + 1 : 256;
Expr base = simplify(index_bounds.min);
// Load all of the possible indices loaded from the
// LUT. Note that for clamped ramps, this loads up to 1
// vector past the max. CodeGen_Hexagon::allocation_padding
// returns a native vector size to account for this.
Expr lut = Load::make(op->type.with_lanes(const_extent), op->name,
Ramp::make(base, 1, const_extent),
op->image, op->param, const_true(const_extent), alignment);
// We know the size of the LUT is not more than 256, so we
// can safely cast the index to 8 bit, which
// dynamic_shuffle requires.
index = simplify(cast(UInt(8).with_lanes(op->type.lanes()), index - base));
return Call::make(op->type, "dynamic_shuffle", {lut, index, 0, const_extent - 1}, Call::PureIntrinsic);
}
// Only the first iteration of this loop is aligned.
alignment = ModulusRemainder();
}
}
if (!index.same_as(op->index)) {
return Load::make(op->type, op->name, index, op->image, op->param, op->predicate, op->alignment);
} else {
return op;
}
}
public:
OptimizeShuffles(int lut_alignment)
: lut_alignment(lut_alignment) {
}
};
// Attempt to generate vtmpy instructions. This requires that all lets
// be substituted prior to running, and so must be an IRGraphMutator.
class VtmpyGenerator : public IRGraphMutator {
private:
using IRMutator::visit;
typedef pair<Expr, size_t> LoadIndex;
// Check if vectors a and b point to the same buffer with the base of a
// shifted by diff i.e. base(a) = base(b) + diff.
bool is_base_shifted(const Expr &a, const Expr &b, int diff) {
Expr maybe_load_a = calc_load(a);
Expr maybe_load_b = calc_load(b);
if (maybe_load_a.defined() && maybe_load_b.defined()) {
const Load *load_a = maybe_load_a.as<Load>();
const Load *load_b = maybe_load_b.as<Load>();
if (load_a->name == load_b->name) {
Expr base_diff = simplify(load_a->index - load_b->index - diff);
if (is_const(base_diff, 0)) {
return true;
}
}
}
return false;
}
// Return the load expression of first vector if all vector in exprs are
// contiguous vectors pointing to the same buffer.
Expr are_contiguous_vectors(const vector<Expr> &exprs) {
if (exprs.empty()) {
return Expr();
}
// If the shuffle simplifies then the vectors are contiguous.
// If not, check if the bases of adjacent vectors differ by
// vector size.
Expr concat = simplify(Shuffle::make_concat(exprs));
const Shuffle *maybe_shuffle = concat.as<Shuffle>();
if (!maybe_shuffle || !maybe_shuffle->is_concat()) {
return calc_load(exprs[0]);
}
return Expr();
}
// Returns the load indicating vector start index. If the vector is sliced
// return load with shifted ramp by slice_begin expr.
Expr calc_load(const Expr &e) {
if (const Cast *maybe_cast = e.as<Cast>()) {
return calc_load(maybe_cast->value);
}
if (const Shuffle *maybe_shuffle = e.as<Shuffle>()) {
if (maybe_shuffle->is_slice() && maybe_shuffle->slice_stride() == 1) {
Expr maybe_load = calc_load(maybe_shuffle->vectors[0]);
if (!maybe_load.defined()) {
return Expr();
}
const Load *res = maybe_load.as<Load>();
Expr shifted_load = Load::make(res->type, res->name, res->index + maybe_shuffle->slice_begin(),
res->image, res->param, res->predicate, ModulusRemainder());
return shifted_load;
} else if (maybe_shuffle->is_concat()) {
return are_contiguous_vectors(maybe_shuffle->vectors);
}
}
if (const Load *maybe_load = e.as<Load>()) {
const Ramp *maybe_ramp = maybe_load->index.as<Ramp>();
if (maybe_ramp && is_const(maybe_ramp->stride, 1)) {
return maybe_load;
}
}
return Expr();
}
// Loads comparator for sorting Load Expr of the same buffer.
static bool loads_comparator(const LoadIndex &a, const LoadIndex &b) {
if (a.first.defined() && b.first.defined()) {
const Load *load_a = a.first.as<Load>();
const Load *load_b = b.first.as<Load>();
if (load_a->name == load_b->name) {
Expr base_diff = simplify(load_b->index - load_a->index);
if (is_positive_const(base_diff)) {
return true;
}
} else {
return load_a->name < load_b->name;
}
}
return false;
}
// Vtmpy helps in sliding window ops of the form a*v0 + b*v1 + v2.
// Conditions required:
// v0, v1 and v2 start indices differ by vector stride
// Current supported value of stride is 1.
// TODO: Add support for any stride.
Expr visit(const Add *op) override {
// Find opportunities vtmpy
if (op && op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
int lanes = op->type.lanes();
vector<MulExpr> mpys;
Expr rest;
string vtmpy_suffix;
// Finding more than 100 such expresssions is rare.
// Setting it to 100 makes sure we dont miss anything
// in most cases and also dont spend unreasonable time while
// just looking for vtmpy patterns.
const int max_mpy_ops = 100;
if (op->type.bits() == 16) {
find_mpy_ops(op, UInt(8, lanes), Int(8), max_mpy_ops, mpys, rest);
vtmpy_suffix = ".vub.vub.b.b";
if (mpys.size() < 3) {
mpys.clear();
rest = Expr();
find_mpy_ops(op, Int(8, lanes), Int(8), max_mpy_ops, mpys, rest);
vtmpy_suffix = ".vb.vb.b.b";
}
} else if (op->type.bits() == 32) {
find_mpy_ops(op, Int(16, lanes), Int(8), max_mpy_ops, mpys, rest);
vtmpy_suffix = ".vh.vh.b.b";
}
if (mpys.size() >= 3) {
const size_t mpy_size = mpys.size();
// Used to put loads with different buffers in different buckets.
std::unordered_map<string, vector<LoadIndex>> loads;
// To keep track of indices selected for vtmpy.
std::unordered_map<size_t, bool> vtmpy_indices;
vector<Expr> vtmpy_exprs;
Expr new_expr;
for (size_t i = 0; i < mpy_size; i++) {
Expr curr_load = calc_load(mpys[i].first);
if (curr_load.defined()) {
loads[curr_load.as<Load>()->name].emplace_back(curr_load, i);
} else {
new_expr = new_expr.defined() ? new_expr + curr_load : curr_load;
}
}
for (auto iter = loads.begin(); iter != loads.end(); iter++) {
// Sort the bucket and compare bases of 3 adjacent vectors
// at a time. If they differ by vector stride, we've
// found a vtmpy
// It doesn't see to be easy to write a comparator function that'll implement a
// strict weak ordering. So, we use stable_sort instead of sort so at the very least, the relative order
// of tied elements in the vector to be sorted is not changed.
std::stable_sort(iter->second.begin(), iter->second.end(), loads_comparator);
size_t vec_size = iter->second.size();
for (size_t i = 0; i + 2 < vec_size; i++) {
Expr v0 = iter->second[i].first;
Expr v1 = iter->second[i + 1].first;
Expr v2 = iter->second[i + 2].first;
size_t v0_idx = iter->second[i].second;
size_t v1_idx = iter->second[i + 1].second;
size_t v2_idx = iter->second[i + 2].second;
if (is_const(mpys[v2_idx].second, 1) &&
is_base_shifted(v2, v1, 1) &&
is_base_shifted(v1, v0, 1)) {
vtmpy_indices[v0_idx] = true;
vtmpy_indices[v1_idx] = true;
vtmpy_indices[v2_idx] = true;
vtmpy_exprs.emplace_back(native_interleave(Call::make(op->type,
"halide.hexagon.vtmpy" + vtmpy_suffix,
{mpys[v0_idx].first, mpys[v2_idx].first,
mpys[v0_idx].second, mpys[v1_idx].second},
Call::PureExtern)));
// As we cannot test the same indices again
i = i + 2;
}
}
}
// If we found any vtmpy's then recombine Expr using
// vtmpy_expr, non_vtmpy_exprs and rest.
if (!vtmpy_exprs.empty()) {
for (size_t i = 0; i < mpy_size; i++) {
if (vtmpy_indices[i]) {
continue;
}
// We put expressions in mpys after un-broadcasting them. So, first broadcast
// then call lossless_cast.
Expr a = mpys[i].first;
Expr b = mpys[i].second;
int lanes = op->type.lanes();
if (a.type().is_scalar())
a = Broadcast::make(a, lanes);
if (b.type().is_scalar())
b = Broadcast::make(b, lanes);
Expr mpy_a = lossless_cast(op->type, a);
Expr mpy_b = lossless_cast(op->type, b);
Expr mpy_res = mpy_a * mpy_b;
new_expr = new_expr.defined() ? new_expr + mpy_res : mpy_res;
}
for (size_t i = 0; i < vtmpy_exprs.size(); i++) {
new_expr = new_expr.defined() ? new_expr + vtmpy_exprs[i] : vtmpy_exprs[i];
}
if (rest.defined()) {
new_expr = new_expr + rest;
}
return mutate(new_expr);
}
}
}
return IRMutator::visit(op);
}
};
// Convert some expressions to an equivalent form which could get better
// optimized in later stages for hexagon
class RearrangeExpressions : public IRMutator {
private:
using IRMutator::visit;
Expr visit(const Mul *op) override {
if (!op->type.is_vector()) {
// Only do this for vectors (where we have vmpa).
return IRMutator::visit(op);
}
if (op->a.as<Broadcast>() && !op->b.as<Broadcast>()) {
// Ensures broadcasts always occurs as op1 not op0
return mutate(op->b * op->a);
}
if (op->b.as<Broadcast>() && op->a.type().is_int() &&
(op->a.type().bits() == 16 || op->a.type().bits() == 32)) {
// Distributing broadcasts helps creating more vmpa
// because of more adds of muls. Since muls are
// generally widening ops we need not check if op->a
// is a sum of widening casts.
if (const Add *add = op->a.as<Add>()) {
// simplify() ensures that if add->a is also
// a scalar multiplications, then we combine the two
// broadcasts produced in add->a * op->b. For eg:
// add->a = bc * i16, then simplify combines
// bc * op->b into a single expression.
// Since the simplifier is used on the individual operands
// after distributing the broadcast, the mutation does not
// compete with the simplifier [the next mutation always occurs
// on the simplified operands]. For eg:
// Consider initial expression:
// ((v0 * bc(x) + v1 * bc(y)) + v2) * bc(z)
// Mutation sequence:
// Step 1:
// mutate((v0 * bc(x) + v1 * bc(y)) * bc(z) + v2 * bc(z))
// Step 2:
// mutate((v0 * bc(x) + v1 * bc(y)) * bc(z)) + mutate(v2 * bc(z))
// Step 3 [Result]:
// ((v0 * bc(x * z) + v1 * bc(y *z)) + v2 * bc(z))
return mutate(simplify(add->a * op->b) + simplify(add->b * op->b));
} else if (const Sub *sub = op->a.as<Sub>()) {
return mutate(simplify(sub->a * op->b) - simplify(sub->b * op->b));
}
}
return IRMutator::visit(op);
}
};
// Try generating vgathers instead of shuffles.
// At present, we request VTCM memory with single page allocation flag for all
// store_in allocations. So it's always safe to generate a vgather.
// Expressions which generate vgathers are of the form:
// out(x) = lut(foo(x))
// For vgathers out and lut should be in VTCM in a single page.
class ScatterGatherGenerator : public IRMutator {
Scope<Interval> bounds;
std::unordered_map<string, const Allocate *> allocations;
using IRMutator::visit;
template<typename NodeType, typename T>
NodeType visit_let(const T *op) {
// We only care about vector lets.
if (op->value.type().is_vector()) {
bounds.push(op->name, bounds_of_expr_in_scope(op->value, bounds));
}
NodeType node = IRMutator::visit(op);
if (op->value.type().is_vector()) {
bounds.pop(op->name);
}
return node;
}
Expr visit(const Let *op) override {
return visit_let<Expr>(op);
}
Stmt visit(const LetStmt *op) override {
return visit_let<Stmt>(op);
}
Stmt visit(const Allocate *op) override {
// Create a map of the allocation
allocations[op->name] = op;
return IRMutator::visit(op);
}
// Try to match expressions of the form:
// out(x) = lut(foo(x))
// to generate vgathers. Here, out and lut should have
// store_in(MemoryType::VTCM) directive. If a vgather is found return Call
// Expr to vgather, otherwise Expr().
Expr make_gather(const Load *op, Expr dst_base, Expr dst_index) {
Type ty = op->type;
const Allocate *alloc = allocations[op->name];
// The lut should be in VTCM.
if (!alloc || alloc->memory_type != MemoryType::VTCM) {
return Expr();
}
// HVX has only 16 or 32-bit gathers. Predicated vgathers are not
// supported yet.
if (op->index.as<Ramp>() || !is_one(op->predicate) || !ty.is_vector() ||
ty.bits() == 8) {
return Expr();
}
Expr index = mutate(ty.bytes() * op->index);
Interval index_bounds = bounds_of_expr_in_scope(index, bounds);
if (ty.bits() == 16 && index_bounds.is_bounded()) {
Expr index_span = span_of_bounds(index_bounds);
index_span = common_subexpression_elimination(index_span);
index_span = simplify(index_span);
// We need to downcast the index values to 16 bit signed. So all the
// the indices must be less than 1 << 15.
if (!can_prove(index_span < std::numeric_limits<int16_t>::max())) {
return Expr();
}
}
// Calculate the size of the buffer lut in bytes.
Expr size = ty.bytes();
for (size_t i = 0; i < alloc->extents.size(); i++) {
size *= alloc->extents[i];
}
Expr src = Variable::make(Handle(), op->name);
Expr new_index = mutate(cast(ty.with_code(Type::Int), index));
dst_index = mutate(dst_index);
return Call::make(ty, "gather", {std::move(dst_base), dst_index, src, size - 1, new_index},
Call::Intrinsic);
}
// Checks if the Store node can be replaced with a scatter_accumulate.
// If yes, return new_value to be used for scatter-accumulate, else return
// the input parameter value.
Expr is_scatter_acc(const Store *op) {
Expr lhs = Load::make(op->value.type(), op->name, op->index, Buffer<>(),
Parameter(), const_true(op->value.type().lanes()), op->alignment);
Expr wild = Variable::make(op->value.type(), "*");
vector<Expr> matches;
if (expr_match(lhs + wild, op->value, matches) ||
expr_match(wild + lhs, op->value, matches)) {
// Scatter accumulate found.
return matches[0];
}
return op->value;
}
Stmt visit(const Store *op) override {
// HVX has only 16 or 32-bit gathers. Predicated vgathers are not
// supported yet.
Type ty = op->value.type();
if (!is_one(op->predicate) || !ty.is_vector() || ty.bits() == 8) {
return IRMutator::visit(op);
}
// To use vgathers, the destination address must be VTCM memory.
const Allocate *alloc = allocations[op->name];
if (!alloc || alloc->memory_type != MemoryType::VTCM) {
return IRMutator::visit(op);
}
// The source for a gather must also be a buffer in VTCM.
if (op->index.as<Ramp>() && op->value.as<Load>()) {
// Check for vgathers
Expr dst_base = Variable::make(Handle(), op->name);
Expr dst_index = op->index.as<Ramp>()->base;
Expr value = make_gather(op->value.as<Load>(), dst_base, dst_index);
if (value.defined()) {
// Found a vgather.
// Function make_gather already mutates all the call arguements,
// so no need to mutate again.
return Evaluate::make(value);
}
}
// Check for scatter/scatter-accumulate.
if (op->index.as<Ramp>()) {
return IRMutator::visit(op);
}
// Calculate the size of the buffer in bytes.
Expr size = ty.bytes();
for (size_t i = 0; i < alloc->extents.size(); i++) {
size *= alloc->extents[i];
}
// Check for scatter-acc.
Expr value = is_scatter_acc(op);
Call::IntrinsicOp intrinsic = Call::hvx_scatter;
if (!value.same_as(op->value)) {
// It's a scatter-accumulate
intrinsic = Call::hvx_scatter_acc;
}
Expr buffer = Variable::make(Handle(), op->name);
Expr index = mutate(cast(ty.with_code(Type::Int), ty.bytes() * op->index));
value = mutate(value);
Stmt scatter = Evaluate::make(Call::make(ty, intrinsic,
{buffer, size - 1, index, value}, Call::Intrinsic));
return scatter;
}
};
// Scatter-Gather instructions on Hexagon are asynchronous and hence require a
// scatter-release store followed by a vector load from the same address. This
// stalls the pipeline untill all previous scatter-gather operations have
// finished. The operations are not ordered with respect to load and store
// operations as well.
class SyncronizationBarriers : public IRMutator {
// Keep track of all scatter-gather operations in flight which could cause
// a hazard in the future.
std::map<string, vector<const Stmt *>> in_flight;
// Trail of For Blocks to reach a stmt.
vector<const Stmt *> curr_path;
// Current Stmt being mutated.
const Stmt *curr = NULL;
// Track where the Stmt generated a scatter-release.
std::map<const Stmt *, Expr> sync;
using IRMutator::visit;
Expr visit(const Call *op) override {
if (op->is_intrinsic(Call::hvx_scatter) ||
op->is_intrinsic(Call::hvx_scatter_acc) ||
op->is_intrinsic(Call::hvx_gather)) {
string name = op->args[0].as<Variable>()->name;
// Check if the scatter-gather encountered conflicts with any
// previous operation. If yes, insert a scatter-release.
check_hazard(name);
in_flight[name] = curr_path;
}
return IRMutator::visit(op);
}
Stmt visit(const For *op) override {
// Keep trail of the For blocks encoutered.
curr_path.push_back(curr);
Stmt s = IRMutator::visit(op);
curr_path.pop_back();
return s;
}
// Creates entry in sync map for the stmt requiring a
// scatter-release instruction before it.
void check_hazard(const string &name) {
if (in_flight.find(name) == in_flight.end()) {
return;
}
// Sync Needed. Add the scatter-release before the first different For
// loop lock between the curr_path and the hazard src location.
size_t min_size = std::min(in_flight[name].size(), curr_path.size());
size_t i = 0;
// Find the first different For loop block.
for (; i < min_size; i++) {
if (in_flight[name][i] != curr_path[i]) {
break;
}
}
if (i < curr_path.size()) {
// Place scatter-release before the first different For loop block.
sync[curr_path[i]] = Variable::make(Handle(), name);
} else {
// Need to add the scatter-release before the curr stmt.
sync[curr] = Variable::make(Handle(), name);
}
in_flight.clear();
}
Expr visit(const Load *op) override {
// Resolve scatter-load hazard.
check_hazard(op->name);
return IRMutator::visit(op);
}
Stmt visit(const Store *op) override {
// Resolve scatter-store and gather-store hazards.
check_hazard(op->name);
return IRMutator::visit(op);
}
public:
using IRMutator::mutate;
Stmt mutate(const Stmt &s) override {
curr = &s;
Stmt new_s = IRMutator::mutate(s);
// Wrap the stmt with scatter-release if any hazard was detected.
if (sync.find(&s) != sync.end()) {
Stmt scatter_sync =
Evaluate::make(Call::make(Int(32), Call::hvx_scatter_release, {sync[&s]}, Call::Intrinsic));
return Block::make(scatter_sync, new_s);
}
return new_s;
}
};
} // namespace
Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment) {
// Replace indirect and other complicated loads with
// dynamic_shuffle (vlut) calls.
return OptimizeShuffles(lut_alignment).mutate(s);
}
Stmt vtmpy_generator(Stmt s) {
// Generate vtmpy instruction if possible
s = substitute_in_all_lets(s);
s = VtmpyGenerator().mutate(s);
s = common_subexpression_elimination(s);
return s;
}
Stmt scatter_gather_generator(Stmt s) {
// Generate vscatter-vgather instruction if target >= v65
s = substitute_in_all_lets(s);
s = ScatterGatherGenerator().mutate(s);
s = SyncronizationBarriers().mutate(s);
s = common_subexpression_elimination(s);
return s;
}
Stmt optimize_hexagon_instructions(Stmt s, Target t) {
// Convert some expressions to an equivalent form which get better
// optimized in later stages for hexagon
s = RearrangeExpressions().mutate(s);
// Peephole optimize for Hexagon instructions. These can generate
// interleaves and deinterleaves alongside the HVX intrinsics.
s = OptimizePatterns(t).mutate(s);
// Try to eliminate any redundant interleave/deinterleave pairs.
s = EliminateInterleaves(t.natural_vector_size(Int(8))).mutate(s);
// There may be interleaves left over that we can fuse with other
// operations.
s = FuseInterleaves().mutate(s);
return s;
}
} // namespace Internal
} // namespace Halide